diff --git a/common/src/app/common/features.cljc b/common/src/app/common/features.cljc index 90ed0930a2..c628d8f0ab 100644 --- a/common/src/app/common/features.cljc +++ b/common/src/app/common/features.cljc @@ -56,6 +56,7 @@ "text-editor/v2-html-paste" "text-editor/v2" "render-wasm/v1" + "graph-wasm/v1" "variants/v1"}) ;; A set of features enabled by default @@ -79,7 +80,8 @@ "text-editor/v2-html-paste" "text-editor/v2" "tokens/numeric-input" - "render-wasm/v1"}) + "render-wasm/v1" + "graph-wasm/v1"}) ;; Features that are mainly backend only or there are a proper ;; fallback when frontend reports no support for it @@ -128,6 +130,7 @@ :feature-text-editor-v2 "text-editor/v2" :feature-text-editor-v2-html-paste "text-editor/v2-html-paste" :feature-render-wasm "render-wasm/v1" + :feature-graph-wasm "graph-wasm/v1" :feature-variants "variants/v1" :feature-token-input "tokens/numeric-input" nil)) diff --git a/frontend/resources/wasm-playground/graph.html b/frontend/resources/wasm-playground/graph.html new file mode 100644 index 0000000000..3a71334ee3 --- /dev/null +++ b/frontend/resources/wasm-playground/graph.html @@ -0,0 +1,23 @@ + + + + + + + + + diff --git a/frontend/shadow-cljs.edn b/frontend/shadow-cljs.edn index 5dceccf652..26c48f0701 100644 --- a/frontend/shadow-cljs.edn +++ b/frontend/shadow-cljs.edn @@ -92,7 +92,7 @@ {:main {:entries [app.worker] :web-worker true - :prepend-js "importScripts('./render.js');" + :prepend-js "importScripts('./render.js', './graph-wasm-worker.js');" :depends-on #{}}} :js-options diff --git a/frontend/src/app/graph_wasm.cljs b/frontend/src/app/graph_wasm.cljs new file mode 100644 index 0000000000..9191970611 --- /dev/null +++ b/frontend/src/app/graph_wasm.cljs @@ -0,0 +1,12 @@ +;; This Source Code Form is subject to the terms of the Mozilla Public +;; License, v. 2.0. If a copy of the MPL was not distributed with this +;; file, You can obtain one at http://mozilla.org/MPL/2.0/. +;; +;; Copyright (c) KALEIDOS INC + +(ns app.graph-wasm + "A WASM based render API" + (:require + [app.graph-wasm.api :as wasm.api])) + +(def module wasm.api/module) diff --git a/frontend/src/app/graph_wasm/api.cljs b/frontend/src/app/graph_wasm/api.cljs new file mode 100644 index 0000000000..9e352a124a --- /dev/null +++ b/frontend/src/app/graph_wasm/api.cljs @@ -0,0 +1,91 @@ +;; This Source Code Form is subject to the terms of the Mozilla Public +;; License, v. 2.0. If a copy of the MPL was not distributed with this +;; file, You can obtain one at http://mozilla.org/MPL/2.0/. +;; +;; Copyright (c) KALEIDOS INC + +(ns app.graph-wasm.api + (:require + [app.common.data.macros :as dm] + [app.common.uuid :as uuid] + [app.config :as cf] + [app.graph-wasm.wasm :as wasm] + [app.render-wasm.helpers :as h] + [app.render-wasm.serializers :as sr] + [app.util.modules :as mod] + [promesa.core :as p])) + +(defn hello [] + (h/call wasm/internal-module "_hello")) + +(defn init [] + (h/call wasm/internal-module "_init")) + +(defn use-shape + [id] + (let [buffer (uuid/get-u32 id)] + (println "use-shape" id) + (h/call wasm/internal-module "_use_shape" + (aget buffer 0) + (aget buffer 1) + (aget buffer 2) + (aget buffer 3)))) + +(defn set-shape-parent-id + [id] + (let [buffer (uuid/get-u32 id)] + (h/call wasm/internal-module "_set_shape_parent" + (aget buffer 0) + (aget buffer 1) + (aget buffer 2) + (aget buffer 3)))) + +(defn set-shape-type + [type] + (h/call wasm/internal-module "_set_shape_type" (sr/translate-shape-type type))) + +(defn set-shape-selrect + [selrect] + (h/call wasm/internal-module "_set_shape_selrect" + (dm/get-prop selrect :x1) + (dm/get-prop selrect :y1) + (dm/get-prop selrect :x2) + (dm/get-prop selrect :y2))) + +(defn set-object + [shape] + (let [id (dm/get-prop shape :id) + type (dm/get-prop shape :type) + parent-id (get shape :parent-id) + selrect (get shape :selrect) + children (get shape :shapes)] + (use-shape id) + (set-shape-type type) + (set-shape-parent-id parent-id) + (set-shape-selrect selrect))) + +(defn set-objects + [objects] + (doseq [shape (vals objects)] + (set-object shape))) + +(defn init-wasm-module + [module] + (let [default-fn (unchecked-get module "default") + href (cf/resolve-href "js/graph-wasm.wasm")] + (default-fn #js {:locateFile (constantly href)}))) + +(defonce module + (delay + (if (exists? js/dynamicImport) + (let [uri (cf/resolve-href "js/graph-wasm.js")] + (->> (mod/import uri) + (p/mcat init-wasm-module) + (p/fmap (fn [default] + (set! wasm/internal-module default) + true)) + (p/merr + (fn [cause] + (js/console.error cause) + (p/resolved false))))) + (p/resolved false)))) \ No newline at end of file diff --git a/frontend/src/app/graph_wasm/wasm.cljs b/frontend/src/app/graph_wasm/wasm.cljs new file mode 100644 index 0000000000..6792d8b7cc --- /dev/null +++ b/frontend/src/app/graph_wasm/wasm.cljs @@ -0,0 +1,9 @@ +;; This Source Code Form is subject to the terms of the Mozilla Public +;; License, v. 2.0. If a copy of the MPL was not distributed with this +;; file, You can obtain one at http://mozilla.org/MPL/2.0/. +;; +;; Copyright (c) KALEIDOS INC + +(ns app.graph-wasm.wasm) + +(defonce internal-module #js {}) diff --git a/frontend/src/app/main/data/workspace/componentize.cljs b/frontend/src/app/main/data/workspace/componentize.cljs new file mode 100644 index 0000000000..4e94297c8e --- /dev/null +++ b/frontend/src/app/main/data/workspace/componentize.cljs @@ -0,0 +1,288 @@ +;; This Source Code Form is subject to the terms of the Mozilla Public +;; License, v. 2.0. If a copy of the MPL was not distributed with this +;; file, You can obtain one at http://mozilla.org/MPL/2.0/. +;; +;; Copyright (c) KALEIDOS INC +;; +;; High level helpers to turn a shape subtree into a component and +;; replace equivalent subtrees by instances of that component. + +(ns app.main.data.workspace.componentize + (:require + [app.common.data :as d] + [app.common.data.macros :as dm] + [app.common.files.changes-builder :as pcb] + [app.common.files.helpers :as cfh] + [app.common.geom.point :as gpt] + [app.common.logic.libraries :as cll] + [app.common.logic.shapes :as cls] + [app.common.types.shape :as cts] + [app.common.uuid :as uuid] + [app.main.data.changes :as dch] + [app.main.data.helpers :as dsh] + [app.main.data.workspace.libraries :as dwl] + [app.main.data.workspace.selection :as dws] + [app.main.data.workspace.shapes :as dwsh] + [app.main.data.workspace.undo :as dwu] + [beicon.v2.core :as rx] + [potok.v2.core :as ptk])) + +;; NOTE: We keep this separate from `workspace.libraries` to avoid +;; introducing more complexity in that already big namespace. + +(def ^:private instance-structural-keys + "Keys we do NOT want to copy from the original shape when creating a + new component instance. These are identity / structural / component + metadata keys that must be managed by the component system itself." + #{:id + :parent-id + :frame-id + :shapes + ;; Component metadata + :component-id + :component-file + :component-root + :main-instance + :remote-synced + :shape-ref + :touched}) + +(def ^:private instance-geometry-keys + "Geometry-related keys that we *do* want to override per instance when + copying props from an existing subtree to a component instance." + #{:x + :y + :width + :height + :rotation + :flip-x + :flip-y + :selrect + :points + :proportion + :proportion-lock + :transform + :transform-inverse}) + +(defn- instantiate-similar-subtrees + "Internal helper. Given an atom `id-ref` that will contain the + `component-id`, replace each subtree rooted at the ids in + `similar-ids` by an instance of that component. + + The operation is performed in a single undo transaction: + - Instantiate the component once per similar id, roughly at the same + top-left position as the original root. + - Delete the original subtrees. + - Select the main instance plus all the new instances." + [id-ref root-id similar-ids] + (ptk/reify ::instantiate-similar-subtrees + ptk/WatchEvent + (watch [it state _] + (let [component-id @id-ref + similar-ids (vec (or similar-ids []))] + (if (or (uuid/zero? component-id) + (empty? similar-ids)) + (rx/empty) + (let [file-id (:current-file-id state) + page (dsh/lookup-page state) + page-id (:id page) + objects (:objects page) + libraries (dsh/lookup-libraries state) + fdata (dsh/lookup-file-data state file-id) + + ;; Reference subtree: shapes used to build the component. + ;; We'll compute per-shape deltas against this subtree so + ;; that we only override attributes that actually differ. + ref-subtree-ids (cfh/get-children-ids objects root-id) + ref-all-ids (into [root-id] ref-subtree-ids) + + undo-id (js/Symbol) + + ;; 1) Instantiate component at each similar root position, + ;; preserving per-instance overrides (geometry, style, etc.) + [changes new-root-ids] + (reduce + (fn [[changes acc] sid] + (if-let [shape (get objects sid)] + (let [position (gpt/point (:x shape) (:y shape)) + ;; Remember original parent and index so we can keep + ;; the same ordering among the parent's children. + orig-root (get objects sid) + orig-parent-id (:parent-id orig-root) + orig-index (when orig-parent-id + (cfh/get-position-on-parent objects sid)) + ;; Instantiate a new component instance at the same position + [new-shape changes'] + (cll/generate-instantiate-component + (or changes + (-> (pcb/empty-changes it page-id) + (pcb/with-objects objects))) + objects + file-id + component-id + position + page + libraries) + ;; Build a structural mapping between the original subtree + ;; (rooted at `sid`) and the new instance subtree. + ;; NOTE 1: instantiating a component can introduce an extra + ;; wrapper frame, so we try to align the original root + ;; with the "equivalent" root inside the instance. + ;; NOTE 2: by default the instance may be created *inside* + ;; the original shape (because of layout / hit-testing). + ;; We explicitly move the new instance to the same parent + ;; and index as the original root, so that later deletes of + ;; the original subtree don't remove the new instances and + ;; the ordering among siblings is preserved. + changes' (cond-> changes' + (some? orig-parent-id) + (pcb/change-parent orig-parent-id [new-shape] orig-index + {:allow-altering-copies true + :ignore-touched true})) + objects' (pcb/get-objects changes') + orig-root (get objects sid) + new-root new-shape + orig-type (:type orig-root) + new-type (:type new-root) + ;; Full original subtree (root + descendants) + orig-subtree-ids (cfh/get-children-ids objects sid) + orig-all-ids (into [sid] orig-subtree-ids) + ;; Try to find an inner instance root matching the original type + ;; when the outer instance root type differs (e.g. rect -> frame+rect). + direct-new-children (cfh/get-children-ids objects' (:id new-root)) + candidate-instance-root + (when (and orig-type (not= orig-type new-type)) + (let [cands (->> direct-new-children + (filter (fn [nid] + (when-let [s (get objects' nid)] + (= (:type s) orig-type)))))] + (when (= 1 (count cands)) + (first cands)))) + instance-root-id (or candidate-instance-root (:id new-root)) + instance-root (get objects' instance-root-id) + new-subtree-ids (cfh/get-children-ids objects' instance-root-id) + new-all-ids (into [instance-root-id] new-subtree-ids) + id-pairs (map vector orig-all-ids new-all-ids) + changes'' + ;; Compute per-shape deltas against the reference + ;; subtree (root-id) and apply only the differences + ;; to the new instance subtree, so we don't blindly + ;; overwrite attributes that are the same. + (reduce + (fn [ch [idx orig-id new-id]] + (let [ref-id (nth ref-all-ids idx nil) + ref-shape (get objects ref-id) + orig-shape (get objects orig-id)] + (if (and ref-shape orig-shape) + (let [;; Style / layout / text props (see `extract-props`) + ref-style (cts/extract-props ref-shape) + orig-style (cts/extract-props orig-shape) + style-delta (reduce (fn [m k] + (let [rv (get ref-style k ::none) + ov (get orig-style k ::none)] + (if (= rv ov) + m + (assoc m k ov)))) + {} + (keys orig-style)) + + ;; Geometry props + ref-geom (select-keys ref-shape instance-geometry-keys) + orig-geom (select-keys orig-shape instance-geometry-keys) + geom-delta (reduce (fn [m k] + (let [rv (get ref-geom k ::none) + ov (get orig-geom k ::none)] + (if (= rv ov) + m + (assoc m k ov)))) + {} + (keys orig-geom)) + + ;; Text content: if the subtree reference and the + ;; original differ in `:content`, treat the whole + ;; content tree as an override for this instance. + content-delta? (not= (:content ref-shape) (:content orig-shape))] + (-> ch + ;; First patch style/text/layout props using the + ;; canonical helpers so we don't touch structural ids. + (cond-> (seq style-delta) + (pcb/update-shapes + [new-id] + (fn [s objs] (cts/patch-props s style-delta objs)) + {:with-objects? true})) + ;; Then patch geometry directly on the instance. + (cond-> (seq geom-delta) + (pcb/update-shapes + [new-id] + (d/patch-object geom-delta))) + ;; Finally, if text content differs between the + ;; reference subtree and the similar subtree, + ;; override the instance content with the original. + (cond-> content-delta? + (pcb/update-shapes + [new-id] + #(assoc % :content (:content orig-shape)))))) + ch))) + changes' + (map-indexed (fn [idx [orig-id new-id]] + [idx orig-id new-id]) + id-pairs))] + [changes'' (conj acc (:id new-shape))]) + ;; If the shape does not exist we just skip it + [changes acc])) + [nil []] + similar-ids) + + changes (or changes + (-> (pcb/empty-changes it page-id) + (pcb/with-objects objects))) + + ;; 2) Delete original similar subtrees + ;; NOTE: `d/ordered-set` with a single arg treats it as a single + ;; element, so we must use `into` when we already have a collection. + ids-to-delete (into (d/ordered-set) similar-ids) + [all-parents changes] + (cls/generate-delete-shapes + changes + fdata + page + objects + ids-to-delete + {:allow-altering-copies true}) + + ;; 3) Select main instance + new instances + ;; Root id is kept as-is; add all new roots. + sel-ids (into (d/ordered-set) (cons root-id new-root-ids))] + + (rx/of + (dwu/start-undo-transaction undo-id) + (dch/commit-changes changes) + (ptk/data-event :layout/update {:ids all-parents}) + (dwu/commit-undo-transaction undo-id)))))))) + +(defn componentize-similar-subtrees + "Turn the subtree rooted at `root-id` into a component, then replace + the subtrees rooted at `similar-ids` with instances of that component. + + This is implemented in two phases: + 1) Use the existing `dwl/add-component` flow to create a component + from `root-id` (and obtain its `component-id`). + 2) Using the new `component-id`, instantiate the component once per + entry in `similar-ids` and delete the old subtrees." + [root-id similar-ids] + (dm/assert! + "expected valid uuid for `root-id`" + (uuid? root-id)) + + (let [similar-ids (vec (or similar-ids []))] + (ptk/reify ::componentize-similar-subtrees + ptk/WatchEvent + (watch [_ _ _] + (let [id-ref (atom uuid/zero)] + (rx/concat + ;; 1) Create component using the existing pipeline + (rx/of (dwl/add-component id-ref (d/ordered-set root-id))) + ;; 2) Replace similar subtrees by instances of the new component + (rx/of (instantiate-similar-subtrees id-ref root-id similar-ids)))))))) + + diff --git a/frontend/src/app/main/ui/workspace/viewport.cljs b/frontend/src/app/main/ui/workspace/viewport.cljs index 7d08054fb4..52da3b5806 100644 --- a/frontend/src/app/main/ui/workspace/viewport.cljs +++ b/frontend/src/app/main/ui/workspace/viewport.cljs @@ -13,11 +13,16 @@ [app.common.geom.shapes :as gsh] [app.common.types.color :as clr] [app.common.types.component :as ctk] + [app.common.types.container :as ctn] [app.common.types.path :as path] [app.common.types.shape :as cts] [app.common.types.shape-tree :as ctt] [app.common.types.shape.layout :as ctl] + [app.config :as cf] + [app.graph-wasm.api :as graph-wasm.api] + [app.main.data.workspace.componentize :as dwc] [app.main.data.workspace.modifiers :as dwm] + [app.main.data.workspace.selection :as dws] [app.main.data.workspace.variants :as dwv] [app.main.features :as features] [app.main.refs :as refs] @@ -57,8 +62,11 @@ [app.main.ui.workspace.viewport.utils :as utils] [app.main.ui.workspace.viewport.viewport-ref :refer [create-viewport-ref]] [app.main.ui.workspace.viewport.widgets :as widgets] + [app.main.worker :as worker] [app.util.debug :as dbg] + [app.util.modules :as mod] [beicon.v2.core :as rx] + [promesa.core :as p] [rumext.v2 :as mf])) ;; --- Viewport @@ -134,6 +142,7 @@ mod? (mf/use-state false) space? (mf/use-state false) z? (mf/use-state false) + g? (mf/use-state false) cursor (mf/use-state #(utils/get-cursor :pointer-inner)) hover-ids (mf/use-state nil) hover (mf/use-state nil) @@ -302,12 +311,79 @@ (mf/use-fn (mf/deps first-shape) #(st/emit! - (dwv/add-new-variant (:id first-shape))))] + (dwv/add-new-variant (:id first-shape)))) + + graph-wasm-enabled? (features/use-feature "graph-wasm/v1")] + + (mf/with-effect [page-id] + (when graph-wasm-enabled? + ;; Initialize graph-wasm in the worker to avoid blocking main thread + (let [subscription + (->> (worker/ask! {:cmd :graph-wasm/init}) + (rx/filter #(= (:status %) :ok)) + (rx/take 1) + (rx/merge-map (fn [_] + (worker/ask! {:cmd :graph-wasm/set-objects + :objects base-objects}))))] + + (rx/subscribe subscription + (fn [result] + (when (= (:status result) :ok) + (js/console.debug "Graph WASM initialized in worker" + (select-keys result [:processed])))) + (fn [error] + (js/console.error "Error initializing graph-wasm in worker:" error)) + (fn [] + (js/console.debug "Graph WASM worker operations completed")))))) + + (mf/with-effect [selected @g?] + (when graph-wasm-enabled? + ;; Search for similar shapes when selection changes or when + ;; the user presses the \"c\" key while having a single + ;; selection. + (when (and @g? + (some? selected) + (= (count selected) 1)) + (let [selected-id (first selected) + selected-shape (get base-objects selected-id) + ;; Skip shapes that already belong to a component + non-component? (and (some? selected-shape) + (not (ctn/in-any-component? base-objects selected-shape)))] + + (println selected-shape) + (println (ctn/in-any-component? base-objects selected-shape)) + + (when non-component? + (let [subscription + (worker/ask! {:cmd :graph-wasm/search-similar-shapes + :shape-id selected-id})] + + (rx/subscribe subscription + (fn [result] + (when (= (:status result) :ok) + (let [raw-similar-shapes (:similar-shapes result) + ;; Filter out shapes that already belong to some component + ;; (main instance, instance head or inside a component copy). + similar-shapes (->> raw-similar-shapes + (remove (fn [sid] + (when-let [s (get base-objects sid)] + (ctn/in-any-component? base-objects s)))) + (into []))] + (when (d/not-empty? similar-shapes) + ;; Transform the selected subtree into a component and + ;; replace similar subtrees by instances of that component. + (st/emit! (dwc/componentize-similar-subtrees + selected-id + similar-shapes)))))) + (fn [error] + (js/console.error "Error searching similar shapes:" error)) + (fn [] + (js/console.debug "Similar shapes search completed"))))))))) (hooks/setup-dom-events zoom disable-paste-ref in-viewport-ref read-only? drawing-tool path-drawing?) (hooks/setup-viewport-size vport viewport-ref) (hooks/setup-cursor cursor alt? mod? space? panning drawing-tool path-drawing? path-editing? z? read-only?) - (hooks/setup-keyboard alt? mod? space? z? shift?) + (hooks/setup-keyboard alt? mod? space? z? shift? g?) (hooks/setup-hover-shapes page-id move-stream base-objects transform selected mod? hover measure-hover hover-ids hover-top-frame-id @hover-disabled? focus zoom show-measures?) (hooks/setup-viewport-modifiers modifiers base-objects) diff --git a/frontend/src/app/main/ui/workspace/viewport/hooks.cljs b/frontend/src/app/main/ui/workspace/viewport/hooks.cljs index 922e18057d..af8f4ad692 100644 --- a/frontend/src/app/main/ui/workspace/viewport/hooks.cljs +++ b/frontend/src/app/main/ui/workspace/viewport/hooks.cljs @@ -124,7 +124,7 @@ (reset! cursor new-cursor)))))) (defn setup-keyboard - [alt* mod* space* z* shift*] + [alt* mod* space* z* shift* g*] (let [kbd-zoom-s (mf/with-memo [] (->> ms/keyboard @@ -151,12 +151,22 @@ (rx/filter kbd/z?) (rx/filter (complement kbd/editing-event?)) (rx/map kbd/key-down-event?) - (rx/pipe (rxo/distinct-contiguous))))] + (rx/pipe (rxo/distinct-contiguous)))) + + kbd-g-s + (mf/with-memo [] + (let [c-pred (kbd/is-key-ignore-case? "g")] + (->> ms/keyboard + (rx/filter c-pred) + (rx/filter (complement kbd/editing-event?)) + (rx/map kbd/key-down-event?) + (rx/pipe (rxo/distinct-contiguous)))))] (hooks/use-stream ms/keyboard-alt (partial reset! alt*)) (hooks/use-stream ms/keyboard-space (partial reset! space*)) (hooks/use-stream kbd-z-s (partial reset! z*)) (hooks/use-stream kbd-shift-s (partial reset! shift*)) + (hooks/use-stream kbd-g-s (partial reset! g*)) (hooks/use-stream ms/keyboard-mod (fn [value] (reset! mod* value) diff --git a/frontend/src/app/main/ui/workspace/viewport_wasm.cljs b/frontend/src/app/main/ui/workspace/viewport_wasm.cljs index a667d3abc5..2d83f6c3d7 100644 --- a/frontend/src/app/main/ui/workspace/viewport_wasm.cljs +++ b/frontend/src/app/main/ui/workspace/viewport_wasm.cljs @@ -122,6 +122,7 @@ mod? (mf/use-state false) space? (mf/use-state false) z? (mf/use-state false) + c? (mf/use-state false) cursor (mf/use-state (utils/get-cursor :pointer-inner)) hover-ids (mf/use-state nil) hover (mf/use-state nil) @@ -360,7 +361,7 @@ (hooks/setup-dom-events zoom disable-paste-ref in-viewport-ref read-only? drawing-tool path-drawing?) (hooks/setup-viewport-size vport viewport-ref) (hooks/setup-cursor cursor alt? mod? space? panning drawing-tool path-drawing? path-editing? z? read-only?) - (hooks/setup-keyboard alt? mod? space? z? shift?) + (hooks/setup-keyboard alt? mod? space? z? shift? c?) (hooks/setup-hover-shapes page-id move-stream base-objects transform selected mod? hover measure-hover hover-ids hover-top-frame-id @hover-disabled? focus zoom show-measures?) (hooks/setup-shortcuts path-editing? path-drawing? text-editing? grid-editing?) diff --git a/frontend/src/app/worker.cljs b/frontend/src/app/worker.cljs index e6ca3fb1d2..56e33d7d12 100644 --- a/frontend/src/app/worker.cljs +++ b/frontend/src/app/worker.cljs @@ -11,6 +11,7 @@ [app.common.schema :as sm] [app.common.types.objects-map] [app.util.object :as obj] + [app.worker.graph-wasm] [app.worker.impl :as impl] [app.worker.import] [app.worker.index] diff --git a/frontend/src/app/worker/graph_wasm.cljs b/frontend/src/app/worker/graph_wasm.cljs new file mode 100644 index 0000000000..07340b8f41 --- /dev/null +++ b/frontend/src/app/worker/graph_wasm.cljs @@ -0,0 +1,181 @@ +;; This Source Code Form is subject to the terms of the Mozilla Public +;; License, v. 2.0. If a copy of the MPL was not distributed with this +;; file, You can obtain one at http://mozilla.org/MPL/2.0/. +;; +;; Copyright (c) KALEIDOS INC + +(ns app.worker.graph-wasm + "Graph WASM operations within the worker." + (:require + [app.common.data.macros :as dm] + [app.common.logging :as log] + [app.common.uuid :as uuid] + [app.config :as cf] + [app.graph-wasm.wasm :as wasm] + [app.render-wasm.helpers :as h] + [app.render-wasm.serializers :as sr] + [app.worker.impl :as impl] + [beicon.v2.core :as rx] + [promesa.core :as p])) + +(log/set-level! :info) + +(defn- use-shape + [module id] + (let [buffer (uuid/get-u32 id)] + (h/call module "_use_shape" + (aget buffer 0) + (aget buffer 1) + (aget buffer 2) + (aget buffer 3)))) + +(defn- set-shape-parent-id + [module id] + (let [buffer (uuid/get-u32 id)] + (h/call module "_set_shape_parent" + (aget buffer 0) + (aget buffer 1) + (aget buffer 2) + (aget buffer 3)))) + +(defn- set-shape-type + [module type] + (h/call module "_set_shape_type" (sr/translate-shape-type type))) + +(defn- set-shape-selrect + [module selrect] + (h/call module "_set_shape_selrect" + (dm/get-prop selrect :x1) + (dm/get-prop selrect :y1) + (dm/get-prop selrect :x2) + (dm/get-prop selrect :y2))) + +(defn- set-object + [module shape] + (let [id (dm/get-prop shape :id) + type (dm/get-prop shape :type) + parent-id (get shape :parent-id) + selrect (get shape :selrect)] + (use-shape module id) + (set-shape-type module type) + (set-shape-parent-id module parent-id) + (set-shape-selrect module selrect))) + +(defonce ^:private graph-wasm-module + (delay + (let [module (unchecked-get js/globalThis "GraphWasmModule") + init-fn (unchecked-get module "default") + href (cf/resolve-href "js/graph-wasm.wasm")] + (->> (init-fn #js {:locateFile (constantly href)}) + (p/fnly (fn [module cause] + (if cause + (js/console.error cause) + (set! wasm/internal-module module)))))))) + +(defmethod impl/handler :graph-wasm/init + [message transfer] + (rx/create + (fn [subs] + (-> @graph-wasm-module + (p/then (fn [module] + (if module + (try + (h/call module "_init") + (rx/push! subs {:status :ok}) + (rx/end! subs) + (catch :default cause + (log/error :hint "Error in graph-wasm/init" :cause cause) + (rx/error! subs cause) + (rx/end! subs))) + (do + (log/warn :hint "Graph WASM module not available") + (rx/push! subs {:status :error :message "Module not available"}) + (rx/end! subs))))) + (p/catch (fn [cause] + (log/error :hint "Error loading graph-wasm module" :cause cause) + (rx/error! subs cause) + (rx/end! subs)))) + nil))) + +(defmethod impl/handler :graph-wasm/set-objects + [message transfer] + (let [objects (:objects message)] + (rx/create + (fn [subs] + (-> @graph-wasm-module + (p/then (fn [module] + (if module + (try + (doseq [shape (vals objects)] + (set-object module shape)) + (h/call module "_generate_db") + (rx/push! subs {:status :ok :processed (count objects)}) + (rx/end! subs) + (catch :default cause + (log/error :hint "Error in graph-wasm/set-objects" :cause cause) + (rx/error! subs cause) + (rx/end! subs))) + (do + (log/warn :hint "Graph WASM module not available") + (rx/push! subs {:status :error :message "Module not available"}) + (rx/end! subs))))) + (p/catch (fn [cause] + (log/error :hint "Error loading graph-wasm module" :cause cause) + (rx/error! subs cause) + (rx/end! subs)))) + nil)))) + +(defmethod impl/handler :graph-wasm/search-similar-shapes + [message transfer] + (let [shape-id (:shape-id message)] + (rx/create + (fn [subs] + (-> @graph-wasm-module + (p/then (fn [module] + (if module + (try + (let [buffer (uuid/get-u32 shape-id) + ptr-raw (h/call module "_search_similar_shapes" + (aget buffer 0) + (aget buffer 1) + (aget buffer 2) + (aget buffer 3)) + ;; Convert pointer to unsigned 32-bit (handle negative numbers from WASM) + ;; Use unsigned right shift to convert signed to unsigned 32-bit + ptr (unsigned-bit-shift-right ptr-raw 0) + heapu8 (unchecked-get module "HEAPU8") + + ;; Read count (first 4 bytes, little-endian u32) + count (bit-or (aget heapu8 ptr) + (bit-shift-left (aget heapu8 (+ ptr 1)) 8) + (bit-shift-left (aget heapu8 (+ ptr 2)) 16) + (bit-shift-left (aget heapu8 (+ ptr 3)) 24)) + ;; Read UUIDs (16 bytes each, starting at offset 4) + similar-shapes (loop [offset (+ ptr 4) + remaining count + result []] + (if (zero? remaining) + result + (let [uuid-bytes (.slice heapu8 offset (+ offset 16))] + (recur (+ offset 16) + (dec remaining) + (conj result (uuid/from-bytes uuid-bytes))))))] + + ;; Free the buffer + (h/call module "_free_similar_shapes_buffer") + + (rx/push! subs {:status :ok :similar-shapes similar-shapes}) + (rx/end! subs)) + (catch :default cause + (log/error :hint "Error in graph-wasm/search-similar-shapes" :cause cause) + (rx/error! subs cause) + (rx/end! subs))) + (do + (log/warn :hint "Graph WASM module not available") + (rx/push! subs {:status :error :message "Module not available"}) + (rx/end! subs))))) + (p/catch (fn [cause] + (log/error :hint "Error loading graph-wasm module" :cause cause) + (rx/error! subs cause) + (rx/end! subs)))) + nil)))) diff --git a/frontend/test/frontend_tests/svg_filters_test.cljs b/frontend/test/frontend_tests/svg_filters_test.cljs index d469183389..8fccd29bdb 100644 --- a/frontend/test/frontend_tests/svg_filters_test.cljs +++ b/frontend/test/frontend_tests/svg_filters_test.cljs @@ -47,3 +47,6 @@ result (svg-filters/apply-svg-filters shape)] (is (= shape result)))) + + + diff --git a/graph-wasm/.cargo/config.toml b/graph-wasm/.cargo/config.toml new file mode 100644 index 0000000000..46eec127dc --- /dev/null +++ b/graph-wasm/.cargo/config.toml @@ -0,0 +1,6 @@ +[target.wasm32-unknown-emscripten] +# Note: Not using atomics to avoid recompiling std library +# We're running without pthreads, so lbug needs to work single-threaded +rustflags = ["-C", "link-arg=-fexceptions"] +# Linker is configured via environment variable CARGO_TARGET_WASM32_UNKNOWN_EMSCRIPTEN_LINKER in _build_env + diff --git a/graph-wasm/.gitignore b/graph-wasm/.gitignore new file mode 100644 index 0000000000..391ed4d660 --- /dev/null +++ b/graph-wasm/.gitignore @@ -0,0 +1,5 @@ +target/ +debug/ + +**/*.rs.bk + diff --git a/graph-wasm/Cargo.lock b/graph-wasm/Cargo.lock new file mode 100644 index 0000000000..3526a993e2 --- /dev/null +++ b/graph-wasm/Cargo.lock @@ -0,0 +1,486 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bumpalo" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" + +[[package]] +name = "cc" +version = "1.2.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "clap" +version = "4.5.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" +dependencies = [ + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_lex" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" + +[[package]] +name = "cmake" +version = "0.1.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586" +dependencies = [ + "cc", +] + +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + +[[package]] +name = "cxx" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3956d60afa98653c5a57f60d7056edd513bfe0307ef6fb06f6167400c3884459" +dependencies = [ + "cc", + "cxxbridge-cmd", + "cxxbridge-flags", + "cxxbridge-macro", + "foldhash", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a4b7522f539fe056f1d6fc8577d8ab731451f6f33a89b1e5912e22b76c553e7" +dependencies = [ + "cc", + "codespan-reporting", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-cmd" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f01e92ab4ce9fd4d16e3bb11b158d98cbdcca803c1417aa43130a6526fbf208" +dependencies = [ + "clap", + "codespan-reporting", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c41cbfab344869e70998b388923f7d1266588f56c8ca284abf259b1c1ffc695" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d82a2f759f0ad3eae43b96604efd42b1d4729a35a6f2dc7bdb797ae25d9284" +dependencies = [ + "proc-macro2", + "quote", + "rustversion", + "syn", +] + +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "graph" +version = "0.1.0" +dependencies = [ + "lbug", + "uuid", +] + +[[package]] +name = "js-sys" +version = "0.3.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lbug" +version = "0.12.2" +dependencies = [ + "cmake", + "cxx", + "cxx-build", + "rust_decimal", + "rustversion", + "time", + "uuid", +] + +[[package]] +name = "libc" +version = "0.2.178" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" + +[[package]] +name = "link-cplusplus" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" +dependencies = [ + "cc", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rust_decimal" +version = "1.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282" +dependencies = [ + "arrayvec", + "num-traits", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "scratch" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "time" +version = "0.3.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "uuid" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" +dependencies = [ + "getrandom", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" diff --git a/graph-wasm/Cargo.toml b/graph-wasm/Cargo.toml new file mode 100644 index 0000000000..575781c60c --- /dev/null +++ b/graph-wasm/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "graph" +version = "0.1.0" +edition = "2021" +repository = "https://github.com/penpot/penpot" +license-file = "../LICENSE" +description = "Wasm-based graph module for Penpot" + +build = "build.rs" + +[[bin]] +name = "graph_wasm" +path = "src/main.rs" + +[profile.release] +opt-level = "s" +# Note: We need panic=unwind because core requires it +# We'll compile std from source with atomics support instead + +[profile.dev] +# Note: We need panic=unwind because core requires it +# We'll compile std from source with atomics support instead + +[dependencies] +lbug = "0.12.2" +uuid = { version = "1.11.0", features = ["v4", "js"] } + +# Patch lbug to use local version with wasm32-unknown-emscripten support +[patch.crates-io] +lbug = { path = "./lbug-0.12.2" } + diff --git a/graph-wasm/README_WASM.md b/graph-wasm/README_WASM.md new file mode 100644 index 0000000000..f5912db5c4 --- /dev/null +++ b/graph-wasm/README_WASM.md @@ -0,0 +1,197 @@ +# lbug WASM Support + +This document describes the modifications made to the `lbug` crate to enable compilation for the `wasm32-unknown-emscripten` target, and the build requirements for using it in a WASM context. + +## Overview + +The `lbug` crate is a Rust wrapper around a C++ graph database library. To compile it for WebAssembly using Emscripten, several modifications were necessary to handle: + +1. C++ exception handling in WASM +2. Conditional compilation for WASM-specific code paths +3. Proper linking of static libraries for Emscripten +4. CMake configuration for single-threaded mode + +## Changes Made to lbug + +### 1. `build.rs` Modifications + +The build script (`build.rs`) was modified to detect and handle the `wasm32-unknown-emscripten` target: + +#### WASM Detection +```rust +fn is_wasm_emscripten() -> bool { + env::var("TARGET") + .map(|t| t == "wasm32-unknown-emscripten") + .unwrap_or(false) +} +``` + +#### CMake Configuration (`build_bundled_cmake()`) +- **Single-threaded mode**: Sets `SINGLE_THREADED=TRUE` for WASM builds (required by Emscripten) +- The Emscripten toolchain is automatically detected when `CC`/`CXX` point to `emcc`/`em++` + +#### FFI Build Configuration (`build_ffi()`) +- **C++20 standard**: Uses `-std=c++20` flag for WASM +- **Exception support**: Enables `-fexceptions` flag (exceptions must be enabled at compile time) +- Note: `-sDISABLE_EXCEPTION_CATCHING=0` is a linker flag and should be set via `EMCC_CFLAGS` + +#### Library Linking (`link_libraries()`) +- **Explicit dependency linking**: For WASM, all static dependencies are explicitly linked: + - `utf8proc`, `antlr4_cypher`, `antlr4_runtime`, `re2`, `fastpfor` + - `parquet`, `thrift`, `snappy`, `zstd`, `miniz` + - `mbedtls`, `brotlidec`, `brotlicommon`, `lz4` + - `roaring_bitmap`, `simsimd` +- **Linking order**: Libraries are linked after FFI compilation for WASM (different from native builds) + +### 2. `src/error.rs` Modifications + +The error handling code was modified to conditionally compile C++ exception support: + +#### Conditional C++ Exception Variant +The `Error::CxxException` variant and related implementations are conditionally compiled: + +```rust +#[cfg(not(target_arch = "wasm32"))] +pub enum Error { + // ... other variants ... + CxxException(cxx::Exception), + // ... +} +``` + +#### Exception Mapping for WASM +In WASM builds, `cxx::Exception` is mapped to `Error::FailedQuery`: + +```rust +impl From for Error { + fn from(item: cxx::Exception) -> Self { + #[cfg(not(target_arch = "wasm32"))] + { + Error::CxxException(item) + } + #[cfg(target_arch = "wasm32")] + { + // In wasm, CxxException is not available, map to a generic error + Error::FailedQuery(item.to_string()) + } + } +} +``` + +**Note**: This change does not affect the rest of `lbug` due to `#[cfg]` guards, ensuring native builds remain unchanged. + +## Build Requirements + +### 1. Using the Modified lbug Crate + +To use the modified `lbug` crate in your project, add a `[patch.crates-io]` section to your `Cargo.toml`: + +```toml +[dependencies] +lbug = "0.12.2" + +# Patch lbug to use local version with wasm32-unknown-emscripten support +[patch.crates-io] +lbug = { path = "./lbug-0.12.2" } +``` + +### 2. Emscripten Environment Setup + +The build requires Emscripten to be properly configured. The following environment variables should be set: + +#### Memory Configuration +```bash +export EM_INITIAL_HEAP=$((256 * 1024 * 1024)) # 256 MB initial heap +export EM_MAXIMUM_MEMORY=$((4 * 1024 * 1024 * 1024)) # 4 GB maximum +export EM_MEMORY_GROWTH_GEOMETRIC_STEP=0.8 +export EM_MALLOC=dlmalloc +``` + +#### Compiler/Linker Configuration +```bash +# Prevent cc-rs from adding default flags that conflict with Emscripten +export CRATE_CC_NO_DEFAULTS=1 + +# Emscripten compiler flags +export EMCC_CFLAGS="--no-entry \ + -sASSERTIONS=1 \ + -sALLOW_TABLE_GROWTH=1 \ + -sALLOW_MEMORY_GROWTH=1 \ + -sINITIAL_HEAP=$EM_INITIAL_HEAP \ + -sMEMORY_GROWTH_GEOMETRIC_STEP=$EM_MEMORY_GROWTH_GEOMETRIC_STEP \ + -sMAXIMUM_MEMORY=$EM_MAXIMUM_MEMORY \ + -sERROR_ON_UNDEFINED_SYMBOLS=0 \ + -sDISABLE_EXCEPTION_CATCHING=0 \ + -sEXPORT_NAME=createGraphModule \ + -sEXPORTED_RUNTIME_METHODS=stringToUTF8,HEAPU8 \ + -sENVIRONMENT=web \ + -sMODULARIZE=1 \ + -sEXPORT_ES6=1" +``` + +#### Function Exports + +To control which functions are exported (avoiding issues with `$` symbols in auto-generated exports), use `RUSTFLAGS`: + +```bash +export RUSTFLAGS="-C link-arg=-sEXPORTED_FUNCTIONS=@${SCRIPT_DIR}/exports.txt -C link-arg=-sEXPORT_ALL=0" +``` + +Where `exports.txt` contains the list of functions to export (one per line, with `_` prefix): +``` +_hello +_generate_db +_init +_search_similar_shapes +# ... etc +``` + +### 3. Build Process + +1. **Source Emscripten environment**: + ```bash + source /opt/emsdk/emsdk_env.sh + ``` + +2. **Set build environment**: + ```bash + source ./_build_env + ``` + +3. **Build**: + ```bash + cargo build --target=wasm32-unknown-emscripten + ``` + +## Key Differences from Native Builds + +1. **Single-threaded**: WASM builds use `SINGLE_THREADED=TRUE` in CMake +2. **Exception handling**: C++ exceptions are enabled at compile time (`-fexceptions`) and runtime (`-sDISABLE_EXCEPTION_CATCHING=0`) +3. **Linking order**: Libraries are linked after FFI compilation for WASM +4. **Error handling**: C++ exceptions are mapped to `FailedQuery` errors in WASM +5. **Function exports**: Manual control of exported functions via `EXPORTED_FUNCTIONS` file + +## Troubleshooting + +### Missing Symbols +If you encounter "missing function" errors at runtime, ensure: +- All required static libraries are listed in `link_libraries()` for WASM +- Libraries are linked in the correct order (after FFI compilation) +- `EXPORTED_FUNCTIONS` includes all functions you need to call from JavaScript + +### Invalid Export Names +If you see errors like `invalid export name: cxxbridge1$exception`: +- Use `EXPORT_ALL=0` and manually specify functions in `exports.txt` +- Avoid using `EXPORT_ALL=1` with auto-generated export lists that may contain `$` symbols + +### CMake Compiler Detection Errors +If CMake fails to detect the compiler: +- Ensure `CC` and `CXX` environment variables point to `emcc` and `em++` +- The Emscripten toolchain should be automatically detected by `cmake-rs` + +## References + +- [Emscripten Documentation](https://emscripten.org/docs/getting_started/index.html) +- [Rust and WebAssembly](https://rustwasm.github.io/docs/book/) +- [cxx crate documentation](https://cxx.rs/) + diff --git a/graph-wasm/_build_env b/graph-wasm/_build_env new file mode 100644 index 0000000000..a3617193c4 --- /dev/null +++ b/graph-wasm/_build_env @@ -0,0 +1,105 @@ +#!/usr/bin/env bash + +# -------------------- +# Build configuration +# -------------------- + +_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export CURRENT_VERSION="${CURRENT_VERSION:-develop}" +export BUILD_NAME="${BUILD_NAME:-graph-wasm}" +export CARGO_BUILD_TARGET="${CARGO_BUILD_TARGET:-wasm32-unknown-emscripten}" +# Keep the emscripten cache in-repo so system cache cleaners do not wipe it. +export EM_CACHE="${EM_CACHE:-${_SCRIPT_DIR}/.emsdk_cache}" +export CARGO_TARGET_DIR="${CARGO_TARGET_DIR:-${_SCRIPT_DIR}/target}" +if [[ -z "${CARGO_INCREMENTAL:-}" && "${NODE_ENV:-}" != "production" ]]; then + export CARGO_INCREMENTAL=1 +fi +if [[ -z "${RUSTC_WRAPPER:-}" ]] && command -v sccache >/dev/null 2>&1; then + export RUSTC_WRAPPER=sccache +fi +export CRATE_CC_NO_DEFAULTS=1 + +# BUILD_MODE +if [[ "${NODE_ENV:-}" == "production" ]]; then + BUILD_MODE=release +else + BUILD_MODE="${1:-debug}" +fi +export BUILD_MODE + +# -------------------- +# Emscripten memory +# -------------------- + +export EM_INITIAL_HEAP=$((256 * 1024 * 1024)) +export EM_MAXIMUM_MEMORY=$((4 * 1024 * 1024 * 1024)) +export EM_MEMORY_GROWTH_GEOMETRIC_STEP=0.8 +export EM_MALLOC=dlmalloc + +# -------------------- +# Flags +# -------------------- + +EMCC_COMMON_FLAGS=( + --no-entry + -sASSERTIONS=1 + -sALLOW_TABLE_GROWTH=1 + -sALLOW_MEMORY_GROWTH=1 + -sINITIAL_HEAP=$EM_INITIAL_HEAP + -sMEMORY_GROWTH_GEOMETRIC_STEP=$EM_MEMORY_GROWTH_GEOMETRIC_STEP + -sMAXIMUM_MEMORY=$EM_MAXIMUM_MEMORY + -sERROR_ON_UNDEFINED_SYMBOLS=0 + -sDISABLE_EXCEPTION_CATCHING=0 + -sEXPORT_NAME=createGraphModule + -sEXPORTED_RUNTIME_METHODS=stringToUTF8,HEAPU8 + -sENVIRONMENT=web + -sMODULARIZE=1 + -sEXPORT_ES6=1 +) + +export RUSTFLAGS="-C link-arg=-sEXPORTED_FUNCTIONS=@${_SCRIPT_DIR}/exports.txt -C link-arg=-sEXPORT_ALL=0" + +# Mode-specific flags +if [[ "$BUILD_MODE" == "release" ]]; then + export EMCC_CFLAGS="-Os ${EMCC_COMMON_FLAGS[*]}" + CARGO_PARAMS=(--release "${@:2}") +else + export EMCC_CFLAGS="-g -sVERBOSE=1 -sMALLOC=$EM_MALLOC ${EMCC_COMMON_FLAGS[*]}" + CARGO_PARAMS=("${@:2}") +fi + +export CARGO_PARAMS + +# -------------------- +# Tasks +# -------------------- + +clean() { + cargo clean +} + +setup() { + : +} + +build() { + cargo build "${CARGO_PARAMS[@]}" +} + +copy_artifacts() { + local dest=$1 + local base="target/$CARGO_BUILD_TARGET/$BUILD_MODE" + + mkdir -p "$dest" + + cp "$base/graph_wasm.js" "$dest/$BUILD_NAME.js" + cp "$base/graph_wasm.wasm" "$dest/$BUILD_NAME.wasm" + + sed -i "s/graph_wasm.wasm/$BUILD_NAME.wasm?version=$CURRENT_VERSION/g" \ + "$dest/$BUILD_NAME.js" +} + +copy_shared_artifact() { + : +} diff --git a/graph-wasm/build b/graph-wasm/build new file mode 100755 index 0000000000..1107f12c3f --- /dev/null +++ b/graph-wasm/build @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +EMSDK_QUIET=1 . /opt/emsdk/emsdk_env.sh + +_SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd); +pushd $_SCRIPT_DIR; + +. ./_build_env + +set -ex; + +setup; +build; +copy_artifacts "../frontend/resources/public/js"; +copy_shared_artifact; + +exit $?; + +popd + diff --git a/graph-wasm/build.log b/graph-wasm/build.log new file mode 100644 index 0000000000..aab4f223df --- /dev/null +++ b/graph-wasm/build.log @@ -0,0 +1,26 @@ +~/penpot/graph-wasm ~/penpot/graph-wasm +./_build_env: line 60: export: `CXX_wasm32-unknown-emscripten=/home/penpot/penpot/graph-wasm/wrapper-em++.sh': not a valid identifier +./_build_env: line 110: export: `CXXFLAGS_wasm32-unknown-emscripten=-ffunction-sections -fdata-sections -fexceptions --target=wasm32-unknown-emscripten': not a valid identifier ++ setup ++ true ++ build ++ cargo build + Compiling graph v0.1.0 (/home/penpot/penpot/graph-wasm) +warning: unused import: `Error` + --> src/main.rs:1:48 + | +1 | use lbug::{Database, Connection, SystemConfig, Error}; + | ^^^^^ + | + = note: `#[warn(unused_imports)]` (part of `#[warn(unused)]`) on by default + +warning: variable does not need to be mutable + --> src/main.rs:23:9 + | +23 | let mut conn = match Connection::new(&db) { + | ----^^^^ + | | + | help: remove this `mut` + | + = note: `#[warn(unused_mut)]` (part of `#[warn(unused)]`) on by default + diff --git a/graph-wasm/build.rs b/graph-wasm/build.rs new file mode 100644 index 0000000000..1f151b4e94 --- /dev/null +++ b/graph-wasm/build.rs @@ -0,0 +1,2 @@ +// We need this empty script so OUT_DIR is automatically set +fn main() {} diff --git a/graph-wasm/exports.txt b/graph-wasm/exports.txt new file mode 100644 index 0000000000..cc71b6d709 --- /dev/null +++ b/graph-wasm/exports.txt @@ -0,0 +1,11 @@ +_hello +_generate_db +_init +_search_similar_shapes +_free_similar_shapes_buffer +_set_shape_parent +_set_shape_selrect +_set_shape_type +_use_shape + + diff --git a/graph-wasm/lbug-0.12.2/.gitignore b/graph-wasm/lbug-0.12.2/.gitignore new file mode 100644 index 0000000000..391ed4d660 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/.gitignore @@ -0,0 +1,5 @@ +target/ +debug/ + +**/*.rs.bk + diff --git a/graph-wasm/lbug-0.12.2/Cargo.lock b/graph-wasm/lbug-0.12.2/Cargo.lock new file mode 100644 index 0000000000..14cd17d92d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/Cargo.lock @@ -0,0 +1,1234 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "const-random", + "getrandom 0.3.3", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anstyle" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" + +[[package]] +name = "anyhow" +version = "1.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "arrow" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3f15b4c6b148206ff3a2b35002e08929c2462467b62b9c02036d9c34f9ef994" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", +] + +[[package]] +name = "arrow-arith" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30feb679425110209ae35c3fbf82404a39a4c0436bb3ec36164d8bffed2a4ce4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "num", +] + +[[package]] +name = "arrow-array" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70732f04d285d49054a48b72c54f791bb3424abae92d27aafdf776c98af161c8" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "hashbrown", + "num", +] + +[[package]] +name = "arrow-buffer" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "169b1d5d6cb390dd92ce582b06b23815c7953e9dfaaea75556e89d890d19993d" +dependencies = [ + "bytes", + "half", + "num", +] + +[[package]] +name = "arrow-cast" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4f12eccc3e1c05a766cafb31f6a60a46c2f8efec9b74c6e0648766d30686af8" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "atoi", + "base64", + "chrono", + "half", + "lexical-core", + "num", + "ryu", +] + +[[package]] +name = "arrow-data" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de1ce212d803199684b658fc4ba55fb2d7e87b213de5af415308d2fee3619c2" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num", +] + +[[package]] +name = "arrow-ord" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6506e3a059e3be23023f587f79c82ef0bcf6d293587e3272d20f2d30b969b5a7" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", +] + +[[package]] +name = "arrow-row" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52bf7393166beaf79b4bed9bfdf19e97472af32ce5b6b48169d321518a08cae2" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "half", +] + +[[package]] +name = "arrow-schema" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7686986a3bf2254c9fb130c623cdcb2f8e1f15763e7c71c310f0834da3d292" +dependencies = [ + "bitflags", +] + +[[package]] +name = "arrow-select" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd2b45757d6a2373faa3352d02ff5b54b098f5e21dccebc45a21806bc34501e5" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", +] + +[[package]] +name = "arrow-string" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0377d532850babb4d927a06294314b316e23311503ed580ec6ce6a0158f49d40" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "memchr", + "num", + "regex", + "regex-syntax", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" + +[[package]] +name = "bumpalo" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" + +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + +[[package]] +name = "cc" +version = "1.2.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5252b3d2648e5eedbc1a6f501e3c795e07025c1e93bbf8bbdd6eef7f447a6d54" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" + +[[package]] +name = "chrono" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +dependencies = [ + "iana-time-zone", + "num-traits", + "windows-link 0.2.0", +] + +[[package]] +name = "clap" +version = "4.5.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eac00902d9d136acd712710d71823fb8ac8004ca445a89e73a41d45aa712931" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ad9bbf750e73b5884fb8a211a9424a1906c1e156724260fdae972f31d70e1d6" +dependencies = [ + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_lex" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.16", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "cxx" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3956d60afa98653c5a57f60d7056edd513bfe0307ef6fb06f6167400c3884459" +dependencies = [ + "cc", + "cxxbridge-cmd", + "cxxbridge-flags", + "cxxbridge-macro", + "foldhash", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a4b7522f539fe056f1d6fc8577d8ab731451f6f33a89b1e5912e22b76c553e7" +dependencies = [ + "cc", + "codespan-reporting", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-cmd" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f01e92ab4ce9fd4d16e3bb11b158d98cbdcca803c1417aa43130a6526fbf208" +dependencies = [ + "clap", + "codespan-reporting", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c41cbfab344869e70998b388923f7d1266588f56c8ca284abf259b1c1ffc695" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d82a2f759f0ad3eae43b96604efd42b1d4729a35a6f2dc7bdb797ae25d9284" +dependencies = [ + "proc-macro2", + "quote", + "rustversion", + "syn", +] + +[[package]] +name = "deranged" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "errno" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fd99930f64d146689264c637b5af2f0233a933bef0d8570e2526bf9e083192d" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.11.1+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi 0.14.4+wasi-0.2.4", +] + +[[package]] +name = "half" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" + +[[package]] +name = "iana-time-zone" +version = "0.1.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "js-sys" +version = "0.3.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0b063578492ceec17683ef2f8c5e89121fbd0b172cbc280635ab7567db2738" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lbug" +version = "0.12.2" +dependencies = [ + "anyhow", + "arrow", + "cmake", + "cxx", + "cxx-build", + "rust_decimal", + "rust_decimal_macros", + "rustversion", + "tempfile", + "time", + "uuid", +] + +[[package]] +name = "lexical-core" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + +[[package]] +name = "link-cplusplus" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" +dependencies = [ + "cc", +] + +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + +[[package]] +name = "log" +version = "0.4.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" + +[[package]] +name = "memchr" +version = "2.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "proc-macro2" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "regex" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" + +[[package]] +name = "rust_decimal" +version = "1.37.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b203a6425500a03e0919c42d3c47caca51e79f1132046626d2c8871c5092035d" +dependencies = [ + "arrayvec", + "num-traits", +] + +[[package]] +name = "rust_decimal_macros" +version = "1.37.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6268b74858287e1a062271b988a0c534bf85bbeb567fe09331bf40ed78113d5" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "rustix" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.60.2", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "scratch" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b61f8f20e3a6f7e0649d825294eaf317edce30f82cf6026e7e4cb9222a7d1e" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys 0.60.2", +] + +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "time" +version = "0.3.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83bde6f1ec10e72d583d91623c939f623002284ef622b87de38cfd546cbf2031" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "time-macros" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "uuid" +version = "1.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasi" +version = "0.14.4+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e14915cadd45b529bb8d1f343c4ed0ac1de926144b746e2710f9cd05df6603b" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e28d1ba982ca7923fd01448d5c30c6864d0a14109560296a162f80f305fb93bb" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c3d463ae3eff775b0c45df9da45d68837702ac35af998361e2c84e7c5ec1b0d" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb4ce89b08211f923caf51d527662b75bdc9c9c7aab40f86dcb9fb85ac552aa" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f143854a3b13752c6950862c906306adb27c7e839f7414cec8fea35beab624c1" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.0", +] + +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link 0.1.3", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-link" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link 0.1.3", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link 0.1.3", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e201184e40b2ede64bc2ea34968b28e33622acdbbf37104f0e4a33f7abe657aa" +dependencies = [ + "windows-link 0.2.0", +] + +[[package]] +name = "windows-targets" +version = "0.53.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +dependencies = [ + "windows-link 0.1.3", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + +[[package]] +name = "wit-bindgen" +version = "0.45.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36" + +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/graph-wasm/lbug-0.12.2/Cargo.toml b/graph-wasm/lbug-0.12.2/Cargo.toml new file mode 100644 index 0000000000..ea839f84c3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/Cargo.toml @@ -0,0 +1,144 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +rust-version = "1.81" +name = "lbug" +version = "0.12.2" +build = "build.rs" +include = [ + "build.rs", + "/src", + "/include", + "/lbug-src/src", + "/lbug-src/cmake", + "/lbug-src/third_party", + "/lbug-src/CMakeLists.txt", + "/lbug-src/tools/CMakeLists.txt", +] +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "An in-process property graph database management system built for query speed and scalability" +homepage = "https://ladybugdb.com/" +readme = "lbug-src/README.md" +keywords = [ + "database", + "graph", + "ffi", +] +categories = ["database"] +license = "MIT" +repository = "https://github.com/lbugdb/lbug" + +[package.metadata.docs.rs] +all-features = true + +[features] +arrow = ["dep:arrow"] +default = [] +extension_tests = [] + +[lib] +name = "lbug" +path = "src/lib.rs" + +[dependencies.arrow] +version = "55" +features = ["ffi"] +optional = true +default-features = false + +[dependencies.cxx] +version = "=1.0.138" + +[dependencies.rust_decimal] +version = "1.37" +default-features = false + +[dependencies.time] +version = "0.3" + +[dependencies.uuid] +version = "1.6" + +[dev-dependencies.anyhow] +version = "1" + +[dev-dependencies.rust_decimal_macros] +version = "1.37" + +[dev-dependencies.tempfile] +version = "3" + +[dev-dependencies.time] +version = "0.3" +features = ["macros"] + +[build-dependencies.cmake] +version = "0.1" + +[build-dependencies.cxx-build] +version = "=1.0.138" + +[build-dependencies.rustversion] +version = "1" + +[lints.clippy] +inline_always = "allow" +missing_errors_doc = "allow" +missing_panics_doc = "allow" +module_name_repetitions = "allow" +must_use_candidate = "allow" +needless_pass_by_value = "allow" +redundant_closure_for_method_calls = "allow" +return_self_not_must_use = "allow" +similar_names = "allow" +struct_excessive_bools = "allow" +too_many_arguments = "allow" +too_many_lines = "allow" +type_complexity = "allow" +unreadable_literal = "allow" + +[lints.clippy.cargo] +level = "warn" +priority = -1 + +[lints.clippy.complexity] +level = "warn" +priority = -1 + +[lints.clippy.correctness] +level = "warn" +priority = -1 + +[lints.clippy.pedantic] +level = "warn" +priority = -1 + +[lints.clippy.perf] +level = "warn" +priority = -1 + +[lints.clippy.style] +level = "warn" +priority = -1 + +[lints.clippy.suspicious] +level = "warn" +priority = -1 + +[profile.relwithdebinfo] +debug = 2 +inherits = "release" diff --git a/graph-wasm/lbug-0.12.2/build.rs b/graph-wasm/lbug-0.12.2/build.rs new file mode 100644 index 0000000000..23f644a5af --- /dev/null +++ b/graph-wasm/lbug-0.12.2/build.rs @@ -0,0 +1,283 @@ +use std::env; +use std::path::{Path, PathBuf}; + +fn link_mode() -> &'static str { + if env::var("LBUG_SHARED").is_ok() { + "dylib" + } else { + "static" + } +} + +fn get_target() -> String { + env::var("PROFILE").unwrap() +} + +fn is_wasm_emscripten() -> bool { + env::var("TARGET") + .map(|t| t == "wasm32-unknown-emscripten") + .unwrap_or(false) +} + +fn link_libraries() { + // For wasm32-unknown-emscripten, we need to link lbug and all its dependencies + // These are built by CMake and need to be linked here + if is_wasm_emscripten() { + // Link all dependencies first (built by CMake) + for lib in [ + "utf8proc", + "antlr4_cypher", + "antlr4_runtime", + "re2", + "fastpfor", + "parquet", + "thrift", + "snappy", + "zstd", + "miniz", + "mbedtls", + "brotlidec", + "brotlicommon", + "lz4", + "roaring_bitmap", + "simsimd", + ] { + println!("cargo:rustc-link-lib=static={lib}"); + } + // Link the lbug static library (built by CMake) + println!("cargo:rustc-link-lib=static=lbug"); + // Don't link system libraries for wasm (they're handled by Emscripten) + return; + } + + // This also needs to be set by any crates using it if they want to use extensions + if !cfg!(windows) && link_mode() == "static" { + println!("cargo:rustc-link-arg=-rdynamic"); + } + if cfg!(windows) && link_mode() == "dylib" { + println!("cargo:rustc-link-lib=dylib=lbug_shared"); + } else if link_mode() == "dylib" { + println!("cargo:rustc-link-lib={}=lbug", link_mode()); + } else if rustversion::cfg!(since(1.82)) { + println!("cargo:rustc-link-lib=static:+whole-archive=lbug"); + } else { + println!("cargo:rustc-link-lib=static=lbug"); + } + if link_mode() == "static" { + if cfg!(windows) { + println!("cargo:rustc-link-lib=dylib=msvcrt"); + println!("cargo:rustc-link-lib=dylib=shell32"); + println!("cargo:rustc-link-lib=dylib=ole32"); + } else if cfg!(target_os = "macos") { + println!("cargo:rustc-link-lib=dylib=c++"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + + for lib in [ + "utf8proc", + "antlr4_cypher", + "antlr4_runtime", + "re2", + "fastpfor", + "parquet", + "thrift", + "snappy", + "zstd", + "miniz", + "mbedtls", + "brotlidec", + "brotlicommon", + "lz4", + "roaring_bitmap", + "simsimd", + ] { + if rustversion::cfg!(since(1.82)) { + println!("cargo:rustc-link-lib=static:+whole-archive={lib}"); + } else { + println!("cargo:rustc-link-lib=static={lib}"); + } + } + } +} + +fn build_bundled_cmake() -> Vec { + let lbug_root = { + let root = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("lbug-src"); + if root.is_symlink() || root.is_dir() { + root + } else { + // If the path is not directory, this is probably an in-source build on windows where the + // symlink is unreadable. + Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("../..") + } + }; + + let mut build = cmake::Config::new(&lbug_root); + build + .no_build_target(true) + .define("BUILD_SHELL", "OFF") + .define("BUILD_SINGLE_FILE_HEADER", "OFF") + .define("AUTO_UPDATE_GRAMMAR", "OFF"); + + // Configure for wasm32-unknown-emscripten + if is_wasm_emscripten() { + // Same configuration as ladybug/tools/wasm build + build.define("SINGLE_THREADED", "TRUE"); + // cmake-rs should automatically detect emscripten toolchain when CC/CXX point to emcc/em++ + } else if cfg!(windows) { + build.generator("Ninja"); + build.cxxflag("/EHsc"); + build.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreadedDLL"); + build.define("CMAKE_POLICY_DEFAULT_CMP0091", "NEW"); + } + if let Ok(jobs) = env::var("NUM_JOBS") { + // SAFETY: Setting environment variables in build scripts is safe + unsafe { + env::set_var("CMAKE_BUILD_PARALLEL_LEVEL", jobs); + } + } + let build_dir = build.build(); + + let lbug_lib_path = build_dir.join("build").join("src"); + println!("cargo:rustc-link-search=native={}", lbug_lib_path.display()); + + for dir in [ + "utf8proc", + "antlr4_cypher", + "antlr4_runtime", + "re2", + "brotli", + "alp", + "fastpfor", + "parquet", + "thrift", + "snappy", + "zstd", + "miniz", + "mbedtls", + "lz4", + "roaring_bitmap", + "simsimd", + ] { + let lib_path = build_dir + .join("build") + .join("third_party") + .join(dir) + .canonicalize() + .unwrap_or_else(|_| { + panic!( + "Could not find {}/build/third_party/{}", + build_dir.display(), + dir + ) + }); + println!("cargo:rustc-link-search=native={}", lib_path.display()); + } + + vec![ + lbug_root.join("src/include"), + build_dir.join("build/src"), + build_dir.join("build/src/include"), + lbug_root.join("third_party/nlohmann_json"), + lbug_root.join("third_party/fastpfor"), + lbug_root.join("third_party/alp/include"), + ] +} + +fn build_ffi( + bridge_file: &str, + out_name: &str, + source_file: &str, + bundled: bool, + include_paths: &Vec, +) { + let mut build = cxx_build::bridge(bridge_file); + build.file(source_file); + + if bundled { + build.define("LBUG_BUNDLED", None); + } + if get_target() == "debug" || get_target() == "relwithdebinfo" { + build.define("ENABLE_RUNTIME_CHECKS", "1"); + } + if link_mode() == "static" { + build.define("LBUG_STATIC_DEFINE", None); + } + + build.includes(include_paths); + + println!("cargo:rerun-if-env-changed=LBUG_SHARED"); + + println!("cargo:rerun-if-changed=include/lbug_rs.h"); + println!("cargo:rerun-if-changed=src/lbug_rs.cpp"); + // Note that this should match the lbug-src/* entries in the package.include list in Cargo.toml + // Unfortunately they appear to need to be specified individually since the symlink is + // considered to be changed each time. + println!("cargo:rerun-if-changed=lbug-src/src"); + println!("cargo:rerun-if-changed=lbug-src/cmake"); + println!("cargo:rerun-if-changed=lbug-src/third_party"); + println!("cargo:rerun-if-changed=lbug-src/CMakeLists.txt"); + println!("cargo:rerun-if-changed=lbug-src/tools/CMakeLists.txt"); + + if is_wasm_emscripten() { + // For emscripten, use C++20 and enable exceptions + build.flag("-std=c++20"); + build.flag("-fexceptions"); + // Note: -sDISABLE_EXCEPTION_CATCHING=0 is a linker flag, not a compiler flag + // It should be set via EMCC_CFLAGS environment variable or cargo rustc-link-arg + } else if cfg!(windows) { + build.flag("/std:c++20"); + build.flag("/MD"); + } else { + build.flag("-std=c++2a"); + } + build.compile(out_name); +} + +fn main() { + if env::var("DOCS_RS").is_ok() { + // Do nothing; we're just building docs and don't need the C++ library + return; + } + + let mut bundled = false; + let mut include_paths = + vec![Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("include")]; + + if let (Ok(lbug_lib_dir), Ok(lbug_include)) = + (env::var("LBUG_LIBRARY_DIR"), env::var("LBUG_INCLUDE_DIR")) + { + println!("cargo:rustc-link-search=native={lbug_lib_dir}"); + println!("cargo:rustc-link-arg=-Wl,-rpath,{lbug_lib_dir}"); + include_paths.push(Path::new(&lbug_include).to_path_buf()); + } else { + include_paths.extend(build_bundled_cmake()); + bundled = true; + } + // For wasm, we need to link libraries after building FFI to ensure proper symbol resolution + if !is_wasm_emscripten() && link_mode() == "static" { + link_libraries(); + } + build_ffi( + "src/ffi.rs", + "lbug_rs", + "src/lbug_rs.cpp", + bundled, + &include_paths, + ); + + if cfg!(feature = "arrow") { + build_ffi( + "src/ffi/arrow.rs", + "lbug_arrow_rs", + "src/lbug_arrow.cpp", + bundled, + &include_paths, + ); + } + // For wasm, link libraries after FFI; for dylib, link after FFI + if is_wasm_emscripten() || link_mode() == "dylib" { + link_libraries(); + } +} diff --git a/graph-wasm/lbug-0.12.2/include/lbug_arrow.h b/graph-wasm/lbug-0.12.2/include/lbug_arrow.h new file mode 100644 index 0000000000..a3587a7276 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/include/lbug_arrow.h @@ -0,0 +1,15 @@ +#pragma once + +#include "rust/cxx.h" +#ifdef LBUG_BUNDLED +#include "main/lbug.h" +#else +#include +#endif + +namespace lbug_arrow { + +ArrowSchema query_result_get_arrow_schema(const lbug::main::QueryResult& result); +ArrowArray query_result_get_next_arrow_chunk(lbug::main::QueryResult& result, uint64_t chunkSize); + +} // namespace lbug_arrow diff --git a/graph-wasm/lbug-0.12.2/include/lbug_rs.h b/graph-wasm/lbug-0.12.2/include/lbug_rs.h new file mode 100644 index 0000000000..64006650cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/include/lbug_rs.h @@ -0,0 +1,243 @@ +#pragma once + +#include +#include + +#include "rust/cxx.h" +#ifdef LBUG_BUNDLED +#include "common/type_utils.h" +#include "common/types/int128_t.h" +#include "common/types/types.h" +#include "common/types/value/nested.h" +#include "common/types/value/node.h" +#include "common/types/value/recursive_rel.h" +#include "common/types/value/rel.h" +#include "common/types/value/value.h" +#include "main/lbug.h" +#include "storage/storage_version_info.h" +#else +#include +#endif + +namespace lbug_rs { + +struct TypeListBuilder { + std::vector types; + + void insert(std::unique_ptr type) { + types.push_back(std::move(*type)); + } +}; + +std::unique_ptr create_type_list(); + +struct QueryParams { + std::unordered_map> inputParams; + + void insert(const rust::Str key, std::unique_ptr value) { + inputParams.insert(std::make_pair(key, std::move(value))); + } +}; + +std::unique_ptr new_params(); + +std::unique_ptr create_logical_type(lbug::common::LogicalTypeID id); +std::unique_ptr create_logical_type_list( + std::unique_ptr childType); +std::unique_ptr create_logical_type_array( + std::unique_ptr childType, uint64_t numElements); + +inline std::unique_ptr create_logical_type_struct( + const rust::Vec& fieldNames, std::unique_ptr fieldTypes) { + std::vector fields; + for (auto i = 0u; i < fieldNames.size(); i++) { + fields.emplace_back(std::string(fieldNames[i]), std::move(fieldTypes->types[i])); + } + return std::make_unique( + lbug::common::LogicalType::STRUCT(std::move(fields))); +} +inline std::unique_ptr create_logical_type_union( + const rust::Vec& fieldNames, std::unique_ptr fieldTypes) { + std::vector fields; + for (auto i = 0u; i < fieldNames.size(); i++) { + fields.emplace_back(std::string(fieldNames[i]), std::move(fieldTypes->types[i])); + } + return std::make_unique( + lbug::common::LogicalType::UNION(std::move(fields))); +} +std::unique_ptr create_logical_type_map( + std::unique_ptr keyType, + std::unique_ptr valueType); + +inline std::unique_ptr create_logical_type_decimal(uint32_t precision, + uint32_t scale) { + return std::make_unique( + lbug::common::LogicalType::DECIMAL(precision, scale)); +} + +std::unique_ptr logical_type_get_list_child_type( + const lbug::common::LogicalType& logicalType); +std::unique_ptr logical_type_get_array_child_type( + const lbug::common::LogicalType& logicalType); +uint64_t logical_type_get_array_num_elements(const lbug::common::LogicalType& logicalType); + +rust::Vec logical_type_get_struct_field_names(const lbug::common::LogicalType& value); +std::unique_ptr> logical_type_get_struct_field_types( + const lbug::common::LogicalType& value); + +inline uint32_t logical_type_get_decimal_precision(const lbug::common::LogicalType& logicalType) { + return lbug::common::DecimalType::getPrecision(logicalType); +} +inline uint32_t logical_type_get_decimal_scale(const lbug::common::LogicalType& logicalType) { + return lbug::common::DecimalType::getScale(logicalType); +} + +/* Database */ +std::unique_ptr new_database(std::string_view databasePath, + uint64_t bufferPoolSize, uint64_t maxNumThreads, bool enableCompression, bool readOnly, + uint64_t maxDBSize, bool autoCheckpoint, int64_t checkpointThreshold, + bool throwOnWalReplayFailure, bool enableChecksums); + +void database_set_logging_level(lbug::main::Database& database, const std::string& level); + +/* Connection */ +std::unique_ptr database_connect(lbug::main::Database& database); +std::unique_ptr connection_execute(lbug::main::Connection& connection, + lbug::main::PreparedStatement& query, std::unique_ptr params); +inline std::unique_ptr connection_query(lbug::main::Connection& connection, + std::string_view query) { + return connection.query(query); +} + +/* PreparedStatement */ +rust::String prepared_statement_error_message(const lbug::main::PreparedStatement& statement); + +/* QueryResult */ +rust::String query_result_to_string(const lbug::main::QueryResult& result); +rust::String query_result_get_error_message(const lbug::main::QueryResult& result); + +double query_result_get_compiling_time(const lbug::main::QueryResult& result); +double query_result_get_execution_time(const lbug::main::QueryResult& result); + +std::unique_ptr> query_result_column_data_types( + const lbug::main::QueryResult& query_result); +rust::Vec query_result_column_names(const lbug::main::QueryResult& query_result); + +/* NodeVal/RelVal */ +rust::String node_value_get_label_name(const lbug::common::Value& val); +rust::String rel_value_get_label_name(const lbug::common::Value& val); + +size_t node_value_get_num_properties(const lbug::common::Value& value); +size_t rel_value_get_num_properties(const lbug::common::Value& value); + +rust::String node_value_get_property_name(const lbug::common::Value& value, size_t index); +rust::String rel_value_get_property_name(const lbug::common::Value& value, size_t index); + +const lbug::common::Value& node_value_get_property_value(const lbug::common::Value& value, + size_t index); +const lbug::common::Value& rel_value_get_property_value(const lbug::common::Value& value, + size_t index); + +/* NodeVal */ +const lbug::common::Value& node_value_get_node_id(const lbug::common::Value& val); + +/* RelVal */ +const lbug::common::Value& rel_value_get_src_id(const lbug::common::Value& val); +std::array rel_value_get_dst_id(const lbug::common::Value& val); + +/* RecursiveRel */ +const lbug::common::Value& recursive_rel_get_nodes(const lbug::common::Value& val); +const lbug::common::Value& recursive_rel_get_rels(const lbug::common::Value& val); + +/* FlatTuple */ +const lbug::common::Value& flat_tuple_get_value(const lbug::processor::FlatTuple& flatTuple, + uint32_t index); + +/* Value */ +const std::string& value_get_string(const lbug::common::Value& value); + +template +std::unique_ptr value_get_unique(const lbug::common::Value& value) { + return std::make_unique(value.getValue()); +} + +int64_t value_get_interval_secs(const lbug::common::Value& value); +int32_t value_get_interval_micros(const lbug::common::Value& value); +int32_t value_get_date_days(const lbug::common::Value& value); +int64_t value_get_timestamp_ns(const lbug::common::Value& value); +int64_t value_get_timestamp_ms(const lbug::common::Value& value); +int64_t value_get_timestamp_sec(const lbug::common::Value& value); +int64_t value_get_timestamp_micros(const lbug::common::Value& value); +int64_t value_get_timestamp_tz(const lbug::common::Value& value); +std::array value_get_int128_t(const lbug::common::Value& value); +std::array value_get_internal_id(const lbug::common::Value& value); +uint32_t value_get_children_size(const lbug::common::Value& value); +const lbug::common::Value& value_get_child(const lbug::common::Value& value, uint32_t index); +lbug::common::LogicalTypeID value_get_data_type_id(const lbug::common::Value& value); +const lbug::common::LogicalType& value_get_data_type(const lbug::common::Value& value); +inline lbug::common::PhysicalTypeID value_get_physical_type(const lbug::common::Value& value) { + return value.getDataType().getPhysicalType(); +} +rust::String value_to_string(const lbug::common::Value& val); + +std::unique_ptr create_value_string(lbug::common::LogicalTypeID typ, + const rust::Slice value); +std::unique_ptr create_value_timestamp(const int64_t timestamp); +std::unique_ptr create_value_timestamp_tz(const int64_t timestamp); +std::unique_ptr create_value_timestamp_ns(const int64_t timestamp); +std::unique_ptr create_value_timestamp_ms(const int64_t timestamp); +std::unique_ptr create_value_timestamp_sec(const int64_t timestamp); +inline std::unique_ptr create_value_date(const int32_t date) { + return std::make_unique(lbug::common::date_t(date)); +} +std::unique_ptr create_value_interval(const int32_t months, const int32_t days, + const int64_t micros); +std::unique_ptr create_value_null( + std::unique_ptr typ); +std::unique_ptr create_value_int128_t(int64_t high, uint64_t low); +std::unique_ptr create_value_internal_id(uint64_t offset, uint64_t table); + +inline std::unique_ptr create_value_uuid_t(int64_t high, uint64_t low) { + return std::make_unique( + lbug::common::ku_uuid_t{lbug::common::int128_t(low, high)}); +} + +template +std::unique_ptr create_value(const T value) { + return std::make_unique(value); +} +inline std::unique_ptr create_value_decimal(int64_t high, uint64_t low, + uint32_t scale, uint32_t precision) { + auto value = + std::make_unique(lbug::common::LogicalType::DECIMAL(precision, scale), + std::vector>{}); + auto i128 = lbug::common::int128_t(low, high); + lbug::common::TypeUtils::visit( + value->getDataType().getPhysicalType(), + [&](lbug::common::int128_t) { value->val.int128Val = i128; }, + [&](int64_t) { value->val.int64Val = static_cast(i128); }, + [&](int32_t) { value->val.int32Val = static_cast(i128); }, + [&](int16_t) { value->val.int16Val = static_cast(i128); }, + [](auto) { KU_UNREACHABLE; }); + return value; +} + +struct ValueListBuilder { + std::vector> values; + + void insert(std::unique_ptr value) { values.push_back(std::move(value)); } +}; + +std::unique_ptr get_list_value(std::unique_ptr typ, + std::unique_ptr value); +std::unique_ptr create_list(); + +inline std::string_view string_view_from_str(rust::Str s) { + return {s.data(), s.size()}; +} + +inline lbug::storage::storage_version_t get_storage_version() { + return lbug::storage::StorageVersionInfo::getStorageVersion(); +} + +} // namespace lbug_rs diff --git a/graph-wasm/lbug-0.12.2/lbug-src/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/CMakeLists.txt new file mode 100644 index 0000000000..90890a5b51 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/CMakeLists.txt @@ -0,0 +1,454 @@ +cmake_minimum_required(VERSION 3.15) + +project(Lbug VERSION 0.12.2 LANGUAGES CXX C) + +option(SINGLE_THREADED "Single-threaded mode" FALSE) +if(SINGLE_THREADED) + set(__SINGLE_THREADED__ TRUE) + add_compile_definitions(__SINGLE_THREADED__) + message(STATUS "Single-threaded mode is enabled") +else() + message(STATUS "Multi-threaded mode is enabled: CMAKE_BUILD_PARALLEL_LEVEL=$ENV{CMAKE_BUILD_PARALLEL_LEVEL}") + find_package(Threads REQUIRED) +endif() + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED TRUE) +set(CMAKE_CXX_VISIBILITY_PRESET hidden) +set(CMAKE_C_VISIBILITY_PRESET hidden) +set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE) +set(CMAKE_FIND_PACKAGE_RESOLVE_SYMLINKS TRUE) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_VISIBILITY_INLINES_HIDDEN ON) +# On Linux, symbols in executables are not accessible by loaded shared libraries (e.g. via dlopen(3)). However, we need to export public symbols in executables so that extensions can access public symbols. This enables that behaviour. +set(CMAKE_ENABLE_EXPORTS TRUE) + +option(ENABLE_WERROR "Treat all warnings as errors" FALSE) +if(ENABLE_WERROR) + if (CMAKE_VERSION VERSION_GREATER "3.24.0" OR CMAKE_VERSION VERSION_EQUAL "3.24.0") + set(CMAKE_COMPILE_WARNING_AS_ERROR TRUE) + elseif (MSVC) + add_compile_options(\WX) + else () + add_compile_options(-Werror) + endif() +endif() + +# Detect OS and architecture, copied from DuckDB +set(OS_NAME "unknown") +set(OS_ARCH "amd64") + +string(REGEX MATCH "(arm64|aarch64)" IS_ARM "${CMAKE_SYSTEM_PROCESSOR}") +if(IS_ARM) + set(OS_ARCH "arm64") +elseif(FORCE_32_BIT) + set(OS_ARCH "i386") +endif() + +if(APPLE) + set(OS_NAME "osx") +endif() +if(WIN32) + set(OS_NAME "windows") +endif() +if(UNIX AND NOT APPLE) + set(OS_NAME "linux") # sorry BSD +endif() + +if(CMAKE_SIZEOF_VOID_P EQUAL 8) + message(STATUS "64-bit architecture detected") + add_compile_definitions(__64BIT__) +elseif(CMAKE_SIZEOF_VOID_P EQUAL 4) + message(STATUS "32-bit architecture detected") + add_compile_definitions(__32BIT__) + set(__32BIT__ TRUE) +endif() + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +if(DEFINED ENV{PYBIND11_PYTHON_VERSION}) + set(PYBIND11_PYTHON_VERSION $ENV{PYBIND11_PYTHON_VERSION}) +endif() + +if(DEFINED ENV{PYTHON_EXECUTABLE}) + set(PYTHON_EXECUTABLE $ENV{PYTHON_EXECUTABLE}) +endif() + +find_program(CCACHE_PROGRAM ccache) +if (CCACHE_PROGRAM) + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + message(STATUS "ccache found and enabled") +else () + find_program(CCACHE_PROGRAM sccache) + if (CCACHE_PROGRAM) + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + message(STATUS "sccache found and enabled") + endif () +endif () + +set(INSTALL_LIB_DIR + lib + CACHE PATH "Installation directory for libraries") +set(INSTALL_BIN_DIR + bin + CACHE PATH "Installation directory for executables") +set(INSTALL_INCLUDE_DIR + include + CACHE PATH "Installation directory for header files") +set(INSTALL_CMAKE_DIR + ${DEF_INSTALL_CMAKE_DIR} + CACHE PATH "Installation directory for CMake files") + +option(ENABLE_ADDRESS_SANITIZER "Enable address sanitizer." FALSE) +option(ENABLE_THREAD_SANITIZER "Enable thread sanitizer." FALSE) +option(ENABLE_UBSAN "Enable undefined behavior sanitizer." FALSE) +option(ENABLE_RUNTIME_CHECKS "Enable runtime coherency checks (e.g. asserts)" FALSE) +option(ENABLE_LTO "Enable Link-Time Optimization" FALSE) +option(ENABLE_MALLOC_BUFFER_MANAGER "Enable Buffer manager using malloc. Default option for webassembly" OFF) + +option(LBUG_DEFAULT_REL_STORAGE_DIRECTION "Only store fwd direction in rel tables by default." BOTH) +if(NOT LBUG_DEFAULT_REL_STORAGE_DIRECTION) + set(LBUG_DEFAULT_REL_STORAGE_DIRECTION BOTH) +endif() +set(LBUG_DEFAULT_REL_STORAGE_DIRECTION ${LBUG_DEFAULT_REL_STORAGE_DIRECTION}_REL_STORAGE) +option(LBUG_PAGE_SIZE_LOG2 "Log2 of the page size." 12) +if(NOT LBUG_PAGE_SIZE_LOG2) + set(LBUG_PAGE_SIZE_LOG2 12) +endif() +message(STATUS "LBUG_PAGE_SIZE_LOG2: ${LBUG_PAGE_SIZE_LOG2}") +option(LBUG_VECTOR_CAPACITY_LOG2 "Log2 of the vector capacity." 11) +if(NOT LBUG_VECTOR_CAPACITY_LOG2) + set(LBUG_VECTOR_CAPACITY_LOG2 11) +endif() +message(STATUS "LBUG_VECTOR_CAPACITY_LOG2: ${LBUG_VECTOR_CAPACITY_LOG2}") + +# 64 * 2048 nodes per group +option(LBUG_NODE_GROUP_SIZE_LOG2 "Log2 of the vector capacity." 17) +if(NOT LBUG_NODE_GROUP_SIZE_LOG2) + set(LBUG_NODE_GROUP_SIZE_LOG2 17) +endif() +message(STATUS "LBUG_NODE_GROUP_SIZE_LOG2: ${LBUG_NODE_GROUP_SIZE_LOG2}") + +option(LBUG_MAX_SEGMENT_SIZE_LOG2 "Log2 of the maximum segment size in bytes." 18) +if(NOT LBUG_MAX_SEGMENT_SIZE_LOG2) + set(LBUG_MAX_SEGMENT_SIZE_LOG2 18) +endif() +message(STATUS "LBUG_MAX_SEGMENT_SIZE_LOG2: ${LBUG_MAX_SEGMENT_SIZE_LOG2}") + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/templates/system_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/src/include/common/system_config.h @ONLY) + +include(CheckCXXSymbolExists) +check_cxx_symbol_exists(F_FULLFSYNC "fcntl.h" HAS_FULLFSYNC) +check_cxx_symbol_exists(fdatasync "unistd.h" HAS_FDATASYNC) +if(HAS_FULLFSYNC) + message(STATUS "✓ F_FULLFSYNC will be used on this platform") + add_compile_definitions(HAS_FULLFSYNC) +else() + message(STATUS "✗ F_FULLFSYNC not available") +endif() + +if(HAS_FDATASYNC) + message(STATUS "✓ fdatasync will be used on this platform") + add_compile_definitions(HAS_FDATASYNC) +else() + message(STATUS "✗ fdatasync not available, using fsync fallback") +endif() + +if(MSVC) + # Required for M_PI on Windows + add_compile_definitions(_USE_MATH_DEFINES) + add_compile_definitions(NOMINMAX) + add_compile_definitions(SERD_STATIC) + # This is a workaround for regex oom issue on windows in gtest. + add_compile_definitions(_REGEX_MAX_STACK_COUNT=0) + add_compile_definitions(_REGEX_MAX_COMPLEXITY_COUNT=0) + # Disable constexpr mutex constructor to avoid compatibility issues with + # older versions of the MSVC runtime library + # See: https://github.com/microsoft/STL/wiki/Changelog#vs-2022-1710 + add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR) + # TODO (bmwinger): Figure out if this can be set automatically by cmake, + # or at least better integrated with user-specified options + # For now, hardcode _AMD64_ + # CMAKE_GENERATOR_PLATFORM can be used for visual studio builds, but not for ninja + add_compile_definitions(_AMD64_) + # Non-english windows system may use other encodings other than utf-8 (e.g. Chinese use GBK). + add_compile_options("/utf-8") + # Enables support for custom hardware exception handling + add_compile_options("/EHa") + # Reduces the size of the static library by roughly 1/2 + add_compile_options("/Zc:inline") + # Disable type conversion warnings + add_compile_options(/wd4244 /wd4267) + # Remove the default to avoid warnings + STRING(REPLACE "/EHsc" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + STRING(REPLACE "/EHs" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + # Store all libraries and binaries in the same directory so that lbug_shared.dll is found at runtime + set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/src") + set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}/src") + # This is a workaround for regex stackoverflow issue on windows in gtest. + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /STACK:8388608") + + string(REGEX REPLACE "/W[3|4]" "/w" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + add_compile_options($<$:/W0>) +else() + add_compile_options(-Wall -Wextra) + # Disable warnings for unknown pragmas, which is used by several third-party libraries + add_compile_options(-Wno-unknown-pragmas) +endif() + +if(${BUILD_WASM}) + if(NOT __SINGLE_THREADED__) + add_compile_options(-pthread) + add_link_options(-pthread) + add_link_options(-sPTHREAD_POOL_SIZE=8) + endif() + add_compile_options(-s DISABLE_EXCEPTION_CATCHING=0) + add_link_options(-sSTACK_SIZE=4MB) + add_link_options(-sASSERTIONS=1) + add_link_options(-lembind) + add_link_options(-sWASM_BIGINT) + + if(BUILD_TESTS OR BUILD_EXTENSION_TESTS) + add_link_options(-sINITIAL_MEMORY=3892MB) + add_link_options(-sNODERAWFS=1) + elseif(WASM_NODEFS) + add_link_options(-sNODERAWFS=1) + add_link_options(-sALLOW_MEMORY_GROWTH=1) + add_link_options(-sMODULARIZE=1) + add_link_options(-sEXPORTED_RUNTIME_METHODS=FS,wasmMemory) + add_link_options(-sEXPORT_NAME=lbug) + add_link_options(-sMAXIMUM_MEMORY=4GB) + else() + add_link_options(-sSINGLE_FILE=1) + add_link_options(-sALLOW_MEMORY_GROWTH=1) + add_link_options(-sMODULARIZE=1) + add_link_options(-sEXPORTED_RUNTIME_METHODS=FS,wasmMemory) + add_link_options(-lidbfs.js) + add_link_options(-lworkerfs.js) + add_link_options(-sEXPORT_NAME=lbug) + add_link_options(-sMAXIMUM_MEMORY=4GB) + endif() + set(__WASM__ TRUE) + add_compile_options(-fexceptions) + add_link_options(-s DISABLE_EXCEPTION_CATCHING=0) + add_link_options(-fexceptions) + add_compile_definitions(__WASM__) + set(ENABLE_MALLOC_BUFFER_MANAGER ON) +endif() + +if(${BUILD_SWIFT}) + add_compile_definitions(__SWIFT__) + set(ENABLE_MALLOC_BUFFER_MANAGER ON) +endif() + +if (${ENABLE_MALLOC_BUFFER_MANAGER}) + add_compile_definitions(BM_MALLOC) +endif() + +if(ANDROID_ABI) + message(STATUS "Android ABI detected: ${ANDROID_ABI}") + add_compile_definitions(__ANDROID__) + set(__ANDROID__ TRUE) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + add_compile_options(-Wno-restrict) # no restrict until https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105651 is fixed +endif() + +if(${ENABLE_THREAD_SANITIZER} AND (NOT __SINGLE_THREADED__)) + if(MSVC) + message(FATAL_ERROR "Thread sanitizer is not supported on MSVC") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread -fno-omit-frame-pointer") + endif() +endif() +if(${ENABLE_ADDRESS_SANITIZER}) + if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /fsanitize=address") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") + endif() +endif() +if(${ENABLE_UBSAN}) + if(MSVC) + message(FATAL_ERROR "Undefined behavior sanitizer is not supported on MSVC") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-omit-frame-pointer") + endif() +endif() + +if(${ENABLE_RUNTIME_CHECKS}) + add_compile_definitions(LBUG_RUNTIME_CHECKS) +endif() + +if (${ENABLE_DESER_DEBUG}) + add_compile_definitions(LBUG_DESER_DEBUG) +endif() + +if(${ENABLE_LTO}) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) +endif() + +option(AUTO_UPDATE_GRAMMAR "Automatically regenerate C++ grammar files on change." TRUE) +option(BUILD_BENCHMARK "Build benchmarks." FALSE) +option(BUILD_EXTENSIONS "Semicolon-separated list of extensions to build." "") +option(BUILD_EXAMPLES "Build examples." FALSE) +option(BUILD_JAVA "Build Java API." FALSE) +option(BUILD_NODEJS "Build NodeJS API." FALSE) +option(BUILD_PYTHON "Build Python API." FALSE) +option(BUILD_SHELL "Build Interactive Shell" TRUE) +option(BUILD_SINGLE_FILE_HEADER "Build single file header. Requires Python >= 3.9." TRUE) +option(BUILD_TESTS "Build C++ tests." FALSE) +option(BUILD_EXTENSION_TESTS "Build C++ extension tests." FALSE) +option(BUILD_LBUG "Build Lbug." TRUE) +option(ENABLE_BACKTRACES "Enable backtrace printing for exceptions and segfaults" FALSE) +option(USE_STD_FORMAT "Use std::format instead of a custom formatter." FALSE) +option(PREFER_SYSTEM_DEPS "Only download certain deps if not found on the system" TRUE) + +option(BUILD_LCOV "Build coverage report." FALSE) +if(${BUILD_LCOV}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage") +endif() + +if (ENABLE_BACKTRACES) + set(DOWNLOAD_CPPTRACE TRUE) + if(${PREFER_SYSTEM_DEPS}) + find_package(cpptrace QUIET) + if(cpptrace_FOUND) + message(STATUS "Using system cpptrace") + set(DOWNLOAD_CPPTRACE FALSE) + endif() + endif() + if(${DOWNLOAD_CPPTRACE}) + message(STATUS "Fetching cpptrace from GitHub...") + include(FetchContent) + FetchContent_Declare( + cpptrace + GIT_REPOSITORY https://github.com/jeremy-rifkin/cpptrace.git + GIT_TAG v0.8.3 + GIT_SHALLOW TRUE + ) + FetchContent_MakeAvailable(cpptrace) + endif() + add_compile_definitions(LBUG_BACKTRACE) +endif() + +if (USE_STD_FORMAT) + add_compile_definitions(USE_STD_FORMAT) +endif() + +function(add_lbug_test TEST_NAME) + set(SRCS ${ARGN}) + add_executable(${TEST_NAME} ${SRCS}) + target_link_libraries(${TEST_NAME} PRIVATE test_helper test_runner graph_test) + if (ENABLE_BACKTRACES) + target_link_libraries(${TEST_NAME} PRIVATE register_backtrace_signal_handler) + endif() + target_include_directories(${TEST_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/test/include) + include(GoogleTest) + + if (TEST_NAME STREQUAL "e2e_test") + gtest_discover_tests(${TEST_NAME} + DISCOVERY_TIMEOUT 600 + DISCOVERY_MODE PRE_TEST + TEST_PREFIX e2e_test_ + ) + else() + gtest_discover_tests(${TEST_NAME} + DISCOVERY_TIMEOUT 600 + DISCOVERY_MODE PRE_TEST + ) + endif() +endfunction() + +function(add_lbug_api_test TEST_NAME) + set(SRCS ${ARGN}) + add_executable(${TEST_NAME} ${SRCS}) + target_link_libraries(${TEST_NAME} PRIVATE api_graph_test api_test_helper) + if (ENABLE_BACKTRACES) + target_link_libraries(${TEST_NAME} PRIVATE register_backtrace_signal_handler) + endif() + target_include_directories(${TEST_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/test/include) + include(GoogleTest) + gtest_discover_tests(${TEST_NAME}) +endfunction() + +# Windows doesn't support dynamic lookup, so we have to link extensions against lbug. +if (MSVC AND (NOT BUILD_EXTENSIONS EQUAL "")) + set(BUILD_LBUG TRUE) +endif () + +include_directories(third_party/antlr4_cypher/include) +include_directories(third_party/antlr4_runtime/src) +include_directories(third_party/brotli/c/include) +include_directories(third_party/fast_float/include) +include_directories(third_party/mbedtls/include) +include_directories(third_party/parquet) +include_directories(third_party/snappy) +include_directories(third_party/thrift) +include_directories(third_party/miniz) +include_directories(third_party/nlohmann_json) +include_directories(third_party/pybind11/include) +include_directories(third_party/pyparse) +include_directories(third_party/re2/include) +include_directories(third_party/alp/include) +if (${BUILD_TESTS} OR ${BUILD_EXTENSION_TESTS}) + include_directories(third_party/spdlog) +elseif (${BUILD_BENCHMARK}) + include_directories(third_party/spdlog) +endif () +include_directories(third_party/utf8proc/include) +include_directories(third_party/zstd/include) +include_directories(third_party/httplib) +include_directories(third_party/pcg) +include_directories(third_party/lz4) +include_directories(third_party/roaring_bitmap) +# Use SYSTEM to suppress warnings from simsimd +include_directories(SYSTEM third_party/simsimd/include) + +add_subdirectory(third_party) + +add_definitions(-DLBUG_ROOT_DIRECTORY="${PROJECT_SOURCE_DIR}") +add_definitions(-DLBUG_CMAKE_VERSION="${CMAKE_PROJECT_VERSION}") +add_definitions(-DLBUG_EXTENSION_VERSION="0.12.0") + +if(BUILD_LBUG) +include_directories( + src/include + ${CMAKE_CURRENT_BINARY_DIR}/src/include +) +endif() + +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/extension/CMakeLists.txt") + add_subdirectory(extension) +endif () + +if(BUILD_LBUG) + +add_subdirectory(src) + +# Link extensions which require static linking. +foreach(ext IN LISTS STATICALLY_LINKED_EXTENSIONS) + if (${BUILD_EXTENSION_TESTS}) + add_compile_definitions(__STATIC_LINK_EXTENSION_TEST__) + endif () + target_link_libraries(lbug PRIVATE "lbug_${ext}_static_extension") + target_link_libraries(lbug_shared PRIVATE "lbug_${ext}_static_extension") +endforeach() + +if (${BUILD_TESTS} OR ${BUILD_EXTENSION_TESTS}) + add_subdirectory(test) +elseif (${BUILD_BENCHMARK}) + add_subdirectory(test/test_helper) +endif () +add_subdirectory(tools) +endif () + +if (${BUILD_EXAMPLES}) + add_subdirectory(examples/c) + add_subdirectory(examples/cpp) +endif() diff --git a/graph-wasm/lbug-0.12.2/lbug-src/README.md b/graph-wasm/lbug-0.12.2/lbug-src/README.md new file mode 100644 index 0000000000..8ffe98d7c9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/README.md @@ -0,0 +1,73 @@ +
+ + + Ladybug Logo + +
+ +
+ +

+ + Github Actions Badge + + discord + + twitter +

+ +# Ladybug +Ladybug is an embedded graph database built for query speed and scalability. Ladybug is optimized for handling complex analytical workloads +on very large databases and provides a set of retrieval features, such as a full text search and vector indices. Our core feature set includes: + +- Flexible Property Graph Data Model and Cypher query language +- Embeddable, serverless integration into applications +- Native full text search and vector index +- Columnar disk-based storage +- Columnar sparse row-based (CSR) adjacency list/join indices +- Vectorized and factorized query processor +- Novel and very fast join algorithms +- Multi-core query parallelism +- Serializable ACID transactions +- Wasm (WebAssembly) bindings for fast, secure execution in the browser + +Ladybug is being developed by [LadybugDB Developers](https://github.com/LadybugDB) and +is available under a permissible license. So try it out and help us make it better! We welcome your feedback and feature requests. + +The database was formerly known as [Kuzu](https://github.com/kuzudb/kuzu). + +## Installation + +> [!WARNING] +> Many of these binary installation methods are not functional yet. We need to work through package names, availability and convention issues. +> For now, use the build from source method. + +| Language | Installation | +| -------- |------------------------------------------------------------------------| +| Python | `pip install real_ladybug` | +| NodeJS | `npm install lbug` | +| Rust | `cargo add lbug` | +| Go | `go get github.com/lbugdb/go-lbug` | +| Swift | [lbug-swift](https://github.com/lbugdb/lbug-swift) | +| Java | [Maven Central](https://central.sonatype.com/artifact/com.ladybugdb/lbug) | +| C/C++ | [precompiled binaries](https://github.com/LadybugDB/ladybug/releases/latest) | +| CLI | [precompiled binaries](https://github.com/LadybugDB/ladybug/releases/latest) | + +To learn more about installation, see our [Installation](https://docs.ladybugdb.com/installation) page. + +## Getting Started + +Refer to our [Getting Started](https://docs.ladybugdb.com/get-started/) page for your first example. + +## Build from Source + +You can build from source using the instructions provided in the [developer guide](https://docs.ladybugdb.com/developer-guide/). + +## Contributing +We welcome contributions to Ladybug. If you are interested in contributing to Ladybug, please read our [Contributing Guide](CONTRIBUTING.md). + +## License +By contributing to Ladybug, you agree that your contributions will be licensed under the [MIT License](LICENSE). + +## Contact +You can contact us at [social@ladybugdb.com](mailto:social@ladybugdb.com) or [join our Discord community](https://discord.com/invite/hXyHmvW3Vy). diff --git a/graph-wasm/lbug-0.12.2/lbug-src/cmake/templates/system_config.h.in b/graph-wasm/lbug-0.12.2/lbug-src/cmake/templates/system_config.h.in new file mode 100644 index 0000000000..971c5874c2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/cmake/templates/system_config.h.in @@ -0,0 +1,75 @@ +/* + * This is a template header used for generating the header 'system_config.h' + * Any value in the format @VALUE_NAME@ can be substituted with a value passed into CMakeLists.txt + * See https://cmake.org/cmake/help/latest/command/configure_file.html for more details + */ + +#pragma once + +#include +#include + +#include "common/enums/extend_direction.h" + +#define BOTH_REL_STORAGE 0 +#define FWD_REL_STORAGE 1 +#define BWD_REL_STORAGE 2 + +namespace lbug { +namespace common { + +#define VECTOR_CAPACITY_LOG_2 @LBUG_VECTOR_CAPACITY_LOG2@ +#if VECTOR_CAPACITY_LOG_2 > 12 +#error "Vector capacity log2 should be less than or equal to 12" +#endif +constexpr uint64_t DEFAULT_VECTOR_CAPACITY = static_cast(1) << VECTOR_CAPACITY_LOG_2; + +// Currently the system supports files with 2 different pages size, which we refer to as +// PAGE_SIZE and TEMP_PAGE_SIZE. PAGE_SIZE is the default size of the page which is the +// unit of read/write to the database files. +static constexpr uint64_t PAGE_SIZE_LOG2 = @LBUG_PAGE_SIZE_LOG2@; // Default to 4KB. +static constexpr uint64_t LBUG_PAGE_SIZE = static_cast(1) << PAGE_SIZE_LOG2; +// Page size for files with large pages, e.g., temporary files that are used by operators that +// may require large amounts of memory. +static constexpr uint64_t TEMP_PAGE_SIZE_LOG2 = 18; +static const uint64_t TEMP_PAGE_SIZE = static_cast(1) << TEMP_PAGE_SIZE_LOG2; + +#define DEFAULT_REL_STORAGE_DIRECTION @LBUG_DEFAULT_REL_STORAGE_DIRECTION@ +#if DEFAULT_REL_STORAGE_DIRECTION == FWD_REL_STORAGE +static constexpr ExtendDirection DEFAULT_EXTEND_DIRECTION = ExtendDirection::FWD; +#elif DEFAULT_REL_STORAGE_DIRECTION == BWD_REL_STORAGE +static constexpr ExtendDirection DEFAULT_EXTEND_DIRECTION = ExtendDirection::BWD; +#else +static constexpr ExtendDirection DEFAULT_EXTEND_DIRECTION = ExtendDirection::BOTH; +#endif + +struct StorageConfig { + static constexpr uint64_t NODE_GROUP_SIZE_LOG2 = @LBUG_NODE_GROUP_SIZE_LOG2@; + static constexpr uint64_t NODE_GROUP_SIZE = static_cast(1) << NODE_GROUP_SIZE_LOG2; + // The number of CSR lists in a leaf region. + static constexpr uint64_t CSR_LEAF_REGION_SIZE_LOG2 = + std::min(static_cast(10), NODE_GROUP_SIZE_LOG2 - 1); + static constexpr uint64_t CSR_LEAF_REGION_SIZE = static_cast(1) + << CSR_LEAF_REGION_SIZE_LOG2; + static constexpr uint64_t CHUNKED_NODE_GROUP_CAPACITY = + std::min(static_cast(2048), NODE_GROUP_SIZE); + + // Maximum size for a segment in bytes + static constexpr uint64_t MAX_SEGMENT_SIZE_LOG2 = @LBUG_MAX_SEGMENT_SIZE_LOG2@; + static constexpr uint64_t MAX_SEGMENT_SIZE = 1 << MAX_SEGMENT_SIZE_LOG2; +}; + +struct OrderByConfig { + static constexpr uint64_t MIN_SIZE_TO_REDUCE = common::DEFAULT_VECTOR_CAPACITY * 5; +}; + +struct CopyConfig { + static constexpr uint64_t PANDAS_PARTITION_COUNT = 50 * DEFAULT_VECTOR_CAPACITY; +}; + +} // namespace common +} // namespace lbug + +#undef BOTH_REL_STORAGE +#undef FWD_REL_STORAGE +#undef BWD_REL_STORAGE diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/CMakeLists.txt new file mode 100644 index 0000000000..2250ce477f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/CMakeLists.txt @@ -0,0 +1,79 @@ +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +# Have to pass this down to every subdirectory, which actually adds the files. +# This doesn't affect parent directories. +add_compile_definitions(LBUG_EXPORTS) +add_compile_definitions(ANTLR4CPP_STATIC) +add_subdirectory(binder) +add_subdirectory(c_api) +add_subdirectory(catalog) +add_subdirectory(common) +add_subdirectory(expression_evaluator) +add_subdirectory(function) +add_subdirectory(graph) +add_subdirectory(main) +add_subdirectory(optimizer) +add_subdirectory(parser) +add_subdirectory(planner) +add_subdirectory(processor) +add_subdirectory(storage) +add_subdirectory(transaction) +add_subdirectory(extension) + +add_library(lbug STATIC ${ALL_OBJECT_FILES}) +add_library(lbug_shared SHARED ${ALL_OBJECT_FILES}) + +set(LBUG_LIBRARIES antlr4_cypher antlr4_runtime brotlidec brotlicommon fast_float utf8proc re2 fastpfor parquet snappy thrift yyjson zstd miniz mbedtls lz4 roaring_bitmap simsimd) +if (NOT __SINGLE_THREADED__) + set(LBUG_LIBRARIES ${LBUG_LIBRARIES} Threads::Threads) +endif() +if(NOT WIN32) + set(LBUG_LIBRARIES dl ${LBUG_LIBRARIES}) +endif() +# Seems to be needed for clang on linux only +# for compiling std::atomic::compare_exchange_weak +if ((NOT APPLE AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") AND NOT __WASM__ AND NOT __SINGLE_THREADED__) + set(LBUG_LIBRARIES atomic ${LBUG_LIBRARIES}) +endif() +if (ENABLE_BACKTRACES) + set(LBUG_LIBRARIES ${LBUG_LIBRARIES} cpptrace::cpptrace) +endif() +target_link_libraries(lbug PUBLIC ${LBUG_LIBRARIES}) +target_link_libraries(lbug_shared PUBLIC ${LBUG_LIBRARIES}) +unset(LBUG_LIBRARIES) + +set(LBUG_INCLUDES $ $ ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include/c_api ${CMAKE_CURRENT_BINARY_DIR}/../src/include) +target_include_directories(lbug PUBLIC ${LBUG_INCLUDES}) +target_include_directories(lbug_shared PUBLIC ${LBUG_INCLUDES}) +unset(LBUG_INCLUDES) + +if(WIN32) + # Anything linking against the static library must not use dllimport. + target_compile_definitions(lbug INTERFACE LBUG_STATIC_DEFINE) +endif() + +if(NOT WIN32) + set_target_properties(lbug_shared PROPERTIES OUTPUT_NAME lbug) +endif() + +install(TARGETS lbug lbug_shared) + +if(${BUILD_SINGLE_FILE_HEADER}) + # Create a command to generate lbug.hpp, and then create a target that is + # always built that depends on it. This allows our generator to detect when + # exactly to build lbug.hpp, while still building the target by default. + find_package(Python3 3.9...4 REQUIRED) + add_custom_command( + OUTPUT lbug.hpp + COMMAND + ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/collect-single-file-header.py ${CMAKE_CURRENT_BINARY_DIR}/.. + DEPENDS + ${PROJECT_SOURCE_DIR}/scripts/collect-single-file-header.py lbug_shared) + add_custom_target(single_file_header ALL DEPENDS lbug.hpp) +endif() + +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/include/c_api/lbug.h TYPE INCLUDE) + +if(${BUILD_SINGLE_FILE_HEADER}) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/lbug.hpp TYPE INCLUDE) +endif() diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/Cypher.g4 b/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/Cypher.g4 new file mode 100644 index 0000000000..e409515c72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/Cypher.g4 @@ -0,0 +1,917 @@ + +ku_Statements + : oC_Cypher ( SP? ';' SP? oC_Cypher )* SP? EOF ; + +oC_Cypher + : oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )?; + +oC_Statement + : oC_Query + | kU_CreateUser + | kU_CreateRole + | kU_CreateNodeTable + | kU_CreateRelTable + | kU_CreateSequence + | kU_CreateType + | kU_Drop + | kU_AlterTable + | kU_CopyFrom + | kU_CopyFromByColumn + | kU_CopyTO + | kU_StandaloneCall + | kU_CreateMacro + | kU_CommentOn + | kU_Transaction + | kU_Extension + | kU_ExportDatabase + | kU_ImportDatabase + | kU_AttachDatabase + | kU_DetachDatabase + | kU_UseDatabase; + +kU_CopyFrom + : COPY SP oC_SchemaName kU_ColumnNames? SP FROM SP kU_ScanSource ( SP? '(' SP? kU_Options SP? ')' )? ; + +kU_ColumnNames + : SP? '(' SP? (oC_SchemaName ( SP? ',' SP? oC_SchemaName )* SP?)? ')'; + +kU_ScanSource + : kU_FilePaths + | '(' SP? oC_Query SP? ')' + | oC_Parameter + | oC_Variable + | oC_Variable '.' SP? oC_SchemaName + | oC_FunctionInvocation ; + +kU_CopyFromByColumn + : COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ; + +kU_CopyTO + : COPY SP '(' SP? oC_Query SP? ')' SP TO SP StringLiteral ( SP? '(' SP? kU_Options SP? ')' )? ; + +kU_ExportDatabase + : EXPORT SP DATABASE SP StringLiteral ( SP? '(' SP? kU_Options SP? ')' )? ; + +kU_ImportDatabase + : IMPORT SP DATABASE SP StringLiteral; + +kU_AttachDatabase + : ATTACH SP StringLiteral (SP AS SP oC_SchemaName)? SP '(' SP? DBTYPE SP oC_SymbolicName (SP? ',' SP? kU_Options)? SP? ')' ; + +kU_Option + : oC_SymbolicName (SP? '=' SP? | SP*) oC_Literal | oC_SymbolicName; + +kU_Options + : kU_Option ( SP? ',' SP? kU_Option )* ; + +kU_DetachDatabase + : DETACH SP oC_SchemaName; + +kU_UseDatabase + : USE SP oC_SchemaName; + +kU_StandaloneCall + : CALL SP oC_SymbolicName SP? '=' SP? oC_Expression + | CALL SP oC_FunctionInvocation; + +kU_CommentOn + : COMMENT SP ON SP TABLE SP oC_SchemaName SP IS SP StringLiteral ; + +kU_CreateMacro + : CREATE SP MACRO SP oC_FunctionName SP? '(' SP? kU_PositionalArgs? SP? kU_DefaultArg? ( SP? ',' SP? kU_DefaultArg )* SP? ')' SP AS SP oC_Expression ; + +kU_PositionalArgs + : oC_SymbolicName ( SP? ',' SP? oC_SymbolicName )* ; + +kU_DefaultArg + : oC_SymbolicName SP? ':' '=' SP? oC_Literal ; + +kU_FilePaths + : '[' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ']' + | StringLiteral + | GLOB SP? '(' SP? StringLiteral SP? ')' ; + +kU_IfNotExists + : IF SP NOT SP EXISTS ; + +kU_CreateNodeTable + : CREATE SP NODE SP TABLE SP (kU_IfNotExists SP)? oC_SchemaName ( SP? '(' SP? kU_PropertyDefinitions SP? ( ',' SP? kU_CreateNodeConstraint )? SP? ')' | SP AS SP oC_Query ) ; + +kU_CreateRelTable + : CREATE SP REL SP TABLE ( SP GROUP )? ( SP kU_IfNotExists )? SP oC_SchemaName + SP? '(' SP? + kU_FromToConnections SP? ( + ( ',' SP? kU_PropertyDefinitions SP? )? + ( ',' SP? oC_SymbolicName SP? )? // Constraints + ')' + | ')' SP AS SP oC_Query ) + ( SP WITH SP? '(' SP? kU_Options SP? ')')? ; + +kU_FromToConnections + : kU_FromToConnection ( SP? ',' SP? kU_FromToConnection )* ; + +kU_FromToConnection + : FROM SP oC_SchemaName SP TO SP oC_SchemaName ; + +kU_CreateSequence + : CREATE SP SEQUENCE SP (kU_IfNotExists SP)? oC_SchemaName (SP kU_SequenceOptions)* ; + +kU_CreateType + : CREATE SP TYPE SP oC_SchemaName SP AS SP kU_DataType SP? ; + +kU_SequenceOptions + : kU_IncrementBy + | kU_MinValue + | kU_MaxValue + | kU_StartWith + | kU_Cycle; + +kU_WithPasswd + : SP WITH SP PASSWORD SP StringLiteral ; + +kU_CreateUser + : CREATE SP USER SP (kU_IfNotExists SP)? oC_Variable kU_WithPasswd? ; + +kU_CreateRole + : CREATE SP ROLE SP (kU_IfNotExists SP)? oC_Variable ; + +kU_IncrementBy : INCREMENT SP ( BY SP )? MINUS? oC_IntegerLiteral ; + +kU_MinValue : (NO SP MINVALUE) | (MINVALUE SP MINUS? oC_IntegerLiteral) ; + +kU_MaxValue : (NO SP MAXVALUE) | (MAXVALUE SP MINUS? oC_IntegerLiteral) ; + +kU_StartWith : START SP ( WITH SP )? MINUS? oC_IntegerLiteral ; + +kU_Cycle : (NO SP)? CYCLE ; + +kU_IfExists + : IF SP EXISTS ; + +kU_Drop + : DROP SP (TABLE | SEQUENCE | MACRO) SP (kU_IfExists SP)? oC_SchemaName ; + +kU_AlterTable + : ALTER SP TABLE SP oC_SchemaName SP kU_AlterOptions ; + +kU_AlterOptions + : kU_AddProperty + | kU_DropProperty + | kU_RenameTable + | kU_RenameProperty + | kU_AddFromToConnection + | kU_DropFromToConnection; + +kU_AddProperty + : ADD SP (kU_IfNotExists SP)? oC_PropertyKeyName SP kU_DataType ( SP kU_Default )? ; + +kU_Default + : DEFAULT SP oC_Expression ; + +kU_DropProperty + : DROP SP (kU_IfExists SP)? oC_PropertyKeyName ; + +kU_RenameTable + : RENAME SP TO SP oC_SchemaName ; + +kU_RenameProperty + : RENAME SP oC_PropertyKeyName SP TO SP oC_PropertyKeyName ; + +kU_AddFromToConnection + : ADD SP (kU_IfNotExists SP)? kU_FromToConnection ; + +kU_DropFromToConnection + : DROP SP (kU_IfExists SP)? kU_FromToConnection ; + +kU_ColumnDefinitions: kU_ColumnDefinition ( SP? ',' SP? kU_ColumnDefinition )* ; + +kU_ColumnDefinition : oC_PropertyKeyName SP kU_DataType ; + +kU_PropertyDefinitions : kU_PropertyDefinition ( SP? ',' SP? kU_PropertyDefinition )* ; + +kU_PropertyDefinition : kU_ColumnDefinition ( SP kU_Default )? ( SP PRIMARY SP KEY)?; + +kU_CreateNodeConstraint : PRIMARY SP KEY SP? '(' SP? oC_PropertyKeyName SP? ')' ; + +DECIMAL: ( 'D' | 'd' ) ( 'E' | 'e' ) ( 'C' | 'c' ) ( 'I' | 'i' ) ( 'M' | 'm' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ; + +kU_UnionType + : UNION SP? '(' SP? kU_ColumnDefinitions SP? ')' ; + +kU_StructType + : STRUCT SP? '(' SP? kU_ColumnDefinitions SP? ')' ; + +kU_MapType + : MAP SP? '(' SP? kU_DataType SP? ',' SP? kU_DataType SP? ')' ; + +kU_DecimalType + : DECIMAL SP? '(' SP? oC_IntegerLiteral SP? ',' SP? oC_IntegerLiteral SP? ')' ; + +kU_DataType + : oC_SymbolicName + | kU_DataType kU_ListIdentifiers + | kU_UnionType + | kU_StructType + | kU_MapType + | kU_DecimalType ; + +kU_ListIdentifiers : kU_ListIdentifier ( kU_ListIdentifier )* ; + +kU_ListIdentifier : '[' oC_IntegerLiteral? ']' ; + +oC_AnyCypherOption + : oC_Explain + | oC_Profile ; + +oC_Explain + : EXPLAIN (SP LOGICAL)? ; + +oC_Profile + : PROFILE ; + +kU_Transaction + : BEGIN SP TRANSACTION + | BEGIN SP TRANSACTION SP READ SP ONLY + | COMMIT + | ROLLBACK + | CHECKPOINT; + +kU_Extension + : kU_LoadExtension + | kU_InstallExtension + | kU_UninstallExtension + | kU_UpdateExtension ; + +kU_LoadExtension + : LOAD SP (EXTENSION SP)? ( StringLiteral | oC_Variable ) ; + +kU_InstallExtension + : (FORCE SP)? INSTALL SP oC_Variable (SP FROM SP StringLiteral)?; + +kU_UninstallExtension + : UNINSTALL SP oC_Variable; + +kU_UpdateExtension + : UPDATE SP oC_Variable; + +oC_Query + : oC_RegularQuery ; + +oC_RegularQuery + : oC_SingleQuery ( SP? oC_Union )* + | (oC_Return SP? )+ oC_SingleQuery { notifyReturnNotAtEnd($ctx->start); } + ; + +oC_Union + : ( UNION SP ALL SP? oC_SingleQuery ) + | ( UNION SP? oC_SingleQuery ) ; + +oC_SingleQuery + : oC_SinglePartQuery + | oC_MultiPartQuery + ; + +oC_SinglePartQuery + : ( oC_ReadingClause SP? )* oC_Return + | ( ( oC_ReadingClause SP? )* oC_UpdatingClause ( SP? oC_UpdatingClause )* ( SP? oC_Return )? ) + ; + +oC_MultiPartQuery + : ( kU_QueryPart SP? )+ oC_SinglePartQuery; + +kU_QueryPart + : (oC_ReadingClause SP? )* ( oC_UpdatingClause SP? )* oC_With ; + +oC_UpdatingClause + : oC_Create + | oC_Merge + | oC_Set + | oC_Delete + ; + +oC_ReadingClause + : oC_Match + | oC_Unwind + | kU_InQueryCall + | kU_LoadFrom + ; + +kU_LoadFrom + : LOAD ( SP WITH SP HEADERS SP? '(' SP? kU_ColumnDefinitions SP? ')' )? SP FROM SP kU_ScanSource (SP? '(' SP? kU_Options SP? ')')? (SP? oC_Where)? ; + + +oC_YieldItem + : ( oC_Variable SP AS SP )? oC_Variable ; + +oC_YieldItems + : oC_YieldItem ( SP? ',' SP? oC_YieldItem )* ; + +kU_InQueryCall + : CALL SP oC_FunctionInvocation (SP? oC_Where)? ( SP? YIELD SP oC_YieldItems )? ; + +oC_Match + : ( OPTIONAL SP )? MATCH SP? oC_Pattern ( SP oC_Where )? ( SP kU_Hint )? ; + +kU_Hint + : HINT SP kU_JoinNode; + +kU_JoinNode + : kU_JoinNode SP JOIN SP kU_JoinNode + | kU_JoinNode ( SP MULTI_JOIN SP oC_SchemaName)+ + | '(' SP? kU_JoinNode SP? ')' + | oC_SchemaName ; + +oC_Unwind : UNWIND SP? oC_Expression SP AS SP oC_Variable ; + +oC_Create + : CREATE SP? oC_Pattern ; + +// For unknown reason, openCypher use oC_PatternPart instead of oC_Pattern. There should be no difference in terms of planning. +// So we choose to be consistent with oC_Create and use oC_Pattern instead. +oC_Merge : MERGE SP? oC_Pattern ( SP oC_MergeAction )* ; + +oC_MergeAction + : ( ON SP MATCH SP oC_Set ) + | ( ON SP CREATE SP oC_Set ) + ; + +oC_Set + : SET SP? oC_SetItem ( SP? ',' SP? oC_SetItem )* + | SET SP? oC_Atom SP? '=' SP? kU_Properties; + +oC_SetItem + : ( oC_PropertyExpression SP? '=' SP? oC_Expression ) ; + +oC_Delete + : ( DETACH SP )? DELETE SP? oC_Expression ( SP? ',' SP? oC_Expression )*; + +oC_With + : WITH oC_ProjectionBody ( SP? oC_Where )? ; + +oC_Return + : RETURN oC_ProjectionBody ; + +oC_ProjectionBody + : ( SP? DISTINCT )? SP oC_ProjectionItems (SP oC_Order )? ( SP oC_Skip )? ( SP oC_Limit )? ; + +oC_ProjectionItems + : ( STAR ( SP? ',' SP? oC_ProjectionItem )* ) + | ( oC_ProjectionItem ( SP? ',' SP? oC_ProjectionItem )* ) + ; + +STAR : '*' ; + +oC_ProjectionItem + : ( oC_Expression SP AS SP oC_Variable ) + | oC_Expression + ; + +oC_Order + : ORDER SP BY SP oC_SortItem ( ',' SP? oC_SortItem )* ; + +oC_Skip + : L_SKIP SP oC_Expression ; + +L_SKIP : ( 'S' | 's' ) ( 'K' | 'k' ) ( 'I' | 'i' ) ( 'P' | 'p' ) ; + +oC_Limit + : LIMIT SP oC_Expression ; + +oC_SortItem + : oC_Expression ( SP? ( ASCENDING | ASC | DESCENDING | DESC ) )? ; + +oC_Where + : WHERE SP oC_Expression ; + +oC_Pattern + : oC_PatternPart ( SP? ',' SP? oC_PatternPart )* ; + +oC_PatternPart + : ( oC_Variable SP? '=' SP? oC_AnonymousPatternPart ) + | oC_AnonymousPatternPart ; + +oC_AnonymousPatternPart + : oC_PatternElement ; + +oC_PatternElement + : ( oC_NodePattern ( SP? oC_PatternElementChain )* ) + | ( '(' oC_PatternElement ')' ) + ; + +oC_NodePattern + : '(' SP? ( oC_Variable SP? )? ( oC_NodeLabels SP? )? ( kU_Properties SP? )? ')' ; + +oC_PatternElementChain + : oC_RelationshipPattern SP? oC_NodePattern ; + +oC_RelationshipPattern + : ( oC_LeftArrowHead SP? oC_Dash SP? oC_RelationshipDetail? SP? oC_Dash ) + | ( oC_Dash SP? oC_RelationshipDetail? SP? oC_Dash SP? oC_RightArrowHead ) + | ( oC_Dash SP? oC_RelationshipDetail? SP? oC_Dash ) + ; + +oC_RelationshipDetail + : '[' SP? ( oC_Variable SP? )? ( oC_RelationshipTypes SP? )? ( kU_RecursiveDetail SP? )? ( kU_Properties SP? )? ']' ; + +// The original oC_Properties definition is oC_MapLiteral | oC_Parameter. +// We choose to not support parameter as properties which will be the decision for a long time. +// We then substitute with oC_MapLiteral definition. We create oC_MapLiteral only when we decide to add MAP type. +kU_Properties + : '{' SP? ( oC_PropertyKeyName SP? ':' SP? oC_Expression SP? ( ',' SP? oC_PropertyKeyName SP? ':' SP? oC_Expression SP? )* )? '}'; + +oC_RelationshipTypes + : ':' SP? oC_RelTypeName ( SP? '|' ':'? SP? oC_RelTypeName )* ; + +oC_NodeLabels + : ':' SP? oC_LabelName ( SP? ('|' ':'? | ':') SP? oC_LabelName )* ; + +kU_RecursiveDetail + : '*' ( SP? kU_RecursiveType)? ( SP? oC_RangeLiteral )? ( SP? kU_RecursiveComprehension )? ; + +kU_RecursiveType + : (ALL SP)? WSHORTEST SP? '(' SP? oC_PropertyKeyName SP? ')' + | SHORTEST + | ALL SP SHORTEST + | TRAIL + | ACYCLIC ; + +oC_RangeLiteral + : oC_LowerBound? SP? DOTDOT SP? oC_UpperBound? + | oC_IntegerLiteral ; + +kU_RecursiveComprehension + : '(' SP? oC_Variable SP? ',' SP? oC_Variable ( SP? '|' SP? oC_Where SP? )? ( SP? '|' SP? kU_RecursiveProjectionItems SP? ',' SP? kU_RecursiveProjectionItems SP? )? ')' ; + +kU_RecursiveProjectionItems + : '{' SP? oC_ProjectionItems? SP? '}' ; + +oC_LowerBound + : DecimalInteger ; + +oC_UpperBound + : DecimalInteger ; + +oC_LabelName + : oC_SchemaName ; + +oC_RelTypeName + : oC_SchemaName ; + +oC_Expression + : oC_OrExpression ; + +oC_OrExpression + : oC_XorExpression ( SP OR SP oC_XorExpression )* ; + +oC_XorExpression + : oC_AndExpression ( SP XOR SP oC_AndExpression )* ; + +oC_AndExpression + : oC_NotExpression ( SP AND SP oC_NotExpression )* ; + +oC_NotExpression + : ( NOT SP? )* oC_ComparisonExpression; + +oC_ComparisonExpression + : kU_BitwiseOrOperatorExpression ( SP? kU_ComparisonOperator SP? kU_BitwiseOrOperatorExpression )? + | kU_BitwiseOrOperatorExpression ( SP? INVALID_NOT_EQUAL SP? kU_BitwiseOrOperatorExpression ) { notifyInvalidNotEqualOperator($INVALID_NOT_EQUAL); } + | kU_BitwiseOrOperatorExpression SP? kU_ComparisonOperator SP? kU_BitwiseOrOperatorExpression ( SP? kU_ComparisonOperator SP? kU_BitwiseOrOperatorExpression )+ { notifyNonBinaryComparison($ctx->start); } + ; + +kU_ComparisonOperator : '=' | '<>' | '<' | '<=' | '>' | '>=' ; + +INVALID_NOT_EQUAL : '!=' ; + +kU_BitwiseOrOperatorExpression + : kU_BitwiseAndOperatorExpression ( SP? '|' SP? kU_BitwiseAndOperatorExpression )* ; + +kU_BitwiseAndOperatorExpression + : kU_BitShiftOperatorExpression ( SP? '&' SP? kU_BitShiftOperatorExpression )* ; + +kU_BitShiftOperatorExpression + : oC_AddOrSubtractExpression ( SP? kU_BitShiftOperator SP? oC_AddOrSubtractExpression )* ; + +kU_BitShiftOperator : '>>' | '<<' ; + +oC_AddOrSubtractExpression + : oC_MultiplyDivideModuloExpression ( SP? kU_AddOrSubtractOperator SP? oC_MultiplyDivideModuloExpression )* ; + +kU_AddOrSubtractOperator : '+' | '-' ; + +oC_MultiplyDivideModuloExpression + : oC_PowerOfExpression ( SP? kU_MultiplyDivideModuloOperator SP? oC_PowerOfExpression )* ; + +kU_MultiplyDivideModuloOperator : '*' | '/' | '%' ; + +oC_PowerOfExpression + : oC_StringListNullOperatorExpression ( SP? '^' SP? oC_StringListNullOperatorExpression )* ; + +oC_StringListNullOperatorExpression + : oC_UnaryAddSubtractOrFactorialExpression ( oC_StringOperatorExpression | oC_ListOperatorExpression+ | oC_NullOperatorExpression )? ; + +oC_ListOperatorExpression + : ( SP IN SP? oC_PropertyOrLabelsExpression ) + | ( '[' oC_Expression ']' ) + | ( '[' oC_Expression? ( COLON | DOTDOT ) oC_Expression? ']' ) ; + +COLON : ':' ; + +DOTDOT : '..' ; + +oC_StringOperatorExpression + : ( oC_RegularExpression | ( SP STARTS SP WITH ) | ( SP ENDS SP WITH ) | ( SP CONTAINS ) ) SP? oC_PropertyOrLabelsExpression ; + +oC_RegularExpression + : SP? '=~' ; + +oC_NullOperatorExpression + : ( SP IS SP NULL ) + | ( SP IS SP NOT SP NULL ) ; + +MINUS : '-' ; + +FACTORIAL : '!' ; + +oC_UnaryAddSubtractOrFactorialExpression + : ( MINUS SP? )* oC_PropertyOrLabelsExpression (SP? FACTORIAL)? ; + +oC_PropertyOrLabelsExpression + : oC_Atom ( SP? oC_PropertyLookup )* ; + +oC_Atom + : oC_Literal + | oC_Parameter + | oC_CaseExpression + | oC_ParenthesizedExpression + | oC_FunctionInvocation + | oC_PathPatterns + | oC_ExistCountSubquery + | oC_Variable + | oC_Quantifier + ; + +oC_Quantifier + : ( ALL SP? '(' SP? oC_FilterExpression SP? ')' ) + | ( ANY SP? '(' SP? oC_FilterExpression SP? ')' ) + | ( NONE SP? '(' SP? oC_FilterExpression SP? ')' ) + | ( SINGLE SP? '(' SP? oC_FilterExpression SP? ')' ) + ; + +oC_FilterExpression + : oC_IdInColl SP oC_Where ; + +oC_IdInColl + : oC_Variable SP IN SP oC_Expression ; + +oC_Literal + : oC_NumberLiteral + | StringLiteral + | oC_BooleanLiteral + | NULL + | oC_ListLiteral + | kU_StructLiteral + ; + +oC_BooleanLiteral + : TRUE + | FALSE + ; + +oC_ListLiteral + : '[' SP? ( oC_Expression SP? ( kU_ListEntry SP? )* )? ']' ; + +kU_ListEntry + : ',' SP? oC_Expression? ; + +kU_StructLiteral + : '{' SP? kU_StructField SP? ( ',' SP? kU_StructField SP? )* '}' ; + +kU_StructField + : ( oC_SymbolicName | StringLiteral ) SP? ':' SP? oC_Expression ; + +oC_ParenthesizedExpression + : '(' SP? oC_Expression SP? ')' ; + +oC_FunctionInvocation + : COUNT SP? '(' SP? '*' SP? ')' + | CAST SP? '(' SP? kU_FunctionParameter SP? ( ( AS SP? kU_DataType ) | ( ',' SP? kU_FunctionParameter ) ) SP? ')' + | oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( kU_FunctionParameter SP? ( ',' SP? kU_FunctionParameter SP? )* )? ')' ; + +oC_FunctionName + : oC_SymbolicName ; + +kU_FunctionParameter + : ( oC_SymbolicName SP? ':' '=' SP? )? oC_Expression + | kU_LambdaParameter ; + +kU_LambdaParameter + : kU_LambdaVars SP? '-' '>' SP? oC_Expression SP? ; + +kU_LambdaVars + : oC_SymbolicName + | '(' SP? oC_SymbolicName SP? ( ',' SP? oC_SymbolicName SP?)* ')' ; + +oC_PathPatterns + : oC_NodePattern ( SP? oC_PatternElementChain )+; + +oC_ExistCountSubquery + : (EXISTS | COUNT) SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? ( SP? kU_Hint )? SP? '}' ; + +oC_PropertyLookup + : '.' SP? ( oC_PropertyKeyName | STAR ) ; + +oC_CaseExpression + : ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_Expression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_Expression )? SP? END ; + +oC_CaseAlternative + : WHEN SP? oC_Expression SP? THEN SP? oC_Expression ; + +oC_Variable + : oC_SymbolicName ; + +StringLiteral + : ( '"' ( StringLiteral_0 | EscapedChar )* '"' ) + | ( '\'' ( StringLiteral_1 | EscapedChar )* '\'' ) + ; + +EscapedChar + : '\\' ( '\\' | '\'' | '"' | ( 'B' | 'b' ) | ( 'F' | 'f' ) | ( 'N' | 'n' ) | ( 'R' | 'r' ) | ( 'T' | 't' ) | ( ( 'X' | 'x' ) ( HexDigit HexDigit ) ) | ( ( 'U' | 'u' ) ( HexDigit HexDigit HexDigit HexDigit ) ) | ( ( 'U' | 'u' ) ( HexDigit HexDigit HexDigit HexDigit HexDigit HexDigit HexDigit HexDigit ) ) ) ; + +oC_NumberLiteral + : oC_DoubleLiteral + | oC_IntegerLiteral + ; + +oC_Parameter + : '$' ( oC_SymbolicName | DecimalInteger ) ; + +oC_PropertyExpression + : oC_Atom SP? oC_PropertyLookup ; + +oC_PropertyKeyName + : oC_SchemaName ; + +oC_IntegerLiteral + : DecimalInteger ; + +DecimalInteger + : ZeroDigit + | ( NonZeroDigit ( Digit )* ) + ; + +HexLetter + : ( 'A' | 'a' ) + | ( 'B' | 'b' ) + | ( 'C' | 'c' ) + | ( 'D' | 'd' ) + | ( 'E' | 'e' ) + | ( 'F' | 'f' ) + ; + +HexDigit + : Digit + | HexLetter + ; + +Digit + : ZeroDigit + | NonZeroDigit + ; + +NonZeroDigit + : NonZeroOctDigit + | '8' + | '9' + ; + +NonZeroOctDigit + : '1' + | '2' + | '3' + | '4' + | '5' + | '6' + | '7' + ; + +ZeroDigit + : '0' ; + +oC_DoubleLiteral + : ExponentDecimalReal + | RegularDecimalReal + ; + +ExponentDecimalReal + : ( ( Digit )+ | ( ( Digit )+ '.' ( Digit )+ ) | ( '.' ( Digit )+ ) ) ( 'E' | 'e' ) '-'? ( Digit )+ ; + +RegularDecimalReal + : ( Digit )* '.' ( Digit )+ ; + +oC_SchemaName + : oC_SymbolicName ; + +oC_SymbolicName + : UnescapedSymbolicName + | EscapedSymbolicName {if ($EscapedSymbolicName.text == "``") { notifyEmptyToken($EscapedSymbolicName); }} + | HexLetter + | kU_NonReservedKeywords + ; + +// example of BEGIN and END: TCKWith2.Scenario1 +kU_NonReservedKeywords + : COMMENT + | ADD + | ALTER + | AS + | ATTACH + | BEGIN + | BY + | CALL + | CHECKPOINT + | COMMENT + | COMMIT + | CONTAINS + | COPY + | COUNT + | CYCLE + | DATABASE + | DECIMAL + | DELETE + | DETACH + | DROP + | EXPLAIN + | EXPORT + | EXTENSION + | FORCE + | GRAPH + | IF + | IS + | IMPORT + | INCREMENT + | KEY + | LOAD + | LOGICAL + | MATCH + | MAXVALUE + | MERGE + | MINVALUE + | NO + | NODE + | PROJECT + | READ + | REL + | RENAME + | RETURN + | ROLLBACK + | ROLE + | SEQUENCE + | SET + | START + | STRUCT + | L_SKIP + | LIMIT + | TRANSACTION + | TYPE + | USE + | UNINSTALL + | UPDATE + | WRITE + | FROM + | TO + | YIELD + | USER + | PASSWORD + | MAP + ; + +UnescapedSymbolicName + : IdentifierStart ( IdentifierPart )* ; + +IdentifierStart + : ID_Start + | Pc + ; + +IdentifierPart + : ID_Continue + | Sc + ; + +EscapedSymbolicName + : ( '`' ( EscapedSymbolicName_0 )* '`' )+ ; + +SP + : ( WHITESPACE )+ ; + +WHITESPACE + : SPACE + | TAB + | LF + | VT + | FF + | CR + | FS + | GS + | RS + | US + | '\u1680' + | '\u180e' + | '\u2000' + | '\u2001' + | '\u2002' + | '\u2003' + | '\u2004' + | '\u2005' + | '\u2006' + | '\u2008' + | '\u2009' + | '\u200a' + | '\u2028' + | '\u2029' + | '\u205f' + | '\u3000' + | '\u00a0' + | '\u2007' + | '\u202f' + | CypherComment + ; + +CypherComment + : ( '/*' ( Comment_1 | ( '*' Comment_2 ) )* '*/' ) + | ( '//' ( Comment_3 )* CR? ( LF | EOF ) ) + ; + +oC_LeftArrowHead + : '<' + | '\u27e8' + | '\u3008' + | '\ufe64' + | '\uff1c' + ; + +oC_RightArrowHead + : '>' + | '\u27e9' + | '\u3009' + | '\ufe65' + | '\uff1e' + ; + +oC_Dash + : '-' + | '\u00ad' + | '\u2010' + | '\u2011' + | '\u2012' + | '\u2013' + | '\u2014' + | '\u2015' + | '\u2212' + | '\ufe58' + | '\ufe63' + | '\uff0d' + ; + +fragment FF : [\f] ; + +fragment EscapedSymbolicName_0 : ~[`] ; + +fragment RS : [\u001E] ; + +fragment ID_Continue : [\p{ID_Continue}] ; + +fragment Comment_1 : ~[*] ; + +fragment StringLiteral_1 : ~['\\] ; + +fragment Comment_3 : ~[\n\r] ; + +fragment Comment_2 : ~[/] ; + +fragment GS : [\u001D] ; + +fragment FS : [\u001C] ; + +fragment CR : [\r] ; + +fragment Sc : [\p{Sc}] ; + +fragment SPACE : [ ] ; + +fragment Pc : [\p{Pc}] ; + +fragment TAB : [\t] ; + +fragment StringLiteral_0 : ~["\\] ; + +fragment LF : [\n] ; + +fragment VT : [\u000B] ; + +fragment US : [\u001F] ; + +fragment ID_Start : [\p{ID_Start}] ; + +// This is used to capture unknown lexer input (e.g. !) to avoid parser exception. +Unknown : .; diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/README.md b/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/README.md new file mode 100644 index 0000000000..7bf65b6970 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/README.md @@ -0,0 +1 @@ +Neither `Cypher.g4` nor `keywords.txt` can be individually used to generate Ladybug's grammar. Rather, the files are combined to generate `scripts/antlr4/Cypher.g4`, which is immediately digestible. diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/keywords.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/keywords.txt new file mode 100644 index 0000000000..563f123f3c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/antlr4/keywords.txt @@ -0,0 +1,115 @@ +ACYCLIC +ANY +ADD +ALL +ALTER +AND +AS +ASC +ASCENDING +ATTACH +BEGIN +BY +CALL +CASE +CAST +CHECKPOINT +COLUMN +COMMENT +COMMIT +COMMIT_SKIP_CHECKPOINT +CONTAINS +COPY +COUNT +CREATE +CYCLE +DATABASE +DBTYPE +DEFAULT +DELETE +DESC +DESCENDING +DETACH +DISTINCT +DROP +ELSE +END +ENDS +EXISTS +EXPLAIN +EXPORT +EXTENSION +FALSE +FROM +FORCE +GLOB +GRAPH +GROUP +HEADERS +HINT +IMPORT +IF +IN +INCREMENT +INSTALL +IS +JOIN +KEY +LIMIT +LOAD +LOGICAL +MACRO +MATCH +MAXVALUE +MERGE +MINVALUE +MULTI_JOIN +NO +NODE +NOT +NONE +NULL +ON +ONLY +OPTIONAL +OR +ORDER +PRIMARY +PROFILE +PROJECT +READ +REL +RENAME +RETURN +ROLLBACK +ROLLBACK_SKIP_CHECKPOINT +SEQUENCE +SET +SHORTEST +START +STARTS +STRUCT +TABLE +THEN +TO +TRAIL +TRANSACTION +TRUE +TYPE +UNION +UNWIND +UNINSTALL +UPDATE +USE +WHEN +WHERE +WITH +WRITE +WSHORTEST +XOR +SINGLE +YIELD +USER +PASSWORD +ROLE +MAP diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/CMakeLists.txt new file mode 100644 index 0000000000..8d8a03a51c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/CMakeLists.txt @@ -0,0 +1,22 @@ +add_subdirectory(bind) +add_subdirectory(bind_expression) +add_subdirectory(ddl) +add_subdirectory(expression) +add_subdirectory(query) +add_subdirectory(rewriter) +add_subdirectory(visitor) + +add_library(lbug_binder + OBJECT + binder.cpp + binder_scope.cpp + bound_statement_result.cpp + bound_scan_source.cpp + bound_statement_rewriter.cpp + bound_statement_visitor.cpp + expression_binder.cpp + expression_visitor.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/CMakeLists.txt new file mode 100644 index 0000000000..04787014be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/CMakeLists.txt @@ -0,0 +1,32 @@ +add_subdirectory(copy) +add_subdirectory(ddl) +add_subdirectory(read) + +add_library( + lbug_binder_bind + OBJECT + bind_attach_database.cpp + bind_create_macro.cpp + bind_ddl.cpp + bind_detach_database.cpp + bind_explain.cpp + bind_extension_clause.cpp + bind_file_scan.cpp + bind_graph_pattern.cpp + bind_projection_clause.cpp + bind_query.cpp + bind_reading_clause.cpp + bind_standalone_call.cpp + bind_table_function.cpp + bind_transaction.cpp + bind_updating_clause.cpp + bind_extension.cpp + bind_export_database.cpp + bind_import_database.cpp + bind_use_database.cpp + bind_standalone_call_function.cpp + bind_table_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_attach_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_attach_database.cpp new file mode 100644 index 0000000000..8dbc95f229 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_attach_database.cpp @@ -0,0 +1,36 @@ +#include "binder/binder.h" +#include "binder/bound_attach_database.h" +#include "common/exception/binder.h" +#include "common/string_utils.h" +#include "parser/attach_database.h" +#include "parser/expression/parsed_literal_expression.h" + +namespace lbug { +namespace binder { + +static AttachInfo bindAttachInfo(const parser::AttachInfo& attachInfo) { + binder::AttachOption attachOption; + for (auto& [name, value] : attachInfo.options) { + if (value->getExpressionType() != common::ExpressionType::LITERAL) { + throw common::BinderException{"Attach option must be a literal expression."}; + } + auto val = value->constPtrCast()->getValue(); + attachOption.options.emplace(name, std::move(val)); + } + + if (common::StringUtils::getUpper(attachInfo.dbType) == common::ATTACHED_LBUG_DB_TYPE && + attachInfo.dbAlias.empty()) { + throw common::BinderException{"Attaching a lbug database without an alias is not allowed."}; + } + return binder::AttachInfo{attachInfo.dbPath, attachInfo.dbAlias, attachInfo.dbType, + std::move(attachOption)}; +} + +std::unique_ptr Binder::bindAttachDatabase(const parser::Statement& statement) { + auto& attachDatabase = statement.constCast(); + auto boundAttachInfo = bindAttachInfo(attachDatabase.getAttachInfo()); + return std::make_unique(std::move(boundAttachInfo)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_create_macro.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_create_macro.cpp new file mode 100644 index 0000000000..7f80046efc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_create_macro.cpp @@ -0,0 +1,35 @@ +#include "binder/binder.h" +#include "binder/bound_create_macro.h" +#include "catalog/catalog.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "parser/create_macro.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindCreateMacro(const Statement& statement) const { + auto& createMacro = ku_dynamic_cast(statement); + auto macroName = createMacro.getMacroName(); + StringUtils::toUpper(macroName); + if (catalog::Catalog::Get(*clientContext) + ->containsMacro(transaction::Transaction::Get(*clientContext), macroName)) { + throw BinderException{stringFormat("Macro {} already exists.", macroName)}; + } + parser::default_macro_args defaultArgs; + for (auto& defaultArg : createMacro.getDefaultArgs()) { + defaultArgs.emplace_back(defaultArg.first, defaultArg.second->copy()); + } + auto scalarMacro = + std::make_unique(createMacro.getMacroExpression()->copy(), + createMacro.getPositionalArgs(), std::move(defaultArgs)); + return std::make_unique(std::move(macroName), std::move(scalarMacro)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_ddl.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_ddl.cpp new file mode 100644 index 0000000000..3169cc7b3f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_ddl.cpp @@ -0,0 +1,492 @@ +#include "binder/binder.h" +#include "binder/ddl/bound_alter.h" +#include "binder/ddl/bound_create_sequence.h" +#include "binder/ddl/bound_create_table.h" +#include "binder/ddl/bound_create_type.h" +#include "binder/ddl/bound_drop.h" +#include "binder/expression_visitor.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "common/enums/extend_direction_util.h" +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/string_format.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "function/sequence/sequence_functions.h" +#include "main/client_context.h" +#include "parser/ddl/alter.h" +#include "parser/ddl/create_sequence.h" +#include "parser/ddl/create_table.h" +#include "parser/ddl/create_table_info.h" +#include "parser/ddl/create_type.h" +#include "parser/ddl/drop.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/expression/parsed_literal_expression.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +static void validatePropertyName(const std::vector& definitions) { + case_insensitve_set_t nameSet; + for (auto& definition : definitions) { + if (nameSet.contains(definition.getName())) { + throw BinderException(stringFormat( + "Duplicated column name: {}, column name must be unique.", definition.getName())); + } + if (Binder::reservedInColumnName(definition.getName())) { + throw BinderException( + stringFormat("{} is a reserved property name.", definition.getName())); + } + nameSet.insert(definition.getName()); + } +} + +std::vector Binder::bindPropertyDefinitions( + const std::vector& parsedDefinitions, const std::string& tableName) { + std::vector definitions; + for (auto& def : parsedDefinitions) { + auto type = LogicalType::convertFromString(def.getType(), clientContext); + auto defaultExpr = + resolvePropertyDefault(def.defaultExpr.get(), type, tableName, def.getName()); + auto boundExpr = expressionBinder.bindExpression(*defaultExpr); + if (boundExpr->dataType != type) { + expressionBinder.implicitCast(boundExpr, type); + } + auto columnDefinition = ColumnDefinition(def.getName(), std::move(type)); + definitions.emplace_back(std::move(columnDefinition), std::move(defaultExpr)); + } + validatePropertyName(definitions); + return definitions; +} + +std::unique_ptr Binder::resolvePropertyDefault(ParsedExpression* parsedDefault, + const LogicalType& type, const std::string& tableName, const std::string& propertyName) { + if (parsedDefault == nullptr) { // No default provided. + if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) { + auto serialName = SequenceCatalogEntry::getSerialName(tableName, propertyName); + auto literalExpr = std::make_unique(Value(serialName), ""); + return std::make_unique(function::NextValFunction::name, + std::move(literalExpr), "" /* rawName */); + } else { + return std::make_unique(Value::createNullValue(type), "NULL"); + } + } else { + if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) { + throw BinderException("No DEFAULT value should be set for SERIAL columns"); + } + return parsedDefault->copy(); + } +} + +static void validatePrimaryKey(const std::string& pkColName, + const std::vector& definitions) { + uint32_t primaryKeyIdx = UINT32_MAX; + for (auto i = 0u; i < definitions.size(); i++) { + if (definitions[i].getName() == pkColName) { + primaryKeyIdx = i; + } + } + if (primaryKeyIdx == UINT32_MAX) { + throw BinderException( + "Primary key " + pkColName + " does not match any of the predefined node properties."); + } + const auto& pkType = definitions[primaryKeyIdx].getType(); + if (!pkType.isInternalType()) { + throw BinderException(ExceptionMessage::invalidPKType(pkType.toString())); + } + switch (pkType.getPhysicalType()) { + case PhysicalTypeID::UINT8: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::INT8: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT128: + case PhysicalTypeID::UINT128: + case PhysicalTypeID::STRING: + case PhysicalTypeID::FLOAT: + case PhysicalTypeID::DOUBLE: + break; + default: + throw BinderException(ExceptionMessage::invalidPKType(pkType.toString())); + } +} + +BoundCreateTableInfo Binder::bindCreateTableInfo(const CreateTableInfo* info) { + switch (info->type) { + case TableType::NODE: { + return bindCreateNodeTableInfo(info); + } + case TableType::REL: { + return bindCreateRelTableGroupInfo(info); + } + default: { + KU_UNREACHABLE; + } + } +} + +BoundCreateTableInfo Binder::bindCreateNodeTableInfo(const CreateTableInfo* info) { + auto propertyDefinitions = bindPropertyDefinitions(info->propertyDefinitions, info->tableName); + auto& extraInfo = info->extraInfo->constCast(); + validatePrimaryKey(extraInfo.pKName, propertyDefinitions); + auto boundExtraInfo = std::make_unique(extraInfo.pKName, + std::move(propertyDefinitions)); + return BoundCreateTableInfo(CatalogEntryType::NODE_TABLE_ENTRY, info->tableName, + info->onConflict, std::move(boundExtraInfo), clientContext->useInternalCatalogEntry()); +} + +void Binder::validateNodeTableType(const TableCatalogEntry* entry) { + if (entry->getType() != CatalogEntryType::NODE_TABLE_ENTRY) { + throw BinderException(stringFormat("{} is not of type NODE.", entry->getName())); + } +} + +void Binder::validateTableExistence(const main::ClientContext& context, + const std::string& tableName) { + auto transaction = transaction::Transaction::Get(context); + if (!Catalog::Get(context)->containsTable(transaction, tableName)) { + throw BinderException{stringFormat("Table {} does not exist.", tableName)}; + } +} + +void Binder::validateColumnExistence(const TableCatalogEntry* entry, + const std::string& columnName) { + if (!entry->containsProperty(columnName)) { + throw BinderException{ + stringFormat("Column {} does not exist in table {}.", columnName, entry->getName())}; + } +} + +static ExtendDirection getStorageDirection(const case_insensitive_map_t& options) { + if (options.contains(TableOptionConstants::REL_STORAGE_DIRECTION_OPTION)) { + return ExtendDirectionUtil::fromString( + options.at(TableOptionConstants::REL_STORAGE_DIRECTION_OPTION).toString()); + } + return DEFAULT_EXTEND_DIRECTION; +} + +std::vector Binder::bindRelPropertyDefinitions(const CreateTableInfo& info) { + std::vector propertyDefinitions; + propertyDefinitions.emplace_back( + ColumnDefinition(InternalKeyword::ID, LogicalType::INTERNAL_ID())); + for (auto& definition : bindPropertyDefinitions(info.propertyDefinitions, info.tableName)) { + propertyDefinitions.push_back(definition.copy()); + } + return propertyDefinitions; +} + +BoundCreateTableInfo Binder::bindCreateRelTableGroupInfo(const CreateTableInfo* info) { + auto propertyDefinitions = bindRelPropertyDefinitions(*info); + auto& extraInfo = info->extraInfo->constCast(); + auto srcMultiplicity = RelMultiplicityUtils::getFwd(extraInfo.relMultiplicity); + auto dstMultiplicity = RelMultiplicityUtils::getBwd(extraInfo.relMultiplicity); + auto boundOptions = bindParsingOptions(extraInfo.options); + auto storageDirection = getStorageDirection(boundOptions); + // Bind from to pairs + node_table_id_pair_set_t nodePairsSet; + std::vector nodePairs; + for (auto& [srcTableName, dstTableName] : extraInfo.srcDstTablePairs) { + auto srcEntry = bindNodeTableEntry(srcTableName); + validateNodeTableType(srcEntry); + auto dstEntry = bindNodeTableEntry(dstTableName); + validateNodeTableType(dstEntry); + NodeTableIDPair pair{srcEntry->getTableID(), dstEntry->getTableID()}; + if (nodePairsSet.contains(pair)) { + throw BinderException( + stringFormat("Found duplicate FROM-TO {}-{} pairs.", srcTableName, dstTableName)); + } + nodePairsSet.insert(pair); + nodePairs.emplace_back(pair); + } + auto boundExtraInfo = + std::make_unique(std::move(propertyDefinitions), + srcMultiplicity, dstMultiplicity, storageDirection, std::move(nodePairs)); + return BoundCreateTableInfo(CatalogEntryType::REL_GROUP_ENTRY, info->tableName, + info->onConflict, std::move(boundExtraInfo), clientContext->useInternalCatalogEntry()); +} + +std::unique_ptr Binder::bindCreateTable(const Statement& statement) { + auto& createTable = statement.constCast(); + if (createTable.getSource()) { + return bindCreateTableAs(createTable); + } + auto boundCreateInfo = bindCreateTableInfo(createTable.getInfo()); + return std::make_unique(std::move(boundCreateInfo), + BoundStatementResult::createSingleStringColumnResult()); +} + +std::unique_ptr Binder::bindCreateTableAs(const Statement& statement) { + auto& createTable = statement.constCast(); + auto boundInnerQuery = bindQuery(*createTable.getSource()->statement.get()); + auto innerQueryResult = boundInnerQuery->getStatementResult(); + auto columnNames = innerQueryResult->getColumnNames(); + auto columnTypes = innerQueryResult->getColumnTypes(); + std::vector propertyDefinitions; + propertyDefinitions.reserve(columnNames.size()); + for (size_t i = 0; i < columnNames.size(); ++i) { + propertyDefinitions.emplace_back( + ColumnDefinition(std::string(columnNames[i]), columnTypes[i].copy())); + } + if (columnNames.empty()) { + throw BinderException("Subquery returns no columns"); + } + auto createInfo = createTable.getInfo(); + switch (createInfo->type) { + case TableType::NODE: { + // first column is primary key column temporarily for now + auto pkName = columnNames[0]; + validatePrimaryKey(pkName, propertyDefinitions); + auto boundCopyFromInfo = bindCopyNodeFromInfo(createInfo->tableName, propertyDefinitions, + createTable.getSource(), options_t{}, columnNames, columnTypes, false /* byColumn */); + auto boundExtraInfo = + std::make_unique(pkName, std::move(propertyDefinitions)); + auto boundCreateInfo = BoundCreateTableInfo(CatalogEntryType::NODE_TABLE_ENTRY, + createInfo->tableName, createInfo->onConflict, std::move(boundExtraInfo), + clientContext->useInternalCatalogEntry()); + auto boundCreateTable = std::make_unique(std::move(boundCreateInfo), + BoundStatementResult::createSingleStringColumnResult()); + boundCreateTable->setCopyInfo(std::move(boundCopyFromInfo)); + return boundCreateTable; + } + case TableType::REL: { + auto& extraInfo = createInfo->extraInfo->constCast(); + // Currently we don't support multiple from/to pairs for create rel table as + if (extraInfo.srcDstTablePairs.size() > 1) { + throw BinderException( + "Multiple FROM/TO pairs are not supported for CREATE REL TABLE AS."); + } + propertyDefinitions.insert(propertyDefinitions.begin(), + PropertyDefinition(ColumnDefinition(InternalKeyword::ID, LogicalType::INTERNAL_ID()))); + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto fromTable = + catalog->getTableCatalogEntry(transaction, extraInfo.srcDstTablePairs[0].first) + ->ptrCast(); + auto toTable = + catalog->getTableCatalogEntry(transaction, extraInfo.srcDstTablePairs[0].second) + ->ptrCast(); + auto boundCreateInfo = bindCreateRelTableGroupInfo(createInfo); + auto boundCopyFromInfo = bindCopyRelFromInfo(createInfo->tableName, propertyDefinitions, + createTable.getSource(), options_t{}, columnNames, columnTypes, fromTable, toTable); + boundCreateInfo.extraInfo->ptrCast()->propertyDefinitions = + std::move(propertyDefinitions); + auto boundCreateTable = std::make_unique(std::move(boundCreateInfo), + BoundStatementResult::createSingleStringColumnResult()); + boundCreateTable->setCopyInfo(std::move(boundCopyFromInfo)); + return boundCreateTable; + } + default: { + KU_UNREACHABLE; + } + } +} + +std::unique_ptr Binder::bindCreateType(const Statement& statement) const { + auto createType = statement.constPtrCast(); + auto name = createType->getName(); + LogicalType type = LogicalType::convertFromString(createType->getDataType(), clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + if (Catalog::Get(*clientContext)->containsType(transaction, name)) { + throw BinderException{stringFormat("Duplicated type name: {}.", name)}; + } + return std::make_unique(std::move(name), std::move(type)); +} + +std::unique_ptr Binder::bindCreateSequence(const Statement& statement) const { + auto& createSequence = statement.constCast(); + auto info = createSequence.getInfo(); + auto sequenceName = info.sequenceName; + int64_t startWith = 0; + int64_t increment = 0; + int64_t minValue = 0; + int64_t maxValue = 0; + auto transaction = transaction::Transaction::Get(*clientContext); + switch (info.onConflict) { + case ConflictAction::ON_CONFLICT_THROW: { + if (Catalog::Get(*clientContext)->containsSequence(transaction, sequenceName)) { + throw BinderException(sequenceName + " already exists in catalog."); + } + } break; + default: + break; + } + auto literal = ku_string_t{info.increment.c_str(), info.increment.length()}; + if (!function::CastString::tryCast(literal, increment)) { + throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64."); + } + if (increment == 0) { + throw BinderException("INCREMENT must be non-zero."); + } + + if (info.minValue == "") { + minValue = increment > 0 ? 1 : std::numeric_limits::min(); + } else { + literal = ku_string_t{info.minValue.c_str(), info.minValue.length()}; + if (!function::CastString::tryCast(literal, minValue)) { + throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64."); + } + } + if (info.maxValue == "") { + maxValue = increment > 0 ? std::numeric_limits::max() : -1; + } else { + literal = ku_string_t{info.maxValue.c_str(), info.maxValue.length()}; + if (!function::CastString::tryCast(literal, maxValue)) { + throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64."); + } + } + if (info.startWith == "") { + startWith = increment > 0 ? minValue : maxValue; + } else { + literal = ku_string_t{info.startWith.c_str(), info.startWith.length()}; + if (!function::CastString::tryCast(literal, startWith)) { + throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64."); + } + } + + if (maxValue < minValue) { + throw BinderException("SEQUENCE MAXVALUE should be greater than or equal to MINVALUE."); + } + if (startWith < minValue || startWith > maxValue) { + throw BinderException("SEQUENCE START value should be between MINVALUE and MAXVALUE."); + } + + auto boundInfo = BoundCreateSequenceInfo(sequenceName, startWith, increment, minValue, maxValue, + info.cycle, info.onConflict, false /* isInternal */); + return std::make_unique(std::move(boundInfo)); +} + +std::unique_ptr Binder::bindDrop(const Statement& statement) { + auto& drop = statement.constCast(); + return std::make_unique(drop.getDropInfo()); +} + +std::unique_ptr Binder::bindAlter(const Statement& statement) { + auto& alter = statement.constCast(); + switch (alter.getInfo()->type) { + case AlterType::RENAME: { + return bindRenameTable(statement); + } + case AlterType::ADD_PROPERTY: { + return bindAddProperty(statement); + } + case AlterType::DROP_PROPERTY: { + return bindDropProperty(statement); + } + case AlterType::RENAME_PROPERTY: { + return bindRenameProperty(statement); + } + case AlterType::COMMENT: { + return bindCommentOn(statement); + } + case AlterType::ADD_FROM_TO_CONNECTION: + case AlterType::DROP_FROM_TO_CONNECTION: { + return bindAlterFromToConnection(statement); + } + default: { + KU_UNREACHABLE; + } + } +} + +std::unique_ptr Binder::bindRenameTable(const Statement& statement) const { + auto& alter = statement.constCast(); + auto info = alter.getInfo(); + auto extraInfo = ku_dynamic_cast(info->extraInfo.get()); + auto tableName = info->tableName; + auto newName = extraInfo->newName; + auto boundExtraInfo = std::make_unique(newName); + auto boundInfo = + BoundAlterInfo(AlterType::RENAME, tableName, std::move(boundExtraInfo), info->onConflict); + return std::make_unique(std::move(boundInfo)); +} + +std::unique_ptr Binder::bindAddProperty(const Statement& statement) { + auto& alter = statement.constCast(); + auto info = alter.getInfo(); + auto extraInfo = info->extraInfo->ptrCast(); + auto tableName = info->tableName; + auto propertyName = extraInfo->propertyName; + auto type = LogicalType::convertFromString(extraInfo->dataType, clientContext); + auto columnDefinition = ColumnDefinition(propertyName, type.copy()); + auto defaultExpr = + resolvePropertyDefault(extraInfo->defaultValue.get(), type, tableName, propertyName); + auto boundDefault = expressionBinder.bindExpression(*defaultExpr); + boundDefault = expressionBinder.implicitCastIfNecessary(boundDefault, type); + if (ConstantExpressionVisitor::needFold(*boundDefault)) { + boundDefault = expressionBinder.foldExpression(boundDefault); + } + auto propertyDefinition = + PropertyDefinition(std::move(columnDefinition), std::move(defaultExpr)); + auto boundExtraInfo = std::make_unique(std::move(propertyDefinition), + std::move(boundDefault)); + auto boundInfo = BoundAlterInfo(AlterType::ADD_PROPERTY, tableName, std::move(boundExtraInfo), + info->onConflict); + return std::make_unique(std::move(boundInfo)); +} + +std::unique_ptr Binder::bindDropProperty(const Statement& statement) const { + auto& alter = statement.constCast(); + auto info = alter.getInfo(); + auto extraInfo = info->extraInfo->constPtrCast(); + auto tableName = info->tableName; + auto propertyName = extraInfo->propertyName; + auto boundExtraInfo = std::make_unique(propertyName); + auto boundInfo = BoundAlterInfo(AlterType::DROP_PROPERTY, tableName, std::move(boundExtraInfo), + info->onConflict); + return std::make_unique(std::move(boundInfo)); +} + +std::unique_ptr Binder::bindRenameProperty(const Statement& statement) const { + auto& alter = statement.constCast(); + auto info = alter.getInfo(); + auto extraInfo = info->extraInfo->constPtrCast(); + auto tableName = info->tableName; + auto propertyName = extraInfo->propertyName; + auto newName = extraInfo->newName; + auto boundExtraInfo = std::make_unique(newName, propertyName); + auto boundInfo = BoundAlterInfo(AlterType::RENAME_PROPERTY, tableName, + std::move(boundExtraInfo), info->onConflict); + return std::make_unique(std::move(boundInfo)); +} + +std::unique_ptr Binder::bindCommentOn(const Statement& statement) const { + auto& alter = statement.constCast(); + auto info = alter.getInfo(); + auto extraInfo = info->extraInfo->constPtrCast(); + auto tableName = info->tableName; + auto comment = extraInfo->comment; + auto boundExtraInfo = std::make_unique(comment); + auto boundInfo = + BoundAlterInfo(AlterType::COMMENT, tableName, std::move(boundExtraInfo), info->onConflict); + return std::make_unique(std::move(boundInfo)); +} + +std::unique_ptr Binder::bindAlterFromToConnection( + const Statement& statement) const { + auto& alter = statement.constCast(); + auto info = alter.getInfo(); + auto extraInfo = info->extraInfo->constPtrCast(); + auto tableName = info->tableName; + auto srcTableEntry = bindNodeTableEntry(extraInfo->srcTableName); + auto dstTableEntry = bindNodeTableEntry(extraInfo->dstTableName); + auto srcTableID = srcTableEntry->getTableID(); + auto dstTableID = dstTableEntry->getTableID(); + auto boundExtraInfo = std::make_unique(srcTableID, dstTableID); + auto boundInfo = + BoundAlterInfo(info->type, tableName, std::move(boundExtraInfo), info->onConflict); + return std::make_unique(std::move(boundInfo)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_detach_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_detach_database.cpp new file mode 100644 index 0000000000..5506e9afc2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_detach_database.cpp @@ -0,0 +1,14 @@ +#include "binder/binder.h" +#include "binder/bound_detach_database.h" +#include "parser/detach_database.h" + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindDetachDatabase(const parser::Statement& statement) { + auto& detachDatabase = statement.constCast(); + return std::make_unique(detachDatabase.getDBName()); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_explain.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_explain.cpp new file mode 100644 index 0000000000..789d246c4b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_explain.cpp @@ -0,0 +1,16 @@ +#include "binder/binder.h" +#include "binder/bound_explain.h" +#include "parser/explain_statement.h" + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindExplain(const parser::Statement& statement) { + auto& explain = statement.constCast(); + auto boundStatementToExplain = bind(*explain.getStatementToExplain()); + return std::make_unique(std::move(boundStatementToExplain), + explain.getExplainType()); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_export_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_export_database.cpp new file mode 100644 index 0000000000..3da7586a27 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_export_database.cpp @@ -0,0 +1,169 @@ +#include "binder/bound_export_database.h" +#include "binder/query/bound_regular_query.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/binder.h" +#include "common/file_system/virtual_file_system.h" +#include "common/string_utils.h" +#include "main/client_context.h" +#include "parser/parser.h" +#include "parser/port_db.h" +#include "parser/query/regular_query.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::catalog; +using namespace lbug::transaction; +using namespace lbug::storage; + +namespace lbug { +namespace binder { + +FileTypeInfo getFileType(case_insensitive_map_t& options) { + auto fileTypeInfo = + FileTypeInfo{FileType::PARQUET, PortDBConstants::DEFAULT_EXPORT_FORMAT_OPTION}; + if (options.contains(PortDBConstants::EXPORT_FORMAT_OPTION)) { + auto value = options.at(PortDBConstants::EXPORT_FORMAT_OPTION); + if (value.getDataType().getLogicalTypeID() != LogicalTypeID::STRING) { + throw BinderException("The type of format option must be a string."); + } + auto valueStr = value.getValue(); + StringUtils::toUpper(valueStr); + fileTypeInfo = FileTypeInfo{FileTypeUtils::fromString(valueStr), valueStr}; + options.erase(PortDBConstants::EXPORT_FORMAT_OPTION); + } + return fileTypeInfo; +} + +void bindExportTableData(ExportedTableData& tableData, const std::string& query, + main::ClientContext* context, Binder* binder) { + auto parsedStatement = Parser::parseQuery(query); + KU_ASSERT(parsedStatement.size() == 1); + auto parsedQuery = parsedStatement[0]->constPtrCast(); + context->setUseInternalCatalogEntry(true /* useInternalCatalogEntry */); + auto boundQuery = binder->bindQuery(*parsedQuery); + context->setUseInternalCatalogEntry(false /* useInternalCatalogEntry */); + auto columns = boundQuery->getStatementResult()->getColumns(); + for (auto& column : columns) { + auto columnName = column->hasAlias() ? column->getAlias() : column->toString(); + tableData.columnNames.push_back(columnName); + tableData.columnTypes.push_back(column->getDataType().copy()); + } + tableData.regularQuery = std::move(boundQuery); +} + +static std::string getExportNodeTableDataQuery(const TableCatalogEntry& entry) { + return stringFormat("match (a:`{}`) return a.*", entry.getName()); +} + +static std::string getExportRelTableDataQuery(const TableCatalogEntry& relGroupEntry, + const NodeTableCatalogEntry& srcEntry, const NodeTableCatalogEntry& dstEntry) { + return stringFormat("match (a:`{}`)-[r:`{}`]->(b:`{}`) return a.{},b.{},r.*;", + srcEntry.getName(), relGroupEntry.getName(), dstEntry.getName(), + srcEntry.getPrimaryKeyName(), dstEntry.getPrimaryKeyName()); +} + +static std::vector getExportInfo(const Catalog& catalog, + main::ClientContext* context, Binder* binder, FileTypeInfo& fileTypeInfo) { + auto transaction = Transaction::Get(*context); + std::vector exportData; + for (auto entry : catalog.getNodeTableEntries(transaction, false /*useInternal*/)) { + ExportedTableData tableData; + tableData.tableName = entry->getName(); + tableData.fileName = + entry->getName() + "." + StringUtils::getLower(fileTypeInfo.fileTypeStr); + auto query = getExportNodeTableDataQuery(*entry); + bindExportTableData(tableData, query, context, binder); + exportData.push_back(std::move(tableData)); + } + for (auto entry : catalog.getRelGroupEntries(transaction, false /* useInternal */)) { + auto& relGroupEntry = entry->constCast(); + for (auto& info : relGroupEntry.getRelEntryInfos()) { + ExportedTableData tableData; + auto srcTableID = info.nodePair.srcTableID; + auto dstTableID = info.nodePair.dstTableID; + auto& srcEntry = catalog.getTableCatalogEntry(transaction, srcTableID) + ->constCast(); + auto& dstEntry = catalog.getTableCatalogEntry(transaction, dstTableID) + ->constCast(); + tableData.tableName = entry->getName(); + tableData.fileName = + stringFormat("{}_{}_{}.{}", relGroupEntry.getName(), srcEntry.getName(), + dstEntry.getName(), StringUtils::getLower(fileTypeInfo.fileTypeStr)); + auto query = getExportRelTableDataQuery(relGroupEntry, srcEntry, dstEntry); + bindExportTableData(tableData, query, context, binder); + exportData.push_back(std::move(tableData)); + } + } + + for (auto indexEntry : catalog.getIndexEntries(transaction)) { + // Export + ExportedTableData tableData; + auto entry = indexEntry->getTableEntryToExport(context); + if (entry == nullptr) { + continue; + } + KU_ASSERT(entry->getTableType() == TableType::NODE); + tableData.tableName = entry->getName(); + tableData.fileName = + entry->getName() + "." + StringUtils::getLower(fileTypeInfo.fileTypeStr); + auto query = getExportNodeTableDataQuery(*entry); + bindExportTableData(tableData, query, context, binder); + exportData.push_back(std::move(tableData)); + } + return exportData; +} + +static bool schemaOnly(case_insensitive_map_t& parsedOptions, + const parser::ExportDB& exportDB) { + auto isSchemaOnlyOption = [](const std::pair& option) -> bool { + if (option.first != PortDBConstants::SCHEMA_ONLY_OPTION) { + return false; + } + if (option.second.getDataType() != LogicalType::BOOL()) { + throw common::BinderException{common::stringFormat( + "The '{}' option must have a BOOL value.", PortDBConstants::SCHEMA_ONLY_OPTION)}; + } + return option.second.getValue(); + }; + auto exportSchemaOnly = + std::count_if(parsedOptions.begin(), parsedOptions.end(), isSchemaOnlyOption) != 0; + if (exportSchemaOnly && exportDB.getParsingOptionsRef().size() != 1) { + throw common::BinderException{ + common::stringFormat("When '{}' option is set to true in export " + "database, no other options are allowed.", + PortDBConstants::SCHEMA_ONLY_OPTION)}; + } + parsedOptions.erase(PortDBConstants::SCHEMA_ONLY_OPTION); + return exportSchemaOnly; +} + +std::unique_ptr Binder::bindExportDatabaseClause(const Statement& statement) { + auto& exportDB = statement.constCast(); + auto parsedOptions = bindParsingOptions(exportDB.getParsingOptionsRef()); + auto fileTypeInfo = getFileType(parsedOptions); + switch (fileTypeInfo.fileType) { + case FileType::CSV: + case FileType::PARQUET: + break; + default: + throw BinderException("Export database currently only supports csv and parquet files."); + } + auto exportSchemaOnly = schemaOnly(parsedOptions, exportDB); + if (!exportSchemaOnly && fileTypeInfo.fileType != FileType::CSV && parsedOptions.size() != 0) { + throw BinderException{"Only export to csv can have options."}; + } + auto exportData = + getExportInfo(*Catalog::Get(*clientContext), clientContext, this, fileTypeInfo); + auto boundFilePath = VirtualFileSystem::GetUnsafe(*clientContext) + ->expandPath(clientContext, exportDB.getFilePath()); + return std::make_unique(boundFilePath, fileTypeInfo, std::move(exportData), + std::move(parsedOptions), exportSchemaOnly); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_extension.cpp new file mode 100644 index 0000000000..1ba10b366c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_extension.cpp @@ -0,0 +1,79 @@ +#include "binder/binder.h" +#include "binder/bound_extension_statement.h" +#include "common/exception/binder.h" +#include "common/file_system/local_file_system.h" +#include "common/string_utils.h" +#include "extension/extension.h" +#include "parser/extension_statement.h" + +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +static void bindInstallExtension(const ExtensionAuxInfo& auxInfo) { + if (!ExtensionUtils::isOfficialExtension(auxInfo.path)) { + throw common::BinderException( + common::stringFormat("{} is not an official extension.\nNon-official extensions " + "can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.", + auxInfo.path)); + } +} + +static void bindLoadExtension(main::ClientContext* context, const ExtensionAuxInfo& auxInfo) { + auto localFileSystem = common::LocalFileSystem(""); + if (ExtensionUtils::isOfficialExtension(auxInfo.path)) { + auto extensionName = common::StringUtils::getLower(auxInfo.path); + if (!localFileSystem.fileOrPathExists( + ExtensionUtils::getLocalPathForExtensionLib(context, extensionName))) { + throw common::BinderException( + common::stringFormat("Extension: {} is an official extension and has not been " + "installed.\nYou can install it by: install {}.", + extensionName, extensionName)); + } + return; + } + if (!localFileSystem.fileOrPathExists(auxInfo.path, nullptr /* clientContext */)) { + throw common::BinderException( + common::stringFormat("The extension {} is neither an official extension, nor does " + "the extension path: '{}' exists.", + auxInfo.path, auxInfo.path)); + } +} + +static void bindUninstallExtension(const ExtensionAuxInfo& auxInfo) { + if (!ExtensionUtils::isOfficialExtension(auxInfo.path)) { + throw common::BinderException( + common::stringFormat("The extension {} is not an official extension.\nOnly official " + "extensions can be uninstalled.", + auxInfo.path)); + } +} + +std::unique_ptr Binder::bindExtension(const Statement& statement) { +#ifdef __WASM__ + throw common::BinderException{"Extensions are not available in the WASM environment"}; +#endif + auto extensionStatement = statement.constPtrCast(); + auto auxInfo = extensionStatement->getAuxInfo(); + switch (auxInfo->action) { + case ExtensionAction::INSTALL: + bindInstallExtension(*auxInfo); + break; + case ExtensionAction::LOAD: + bindLoadExtension(clientContext, *auxInfo); + break; + case ExtensionAction::UNINSTALL: + bindUninstallExtension(*auxInfo); + break; + default: + KU_UNREACHABLE; + } + if (ExtensionUtils::isOfficialExtension(auxInfo->path)) { + common::StringUtils::toLower(auxInfo->path); + } + return std::make_unique(std::move(auxInfo)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_extension_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_extension_clause.cpp new file mode 100644 index 0000000000..af937fc206 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_extension_clause.cpp @@ -0,0 +1,21 @@ +#include "binder/binder.h" +#include "extension/binder_extension.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindExtensionClause(const parser::Statement& statement) { + for (auto& binderExtension : binderExtensions) { + auto boundStatement = binderExtension->bind(statement); + if (boundStatement) { + return boundStatement; + } + } + KU_UNREACHABLE; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_file_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_file_scan.cpp new file mode 100644 index 0000000000..1627f3ee92 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_file_scan.cpp @@ -0,0 +1,312 @@ +#include "binder/binder.h" +#include "binder/bound_scan_source.h" +#include "binder/expression/literal_expression.h" +#include "binder/expression/parameter_expression.h" +#include "common/exception/binder.h" +#include "common/exception/copy.h" +#include "common/exception/message.h" +#include "common/file_system/local_file_system.h" +#include "common/file_system/virtual_file_system.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "extension/extension_manager.h" +#include "function/table/bind_input.h" +#include "main/client_context.h" +#include "main/database_manager.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/scan_source.h" + +using namespace lbug::parser; +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +FileTypeInfo bindSingleFileType(const main::ClientContext* context, const std::string& filePath) { + std::filesystem::path fileName(filePath); + auto extension = VirtualFileSystem::GetUnsafe(*context)->getFileExtension(fileName); + return FileTypeInfo{FileTypeUtils::getFileTypeFromExtension(extension), + extension.substr(std::min(1, extension.length()))}; +} + +FileTypeInfo Binder::bindFileTypeInfo(const std::vector& filePaths) const { + auto expectedFileType = FileTypeInfo{FileType::UNKNOWN, "" /* fileTypeStr */}; + for (auto& filePath : filePaths) { + auto fileType = bindSingleFileType(clientContext, filePath); + expectedFileType = + (expectedFileType.fileType == FileType::UNKNOWN) ? fileType : expectedFileType; + if (fileType.fileType != expectedFileType.fileType) { + throw CopyException("Loading files with different types is not currently supported."); + } + } + return expectedFileType; +} + +std::vector Binder::bindFilePaths(const std::vector& filePaths) const { + std::vector boundFilePaths; + for (auto& filePath : filePaths) { + // This is a temporary workaround because we use duckdb to read from iceberg/delta/azure. + // When we read delta/iceberg/azure tables from s3/httpfs, we don't have the httpfs + // extension loaded meaning that we cannot handle remote paths. So we pass the file path to + // duckdb for validation when we bindFileScanSource. + const auto& loadedExtensions = + extension::ExtensionManager::Get(*clientContext)->getLoadedExtensions(); + const bool httpfsExtensionLoaded = + std::any_of(loadedExtensions.begin(), loadedExtensions.end(), + [](const auto& extension) { return extension.getExtensionName() == "HTTPFS"; }); + if (!httpfsExtensionLoaded && !LocalFileSystem::isLocalPath(filePath)) { + boundFilePaths.push_back(filePath); + continue; + } + auto globbedFilePaths = + VirtualFileSystem::GetUnsafe(*clientContext)->glob(clientContext, filePath); + if (globbedFilePaths.empty()) { + throw BinderException{ + stringFormat("No file found that matches the pattern: {}.", filePath)}; + } + for (auto& globbedPath : globbedFilePaths) { + boundFilePaths.push_back(globbedPath); + } + } + return boundFilePaths; +} + +case_insensitive_map_t Binder::bindParsingOptions(const options_t& parsingOptions) { + case_insensitive_map_t options; + for (auto& option : parsingOptions) { + auto name = option.first; + StringUtils::toUpper(name); + auto expr = expressionBinder.bindExpression(*option.second); + KU_ASSERT(expr->expressionType == ExpressionType::LITERAL); + auto literalExpr = ku_dynamic_cast(expr.get()); + options.insert({name, literalExpr->getValue()}); + } + return options; +} + +std::unique_ptr Binder::bindScanSource(const BaseScanSource* source, + const options_t& options, const std::vector& columnNames, + const std::vector& columnTypes) { + switch (source->type) { + case ScanSourceType::FILE: { + return bindFileScanSource(*source, options, columnNames, columnTypes); + } + case ScanSourceType::QUERY: { + return bindQueryScanSource(*source, options, columnNames, columnTypes); + } + case ScanSourceType::OBJECT: { + return bindObjectScanSource(*source, options, columnNames, columnTypes); + } + case ScanSourceType::TABLE_FUNC: { + return bindTableFuncScanSource(*source, options, columnNames, columnTypes); + } + case ScanSourceType::PARAM: { + return bindParameterScanSource(*source, options, columnNames, columnTypes); + } + default: + KU_UNREACHABLE; + } +} + +bool handleFileViaFunction(main::ClientContext* context, std::vector filePaths) { + bool handleFileViaFunction = false; + if (VirtualFileSystem::GetUnsafe(*context)->fileOrPathExists(filePaths[0], context)) { + handleFileViaFunction = + VirtualFileSystem::GetUnsafe(*context)->handleFileViaFunction(filePaths[0]); + } + return handleFileViaFunction; +} + +std::unique_ptr Binder::bindFileScanSource(const BaseScanSource& scanSource, + const options_t& options, const std::vector& columnNames, + const std::vector& columnTypes) { + auto fileSource = scanSource.constPtrCast(); + auto filePaths = bindFilePaths(fileSource->filePaths); + auto boundOptions = bindParsingOptions(options); + FileTypeInfo fileTypeInfo; + + if (boundOptions.contains(FileScanInfo::FILE_FORMAT_OPTION_NAME)) { + auto fileFormat = boundOptions.at(FileScanInfo::FILE_FORMAT_OPTION_NAME).toString(); + fileTypeInfo = FileTypeInfo{FileTypeUtils::fromString(fileFormat), fileFormat}; + } else { + fileTypeInfo = bindFileTypeInfo(filePaths); + } + // If we defined a certain FileType, we have to ensure the path is a file, not something else + // (e.g. an existed directory) + if (fileTypeInfo.fileType != FileType::UNKNOWN) { + for (const auto& filePath : filePaths) { + if (!LocalFileSystem::fileExists(filePath) && LocalFileSystem::isLocalPath(filePath)) { + throw BinderException{stringFormat("Provided path is not a file: {}.", filePath)}; + } + } + } + boundOptions.erase(FileScanInfo::FILE_FORMAT_OPTION_NAME); + // Bind file configuration + auto fileScanInfo = std::make_unique(std::move(fileTypeInfo), filePaths); + fileScanInfo->options = std::move(boundOptions); + TableFunction func; + if (handleFileViaFunction(clientContext, filePaths)) { + func = VirtualFileSystem::GetUnsafe(*clientContext)->getHandleFunction(filePaths[0]); + } else { + func = getScanFunction(fileScanInfo->fileTypeInfo, *fileScanInfo); + } + // Bind table function + auto bindInput = TableFuncBindInput(); + bindInput.addLiteralParam(Value::createValue(filePaths[0])); + auto extraInput = std::make_unique(); + extraInput->fileScanInfo = fileScanInfo->copy(); + extraInput->expectedColumnNames = columnNames; + extraInput->expectedColumnTypes = LogicalType::copy(columnTypes); + extraInput->tableFunction = &func; + bindInput.extraInput = std::move(extraInput); + bindInput.binder = this; + auto bindData = func.bindFunc(clientContext, &bindInput); + auto info = BoundTableScanInfo(func, std::move(bindData)); + return std::make_unique(ScanSourceType::FILE, std::move(info)); +} + +std::unique_ptr Binder::bindQueryScanSource(const BaseScanSource& scanSource, + const options_t& options, const std::vector& columnNames, + const std::vector&) { + auto querySource = scanSource.constPtrCast(); + auto boundStatement = bind(*querySource->statement); + auto columns = boundStatement->getStatementResult()->getColumns(); + if (columns.size() != columnNames.size()) { + throw BinderException(stringFormat("Query returns {} columns but {} columns were expected.", + columns.size(), columnNames.size())); + } + for (auto i = 0u; i < columns.size(); ++i) { + columns[i]->setAlias(columnNames[i]); + } + auto scanInfo = BoundQueryScanSourceInfo(bindParsingOptions(options)); + return std::make_unique(std::move(boundStatement), std::move(scanInfo)); +} + +static TableFunction getObjectScanFunc(const std::string& dbName, const std::string& tableName, + main::ClientContext* clientContext) { + // Bind external database table + auto attachedDB = main::DatabaseManager::Get(*clientContext)->getAttachedDatabase(dbName); + auto attachedCatalog = attachedDB->getCatalog(); + auto entry = attachedCatalog->getTableCatalogEntry( + transaction::Transaction::Get(*clientContext), tableName); + return entry->ptrCast()->getScanFunction(); +} + +BoundTableScanInfo bindTableScanSourceInfo(Binder& binder, TableFunction func, + const std::string& sourceName, std::unique_ptr bindData, + const std::vector& columnNames, const std::vector& columnTypes) { + expression_vector columns; + if (columnTypes.empty()) { + } else { + if (bindData->getNumColumns() != columnTypes.size()) { + throw BinderException(stringFormat("{} has {} columns but {} columns were expected.", + sourceName, bindData->getNumColumns(), columnTypes.size())); + } + for (auto i = 0u; i < bindData->getNumColumns(); ++i) { + auto column = + binder.createInvisibleVariable(columnNames[i], bindData->columns[i]->getDataType()); + binder.replaceExpressionInScope(bindData->columns[i]->toString(), columnNames[i], + column); + columns.push_back(column); + } + bindData->columns = columns; + } + return BoundTableScanInfo(func, std::move(bindData)); +} + +std::unique_ptr Binder::bindParameterScanSource( + const BaseScanSource& scanSource, const options_t& options, + const std::vector& columnNames, const std::vector& columnTypes) { + auto paramSource = scanSource.constPtrCast(); + auto paramExpr = expressionBinder.bindParameterExpression(*paramSource->paramExpression); + auto scanSourceValue = paramExpr->constCast().getValue(); + if (scanSourceValue.getDataType().getLogicalTypeID() != LogicalTypeID::POINTER) { + throw BinderException(stringFormat( + "Trying to scan from unsupported data type {}. The only parameter types that can be " + "scanned from are pandas/polars dataframes and pyarrow tables.", + scanSourceValue.getDataType().toString())); + } + TableFunction func; + std::unique_ptr bindData; + auto bindInput = TableFuncBindInput(); + bindInput.binder = this; + // Bind external object as table + auto replacementData = + clientContext->tryReplaceByHandle(scanSourceValue.getValue()); + func = replacementData->func; + auto replaceExtraInput = std::make_unique(); + replaceExtraInput->fileScanInfo.options = bindParsingOptions(options); + replacementData->bindInput.extraInput = std::move(replaceExtraInput); + replacementData->bindInput.binder = this; + bindData = func.bindFunc(clientContext, &replacementData->bindInput); + auto info = bindTableScanSourceInfo(*this, func, paramExpr->toString(), std::move(bindData), + columnNames, columnTypes); + return std::make_unique(ScanSourceType::OBJECT, std::move(info)); +} + +std::unique_ptr Binder::bindObjectScanSource(const BaseScanSource& scanSource, + const options_t& options, const std::vector& columnNames, + const std::vector& columnTypes) { + auto objectSource = scanSource.constPtrCast(); + TableFunction func; + std::unique_ptr bindData; + std::string objectName; + auto bindInput = TableFuncBindInput(); + bindInput.binder = this; + if (objectSource->objectNames.size() == 1) { + // Bind external object as table + objectName = objectSource->objectNames[0]; + auto replacementData = clientContext->tryReplaceByName(objectName); + if (replacementData != nullptr) { // Replace as python object + func = replacementData->func; + auto replaceExtraInput = std::make_unique(); + replaceExtraInput->fileScanInfo.options = bindParsingOptions(options); + replacementData->bindInput.extraInput = std::move(replaceExtraInput); + replacementData->bindInput.binder = this; + bindData = func.bindFunc(clientContext, &replacementData->bindInput); + } else if (main::DatabaseManager::Get(*clientContext)->hasDefaultDatabase()) { + auto dbName = main::DatabaseManager::Get(*clientContext)->getDefaultDatabase(); + func = getObjectScanFunc(dbName, objectSource->objectNames[0], clientContext); + bindData = func.bindFunc(clientContext, &bindInput); + } else { + throw BinderException(ExceptionMessage::variableNotInScope(objectName)); + } + } else if (objectSource->objectNames.size() == 2) { + // Bind external database table + objectName = objectSource->objectNames[0] + "." + objectSource->objectNames[1]; + func = getObjectScanFunc(objectSource->objectNames[0], objectSource->objectNames[1], + clientContext); + bindData = func.bindFunc(clientContext, &bindInput); + } else { + // LCOV_EXCL_START + throw BinderException(stringFormat("Cannot find object {}.", + StringUtils::join(objectSource->objectNames, ","))); + // LCOV_EXCL_STOP + } + auto info = bindTableScanSourceInfo(*this, func, objectName, std::move(bindData), columnNames, + columnTypes); + return std::make_unique(ScanSourceType::OBJECT, std::move(info)); +} + +std::unique_ptr Binder::bindTableFuncScanSource( + const BaseScanSource& scanSource, const options_t& options, + const std::vector& columnNames, const std::vector& columnTypes) { + if (!options.empty()) { + throw common::BinderException{"No option is supported when copying from table functions."}; + } + auto tableFuncScanSource = scanSource.constPtrCast(); + auto& parsedFuncExpression = + tableFuncScanSource->functionExpression->constCast(); + auto boundTableFunc = bindTableFunc(parsedFuncExpression.getFunctionName(), + *tableFuncScanSource->functionExpression, {} /* yieldVariables */); + auto& tableFunc = boundTableFunc.func; + auto info = bindTableScanSourceInfo(*this, tableFunc, tableFunc.name, + std::move(boundTableFunc.bindData), columnNames, columnTypes); + return std::make_unique(ScanSourceType::OBJECT, std::move(info)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_graph_pattern.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_graph_pattern.cpp new file mode 100644 index 0000000000..1ad1af5ee8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_graph_pattern.cpp @@ -0,0 +1,695 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/path_expression.h" +#include "binder/expression/property_expression.h" +#include "binder/expression_visitor.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/enums/rel_direction.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/utils.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "function/gds/rec_joins.h" +#include "function/rewrite_function.h" +#include "function/schema/vector_node_rel_functions.h" +#include "main/client_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +// A graph pattern contains node/rel and a set of key-value pairs associated with the variable. We +// bind node/rel as query graph and key-value pairs as a separate collection. This collection is +// interpreted in two different ways. +// - In MATCH clause, these are additional predicates to WHERE clause +// - In UPDATE clause, there are properties to set. +// We do not store key-value pairs in query graph primarily because we will merge key-value +// std::pairs with other predicates specified in WHERE clause. +BoundGraphPattern Binder::bindGraphPattern(const std::vector& graphPattern) { + auto queryGraphCollection = QueryGraphCollection(); + for (auto& patternElement : graphPattern) { + queryGraphCollection.addAndMergeQueryGraphIfConnected(bindPatternElement(patternElement)); + } + queryGraphCollection.finalize(); + auto boundPattern = BoundGraphPattern(); + boundPattern.queryGraphCollection = std::move(queryGraphCollection); + return boundPattern; +} + +// Grammar ensures pattern element is always connected and thus can be bound as a query graph. +QueryGraph Binder::bindPatternElement(const PatternElement& patternElement) { + auto queryGraph = QueryGraph(); + expression_vector nodeAndRels; + auto leftNode = bindQueryNode(*patternElement.getFirstNodePattern(), queryGraph); + nodeAndRels.push_back(leftNode); + for (auto i = 0u; i < patternElement.getNumPatternElementChains(); ++i) { + auto patternElementChain = patternElement.getPatternElementChain(i); + auto rightNode = bindQueryNode(*patternElementChain->getNodePattern(), queryGraph); + auto rel = + bindQueryRel(*patternElementChain->getRelPattern(), leftNode, rightNode, queryGraph); + nodeAndRels.push_back(rel); + nodeAndRels.push_back(rightNode); + leftNode = rightNode; + } + if (patternElement.hasPathName()) { + auto pathName = patternElement.getPathName(); + auto pathExpression = createPath(pathName, nodeAndRels); + addToScope(pathName, pathExpression); + } + return queryGraph; +} + +static LogicalType getRecursiveRelLogicalType(const LogicalType& nodeType, + const LogicalType& relType) { + auto nodesType = LogicalType::LIST(nodeType.copy()); + auto relsType = LogicalType::LIST(relType.copy()); + std::vector recursiveRelFields; + recursiveRelFields.emplace_back(InternalKeyword::NODES, std::move(nodesType)); + recursiveRelFields.emplace_back(InternalKeyword::RELS, std::move(relsType)); + return LogicalType::RECURSIVE_REL(std::move(recursiveRelFields)); +} + +static void extraFieldFromStructType(const LogicalType& structType, + std::unordered_set& set, std::vector& structFields) { + for (auto& field : StructType::getFields(structType)) { + if (!set.contains(field.getName())) { + set.insert(field.getName()); + structFields.emplace_back(field.getName(), field.getType().copy()); + } + } +} + +std::shared_ptr Binder::createPath(const std::string& pathName, + const expression_vector& children) { + std::unordered_set nodeFieldNameSet; + std::vector nodeFields; + std::unordered_set relFieldNameSet; + std::vector relFields; + for (auto& child : children) { + if (ExpressionUtil::isNodePattern(*child)) { + auto& node = child->constCast(); + extraFieldFromStructType(node.getDataType(), nodeFieldNameSet, nodeFields); + } else if (ExpressionUtil::isRelPattern(*child)) { + auto rel = ku_dynamic_cast(child.get()); + extraFieldFromStructType(rel->getDataType(), relFieldNameSet, relFields); + } else if (ExpressionUtil::isRecursiveRelPattern(*child)) { + auto recursiveRel = ku_dynamic_cast(child.get()); + auto recursiveInfo = recursiveRel->getRecursiveInfo(); + extraFieldFromStructType(recursiveInfo->node->getDataType(), nodeFieldNameSet, + nodeFields); + extraFieldFromStructType(recursiveInfo->rel->getDataType(), relFieldNameSet, relFields); + } else { + KU_UNREACHABLE; + } + } + auto nodeType = LogicalType::NODE(std::move(nodeFields)); + auto relType = LogicalType::REL(std::move(relFields)); + auto uniqueName = getUniqueExpressionName(pathName); + return std::make_shared(getRecursiveRelLogicalType(nodeType, relType), + uniqueName, pathName, std::move(nodeType), std::move(relType), children); +} + +static std::vector getPropertyNames(const std::vector& entries) { + std::vector result; + std::unordered_set propertyNamesSet; + for (auto& entry : entries) { + for (auto& property : entry->getProperties()) { + if (propertyNamesSet.contains(property.getName())) { + continue; + } + propertyNamesSet.insert(property.getName()); + result.push_back(property.getName()); + } + } + return result; +} + +static std::shared_ptr createPropertyExpression(const std::string& propertyName, + const std::string& uniqueVariableName, const std::string& rawVariableName, + const std::vector& entries) { + table_id_map_t infos; + std::vector dataTypes; + for (auto& entry : entries) { + bool exists = false; + if (entry->containsProperty(propertyName)) { + exists = true; + dataTypes.push_back(entry->getProperty(propertyName).getType().copy()); + } + // Bind isPrimaryKey + auto isPrimaryKey = false; + if (entry->getTableType() == TableType::NODE) { + auto nodeEntry = entry->constPtrCast(); + isPrimaryKey = nodeEntry->getPrimaryKeyName() == propertyName; + } + auto info = SingleLabelPropertyInfo(exists, isPrimaryKey); + infos.insert({entry->getTableID(), std::move(info)}); + } + LogicalType maxType = LogicalTypeUtils::combineTypes(dataTypes); + return std::make_shared(std::move(maxType), propertyName, + uniqueVariableName, rawVariableName, std::move(infos)); +} + +static void checkRelDirectionTypeAgainstStorageDirection(const RelExpression* rel) { + switch (rel->getDirectionType()) { + case RelDirectionType::SINGLE: + // Directed pattern is in the fwd direction + if (!containsValue(rel->getExtendDirections(), ExtendDirection::FWD)) { + throw BinderException(stringFormat("Querying table matched in rel pattern '{}' with " + "bwd-only storage direction isn't supported.", + rel->toString())); + } + break; + case RelDirectionType::BOTH: + if (rel->getExtendDirections().size() < NUM_REL_DIRECTIONS) { + throw BinderException( + stringFormat("Undirected rel pattern '{}' has at least one matched rel table with " + "storage type 'fwd' or 'bwd'. Undirected rel patterns are only " + "supported if every matched rel table has storage type 'both'.", + rel->toString())); + } + break; + default: + KU_UNREACHABLE; + } +} + +std::shared_ptr Binder::bindQueryRel(const RelPattern& relPattern, + const std::shared_ptr& leftNode, + const std::shared_ptr& rightNode, QueryGraph& queryGraph) { + auto parsedName = relPattern.getVariableName(); + if (scope.contains(parsedName)) { + auto prevVariable = scope.getExpression(parsedName); + auto expectedDataType = QueryRelTypeUtils::isRecursive(relPattern.getRelType()) ? + LogicalTypeID::RECURSIVE_REL : + LogicalTypeID::REL; + ExpressionUtil::validateDataType(*prevVariable, expectedDataType); + throw BinderException("Bind relationship " + parsedName + + " to relationship with same name is not supported."); + } + auto entries = bindRelGroupEntries(relPattern.getTableNames()); + // bind src & dst node + RelDirectionType directionType = RelDirectionType::UNKNOWN; + std::shared_ptr srcNode; + std::shared_ptr dstNode; + switch (relPattern.getDirection()) { + case ArrowDirection::LEFT: { + srcNode = rightNode; + dstNode = leftNode; + directionType = RelDirectionType::SINGLE; + } break; + case ArrowDirection::RIGHT: { + srcNode = leftNode; + dstNode = rightNode; + directionType = RelDirectionType::SINGLE; + } break; + case ArrowDirection::BOTH: { + // For both direction, left and right will be written with the same label set. So either one + // being src will be correct. + srcNode = leftNode; + dstNode = rightNode; + directionType = RelDirectionType::BOTH; + } break; + default: + KU_UNREACHABLE; + } + // bind variable length + std::shared_ptr queryRel; + if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) { + queryRel = createRecursiveQueryRel(relPattern, entries, srcNode, dstNode, directionType); + } else { + queryRel = createNonRecursiveQueryRel(relPattern.getVariableName(), entries, srcNode, + dstNode, directionType); + for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) { + auto boundLhs = + expressionBinder.bindNodeOrRelPropertyExpression(*queryRel, propertyName); + auto boundRhs = expressionBinder.bindExpression(*rhs); + boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType); + queryRel->addPropertyDataExpr(propertyName, std::move(boundRhs)); + } + } + queryRel->setLeftNode(leftNode); + queryRel->setRightNode(rightNode); + queryRel->setAlias(parsedName); + if (!parsedName.empty()) { + addToScope(parsedName, queryRel); + } + queryGraph.addQueryRel(queryRel); + checkRelDirectionTypeAgainstStorageDirection(queryRel.get()); + return queryRel; +} + +static std::vector getBaseNodeStructFields() { + std::vector fields; + fields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID()); + fields.emplace_back(InternalKeyword::LABEL, LogicalType::STRING()); + return fields; +} + +static std::vector getBaseRelStructFields() { + std::vector fields; + fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID()); + fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID()); + fields.emplace_back(InternalKeyword::LABEL, LogicalType::STRING()); + return fields; +} + +static std::shared_ptr construct(LogicalType type, + const std::string& propertyName, const Expression& child) { + KU_ASSERT(child.expressionType == ExpressionType::PATTERN); + auto& patternExpr = child.constCast(); + auto variableName = patternExpr.getVariableName(); + auto uniqueName = patternExpr.getUniqueName(); + // Assign an invalid property id for virtual property. + table_id_map_t infos; + for (auto& entry : patternExpr.getEntries()) { + infos.insert({entry->getTableID(), + SingleLabelPropertyInfo(false /* exists */, false /* isPrimaryKey */)}); + } + return std::make_unique(std::move(type), propertyName, uniqueName, + variableName, std::move(infos)); +} + +std::shared_ptr Binder::createNonRecursiveQueryRel(const std::string& parsedName, + const std::vector& entries, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType) { + auto uniqueName = getUniqueExpressionName(parsedName); + // Bind properties + auto structFields = getBaseRelStructFields(); + std::vector> propertyExpressions; + if (entries.empty()) { + structFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID()); + } else { + for (auto& propertyName : getPropertyNames(entries)) { + auto property = createPropertyExpression(propertyName, uniqueName, parsedName, entries); + structFields.emplace_back(property->getPropertyName(), property->getDataType().copy()); + propertyExpressions.push_back(std::move(property)); + } + } + auto queryRel = std::make_shared(LogicalType::REL(std::move(structFields)), + uniqueName, parsedName, entries, std::move(srcNode), std::move(dstNode), directionType, + QueryRelType::NON_RECURSIVE); + queryRel->setAlias(parsedName); + if (entries.empty()) { + queryRel->addPropertyExpression( + construct(LogicalType::INTERNAL_ID(), InternalKeyword::ID, *queryRel)); + } else { + for (auto& property : propertyExpressions) { + queryRel->addPropertyExpression(property); + } + } + // Bind internal expressions. + if (directionType == RelDirectionType::BOTH) { + queryRel->setDirectionExpr(expressionBinder.createVariableExpression(LogicalType::BOOL(), + queryRel->getUniqueName() + InternalKeyword::DIRECTION)); + } + auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryRel}); + queryRel->setLabelExpression(function::LabelFunction::rewriteFunc(input)); + return queryRel; +} + +static void bindProjectionListAsStructField(const expression_vector& projectionList, + std::vector& fields) { + for (auto& expression : projectionList) { + if (expression->expressionType != ExpressionType::PROPERTY) { + throw BinderException(stringFormat("Unsupported projection item {} on recursive rel.", + expression->toString())); + } + auto& property = expression->constCast(); + fields.emplace_back(property.getPropertyName(), property.getDataType().copy()); + } +} + +static void checkWeightedShortestPathSupportedType(const LogicalType& type) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::INT8: + case LogicalTypeID::UINT8: + case LogicalTypeID::INT16: + case LogicalTypeID::UINT16: + case LogicalTypeID::INT32: + case LogicalTypeID::UINT32: + case LogicalTypeID::INT64: + case LogicalTypeID::UINT64: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + return; + default: + break; + } + throw BinderException(stringFormat( + "{} weight type is not supported for weighted shortest path.", type.toString())); +} + +std::shared_ptr Binder::createRecursiveQueryRel(const parser::RelPattern& relPattern, + const std::vector& entries, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType) { + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + table_catalog_entry_set_t nodeEntrySet; + for (auto entry : entries) { + auto& relGroupEntry = entry->constCast(); + for (auto id : relGroupEntry.getSrcNodeTableIDSet()) { + nodeEntrySet.insert(catalog->getTableCatalogEntry(transaction, id)); + } + for (auto id : relGroupEntry.getDstNodeTableIDSet()) { + nodeEntrySet.insert(catalog->getTableCatalogEntry(transaction, id)); + } + } + auto nodeEntries = std::vector{nodeEntrySet.begin(), nodeEntrySet.end()}; + auto recursivePatternInfo = relPattern.getRecursiveInfo(); + auto prevScope = saveScope(); + scope.clear(); + // Bind intermediate node. + auto node = createQueryNode(recursivePatternInfo->nodeName, nodeEntries); + addToScope(node->toString(), node); + auto nodeFields = getBaseNodeStructFields(); + auto nodeProjectionList = bindRecursivePatternNodeProjectionList(*recursivePatternInfo, *node); + bindProjectionListAsStructField(nodeProjectionList, nodeFields); + node->setDataType(LogicalType::NODE(std::move(nodeFields))); + auto nodeCopy = createQueryNode(recursivePatternInfo->nodeName, nodeEntries); + // Bind intermediate rel + auto rel = createNonRecursiveQueryRel(recursivePatternInfo->relName, entries, + nullptr /* srcNode */, nullptr /* dstNode */, directionType); + addToScope(rel->toString(), rel); + auto relProjectionList = bindRecursivePatternRelProjectionList(*recursivePatternInfo, *rel); + auto relFields = getBaseRelStructFields(); + relFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID()); + bindProjectionListAsStructField(relProjectionList, relFields); + rel->setDataType(LogicalType::REL(std::move(relFields))); + // Bind predicates in {}, e.g. [e* {date=1999-01-01}] + std::shared_ptr relPredicate = nullptr; + for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) { + auto boundLhs = expressionBinder.bindNodeOrRelPropertyExpression(*rel, propertyName); + auto boundRhs = expressionBinder.bindExpression(*rhs); + boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType); + auto predicate = expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs); + relPredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND, relPredicate, + predicate); + } + // Bind predicates in (r, n | WHERE ) + bool emptyRecursivePattern = false; + std::shared_ptr nodePredicate = nullptr; + if (recursivePatternInfo->whereExpression != nullptr) { + expressionBinder.config.disableLabelFunctionLiteralRewrite = true; + auto wherePredicate = bindWhereExpression(*recursivePatternInfo->whereExpression); + expressionBinder.config.disableLabelFunctionLiteralRewrite = false; + for (auto& predicate : wherePredicate->splitOnAND()) { + auto collector = DependentVarNameCollector(); + collector.visit(predicate); + auto dependentVariableNames = collector.getVarNames(); + auto dependOnNode = dependentVariableNames.contains(node->getUniqueName()); + auto dependOnRel = dependentVariableNames.contains(rel->getUniqueName()); + if (dependOnNode && dependOnRel) { + throw BinderException( + stringFormat("Cannot evaluate {} because it depends on both {} and {}.", + predicate->toString(), node->toString(), rel->toString())); + } else if (dependOnNode) { + nodePredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND, + nodePredicate, predicate); + } else if (dependOnRel) { + relPredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND, + relPredicate, predicate); + } else { + if (!ExpressionUtil::isBoolLiteral(*predicate)) { + throw BinderException(stringFormat( + "Cannot evaluate {} because it does not depend on {} or {}. Treating it as " + "a node or relationship predicate is ambiguous.", + predicate->toString(), node->toString(), rel->toString())); + } + // If predicate is true literal, we ignore. + // If predicate is false literal, we mark this recursive relationship as empty + // and later in planner we replace it with EmptyResult. + if (!ExpressionUtil::getLiteralValue(*predicate)) { + emptyRecursivePattern = true; + } + } + } + } + // Bind rel + restoreScope(std::move(prevScope)); + auto parsedName = relPattern.getVariableName(); + auto prunedRelEntries = entries; + if (emptyRecursivePattern) { + prunedRelEntries.clear(); + } + auto queryRel = std::make_shared( + getRecursiveRelLogicalType(node->getDataType(), rel->getDataType()), + getUniqueExpressionName(parsedName), parsedName, prunedRelEntries, std::move(srcNode), + std::move(dstNode), directionType, relPattern.getRelType()); + // Bind graph entry. + auto graphEntry = graph::NativeGraphEntry(); + for (auto nodeEntry : node->getEntries()) { + graphEntry.nodeInfos.emplace_back(nodeEntry); + } + for (auto relEntry : rel->getEntries()) { + graphEntry.relInfos.emplace_back(relEntry, rel, relPredicate); + } + auto bindData = std::make_unique(graphEntry.copy()); + // Bind lower upper bound. + auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern); + bindData->lowerBound = lowerBound; + bindData->upperBound = upperBound; + // Bind semantic. + bindData->semantic = QueryRelTypeUtils::getPathSemantic(queryRel->getRelType()); + // Bind path related expressions. + bindData->lengthExpr = construct(LogicalType::INT64(), InternalKeyword::LENGTH, *queryRel); + bindData->pathNodeIDsExpr = + createInvisibleVariable("pathNodeIDs", LogicalType::LIST(LogicalType::INTERNAL_ID())); + bindData->pathEdgeIDsExpr = + createInvisibleVariable("pathEdgeIDs", LogicalType::LIST(LogicalType::INTERNAL_ID())); + if (queryRel->getDirectionType() == RelDirectionType::BOTH) { + bindData->directionExpr = + createInvisibleVariable("pathEdgeDirections", LogicalType::LIST(LogicalType::BOOL())); + } + // Bind weighted path related expressions. + if (QueryRelTypeUtils::isWeighted(queryRel->getRelType())) { + auto propertyExpr = expressionBinder.bindNodeOrRelPropertyExpression(*rel, + recursivePatternInfo->weightPropertyName); + checkWeightedShortestPathSupportedType(propertyExpr->getDataType()); + bindData->weightPropertyExpr = propertyExpr; + bindData->weightOutputExpr = + createInvisibleVariable(parsedName + "_cost", LogicalType::DOUBLE()); + } + + auto recursiveInfo = std::make_unique(); + recursiveInfo->node = node; + recursiveInfo->nodeCopy = nodeCopy; + recursiveInfo->rel = rel; + recursiveInfo->nodePredicate = std::move(nodePredicate); + recursiveInfo->relPredicate = std::move(relPredicate); + recursiveInfo->nodeProjectionList = std::move(nodeProjectionList); + recursiveInfo->relProjectionList = std::move(relProjectionList); + recursiveInfo->function = QueryRelTypeUtils::getFunction(queryRel->getRelType()); + recursiveInfo->bindData = std::move(bindData); + queryRel->setRecursiveInfo(std::move(recursiveInfo)); + return queryRel; +} + +expression_vector Binder::bindRecursivePatternNodeProjectionList( + const RecursiveRelPatternInfo& info, const NodeOrRelExpression& expr) { + expression_vector result; + if (!info.hasProjection) { + for (auto& expression : expr.getPropertyExpressions()) { + result.push_back(expression); + } + } else { + for (auto& expression : info.nodeProjectionList) { + result.push_back(expressionBinder.bindExpression(*expression)); + } + } + return result; +} + +expression_vector Binder::bindRecursivePatternRelProjectionList(const RecursiveRelPatternInfo& info, + const NodeOrRelExpression& expr) { + expression_vector result; + if (!info.hasProjection) { + for (auto& property : expr.getPropertyExpressions()) { + if (property->isInternalID()) { + continue; + } + result.push_back(property); + } + } else { + for (auto& expression : info.relProjectionList) { + result.push_back(expressionBinder.bindExpression(*expression)); + } + } + return result; +} + +std::pair Binder::bindVariableLengthRelBound(const RelPattern& relPattern) { + auto recursiveInfo = relPattern.getRecursiveInfo(); + uint32_t lowerBound = 0; + function::CastString::operation( + ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()}, + lowerBound); + auto maxDepth = clientContext->getClientConfig()->varLengthMaxDepth; + auto upperBound = maxDepth; + if (!recursiveInfo->upperBound.empty()) { + function::CastString::operation( + ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()}, + upperBound); + } + if (lowerBound > upperBound) { + throw BinderException(stringFormat("Lower bound of rel {} is greater than upperBound.", + relPattern.getVariableName())); + } + if (upperBound > maxDepth) { + throw BinderException(stringFormat("Upper bound of rel {} exceeds maximum: {}.", + relPattern.getVariableName(), std::to_string(maxDepth))); + } + if ((relPattern.getRelType() == QueryRelType::ALL_SHORTEST || + relPattern.getRelType() == QueryRelType::SHORTEST) && + lowerBound != 1) { + throw BinderException("Lower bound of shortest/all_shortest path must be 1."); + } + return std::make_pair(lowerBound, upperBound); +} + +std::shared_ptr Binder::bindQueryNode(const NodePattern& nodePattern, + QueryGraph& queryGraph) { + auto parsedName = nodePattern.getVariableName(); + std::shared_ptr queryNode; + if (scope.contains(parsedName)) { // bind to node in scope + auto prevVariable = scope.getExpression(parsedName); + if (!ExpressionUtil::isNodePattern(*prevVariable)) { + if (!scope.hasNodeReplacement(parsedName)) { + throw BinderException(stringFormat("Cannot bind {} as node pattern.", parsedName)); + } + queryNode = scope.getNodeReplacement(parsedName); + queryNode->addPropertyDataExpr(InternalKeyword::ID, queryNode->getInternalID()); + } else { + queryNode = std::static_pointer_cast(prevVariable); + // E.g. MATCH (a:person) MATCH (a:organisation) + // We bind to a single node with both labels + if (!nodePattern.getTableNames().empty()) { + auto otherNodeEntries = bindNodeTableEntries(nodePattern.getTableNames()); + queryNode->addEntries(otherNodeEntries); + } + } + } else { + queryNode = createQueryNode(nodePattern); + if (!parsedName.empty()) { + addToScope(parsedName, queryNode); + } + } + for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) { + auto boundLhs = expressionBinder.bindNodeOrRelPropertyExpression(*queryNode, propertyName); + auto boundRhs = expressionBinder.bindExpression(*rhs); + boundRhs = expressionBinder.forceCast(boundRhs, boundLhs->dataType); + queryNode->addPropertyDataExpr(propertyName, std::move(boundRhs)); + } + queryGraph.addQueryNode(queryNode); + return queryNode; +} + +std::shared_ptr Binder::createQueryNode(const NodePattern& nodePattern) { + auto parsedName = nodePattern.getVariableName(); + return createQueryNode(parsedName, bindNodeTableEntries(nodePattern.getTableNames())); +} + +std::shared_ptr Binder::createQueryNode(const std::string& parsedName, + const std::vector& entries) { + auto uniqueName = getUniqueExpressionName(parsedName); + // Bind properties. + auto structFields = getBaseNodeStructFields(); + std::vector> propertyExpressions; + for (auto& propertyName : getPropertyNames(entries)) { + auto property = createPropertyExpression(propertyName, uniqueName, parsedName, entries); + structFields.emplace_back(property->getPropertyName(), property->getDataType().copy()); + propertyExpressions.push_back(std::move(property)); + } + auto queryNode = std::make_shared(LogicalType::NODE(std::move(structFields)), + uniqueName, parsedName, entries); + queryNode->setAlias(parsedName); + for (auto& property : propertyExpressions) { + queryNode->addPropertyExpression(property); + } + // Bind internal expressions + queryNode->setInternalID( + construct(LogicalType::INTERNAL_ID(), InternalKeyword::ID, *queryNode)); + auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryNode}); + queryNode->setLabelExpression(function::LabelFunction::rewriteFunc(input)); + return queryNode; +} + +static std::vector sortEntries(const table_catalog_entry_set_t& set) { + std::vector entries; + for (auto entry : set) { + entries.push_back(entry); + } + std::sort(entries.begin(), entries.end(), + [](const TableCatalogEntry* a, const TableCatalogEntry* b) { + return a->getTableID() < b->getTableID(); + }); + return entries; +} + +std::vector Binder::bindNodeTableEntries( + const std::vector& tableNames) const { + auto transaction = transaction::Transaction::Get(*clientContext); + auto catalog = Catalog::Get(*clientContext); + auto useInternal = clientContext->useInternalCatalogEntry(); + table_catalog_entry_set_t entrySet; + if (tableNames.empty()) { // Rewrite as all node tables in database. + for (auto entry : catalog->getNodeTableEntries(transaction, useInternal)) { + entrySet.insert(entry); + } + } else { + for (auto& name : tableNames) { + auto entry = bindNodeTableEntry(name); + if (entry->getType() != CatalogEntryType::NODE_TABLE_ENTRY) { + throw BinderException( + stringFormat("Cannot bind {} as a node pattern label.", entry->getName())); + } + entrySet.insert(entry); + } + } + return sortEntries(entrySet); +} + +TableCatalogEntry* Binder::bindNodeTableEntry(const std::string& name) const { + auto transaction = transaction::Transaction::Get(*clientContext); + auto catalog = Catalog::Get(*clientContext); + auto useInternal = clientContext->useInternalCatalogEntry(); + if (!catalog->containsTable(transaction, name, useInternal)) { + throw BinderException(stringFormat("Table {} does not exist.", name)); + } + return catalog->getTableCatalogEntry(transaction, name, useInternal); +} + +std::vector Binder::bindRelGroupEntries( + const std::vector& tableNames) const { + auto transaction = transaction::Transaction::Get(*clientContext); + auto catalog = Catalog::Get(*clientContext); + auto useInternal = clientContext->useInternalCatalogEntry(); + table_catalog_entry_set_t entrySet; + if (tableNames.empty()) { // Rewrite as all rel groups in database. + for (auto entry : catalog->getRelGroupEntries(transaction, useInternal)) { + entrySet.insert(entry); + } + } else { + for (auto& name : tableNames) { + if (catalog->containsTable(transaction, name)) { + auto entry = catalog->getTableCatalogEntry(transaction, name, useInternal); + if (entry->getType() != CatalogEntryType::REL_GROUP_ENTRY) { + throw BinderException(stringFormat( + "Cannot bind {} as a relationship pattern label.", entry->getName())); + } + entrySet.insert(entry); + } else { + throw BinderException(stringFormat("Table {} does not exist.", name)); + } + } + } + return sortEntries(entrySet); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_import_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_import_database.cpp new file mode 100644 index 0000000000..24be90d5cf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_import_database.cpp @@ -0,0 +1,144 @@ +#include "binder/binder.h" +#include "binder/bound_import_database.h" +#include "common/copier_config/csv_reader_config.h" +#include "common/exception/binder.h" +#include "common/file_system/virtual_file_system.h" +#include "main/client_context.h" +#include "parser/copy.h" +#include "parser/parser.h" +#include "parser/port_db.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +static std::string getQueryFromFile(VirtualFileSystem* vfs, const std::string& boundFilePath, + const std::string& fileName, main::ClientContext* context) { + auto filePath = vfs->joinPath(boundFilePath, fileName); + if (!vfs->fileOrPathExists(filePath, context)) { + if (fileName == PortDBConstants::COPY_FILE_NAME) { + return ""; + } + if (fileName == PortDBConstants::INDEX_FILE_NAME) { + return ""; + } + throw BinderException(stringFormat("File {} does not exist.", filePath)); + } + auto fileInfo = vfs->openFile(filePath, FileOpenFlags(FileFlags::READ_ONLY +#ifdef _WIN32 + | FileFlags::BINARY +#endif + )); + auto fsize = fileInfo->getFileSize(); + auto buffer = std::make_unique(fsize); + fileInfo->readFile(buffer.get(), fsize); + return std::string(buffer.get(), fsize); +} + +static std::string getColumnNamesToCopy(const CopyFrom& copyFrom) { + std::string columns = ""; + std::string delimiter = ""; + for (auto& column : copyFrom.getCopyColumnInfo().columnNames) { + columns += delimiter; + columns += "`" + column + "`"; + if (delimiter == "") { + delimiter = ","; + } + } + if (columns.empty()) { + return columns; + } + return stringFormat("({})", columns); +} + +static std::string getCopyFilePath(const std::string& boundFilePath, const std::string& filePath) { + if (filePath[0] == '/' || (std::isalpha(filePath[0]) && filePath[1] == ':')) { + // Note: + // Unix absolute path starts with '/' + // Windows absolute path starts with "[DiskID]://" + // This code path is for backward compatibility, we used to export the absolute path for + // csv files to copy.cypher files. + return filePath; + } + + auto path = boundFilePath + "/" + filePath; +#if defined(_WIN32) + // TODO(Ziyi): This is a temporary workaround because our parser requires input cypher queries + // to escape all special characters in string literal. E.g. The user input query is: 'IMPORT + // DATABASE 'C:\\db\\uw''. The parser removes any escaped characters and this function accepts + // the path parameter as 'C:\db\uw'. Then the ImportDatabase operator gives the file path to + // antlr4 parser directly without escaping any special characters in the path, which causes a + // parser exception. However, the parser exception is not thrown properly which leads to the + // undefined behaviour. + size_t pos = 0; + while ((pos = path.find('\\', pos)) != std::string::npos) { + path.replace(pos, 1, "\\\\"); + pos += 2; + } +#endif + return path; +} + +std::unique_ptr Binder::bindImportDatabaseClause(const Statement& statement) { + auto& importDB = statement.constCast(); + auto fs = VirtualFileSystem::GetUnsafe(*clientContext); + auto boundFilePath = fs->expandPath(clientContext, importDB.getFilePath()); + if (!fs->fileOrPathExists(boundFilePath, clientContext)) { + throw BinderException(stringFormat("Directory {} does not exist.", boundFilePath)); + } + std::string finalQueryStatements; + finalQueryStatements += + getQueryFromFile(fs, boundFilePath, PortDBConstants::SCHEMA_FILE_NAME, clientContext); + // replace the path in copy from statements with the bound path + auto copyQuery = + getQueryFromFile(fs, boundFilePath, PortDBConstants::COPY_FILE_NAME, clientContext); + if (!copyQuery.empty()) { + auto parsedStatements = Parser::parseQuery(copyQuery); + for (auto& parsedStatement : parsedStatements) { + KU_ASSERT(parsedStatement->getStatementType() == StatementType::COPY_FROM); + auto& copyFromStatement = parsedStatement->constCast(); + KU_ASSERT(copyFromStatement.getSource()->type == ScanSourceType::FILE); + auto filePaths = + copyFromStatement.getSource()->constPtrCast()->filePaths; + KU_ASSERT(filePaths.size() == 1); + auto fileTypeInfo = bindFileTypeInfo(filePaths); + std::string query; + auto copyFilePath = getCopyFilePath(boundFilePath, filePaths[0]); + auto columnNames = getColumnNamesToCopy(copyFromStatement); + auto parsingOptions = bindParsingOptions(copyFromStatement.getParsingOptions()); + std::unordered_map copyFromOptions; + if (parsingOptions.contains(CopyConstants::FROM_OPTION_NAME)) { + KU_ASSERT(parsingOptions.contains(CopyConstants::TO_OPTION_NAME)); + copyFromOptions[CopyConstants::FROM_OPTION_NAME] = stringFormat("'{}'", + parsingOptions.at(CopyConstants::FROM_OPTION_NAME).getValue()); + copyFromOptions[CopyConstants::TO_OPTION_NAME] = stringFormat("'{}'", + parsingOptions.at(CopyConstants::TO_OPTION_NAME).getValue()); + parsingOptions.erase(CopyConstants::FROM_OPTION_NAME); + parsingOptions.erase(CopyConstants::TO_OPTION_NAME); + } + if (fileTypeInfo.fileType == FileType::CSV) { + auto csvConfig = CSVReaderConfig::construct(parsingOptions); + csvConfig.option.autoDetection = false; + auto optionsMap = csvConfig.option.toOptionsMap(csvConfig.parallel); + if (!copyFromOptions.empty()) { + optionsMap.insert(copyFromOptions.begin(), copyFromOptions.end()); + } + query = + stringFormat("COPY `{}` {} FROM \"{}\" {};", copyFromStatement.getTableName(), + columnNames, copyFilePath, CSVOption::toCypher(optionsMap)); + } else { + query = + stringFormat("COPY `{}` {} FROM \"{}\" {};", copyFromStatement.getTableName(), + columnNames, copyFilePath, CSVOption::toCypher(copyFromOptions)); + } + finalQueryStatements += query; + } + } + return std::make_unique(boundFilePath, finalQueryStatements, + getQueryFromFile(fs, boundFilePath, PortDBConstants::INDEX_FILE_NAME, clientContext)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_projection_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_projection_clause.cpp new file mode 100644 index 0000000000..31266f0849 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_projection_clause.cpp @@ -0,0 +1,308 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/lambda_expression.h" +#include "binder/expression_visitor.h" +#include "binder/query/return_with_clause/bound_return_clause.h" +#include "binder/query/return_with_clause/bound_with_clause.h" +#include "common/exception/binder.h" +#include "parser/expression/parsed_property_expression.h" +#include "parser/query/return_with_clause/with_clause.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +void validateColumnNamesAreUnique(const std::vector& columnNames) { + auto existColumnNames = std::unordered_set(); + for (auto& name : columnNames) { + if (existColumnNames.contains(name)) { + throw BinderException( + "Multiple result columns with the same name " + name + " are not supported."); + } + existColumnNames.insert(name); + } +} + +std::vector getColumnNames(const expression_vector& exprs, + const std::vector& aliases) { + std::vector columnNames; + for (auto i = 0u; i < exprs.size(); ++i) { + if (aliases[i].empty()) { + columnNames.push_back(exprs[i]->toString()); + } else { + columnNames.push_back(aliases[i]); + } + } + return columnNames; +} + +static void validateOrderByFollowedBySkipOrLimitInWithClause( + const BoundProjectionBody& boundProjectionBody) { + auto hasSkipOrLimit = boundProjectionBody.hasSkip() || boundProjectionBody.hasLimit(); + if (boundProjectionBody.hasOrderByExpressions() && !hasSkipOrLimit) { + throw BinderException("In WITH clause, ORDER BY must be followed by SKIP or LIMIT."); + } +} + +BoundWithClause Binder::bindWithClause(const WithClause& withClause) { + auto projectionBody = withClause.getProjectionBody(); + auto [projectionExprs, aliases] = bindProjectionList(*projectionBody); + // Check all expressions are aliased + for (auto& alias : aliases) { + if (alias.empty()) { + throw BinderException("Expression in WITH must be aliased (use AS)."); + } + } + auto columnNames = getColumnNames(projectionExprs, aliases); + validateColumnNamesAreUnique(columnNames); + auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExprs, aliases); + validateOrderByFollowedBySkipOrLimitInWithClause(boundProjectionBody); + // Update scope + scope.clear(); + for (auto i = 0u; i < projectionExprs.size(); ++i) { + addToScope(aliases[i], projectionExprs[i]); + } + auto boundWithClause = BoundWithClause(std::move(boundProjectionBody)); + if (withClause.hasWhereExpression()) { + boundWithClause.setWhereExpression(bindWhereExpression(*withClause.getWhereExpression())); + } + return boundWithClause; +} + +BoundReturnClause Binder::bindReturnClause(const ReturnClause& returnClause) { + auto projectionBody = returnClause.getProjectionBody(); + auto [projectionExprs, aliases] = bindProjectionList(*projectionBody); + auto columnNames = getColumnNames(projectionExprs, aliases); + auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExprs, aliases); + auto statementResult = BoundStatementResult(); + KU_ASSERT(columnNames.size() == projectionExprs.size()); + for (auto i = 0u; i < columnNames.size(); ++i) { + statementResult.addColumn(columnNames[i], projectionExprs[i]); + } + return BoundReturnClause(std::move(boundProjectionBody), std::move(statementResult)); +} + +static expression_vector getAggregateExpressions(const std::shared_ptr& expression, + const BinderScope& scope) { + expression_vector result; + if (expression->hasAlias() && scope.contains(expression->getAlias())) { + return result; + } + if (expression->expressionType == ExpressionType::AGGREGATE_FUNCTION) { + result.push_back(expression); + return result; + } + for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) { + for (auto& expr : getAggregateExpressions(child, scope)) { + result.push_back(expr); + } + } + return result; +} + +std::pair> Binder::bindProjectionList( + const ProjectionBody& projectionBody) { + expression_vector projectionExprs; + std::vector aliases; + for (auto& parsedExpr : projectionBody.getProjectionExpressions()) { + if (parsedExpr->getExpressionType() == ExpressionType::STAR) { + // Rewrite star expression as all expression in scope. + if (scope.empty()) { + throw BinderException( + "RETURN or WITH * is not allowed when there are no variables in scope."); + } + for (auto& expr : scope.getExpressions()) { + projectionExprs.push_back(expr); + aliases.push_back(expr->getAlias()); + } + } else if (parsedExpr->getExpressionType() == ExpressionType::PROPERTY) { + auto& propExpr = parsedExpr->constCast(); + if (propExpr.isStar()) { + // Rewrite property star expression + for (auto& expr : expressionBinder.bindPropertyStarExpression(*parsedExpr)) { + projectionExprs.push_back(expr); + aliases.push_back(""); + } + } else { + auto expr = expressionBinder.bindExpression(*parsedExpr); + projectionExprs.push_back(expr); + aliases.push_back(parsedExpr->getAlias()); + } + } else { + auto expr = expressionBinder.bindExpression(*parsedExpr); + projectionExprs.push_back(expr); + aliases.push_back(parsedExpr->hasAlias() ? parsedExpr->getAlias() : expr->getAlias()); + } + } + return {projectionExprs, aliases}; +} + +class NestedAggCollector final : public ExpressionVisitor { +public: + expression_vector exprs; + +protected: + void visitAggFunctionExpr(std::shared_ptr expr) override { exprs.push_back(expr); } + void visitChildren(const Expression& expr) override { + switch (expr.expressionType) { + case ExpressionType::CASE_ELSE: { + visitCaseExprChildren(expr); + } break; + case ExpressionType::LAMBDA: { + auto& lambda = expr.constCast(); + visit(lambda.getFunctionExpr()); + } break; + case ExpressionType::AGGREGATE_FUNCTION: { + // We do not traverse the child or aggregate because nested agg is validated recursively + // e.g. WITH SUM(1) AS x WITH SUM(x) AS y RETURN SUM(y) + // when validating SUM(y) we only need to check y and not x. + } break; + default: { + for (auto& child : expr.getChildren()) { + visit(child); + } + } + } + } +}; + +static void validateNestedAggregate(const Expression& expr, const BinderScope& scope) { + KU_ASSERT(expr.expressionType == ExpressionType::AGGREGATE_FUNCTION); + if (expr.getNumChildren() == 0) { // Skip COUNT(*) + return; + } + auto collector = NestedAggCollector(); + collector.visit(expr.getChild(0)); + for (auto& childAgg : collector.exprs) { + if (!scope.contains(childAgg->getAlias())) { + throw BinderException( + stringFormat("Expression {} contains nested aggregation.", expr.toString())); + } + } +} + +BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& projectionBody, + const expression_vector& projectionExprs, const std::vector& aliases) { + expression_vector groupByExprs; + expression_vector aggregateExprs; + KU_ASSERT(projectionExprs.size() == aliases.size()); + for (auto i = 0u; i < projectionExprs.size(); ++i) { + auto expr = projectionExprs[i]; + auto aggExprs = getAggregateExpressions(expr, scope); + if (!aggExprs.empty()) { + for (auto& agg : aggExprs) { + aggregateExprs.push_back(agg); + } + } else { + groupByExprs.push_back(expr); + } + expr->setAlias(aliases[i]); + } + + auto boundProjectionBody = BoundProjectionBody(projectionBody.getIsDistinct()); + boundProjectionBody.setProjectionExpressions(projectionExprs); + + if (!aggregateExprs.empty()) { + for (auto& expr : aggregateExprs) { + validateNestedAggregate(*expr, scope); + } + if (!groupByExprs.empty()) { + // TODO(Xiyang): we can remove augment group by. But make sure we test sufficient + // including edge case and bug before release. + expression_vector augmentedGroupByExpressions = groupByExprs; + for (auto& expression : groupByExprs) { + if (ExpressionUtil::isNodePattern(*expression)) { + auto& node = expression->constCast(); + augmentedGroupByExpressions.push_back(node.getInternalID()); + } else if (ExpressionUtil::isRelPattern(*expression)) { + auto& rel = expression->constCast(); + augmentedGroupByExpressions.push_back(rel.getInternalID()); + } + } + boundProjectionBody.setGroupByExpressions(std::move(augmentedGroupByExpressions)); + } + boundProjectionBody.setAggregateExpressions(std::move(aggregateExprs)); + } + // Bind order by + if (projectionBody.hasOrderByExpressions()) { + // Cypher rule of ORDER BY expression scope: if projection contains aggregation, only + // expressions in projection are available. Otherwise, expressions before projection are + // also available + expression_vector orderByExprs; + if (boundProjectionBody.hasAggregateExpressions() || boundProjectionBody.isDistinct()) { + scope.clear(); + KU_ASSERT(projectionBody.getProjectionExpressions().size() == projectionExprs.size()); + std::vector tmpAliases; + for (auto& expr : projectionBody.getProjectionExpressions()) { + tmpAliases.push_back(expr->hasAlias() ? expr->getAlias() : expr->toString()); + } + addToScope(tmpAliases, projectionExprs); + expressionBinder.config.bindOrderByAfterAggregate = true; + orderByExprs = bindOrderByExpressions(projectionBody.getOrderByExpressions()); + expressionBinder.config.bindOrderByAfterAggregate = false; + } else { + addToScope(aliases, projectionExprs); + orderByExprs = bindOrderByExpressions(projectionBody.getOrderByExpressions()); + } + boundProjectionBody.setOrderByExpressions(std::move(orderByExprs), + projectionBody.getSortOrders()); + } + // Bind skip + if (projectionBody.hasSkipExpression()) { + boundProjectionBody.setSkipNumber( + bindSkipLimitExpression(*projectionBody.getSkipExpression())); + } + // Bind limit + if (projectionBody.hasLimitExpression()) { + boundProjectionBody.setLimitNumber( + bindSkipLimitExpression(*projectionBody.getLimitExpression())); + } + return boundProjectionBody; +} + +static bool isOrderByKeyTypeSupported(const LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::LIST: + case LogicalTypeID::ARRAY: + case LogicalTypeID::STRUCT: + case LogicalTypeID::MAP: + case LogicalTypeID::UNION: + case LogicalTypeID::POINTER: + return false; + default: + return true; + } +} + +expression_vector Binder::bindOrderByExpressions( + const std::vector>& parsedExprs) { + expression_vector exprs; + for (auto& parsedExpr : parsedExprs) { + auto expr = expressionBinder.bindExpression(*parsedExpr); + if (!isOrderByKeyTypeSupported(expr->dataType)) { + throw BinderException(stringFormat("Cannot order by {}. Order by {} is not supported.", + expr->toString(), expr->dataType.toString())); + } + exprs.push_back(std::move(expr)); + } + return exprs; +} + +std::shared_ptr Binder::bindSkipLimitExpression(const ParsedExpression& expression) { + auto boundExpression = expressionBinder.bindExpression(expression); + if (boundExpression->expressionType != ExpressionType::LITERAL && + boundExpression->expressionType != ExpressionType::PARAMETER) { + throw BinderException( + "The number of rows to skip/limit must be a parameter/literal expression."); + } + return boundExpression; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_query.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_query.cpp new file mode 100644 index 0000000000..45a305aa3d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_query.cpp @@ -0,0 +1,106 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/query/return_with_clause/bound_return_clause.h" +#include "binder/query/return_with_clause/bound_with_clause.h" +#include "common/exception/binder.h" +#include "parser/query/regular_query.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +void validateUnionColumnsOfTheSameType( + const std::vector& normalizedSingleQueries) { + if (normalizedSingleQueries.size() <= 1) { + return; + } + auto columns = normalizedSingleQueries[0].getStatementResult()->getColumns(); + for (auto i = 1u; i < normalizedSingleQueries.size(); i++) { + auto otherColumns = normalizedSingleQueries[i].getStatementResult()->getColumns(); + if (columns.size() != otherColumns.size()) { + throw BinderException("The number of columns to union/union all must be the same."); + } + // Check whether the dataTypes in union expressions are exactly the same in each single + // query. + for (auto j = 0u; j < columns.size(); j++) { + ExpressionUtil::validateDataType(*otherColumns[j], columns[j]->getDataType()); + } + } +} + +void validateIsAllUnionOrUnionAll(const BoundRegularQuery& regularQuery) { + auto unionAllExpressionCounter = 0u; + for (auto i = 0u; i < regularQuery.getNumSingleQueries() - 1; i++) { + unionAllExpressionCounter += regularQuery.getIsUnionAll(i); + } + if ((0 < unionAllExpressionCounter) && + (unionAllExpressionCounter < regularQuery.getNumSingleQueries() - 1)) { + throw BinderException("Union and union all can not be used together."); + } +} + +std::unique_ptr Binder::bindQuery(const Statement& statement) { + auto& regularQuery = statement.constCast(); + std::vector normalizedSingleQueries; + for (auto i = 0u; i < regularQuery.getNumSingleQueries(); i++) { + // Don't clear scope within bindSingleQuery() yet because it is also used for subquery + // binding. + scope.clear(); + normalizedSingleQueries.push_back(bindSingleQuery(*regularQuery.getSingleQuery(i))); + } + validateUnionColumnsOfTheSameType(normalizedSingleQueries); + KU_ASSERT(!normalizedSingleQueries.empty()); + auto boundRegularQuery = std::make_unique(regularQuery.getIsUnionAll(), + normalizedSingleQueries[0].getStatementResult()->copy()); + for (auto& normalizedSingleQuery : normalizedSingleQueries) { + boundRegularQuery->addSingleQuery(std::move(normalizedSingleQuery)); + } + validateIsAllUnionOrUnionAll(*boundRegularQuery); + return boundRegularQuery; +} + +NormalizedSingleQuery Binder::bindSingleQuery(const SingleQuery& singleQuery) { + auto normalizedSingleQuery = NormalizedSingleQuery(); + for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) { + normalizedSingleQuery.appendQueryPart(bindQueryPart(*singleQuery.getQueryPart(i))); + } + auto lastQueryPart = NormalizedQueryPart(); + for (auto i = 0u; i < singleQuery.getNumReadingClauses(); i++) { + lastQueryPart.addReadingClause(bindReadingClause(*singleQuery.getReadingClause(i))); + } + for (auto i = 0u; i < singleQuery.getNumUpdatingClauses(); ++i) { + lastQueryPart.addUpdatingClause(bindUpdatingClause(*singleQuery.getUpdatingClause(i))); + } + auto statementResult = BoundStatementResult(); + if (singleQuery.hasReturnClause()) { + auto boundReturnClause = bindReturnClause(*singleQuery.getReturnClause()); + lastQueryPart.setProjectionBody(boundReturnClause.getProjectionBody()->copy()); + statementResult = boundReturnClause.getStatementResult()->copy(); + } else { + statementResult = BoundStatementResult::createEmptyResult(); + } + normalizedSingleQuery.appendQueryPart(std::move(lastQueryPart)); + normalizedSingleQuery.setStatementResult(std::move(statementResult)); + return normalizedSingleQuery; +} + +NormalizedQueryPart Binder::bindQueryPart(const QueryPart& queryPart) { + auto normalizedQueryPart = NormalizedQueryPart(); + for (auto i = 0u; i < queryPart.getNumReadingClauses(); i++) { + normalizedQueryPart.addReadingClause(bindReadingClause(*queryPart.getReadingClause(i))); + } + for (auto i = 0u; i < queryPart.getNumUpdatingClauses(); ++i) { + normalizedQueryPart.addUpdatingClause(bindUpdatingClause(*queryPart.getUpdatingClause(i))); + } + auto boundWithClause = bindWithClause(*queryPart.getWithClause()); + normalizedQueryPart.setProjectionBody(boundWithClause.getProjectionBody()->copy()); + if (boundWithClause.hasWhereExpression()) { + normalizedQueryPart.setProjectionBodyPredicate(boundWithClause.getWhereExpression()); + } + return normalizedQueryPart; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_reading_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_reading_clause.cpp new file mode 100644 index 0000000000..d22010e40d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_reading_clause.cpp @@ -0,0 +1,30 @@ +#include "binder/binder.h" +#include "parser/query/reading_clause/reading_clause.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindReadingClause(const ReadingClause& readingClause) { + switch (readingClause.getClauseType()) { + case ClauseType::MATCH: { + return bindMatchClause(readingClause); + } + case ClauseType::UNWIND: { + return bindUnwindClause(readingClause); + } + case ClauseType::IN_QUERY_CALL: { + return bindInQueryCall(readingClause); + } + case ClauseType::LOAD_FROM: { + return bindLoadFrom(readingClause); + } + default: + KU_UNREACHABLE; + } +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_standalone_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_standalone_call.cpp new file mode 100644 index 0000000000..c006540530 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_standalone_call.cpp @@ -0,0 +1,44 @@ +#include "binder/binder.h" +#include "binder/bound_standalone_call.h" +#include "binder/expression/expression_util.h" +#include "binder/expression_visitor.h" +#include "common/cast.h" +#include "common/exception/binder.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "parser/standalone_call.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindStandaloneCall(const parser::Statement& statement) { + auto& callStatement = ku_dynamic_cast(statement); + const main::Option* option = main::DBConfig::getOptionByName(callStatement.getOptionName()); + if (option == nullptr) { + option = clientContext->getExtensionOption(callStatement.getOptionName()); + } + if (option == nullptr) { + throw BinderException{"Invalid option name: " + callStatement.getOptionName() + "."}; + } + auto optionValue = expressionBinder.bindExpression(*callStatement.getOptionValue()); + ExpressionUtil::validateExpressionType(*optionValue, ExpressionType::LITERAL); + if (LogicalTypeUtils::isFloatingPoint(optionValue->dataType.getLogicalTypeID()) && + LogicalTypeUtils::isIntegral(LogicalType(option->parameterType))) { + throw BinderException{stringFormat( + "Expression {} has data type {} but expected {}. Implicit cast is not supported.", + optionValue->toString(), + LogicalTypeUtils::toString(optionValue->dataType.getLogicalTypeID()), + LogicalTypeUtils::toString(option->parameterType))}; + } + optionValue = + expressionBinder.implicitCastIfNecessary(optionValue, LogicalType(option->parameterType)); + if (ConstantExpressionVisitor::needFold(*optionValue)) { + optionValue = expressionBinder.foldExpression(optionValue); + } + return std::make_unique(option, std::move(optionValue)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_standalone_call_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_standalone_call_function.cpp new file mode 100644 index 0000000000..40ed05f28b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_standalone_call_function.cpp @@ -0,0 +1,35 @@ +#include "binder/binder.h" +#include "binder/bound_standalone_call_function.h" +#include "catalog/catalog.h" +#include "common/exception/binder.h" +#include "main/client_context.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/standalone_call_function.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindStandaloneCallFunction( + const parser::Statement& statement) { + auto& callStatement = statement.constCast(); + auto& funcExpr = + callStatement.getFunctionExpression()->constCast(); + auto funcName = funcExpr.getFunctionName(); + auto catalog = catalog::Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto entry = + catalog->getFunctionEntry(transaction, funcName, clientContext->useInternalCatalogEntry()); + KU_ASSERT(entry); + if (entry->getType() != catalog::CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY) { + throw BinderException( + "Only standalone table functions can be called without return statement."); + } + auto boundTableFunction = bindTableFunc(funcName, funcExpr, {} /* yieldVariables */); + return std::make_unique(std::move(boundTableFunction)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_table_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_table_function.cpp new file mode 100644 index 0000000000..2277c1b6db --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_table_function.cpp @@ -0,0 +1,77 @@ +#include "binder/binder.h" +#include "binder/bound_table_scan_info.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/literal_expression.h" +#include "catalog/catalog.h" +#include "function/built_in_function_utils.h" +#include "main/client_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace binder { + +BoundTableScanInfo Binder::bindTableFunc(const std::string& tableFuncName, + const parser::ParsedExpression& expr, std::vector yieldVariables) { + auto catalog = catalog::Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto entry = catalog->getFunctionEntry(transaction, tableFuncName, + clientContext->useInternalCatalogEntry()); + expression_vector positionalParams; + std::vector positionalParamTypes; + optional_params_t optionalParams; + expression_vector optionalParamsLegacy; + for (auto i = 0u; i < expr.getNumChildren(); i++) { + auto& childExpr = *expr.getChild(i); + auto param = expressionBinder.bindExpression(childExpr); + ExpressionUtil::validateExpressionType(*param, + {ExpressionType::LITERAL, ExpressionType::PARAMETER, ExpressionType::PATTERN}); + if (!childExpr.hasAlias()) { + positionalParams.push_back(param); + positionalParamTypes.push_back(param->getDataType().copy()); + } else { + if (param->expressionType == ExpressionType::LITERAL) { + auto literalExpr = param->constPtrCast(); + optionalParams.emplace(childExpr.getAlias(), literalExpr->getValue()); + } + param->setAlias(expr.getChild(i)->getAlias()); + optionalParamsLegacy.push_back(param); + } + } + auto func = BuiltInFunctionsUtils::matchFunction(tableFuncName, positionalParamTypes, + entry->ptrCast()); + auto tableFunc = func->constPtrCast(); + std::vector inputTypes; + if (tableFunc->inferInputTypes) { + // For functions which take in nested data types, we have to use the input parameters to + // detect the input types. (E.g. query_hnsw_index takes in an ARRAY which needs the user + // input parameters to decide the array dimension). + inputTypes = tableFunc->inferInputTypes(positionalParams); + } else { + // For functions which don't have nested type parameters, we can simply use the types + // declared in the function signature. + for (auto i = 0u; i < tableFunc->parameterTypeIDs.size(); i++) { + inputTypes.push_back(LogicalType(tableFunc->parameterTypeIDs[i])); + } + } + for (auto i = 0u; i < positionalParams.size(); ++i) { + auto parameterTypeID = tableFunc->parameterTypeIDs[i]; + if (positionalParams[i]->expressionType == ExpressionType::LITERAL && + parameterTypeID != LogicalTypeID::ANY) { + positionalParams[i] = expressionBinder.foldExpression( + expressionBinder.implicitCastIfNecessary(positionalParams[i], inputTypes[i])); + } + } + auto bindInput = TableFuncBindInput(); + bindInput.params = std::move(positionalParams); + bindInput.optionalParams = std::move(optionalParams); + bindInput.optionalParamsLegacy = std::move(optionalParamsLegacy); + bindInput.binder = this; + bindInput.yieldVariables = std::move(yieldVariables); + return BoundTableScanInfo{*tableFunc, tableFunc->bindFunc(clientContext, &bindInput)}; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_transaction.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_transaction.cpp new file mode 100644 index 0000000000..21277ddadc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_transaction.cpp @@ -0,0 +1,16 @@ +#include "binder/binder.h" +#include "binder/bound_transaction_statement.h" +#include "parser/transaction_statement.h" + +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindTransaction(const Statement& statement) { + auto& transactionStatement = statement.constCast(); + return std::make_unique(transactionStatement.getTransactionAction()); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_updating_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_updating_clause.cpp new file mode 100644 index 0000000000..d5686b412b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_updating_clause.cpp @@ -0,0 +1,387 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/property_expression.h" +#include "binder/query/query_graph_label_analyzer.h" +#include "binder/query/updating_clause/bound_delete_clause.h" +#include "binder/query/updating_clause/bound_insert_clause.h" +#include "binder/query/updating_clause/bound_merge_clause.h" +#include "binder/query/updating_clause/bound_set_clause.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/assert.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "parser/query/updating_clause/delete_clause.h" +#include "parser/query/updating_clause/insert_clause.h" +#include "parser/query/updating_clause/merge_clause.h" +#include "parser/query/updating_clause/set_clause.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindUpdatingClause( + const UpdatingClause& updatingClause) { + switch (updatingClause.getClauseType()) { + case ClauseType::INSERT: { + return bindInsertClause(updatingClause); + } + case ClauseType::MERGE: { + return bindMergeClause(updatingClause); + } + case ClauseType::SET: { + return bindSetClause(updatingClause); + } + case ClauseType::DELETE_: { + return bindDeleteClause(updatingClause); + } + default: + KU_UNREACHABLE; + } +} + +static std::unordered_set populatePatternsScope(const BinderScope& scope) { + std::unordered_set result; + for (auto& expression : scope.getExpressions()) { + if (ExpressionUtil::isNodePattern(*expression) || + ExpressionUtil::isRelPattern(*expression)) { + result.insert(expression->toString()); + } else if (expression->expressionType == ExpressionType::VARIABLE) { + if (scope.hasNodeReplacement(expression->toString())) { + result.insert(expression->toString()); + } + } + } + return result; +} + +std::unique_ptr Binder::bindInsertClause( + const UpdatingClause& updatingClause) { + auto& insertClause = updatingClause.constCast(); + auto patternsScope = populatePatternsScope(scope); + // bindGraphPattern will update scope. + auto boundGraphPattern = bindGraphPattern(insertClause.getPatternElementsRef()); + auto insertInfos = bindInsertInfos(boundGraphPattern.queryGraphCollection, patternsScope); + return std::make_unique(std::move(insertInfos)); +} + +static expression_vector getColumnDataExprs(QueryGraphCollection& collection) { + expression_vector exprs; + for (auto i = 0u; i < collection.getNumQueryGraphs(); ++i) { + auto queryGraph = collection.getQueryGraph(i); + for (auto& pattern : queryGraph->getAllPatterns()) { + for (auto& [_, rhs] : pattern->getPropertyDataExprRef()) { + exprs.push_back(rhs); + } + } + } + return exprs; +} + +std::unique_ptr Binder::bindMergeClause(const UpdatingClause& updatingClause) { + auto& mergeClause = updatingClause.constCast(); + auto patternsScope = populatePatternsScope(scope); + // bindGraphPattern will update scope. + auto boundGraphPattern = bindGraphPattern(mergeClause.getPatternElementsRef()); + auto columnDataExprs = getColumnDataExprs(boundGraphPattern.queryGraphCollection); + // Rewrite key value pairs in MATCH clause as predicate + rewriteMatchPattern(boundGraphPattern); + auto existenceMark = + expressionBinder.createVariableExpression(LogicalType::BOOL(), std::string("__existence")); + auto distinctMark = + expressionBinder.createVariableExpression(LogicalType::BOOL(), std::string("__distinct")); + auto createInfos = bindInsertInfos(boundGraphPattern.queryGraphCollection, patternsScope); + auto boundMergeClause = + std::make_unique(columnDataExprs, std::move(existenceMark), + std::move(distinctMark), std::move(boundGraphPattern.queryGraphCollection), + std::move(boundGraphPattern.where), std::move(createInfos)); + if (mergeClause.hasOnMatchSetItems()) { + for (auto& [lhs, rhs] : mergeClause.getOnMatchSetItemsRef()) { + auto setPropertyInfo = bindSetPropertyInfo(lhs.get(), rhs.get()); + boundMergeClause->addOnMatchSetPropertyInfo(std::move(setPropertyInfo)); + } + } + if (mergeClause.hasOnCreateSetItems()) { + for (auto& [lhs, rhs] : mergeClause.getOnCreateSetItemsRef()) { + auto setPropertyInfo = bindSetPropertyInfo(lhs.get(), rhs.get()); + boundMergeClause->addOnCreateSetPropertyInfo(std::move(setPropertyInfo)); + } + } + return boundMergeClause; +} + +std::vector Binder::bindInsertInfos(QueryGraphCollection& queryGraphCollection, + const std::unordered_set& patternsInScope_) { + auto patternsInScope = patternsInScope_; + std::vector result; + auto analyzer = QueryGraphLabelAnalyzer(*clientContext, true /* throwOnViolate */); + for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) { + auto queryGraph = queryGraphCollection.getQueryGraphUnsafe(i); + // Ensure query graph does not violate declared schema. + analyzer.pruneLabel(*queryGraph); + for (auto j = 0u; j < queryGraph->getNumQueryNodes(); ++j) { + auto node = queryGraph->getQueryNode(j); + if (node->getVariableName().empty()) { // Always create anonymous node. + bindInsertNode(node, result); + continue; + } + if (patternsInScope.contains(node->getVariableName())) { + continue; + } + patternsInScope.insert(node->getVariableName()); + bindInsertNode(node, result); + } + for (auto j = 0u; j < queryGraph->getNumQueryRels(); ++j) { + auto rel = queryGraph->getQueryRel(j); + if (rel->getVariableName().empty()) { // Always create anonymous rel. + bindInsertRel(rel, result); + continue; + } + if (patternsInScope.contains(rel->getVariableName())) { + continue; + } + patternsInScope.insert(rel->getVariableName()); + bindInsertRel(rel, result); + } + } + if (result.empty()) { + throw BinderException("Cannot resolve any node or relationship to create."); + } + return result; +} + +static void validatePrimaryKeyExistence(const NodeTableCatalogEntry* nodeTableEntry, + const NodeExpression& node, const expression_vector& defaultExprs) { + auto primaryKeyName = nodeTableEntry->getPrimaryKeyName(); + auto pkeyDefaultExpr = defaultExprs.at(nodeTableEntry->getPrimaryKeyID()); + if (!node.hasPropertyDataExpr(primaryKeyName) && + ExpressionUtil::isNullLiteral(*pkeyDefaultExpr)) { + throw BinderException(stringFormat("Create node {} expects primary key {} as input.", + node.toString(), primaryKeyName)); + } +} + +void Binder::bindInsertNode(std::shared_ptr node, + std::vector& infos) { + if (node->isMultiLabeled()) { + throw BinderException( + "Create node " + node->toString() + " with multiple node labels is not supported."); + } + if (node->isEmpty()) { + throw BinderException( + "Create node " + node->toString() + " with empty node labels is not supported."); + } + KU_ASSERT(node->getNumEntries() == 1); + auto entry = node->getEntry(0); + KU_ASSERT(entry->getTableType() == TableType::NODE); + auto insertInfo = BoundInsertInfo(TableType::NODE, node); + for (auto& property : node->getPropertyExpressions()) { + if (property->hasProperty(entry->getTableID())) { + insertInfo.columnExprs.push_back(property); + } + } + insertInfo.columnDataExprs = + bindInsertColumnDataExprs(node->getPropertyDataExprRef(), entry->getProperties()); + auto nodeEntry = entry->ptrCast(); + validatePrimaryKeyExistence(nodeEntry, *node, insertInfo.columnDataExprs); + // Check extension secondary index loaded + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + for (auto indexEntry : catalog->getIndexEntries(transaction, nodeEntry->getTableID())) { + if (!indexEntry->isLoaded()) { + throw BinderException(stringFormat( + "Trying to insert into an index on table {} but its extension is not loaded.", + nodeEntry->getName())); + } + } + infos.push_back(std::move(insertInfo)); +} + +static TableCatalogEntry* tryPruneMultiLabeled(const RelExpression& rel, + const TableCatalogEntry& srcEntry, const TableCatalogEntry& dstEntry) { + std::vector candidates; + for (auto& entry : rel.getEntries()) { + KU_ASSERT(entry->getType() == CatalogEntryType::REL_GROUP_ENTRY); + auto& relEntry = entry->constCast(); + if (relEntry.hasRelEntryInfo(srcEntry.getTableID(), dstEntry.getTableID())) { + candidates.push_back(entry); + } + } + if (candidates.size() > 1) { + throw BinderException(stringFormat( + "Create rel {} with multiple rel labels is not supported.", rel.toString())); + } + if (candidates.size() == 0) { + throw BinderException( + stringFormat("Cannot find a valid label in {} that connects {} and {}.", rel.toString(), + srcEntry.getName(), dstEntry.getName())); + } + return candidates[0]; +} + +void Binder::bindInsertRel(std::shared_ptr rel, + std::vector& infos) { + if (rel->isBoundByMultiLabeledNode()) { + throw BinderException(stringFormat( + "Create rel {} bound by multiple node labels is not supported.", rel->toString())); + } + if (rel->getDirectionType() == RelDirectionType::BOTH) { + throw BinderException(stringFormat("Create undirected relationship is not supported. " + "Try create 2 directed relationships instead.")); + } + if (ExpressionUtil::isRecursiveRelPattern(*rel)) { + throw BinderException(stringFormat("Cannot create recursive rel {}.", rel->toString())); + } + TableCatalogEntry* entry = nullptr; + if (!rel->isMultiLabeled()) { + KU_ASSERT(rel->getNumEntries() == 1); + entry = rel->getEntry(0); + } else { + auto srcEntry = rel->getSrcNode()->getEntry(0); + auto dstEntry = rel->getDstNode()->getEntry(0); + entry = tryPruneMultiLabeled(*rel, *srcEntry, *dstEntry); + } + rel->setEntries(std::vector{entry}); + auto insertInfo = BoundInsertInfo(TableType::REL, rel); + // Because we might prune entries, some property exprs may belong to pruned entry + for (auto& p : entry->getProperties()) { + insertInfo.columnExprs.push_back(rel->getPropertyExpression(p.getName())); + } + insertInfo.columnDataExprs = + bindInsertColumnDataExprs(rel->getPropertyDataExprRef(), entry->getProperties()); + infos.push_back(std::move(insertInfo)); +} + +expression_vector Binder::bindInsertColumnDataExprs( + const case_insensitive_map_t>& propertyDataExprs, + const std::vector& propertyDefinitions) { + expression_vector result; + for (auto& definition : propertyDefinitions) { + std::shared_ptr rhs; + if (propertyDataExprs.contains(definition.getName())) { + rhs = propertyDataExprs.at(definition.getName()); + } else { + rhs = expressionBinder.bindExpression(*definition.defaultExpr); + } + rhs = expressionBinder.implicitCastIfNecessary(rhs, definition.getType()); + result.push_back(std::move(rhs)); + } + return result; +} + +std::unique_ptr Binder::bindSetClause(const UpdatingClause& updatingClause) { + auto& setClause = updatingClause.constCast(); + auto boundSetClause = std::make_unique(); + for (auto& setItem : setClause.getSetItemsRef()) { + boundSetClause->addInfo(bindSetPropertyInfo(setItem.first.get(), setItem.second.get())); + } + return boundSetClause; +} + +BoundSetPropertyInfo Binder::bindSetPropertyInfo(const ParsedExpression* column, + const ParsedExpression* columnData) { + auto expr = expressionBinder.bindExpression(*column->getChild(0)); + auto isNode = ExpressionUtil::isNodePattern(*expr); + auto isRel = ExpressionUtil::isRelPattern(*expr); + if (!isNode && !isRel) { + throw BinderException( + stringFormat("Cannot set expression {} with type {}. Expect node or rel pattern.", + expr->toString(), ExpressionTypeUtil::toString(expr->expressionType))); + } + auto boundSetItem = bindSetItem(column, columnData); + auto boundColumn = boundSetItem.first; + auto boundColumnData = boundSetItem.second; + auto& nodeOrRel = expr->constCast(); + auto& property = boundSetItem.first->constCast(); + // Check secondary index constraint + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + for (auto entry : nodeOrRel.getEntries()) { + // When setting multi labeled node, skip checking if property is not in current table. + if (!property.hasProperty(entry->getTableID())) { + continue; + } + auto propertyID = entry->getPropertyID(property.getPropertyName()); + if (catalog->containsUnloadedIndex(transaction, entry->getTableID(), propertyID)) { + throw BinderException( + stringFormat("Cannot set property {} in table {} because it is used in one or more " + "indexes which is unloaded.", + property.getPropertyName(), entry->getName())); + } + } + // Check primary key constraint + if (isNode) { + for (auto entry : nodeOrRel.getEntries()) { + if (property.isPrimaryKey(entry->getTableID())) { + throw BinderException( + stringFormat("Cannot set property {} in table {} because it is used as primary " + "key. Try delete and then insert.", + property.getPropertyName(), entry->getName())); + } + } + return BoundSetPropertyInfo(TableType::NODE, expr, boundColumn, boundColumnData); + } + return BoundSetPropertyInfo(TableType::REL, expr, boundColumn, boundColumnData); +} + +expression_pair Binder::bindSetItem(const ParsedExpression* column, + const ParsedExpression* columnData) { + auto boundColumn = expressionBinder.bindExpression(*column); + auto boundColumnData = expressionBinder.bindExpression(*columnData); + boundColumnData = + expressionBinder.implicitCastIfNecessary(boundColumnData, boundColumn->dataType); + return make_pair(std::move(boundColumn), std::move(boundColumnData)); +} + +std::unique_ptr Binder::bindDeleteClause( + const UpdatingClause& updatingClause) { + auto& deleteClause = updatingClause.constCast(); + auto deleteType = deleteClause.getDeleteClauseType(); + auto boundDeleteClause = std::make_unique(); + for (auto i = 0u; i < deleteClause.getNumExpressions(); ++i) { + auto pattern = expressionBinder.bindExpression(*deleteClause.getExpression(i)); + if (ExpressionUtil::isNodePattern(*pattern)) { + auto deleteNodeInfo = BoundDeleteInfo(deleteType, TableType::NODE, pattern); + auto& node = pattern->constCast(); + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + for (auto entry : node.getEntries()) { + for (auto index : catalog->getIndexEntries(transaction, entry->getTableID())) { + if (!index->isLoaded()) { + throw BinderException( + stringFormat("Trying to delete from an index on table {} but its " + "extension is not loaded.", + entry->getName())); + } + } + } + boundDeleteClause->addInfo(std::move(deleteNodeInfo)); + } else if (ExpressionUtil::isRelPattern(*pattern)) { + // LCOV_EXCL_START + if (deleteClause.getDeleteClauseType() == DeleteNodeType::DETACH_DELETE) { + throw BinderException("Detach delete on rel tables is not supported."); + } + // LCOV_EXCL_STOP + auto rel = pattern->constPtrCast(); + if (rel->getDirectionType() == RelDirectionType::BOTH) { + throw BinderException("Delete undirected rel is not supported."); + } + auto deleteRel = BoundDeleteInfo(deleteType, TableType::REL, pattern); + boundDeleteClause->addInfo(std::move(deleteRel)); + } else { + throw BinderException(stringFormat( + "Cannot delete expression {} with type {}. Expect node or rel pattern.", + pattern->toString(), ExpressionTypeUtil::toString(pattern->expressionType))); + } + } + return boundDeleteClause; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_use_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_use_database.cpp new file mode 100644 index 0000000000..57dda7669d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/bind_use_database.cpp @@ -0,0 +1,14 @@ +#include "binder/binder.h" +#include "binder/bound_use_database.h" +#include "parser/use_database.h" + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindUseDatabase(const parser::Statement& statement) { + auto useDatabase = statement.constCast(); + return std::make_unique(useDatabase.getDBName()); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/CMakeLists.txt new file mode 100644 index 0000000000..c4b0da135d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_binder_bind_copy + OBJECT + bind_copy_to.cpp + bind_copy_from.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/bind_copy_from.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/bind_copy_from.cpp new file mode 100644 index 0000000000..93a6627f7c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/bind_copy_from.cpp @@ -0,0 +1,344 @@ +#include "binder/binder.h" +#include "binder/copy/bound_copy_from.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "parser/copy.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::function; + +namespace lbug { +namespace binder { + +static void throwTableNotExist(const std::string& tableName) { + throw BinderException(stringFormat("Table {} does not exist.", tableName)); +} + +std::unique_ptr Binder::bindLegacyCopyRelGroupFrom(const Statement& statement) { + auto& copyFrom = statement.constCast(); + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto tableName = copyFrom.getTableName(); + auto tableNameParts = common::StringUtils::split(tableName, "_"); + if (tableNameParts.size() != 3 || !catalog->containsTable(transaction, tableNameParts[0])) { + throwTableNotExist(tableName); + } + auto entry = catalog->getTableCatalogEntry(transaction, tableNameParts[0]); + if (entry->getType() != CatalogEntryType::REL_GROUP_ENTRY) { + throwTableNotExist(tableName); + } + auto relGroupEntry = entry->ptrCast(); + try { + return bindCopyRelFrom(copyFrom, *relGroupEntry, tableNameParts[1], tableNameParts[2]); + } catch (Exception& e) { + throwTableNotExist(tableName); + return nullptr; + } +} + +std::unique_ptr Binder::bindCopyFromClause(const Statement& statement) { + auto& copyStatement = statement.constCast(); + auto tableName = copyStatement.getTableName(); + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + if (!catalog->containsTable(transaction, tableName)) { + return bindLegacyCopyRelGroupFrom(statement); + } + auto tableEntry = catalog->getTableCatalogEntry(transaction, tableName); + switch (tableEntry->getType()) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + return bindCopyNodeFrom(statement, *tableEntry->ptrCast()); + } + case CatalogEntryType::REL_GROUP_ENTRY: { + auto entry = tableEntry->ptrCast(); + auto properties = entry->getProperties(); + KU_ASSERT(entry->getNumRelTables() > 0); + if (entry->getNumRelTables() == 1) { + auto fromToNodePair = entry->getSingleRelEntryInfo().nodePair; + auto fromTable = catalog->getTableCatalogEntry(transaction, fromToNodePair.srcTableID); + auto toTable = catalog->getTableCatalogEntry(transaction, fromToNodePair.dstTableID); + return bindCopyRelFrom(statement, *entry, fromTable->getName(), toTable->getName()); + } else { + auto options = bindParsingOptions(copyStatement.getParsingOptions()); + if (!options.contains(CopyConstants::FROM_OPTION_NAME) || + !options.contains(CopyConstants::TO_OPTION_NAME)) { + throw BinderException(stringFormat( + "The table {} has multiple FROM and TO pairs defined in the schema. A " + "specific pair of FROM and TO options is expected when copying data " + "into " + "the {} table.", + tableName, tableName)); + } + auto from = options.at(CopyConstants::FROM_OPTION_NAME).getValue(); + auto to = options.at(CopyConstants::TO_OPTION_NAME).getValue(); + return bindCopyRelFrom(statement, *entry, from, to); + } + } + default: { + KU_UNREACHABLE; + } + } +} + +static void bindExpectedNodeColumns(const NodeTableCatalogEntry& entry, + const CopyFromColumnInfo& info, std::vector& columnNames, + std::vector& columnTypes); +static void bindExpectedRelColumns(const RelGroupCatalogEntry& entry, + const NodeTableCatalogEntry& fromEntry, const NodeTableCatalogEntry& toEntry, + const CopyFromColumnInfo& info, std::vector& columnNames, + std::vector& columnTypes); + +static std::pair> matchColumnExpression( + const expression_vector& columns, const PropertyDefinition& property, + ExpressionBinder& expressionBinder) { + for (auto& column : columns) { + if (property.getName() == column->toString()) { + if (column->dataType == property.getType()) { + return {ColumnEvaluateType::REFERENCE, column}; + } else { + return {ColumnEvaluateType::CAST, + expressionBinder.forceCast(column, property.getType())}; + } + } + } + return {ColumnEvaluateType::DEFAULT, expressionBinder.bindExpression(*property.defaultExpr)}; +} + +BoundCopyFromInfo Binder::bindCopyNodeFromInfo(std::string tableName, + const std::vector& properties, const BaseScanSource* source, + const options_t& parsingOptions, const std::vector& expectedColumnNames, + const std::vector& expectedColumnTypes, bool byColumn) { + auto boundSource = + bindScanSource(source, parsingOptions, expectedColumnNames, expectedColumnTypes); + expression_vector warningDataExprs = boundSource->getWarningColumns(); + if (boundSource->type == ScanSourceType::FILE) { + auto bindData = boundSource->constCast() + .info.bindData->constPtrCast(); + if (byColumn && bindData->fileScanInfo.fileTypeInfo.fileType != FileType::NPY) { + throw BinderException(stringFormat("Copy by column with {} file type is not supported.", + bindData->fileScanInfo.fileTypeInfo.fileTypeStr)); + } + } + expression_vector columns; + std::vector evaluateTypes; + for (auto& property : properties) { + auto [evaluateType, column] = + matchColumnExpression(boundSource->getColumns(), property, expressionBinder); + columns.push_back(column); + evaluateTypes.push_back(evaluateType); + } + columns.insert(columns.end(), warningDataExprs.begin(), warningDataExprs.end()); + auto offset = + createInvisibleVariable(std::string(InternalKeyword::ROW_OFFSET), LogicalType::INT64()); + return BoundCopyFromInfo(tableName, TableType::NODE, std::move(boundSource), std::move(offset), + std::move(columns), std::move(evaluateTypes), nullptr /* extraInfo */); +} + +std::unique_ptr Binder::bindCopyNodeFrom(const Statement& statement, + NodeTableCatalogEntry& nodeTableEntry) { + auto& copyStatement = statement.constCast(); + // Check extension secondary index loaded + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + for (auto indexEntry : catalog->getIndexEntries(transaction, nodeTableEntry.getTableID())) { + if (!indexEntry->isLoaded()) { + throw BinderException(stringFormat( + "Trying to insert into an index on table {} but its extension is not loaded.", + nodeTableEntry.getName())); + } + } + // Bind expected columns based on catalog information. + std::vector expectedColumnNames; + std::vector expectedColumnTypes; + bindExpectedNodeColumns(nodeTableEntry, copyStatement.getCopyColumnInfo(), expectedColumnNames, + expectedColumnTypes); + auto boundCopyFromInfo = + bindCopyNodeFromInfo(nodeTableEntry.getName(), nodeTableEntry.getProperties(), + copyStatement.getSource(), copyStatement.getParsingOptions(), expectedColumnNames, + expectedColumnTypes, copyStatement.byColumn()); + return std::make_unique(std::move(boundCopyFromInfo)); +} + +static options_t getScanSourceOptions(const CopyFrom& copyFrom) { + options_t options; + static case_insensitve_set_t copyFromPairsOptions = {CopyConstants::FROM_OPTION_NAME, + CopyConstants::TO_OPTION_NAME}; + for (auto& option : copyFrom.getParsingOptions()) { + if (copyFromPairsOptions.contains(option.first)) { + continue; + } + options.emplace(option.first, option.second->copy()); + } + return options; +} + +BoundCopyFromInfo Binder::bindCopyRelFromInfo(std::string tableName, + const std::vector& properties, const BaseScanSource* source, + const options_t& parsingOptions, const std::vector& expectedColumnNames, + const std::vector& expectedColumnTypes, const NodeTableCatalogEntry* fromTable, + const NodeTableCatalogEntry* toTable) { + auto boundSource = + bindScanSource(source, parsingOptions, expectedColumnNames, expectedColumnTypes); + expression_vector warningDataExprs = boundSource->getWarningColumns(); + auto columns = boundSource->getColumns(); + auto offset = + createInvisibleVariable(std::string(InternalKeyword::ROW_OFFSET), LogicalType::INT64()); + auto srcOffset = createVariable(std::string(InternalKeyword::SRC_OFFSET), LogicalType::INT64()); + auto dstOffset = createVariable(std::string(InternalKeyword::DST_OFFSET), LogicalType::INT64()); + expression_vector columnExprs{srcOffset, dstOffset, offset}; + std::vector evaluateTypes{ColumnEvaluateType::REFERENCE, + ColumnEvaluateType::REFERENCE, ColumnEvaluateType::REFERENCE}; + for (auto i = 1u; i < properties.size(); ++i) { // skip internal ID + auto& property = properties[i]; + auto [evaluateType, column] = + matchColumnExpression(boundSource->getColumns(), property, expressionBinder); + columnExprs.push_back(column); + evaluateTypes.push_back(evaluateType); + } + columnExprs.insert(columnExprs.end(), warningDataExprs.begin(), warningDataExprs.end()); + std::shared_ptr srcKey = nullptr, dstKey = nullptr; + if (expectedColumnTypes[0] != columns[0]->getDataType()) { + srcKey = expressionBinder.forceCast(columns[0], expectedColumnTypes[0]); + } else { + srcKey = columns[0]; + } + if (expectedColumnTypes[1] != columns[1]->getDataType()) { + dstKey = expressionBinder.forceCast(columns[1], expectedColumnTypes[1]); + } else { + dstKey = columns[1]; + } + auto srcLookUpInfo = + IndexLookupInfo(fromTable->getTableID(), srcOffset, srcKey, warningDataExprs); + auto dstLookUpInfo = + IndexLookupInfo(toTable->getTableID(), dstOffset, dstKey, warningDataExprs); + auto lookupInfos = std::vector{srcLookUpInfo, dstLookUpInfo}; + auto internalIDColumnIndices = std::vector{0, 1, 2}; + auto extraCopyRelInfo = std::make_unique(fromTable->getName(), + toTable->getName(), internalIDColumnIndices, lookupInfos); + return BoundCopyFromInfo(tableName, TableType::REL, boundSource->copy(), offset, + std::move(columnExprs), std::move(evaluateTypes), std::move(extraCopyRelInfo)); +} + +std::unique_ptr Binder::bindCopyRelFrom(const Statement& statement, + RelGroupCatalogEntry& relGroupEntry, const std::string& fromTableName, + const std::string& toTableName) { + auto& copyStatement = statement.constCast(); + if (copyStatement.byColumn()) { + throw BinderException( + stringFormat("Copy by column is not supported for relationship table.")); + } + // Bind from to tables + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto fromTable = + catalog->getTableCatalogEntry(transaction, fromTableName)->ptrCast(); + auto toTable = + catalog->getTableCatalogEntry(transaction, toTableName)->ptrCast(); + auto relInfo = relGroupEntry.getRelEntryInfo(fromTable->getTableID(), toTable->getTableID()); + if (relInfo == nullptr) { + throw BinderException(stringFormat("Rel table {} does not contain {}-{} from-to pair.", + relGroupEntry.getName(), fromTable->getName(), toTable->getName())); + } + // Bind expected columns based on catalog information. + std::vector expectedColumnNames; + std::vector expectedColumnTypes; + bindExpectedRelColumns(relGroupEntry, *fromTable, *toTable, copyStatement.getCopyColumnInfo(), + expectedColumnNames, expectedColumnTypes); + // Bind info + auto boundCopyFromInfo = + bindCopyRelFromInfo(relGroupEntry.getName(), relGroupEntry.getProperties(), + copyStatement.getSource(), getScanSourceOptions(copyStatement), expectedColumnNames, + expectedColumnTypes, fromTable, toTable); + return std::make_unique(std::move(boundCopyFromInfo)); +} + +static bool skipPropertyInFile(const PropertyDefinition& property) { + if (property.getName() == InternalKeyword::ID) { + return true; + } + return false; +} + +static bool skipPropertyInSchema(const PropertyDefinition& property) { + if (property.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) { + return true; + } + if (property.getName() == InternalKeyword::ID) { + return true; + } + return false; +} + +static void bindExpectedColumns(const TableCatalogEntry& entry, const CopyFromColumnInfo& info, + std::vector& columnNames, std::vector& columnTypes) { + if (info.inputColumnOrder) { + std::unordered_set inputColumnNamesSet; + for (auto& columName : info.columnNames) { + if (inputColumnNamesSet.contains(columName)) { + throw BinderException( + stringFormat("Detect duplicate column name {} during COPY.", columName)); + } + inputColumnNamesSet.insert(columName); + } + // Search column data type for each input column. + for (auto& columnName : info.columnNames) { + if (!entry.containsProperty(columnName)) { + throw BinderException(stringFormat("Table {} does not contain column {}.", + entry.getName(), columnName)); + } + auto& property = entry.getProperty(columnName); + if (skipPropertyInFile(property)) { + continue; + } + columnNames.push_back(columnName); + columnTypes.push_back(property.getType().copy()); + } + } else { + // No column specified. Fall back to schema columns. + for (auto& property : entry.getProperties()) { + if (skipPropertyInSchema(property)) { + continue; + } + columnNames.push_back(property.getName()); + columnTypes.push_back(property.getType().copy()); + } + } +} + +void bindExpectedNodeColumns(const NodeTableCatalogEntry& entry, const CopyFromColumnInfo& info, + std::vector& columnNames, std::vector& columnTypes) { + KU_ASSERT(columnNames.empty() && columnTypes.empty()); + bindExpectedColumns(entry, info, columnNames, columnTypes); +} + +void bindExpectedRelColumns(const RelGroupCatalogEntry& entry, + const NodeTableCatalogEntry& fromEntry, const NodeTableCatalogEntry& toEntry, + const CopyFromColumnInfo& info, std::vector& columnNames, + std::vector& columnTypes) { + KU_ASSERT(columnNames.empty() && columnTypes.empty()); + columnNames.push_back("from"); + columnNames.push_back("to"); + auto srcPKColumnType = fromEntry.getPrimaryKeyDefinition().getType().copy(); + if (srcPKColumnType.getLogicalTypeID() == LogicalTypeID::SERIAL) { + srcPKColumnType = LogicalType::INT64(); + } + auto dstPKColumnType = toEntry.getPrimaryKeyDefinition().getType().copy(); + if (dstPKColumnType.getLogicalTypeID() == LogicalTypeID::SERIAL) { + dstPKColumnType = LogicalType::INT64(); + } + columnTypes.push_back(std::move(srcPKColumnType)); + columnTypes.push_back(std::move(dstPKColumnType)); + bindExpectedColumns(entry, info, columnNames, columnTypes); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/bind_copy_to.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/bind_copy_to.cpp new file mode 100644 index 0000000000..97bb97da61 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/copy/bind_copy_to.cpp @@ -0,0 +1,48 @@ +#include "binder/binder.h" +#include "binder/copy/bound_copy_to.h" +#include "catalog/catalog.h" +#include "common/exception/catalog.h" +#include "common/exception/runtime.h" +#include "function/built_in_function_utils.h" +#include "parser/copy.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindCopyToClause(const Statement& statement) { + auto& copyToStatement = statement.constCast(); + auto boundFilePath = copyToStatement.getFilePath(); + auto fileTypeInfo = bindFileTypeInfo({boundFilePath}); + std::vector columnNames; + auto parsedQuery = copyToStatement.getStatement(); + auto query = bindQuery(*parsedQuery); + auto columns = query->getStatementResult()->getColumns(); + auto fileTypeStr = fileTypeInfo.fileTypeStr; + auto name = stringFormat("COPY_{}", fileTypeStr); + catalog::CatalogEntry* entry = nullptr; + try { + entry = catalog::Catalog::Get(*clientContext) + ->getFunctionEntry(transaction::Transaction::Get(*clientContext), name); + } catch (CatalogException& exception) { + throw RuntimeException{common::stringFormat( + "Exporting query result to the '{}' file is currently not supported.", fileTypeStr)}; + } + auto exportFunc = function::BuiltInFunctionsUtils::matchFunction(name, + entry->ptrCast()) + ->constPtrCast(); + for (auto& column : columns) { + auto columnName = column->hasAlias() ? column->getAlias() : column->toString(); + columnNames.push_back(columnName); + } + function::ExportFuncBindInput bindInput{std::move(columnNames), std::move(boundFilePath), + bindParsingOptions(copyToStatement.getParsingOptions())}; + auto bindData = exportFunc->bind(bindInput); + return std::make_unique(std::move(bindData), *exportFunc, std::move(query)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/ddl/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/ddl/CMakeLists.txt new file mode 100644 index 0000000000..e5720d33d2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/ddl/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_binder_bind_ddl + OBJECT + bound_create_table_info.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/ddl/bound_create_table_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/ddl/bound_create_table_info.cpp new file mode 100644 index 0000000000..a777a1fb31 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/ddl/bound_create_table_info.cpp @@ -0,0 +1,43 @@ +#include "binder/ddl/bound_create_table_info.h" + +#include "catalog/catalog_entry/catalog_entry_type.h" +#include "catalog/catalog_entry/table_catalog_entry.h" + +using namespace lbug::parser; +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +std::string BoundCreateTableInfo::toString() const { + std::string result = ""; + switch (type) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + result += "Create Node Table: "; + result += tableName; + result += ",Properties: "; + auto nodeInfo = extraInfo->ptrCast(); + for (auto& definition : nodeInfo->propertyDefinitions) { + result += definition.getName(); + result += ", "; + } + } break; + case CatalogEntryType::REL_GROUP_ENTRY: { + result += "Create Relationship Table: "; + result += tableName; + auto relGroupInfo = extraInfo->ptrCast(); + result += "Properties: "; + for (auto& definition : relGroupInfo->propertyDefinitions) { + result += definition.getName(); + result += ", "; + } + } break; + default: + break; + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/CMakeLists.txt new file mode 100644 index 0000000000..1e7669d963 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(lbug_binder_bind_read + OBJECT + bind_in_query_call.cpp + bind_load_from.cpp + bind_match.cpp + bind_unwind.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_in_query_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_in_query_call.cpp new file mode 100644 index 0000000000..d06581ebc6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_in_query_call.cpp @@ -0,0 +1,45 @@ +#include "binder/binder.h" +#include "binder/query/reading_clause/bound_table_function_call.h" +#include "catalog/catalog.h" +#include "common/exception/binder.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/query/reading_clause/in_query_call_clause.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::parser; +using namespace lbug::function; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindInQueryCall(const ReadingClause& readingClause) { + auto& call = readingClause.constCast(); + auto expr = call.getFunctionExpression(); + auto functionExpr = expr->constPtrCast(); + auto functionName = functionExpr->getFunctionName(); + std::unique_ptr boundReadingClause; + auto transaction = transaction::Transaction::Get(*clientContext); + auto entry = Catalog::Get(*clientContext)->getFunctionEntry(transaction, functionName); + switch (entry->getType()) { + case CatalogEntryType::TABLE_FUNCTION_ENTRY: { + auto boundTableFunction = + bindTableFunc(functionName, *functionExpr, call.getYieldVariables()); + boundReadingClause = + std::make_unique(std::move(boundTableFunction)); + } break; + default: + throw BinderException( + stringFormat("{} is not a table or algorithm function.", functionName)); + } + if (call.hasWherePredicate()) { + auto wherePredicate = bindWhereExpression(*call.getWherePredicate()); + boundReadingClause->setPredicate(std::move(wherePredicate)); + } + return boundReadingClause; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_load_from.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_load_from.cpp new file mode 100644 index 0000000000..6e399cae79 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_load_from.cpp @@ -0,0 +1,64 @@ +#include "binder/binder.h" +#include "binder/bound_scan_source.h" +#include "binder/expression/expression_util.h" +#include "binder/query/reading_clause/bound_load_from.h" +#include "common/exception/binder.h" +#include "parser/query/reading_clause/load_from.h" +#include "parser/scan_source.h" + +using namespace lbug::function; +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bindLoadFrom(const ReadingClause& readingClause) { + auto& loadFrom = readingClause.constCast(); + auto source = loadFrom.getSource(); + std::unique_ptr boundLoadFrom; + std::vector columnNames; + std::vector columnTypes; + for (auto& [name, type] : loadFrom.getColumnDefinitions()) { + columnNames.push_back(name); + columnTypes.push_back(LogicalType::convertFromString(type, clientContext)); + } + switch (source->type) { + case ScanSourceType::OBJECT: { + auto objectSource = source->ptrCast(); + auto boundScanSource = bindObjectScanSource(*objectSource, loadFrom.getParsingOptions(), + columnNames, columnTypes); + auto& scanInfo = boundScanSource->constCast().info; + boundLoadFrom = std::make_unique(scanInfo.copy()); + } break; + case ScanSourceType::FILE: { + auto boundScanSource = + bindFileScanSource(*source, loadFrom.getParsingOptions(), columnNames, columnTypes); + auto& scanInfo = boundScanSource->constCast().info; + boundLoadFrom = std::make_unique(scanInfo.copy()); + } break; + case ScanSourceType::PARAM: { + auto boundScanSource = bindParameterScanSource(*source, loadFrom.getParsingOptions(), + columnNames, columnTypes); + auto& scanInfo = boundScanSource->constCast().info; + boundLoadFrom = std::make_unique(scanInfo.copy()); + } break; + default: + throw BinderException(stringFormat("LOAD FROM subquery is not supported.")); + } + if (!columnTypes.empty()) { + auto info = boundLoadFrom->getInfo(); + for (auto i = 0u; i < columnTypes.size(); ++i) { + ExpressionUtil::validateDataType(*info->bindData->columns[i], columnTypes[i]); + } + } + if (loadFrom.hasWherePredicate()) { + auto wherePredicate = bindWhereExpression(*loadFrom.getWherePredicate()); + boundLoadFrom->setPredicate(std::move(wherePredicate)); + } + return boundLoadFrom; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_match.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_match.cpp new file mode 100644 index 0000000000..bd1a9b67e4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_match.cpp @@ -0,0 +1,126 @@ +#include "binder/binder.h" +#include "binder/query/reading_clause/bound_match_clause.h" +#include "common/exception/binder.h" +#include "parser/query/reading_clause/match_clause.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +static void collectHintPattern(const BoundJoinHintNode& node, binder::expression_set& set) { + if (node.isLeaf()) { + set.insert(node.nodeOrRel); + return; + } + for (auto& child : node.children) { + collectHintPattern(*child, set); + } +} + +static void validateHintCompleteness(const BoundJoinHintNode& root, const QueryGraph& queryGraph) { + binder::expression_set set; + collectHintPattern(root, set); + for (auto& nodeOrRel : queryGraph.getAllPatterns()) { + if (nodeOrRel->getVariableName().empty()) { + throw BinderException( + "Cannot hint join order in a match patter with anonymous node or relationship."); + } + if (!set.contains(nodeOrRel)) { + throw BinderException( + stringFormat("Cannot find {} in join hint.", nodeOrRel->toString())); + } + } +} + +std::unique_ptr Binder::bindMatchClause(const ReadingClause& readingClause) { + auto& matchClause = readingClause.constCast(); + auto boundGraphPattern = bindGraphPattern(matchClause.getPatternElementsRef()); + if (matchClause.hasWherePredicate()) { + boundGraphPattern.where = bindWhereExpression(*matchClause.getWherePredicate()); + } + rewriteMatchPattern(boundGraphPattern); + auto boundMatch = std::make_unique( + std::move(boundGraphPattern.queryGraphCollection), matchClause.getMatchClauseType()); + if (matchClause.hasHint()) { + boundMatch->setHint( + bindJoinHint(*boundMatch->getQueryGraphCollection(), *matchClause.getHint())); + } + boundMatch->setPredicate(boundGraphPattern.where); + return boundMatch; +} + +std::shared_ptr Binder::bindJoinHint( + const QueryGraphCollection& queryGraphCollection, const JoinHintNode& joinHintNode) { + if (queryGraphCollection.getNumQueryGraphs() > 1) { + throw BinderException("Join hint on disconnected match pattern is not supported."); + } + auto hint = bindJoinNode(joinHintNode); + validateHintCompleteness(*hint, *queryGraphCollection.getQueryGraph(0)); + return hint; +} + +std::shared_ptr Binder::bindJoinNode(const JoinHintNode& joinHintNode) { + if (joinHintNode.isLeaf()) { + std::shared_ptr pattern = nullptr; + if (scope.contains(joinHintNode.variableName)) { + pattern = scope.getExpression(joinHintNode.variableName); + } + if (pattern == nullptr || pattern->expressionType != ExpressionType::PATTERN) { + throw BinderException(stringFormat("Cannot bind {} to a node or relationship pattern", + joinHintNode.variableName)); + } + return std::make_shared(std::move(pattern)); + } + auto node = std::make_shared(); + for (auto& child : joinHintNode.children) { + node->addChild(bindJoinNode(*child)); + } + return node; +} + +void Binder::rewriteMatchPattern(BoundGraphPattern& boundGraphPattern) { + // Rewrite self loop edge + // e.g. rewrite (a)-[e]->(a) as [a]-[e]->(b) WHERE id(a) = id(b) + expression_vector selfLoopEdgePredicates; + auto& graphCollection = boundGraphPattern.queryGraphCollection; + for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) { + auto queryGraph = graphCollection.getQueryGraphUnsafe(i); + for (auto& queryRel : queryGraph->getQueryRels()) { + if (!queryRel->isSelfLoop()) { + continue; + } + auto src = queryRel->getSrcNode(); + auto dst = queryRel->getDstNode(); + auto newDst = createQueryNode(dst->getVariableName(), dst->getEntries()); + queryGraph->addQueryNode(newDst); + queryRel->setDstNode(newDst); + auto predicate = expressionBinder.createEqualityComparisonExpression( + src->getInternalID(), newDst->getInternalID()); + selfLoopEdgePredicates.push_back(std::move(predicate)); + } + } + auto where = boundGraphPattern.where; + for (auto& predicate : selfLoopEdgePredicates) { + where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate, where); + } + // Rewrite key value pairs in MATCH clause as predicate + for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) { + auto queryGraph = graphCollection.getQueryGraphUnsafe(i); + for (auto& pattern : queryGraph->getAllPatterns()) { + for (auto& [propertyName, rhs] : pattern->getPropertyDataExprRef()) { + auto propertyExpr = + expressionBinder.bindNodeOrRelPropertyExpression(*pattern, propertyName); + auto predicate = + expressionBinder.createEqualityComparisonExpression(propertyExpr, rhs); + where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate, + where); + } + } + } + boundGraphPattern.where = std::move(where); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_unwind.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_unwind.cpp new file mode 100644 index 0000000000..b5d4b43b9a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind/read/bind_unwind.cpp @@ -0,0 +1,55 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/query/reading_clause/bound_unwind_clause.h" +#include "parser/query/reading_clause/unwind_clause.h" + +using namespace lbug::parser; +using namespace lbug::common; + +namespace lbug { +namespace binder { + +// E.g. UNWIND $1. We cannot validate $1 has data type LIST until we see the actual parameter. +static bool skipDataTypeValidation(const Expression& expr) { + return expr.expressionType == ExpressionType::PARAMETER && + expr.getDataType().getLogicalTypeID() == LogicalTypeID::ANY; +} + +std::unique_ptr Binder::bindUnwindClause(const ReadingClause& readingClause) { + auto& unwindClause = readingClause.constCast(); + auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression()); + auto aliasName = unwindClause.getAlias(); + std::shared_ptr alias; + if (boundExpression->getDataType().getLogicalTypeID() == LogicalTypeID::ARRAY) { + auto targetType = + LogicalType::LIST(ArrayType::getChildType(boundExpression->dataType).copy()); + boundExpression = expressionBinder.implicitCast(boundExpression, targetType); + } + if (!skipDataTypeValidation(*boundExpression)) { + if (ExpressionUtil::isNullLiteral(*boundExpression)) { + // For the `unwind NULL` clause, we assign the STRING[] type to the NULL literal. + alias = createVariable(aliasName, LogicalType::STRING()); + boundExpression = expressionBinder.implicitCast(boundExpression, + LogicalType::LIST(LogicalType::STRING())); + } else { + ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST); + boundExpression = expressionBinder.implicitCastIfNecessary(boundExpression, + LogicalTypeUtils::purgeAny(boundExpression->dataType, LogicalType::STRING())); + alias = createVariable(aliasName, ListType::getChildType(boundExpression->dataType)); + } + } else { + alias = createVariable(aliasName, LogicalType::ANY()); + } + std::shared_ptr idExpr = nullptr; + if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) { + auto entries = scope.getMemorizedTableEntries(boundExpression->getAlias()); + auto node = createQueryNode(aliasName, entries); + idExpr = node->getInternalID(); + scope.addNodeReplacement(node); + } + return make_unique(std::move(boundExpression), std::move(alias), + std::move(idExpr)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/CMakeLists.txt new file mode 100644 index 0000000000..d2a1a3f47c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/CMakeLists.txt @@ -0,0 +1,18 @@ +add_library( + lbug_binder_bind_expression + OBJECT + bind_boolean_expression.cpp + bind_case_expression.cpp + bind_comparison_expression.cpp + bind_function_expression.cpp + bind_literal_expression.cpp + bind_null_operator_expression.cpp + bind_parameter_expression.cpp + bind_property_expression.cpp + bind_subquery_expression.cpp + bind_variable_expression.cpp + bind_lambda_expression.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_boolean_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_boolean_expression.cpp new file mode 100644 index 0000000000..89011b1f02 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_boolean_expression.cpp @@ -0,0 +1,57 @@ +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "function/boolean/vector_boolean_functions.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::function; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindBooleanExpression( + const ParsedExpression& parsedExpression) { + expression_vector children; + for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { + children.push_back(bindExpression(*parsedExpression.getChild(i))); + } + return bindBooleanExpression(parsedExpression.getExpressionType(), children); +} + +std::shared_ptr ExpressionBinder::bindBooleanExpression(ExpressionType expressionType, + const expression_vector& children) { + expression_vector childrenAfterCast; + std::vector inputTypeIDs; + for (auto& child : children) { + childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalType::BOOL())); + inputTypeIDs.push_back(LogicalTypeID::BOOL); + } + auto functionName = ExpressionTypeUtil::toString(expressionType); + scalar_func_exec_t execFunc; + VectorBooleanFunction::bindExecFunction(expressionType, childrenAfterCast, execFunc); + scalar_func_select_t selectFunc; + VectorBooleanFunction::bindSelectFunction(expressionType, childrenAfterCast, selectFunc); + auto bindData = std::make_unique(LogicalType::BOOL()); + auto uniqueExpressionName = + ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast); + auto func = std::make_unique(functionName, inputTypeIDs, LogicalTypeID::BOOL, + execFunc, selectFunc); + return std::make_shared(expressionType, std::move(func), + std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName); +} + +std::shared_ptr ExpressionBinder::combineBooleanExpressions( + ExpressionType expressionType, std::shared_ptr left, + std::shared_ptr right) { + if (left == nullptr) { + return right; + } else if (right == nullptr) { + return left; + } else { + return bindBooleanExpression(expressionType, + expression_vector{std::move(left), std::move(right)}); + } +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_case_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_case_expression.cpp new file mode 100644 index 0000000000..695c5fb895 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_case_expression.cpp @@ -0,0 +1,75 @@ +#include "binder/binder.h" +#include "binder/expression/case_expression.h" +#include "binder/expression/expression_util.h" +#include "binder/expression_binder.h" +#include "parser/expression/parsed_case_expression.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindCaseExpression( + const ParsedExpression& parsedExpression) { + auto& parsedCaseExpression = parsedExpression.constCast(); + auto resultType = LogicalType::ANY(); + // Resolve result type by checking each then expression type. + for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) { + auto alternative = parsedCaseExpression.getCaseAlternative(i); + auto boundThen = bindExpression(*alternative->thenExpression); + if (boundThen->getDataType().getLogicalTypeID() != LogicalTypeID::ANY) { + resultType = boundThen->getDataType().copy(); + } + } + // Resolve result type by else expression if above resolving fails. + if (resultType.getLogicalTypeID() == LogicalTypeID::ANY && + parsedCaseExpression.hasElseExpression()) { + auto elseExpression = bindExpression(*parsedCaseExpression.getElseExpression()); + resultType = elseExpression->getDataType().copy(); + } + auto name = binder->getUniqueExpressionName(parsedExpression.getRawName()); + // bind ELSE ... + std::shared_ptr elseExpression; + if (parsedCaseExpression.hasElseExpression()) { + elseExpression = bindExpression(*parsedCaseExpression.getElseExpression()); + } else { + elseExpression = createNullLiteralExpression(); + } + elseExpression = implicitCastIfNecessary(elseExpression, resultType); + auto boundCaseExpression = + make_shared(resultType.copy(), std::move(elseExpression), name); + // bind WHEN ... THEN ... + if (parsedCaseExpression.hasCaseExpression()) { + auto boundCase = bindExpression(*parsedCaseExpression.getCaseExpression()); + for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) { + auto caseAlternative = parsedCaseExpression.getCaseAlternative(i); + auto boundWhen = bindExpression(*caseAlternative->whenExpression); + boundWhen = implicitCastIfNecessary(boundWhen, boundCase->dataType); + // rewrite "CASE a.age WHEN 1" as "CASE WHEN a.age = 1" + if (ExpressionUtil::isNullLiteral(*boundWhen)) { + boundWhen = bindNullOperatorExpression(ExpressionType::IS_NULL, + expression_vector{boundWhen}); + } else { + boundWhen = bindComparisonExpression(ExpressionType::EQUALS, + expression_vector{boundCase, boundWhen}); + } + auto boundThen = bindExpression(*caseAlternative->thenExpression); + boundThen = implicitCastIfNecessary(boundThen, resultType); + boundCaseExpression->addCaseAlternative(boundWhen, boundThen); + } + } else { + for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) { + auto caseAlternative = parsedCaseExpression.getCaseAlternative(i); + auto boundWhen = bindExpression(*caseAlternative->whenExpression); + boundWhen = implicitCastIfNecessary(boundWhen, LogicalType::BOOL()); + auto boundThen = bindExpression(*caseAlternative->thenExpression); + boundThen = implicitCastIfNecessary(boundThen, resultType); + boundCaseExpression->addCaseAlternative(boundWhen, boundThen); + } + } + return boundCaseExpression; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_comparison_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_comparison_expression.cpp new file mode 100644 index 0000000000..db1576e31c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_comparison_expression.cpp @@ -0,0 +1,96 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "catalog/catalog.h" +#include "common/exception/binder.h" +#include "function/built_in_function_utils.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::parser; +using namespace lbug::function; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindComparisonExpression( + const ParsedExpression& parsedExpression) { + expression_vector children; + for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { + auto child = bindExpression(*parsedExpression.getChild(i)); + children.push_back(std::move(child)); + } + return bindComparisonExpression(parsedExpression.getExpressionType(), children); +} + +static bool isNodeOrRel(const Expression& expression) { + switch (expression.getDataType().getLogicalTypeID()) { + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + return true; + default: + return false; + } +} + +std::shared_ptr ExpressionBinder::bindComparisonExpression( + ExpressionType expressionType, const expression_vector& children) { + // Rewrite node or rel comparison + KU_ASSERT(children.size() == 2); + if (isNodeOrRel(*children[0]) && isNodeOrRel(*children[1])) { + expression_vector newChildren; + newChildren.push_back(children[0]->constCast().getInternalID()); + newChildren.push_back(children[1]->constCast().getInternalID()); + return bindComparisonExpression(expressionType, newChildren); + } + + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto functionName = ExpressionTypeUtil::toString(expressionType); + LogicalType combinedType(LogicalTypeID::ANY); + if (!ExpressionUtil::tryCombineDataType(children, combinedType)) { + throw BinderException(stringFormat("Type Mismatch: Cannot compare types {} and {}", + children[0]->dataType.toString(), children[1]->dataType.toString())); + } + if (combinedType.getLogicalTypeID() == LogicalTypeID::ANY) { + combinedType = LogicalType(LogicalTypeID::INT8); + } + std::vector childrenTypes; + for (auto i = 0u; i < children.size(); i++) { + childrenTypes.push_back(combinedType.copy()); + } + auto entry = + catalog->getFunctionEntry(transaction, functionName)->ptrCast(); + auto function = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry) + ->ptrCast(); + expression_vector childrenAfterCast; + for (auto i = 0u; i < children.size(); ++i) { + if (children[i]->dataType != combinedType) { + childrenAfterCast.push_back(forceCast(children[i], combinedType)); + } else { + childrenAfterCast.push_back(children[i]); + } + } + if (function->bindFunc) { + // Resolve exec and select function if necessary + // Only used for decimal at the moment. See `bindDecimalCompare`. + function->bindFunc({childrenAfterCast, function, nullptr, + std::vector{} /* optionalParams */}); + } + auto bindData = std::make_unique(LogicalType(function->returnTypeID)); + auto uniqueExpressionName = + ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast); + return std::make_shared(expressionType, function->copy(), + std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName); +} + +std::shared_ptr ExpressionBinder::createEqualityComparisonExpression( + std::shared_ptr left, std::shared_ptr right) { + return bindComparisonExpression(ExpressionType::EQUALS, + expression_vector{std::move(left), std::move(right)}); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_function_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_function_expression.cpp new file mode 100644 index 0000000000..18e11c96ea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_function_expression.cpp @@ -0,0 +1,215 @@ +#include "binder/binder.h" +#include "binder/expression/aggregate_function_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "catalog/catalog.h" +#include "common/exception/binder.h" +#include "function/built_in_function_utils.h" +#include "function/cast/vector_cast_functions.h" +#include "function/rewrite_function.h" +#include "function/scalar_macro_function.h" +#include "parser/expression/parsed_expression_visitor.h" +#include "parser/expression/parsed_function_expression.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::function; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) { + auto funcExpr = expr.constPtrCast(); + auto functionName = funcExpr->getNormalizedFunctionName(); + auto transaction = transaction::Transaction::Get(*context); + auto catalog = Catalog::Get(*context); + auto entry = catalog->getFunctionEntry(transaction, functionName); + switch (entry->getType()) { + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + return bindScalarFunctionExpression(expr, functionName); + case CatalogEntryType::REWRITE_FUNCTION_ENTRY: + return bindRewriteFunctionExpression(expr); + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: + return bindAggregateFunctionExpression(expr, functionName, funcExpr->getIsDistinct()); + case CatalogEntryType::SCALAR_MACRO_ENTRY: + return bindMacroExpression(expr, functionName); + default: + throw BinderException( + stringFormat("{} is a {}. Scalar function, aggregate function or macro was expected. ", + functionName, CatalogEntryTypeUtils::toString(entry->getType()))); + } +} + +std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( + const ParsedExpression& parsedExpression, const std::string& functionName) { + expression_vector children; + for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { + auto expr = bindExpression(*parsedExpression.getChild(i)); + if (parsedExpression.getChild(i)->hasAlias()) { + expr->setAlias(parsedExpression.getChild(i)->getAlias()); + } + children.push_back(expr); + } + return bindScalarFunctionExpression(children, functionName, + parsedExpression.constCast().getOptionalArguments()); +} + +static std::vector getTypes(const expression_vector& exprs) { + std::vector result; + for (auto& expr : exprs) { + result.push_back(expr->getDataType().copy()); + } + return result; +} + +std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( + const expression_vector& children, const std::string& functionName, + std::vector optionalArguments) { + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto childrenTypes = getTypes(children); + + auto entry = catalog->getFunctionEntry(transaction, functionName); + + auto function = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, + entry->ptrCast()) + ->ptrCast() + ->copy(); + if (children.size() == 2 && children[1]->expressionType == ExpressionType::LAMBDA) { + if (!function->isListLambda) { + throw BinderException(stringFormat("{} does not support lambda input.", functionName)); + } + bindLambdaExpression(*children[0], *children[1]); + } + expression_vector childrenAfterCast; + std::unique_ptr bindData; + auto bindInput = + ScalarBindFuncInput{children, function.get(), context, std::move(optionalArguments)}; + if (functionName == CastAnyFunction::name) { + bindData = function->bindFunc(bindInput); + if (bindData == nullptr) { // No need to cast. + // TODO(Xiyang): We should return a deep copy otherwise the same expression might + // appear in the final projection list repeatedly. + // E.g. RETURN cast([NULL], "INT64[1][]"), cast([NULL], "INT64[1][][]") + return children[0]; + } + auto childAfterCast = children[0]; + if (children[0]->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + childAfterCast = implicitCastIfNecessary(children[0], LogicalType::STRING()); + } + childrenAfterCast.push_back(std::move(childAfterCast)); + } else { + if (function->bindFunc) { + bindData = function->bindFunc(bindInput); + } else { + bindData = std::make_unique(LogicalType(function->returnTypeID)); + } + if (!bindData->paramTypes.empty()) { + for (auto i = 0u; i < children.size(); ++i) { + childrenAfterCast.push_back( + implicitCastIfNecessary(children[i], bindData->paramTypes[i])); + } + } else { + for (auto i = 0u; i < children.size(); ++i) { + auto id = function->isVarLength ? function->parameterTypeIDs[0] : + function->parameterTypeIDs[i]; + auto type = LogicalType(id); + childrenAfterCast.push_back(implicitCastIfNecessary(children[i], type)); + } + } + } + auto uniqueExpressionName = + ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast); + return std::make_shared(ExpressionType::FUNCTION, std::move(function), + std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName); +} + +std::shared_ptr ExpressionBinder::bindRewriteFunctionExpression( + const ParsedExpression& expr) { + auto& funcExpr = expr.constCast(); + expression_vector children; + for (auto i = 0u; i < expr.getNumChildren(); ++i) { + children.push_back(bindExpression(*expr.getChild(i))); + } + auto childrenTypes = getTypes(children); + auto functionName = funcExpr.getNormalizedFunctionName(); + auto transaction = transaction::Transaction::Get(*context); + auto entry = Catalog::Get(*context)->getFunctionEntry(transaction, functionName); + auto match = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, + entry->ptrCast()); + auto function = match->constPtrCast(); + KU_ASSERT(function->rewriteFunc != nullptr); + auto input = RewriteFunctionBindInput(context, this, children); + return function->rewriteFunc(input); +} + +std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( + const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) { + std::vector childrenTypes; + expression_vector children; + for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { + auto child = bindExpression(*parsedExpression.getChild(i)); + childrenTypes.push_back(child->dataType.copy()); + children.push_back(std::move(child)); + } + auto transaction = transaction::Transaction::Get(*context); + auto entry = Catalog::Get(*context)->getFunctionEntry(transaction, functionName); + auto function = BuiltInFunctionsUtils::matchAggregateFunction(functionName, childrenTypes, + isDistinct, entry->ptrCast()) + ->copy(); + if (function.paramRewriteFunc) { + function.paramRewriteFunc(children); + } + if (functionName == CollectFunction::name && parsedExpression.hasAlias() && + children[0]->getDataType().getLogicalTypeID() == LogicalTypeID::NODE) { + auto& node = children[0]->constCast(); + binder->scope.memorizeTableEntries(parsedExpression.getAlias(), node.getEntries()); + } + auto uniqueExpressionName = + AggregateFunctionExpression::getUniqueName(function.name, children, function.isDistinct); + if (children.empty()) { + uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName); + } + std::unique_ptr bindData; + if (function.bindFunc) { + auto bindInput = ScalarBindFuncInput{children, &function, context, + std::vector{} /* optionalParams */}; + bindData = function.bindFunc(bindInput); + } else { + bindData = std::make_unique(LogicalType(function.returnTypeID)); + } + return std::make_shared(std::move(function), std::move(bindData), + std::move(children), uniqueExpressionName); +} + +std::shared_ptr ExpressionBinder::bindMacroExpression( + const ParsedExpression& parsedExpression, const std::string& macroName) { + auto transaction = transaction::Transaction::Get(*context); + auto scalarMacroFunction = + Catalog::Get(*context)->getScalarMacroFunction(transaction, macroName); + auto macroExpr = scalarMacroFunction->expression->copy(); + auto parameterVals = scalarMacroFunction->getDefaultParameterVals(); + auto& parsedFuncExpr = parsedExpression.constCast(); + auto positionalArgs = scalarMacroFunction->getPositionalArgs(); + if (parsedFuncExpr.getNumChildren() > scalarMacroFunction->getNumArgs() || + parsedFuncExpr.getNumChildren() < positionalArgs.size()) { + throw BinderException{"Invalid number of arguments for macro " + macroName + "."}; + } + // Bind positional arguments. + for (auto i = 0u; i < positionalArgs.size(); i++) { + parameterVals[positionalArgs[i]] = parsedFuncExpr.getChild(i); + } + // Bind arguments with default values. + for (auto i = positionalArgs.size(); i < parsedFuncExpr.getNumChildren(); i++) { + auto parameterName = + scalarMacroFunction->getDefaultParameterName(i - positionalArgs.size()); + parameterVals[parameterName] = parsedFuncExpr.getChild(i); + } + auto replacer = MacroParameterReplacer(parameterVals); + return bindExpression(*replacer.replace(std::move(macroExpr))); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_lambda_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_lambda_expression.cpp new file mode 100644 index 0000000000..79ff810608 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_lambda_expression.cpp @@ -0,0 +1,37 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/lambda_expression.h" +#include "parser/expression/parsed_lambda_expression.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +void ExpressionBinder::bindLambdaExpression(const Expression& lambdaInput, + Expression& lambdaExpr) const { + ExpressionUtil::validateDataType(lambdaInput, LogicalTypeID::LIST); + auto& listChildType = ListType::getChildType(lambdaInput.getDataType()); + auto& boundLambdaExpr = lambdaExpr.cast(); + auto& parsedLambdaExpr = + boundLambdaExpr.getParsedLambdaExpr()->constCast(); + auto prevScope = binder->saveScope(); + for (auto& varName : parsedLambdaExpr.getVarNames()) { + binder->createVariable(varName, listChildType); + } + auto funcExpr = + binder->getExpressionBinder()->bindExpression(*parsedLambdaExpr.getFunctionExpr()); + binder->restoreScope(std::move(prevScope)); + boundLambdaExpr.cast(funcExpr->getDataType().copy()); + boundLambdaExpr.setFunctionExpr(std::move(funcExpr)); +} + +std::shared_ptr ExpressionBinder::bindLambdaExpression( + const parser::ParsedExpression& parsedExpr) const { + auto uniqueName = getUniqueName(parsedExpr.getRawName()); + return std::make_shared(parsedExpr.copy(), uniqueName); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_literal_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_literal_expression.cpp new file mode 100644 index 0000000000..2dd049081d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_literal_expression.cpp @@ -0,0 +1,43 @@ +#include "binder/binder.h" +#include "binder/expression/literal_expression.h" +#include "binder/expression_binder.h" +#include "parser/expression/parsed_literal_expression.h" + +using namespace lbug::parser; +using namespace lbug::common; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindLiteralExpression( + const ParsedExpression& parsedExpression) const { + auto& literalExpression = parsedExpression.constCast(); + auto value = literalExpression.getValue(); + if (value.isNull()) { + return createNullLiteralExpression(value); + } + return createLiteralExpression(value); +} + +std::shared_ptr ExpressionBinder::createLiteralExpression(const Value& value) const { + auto uniqueName = binder->getUniqueExpressionName(value.toString()); + return std::make_unique(value, uniqueName); +} + +std::shared_ptr ExpressionBinder::createLiteralExpression( + const std::string& strVal) const { + return createLiteralExpression(Value(strVal)); +} + +std::shared_ptr ExpressionBinder::createNullLiteralExpression() const { + return make_shared(Value::createNullValue(), + binder->getUniqueExpressionName("NULL")); +} + +std::shared_ptr ExpressionBinder::createNullLiteralExpression( + const Value& value) const { + return make_shared(value, binder->getUniqueExpressionName("NULL")); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_null_operator_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_null_operator_expression.cpp new file mode 100644 index 0000000000..f6d8e48e91 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_null_operator_expression.cpp @@ -0,0 +1,48 @@ +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "function/null/vector_null_functions.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::function; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindNullOperatorExpression( + const ParsedExpression& parsedExpression) { + expression_vector children; + for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { + children.push_back(bindExpression(*parsedExpression.getChild(i))); + } + return bindNullOperatorExpression(parsedExpression.getExpressionType(), children); +} + +std::shared_ptr ExpressionBinder::bindNullOperatorExpression( + ExpressionType expressionType, const expression_vector& children) { + expression_vector childrenAfterCast; + std::vector inputTypeIDs; + for (auto& child : children) { + inputTypeIDs.push_back(child->getDataType().getLogicalTypeID()); + if (child->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { + childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalType::BOOL())); + } else { + childrenAfterCast.push_back(child); + } + } + auto functionName = ExpressionTypeUtil::toString(expressionType); + function::scalar_func_exec_t execFunc; + function::VectorNullFunction::bindExecFunction(expressionType, childrenAfterCast, execFunc); + function::scalar_func_select_t selectFunc; + function::VectorNullFunction::bindSelectFunction(expressionType, childrenAfterCast, selectFunc); + auto bindData = std::make_unique(LogicalType::BOOL()); + auto uniqueExpressionName = + ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast); + auto func = std::make_unique(functionName, inputTypeIDs, LogicalTypeID::BOOL, + execFunc, selectFunc); + return make_shared(expressionType, std::move(func), + std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_parameter_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_parameter_expression.cpp new file mode 100644 index 0000000000..de56757708 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_parameter_expression.cpp @@ -0,0 +1,26 @@ +#include "binder/expression/parameter_expression.h" +#include "binder/expression_binder.h" +#include "common/exception/binder.h" +#include "parser/expression/parsed_parameter_expression.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindParameterExpression( + const ParsedExpression& parsedExpression) { + auto& parsedParameterExpression = parsedExpression.constCast(); + auto parameterName = parsedParameterExpression.getParameterName(); + if (knownParameters.contains(parameterName)) { + return make_shared(parameterName, *knownParameters.at(parameterName)); + } + // LCOV_EXCL_START + throw BinderException( + stringFormat("Cannot find parameter {}. This should not happen.", parameterName)); + // LCOV_EXCL_STOP +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_property_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_property_expression.cpp new file mode 100644 index 0000000000..2010a2f6cc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_property_expression.cpp @@ -0,0 +1,124 @@ +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/node_rel_expression.h" +#include "binder/expression_binder.h" +#include "common/cast.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "function/struct/vector_struct_functions.h" +#include "parser/expression/parsed_property_expression.h" + +using namespace lbug::common; +using namespace lbug::parser; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +static bool isNodeOrRelPattern(const Expression& expression) { + return ExpressionUtil::isNodePattern(expression) || ExpressionUtil::isRelPattern(expression); +} + +static bool isStructPattern(const Expression& expression) { + auto logicalTypeID = expression.getDataType().getLogicalTypeID(); + return logicalTypeID == LogicalTypeID::NODE || logicalTypeID == LogicalTypeID::REL || + logicalTypeID == LogicalTypeID::STRUCT; +} + +expression_vector ExpressionBinder::bindPropertyStarExpression( + const parser::ParsedExpression& parsedExpression) { + auto child = bindExpression(*parsedExpression.getChild(0)); + if (isNodeOrRelPattern(*child)) { + return bindNodeOrRelPropertyStarExpression(*child); + } else if (isStructPattern(*child)) { + return bindStructPropertyStarExpression(child); + } else { + throw BinderException(stringFormat("Cannot bind property for expression {} with type {}.", + child->toString(), ExpressionTypeUtil::toString(child->expressionType))); + } +} + +expression_vector ExpressionBinder::bindNodeOrRelPropertyStarExpression(const Expression& child) { + expression_vector result; + auto& nodeOrRel = child.constCast(); + for (auto& property : nodeOrRel.getPropertyExpressions()) { + if (Binder::reservedInPropertyLookup(property->getPropertyName())) { + continue; + } + result.push_back(property); + } + return result; +} + +expression_vector ExpressionBinder::bindStructPropertyStarExpression( + const std::shared_ptr& child) { + expression_vector result; + const auto& childType = child->getDataType(); + for (auto& field : StructType::getFields(childType)) { + result.push_back(bindStructPropertyExpression(child, field.getName())); + } + return result; +} + +std::shared_ptr ExpressionBinder::bindPropertyExpression( + const ParsedExpression& parsedExpression) { + auto& propertyExpression = parsedExpression.constCast(); + if (propertyExpression.isStar()) { + throw BinderException(stringFormat("Cannot bind {} as a single property expression.", + parsedExpression.toString())); + } + auto propertyName = propertyExpression.getPropertyName(); + auto child = bindExpression(*parsedExpression.getChild(0)); + ExpressionUtil::validateDataType(*child, + std::vector{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT, + LogicalTypeID::ANY}); + if (config.bindOrderByAfterAggregate) { + // See the declaration of this field for more information. + return bindStructPropertyExpression(child, propertyName); + } + if (isNodeOrRelPattern(*child)) { + if (Binder::reservedInPropertyLookup(propertyName)) { + // Note we don't expose direct access to internal properties in case user tries to + // modify them. However, we can expose indirect read-only access through function e.g. + // ID(). + throw BinderException( + propertyName + " is reserved for system usage. External access is not allowed."); + } + return bindNodeOrRelPropertyExpression(*child, propertyName); + } else if (isStructPattern(*child)) { + return bindStructPropertyExpression(child, propertyName); + } else if (child->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + return createVariableExpression(LogicalType::ANY(), binder->getUniqueExpressionName("")); + } else { + throw BinderException(stringFormat("Cannot bind property for expression {} with type {}.", + child->toString(), ExpressionTypeUtil::toString(child->expressionType))); + } +} + +std::shared_ptr ExpressionBinder::bindNodeOrRelPropertyExpression( + const Expression& child, const std::string& propertyName) { + auto& nodeOrRel = child.constCast(); + // TODO(Xiyang): we should be able to remove l97-l100 after removing propertyDataExprs from node + // & rel expression. + if (propertyName == InternalKeyword::ID && + child.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) { + auto& node = ku_dynamic_cast(child); + return node.getInternalID(); + } + if (!nodeOrRel.hasPropertyExpression(propertyName)) { + throw BinderException( + "Cannot find property " + propertyName + " for " + child.toString() + "."); + } + // We always create new object when binding expression except when referring to an existing + // alias when binding variables. + return nodeOrRel.getPropertyExpression(propertyName)->copy(); +} + +std::shared_ptr ExpressionBinder::bindStructPropertyExpression( + std::shared_ptr child, const std::string& propertyName) { + auto children = expression_vector{std::move(child), createLiteralExpression(propertyName)}; + return bindScalarFunctionExpression(children, function::StructExtractFunctions::name); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_subquery_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_subquery_expression.cpp new file mode 100644 index 0000000000..4e09f7bf6b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_subquery_expression.cpp @@ -0,0 +1,74 @@ +#include "binder/binder.h" +#include "binder/expression/aggregate_function_expression.h" +#include "binder/expression/subquery_expression.h" +#include "binder/expression_binder.h" +#include "catalog/catalog.h" +#include "common/types/value/value.h" +#include "function/aggregate/count_star.h" +#include "function/built_in_function_utils.h" +#include "parser/expression/parsed_subquery_expression.h" +#include "transaction/transaction.h" + +using namespace lbug::parser; +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindSubqueryExpression( + const ParsedExpression& parsedExpr) { + auto& subqueryExpr = ku_dynamic_cast(parsedExpr); + auto prevScope = binder->saveScope(); + auto boundGraphPattern = binder->bindGraphPattern(subqueryExpr.getPatternElements()); + if (subqueryExpr.hasWhereClause()) { + boundGraphPattern.where = binder->bindWhereExpression(*subqueryExpr.getWhereClause()); + } + binder->rewriteMatchPattern(boundGraphPattern); + auto subqueryType = subqueryExpr.getSubqueryType(); + auto dataType = + subqueryType == SubqueryType::COUNT ? LogicalType::INT64() : LogicalType::BOOL(); + auto rawName = subqueryExpr.getRawName(); + auto uniqueName = binder->getUniqueExpressionName(rawName); + auto boundSubqueryExpr = make_shared(subqueryType, std::move(dataType), + std::move(boundGraphPattern.queryGraphCollection), uniqueName, std::move(rawName)); + boundSubqueryExpr->setWhereExpression(boundGraphPattern.where); + // Bind projection + auto entry = catalog::Catalog::Get(*context)->getFunctionEntry( + transaction::Transaction::Get(*context), CountStarFunction::name); + auto function = BuiltInFunctionsUtils::matchAggregateFunction(CountStarFunction::name, + std::vector{}, false, entry->ptrCast()); + auto bindData = std::make_unique(LogicalType(function->returnTypeID)); + auto countStarExpr = + std::make_shared(function->copy(), std::move(bindData), + expression_vector{}, binder->getUniqueExpressionName(CountStarFunction::name)); + boundSubqueryExpr->setCountStarExpr(countStarExpr); + std::shared_ptr projectionExpr; + switch (subqueryType) { + case SubqueryType::COUNT: { + // Rewrite COUNT subquery as COUNT(*) + projectionExpr = countStarExpr; + } break; + case SubqueryType::EXISTS: { + // Rewrite EXISTS subquery as COUNT(*) > 0 + auto literalExpr = createLiteralExpression(Value(static_cast(0))); + projectionExpr = bindComparisonExpression(ExpressionType::GREATER_THAN, + expression_vector{countStarExpr, literalExpr}); + } break; + default: + KU_UNREACHABLE; + } + // Use the same unique identifier for projection & subquery expression. We will replace subquery + // expression with projection expression during processing. + projectionExpr->setUniqueName(uniqueName); + boundSubqueryExpr->setProjectionExpr(projectionExpr); + if (subqueryExpr.hasHint()) { + boundSubqueryExpr->setHint(binder->bindJoinHint( + *boundSubqueryExpr->getQueryGraphCollection(), *subqueryExpr.getHint())); + } + binder->restoreScope(std::move(prevScope)); + return boundSubqueryExpr; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_variable_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_variable_expression.cpp new file mode 100644 index 0000000000..e84fb1211e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bind_expression/bind_variable_expression.cpp @@ -0,0 +1,41 @@ +#include "binder/binder.h" +#include "binder/expression/variable_expression.h" +#include "binder/expression_binder.h" +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "parser/expression/parsed_variable_expression.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindVariableExpression( + const ParsedExpression& parsedExpression) const { + auto& variableExpression = ku_dynamic_cast(parsedExpression); + auto variableName = variableExpression.getVariableName(); + return bindVariableExpression(variableName); +} + +std::shared_ptr ExpressionBinder::bindVariableExpression( + const std::string& varName) const { + if (binder->scope.contains(varName)) { + return binder->scope.getExpression(varName); + } + throw BinderException(ExceptionMessage::variableNotInScope(varName)); +} + +std::shared_ptr ExpressionBinder::createVariableExpression(LogicalType logicalType, + std::string_view name) const { + return createVariableExpression(std::move(logicalType), std::string(name)); +} + +std::shared_ptr ExpressionBinder::createVariableExpression(LogicalType logicalType, + std::string name) const { + return std::make_shared(std::move(logicalType), + binder->getUniqueExpressionName(name), std::move(name)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/binder.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/binder.cpp new file mode 100644 index 0000000000..d3cfb679a3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/binder.cpp @@ -0,0 +1,278 @@ +#include "binder/binder.h" + +#include "binder/bound_statement_rewriter.h" +#include "catalog/catalog.h" +#include "common/copier_config/csv_reader_config.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "function/built_in_function_utils.h" +#include "function/table/table_function.h" +#include "parser/statement.h" +#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" +#include "processor/operator/persistent/reader/npy/npy_reader.h" +#include "processor/operator/persistent/reader/parquet/parquet_reader.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::parser; +using namespace lbug::processor; + +namespace lbug { +namespace binder { + +std::unique_ptr Binder::bind(const Statement& statement) { + std::unique_ptr boundStatement; + switch (statement.getStatementType()) { + case StatementType::CREATE_TABLE: { + boundStatement = bindCreateTable(statement); + } break; + case StatementType::CREATE_TYPE: { + boundStatement = bindCreateType(statement); + } break; + case StatementType::CREATE_SEQUENCE: { + boundStatement = bindCreateSequence(statement); + } break; + case StatementType::COPY_FROM: { + boundStatement = bindCopyFromClause(statement); + } break; + case StatementType::COPY_TO: { + boundStatement = bindCopyToClause(statement); + } break; + case StatementType::DROP: { + boundStatement = bindDrop(statement); + } break; + case StatementType::ALTER: { + boundStatement = bindAlter(statement); + } break; + case StatementType::QUERY: { + boundStatement = bindQuery(statement); + } break; + case StatementType::STANDALONE_CALL: { + boundStatement = bindStandaloneCall(statement); + } break; + case StatementType::STANDALONE_CALL_FUNCTION: { + boundStatement = bindStandaloneCallFunction(statement); + } break; + case StatementType::EXPLAIN: { + boundStatement = bindExplain(statement); + } break; + case StatementType::CREATE_MACRO: { + boundStatement = bindCreateMacro(statement); + } break; + case StatementType::TRANSACTION: { + boundStatement = bindTransaction(statement); + } break; + case StatementType::EXTENSION: { + boundStatement = bindExtension(statement); + } break; + case StatementType::EXPORT_DATABASE: { + boundStatement = bindExportDatabaseClause(statement); + } break; + case StatementType::IMPORT_DATABASE: { + boundStatement = bindImportDatabaseClause(statement); + } break; + case StatementType::ATTACH_DATABASE: { + boundStatement = bindAttachDatabase(statement); + } break; + case StatementType::DETACH_DATABASE: { + boundStatement = bindDetachDatabase(statement); + } break; + case StatementType::USE_DATABASE: { + boundStatement = bindUseDatabase(statement); + } break; + case StatementType::EXTENSION_CLAUSE: { + boundStatement = bindExtensionClause(statement); + } break; + default: { + KU_UNREACHABLE; + } + } + BoundStatementRewriter::rewrite(*boundStatement, *clientContext); + return boundStatement; +} + +std::shared_ptr Binder::bindWhereExpression(const ParsedExpression& parsedExpression) { + auto whereExpression = expressionBinder.bindExpression(parsedExpression); + expressionBinder.implicitCastIfNecessary(whereExpression, LogicalType::BOOL()); + return whereExpression; +} + +std::shared_ptr Binder::createVariable(std::string_view name, LogicalTypeID typeID) { + return createVariable(std::string(name), LogicalType{typeID}); +} + +std::shared_ptr Binder::createVariable(const std::string& name, + LogicalTypeID logicalTypeID) { + return createVariable(name, LogicalType{logicalTypeID}); +} + +std::shared_ptr Binder::createVariable(const std::string& name, + const LogicalType& dataType) { + if (scope.contains(name)) { + throw BinderException("Variable " + name + " already exists."); + } + auto expression = expressionBinder.createVariableExpression(dataType.copy(), name); + expression->setAlias(name); + addToScope(name, expression); + return expression; +} + +std::shared_ptr Binder::createInvisibleVariable(const std::string& name, + const LogicalType& dataType) const { + auto expression = expressionBinder.createVariableExpression(dataType.copy(), name); + expression->setAlias(name); + return expression; +} + +expression_vector Binder::createVariables(const std::vector& names, + const std::vector& types) { + KU_ASSERT(names.size() == types.size()); + expression_vector variables; + for (auto i = 0u; i < names.size(); ++i) { + variables.push_back(createVariable(names[i], types[i])); + } + return variables; +} + +expression_vector Binder::createInvisibleVariables(const std::vector& names, + const std::vector& types) const { + KU_ASSERT(names.size() == types.size()); + expression_vector variables; + for (auto i = 0u; i < names.size(); ++i) { + variables.push_back(createInvisibleVariable(names[i], types[i])); + } + return variables; +} + +std::string Binder::getUniqueExpressionName(const std::string& name) { + return "_" + std::to_string(lastExpressionId++) + "_" + name; +} + +struct ReservedNames { + // Column name that might conflict with internal names. + static std::unordered_set getColumnNames() { + return { + InternalKeyword::ID, + InternalKeyword::LABEL, + InternalKeyword::SRC, + InternalKeyword::DST, + InternalKeyword::DIRECTION, + InternalKeyword::LENGTH, + InternalKeyword::NODES, + InternalKeyword::RELS, + InternalKeyword::PLACE_HOLDER, + StringUtils::getUpper(InternalKeyword::ROW_OFFSET), + StringUtils::getUpper(InternalKeyword::SRC_OFFSET), + StringUtils::getUpper(InternalKeyword::DST_OFFSET), + }; + } + + // Properties that should be hidden from user access. + static std::unordered_set getPropertyLookupName() { + return { + InternalKeyword::ID, + }; + } +}; + +bool Binder::reservedInColumnName(const std::string& name) { + auto normalizedName = StringUtils::getUpper(name); + return ReservedNames::getColumnNames().contains(normalizedName); +} + +bool Binder::reservedInPropertyLookup(const std::string& name) { + auto normalizedName = StringUtils::getUpper(name); + return ReservedNames::getPropertyLookupName().contains(normalizedName); +} + +void Binder::addToScope(const std::vector& names, const expression_vector& exprs) { + KU_ASSERT(names.size() == exprs.size()); + for (auto i = 0u; i < names.size(); ++i) { + addToScope(names[i], exprs[i]); + } +} + +void Binder::addToScope(const std::string& name, std::shared_ptr expr) { + // TODO(Xiyang): assert name not in scope. + // Note to Xiyang: I don't think the TODO still stands here. I tried adding the assertion, but + // it failed a few tests. You may want to revisit this TODO. + scope.addExpression(name, std::move(expr)); +} + +BinderScope Binder::saveScope() const { + return scope.copy(); +} + +void Binder::restoreScope(BinderScope prevScope) { + scope = std::move(prevScope); +} + +void Binder::replaceExpressionInScope(const std::string& oldName, const std::string& newName, + std::shared_ptr expression) { + scope.replaceExpression(oldName, newName, expression); +} + +TableFunction Binder::getScanFunction(const FileTypeInfo& typeInfo, + const FileScanInfo& fileScanInfo) const { + Function* func = nullptr; + std::vector inputTypes; + inputTypes.push_back(LogicalType::STRING()); + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + switch (typeInfo.fileType) { + case FileType::PARQUET: { + auto entry = catalog->getFunctionEntry(transaction, ParquetScanFunction::name); + func = BuiltInFunctionsUtils::matchFunction(ParquetScanFunction::name, inputTypes, + entry->ptrCast()); + } break; + case FileType::NPY: { + auto entry = catalog->getFunctionEntry(transaction, NpyScanFunction::name); + func = BuiltInFunctionsUtils::matchFunction(NpyScanFunction::name, inputTypes, + entry->ptrCast()); + } break; + case FileType::CSV: { + bool containCompressedCSV = std::any_of(fileScanInfo.filePaths.begin(), + fileScanInfo.filePaths.end(), [&](const auto& file) { + return VirtualFileSystem::GetUnsafe(*clientContext)->isCompressedFile(file); + }); + auto csvConfig = CSVReaderConfig::construct(fileScanInfo.options); + // Parallel CSV scanning is only allowed: + // 1. No newline character inside the csv body. + // 2. The CSV file to scan is not compressed (because we couldn't perform seek in such + // case). + // 3. Not explicitly set by the user to use the serial csv reader. + auto name = (csvConfig.parallel && !containCompressedCSV) ? ParallelCSVScan::name : + SerialCSVScan::name; + auto entry = catalog->getFunctionEntry(transaction, name); + func = BuiltInFunctionsUtils::matchFunction(name, inputTypes, + entry->ptrCast()); + } break; + case FileType::UNKNOWN: { + try { + auto name = stringFormat("{}_SCAN", typeInfo.fileTypeStr); + auto entry = catalog->getFunctionEntry(transaction, name); + func = BuiltInFunctionsUtils::matchFunction(name, inputTypes, + entry->ptrCast()); + } catch (...) { + if (typeInfo.fileTypeStr == "") { + throw BinderException{"Cannot infer the format of the given file. Please " + "set the file format explicitly by (file_format=)."}; + } + throw BinderException{ + stringFormat("Cannot load from file type {}. If this file type is part of a lbug " + "extension please load the extension then try again.", + typeInfo.fileTypeStr)}; + } + } break; + default: + KU_UNREACHABLE; + } + return *func->ptrCast(); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/binder_scope.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/binder_scope.cpp new file mode 100644 index 0000000000..765ed33d69 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/binder_scope.cpp @@ -0,0 +1,27 @@ +#include "binder/binder_scope.h" + +namespace lbug { +namespace binder { + +void BinderScope::addExpression(const std::string& varName, + std::shared_ptr expression) { + nameToExprIdx.insert({varName, expressions.size()}); + expressions.push_back(std::move(expression)); +} + +void BinderScope::replaceExpression(const std::string& oldName, const std::string& newName, + std::shared_ptr expression) { + KU_ASSERT(nameToExprIdx.contains(oldName)); + auto idx = nameToExprIdx.at(oldName); + expressions[idx] = std::move(expression); + nameToExprIdx.erase(oldName); + nameToExprIdx.insert({newName, idx}); +} + +void BinderScope::clear() { + expressions.clear(); + nameToExprIdx.clear(); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_scan_source.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_scan_source.cpp new file mode 100644 index 0000000000..166b37f121 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_scan_source.cpp @@ -0,0 +1,36 @@ +#include "binder/bound_scan_source.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +expression_vector BoundTableScanSource::getWarningColumns() const { + expression_vector warningDataExprs; + auto& columns = info.bindData->columns; + switch (type) { + case ScanSourceType::FILE: { + auto bindData = info.bindData->constPtrCast(); + for (auto i = bindData->numWarningDataColumns; i >= 1; --i) { + KU_ASSERT(i < columns.size()); + warningDataExprs.push_back(columns[columns.size() - i]); + } + } break; + default: + break; + } + return warningDataExprs; +} + +bool BoundTableScanSource::getIgnoreErrorsOption() const { + return info.bindData->getIgnoreErrorsOption(); +} + +bool BoundQueryScanSource::getIgnoreErrorsOption() const { + return info.options.contains(CopyConstants::IGNORE_ERRORS_OPTION_NAME) ? + info.options.at(CopyConstants::IGNORE_ERRORS_OPTION_NAME).getValue() : + CopyConstants::DEFAULT_IGNORE_ERRORS; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_result.cpp new file mode 100644 index 0000000000..10fc8985fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_result.cpp @@ -0,0 +1,20 @@ +#include "binder/bound_statement_result.h" + +#include "binder/expression/literal_expression.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +BoundStatementResult BoundStatementResult::createSingleStringColumnResult( + const std::string& columnName) { + auto result = BoundStatementResult(); + auto value = Value(LogicalType::STRING(), columnName); + auto stringColumn = std::make_shared(std::move(value), columnName); + result.addColumn(columnName, stringColumn); + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_rewriter.cpp new file mode 100644 index 0000000000..368db62743 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_rewriter.cpp @@ -0,0 +1,27 @@ +#include "binder/bound_statement_rewriter.h" + +#include "binder/rewriter/match_clause_pattern_label_rewriter.h" +#include "binder/rewriter/normalized_query_part_match_rewriter.h" +#include "binder/rewriter/with_clause_projection_rewriter.h" +#include "binder/visitor/default_type_solver.h" + +namespace lbug { +namespace binder { + +void BoundStatementRewriter::rewrite(BoundStatement& boundStatement, + main::ClientContext& clientContext) { + auto withClauseProjectionRewriter = WithClauseProjectionRewriter(); + withClauseProjectionRewriter.visitUnsafe(boundStatement); + + auto normalizedQueryPartMatchRewriter = NormalizedQueryPartMatchRewriter(&clientContext); + normalizedQueryPartMatchRewriter.visitUnsafe(boundStatement); + + auto matchClausePatternLabelRewriter = MatchClausePatternLabelRewriter(clientContext); + matchClausePatternLabelRewriter.visitUnsafe(boundStatement); + + auto defaultTypeSolver = DefaultTypeSolver(); + defaultTypeSolver.visit(boundStatement); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_visitor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_visitor.cpp new file mode 100644 index 0000000000..230e67cd39 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/bound_statement_visitor.cpp @@ -0,0 +1,222 @@ +#include "binder/bound_statement_visitor.h" + +#include "binder/bound_explain.h" +#include "binder/copy/bound_copy_from.h" +#include "binder/copy/bound_copy_to.h" +#include "binder/query/bound_regular_query.h" +#include "common/cast.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +void BoundStatementVisitor::visit(const BoundStatement& statement) { + switch (statement.getStatementType()) { + case StatementType::QUERY: { + visitRegularQuery(statement); + } break; + case StatementType::CREATE_SEQUENCE: { + visitCreateSequence(statement); + } break; + case StatementType::DROP: { + visitDrop(statement); + } break; + case StatementType::CREATE_TABLE: { + visitCreateTable(statement); + } break; + case StatementType::CREATE_TYPE: { + visitCreateType(statement); + } break; + case StatementType::ALTER: { + visitAlter(statement); + } break; + case StatementType::COPY_FROM: { + visitCopyFrom(statement); + } break; + case StatementType::COPY_TO: { + visitCopyTo(statement); + } break; + case StatementType::STANDALONE_CALL: { + visitStandaloneCall(statement); + } break; + case StatementType::EXPLAIN: { + visitExplain(statement); + } break; + case StatementType::CREATE_MACRO: { + visitCreateMacro(statement); + } break; + case StatementType::TRANSACTION: { + visitTransaction(statement); + } break; + case StatementType::EXTENSION: { + visitExtension(statement); + } break; + case StatementType::EXPORT_DATABASE: { + visitExportDatabase(statement); + } break; + case StatementType::IMPORT_DATABASE: { + visitImportDatabase(statement); + } break; + case StatementType::ATTACH_DATABASE: { + visitAttachDatabase(statement); + } break; + case StatementType::DETACH_DATABASE: { + visitDetachDatabase(statement); + } break; + case StatementType::USE_DATABASE: { + visitUseDatabase(statement); + } break; + case StatementType::STANDALONE_CALL_FUNCTION: { + visitStandaloneCallFunction(statement); + } break; + case StatementType::EXTENSION_CLAUSE: { + visitExtensionClause(statement); + } break; + default: + KU_UNREACHABLE; + } +} + +void BoundStatementVisitor::visitUnsafe(BoundStatement& statement) { + switch (statement.getStatementType()) { + case StatementType::QUERY: { + visitRegularQueryUnsafe(statement); + } break; + default: + break; + } +} + +void BoundStatementVisitor::visitCopyFrom(const BoundStatement& statement) { + auto& copyFrom = ku_dynamic_cast(statement); + if (copyFrom.getInfo()->source->type == ScanSourceType::QUERY) { + auto querySource = ku_dynamic_cast(copyFrom.getInfo()->source.get()); + visit(*querySource->statement); + } +} + +void BoundStatementVisitor::visitCopyTo(const BoundStatement& statement) { + auto& copyTo = ku_dynamic_cast(statement); + visitRegularQuery(*copyTo.getRegularQuery()); +} + +void BoundStatementVisitor::visitRegularQuery(const BoundStatement& statement) { + auto& regularQuery = ku_dynamic_cast(statement); + for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) { + visitSingleQuery(*regularQuery.getSingleQuery(i)); + } +} + +void BoundStatementVisitor::visitRegularQueryUnsafe(BoundStatement& statement) { + auto& regularQuery = statement.cast(); + for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) { + visitSingleQueryUnsafe(*regularQuery.getSingleQueryUnsafe(i)); + } +} + +void BoundStatementVisitor::visitSingleQuery(const NormalizedSingleQuery& singleQuery) { + for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) { + visitQueryPart(*singleQuery.getQueryPart(i)); + } +} + +void BoundStatementVisitor::visitSingleQueryUnsafe(NormalizedSingleQuery& singleQuery) { + for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) { + visitQueryPartUnsafe(*singleQuery.getQueryPartUnsafe(i)); + } +} + +void BoundStatementVisitor::visitQueryPart(const NormalizedQueryPart& queryPart) { + for (auto i = 0u; i < queryPart.getNumReadingClause(); ++i) { + visitReadingClause(*queryPart.getReadingClause(i)); + } + for (auto i = 0u; i < queryPart.getNumUpdatingClause(); ++i) { + visitUpdatingClause(*queryPart.getUpdatingClause(i)); + } + if (queryPart.hasProjectionBody()) { + visitProjectionBody(*queryPart.getProjectionBody()); + if (queryPart.hasProjectionBodyPredicate()) { + visitProjectionBodyPredicate(queryPart.getProjectionBodyPredicate()); + } + } +} + +void BoundStatementVisitor::visitQueryPartUnsafe(NormalizedQueryPart& queryPart) { + for (auto i = 0u; i < queryPart.getNumReadingClause(); ++i) { + visitReadingClauseUnsafe(*queryPart.getReadingClause(i)); + } + for (auto i = 0u; i < queryPart.getNumUpdatingClause(); ++i) { + visitUpdatingClause(*queryPart.getUpdatingClause(i)); + } + if (queryPart.hasProjectionBody()) { + visitProjectionBody(*queryPart.getProjectionBody()); + if (queryPart.hasProjectionBodyPredicate()) { + visitProjectionBodyPredicate(queryPart.getProjectionBodyPredicate()); + } + } +} + +void BoundStatementVisitor::visitExplain(const BoundStatement& statement) { + visit(*(statement.constCast()).getStatementToExplain()); +} + +void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& readingClause) { + switch (readingClause.getClauseType()) { + case ClauseType::MATCH: { + visitMatch(readingClause); + } break; + case ClauseType::UNWIND: { + visitUnwind(readingClause); + } break; + case ClauseType::TABLE_FUNCTION_CALL: { + visitTableFunctionCall(readingClause); + } break; + case ClauseType::LOAD_FROM: { + visitLoadFrom(readingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +void BoundStatementVisitor::visitReadingClauseUnsafe(BoundReadingClause& readingClause) { + switch (readingClause.getClauseType()) { + case ClauseType::MATCH: { + visitMatchUnsafe(readingClause); + } break; + case ClauseType::UNWIND: { + visitUnwind(readingClause); + } break; + case ClauseType::TABLE_FUNCTION_CALL: { + visitTableFunctionCall(readingClause); + } break; + case ClauseType::LOAD_FROM: { + visitLoadFrom(readingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +void BoundStatementVisitor::visitUpdatingClause(const BoundUpdatingClause& updatingClause) { + switch (updatingClause.getClauseType()) { + case ClauseType::SET: { + visitSet(updatingClause); + } break; + case ClauseType::DELETE_: { + visitDelete(updatingClause); + } break; + case ClauseType::INSERT: { + visitInsert(updatingClause); + } break; + case ClauseType::MERGE: { + visitMerge(updatingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/CMakeLists.txt new file mode 100644 index 0000000000..d0630c0c49 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library( + lbug_binder_ddl + OBJECT + bound_alter_info.cpp + property_definition.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/bound_alter_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/bound_alter_info.cpp new file mode 100644 index 0000000000..1ef37ccb50 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/bound_alter_info.cpp @@ -0,0 +1,43 @@ +#include "binder/ddl/bound_alter_info.h" + +namespace lbug { +namespace binder { + +std::string BoundAlterInfo::toString() const { + std::string result = "Operation: "; + switch (alterType) { + case common::AlterType::RENAME: { + auto renameInfo = common::ku_dynamic_cast(extraInfo.get()); + result += "Rename Table " + tableName + " to " + renameInfo->newName; + break; + } + case common::AlterType::ADD_PROPERTY: { + auto addPropInfo = common::ku_dynamic_cast(extraInfo.get()); + result += + "Add Property " + addPropInfo->propertyDefinition.getName() + " to Table " + tableName; + break; + } + case common::AlterType::DROP_PROPERTY: { + auto dropPropInfo = common::ku_dynamic_cast(extraInfo.get()); + result += "Drop Property " + dropPropInfo->propertyName + " from Table " + tableName; + break; + } + case common::AlterType::RENAME_PROPERTY: { + auto renamePropInfo = + common::ku_dynamic_cast(extraInfo.get()); + result += "Rename Property " + renamePropInfo->oldName + " to " + renamePropInfo->newName + + " in Table " + tableName; + break; + } + case common::AlterType::COMMENT: { + result += "Comment on Table " + tableName; + break; + } + default: + break; + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/property_definition.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/property_definition.cpp new file mode 100644 index 0000000000..f125ea80fe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/ddl/property_definition.cpp @@ -0,0 +1,34 @@ +#include "binder/ddl/property_definition.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "parser/expression/parsed_literal_expression.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +PropertyDefinition::PropertyDefinition(ColumnDefinition columnDefinition) + : columnDefinition{std::move(columnDefinition)} { + defaultExpr = std::make_unique(Value::createNullValue(), "NULL"); +} + +void PropertyDefinition::serialize(Serializer& serializer) const { + serializer.serializeValue(columnDefinition.name); + columnDefinition.type.serialize(serializer); + defaultExpr->serialize(serializer); +} + +PropertyDefinition PropertyDefinition::deserialize(Deserializer& deserializer) { + std::string name; + deserializer.deserializeValue(name); + auto type = LogicalType::deserialize(deserializer); + auto columnDefinition = ColumnDefinition(name, std::move(type)); + auto defaultExpr = ParsedExpression::deserialize(deserializer); + return PropertyDefinition(std::move(columnDefinition), std::move(defaultExpr)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/CMakeLists.txt new file mode 100644 index 0000000000..02da08a61e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/CMakeLists.txt @@ -0,0 +1,19 @@ +add_library( + lbug_binder_expression + OBJECT + aggregate_function_expression.cpp + case_expression.cpp + expression.cpp + expression_util.cpp + literal_expression.cpp + node_expression.cpp + node_rel_expression.cpp + parameter_expression.cpp + property_expression.cpp + rel_expression.cpp + scalar_function_expression.cpp + variable_expression.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/aggregate_function_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/aggregate_function_expression.cpp new file mode 100644 index 0000000000..f88907e843 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/aggregate_function_expression.cpp @@ -0,0 +1,22 @@ +#include "binder/expression/aggregate_function_expression.h" + +#include "binder/expression/expression_util.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +std::string AggregateFunctionExpression::toStringInternal() const { + return stringFormat("{}({}{})", function.name, function.isDistinct ? "DISTINCT " : "", + ExpressionUtil::toString(children)); +} + +std::string AggregateFunctionExpression::getUniqueName(const std::string& functionName, + const expression_vector& children, bool isDistinct) { + return stringFormat("{}({}{})", functionName, isDistinct ? "DISTINCT " : "", + ExpressionUtil::getUniqueName(children)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/case_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/case_expression.cpp new file mode 100644 index 0000000000..aac05b480c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/case_expression.cpp @@ -0,0 +1,17 @@ +#include "binder/expression/case_expression.h" + +namespace lbug { +namespace binder { + +std::string CaseExpression::toStringInternal() const { + std::string result = "CASE "; + for (auto& caseAlternative : caseAlternatives) { + result += "WHEN " + caseAlternative->whenExpression->toString() + " THEN " + + caseAlternative->thenExpression->toString(); + } + result += " ELSE " + elseExpression->toString(); + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/expression.cpp new file mode 100644 index 0000000000..1c29348b06 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/expression.cpp @@ -0,0 +1,34 @@ +#include "binder/expression/expression.h" + +#include "common/exception/binder.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +Expression::~Expression() = default; + +void Expression::cast(const LogicalType&) { + // LCOV_EXCL_START + throw BinderException( + stringFormat("Data type of expression {} should not be modified.", toString())); + // LCOV_EXCL_STOP +} + +expression_vector Expression::splitOnAND() { + expression_vector result; + if (ExpressionType::AND == expressionType) { + for (auto& child : children) { + for (auto& exp : child->splitOnAND()) { + result.push_back(exp); + } + } + } else { + result.push_back(shared_from_this()); + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/expression_util.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/expression_util.cpp new file mode 100644 index 0000000000..0708a416d9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/expression_util.cpp @@ -0,0 +1,568 @@ +#include "binder/expression/expression_util.h" + +#include + +#include "binder/binder.h" +#include "binder/expression/literal_expression.h" +#include "binder/expression/node_rel_expression.h" +#include "binder/expression/parameter_expression.h" +#include "common/exception/binder.h" +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "common/types/value/nested.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +expression_vector ExpressionUtil::getExpressionsWithDataType(const expression_vector& expressions, + LogicalTypeID dataTypeID) { + expression_vector result; + for (auto& expression : expressions) { + if (expression->dataType.getLogicalTypeID() == dataTypeID) { + result.push_back(expression); + } + } + return result; +} + +uint32_t ExpressionUtil::find(const Expression* target, const expression_vector& expressions) { + for (auto i = 0u; i < expressions.size(); ++i) { + if (target->getUniqueName() == expressions[i]->getUniqueName()) { + return i; + } + } + return INVALID_IDX; +} + +std::string ExpressionUtil::toString(const expression_vector& expressions) { + if (expressions.empty()) { + return std::string{}; + } + auto result = expressions[0]->toString(); + for (auto i = 1u; i < expressions.size(); ++i) { + result += "," + expressions[i]->toString(); + } + return result; +} + +std::string ExpressionUtil::toStringOrdered(const expression_vector& expressions) { + auto expressions_ = expressions; + std::sort(expressions_.begin(), expressions_.end(), + [](const std::shared_ptr& a, const std::shared_ptr& b) { + return a->toString() < b->toString(); + }); + return toString(expressions_); +} + +std::string ExpressionUtil::toString(const std::vector& expressionPairs) { + if (expressionPairs.empty()) { + return std::string{}; + } + auto result = toString(expressionPairs[0]); + for (auto i = 1u; i < expressionPairs.size(); ++i) { + result += "," + toString(expressionPairs[i]); + } + return result; +} + +std::string ExpressionUtil::toString(const expression_pair& expressionPair) { + return expressionPair.first->toString() + "=" + expressionPair.second->toString(); +} + +std::string ExpressionUtil::getUniqueName(const expression_vector& expressions) { + if (expressions.empty()) { + return std::string(); + } + auto result = expressions[0]->getUniqueName(); + for (auto i = 1u; i < expressions.size(); ++i) { + result += "," + expressions[i]->getUniqueName(); + } + return result; +} + +expression_vector ExpressionUtil::excludeExpression(const expression_vector& exprs, + const Expression& exprToExclude) { + expression_vector result; + for (auto& expr : exprs) { + if (*expr != exprToExclude) { + result.push_back(expr); + } + } + return result; +} + +expression_vector ExpressionUtil::excludeExpressions(const expression_vector& expressions, + const expression_vector& expressionsToExclude) { + expression_set excludeSet; + for (auto& expression : expressionsToExclude) { + excludeSet.insert(expression); + } + expression_vector result; + for (auto& expression : expressions) { + if (!excludeSet.contains(expression)) { + result.push_back(expression); + } + } + return result; +} + +logical_type_vec_t ExpressionUtil::getDataTypes(const expression_vector& expressions) { + std::vector result; + result.reserve(expressions.size()); + for (auto& expression : expressions) { + result.push_back(expression->getDataType().copy()); + } + return result; +} + +expression_vector ExpressionUtil::removeDuplication(const expression_vector& expressions) { + expression_vector result; + expression_set expressionSet; + for (auto& expression : expressions) { + if (expressionSet.contains(expression)) { + continue; + } + result.push_back(expression); + expressionSet.insert(expression); + } + return result; +} + +bool ExpressionUtil::isEmptyPattern(const Expression& expression) { + if (expression.expressionType != ExpressionType::PATTERN) { + return false; + } + return expression.constCast().isEmpty(); +} + +bool ExpressionUtil::isNodePattern(const Expression& expression) { + if (expression.expressionType != ExpressionType::PATTERN) { + return false; + } + return expression.dataType.getLogicalTypeID() == LogicalTypeID::NODE; +}; + +bool ExpressionUtil::isRelPattern(const Expression& expression) { + if (expression.expressionType != ExpressionType::PATTERN) { + return false; + } + return expression.dataType.getLogicalTypeID() == LogicalTypeID::REL; +} + +bool ExpressionUtil::isRecursiveRelPattern(const Expression& expression) { + if (expression.expressionType != ExpressionType::PATTERN) { + return false; + } + return expression.dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL; +} + +bool ExpressionUtil::isNullLiteral(const Expression& expression) { + if (expression.expressionType != ExpressionType::LITERAL) { + return false; + } + return expression.constCast().getValue().isNull(); +} + +bool ExpressionUtil::isBoolLiteral(const Expression& expression) { + if (expression.expressionType != ExpressionType::LITERAL) { + return false; + } + return expression.dataType == LogicalType::BOOL(); +} + +bool ExpressionUtil::isFalseLiteral(const Expression& expression) { + if (expression.expressionType != ExpressionType::LITERAL) { + return false; + } + return expression.constCast().getValue().getValue() == false; +} + +bool ExpressionUtil::isEmptyList(const Expression& expression) { + auto val = Value::createNullValue(); + switch (expression.expressionType) { + case ExpressionType::LITERAL: { + val = expression.constCast().getValue(); + } break; + case ExpressionType::PARAMETER: { + val = expression.constCast().getValue(); + } break; + default: + return false; + } + if (val.getDataType().getLogicalTypeID() != LogicalTypeID::LIST) { + return false; + } + return val.getChildrenSize() == 0; +} + +void ExpressionUtil::validateExpressionType(const Expression& expr, ExpressionType expectedType) { + if (expr.expressionType == expectedType) { + return; + } + throw BinderException(stringFormat("{} has type {} but {} was expected.", expr.toString(), + ExpressionTypeUtil::toString(expr.expressionType), + ExpressionTypeUtil::toString(expectedType))); +} + +void ExpressionUtil::validateExpressionType(const Expression& expr, + std::vector expectedType) { + if (std::find(expectedType.begin(), expectedType.end(), expr.expressionType) != + expectedType.end()) { + return; + } + std::string expectedTypesStr = ""; + std::for_each(expectedType.begin(), expectedType.end(), + [&expectedTypesStr](ExpressionType type) { + expectedTypesStr += expectedTypesStr.empty() ? ExpressionTypeUtil::toString(type) : + "," + ExpressionTypeUtil::toString(type); + }); + throw BinderException(stringFormat("{} has type {} but {} was expected.", expr.toString(), + ExpressionTypeUtil::toString(expr.expressionType), expectedTypesStr)); +} + +void ExpressionUtil::validateDataType(const Expression& expr, const LogicalType& expectedType) { + if (expr.getDataType() == expectedType) { + return; + } + throw BinderException(stringFormat("{} has data type {} but {} was expected.", expr.toString(), + expr.getDataType().toString(), expectedType.toString())); +} + +void ExpressionUtil::validateDataType(const Expression& expr, LogicalTypeID expectedTypeID) { + if (expr.getDataType().getLogicalTypeID() == expectedTypeID) { + return; + } + throw BinderException(stringFormat("{} has data type {} but {} was expected.", expr.toString(), + expr.getDataType().toString(), LogicalTypeUtils::toString(expectedTypeID))); +} + +void ExpressionUtil::validateDataType(const Expression& expr, + const std::vector& expectedTypeIDs) { + auto targetsSet = + std::unordered_set{expectedTypeIDs.begin(), expectedTypeIDs.end()}; + if (targetsSet.contains(expr.getDataType().getLogicalTypeID())) { + return; + } + throw BinderException(stringFormat("{} has data type {} but {} was expected.", expr.toString(), + expr.getDataType().toString(), LogicalTypeUtils::toString(expectedTypeIDs))); +} + +template<> +uint64_t ExpressionUtil::getLiteralValue(const Expression& expr) { + validateExpressionType(expr, ExpressionType::LITERAL); + validateDataType(expr, LogicalType::UINT64()); + return expr.constCast().getValue().getValue(); +} +template<> +int64_t ExpressionUtil::getLiteralValue(const Expression& expr) { + validateExpressionType(expr, ExpressionType::LITERAL); + validateDataType(expr, LogicalType::INT64()); + return expr.constCast().getValue().getValue(); +} +template<> +bool ExpressionUtil::getLiteralValue(const Expression& expr) { + validateExpressionType(expr, ExpressionType::LITERAL); + validateDataType(expr, LogicalType::BOOL()); + return expr.constCast().getValue().getValue(); +} +template<> +std::string ExpressionUtil::getLiteralValue(const Expression& expr) { + validateExpressionType(expr, ExpressionType::LITERAL); + validateDataType(expr, LogicalType::STRING()); + return expr.constCast().getValue().getValue(); +} +template<> +double ExpressionUtil::getLiteralValue(const Expression& expr) { + validateExpressionType(expr, ExpressionType::LITERAL); + validateDataType(expr, LogicalType::DOUBLE()); + return expr.constCast().getValue().getValue(); +} + +// For primitive types, two types are compatible if they have the same id. +// For nested types, two types are compatible if they have the same id and their children are also +// compatible. +// E.g. [NULL] is compatible with [1,2] +// E.g. {a: NULL, b: NULL} is compatible with {a: [1,2], b: ['c']} +static bool compatible(const LogicalType& type, const LogicalType& target) { + if (type.isInternalType() != target.isInternalType()) { + return false; + } + if (type.getLogicalTypeID() == LogicalTypeID::ANY) { + return true; + } + if (type.getLogicalTypeID() != target.getLogicalTypeID()) { + return false; + } + switch (type.getLogicalTypeID()) { + case LogicalTypeID::LIST: { + return compatible(ListType::getChildType(type), ListType::getChildType(target)); + } + case LogicalTypeID::ARRAY: { + return compatible(ArrayType::getChildType(type), ArrayType::getChildType(target)); + } + case LogicalTypeID::STRUCT: { + if (StructType::getNumFields(type) != StructType::getNumFields(target)) { + return false; + } + for (auto i = 0u; i < StructType::getNumFields(type); ++i) { + if (!compatible(StructType::getField(type, i).getType(), + StructType::getField(target, i).getType())) { + return false; + } + } + return true; + } + case LogicalTypeID::DECIMAL: + case LogicalTypeID::UNION: + case LogicalTypeID::MAP: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + return false; + default: + return true; + } +} + +// Handle special cases where value can be compatible to a type. This happens when a value is a +// nested value but does not have any child. +// E.g. [] is compatible with [1,2] +static bool compatible(const Value& value, const LogicalType& targetType) { + if (value.isNull()) { // Value is null. We can safely change its type. + return true; + } + if (value.getDataType().getLogicalTypeID() != targetType.getLogicalTypeID()) { + return false; + } + switch (value.getDataType().getLogicalTypeID()) { + case LogicalTypeID::LIST: { + if (!value.hasNoneNullChildren()) { // Empty list free to change. + return true; + } + for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) { + if (!compatible(*NestedVal::getChildVal(&value, i), + ListType::getChildType(targetType))) { + return false; + } + } + return true; + } + case LogicalTypeID::ARRAY: { + if (!value.hasNoneNullChildren()) { // Empty array free to change. + return true; + } + for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) { + if (!compatible(*NestedVal::getChildVal(&value, i), + ArrayType::getChildType(targetType))) { + return false; + } + } + return true; + } + case LogicalTypeID::MAP: { + if (!value.hasNoneNullChildren()) { // Empty map free to change. + return true; + } + const auto& keyType = MapType::getKeyType(targetType); + const auto& valType = MapType::getValueType(targetType); + for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) { + auto childVal = NestedVal::getChildVal(&value, i); + KU_ASSERT(NestedVal::getChildrenSize(childVal) == 2); + auto key = NestedVal::getChildVal(childVal, 0); + auto val = NestedVal::getChildVal(childVal, 1); + if (!compatible(*key, keyType) || !compatible(*val, valType)) { + return false; + } + } + return true; + } + default: + break; + } + return compatible(value.getDataType(), targetType); +} + +bool ExpressionUtil::tryCombineDataType(const expression_vector& expressions, LogicalType& result) { + std::vector secondaryValues; + std::vector primaryTypes; + for (auto& expr : expressions) { + if (expr->expressionType != ExpressionType::LITERAL) { + primaryTypes.push_back(expr->getDataType().copy()); + continue; + } + auto literalExpr = expr->constPtrCast(); + if (literalExpr->getValue().allowTypeChange()) { + secondaryValues.push_back(literalExpr->getValue()); + continue; + } + primaryTypes.push_back(expr->getDataType().copy()); + } + if (!LogicalTypeUtils::tryGetMaxLogicalType(primaryTypes, result)) { + return false; + } + for (auto& value : secondaryValues) { + if (compatible(value, result)) { + continue; + } + if (!LogicalTypeUtils::tryGetMaxLogicalType(result, value.getDataType(), result)) { + return false; + } + } + return true; +} + +bool ExpressionUtil::canCastStatically(const Expression& expr, const LogicalType& targetType) { + switch (expr.expressionType) { + case ExpressionType::LITERAL: { + auto value = expr.constPtrCast()->getValue(); + return compatible(value, targetType); + } + case ExpressionType::PARAMETER: { + auto value = expr.constPtrCast()->getValue(); + return compatible(value, targetType); + } + default: + return compatible(expr.getDataType(), targetType); + } +} + +bool ExpressionUtil::canEvaluateAsLiteral(const Expression& expr) { + switch (expr.expressionType) { + case ExpressionType::LITERAL: + return true; + case ExpressionType::PARAMETER: + return expr.getDataType().getLogicalTypeID() != LogicalTypeID::ANY; + default: + return false; + } +} + +Value ExpressionUtil::evaluateAsLiteralValue(const Expression& expr) { + KU_ASSERT(canEvaluateAsLiteral(expr)); + auto value = Value::createDefaultValue(expr.dataType); + switch (expr.expressionType) { + case ExpressionType::LITERAL: { + value = expr.constCast().getValue(); + } break; + case ExpressionType::PARAMETER: { + value = expr.constCast().getValue(); + } break; + default: + KU_UNREACHABLE; + } + return value; +} + +uint64_t ExpressionUtil::evaluateAsSkipLimit(const Expression& expr) { + auto value = evaluateAsLiteralValue(expr); + auto errorMsg = "The number of rows to skip/limit must be a non-negative integer."; + uint64_t number = INVALID_LIMIT; + TypeUtils::visit( + value.getDataType(), + [&](T) { + if (value.getValue() < 0) { + throw RuntimeException{errorMsg}; + } + number = (uint64_t)value.getValue(); + }, + [&](auto) { throw RuntimeException{errorMsg}; }); + return number; +} + +template +T ExpressionUtil::getExpressionVal(const Expression& expr, const Value& value, + const LogicalType& targetType, validate_param_func validateParamFunc) { + if (value.getDataType() != targetType) { + throw RuntimeException{common::stringFormat("Parameter: {} must be a {} literal.", + expr.getAlias(), targetType.toString())}; + } + T val = value.getValue(); + if (validateParamFunc != nullptr) { + validateParamFunc(val); + } + return val; +} + +template +T ExpressionUtil::evaluateLiteral(main::ClientContext* context, + std::shared_ptr expression, const common::LogicalType& type, + validate_param_func validateParamFunc) { + if (!canEvaluateAsLiteral(*expression)) { + std::string errMsg; + switch (expression->expressionType) { + case ExpressionType::PARAMETER: { + errMsg = common::stringFormat( + "The expression: '{}' is a parameter expression. Please assign it a value.", + expression->toString()); + } break; + default: { + errMsg = + common::stringFormat("The expression: '{}' must be a parameter/literal expression.", + expression->toString()); + ; + } break; + } + throw RuntimeException{errMsg}; + } + if (expression->getDataType() != type) { + binder::Binder binder{context}; + auto literalExpr = std::make_shared( + ExpressionUtil::evaluateAsLiteralValue(*expression), expression->getUniqueName()); + expression = binder.getExpressionBinder()->implicitCast(literalExpr, type.copy()); + expression = binder.getExpressionBinder()->foldExpression(expression); + } + auto value = evaluateAsLiteralValue(*expression); + return getExpressionVal(*expression, value, type, validateParamFunc); +} + +std::shared_ptr ExpressionUtil::applyImplicitCastingIfNecessary( + main::ClientContext* context, std::shared_ptr expr, + common::LogicalType targetType) { + if (expr->getDataType() != targetType) { + binder::Binder binder{context}; + expr = binder.getExpressionBinder()->implicitCastIfNecessary(expr, targetType); + expr = binder.getExpressionBinder()->foldExpression(expr); + } + return expr; +} + +template LBUG_API std::string ExpressionUtil::getExpressionVal(const Expression& expr, + const common::Value& value, const common::LogicalType& targetType, + validate_param_func validateParamFunc); + +template LBUG_API double ExpressionUtil::getExpressionVal(const Expression& expr, + const common::Value& value, const common::LogicalType& targetType, + validate_param_func validateParamFunc); + +template LBUG_API int64_t ExpressionUtil::getExpressionVal(const Expression& expr, + const common::Value& value, const common::LogicalType& targetType, + validate_param_func validateParamFunc); + +template LBUG_API bool ExpressionUtil::getExpressionVal(const Expression& expr, + const common::Value& value, const common::LogicalType& targetType, + validate_param_func validateParamFunc); + +template LBUG_API std::string ExpressionUtil::evaluateLiteral( + main::ClientContext* context, std::shared_ptr expression, + const common::LogicalType& type, validate_param_func validateParamFunc); + +template LBUG_API double ExpressionUtil::evaluateLiteral(main::ClientContext* context, + std::shared_ptr expression, const LogicalType& type, + validate_param_func validateParamFunc); + +template LBUG_API int64_t ExpressionUtil::evaluateLiteral(main::ClientContext* context, + std::shared_ptr expression, const LogicalType& type, + validate_param_func validateParamFunc); + +template LBUG_API bool ExpressionUtil::evaluateLiteral(main::ClientContext* context, + std::shared_ptr expression, const LogicalType& type, + validate_param_func validateParamFunc); + +template LBUG_API uint64_t ExpressionUtil::evaluateLiteral(main::ClientContext* context, + std::shared_ptr expression, const LogicalType& type, + validate_param_func validateParamFunc); + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/literal_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/literal_expression.cpp new file mode 100644 index 0000000000..faddcef800 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/literal_expression.cpp @@ -0,0 +1,24 @@ +#include "binder/expression/literal_expression.h" + +#include "common/exception/binder.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +void LiteralExpression::cast(const LogicalType& type) { + // The following is a safeguard to make sure we are not changing literal type unexpectedly. + if (!value.allowTypeChange()) { + // LCOV_EXCL_START + throw BinderException( + stringFormat("Cannot change literal expression data type from {} to {}.", + dataType.toString(), type.toString())); + // LCOV_EXCL_STOP + } + dataType = type.copy(); + value.setDataType(type); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/node_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/node_expression.cpp new file mode 100644 index 0000000000..72e235f5c3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/node_expression.cpp @@ -0,0 +1,18 @@ +#include "binder/expression/node_expression.h" + +namespace lbug { +namespace binder { + +NodeExpression::~NodeExpression() = default; + +std::shared_ptr NodeExpression::getPrimaryKey(common::table_id_t tableID) const { + for (auto& property : propertyExprs) { + if (property->isPrimaryKey(tableID)) { + return property; + } + } + KU_UNREACHABLE; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/node_rel_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/node_rel_expression.cpp new file mode 100644 index 0000000000..a8d414eb3c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/node_rel_expression.cpp @@ -0,0 +1,44 @@ +#include "binder/expression/node_rel_expression.h" + +#include "catalog/catalog_entry/table_catalog_entry.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace binder { + +table_id_vector_t NodeOrRelExpression::getTableIDs() const { + table_id_vector_t result; + for (auto& entry : entries) { + result.push_back(entry->getTableID()); + } + return result; +} + +table_id_set_t NodeOrRelExpression::getTableIDsSet() const { + table_id_set_t result; + for (auto& entry : entries) { + result.insert(entry->getTableID()); + } + return result; +} + +void NodeOrRelExpression::addEntries(const std::vector& entries_) { + auto tableIDsSet = getTableIDsSet(); + for (auto& entry : entries_) { + if (!tableIDsSet.contains(entry->getTableID())) { + entries.push_back(entry); + } + } +} + +void NodeOrRelExpression::addPropertyExpression(std::shared_ptr property) { + auto propertyName = property->getPropertyName(); + KU_ASSERT(!propertyNameToIdx.contains(propertyName)); + propertyNameToIdx.insert({propertyName, propertyExprs.size()}); + propertyExprs.push_back(std::move(property)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/parameter_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/parameter_expression.cpp new file mode 100644 index 0000000000..42a33e8a50 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/parameter_expression.cpp @@ -0,0 +1,23 @@ +#include "binder/expression/parameter_expression.h" + +#include "common/exception/binder.h" + +namespace lbug { +using namespace common; + +namespace binder { + +void ParameterExpression::cast(const LogicalType& type) { + if (!dataType.containsAny()) { + // LCOV_EXCL_START + throw BinderException( + stringFormat("Cannot change parameter expression data type from {} to {}.", + dataType.toString(), type.toString())); + // LCOV_EXCL_STOP + } + dataType = type.copy(); + value.setDataType(type); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/property_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/property_expression.cpp new file mode 100644 index 0000000000..05b8648fbd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/property_expression.cpp @@ -0,0 +1,31 @@ +#include "binder/expression/property_expression.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace binder { + +bool PropertyExpression::isPrimaryKey() const { + for (auto& [id, info] : infos) { + if (!info.isPrimaryKey) { + return false; + } + } + return true; +} + +bool PropertyExpression::isPrimaryKey(table_id_t tableID) const { + if (!infos.contains(tableID)) { + return false; + } + return infos.at(tableID).isPrimaryKey; +} + +bool PropertyExpression::hasProperty(table_id_t tableID) const { + KU_ASSERT(infos.contains(tableID)); + return infos.at(tableID).exists; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/rel_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/rel_expression.cpp new file mode 100644 index 0000000000..79ddcf246a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/rel_expression.cpp @@ -0,0 +1,88 @@ +#include "binder/expression/rel_expression.h" + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/enums/extend_direction_util.h" +#include "common/exception/binder.h" +#include "common/utils.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +bool RelExpression::isMultiLabeled() const { + if (entries.size() > 1) { + return true; + } + for (auto& entry : entries) { + auto relGroupEntry = entry->ptrCast(); + if (relGroupEntry->getNumRelTables() > 1) { + return true; + } + } + return false; +} + +std::string RelExpression::detailsToString() const { + std::string result = toString(); + switch (relType) { + case QueryRelType::SHORTEST: { + result += "SHORTEST"; + } break; + case QueryRelType::ALL_SHORTEST: { + result += "ALL SHORTEST"; + } break; + case QueryRelType::WEIGHTED_SHORTEST: { + result += "WEIGHTED SHORTEST"; + } break; + case QueryRelType::ALL_WEIGHTED_SHORTEST: { + result += "ALL WEIGHTED SHORTEST"; + } break; + default: + break; + } + if (QueryRelTypeUtils::isRecursive(relType)) { + result += std::to_string(recursiveInfo->bindData->lowerBound); + result += ".."; + result += std::to_string(recursiveInfo->bindData->upperBound); + } + return result; +} + +std::vector RelExpression::getExtendDirections() const { + std::vector ret; + for (const auto direction : {ExtendDirection::FWD, ExtendDirection::BWD}) { + const bool addDirection = std::all_of(entries.begin(), entries.end(), + [direction](const catalog::TableCatalogEntry* tableEntry) { + const auto* entry = tableEntry->constPtrCast(); + return common::containsValue(entry->getRelDataDirections(), + ExtendDirectionUtil::getRelDataDirection(direction)); + }); + if (addDirection) { + ret.push_back(direction); + } + } + if (ret.empty()) { + throw BinderException(stringFormat( + "There are no common storage directions among the rel " + "tables matched by pattern '{}' (some tables have storage direction 'fwd' " + "while others have storage direction 'bwd'). Scanning different tables matching the " + "same pattern in different directions is currently unsupported.", + toString())); + } + return ret; +} + +std::vector RelExpression::getInnerRelTableIDs() const { + std::vector innerTableIDs; + for (auto& entry : entries) { + for (auto& info : entry->cast().getRelEntryInfos()) { + innerTableIDs.push_back(info.oid); + } + } + return innerTableIDs; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/scalar_function_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/scalar_function_expression.cpp new file mode 100644 index 0000000000..c24bfaed85 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/scalar_function_expression.cpp @@ -0,0 +1,24 @@ +#include "binder/expression/scalar_function_expression.h" + +#include "binder/expression/expression_util.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +std::string ScalarFunctionExpression::toStringInternal() const { + if (function->name.starts_with("CAST")) { + return stringFormat("CAST({}, {})", ExpressionUtil::toString(children), + bindData->resultType.toString()); + } + return stringFormat("{}({})", function->name, ExpressionUtil::toString(children)); +} + +std::string ScalarFunctionExpression::getUniqueName(const std::string& functionName, + const expression_vector& children) { + return stringFormat("{}({})", functionName, ExpressionUtil::getUniqueName(children)); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/variable_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/variable_expression.cpp new file mode 100644 index 0000000000..0df190cb7e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression/variable_expression.cpp @@ -0,0 +1,22 @@ +#include "binder/expression/variable_expression.h" + +#include "common/exception/binder.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +void VariableExpression::cast(const LogicalType& type) { + if (!dataType.containsAny()) { + // LCOV_EXCL_START + throw BinderException( + stringFormat("Cannot change variable expression data type from {} to {}.", + dataType.toString(), type.toString())); + // LCOV_EXCL_STOP + } + dataType = type.copy(); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression_binder.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression_binder.cpp new file mode 100644 index 0000000000..40aad2d390 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression_binder.cpp @@ -0,0 +1,154 @@ +#include "binder/expression_binder.h" + +#include "binder/binder.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/parameter_expression.h" +#include "binder/expression_visitor.h" +#include "common/exception/binder.h" +#include "common/exception/not_implemented.h" +#include "common/string_format.h" +#include "expression_evaluator/expression_evaluator_utils.h" +#include "function/cast/vector_cast_functions.h" +#include "parser/expression/parsed_expression_visitor.h" +#include "parser/expression/parsed_parameter_expression.h" + +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::parser; + +namespace lbug { +namespace binder { + +std::shared_ptr ExpressionBinder::bindExpression( + const ParsedExpression& parsedExpression) { + // Normally u can only reference an existing expression through alias which is a parsed + // VARIABLE expression. + // An exception is order by binding, e.g. RETURN a, COUNT(*) ORDER BY COUNT(*) + // the later COUNT(*) should reference the one in projection list. So we need to explicitly + // check scope when binding order by list. + if (config.bindOrderByAfterAggregate && binder->scope.contains(parsedExpression.toString())) { + return binder->scope.getExpression(parsedExpression.toString()); + } + auto collector = ParsedParamExprCollector(); + collector.visit(&parsedExpression); + if (collector.hasParamExprs()) { + bool allParamExist = true; + for (auto& parsedExpr : collector.getParamExprs()) { + auto name = parsedExpr->constCast().getParameterName(); + if (!knownParameters.contains(name)) { + unknownParameters.insert(name); + allParamExist = false; + } + } + if (!allParamExist) { + auto expr = std::make_shared(binder->getUniqueExpressionName(""), + Value::createNullValue()); + if (parsedExpression.hasAlias()) { + expr->setAlias(parsedExpression.getAlias()); + } + return expr; + } + } + std::shared_ptr expression; + auto expressionType = parsedExpression.getExpressionType(); + if (ExpressionTypeUtil::isBoolean(expressionType)) { + expression = bindBooleanExpression(parsedExpression); + } else if (ExpressionTypeUtil::isComparison(expressionType)) { + expression = bindComparisonExpression(parsedExpression); + } else if (ExpressionTypeUtil::isNullOperator(expressionType)) { + expression = bindNullOperatorExpression(parsedExpression); + } else if (ExpressionType::FUNCTION == expressionType) { + expression = bindFunctionExpression(parsedExpression); + } else if (ExpressionType::PROPERTY == expressionType) { + expression = bindPropertyExpression(parsedExpression); + } else if (ExpressionType::PARAMETER == expressionType) { + expression = bindParameterExpression(parsedExpression); + } else if (ExpressionType::LITERAL == expressionType) { + expression = bindLiteralExpression(parsedExpression); + } else if (ExpressionType::VARIABLE == expressionType) { + expression = bindVariableExpression(parsedExpression); + } else if (ExpressionType::SUBQUERY == expressionType) { + expression = bindSubqueryExpression(parsedExpression); + } else if (ExpressionType::CASE_ELSE == expressionType) { + expression = bindCaseExpression(parsedExpression); + } else if (ExpressionType::LAMBDA == expressionType) { + expression = bindLambdaExpression(parsedExpression); + } else { + throw NotImplementedException( + "bindExpression(" + ExpressionTypeUtil::toString(expressionType) + ")."); + } + if (ConstantExpressionVisitor::needFold(*expression)) { + return foldExpression(expression); + } + return expression; +} + +std::shared_ptr ExpressionBinder::foldExpression( + const std::shared_ptr& expression) const { + auto value = + evaluator::ExpressionEvaluatorUtils::evaluateConstantExpression(expression, context); + auto result = createLiteralExpression(value); + // Fold result should preserve the alias original expression. E.g. + // RETURN 2, 1 + 1 AS x + // Once folded, 1 + 1 will become 2 and have the same identifier as the first RETURN element. + // We preserve alias (x) to avoid such conflict. + if (expression->hasAlias()) { + result->setAlias(expression->getAlias()); + } else { + result->setAlias(expression->toString()); + } + return result; +} + +static std::string unsupportedImplicitCastException(const Expression& expression, + const std::string& targetTypeStr) { + return stringFormat( + "Expression {} has data type {} but expected {}. Implicit cast is not supported.", + expression.toString(), expression.dataType.toString(), targetTypeStr); +} + +std::shared_ptr ExpressionBinder::implicitCastIfNecessary( + const std::shared_ptr& expression, const LogicalType& targetType) { + auto& type = expression->dataType; + if (type == targetType || targetType.containsAny()) { // No need to cast. + return expression; + } + if (!type.isInternalType() || !targetType.isInternalType()) { + return implicitCast(expression, targetType); + } + if (ExpressionUtil::canCastStatically(*expression, targetType)) { + expression->cast(targetType); + return expression; + } + return implicitCast(expression, targetType); +} + +std::shared_ptr ExpressionBinder::implicitCast( + const std::shared_ptr& expression, const LogicalType& targetType) { + if (CastFunction::hasImplicitCast(expression->dataType, targetType)) { + return forceCast(expression, targetType); + } else { + throw BinderException(unsupportedImplicitCastException(*expression, targetType.toString())); + } +} + +// cast without implicit checking. +std::shared_ptr ExpressionBinder::forceCast( + const std::shared_ptr& expression, const LogicalType& targetType) { + auto functionName = "CAST"; + auto children = + expression_vector{expression, createLiteralExpression(Value(targetType.toString()))}; + return bindScalarFunctionExpression(children, functionName); +} + +std::string ExpressionBinder::getUniqueName(const std::string& name) const { + return binder->getUniqueExpressionName(name); +} + +void ExpressionBinder::addParameter(const std::string& name, std::shared_ptr value) { + KU_ASSERT(!knownParameters.contains(name)); + knownParameters[name] = value; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression_visitor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression_visitor.cpp new file mode 100644 index 0000000000..b95af484fe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/expression_visitor.cpp @@ -0,0 +1,337 @@ +#include "binder/expression_visitor.h" + +#include "binder/expression/case_expression.h" +#include "binder/expression/lambda_expression.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/property_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression/subquery_expression.h" +#include "common/exception/not_implemented.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/sequence/sequence_functions.h" +#include "function/uuid/vector_uuid_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +void ExpressionVisitor::visit(std::shared_ptr expr) { + visitChildren(*expr); + visitSwitch(expr); +} + +void ExpressionVisitor::visitSwitch(std::shared_ptr expr) { + switch (expr->expressionType) { + case ExpressionType::OR: + case ExpressionType::XOR: + case ExpressionType::AND: + case ExpressionType::NOT: + case ExpressionType::EQUALS: + case ExpressionType::NOT_EQUALS: + case ExpressionType::GREATER_THAN: + case ExpressionType::GREATER_THAN_EQUALS: + case ExpressionType::LESS_THAN: + case ExpressionType::LESS_THAN_EQUALS: + case ExpressionType::IS_NULL: + case ExpressionType::IS_NOT_NULL: + case ExpressionType::FUNCTION: { + visitFunctionExpr(expr); + } break; + case ExpressionType::AGGREGATE_FUNCTION: { + visitAggFunctionExpr(expr); + } break; + case ExpressionType::PROPERTY: { + visitPropertyExpr(expr); + } break; + case ExpressionType::LITERAL: { + visitLiteralExpr(expr); + } break; + case ExpressionType::VARIABLE: { + visitVariableExpr(expr); + } break; + case ExpressionType::PATH: { + visitPathExpr(expr); + } break; + case ExpressionType::PATTERN: { + visitNodeRelExpr(expr); + } break; + case ExpressionType::PARAMETER: { + visitParamExpr(expr); + } break; + case ExpressionType::SUBQUERY: { + visitSubqueryExpr(expr); + } break; + case ExpressionType::CASE_ELSE: { + visitCaseExpr(expr); + } break; + case ExpressionType::GRAPH: { + visitGraphExpr(expr); + } break; + case ExpressionType::LAMBDA: { + visitLambdaExpr(expr); + } break; + // LCOV_EXCL_START + default: + throw NotImplementedException("ExpressionVisitor::visitSwitch"); + // LCOV_EXCL_STOP + } +} + +void ExpressionVisitor::visitChildren(const Expression& expr) { + switch (expr.expressionType) { + case ExpressionType::CASE_ELSE: { + visitCaseExprChildren(expr); + } break; + case ExpressionType::LAMBDA: { + auto& lambda = expr.constCast(); + visit(lambda.getFunctionExpr()); + } break; + default: { + for (auto& child : expr.getChildren()) { + visit(child); + } + } + } +} + +void ExpressionVisitor::visitCaseExprChildren(const Expression& expr) { + auto& caseExpr = expr.constCast(); + for (auto i = 0u; i < caseExpr.getNumCaseAlternatives(); ++i) { + auto caseAlternative = caseExpr.getCaseAlternative(i); + visit(caseAlternative->whenExpression); + visit(caseAlternative->thenExpression); + } + visit(caseExpr.getElseExpression()); +} + +expression_vector ExpressionChildrenCollector::collectChildren(const Expression& expression) { + switch (expression.expressionType) { + case ExpressionType::CASE_ELSE: { + return collectCaseChildren(expression); + } + case ExpressionType::SUBQUERY: { + return collectSubqueryChildren(expression); + } + case ExpressionType::PATTERN: { + switch (expression.dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + return collectNodeChildren(expression); + } + case LogicalTypeID::REL: { + return collectRelChildren(expression); + } + default: { + return expression_vector{}; + } + } + } + default: { + return expression.children; + } + } +} + +expression_vector ExpressionChildrenCollector::collectCaseChildren(const Expression& expression) { + expression_vector result; + auto& caseExpression = expression.constCast(); + for (auto i = 0u; i < caseExpression.getNumCaseAlternatives(); ++i) { + auto caseAlternative = caseExpression.getCaseAlternative(i); + result.push_back(caseAlternative->whenExpression); + result.push_back(caseAlternative->thenExpression); + } + result.push_back(caseExpression.getElseExpression()); + return result; +} + +expression_vector ExpressionChildrenCollector::collectSubqueryChildren( + const Expression& expression) { + expression_vector result; + auto& subqueryExpression = expression.constCast(); + for (auto& node : subqueryExpression.getQueryGraphCollection()->getQueryNodes()) { + result.push_back(node->getInternalID()); + } + if (subqueryExpression.hasWhereExpression()) { + result.push_back(subqueryExpression.getWhereExpression()); + } + return result; +} + +expression_vector ExpressionChildrenCollector::collectNodeChildren(const Expression& expression) { + expression_vector result; + auto& node = expression.constCast(); + for (auto& property : node.getPropertyExpressions()) { + result.push_back(property); + } + result.push_back(node.getInternalID()); + return result; +} + +expression_vector ExpressionChildrenCollector::collectRelChildren(const Expression& expression) { + expression_vector result; + auto& rel = expression.constCast(); + result.push_back(rel.getSrcNode()->getInternalID()); + result.push_back(rel.getDstNode()->getInternalID()); + for (auto& property : rel.getPropertyExpressions()) { + result.push_back(property); + } + if (rel.hasDirectionExpr()) { + result.push_back(rel.getDirectionExpr()); + } + return result; +} + +bool ExpressionVisitor::isRandom(const Expression& expression) { + if (expression.expressionType != ExpressionType::FUNCTION) { + return false; + } + auto& funcExpr = expression.constCast(); + auto funcName = funcExpr.getFunction().name; + if (funcName == function::GenRandomUUIDFunction::name || + funcName == function::RandFunction::name) { + return true; + } + for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) { + if (isRandom(*child)) { + return true; + } + } + return false; +} + +void DependentVarNameCollector::visitSubqueryExpr(std::shared_ptr expr) { + auto& subqueryExpr = expr->constCast(); + for (auto& node : subqueryExpr.getQueryGraphCollection()->getQueryNodes()) { + varNames.insert(node->getUniqueName()); + } + if (subqueryExpr.hasWhereExpression()) { + visit(subqueryExpr.getWhereExpression()); + } +} + +void DependentVarNameCollector::visitPropertyExpr(std::shared_ptr expr) { + varNames.insert(expr->constCast().getVariableName()); +} + +void DependentVarNameCollector::visitNodeRelExpr(std::shared_ptr expr) { + varNames.insert(expr->getUniqueName()); + if (expr->getDataType().getLogicalTypeID() == LogicalTypeID::REL) { + auto& rel = expr->constCast(); + varNames.insert(rel.getSrcNodeName()); + varNames.insert(rel.getDstNodeName()); + } +} + +void DependentVarNameCollector::visitVariableExpr(std::shared_ptr expr) { + varNames.insert(expr->getUniqueName()); +} + +void PropertyExprCollector::visitSubqueryExpr(std::shared_ptr expr) { + auto& subqueryExpr = expr->constCast(); + for (auto& rel : subqueryExpr.getQueryGraphCollection()->getQueryRels()) { + if (rel->isEmpty() || rel->getRelType() != QueryRelType::NON_RECURSIVE) { + // If a query rel is empty then it does not have an internal id property. + continue; + } + expressions.push_back(rel->getInternalID()); + } + if (subqueryExpr.hasWhereExpression()) { + visit(subqueryExpr.getWhereExpression()); + } +} + +void PropertyExprCollector::visitPropertyExpr(std::shared_ptr expr) { + expressions.push_back(expr); +} + +void PropertyExprCollector::visitNodeRelExpr(std::shared_ptr expr) { + for (auto& property : expr->constCast().getPropertyExpressions()) { + expressions.push_back(property); + } +} + +bool ConstantExpressionVisitor::needFold(const Expression& expr) { + if (expr.expressionType == common::ExpressionType::LITERAL) { + return false; // No need to fold a literal. + } + return isConstant(expr); +} + +bool ConstantExpressionVisitor::isConstant(const Expression& expr) { + switch (expr.expressionType) { + case ExpressionType::LITERAL: + return true; + case ExpressionType::AGGREGATE_FUNCTION: + case ExpressionType::PROPERTY: + case ExpressionType::VARIABLE: + case ExpressionType::PATH: + case ExpressionType::PATTERN: + case ExpressionType::PARAMETER: + case ExpressionType::SUBQUERY: + case ExpressionType::GRAPH: + case ExpressionType::LAMBDA: + return false; + case ExpressionType::FUNCTION: + return visitFunction(expr); + case ExpressionType::CASE_ELSE: + return visitCase(expr); + case ExpressionType::OR: + case ExpressionType::XOR: + case ExpressionType::AND: + case ExpressionType::NOT: + case ExpressionType::EQUALS: + case ExpressionType::NOT_EQUALS: + case ExpressionType::GREATER_THAN: + case ExpressionType::GREATER_THAN_EQUALS: + case ExpressionType::LESS_THAN: + case ExpressionType::LESS_THAN_EQUALS: + case ExpressionType::IS_NULL: + case ExpressionType::IS_NOT_NULL: + return visitChildren(expr); + // LCOV_EXCL_START + default: + throw NotImplementedException("ConstantExpressionVisitor::isConstant"); + // LCOV_EXCL_STOP + } +} + +bool ConstantExpressionVisitor::visitFunction(const Expression& expr) { + auto& funcExpr = expr.constCast(); + if (funcExpr.getFunction().name == function::NextValFunction::name) { + return false; + } + if (funcExpr.getFunction().name == function::GenRandomUUIDFunction::name) { + return false; + } + if (funcExpr.getFunction().name == function::RandFunction::name) { + return false; + } + return visitChildren(expr); +} + +bool ConstantExpressionVisitor::visitCase(const Expression& expr) { + auto& caseExpr = expr.constCast(); + for (auto i = 0u; i < caseExpr.getNumCaseAlternatives(); ++i) { + auto caseAlternative = caseExpr.getCaseAlternative(i); + if (!isConstant(*caseAlternative->whenExpression)) { + return false; + } + if (!isConstant(*caseAlternative->thenExpression)) { + return false; + } + } + return isConstant(*caseExpr.getElseExpression()); +} + +bool ConstantExpressionVisitor::visitChildren(const Expression& expr) { + for (auto& child : expr.getChildren()) { + if (!isConstant(*child)) { + return false; + } + } + return true; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/CMakeLists.txt new file mode 100644 index 0000000000..a80019b7c9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library( + lbug_binder_query + OBJECT + bound_insert_clause.cpp + bound_delete_clause.cpp + bound_merge_clause.cpp + bound_set_clause.cpp + query_graph.cpp + query_graph_label_analyzer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_delete_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_delete_clause.cpp new file mode 100644 index 0000000000..e0270a7368 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_delete_clause.cpp @@ -0,0 +1,29 @@ +#include "binder/query/updating_clause/bound_delete_clause.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +bool BoundDeleteClause::hasInfo(const std::function& check) const { + for (auto& info : infos) { + if (check(info)) { + return true; + } + } + return false; +} + +std::vector BoundDeleteClause::getInfos( + const std::function& check) const { + std::vector result; + for (auto& info : infos) { + if (check(info)) { + result.push_back(info.copy()); + } + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_insert_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_insert_clause.cpp new file mode 100644 index 0000000000..18982e3cb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_insert_clause.cpp @@ -0,0 +1,29 @@ +#include "binder/query/updating_clause/bound_insert_clause.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +bool BoundInsertClause::hasInfo(const std::function& check) const { + for (auto& info : infos) { + if (check(info)) { + return true; + } + } + return false; +} + +std::vector BoundInsertClause::getInfos( + const std::function& check) const { + std::vector result; + for (auto& info : infos) { + if (check(info)) { + result.push_back(&info); + } + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_merge_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_merge_clause.cpp new file mode 100644 index 0000000000..9ea91b10fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_merge_clause.cpp @@ -0,0 +1,72 @@ +#include "binder/query/updating_clause/bound_merge_clause.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +bool BoundMergeClause::hasInsertInfo( + const std::function& check) const { + for (auto& info : insertInfos) { + if (check(info)) { + return true; + } + } + return false; +} + +std::vector BoundMergeClause::getInsertInfos( + const std::function& check) const { + std::vector result; + for (auto& info : insertInfos) { + if (check(info)) { + result.push_back(&info); + } + } + return result; +} + +bool BoundMergeClause::hasOnMatchSetInfo( + const std::function& check) const { + for (auto& info : onMatchSetPropertyInfos) { + if (check(info)) { + return true; + } + } + return false; +} + +std::vector BoundMergeClause::getOnMatchSetInfos( + const std::function& check) const { + std::vector result; + for (auto& info : onMatchSetPropertyInfos) { + if (check(info)) { + result.push_back(info.copy()); + } + } + return result; +} + +bool BoundMergeClause::hasOnCreateSetInfo( + const std::function& check) const { + for (auto& info : onCreateSetPropertyInfos) { + if (check(info)) { + return true; + } + } + return false; +} + +std::vector BoundMergeClause::getOnCreateSetInfos( + const std::function& check) const { + std::vector result; + for (auto& info : onCreateSetPropertyInfos) { + if (check(info)) { + result.push_back(info.copy()); + } + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_set_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_set_clause.cpp new file mode 100644 index 0000000000..0bb2dca194 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/bound_set_clause.cpp @@ -0,0 +1,29 @@ +#include "binder/query/updating_clause/bound_set_clause.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +bool BoundSetClause::hasInfo(const std::function& check) const { + for (auto& info : infos) { + if (check(info)) { + return true; + } + } + return false; +} + +std::vector BoundSetClause::getInfos( + const std::function& check) const { + std::vector result; + for (auto& info : infos) { + if (check(info)) { + result.push_back(info.copy()); + } + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/query_graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/query_graph.cpp new file mode 100644 index 0000000000..0358a1afef --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/query_graph.cpp @@ -0,0 +1,338 @@ +#include "binder/query/query_graph.h" + +#include "binder/expression_visitor.h" + +namespace lbug { +namespace binder { + +std::size_t SubqueryGraphHasher::operator()(const SubqueryGraph& key) const { + if (0 == key.queryRelsSelector.count()) { + return std::hash>{}(key.queryNodesSelector); + } + return std::hash>{}(key.queryRelsSelector); +} + +bool SubqueryGraph::containAllVariables(const std::unordered_set& variables) const { + for (auto& var : variables) { + if (queryGraph.containsQueryNode(var) && + !queryNodesSelector[queryGraph.getQueryNodeIdx(var)]) { + return false; + } + if (queryGraph.containsQueryRel(var) && + !queryRelsSelector[queryGraph.getQueryRelIdx(var)]) { + return false; + } + } + return true; +} + +std::unordered_set SubqueryGraph::getNodeNbrPositions() const { + std::unordered_set result; + for (auto relPos = 0u; relPos < queryGraph.getNumQueryRels(); ++relPos) { + if (!queryRelsSelector[relPos]) { // rel not in subgraph, no need to check + continue; + } + auto rel = queryGraph.getQueryRel(relPos); + auto srcNodePos = queryGraph.getQueryNodeIdx(*rel->getSrcNode()); + if (!queryNodesSelector[srcNodePos]) { + result.insert(srcNodePos); + } + auto dstNodePos = queryGraph.getQueryNodeIdx(*rel->getDstNode()); + if (!queryNodesSelector[dstNodePos]) { + result.insert(dstNodePos); + } + } + return result; +} + +std::unordered_set SubqueryGraph::getRelNbrPositions() const { + std::unordered_set result; + for (auto relPos = 0u; relPos < queryGraph.getNumQueryRels(); ++relPos) { + if (queryRelsSelector[relPos]) { // rel already in subgraph, cannot be rel neighbour + continue; + } + auto rel = queryGraph.getQueryRel(relPos); + auto srcNodePos = queryGraph.getQueryNodeIdx(*rel->getSrcNode()); + auto dstNodePos = queryGraph.getQueryNodeIdx(*rel->getDstNode()); + if (queryNodesSelector[srcNodePos] || queryNodesSelector[dstNodePos]) { + result.insert(relPos); + } + } + return result; +} + +subquery_graph_set_t SubqueryGraph::getNbrSubgraphs(uint32_t size) const { + auto result = getBaseNbrSubgraph(); + for (auto i = 1u; i < size; ++i) { + std::unordered_set tmp; + for (auto& prevNbr : result) { + for (auto& subgraph : getNextNbrSubgraphs(prevNbr)) { + tmp.insert(subgraph); + } + } + result = std::move(tmp); + } + return result; +} + +std::vector SubqueryGraph::getConnectedNodePos(const SubqueryGraph& nbr) const { + std::vector result; + for (auto& nodePos : getNodeNbrPositions()) { + if (nbr.queryNodesSelector[nodePos]) { + result.push_back(nodePos); + } + } + for (auto& nodePos : nbr.getNodeNbrPositions()) { + if (queryNodesSelector[nodePos]) { + result.push_back(nodePos); + } + } + return result; +} + +std::unordered_set SubqueryGraph::getNodePositionsIgnoringNodeSelector() const { + std::unordered_set result; + for (auto nodePos = 0u; nodePos < queryGraph.getNumQueryNodes(); ++nodePos) { + if (queryNodesSelector[nodePos]) { + result.insert(nodePos); + } + } + for (auto relPos = 0u; relPos < queryGraph.getNumQueryRels(); ++relPos) { + auto rel = queryGraph.getQueryRel(relPos); + if (queryRelsSelector[relPos]) { + result.insert(queryGraph.getQueryNodeIdx(rel->getSrcNodeName())); + result.insert(queryGraph.getQueryNodeIdx(rel->getDstNodeName())); + } + } + return result; +} + +std::vector SubqueryGraph::getNbrNodeIndices() const { + std::unordered_set result; + for (auto i = 0u; i < queryGraph.getNumQueryRels(); ++i) { + if (!queryRelsSelector[i]) { + continue; + } + auto rel = queryGraph.getQueryRel(i); + auto srcNodePos = queryGraph.getQueryNodeIdx(rel->getSrcNodeName()); + auto dstNodePos = queryGraph.getQueryNodeIdx(rel->getDstNodeName()); + if (!queryNodesSelector[srcNodePos]) { + result.insert(srcNodePos); + } + if (!queryNodesSelector[dstNodePos]) { + result.insert(dstNodePos); + } + } + return std::vector{result.begin(), result.end()}; +} + +subquery_graph_set_t SubqueryGraph::getBaseNbrSubgraph() const { + subquery_graph_set_t result; + for (auto& nodePos : getNodeNbrPositions()) { + auto nbr = SubqueryGraph(queryGraph); + nbr.addQueryNode(nodePos); + result.insert(nbr); + } + for (auto& relPos : getRelNbrPositions()) { + auto nbr = SubqueryGraph(queryGraph); + nbr.addQueryRel(relPos); + result.insert(nbr); + } + return result; +} + +subquery_graph_set_t SubqueryGraph::getNextNbrSubgraphs(const SubqueryGraph& prevNbr) const { + subquery_graph_set_t result; + for (auto& nodePos : prevNbr.getNodeNbrPositions()) { + if (queryNodesSelector[nodePos]) { + continue; + } + auto nbr = prevNbr; + nbr.addQueryNode(nodePos); + result.insert(nbr); + } + for (auto& relPos : prevNbr.getRelNbrPositions()) { + if (queryRelsSelector[relPos]) { + continue; + } + auto nbr = prevNbr; + nbr.addQueryRel(relPos); + result.insert(nbr); + } + return result; +} + +bool QueryGraph::isEmpty() const { + for (auto& n : queryNodes) { + if (n->isEmpty()) { + return true; + } + } + for (auto& r : queryRels) { + if (r->isEmpty()) { + return true; + } + } + return false; +} + +std::vector> QueryGraph::getAllPatterns() const { + std::vector> patterns; + for (auto& p : queryNodes) { + patterns.push_back(p); + } + for (auto& p : queryRels) { + patterns.push_back(p); + } + return patterns; +} + +void QueryGraph::addQueryNode(std::shared_ptr queryNode) { + // Note that a node may be added multiple times. We should only keep one of it. + // E.g. MATCH (a:person)-[:knows]->(b:person), (a)-[:knows]->(c:person) + // a will be added twice during binding + if (containsQueryNode(queryNode->getUniqueName())) { + return; + } + queryNodeNameToPosMap.insert({queryNode->getUniqueName(), queryNodes.size()}); + queryNodes.push_back(std::move(queryNode)); +} + +void QueryGraph::addQueryRel(std::shared_ptr queryRel) { + if (containsQueryRel(queryRel->getUniqueName())) { + return; + } + queryRelNameToPosMap.insert({queryRel->getUniqueName(), queryRels.size()}); + queryRels.push_back(std::move(queryRel)); +} + +void QueryGraph::merge(const QueryGraph& other) { + for (auto& otherNode : other.queryNodes) { + addQueryNode(otherNode); + } + for (auto& otherRel : other.queryRels) { + addQueryRel(otherRel); + } +} + +bool QueryGraph::canProjectExpression(const std::shared_ptr& expression) const { + auto collector = DependentVarNameCollector(); + collector.visit(expression); + for (auto& variable : collector.getVarNames()) { + if (!containsQueryNode(variable) && !containsQueryRel(variable)) { + return false; + } + } + return true; +} + +bool QueryGraph::isConnected(const QueryGraph& other) const { + for (auto& queryNode : queryNodes) { + if (other.containsQueryNode(queryNode->getUniqueName())) { + return true; + } + } + return false; +} + +void QueryGraphCollection::merge(const QueryGraphCollection& other) { + for (auto& queryGraph : other.queryGraphs) { + addAndMergeQueryGraphIfConnected(queryGraph); + } + finalize(); +} + +void QueryGraphCollection::addAndMergeQueryGraphIfConnected(QueryGraph queryGraphToAdd) { + auto newQueryGraphSet = std::vector(); + for (auto i = 0u; i < queryGraphs.size(); i++) { + auto queryGraph = std::move(queryGraphs[i]); + if (queryGraph.isConnected(queryGraphToAdd)) { + queryGraphToAdd.merge(queryGraph); + } else { + newQueryGraphSet.push_back(std::move(queryGraph)); + } + } + newQueryGraphSet.push_back(std::move(queryGraphToAdd)); + queryGraphs = std::move(newQueryGraphSet); +} + +void QueryGraphCollection::finalize() { + common::idx_t baseGraphIdx = 0; + while (true) { + auto prevNumGraphs = queryGraphs.size(); + queryGraphs = mergeGraphs(baseGraphIdx++); + if (queryGraphs.size() == prevNumGraphs || baseGraphIdx == queryGraphs.size()) { + return; + } + } +} + +std::vector QueryGraphCollection::mergeGraphs(common::idx_t baseGraphIdx) { + KU_ASSERT(baseGraphIdx < queryGraphs.size()); + auto& baseGraph = queryGraphs[baseGraphIdx]; + std::unordered_set mergedGraphIndices; + mergedGraphIndices.insert(baseGraphIdx); + while (true) { + // find graph to merge + common::idx_t graphToMergeIdx = common::INVALID_IDX; + for (auto i = 0u; i < queryGraphs.size(); ++i) { + if (mergedGraphIndices.contains(i)) { // graph has been merged. + continue; + } + if (baseGraph.isConnected(queryGraphs[i])) { // find graph to merge. + graphToMergeIdx = i; + break; + } + } + if (graphToMergeIdx == common::INVALID_IDX) { // No graph can be merged. Terminate. + break; + } + // Perform merge + baseGraph.merge(queryGraphs[graphToMergeIdx]); + mergedGraphIndices.insert(graphToMergeIdx); + } + std::vector finalGraphs; + for (auto i = 0u; i < queryGraphs.size(); ++i) { + if (i == baseGraphIdx) { + finalGraphs.push_back(baseGraph); + continue; + } + if (mergedGraphIndices.contains(i)) { + continue; + } + finalGraphs.push_back(std::move(queryGraphs[i])); + } + return finalGraphs; +} + +bool QueryGraphCollection::contains(const std::string& name) const { + for (auto& queryGraph : queryGraphs) { + if (queryGraph.containsQueryNode(name) || queryGraph.containsQueryRel(name)) { + return true; + } + } + return false; +} + +std::vector> QueryGraphCollection::getQueryNodes() const { + std::vector> result; + for (auto& queryGraph : queryGraphs) { + for (auto& node : queryGraph.getQueryNodes()) { + result.push_back(node); + } + } + return result; +} + +std::vector> QueryGraphCollection::getQueryRels() const { + std::vector> result; + for (auto& queryGraph : queryGraphs) { + for (auto& rel : queryGraph.getQueryRels()) { + result.push_back(rel); + } + } + return result; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/query_graph_label_analyzer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/query_graph_label_analyzer.cpp new file mode 100644 index 0000000000..094716e4db --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/query/query_graph_label_analyzer.cpp @@ -0,0 +1,159 @@ +#include "binder/query/query_graph_label_analyzer.h" + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::transaction; + +namespace lbug { +namespace binder { + +// NOLINTNEXTLINE(readability-non-const-parameter): graph is supposed to be modified. +void QueryGraphLabelAnalyzer::pruneLabel(QueryGraph& graph) const { + for (auto i = 0u; i < graph.getNumQueryNodes(); ++i) { + pruneNode(graph, *graph.getQueryNode(i)); + } + for (auto i = 0u; i < graph.getNumQueryRels(); ++i) { + pruneRel(*graph.getQueryRel(i)); + } +} + +struct Candidates { + table_id_set_t idSet; + std::unordered_set nameSet; + + void insert(const table_id_set_t& idsToInsert, Catalog* catalog, Transaction* transaction) { + for (auto id : idsToInsert) { + auto name = catalog->getTableCatalogEntry(transaction, id)->getName(); + idSet.insert(id); + nameSet.insert(name); + } + } + + bool empty() const { return idSet.empty(); } + + bool contains(const table_id_t& id) const { return idSet.contains(id); } + + std::string toString() const { + auto names = std::vector{nameSet.begin(), nameSet.end()}; + auto result = names[0]; + for (auto j = 1u; j < names.size(); ++j) { + result += ", " + names[j]; + } + return result; + } +}; + +void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& node) const { + auto catalog = Catalog::Get(clientContext); + for (auto i = 0u; i < graph.getNumQueryRels(); ++i) { + auto queryRel = graph.getQueryRel(i); + if (queryRel->isRecursive()) { + continue; + } + Candidates candidates; + auto isSrcConnect = *queryRel->getSrcNode() == node; + auto isDstConnect = *queryRel->getDstNode() == node; + auto tx = transaction::Transaction::Get(clientContext); + if (queryRel->getDirectionType() == RelDirectionType::BOTH) { + if (isSrcConnect || isDstConnect) { + for (auto entry : queryRel->getEntries()) { + auto& relEntry = entry->constCast(); + candidates.insert(relEntry.getSrcNodeTableIDSet(), catalog, tx); + candidates.insert(relEntry.getDstNodeTableIDSet(), catalog, tx); + } + } + } else { + if (isSrcConnect) { + for (auto entry : queryRel->getEntries()) { + auto& relEntry = entry->constCast(); + candidates.insert(relEntry.getSrcNodeTableIDSet(), catalog, tx); + } + } else if (isDstConnect) { + for (auto entry : queryRel->getEntries()) { + auto& relEntry = entry->constCast(); + candidates.insert(relEntry.getDstNodeTableIDSet(), catalog, tx); + } + } + } + if (candidates.empty()) { // No need to prune. + continue; + } + std::vector prunedEntries; + for (auto entry : node.getEntries()) { + if (!candidates.contains(entry->getTableID())) { + continue; + } + prunedEntries.push_back(entry); + } + node.setEntries(prunedEntries); + if (prunedEntries.empty()) { + if (throwOnViolate) { + throw BinderException( + stringFormat("Query node {} violates schema. Expected labels are {}.", + node.toString(), candidates.toString())); + } + } + } +} + +bool hasOverlap(const table_id_set_t& left, const table_id_set_t& right) { + for (auto id : left) { + if (right.contains(id)) { + return true; + } + } + return false; +} + +void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { + if (rel.isRecursive()) { + return; + } + std::vector prunedEntries; + auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); + auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); + if (rel.getDirectionType() == RelDirectionType::BOTH) { + for (auto& entry : rel.getEntries()) { + auto& relEntry = entry->constCast(); + auto fwdSrcOverlap = hasOverlap(srcTableIDSet, relEntry.getSrcNodeTableIDSet()); + auto fwdDstOverlap = hasOverlap(dstTableIDSet, relEntry.getDstNodeTableIDSet()); + auto fwdOverlap = fwdSrcOverlap && fwdDstOverlap; + auto bwdSrcOverlap = hasOverlap(dstTableIDSet, relEntry.getSrcNodeTableIDSet()); + auto bwdDstOverlap = hasOverlap(srcTableIDSet, relEntry.getDstNodeTableIDSet()); + auto bwdOverlap = bwdSrcOverlap && bwdDstOverlap; + if (fwdOverlap || bwdOverlap) { + prunedEntries.push_back(entry); + } + } + } else { + for (auto& entry : rel.getEntries()) { + auto& relEntry = entry->constCast(); + auto srcOverlap = hasOverlap(srcTableIDSet, relEntry.getSrcNodeTableIDSet()); + auto dstOverlap = hasOverlap(dstTableIDSet, relEntry.getDstNodeTableIDSet()); + if (srcOverlap && dstOverlap) { + prunedEntries.push_back(entry); + } + } + } + rel.setEntries(prunedEntries); + // Note the pruning for node should guarantee the following exception won't be triggered. + // For safety (and consistency) reason, we still write the check but skip coverage check. + // LCOV_EXCL_START + if (prunedEntries.empty()) { + if (throwOnViolate) { + throw BinderException(stringFormat("Cannot find a label for relationship {} that " + "connects to all of its neighbour nodes.", + rel.toString())); + } + } + // LCOV_EXCL_STOP +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/CMakeLists.txt new file mode 100644 index 0000000000..1b3788ca13 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library( + lbug_binder_rewriter + OBJECT + match_clause_pattern_label_rewriter.cpp + normalized_query_part_match_rewriter.cpp + with_clause_projection_rewriter.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/match_clause_pattern_label_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/match_clause_pattern_label_rewriter.cpp new file mode 100644 index 0000000000..405eb12af1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/match_clause_pattern_label_rewriter.cpp @@ -0,0 +1,22 @@ +#include "binder/rewriter/match_clause_pattern_label_rewriter.h" + +#include "binder/query/reading_clause/bound_match_clause.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +void MatchClausePatternLabelRewriter::visitMatchUnsafe(BoundReadingClause& readingClause) { + auto& matchClause = readingClause.cast(); + if (matchClause.getMatchClauseType() == MatchClauseType::OPTIONAL_MATCH) { + return; + } + auto collection = matchClause.getQueryGraphCollectionUnsafe(); + for (auto i = 0u; i < collection->getNumQueryGraphs(); ++i) { + analyzer.pruneLabel(*collection->getQueryGraphUnsafe(i)); + } +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/normalized_query_part_match_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/normalized_query_part_match_rewriter.cpp new file mode 100644 index 0000000000..b4db8e2652 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/normalized_query_part_match_rewriter.cpp @@ -0,0 +1,48 @@ +#include "binder/rewriter/normalized_query_part_match_rewriter.h" + +#include "binder/binder.h" +#include "binder/query/reading_clause/bound_match_clause.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +static bool canRewrite(const BoundMatchClause& matchClause) { + return !matchClause.hasHint() && + matchClause.getMatchClauseType() != MatchClauseType::OPTIONAL_MATCH; +} + +void NormalizedQueryPartMatchRewriter::visitQueryPartUnsafe(NormalizedQueryPart& queryPart) { + if (queryPart.getNumReadingClause() == 0) { + return; + } + for (auto i = 0u; i < queryPart.getNumReadingClause(); i++) { + if (queryPart.getReadingClause(i)->getClauseType() != ClauseType::MATCH) { + return; + } + auto& match = queryPart.getReadingClause(i)->constCast(); + if (!canRewrite(match)) { + return; + } + } + // Merge consecutive match clauses + std::vector> newReadingClauses; + newReadingClauses.push_back(std::move(queryPart.readingClauses[0])); + auto& leadingMatchClause = newReadingClauses[0]->cast(); + auto binder = Binder(clientContext); + auto expressionBinder = binder.getExpressionBinder(); + for (auto idx = 1u; idx < queryPart.getNumReadingClause(); idx++) { + auto& otherMatchClause = queryPart.readingClauses[idx]->constCast(); + leadingMatchClause.getQueryGraphCollectionUnsafe()->merge( + *otherMatchClause.getQueryGraphCollection()); + auto predicate = expressionBinder->combineBooleanExpressions(ExpressionType::AND, + leadingMatchClause.getPredicate(), otherMatchClause.getPredicate()); + leadingMatchClause.setPredicate(std::move(predicate)); + } + // Move remaining reading clause + queryPart.readingClauses = std::move(newReadingClauses); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/with_clause_projection_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/with_clause_projection_rewriter.cpp new file mode 100644 index 0000000000..63dba5a0a6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/rewriter/with_clause_projection_rewriter.cpp @@ -0,0 +1,101 @@ +#include "binder/rewriter/with_clause_projection_rewriter.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/property_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression_visitor.h" +#include "binder/visitor/property_collector.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +static void rewrite(std::shared_ptr expr, expression_vector& projectionList, + const std::unordered_map& varNameToProperties) { + std::string varName; + if (ExpressionUtil::isNodePattern(*expr)) { + auto& node = expr->constCast(); + projectionList.push_back(node.getInternalID()); + varName = node.getUniqueName(); + } else if (ExpressionUtil::isRelPattern(*expr)) { + auto& rel = expr->constCast(); + projectionList.push_back(rel.getSrcNode()->getInternalID()); + projectionList.push_back(rel.getDstNode()->getInternalID()); + projectionList.push_back(rel.getInternalID()); + if (rel.hasDirectionExpr()) { + projectionList.push_back(rel.getDirectionExpr()); + } + varName = rel.getUniqueName(); + } else if (ExpressionUtil::isRecursiveRelPattern(*expr)) { + auto& rel = expr->constCast(); + projectionList.push_back(rel.getLengthExpression()); + projectionList.push_back(expr); + varName = rel.getUniqueName(); + } + if (!varName.empty()) { + if (varNameToProperties.contains(varName)) { + for (auto& property : varNameToProperties.at(varName)) { + projectionList.push_back(property); + } + } + } else { + projectionList.push_back(expr); + } +} + +static expression_vector rewrite(const expression_vector& exprs, + const std::unordered_map& varNameToProperties) { + expression_vector projectionList; + for (auto& expr : exprs) { + rewrite(expr, projectionList, varNameToProperties); + } + return projectionList; +} + +void WithClauseProjectionRewriter::visitSingleQueryUnsafe(NormalizedSingleQuery& singleQuery) { + auto propertyCollector = PropertyCollector(); + propertyCollector.visitSingleQuerySkipNodeRel(singleQuery); + std::unordered_map varNameToProperties; + for (auto& expr : propertyCollector.getProperties()) { + auto& property = expr->constCast(); + if (!varNameToProperties.contains(property.getVariableName())) { + varNameToProperties.insert({property.getVariableName(), expression_vector{}}); + } + varNameToProperties.at(property.getVariableName()).push_back(expr); + } + // Rewrite WITH clause node, relationship pattern projection as node.* & rel.* + // Because we want to delay the evaluation of node and rel as a struct. + for (auto i = 0u; i < singleQuery.getNumQueryParts() - 1; ++i) { + auto queryPart = singleQuery.getQueryPartUnsafe(i); + auto projectionBody = queryPart->getProjectionBodyUnsafe(); + auto newProjectionExprs = + rewrite(projectionBody->getProjectionExpressions(), varNameToProperties); + projectionBody->setProjectionExpressions(std::move(newProjectionExprs)); + auto newGroupByExprs = + rewrite(projectionBody->getGroupByExpressions(), varNameToProperties); + projectionBody->setGroupByExpressions(std::move(newGroupByExprs)); + } + // Remove constant expressions from WITH clause projection list. + for (auto i = 0u; i < singleQuery.getNumQueryParts() - 1; ++i) { + auto queryPart = singleQuery.getQueryPartUnsafe(i); + auto projectionBody = queryPart->getProjectionBodyUnsafe(); + // Avoid rewrite in the case of ORDER BY 1 or aggregate by constant. Because operator + // implementation replies on expressions to be projected first. + if (projectionBody->hasOrderByExpressions() || projectionBody->hasAggregateExpressions()) { + continue; + } + expression_vector nonConstantProjectionExprs; + for (auto& expr : projectionBody->getProjectionExpressions()) { + if (ConstantExpressionVisitor::isConstant(*expr)) { + continue; + } + nonConstantProjectionExprs.push_back(expr); + } + projectionBody->setProjectionExpressions(nonConstantProjectionExprs); + } +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/CMakeLists.txt new file mode 100644 index 0000000000..413008bf8b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library( + lbug_binder_visitor + OBJECT + confidential_statement_analyzer.cpp + default_type_solver.cpp + property_collector.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/confidential_statement_analyzer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/confidential_statement_analyzer.cpp new file mode 100644 index 0000000000..17f741b304 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/confidential_statement_analyzer.cpp @@ -0,0 +1,17 @@ +#include "binder/visitor/confidential_statement_analyzer.h" + +#include "binder/bound_standalone_call.h" +#include "main/db_config.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +void ConfidentialStatementAnalyzer::visitStandaloneCall(const BoundStatement& boundStatement) { + auto& standaloneCall = boundStatement.constCast(); + confidentialStatement = standaloneCall.getOption()->isConfidential; +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/default_type_solver.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/default_type_solver.cpp new file mode 100644 index 0000000000..947ef3e208 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/default_type_solver.cpp @@ -0,0 +1,25 @@ +#include "binder/visitor/default_type_solver.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +static void resolveAnyType(Expression& expr) { + if (expr.getDataType().getLogicalTypeID() != LogicalTypeID::ANY) { + return; + } + expr.cast(LogicalType::STRING()); +} + +void DefaultTypeSolver::visitProjectionBody(const BoundProjectionBody& projectionBody) { + for (auto& expr : projectionBody.getProjectionExpressions()) { + resolveAnyType(*expr); + } + for (auto& expr : projectionBody.getOrderByExpressions()) { + resolveAnyType(*expr); + } +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/property_collector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/property_collector.cpp new file mode 100644 index 0000000000..3f0902b087 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/binder/visitor/property_collector.cpp @@ -0,0 +1,180 @@ +#include "binder/visitor/property_collector.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression_visitor.h" +#include "binder/query/reading_clause/bound_load_from.h" +#include "binder/query/reading_clause/bound_match_clause.h" +#include "binder/query/reading_clause/bound_table_function_call.h" +#include "binder/query/reading_clause/bound_unwind_clause.h" +#include "binder/query/updating_clause/bound_delete_clause.h" +#include "binder/query/updating_clause/bound_insert_clause.h" +#include "binder/query/updating_clause/bound_merge_clause.h" +#include "binder/query/updating_clause/bound_set_clause.h" +#include "catalog/catalog_entry/table_catalog_entry.h" + +using namespace lbug::common; + +namespace lbug { +namespace binder { + +expression_vector PropertyCollector::getProperties() const { + expression_vector result; + for (auto& property : properties) { + result.push_back(property); + } + return result; +} + +void PropertyCollector::visitSingleQuerySkipNodeRel(const NormalizedSingleQuery& singleQuery) { + KU_ASSERT(singleQuery.getNumQueryParts() != 0); + auto numQueryParts = singleQuery.getNumQueryParts(); + for (auto i = 0u; i < numQueryParts - 1; ++i) { + visitQueryPartSkipNodeRel(*singleQuery.getQueryPart(i)); + } + visitQueryPart(*singleQuery.getQueryPart(numQueryParts - 1)); +} + +void PropertyCollector::visitQueryPartSkipNodeRel(const NormalizedQueryPart& queryPart) { + for (auto i = 0u; i < queryPart.getNumReadingClause(); ++i) { + visitReadingClause(*queryPart.getReadingClause(i)); + } + for (auto i = 0u; i < queryPart.getNumUpdatingClause(); ++i) { + visitUpdatingClause(*queryPart.getUpdatingClause(i)); + } + if (queryPart.hasProjectionBody()) { + visitProjectionBodySkipNodeRel(*queryPart.getProjectionBody()); + if (queryPart.hasProjectionBodyPredicate()) { + visitProjectionBodyPredicate(queryPart.getProjectionBodyPredicate()); + } + } +} + +void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) { + auto& matchClause = readingClause.constCast(); + if (matchClause.hasPredicate()) { + collectProperties(matchClause.getPredicate()); + } +} + +void PropertyCollector::visitUnwind(const BoundReadingClause& readingClause) { + auto& unwindClause = readingClause.constCast(); + collectProperties(unwindClause.getInExpr()); +} + +void PropertyCollector::visitLoadFrom(const BoundReadingClause& readingClause) { + auto& loadFromClause = readingClause.constCast(); + if (loadFromClause.hasPredicate()) { + collectProperties(loadFromClause.getPredicate()); + } +} + +void PropertyCollector::visitTableFunctionCall(const BoundReadingClause& readingClause) { + auto& call = readingClause.constCast(); + if (call.hasPredicate()) { + collectProperties(call.getPredicate()); + } +} + +void PropertyCollector::visitSet(const BoundUpdatingClause& updatingClause) { + auto& boundSetClause = updatingClause.constCast(); + for (auto& info : boundSetClause.getInfos()) { + collectProperties(info.columnData); + } + for (const auto& info : boundSetClause.getRelInfos()) { + auto& rel = info.pattern->constCast(); + KU_ASSERT(!rel.isEmpty() && rel.getRelType() == QueryRelType::NON_RECURSIVE); + properties.insert(rel.getInternalID()); + } +} + +void PropertyCollector::visitDelete(const BoundUpdatingClause& updatingClause) { + auto& boundDeleteClause = updatingClause.constCast(); + // Read primary key if we are deleting nodes; + for (const auto& info : boundDeleteClause.getNodeInfos()) { + auto& node = info.pattern->constCast(); + for (const auto entry : node.getEntries()) { + properties.insert(node.getPrimaryKey(entry->getTableID())); + } + } + // Read rel internal id if we are deleting relationships. + for (const auto& info : boundDeleteClause.getRelInfos()) { + auto& rel = info.pattern->constCast(); + if (!rel.isEmpty() && rel.getRelType() == QueryRelType::NON_RECURSIVE) { + properties.insert(rel.getInternalID()); + } + } +} + +void PropertyCollector::visitInsert(const BoundUpdatingClause& updatingClause) { + auto& insertClause = updatingClause.constCast(); + for (auto& info : insertClause.getInfos()) { + for (auto& expr : info.columnDataExprs) { + collectProperties(expr); + } + } +} + +void PropertyCollector::visitMerge(const BoundUpdatingClause& updatingClause) { + auto& boundMergeClause = updatingClause.constCast(); + for (auto& rel : boundMergeClause.getQueryGraphCollection()->getQueryRels()) { + if (rel->getRelType() == QueryRelType::NON_RECURSIVE) { + properties.insert(rel->getInternalID()); + } + } + if (boundMergeClause.hasPredicate()) { + collectProperties(boundMergeClause.getPredicate()); + } + for (auto& info : boundMergeClause.getInsertInfosRef()) { + for (auto& expr : info.columnDataExprs) { + collectProperties(expr); + } + } + for (auto& info : boundMergeClause.getOnMatchSetInfosRef()) { + collectProperties(info.columnData); + } + for (auto& info : boundMergeClause.getOnCreateSetInfosRef()) { + collectProperties(info.columnData); + } +} + +void PropertyCollector::visitProjectionBodySkipNodeRel(const BoundProjectionBody& projectionBody) { + for (auto& expression : projectionBody.getProjectionExpressions()) { + collectPropertiesSkipNodeRel(expression); + } + for (auto& expression : projectionBody.getOrderByExpressions()) { + collectPropertiesSkipNodeRel(expression); + } +} + +void PropertyCollector::visitProjectionBody(const BoundProjectionBody& projectionBody) { + for (auto& expression : projectionBody.getProjectionExpressions()) { + collectProperties(expression); + } + for (auto& expression : projectionBody.getOrderByExpressions()) { + collectProperties(expression); + } +} + +void PropertyCollector::visitProjectionBodyPredicate(const std::shared_ptr& predicate) { + collectProperties(predicate); +} + +void PropertyCollector::collectProperties(const std::shared_ptr& expression) { + auto collector = PropertyExprCollector(); + collector.visit(expression); + for (auto& expr : collector.getPropertyExprs()) { + properties.insert(expr); + } +} + +void PropertyCollector::collectPropertiesSkipNodeRel( + const std::shared_ptr& expression) { + if (ExpressionUtil::isNodePattern(*expression) || ExpressionUtil::isRelPattern(*expression) || + ExpressionUtil::isRecursiveRelPattern(*expression)) { + return; + } + collectProperties(expression); +} + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/CMakeLists.txt new file mode 100644 index 0000000000..619e092e13 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/CMakeLists.txt @@ -0,0 +1,16 @@ +add_library(lbug_c_api + OBJECT + connection.cpp + database.cpp + data_type.cpp + helpers.cpp + flat_tuple.cpp + prepared_statement.cpp + query_result.cpp + query_summary.cpp + value.cpp + version.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/connection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/connection.cpp new file mode 100644 index 0000000000..ab70353174 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/connection.cpp @@ -0,0 +1,157 @@ +#include "c_api/lbug.h" +#include "common/exception/exception.h" +#include "main/lbug.h" + +namespace lbug { +namespace common { +class Value; +} +} // namespace lbug + +using namespace lbug::common; +using namespace lbug::main; + +lbug_state lbug_connection_init(lbug_database* database, lbug_connection* out_connection) { + if (database == nullptr || database->_database == nullptr) { + out_connection->_connection = nullptr; + return LbugError; + } + try { + out_connection->_connection = new Connection(static_cast(database->_database)); + } catch (Exception& e) { + out_connection->_connection = nullptr; + return LbugError; + } + return LbugSuccess; +} + +void lbug_connection_destroy(lbug_connection* connection) { + if (connection == nullptr) { + return; + } + if (connection->_connection != nullptr) { + delete static_cast(connection->_connection); + } +} + +lbug_state lbug_connection_set_max_num_thread_for_exec(lbug_connection* connection, + uint64_t num_threads) { + if (connection == nullptr || connection->_connection == nullptr) { + return LbugError; + } + try { + static_cast(connection->_connection)->setMaxNumThreadForExec(num_threads); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_connection_get_max_num_thread_for_exec(lbug_connection* connection, + uint64_t* out_result) { + if (connection == nullptr || connection->_connection == nullptr) { + return LbugError; + } + try { + *out_result = static_cast(connection->_connection)->getMaxNumThreadForExec(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_connection_query(lbug_connection* connection, const char* query, + lbug_query_result* out_query_result) { + if (connection == nullptr || connection->_connection == nullptr) { + return LbugError; + } + try { + auto query_result = + static_cast(connection->_connection)->query(query).release(); + if (query_result == nullptr) { + return LbugError; + } + out_query_result->_query_result = query_result; + out_query_result->_is_owned_by_cpp = false; + if (!query_result->isSuccess()) { + return LbugError; + } + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_connection_prepare(lbug_connection* connection, const char* query, + lbug_prepared_statement* out_prepared_statement) { + if (connection == nullptr || connection->_connection == nullptr) { + return LbugError; + } + try { + auto prepared_statement = + static_cast(connection->_connection)->prepare(query).release(); + if (prepared_statement == nullptr) { + return LbugError; + } + out_prepared_statement->_prepared_statement = prepared_statement; + out_prepared_statement->_bound_values = + new std::unordered_map>; + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_connection_execute(lbug_connection* connection, + lbug_prepared_statement* prepared_statement, lbug_query_result* out_query_result) { + if (connection == nullptr || connection->_connection == nullptr || + prepared_statement == nullptr || prepared_statement->_prepared_statement == nullptr || + prepared_statement->_bound_values == nullptr) { + return LbugError; + } + try { + auto prepared_statement_ptr = + static_cast(prepared_statement->_prepared_statement); + auto bound_values = static_cast>*>( + prepared_statement->_bound_values); + + // Must copy the parameters for safety, and so that the parameters in the prepared statement + // stay the same. + std::unordered_map> copied_bound_values; + for (auto& [name, value] : *bound_values) { + copied_bound_values.emplace(name, value->copy()); + } + + auto query_result = + static_cast(connection->_connection) + ->executeWithParams(prepared_statement_ptr, std::move(copied_bound_values)) + .release(); + if (query_result == nullptr) { + return LbugError; + } + out_query_result->_query_result = query_result; + out_query_result->_is_owned_by_cpp = false; + if (!query_result->isSuccess()) { + return LbugError; + } + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} +void lbug_connection_interrupt(lbug_connection* connection) { + static_cast(connection->_connection)->interrupt(); +} + +lbug_state lbug_connection_set_query_timeout(lbug_connection* connection, uint64_t timeout_in_ms) { + if (connection == nullptr || connection->_connection == nullptr) { + return LbugError; + } + try { + static_cast(connection->_connection)->setQueryTimeOut(timeout_in_ms); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/data_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/data_type.cpp new file mode 100644 index 0000000000..37c57ca848 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/data_type.cpp @@ -0,0 +1,70 @@ +#include "c_api/lbug.h" +#include "common/types/types.h" + +using namespace lbug::common; + +namespace lbug::common { +struct CAPIHelper { + static inline LogicalType* createLogicalType(LogicalTypeID typeID, + std::unique_ptr extraTypeInfo) { + return new LogicalType(typeID, std::move(extraTypeInfo)); + } +}; +} // namespace lbug::common + +void lbug_data_type_create(lbug_data_type_id id, lbug_logical_type* child_type, + uint64_t num_elements_in_array, lbug_logical_type* out_data_type) { + uint8_t data_type_id_u8 = id; + LogicalType* data_type = nullptr; + auto logicalTypeID = static_cast(data_type_id_u8); + if (child_type == nullptr) { + data_type = new LogicalType(logicalTypeID); + } else { + auto child_type_pty = static_cast(child_type->_data_type)->copy(); + auto extraTypeInfo = + num_elements_in_array > 0 ? + std::make_unique(std::move(child_type_pty), num_elements_in_array) : + std::make_unique(std::move(child_type_pty)); + data_type = CAPIHelper::createLogicalType(logicalTypeID, std::move(extraTypeInfo)); + } + out_data_type->_data_type = data_type; +} + +void lbug_data_type_clone(lbug_logical_type* data_type, lbug_logical_type* out_data_type) { + out_data_type->_data_type = + new LogicalType(static_cast(data_type->_data_type)->copy()); +} + +void lbug_data_type_destroy(lbug_logical_type* data_type) { + if (data_type == nullptr) { + return; + } + if (data_type->_data_type != nullptr) { + delete static_cast(data_type->_data_type); + } +} + +bool lbug_data_type_equals(lbug_logical_type* data_type1, lbug_logical_type* data_type2) { + return *static_cast(data_type1->_data_type) == + *static_cast(data_type2->_data_type); +} + +lbug_data_type_id lbug_data_type_get_id(lbug_logical_type* data_type) { + auto data_type_id_u8 = + static_cast(static_cast(data_type->_data_type)->getLogicalTypeID()); + return static_cast(data_type_id_u8); +} + +lbug_state lbug_data_type_get_num_elements_in_array(lbug_logical_type* data_type, + uint64_t* out_result) { + auto parent_type = static_cast(data_type->_data_type); + if (parent_type->getLogicalTypeID() != LogicalTypeID::ARRAY) { + return LbugError; + } + try { + *out_result = ArrayType::getNumElements(*parent_type); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/database.cpp new file mode 100644 index 0000000000..80c60689dd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/database.cpp @@ -0,0 +1,49 @@ +#include "c_api/lbug.h" +#include "common/exception/exception.h" +#include "main/lbug.h" +using namespace lbug::main; +using namespace lbug::common; + +lbug_state lbug_database_init(const char* database_path, lbug_system_config config, + lbug_database* out_database) { + try { + std::string database_path_str = database_path; + auto systemConfig = SystemConfig(config.buffer_pool_size, config.max_num_threads, + config.enable_compression, config.read_only, config.max_db_size, config.auto_checkpoint, + config.checkpoint_threshold); + +#if defined(__APPLE__) + systemConfig.threadQos = config.thread_qos; +#endif + out_database->_database = new Database(database_path_str, systemConfig); + } catch (Exception& e) { + out_database->_database = nullptr; + return LbugError; + } + return LbugSuccess; +} + +void lbug_database_destroy(lbug_database* database) { + if (database == nullptr) { + return; + } + if (database->_database != nullptr) { + delete static_cast(database->_database); + } +} + +lbug_system_config lbug_default_system_config() { + SystemConfig config = SystemConfig(); + auto cSystemConfig = lbug_system_config(); + cSystemConfig.buffer_pool_size = config.bufferPoolSize; + cSystemConfig.max_num_threads = config.maxNumThreads; + cSystemConfig.enable_compression = config.enableCompression; + cSystemConfig.read_only = config.readOnly; + cSystemConfig.max_db_size = config.maxDBSize; + cSystemConfig.auto_checkpoint = config.autoCheckpoint; + cSystemConfig.checkpoint_threshold = config.checkpointThreshold; +#if defined(__APPLE__) + cSystemConfig.thread_qos = config.threadQos; +#endif + return cSystemConfig; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/flat_tuple.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/flat_tuple.cpp new file mode 100644 index 0000000000..a4580a20a6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/flat_tuple.cpp @@ -0,0 +1,38 @@ +#include "processor/result/flat_tuple.h" + +#include "c_api/helpers.h" +#include "c_api/lbug.h" +#include "common/exception/exception.h" + +using namespace lbug::common; +using namespace lbug::processor; + +void lbug_flat_tuple_destroy(lbug_flat_tuple* flat_tuple) { + if (flat_tuple == nullptr) { + return; + } + if (flat_tuple->_flat_tuple != nullptr && !flat_tuple->_is_owned_by_cpp) { + delete static_cast(flat_tuple->_flat_tuple); + } +} + +lbug_state lbug_flat_tuple_get_value(lbug_flat_tuple* flat_tuple, uint64_t index, + lbug_value* out_value) { + auto flat_tuple_ptr = static_cast(flat_tuple->_flat_tuple); + Value* _value = nullptr; + try { + _value = flat_tuple_ptr->getValue(index); + } catch (Exception& e) { + return LbugError; + } + out_value->_value = _value; + // We set the ownership of the value to C++, so it will not be deleted if the value is destroyed + // in C. + out_value->_is_owned_by_cpp = true; + return LbugSuccess; +} + +char* lbug_flat_tuple_to_string(lbug_flat_tuple* flat_tuple) { + auto flat_tuple_ptr = static_cast(flat_tuple->_flat_tuple); + return convertToOwnedCString(flat_tuple_ptr->toString()); +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/helpers.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/helpers.cpp new file mode 100644 index 0000000000..4eb8955dd1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/helpers.cpp @@ -0,0 +1,54 @@ +#include "c_api/helpers.h" + +#include + +#ifdef _WIN32 +const uint64_t NS_TO_SEC = 10000000ULL; +const uint64_t SEC_TO_UNIX_EPOCH = 11644473600ULL; + +time_t convertTmToTime(struct tm tm) { + SYSTEMTIME st; + st.wYear = tm.tm_year + 1900; + st.wMonth = tm.tm_mon + 1; + st.wDay = tm.tm_mday; + st.wHour = tm.tm_hour; + st.wMinute = tm.tm_min; + st.wSecond = tm.tm_sec; + st.wMilliseconds = 0; + FILETIME ft; + if (!SystemTimeToFileTime(&st, &ft)) { + return -1; + } + ULARGE_INTEGER ull; + ull.LowPart = ft.dwLowDateTime; + ull.HighPart = ft.dwHighDateTime; + return static_cast((ull.QuadPart / NS_TO_SEC) - SEC_TO_UNIX_EPOCH); +} + +int32_t convertTimeToTm(time_t time, struct tm* out_tm) { + ULARGE_INTEGER ull; + ull.QuadPart = (time + SEC_TO_UNIX_EPOCH) * NS_TO_SEC; + FILETIME ft; + ft.dwLowDateTime = ull.LowPart; + ft.dwHighDateTime = ull.HighPart; + SYSTEMTIME st; + if (!FileTimeToSystemTime(&ft, &st)) { + return -1; + } + out_tm->tm_year = st.wYear - 1900; + out_tm->tm_mon = st.wMonth - 1; + out_tm->tm_mday = st.wDay; + out_tm->tm_hour = st.wHour; + out_tm->tm_min = st.wMinute; + out_tm->tm_sec = st.wSecond; + return 0; +} +#endif + +char* convertToOwnedCString(const std::string& str) { + size_t src_len = str.size(); + auto* c_str = (char*)malloc(sizeof(char) * (src_len + 1)); + memcpy(c_str, str.c_str(), src_len); + c_str[src_len] = '\0'; + return c_str; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/prepared_statement.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/prepared_statement.cpp new file mode 100644 index 0000000000..4000f2a8a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/prepared_statement.cpp @@ -0,0 +1,283 @@ +#include "main/prepared_statement.h" + +#include "c_api/helpers.h" +#include "c_api/lbug.h" +#include "common/types/value/value.h" + +using namespace lbug::common; +using namespace lbug::main; + +void lbug_prepared_statement_bind_cpp_value(lbug_prepared_statement* prepared_statement, + const char* param_name, std::unique_ptr value) { + auto* bound_values = static_cast>*>( + prepared_statement->_bound_values); + bound_values->erase(param_name); + bound_values->insert({param_name, std::move(value)}); +} + +void lbug_prepared_statement_destroy(lbug_prepared_statement* prepared_statement) { + if (prepared_statement == nullptr) { + return; + } + if (prepared_statement->_prepared_statement != nullptr) { + delete static_cast(prepared_statement->_prepared_statement); + } + if (prepared_statement->_bound_values != nullptr) { + delete static_cast>*>( + prepared_statement->_bound_values); + } +} + +bool lbug_prepared_statement_is_success(lbug_prepared_statement* prepared_statement) { + return static_cast(prepared_statement->_prepared_statement)->isSuccess(); +} + +char* lbug_prepared_statement_get_error_message(lbug_prepared_statement* prepared_statement) { + auto error_message = + static_cast(prepared_statement->_prepared_statement)->getErrorMessage(); + if (error_message.empty()) { + return nullptr; + } + return convertToOwnedCString(error_message); +} + +lbug_state lbug_prepared_statement_bind_bool(lbug_prepared_statement* prepared_statement, + const char* param_name, bool value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_int64(lbug_prepared_statement* prepared_statement, + const char* param_name, int64_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_int32(lbug_prepared_statement* prepared_statement, + const char* param_name, int32_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_int16(lbug_prepared_statement* prepared_statement, + const char* param_name, int16_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_int8(lbug_prepared_statement* prepared_statement, + const char* param_name, int8_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_uint64(lbug_prepared_statement* prepared_statement, + const char* param_name, uint64_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_uint32(lbug_prepared_statement* prepared_statement, + const char* param_name, uint32_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_uint16(lbug_prepared_statement* prepared_statement, + const char* param_name, uint16_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_uint8(lbug_prepared_statement* prepared_statement, + const char* param_name, uint8_t value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_double(lbug_prepared_statement* prepared_statement, + const char* param_name, double value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_float(lbug_prepared_statement* prepared_statement, + const char* param_name, float value) { + try { + auto value_ptr = std::make_unique(value); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_date(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_date_t value) { + try { + auto value_ptr = std::make_unique(date_t(value.days)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_timestamp_ns(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_timestamp_ns_t value) { + try { + auto value_ptr = std::make_unique(timestamp_ns_t(value.value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_timestamp_ms(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_timestamp_ms_t value) { + try { + auto value_ptr = std::make_unique(timestamp_ms_t(value.value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_timestamp_sec(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_timestamp_sec_t value) { + try { + auto value_ptr = std::make_unique(timestamp_sec_t(value.value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_timestamp_tz(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_timestamp_tz_t value) { + try { + auto value_ptr = std::make_unique(timestamp_tz_t(value.value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_timestamp(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_timestamp_t value) { + try { + auto value_ptr = std::make_unique(timestamp_t(value.value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_interval(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_interval_t value) { + try { + auto value_ptr = + std::make_unique(interval_t(value.months, value.days, value.micros)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_string(lbug_prepared_statement* prepared_statement, + const char* param_name, const char* value) { + try { + auto value_ptr = std::make_unique(std::string(value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_prepared_statement_bind_value(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_value* value) { + try { + auto value_ptr = std::make_unique(*static_cast(value->_value)); + lbug_prepared_statement_bind_cpp_value(prepared_statement, param_name, + std::move(value_ptr)); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/query_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/query_result.cpp new file mode 100644 index 0000000000..bb3f77aae7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/query_result.cpp @@ -0,0 +1,136 @@ +#include "main/query_result.h" + +#include "c_api/helpers.h" +#include "c_api/lbug.h" + +using namespace lbug::main; +using namespace lbug::common; +using namespace lbug::processor; + +void lbug_query_result_destroy(lbug_query_result* query_result) { + if (query_result == nullptr) { + return; + } + if (query_result->_query_result != nullptr) { + if (!query_result->_is_owned_by_cpp) { + delete static_cast(query_result->_query_result); + } + } +} + +bool lbug_query_result_is_success(lbug_query_result* query_result) { + return static_cast(query_result->_query_result)->isSuccess(); +} + +char* lbug_query_result_get_error_message(lbug_query_result* query_result) { + auto error_message = static_cast(query_result->_query_result)->getErrorMessage(); + if (error_message.empty()) { + return nullptr; + } + return convertToOwnedCString(error_message); +} + +uint64_t lbug_query_result_get_num_columns(lbug_query_result* query_result) { + return static_cast(query_result->_query_result)->getNumColumns(); +} + +lbug_state lbug_query_result_get_column_name(lbug_query_result* query_result, uint64_t index, + char** out_column_name) { + auto column_names = static_cast(query_result->_query_result)->getColumnNames(); + if (index >= column_names.size()) { + return LbugError; + } + *out_column_name = convertToOwnedCString(column_names[index]); + return LbugSuccess; +} + +lbug_state lbug_query_result_get_column_data_type(lbug_query_result* query_result, uint64_t index, + lbug_logical_type* out_column_data_type) { + auto column_data_types = + static_cast(query_result->_query_result)->getColumnDataTypes(); + if (index >= column_data_types.size()) { + return LbugError; + } + const auto& column_data_type = column_data_types[index]; + out_column_data_type->_data_type = new LogicalType(column_data_type.copy()); + return LbugSuccess; +} + +uint64_t lbug_query_result_get_num_tuples(lbug_query_result* query_result) { + return static_cast(query_result->_query_result)->getNumTuples(); +} + +lbug_state lbug_query_result_get_query_summary(lbug_query_result* query_result, + lbug_query_summary* out_query_summary) { + if (out_query_summary == nullptr) { + return LbugError; + } + auto query_summary = static_cast(query_result->_query_result)->getQuerySummary(); + out_query_summary->_query_summary = query_summary; + return LbugSuccess; +} + +bool lbug_query_result_has_next(lbug_query_result* query_result) { + return static_cast(query_result->_query_result)->hasNext(); +} + +bool lbug_query_result_has_next_query_result(lbug_query_result* query_result) { + return static_cast(query_result->_query_result)->hasNextQueryResult(); +} + +lbug_state lbug_query_result_get_next_query_result(lbug_query_result* query_result, + lbug_query_result* out_query_result) { + if (!lbug_query_result_has_next_query_result(query_result)) { + return LbugError; + } + auto next_query_result = + static_cast(query_result->_query_result)->getNextQueryResult(); + if (next_query_result == nullptr) { + return LbugError; + } + out_query_result->_query_result = next_query_result; + out_query_result->_is_owned_by_cpp = true; + return LbugSuccess; +} + +lbug_state lbug_query_result_get_next(lbug_query_result* query_result, + lbug_flat_tuple* out_flat_tuple) { + try { + auto flat_tuple = static_cast(query_result->_query_result)->getNext(); + out_flat_tuple->_flat_tuple = flat_tuple.get(); + out_flat_tuple->_is_owned_by_cpp = true; + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +char* lbug_query_result_to_string(lbug_query_result* query_result) { + std::string result_string = static_cast(query_result->_query_result)->toString(); + return convertToOwnedCString(result_string); +} + +void lbug_query_result_reset_iterator(lbug_query_result* query_result) { + static_cast(query_result->_query_result)->resetIterator(); +} + +lbug_state lbug_query_result_get_arrow_schema(lbug_query_result* query_result, + ArrowSchema* out_schema) { + try { + *out_schema = *static_cast(query_result->_query_result)->getArrowSchema(); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_query_result_get_next_arrow_chunk(lbug_query_result* query_result, + int64_t chunk_size, ArrowArray* out_arrow_array) { + try { + *out_arrow_array = + *static_cast(query_result->_query_result)->getNextArrowChunk(chunk_size); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/query_summary.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/query_summary.cpp new file mode 100644 index 0000000000..7f75090d84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/query_summary.cpp @@ -0,0 +1,23 @@ +#include "main/query_summary.h" + +#include + +#include "c_api/lbug.h" + +using namespace lbug::main; + +void lbug_query_summary_destroy(lbug_query_summary* query_summary) { + if (query_summary == nullptr) { + return; + } + // The query summary is owned by the query result, so it should not be deleted here. + query_summary->_query_summary = nullptr; +} + +double lbug_query_summary_get_compiling_time(lbug_query_summary* query_summary) { + return static_cast(query_summary->_query_summary)->getCompilingTime(); +} + +double lbug_query_summary_get_execution_time(lbug_query_summary* query_summary) { + return static_cast(query_summary->_query_summary)->getExecutionTime(); +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/value.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/value.cpp new file mode 100644 index 0000000000..24335fe7c2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/value.cpp @@ -0,0 +1,1192 @@ +#include "common/types/value/value.h" + +#include "c_api/helpers.h" +#include "c_api/lbug.h" +#include "common/constants.h" +#include "common/types/types.h" +#include "common/types/value/nested.h" +#include "common/types/value/node.h" +#include "common/types/value/recursive_rel.h" +#include "common/types/value/rel.h" +#include "function/cast/functions/cast_from_string_functions.h" + +using namespace lbug::common; + +lbug_value* lbug_value_create_null() { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(Value::createNullValue()); + return c_value; +} + +lbug_value* lbug_value_create_null_with_data_type(lbug_logical_type* data_type) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = + new Value(Value::createNullValue(*static_cast(data_type->_data_type))); + return c_value; +} + +bool lbug_value_is_null(lbug_value* value) { + return static_cast(value->_value)->isNull(); +} + +void lbug_value_set_null(lbug_value* value, bool is_null) { + static_cast(value->_value)->setNull(is_null); +} + +lbug_value* lbug_value_create_default(lbug_logical_type* data_type) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = + new Value(Value::createDefaultValue(*static_cast(data_type->_data_type))); + return c_value; +} + +lbug_value* lbug_value_create_bool(bool val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_int8(int8_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_int16(int16_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_int32(int32_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_int64(int64_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_uint8(uint8_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_uint16(uint16_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_uint32(uint32_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_uint64(uint64_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_int128(lbug_int128_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + int128_t int128(val_.low, val_.high); + c_value->_value = new Value(int128); + return c_value; +} + +lbug_value* lbug_value_create_float(float val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_double(double val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_value* lbug_value_create_internal_id(lbug_internal_id_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + internalID_t id(val_.offset, val_.table_id); + c_value->_value = new Value(id); + return c_value; +} + +lbug_value* lbug_value_create_date(lbug_date_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto date = date_t(val_.days); + c_value->_value = new Value(date); + return c_value; +} + +lbug_value* lbug_value_create_timestamp_ns(lbug_timestamp_ns_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto timestamp_ns = timestamp_ns_t(val_.value); + c_value->_value = new Value(timestamp_ns); + return c_value; +} + +lbug_value* lbug_value_create_timestamp_ms(lbug_timestamp_ms_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto timestamp_ms = timestamp_ms_t(val_.value); + c_value->_value = new Value(timestamp_ms); + return c_value; +} + +lbug_value* lbug_value_create_timestamp_sec(lbug_timestamp_sec_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto timestamp_sec = timestamp_sec_t(val_.value); + c_value->_value = new Value(timestamp_sec); + return c_value; +} + +lbug_value* lbug_value_create_timestamp_tz(lbug_timestamp_tz_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto timestamp_tz = timestamp_tz_t(val_.value); + c_value->_value = new Value(timestamp_tz); + return c_value; +} + +lbug_value* lbug_value_create_timestamp(lbug_timestamp_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto timestamp = timestamp_t(val_.value); + c_value->_value = new Value(timestamp); + return c_value; +} + +lbug_value* lbug_value_create_interval(lbug_interval_t val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + auto interval = interval_t(val_.months, val_.days, val_.micros); + c_value->_value = new Value(interval); + return c_value; +} + +lbug_value* lbug_value_create_string(const char* val_) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(val_); + return c_value; +} + +lbug_state lbug_value_create_list(uint64_t num_elements, lbug_value** elements, + lbug_value** out_value) { + if (num_elements == 0) { + return LbugError; + } + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + std::vector> children; + + auto first_element = static_cast(elements[0]->_value); + auto type = first_element->getDataType().copy(); + + for (uint64_t i = 0; i < num_elements; ++i) { + auto child = static_cast(elements[i]->_value); + if (child->getDataType() != type) { + free(c_value); + return LbugError; + } + // Copy the value to the list value to transfer ownership to the C++ side. + children.push_back(child->copy()); + } + auto list_type = LogicalType::LIST(first_element->getDataType().copy()); + c_value->_value = new Value(list_type.copy(), std::move(children)); + c_value->_is_owned_by_cpp = false; + *out_value = c_value; + return LbugSuccess; +} + +lbug_state lbug_value_create_struct(uint64_t num_fields, const char** field_names, + lbug_value** field_values, lbug_value** out_value) { + if (num_fields == 0) { + return LbugError; + } + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + std::vector> children; + auto struct_fields = std::vector{}; + for (uint64_t i = 0; i < num_fields; ++i) { + auto field_name = std::string(field_names[i]); + auto field_value = static_cast(field_values[i]->_value); + auto field_type = field_value->getDataType().copy(); + struct_fields.emplace_back(std::move(field_name), std::move(field_type)); + children.push_back(field_value->copy()); + } + auto struct_type = LogicalType::STRUCT(std::move(struct_fields)); + c_value->_value = new Value(std::move(struct_type), std::move(children)); + c_value->_is_owned_by_cpp = false; + *out_value = c_value; + return LbugSuccess; +} + +lbug_state lbug_value_create_map(uint64_t num_fields, lbug_value** keys, lbug_value** values, + lbug_value** out_value) { + if (num_fields == 0) { + return LbugError; + } + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + std::vector> children; + + auto first_key = static_cast(keys[0]->_value); + auto first_value = static_cast(values[0]->_value); + auto key_type = first_key->getDataType().copy(); + auto value_type = first_value->getDataType().copy(); + + for (uint64_t i = 0; i < num_fields; ++i) { + auto key = static_cast(keys[i]->_value); + auto value = static_cast(values[i]->_value); + if (key->getDataType() != key_type || value->getDataType() != value_type) { + free(c_value); + return LbugError; + } + std::vector struct_fields; + struct_fields.emplace_back(InternalKeyword::MAP_KEY, key_type.copy()); + struct_fields.emplace_back(InternalKeyword::MAP_VALUE, value_type.copy()); + std::vector> struct_values; + struct_values.push_back(key->copy()); + struct_values.push_back(value->copy()); + auto struct_type = LogicalType::STRUCT(std::move(struct_fields)); + auto struct_value = new Value(std::move(struct_type), std::move(struct_values)); + children.push_back(std::unique_ptr(struct_value)); + } + auto map_type = LogicalType::MAP(key_type.copy(), value_type.copy()); + c_value->_value = new Value(map_type.copy(), std::move(children)); + c_value->_is_owned_by_cpp = false; + *out_value = c_value; + return LbugSuccess; +} + +lbug_value* lbug_value_clone(lbug_value* value) { + auto* c_value = (lbug_value*)calloc(1, sizeof(lbug_value)); + c_value->_value = new Value(*static_cast(value->_value)); + return c_value; +} + +void lbug_value_copy(lbug_value* value, lbug_value* other) { + static_cast(value->_value)->copyValueFrom(*static_cast(other->_value)); +} + +void lbug_value_destroy(lbug_value* value) { + if (value == nullptr) { + return; + } + if (!value->_is_owned_by_cpp) { + if (value->_value != nullptr) { + delete static_cast(value->_value); + } + free(value); + } +} + +lbug_state lbug_value_get_list_size(lbug_value* value, uint64_t* out_result) { + if (static_cast(value->_value)->getDataType().getLogicalTypeID() != + LogicalTypeID::LIST) { + return LbugError; + } + *out_result = NestedVal::getChildrenSize(static_cast(value->_value)); + return LbugSuccess; +} + +lbug_state lbug_value_get_list_element(lbug_value* value, uint64_t index, lbug_value* out_value) { + auto physical_type_id = static_cast(value->_value)->getDataType().getPhysicalType(); + if (physical_type_id != PhysicalTypeID::ARRAY && physical_type_id != PhysicalTypeID::STRUCT && + physical_type_id != PhysicalTypeID::LIST) { + return LbugError; + } + auto listValue = static_cast(value->_value); + if (index >= NestedVal::getChildrenSize(listValue)) { + return LbugError; + } + try { + auto val = NestedVal::getChildVal(listValue, index); + out_value->_value = val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_struct_num_fields(lbug_value* value, uint64_t* out_result) { + auto physical_type_id = static_cast(value->_value)->getDataType().getPhysicalType(); + if (physical_type_id != PhysicalTypeID::STRUCT) { + return LbugError; + } + auto val = static_cast(value->_value); + const auto& data_type = val->getDataType(); + try { + *out_result = StructType::getNumFields(data_type); + return LbugSuccess; + } catch (Exception& e) { + return LbugError; + } +} + +lbug_state lbug_value_get_struct_field_name(lbug_value* value, uint64_t index, char** out_result) { + auto physical_type_id = static_cast(value->_value)->getDataType().getPhysicalType(); + if (physical_type_id != PhysicalTypeID::STRUCT) { + return LbugError; + } + auto val = static_cast(value->_value); + const auto& data_type = val->getDataType(); + if (index >= StructType::getNumFields(data_type)) { + return LbugError; + } + std::string struct_field_name = StructType::getFields(data_type)[index].getName(); + if (struct_field_name.empty()) { + return LbugError; + } + *out_result = convertToOwnedCString(struct_field_name); + return LbugSuccess; +} + +lbug_state lbug_value_get_struct_field_value(lbug_value* value, uint64_t index, + lbug_value* out_value) { + return lbug_value_get_list_element(value, index, out_value); +} + +lbug_state lbug_value_get_map_size(lbug_value* value, uint64_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::MAP) { + return LbugError; + } + auto listValue = static_cast(value->_value); + *out_result = NestedVal::getChildrenSize(listValue); + return LbugSuccess; +} + +lbug_state lbug_value_get_map_key(lbug_value* value, uint64_t index, lbug_value* out_key) { + lbug_value map_entry; + if (lbug_value_get_list_element(value, index, &map_entry) == LbugError) { + return LbugError; + } + return lbug_value_get_struct_field_value(&map_entry, 0, out_key); +} + +lbug_state lbug_value_get_map_value(lbug_value* value, uint64_t index, lbug_value* out_value) { + lbug_value map_entry; + if (lbug_value_get_list_element(value, index, &map_entry) == LbugError) { + return LbugError; + } + return lbug_value_get_struct_field_value(&map_entry, 1, out_value); +} + +lbug_state lbug_value_get_recursive_rel_node_list(lbug_value* value, lbug_value* out_value) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::RECURSIVE_REL) { + return LbugError; + } + out_value->_is_owned_by_cpp = true; + try { + out_value->_value = RecursiveRelVal::getNodes(static_cast(value->_value)); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_recursive_rel_rel_list(lbug_value* value, lbug_value* out_value) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::RECURSIVE_REL) { + return LbugError; + } + out_value->_is_owned_by_cpp = true; + try { + out_value->_value = RecursiveRelVal::getRels(static_cast(value->_value)); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +void lbug_value_get_data_type(lbug_value* value, lbug_logical_type* out_data_type) { + out_data_type->_data_type = + new LogicalType(static_cast(value->_value)->getDataType().copy()); +} + +lbug_state lbug_value_get_bool(lbug_value* value, bool* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::BOOL) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_int8(lbug_value* value, int8_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INT8) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_int16(lbug_value* value, int16_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INT16) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_int32(lbug_value* value, int32_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INT32) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_int64(lbug_value* value, int64_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INT64) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_uint8(lbug_value* value, uint8_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::UINT8) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_uint16(lbug_value* value, uint16_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::UINT16) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_uint32(lbug_value* value, uint32_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::UINT32) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_uint64(lbug_value* value, uint64_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::UINT64) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_int128(lbug_value* value, lbug_int128_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INT128) { + return LbugError; + } + try { + auto int128_val = static_cast(value->_value)->getValue(); + out_result->low = int128_val.low; + out_result->high = int128_val.high; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_int128_t_from_string(const char* str, lbug_int128_t* out_result) { + int128_t int128_val = 0; + try { + lbug::function::CastString::operation(ku_string_t{str, strlen(str)}, int128_val); + out_result->low = int128_val.low; + out_result->high = int128_val.high; + } catch (ConversionException& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_int128_t_to_string(lbug_int128_t int128_val, char** out_result) { + int128_t c_int128 = 0; + c_int128.low = int128_val.low; + c_int128.high = int128_val.high; + try { + *out_result = convertToOwnedCString(TypeUtils::toString(c_int128)); + } catch (ConversionException& e) { + return LbugError; + } + return LbugSuccess; +} +// TODO: bind all int128_t supported functions + +lbug_state lbug_value_get_float(lbug_value* value, float* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::FLOAT) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_double(lbug_value* value, double* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::DOUBLE) { + return LbugError; + } + try { + *out_result = static_cast(value->_value)->getValue(); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_internal_id(lbug_value* value, lbug_internal_id_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INTERNAL_ID) { + return LbugError; + } + try { + auto id = static_cast(value->_value)->getValue(); + out_result->offset = id.offset; + out_result->table_id = id.tableID; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_date(lbug_value* value, lbug_date_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::DATE) { + return LbugError; + } + try { + auto date_val = static_cast(value->_value)->getValue(); + out_result->days = date_val.days; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_timestamp(lbug_value* value, lbug_timestamp_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::TIMESTAMP) { + return LbugError; + } + try { + auto timestamp_val = static_cast(value->_value)->getValue(); + out_result->value = timestamp_val.value; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_timestamp_ns(lbug_value* value, lbug_timestamp_ns_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::TIMESTAMP_NS) { + return LbugError; + } + try { + auto timestamp_val = static_cast(value->_value)->getValue(); + out_result->value = timestamp_val.value; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_timestamp_ms(lbug_value* value, lbug_timestamp_ms_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::TIMESTAMP_MS) { + return LbugError; + } + try { + auto timestamp_val = static_cast(value->_value)->getValue(); + out_result->value = timestamp_val.value; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_timestamp_sec(lbug_value* value, lbug_timestamp_sec_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::TIMESTAMP_SEC) { + return LbugError; + } + try { + auto timestamp_val = static_cast(value->_value)->getValue(); + out_result->value = timestamp_val.value; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_timestamp_tz(lbug_value* value, lbug_timestamp_tz_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::TIMESTAMP_TZ) { + return LbugError; + } + try { + auto timestamp_val = static_cast(value->_value)->getValue(); + out_result->value = timestamp_val.value; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_decimal_as_string(lbug_value* value, char** out_result) { + auto decimal_val = static_cast(value->_value); + auto logical_type_id = decimal_val->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::DECIMAL) { + return LbugError; + } + + *out_result = convertToOwnedCString(decimal_val->toString()); + return LbugSuccess; +} + +lbug_state lbug_value_get_interval(lbug_value* value, lbug_interval_t* out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::INTERVAL) { + return LbugError; + } + try { + auto interval_val = static_cast(value->_value)->getValue(); + out_result->months = interval_val.months; + out_result->days = interval_val.days; + out_result->micros = interval_val.micros; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_string(lbug_value* value, char** out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::STRING) { + return LbugError; + } + try { + *out_result = + convertToOwnedCString(static_cast(value->_value)->getValue()); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_blob(lbug_value* value, uint8_t** out_result, uint64_t* out_length) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::BLOB) { + return LbugError; + } + try { + auto blob = static_cast(value->_value)->getValue(); + *out_length = blob.size(); + auto* buffer = (uint8_t*)malloc(sizeof(uint8_t) * blob.size()); + memcpy(buffer, blob.data(), blob.size()); + *out_result = buffer; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_value_get_uuid(lbug_value* value, char** out_result) { + auto logical_type_id = static_cast(value->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::UUID) { + return LbugError; + } + try { + *out_result = + convertToOwnedCString(static_cast(value->_value)->getValue()); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +char* lbug_value_to_string(lbug_value* value) { + return convertToOwnedCString(static_cast(value->_value)->toString()); +} + +lbug_state lbug_node_val_get_id_val(lbug_value* node_val, lbug_value* out_value) { + auto logical_type_id = static_cast(node_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::NODE) { + return LbugError; + } + try { + auto id_val = NodeVal::getNodeIDVal(static_cast(node_val->_value)); + out_value->_value = id_val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_node_val_get_label_val(lbug_value* node_val, lbug_value* out_value) { + auto logical_type_id = static_cast(node_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::NODE) { + return LbugError; + } + try { + auto label_val = NodeVal::getLabelVal(static_cast(node_val->_value)); + out_value->_value = label_val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_node_val_get_property_size(lbug_value* node_val, uint64_t* out_result) { + auto logical_type_id = static_cast(node_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::NODE) { + return LbugError; + } + try { + *out_result = NodeVal::getNumProperties(static_cast(node_val->_value)); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_node_val_get_property_name_at(lbug_value* node_val, uint64_t index, + char** out_result) { + auto logical_type_id = static_cast(node_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::NODE) { + return LbugError; + } + try { + std::string property_name = + NodeVal::getPropertyName(static_cast(node_val->_value), index); + if (property_name.empty()) { + return LbugError; + } + *out_result = convertToOwnedCString(property_name); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_node_val_get_property_value_at(lbug_value* node_val, uint64_t index, + lbug_value* out_value) { + auto logical_type_id = static_cast(node_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::NODE) { + return LbugError; + } + try { + auto value = NodeVal::getPropertyVal(static_cast(node_val->_value), index); + out_value->_value = value; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_node_val_to_string(lbug_value* node_val, char** out_result) { + auto logical_type_id = static_cast(node_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::NODE) { + return LbugError; + } + try { + *out_result = + convertToOwnedCString(NodeVal::toString(static_cast(node_val->_value))); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_get_id_val(lbug_value* rel_val, lbug_value* out_value) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + auto id_val = RelVal::getIDVal(static_cast(rel_val->_value)); + out_value->_value = id_val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_get_src_id_val(lbug_value* rel_val, lbug_value* out_value) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + auto src_id_val = RelVal::getSrcNodeIDVal(static_cast(rel_val->_value)); + out_value->_value = src_id_val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_get_dst_id_val(lbug_value* rel_val, lbug_value* out_value) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + auto dst_id_val = RelVal::getDstNodeIDVal(static_cast(rel_val->_value)); + out_value->_value = dst_id_val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_get_label_val(lbug_value* rel_val, lbug_value* out_value) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + auto label_val = RelVal::getLabelVal(static_cast(rel_val->_value)); + out_value->_value = label_val; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_get_property_size(lbug_value* rel_val, uint64_t* out_result) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + *out_result = RelVal::getNumProperties(static_cast(rel_val->_value)); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} +lbug_state lbug_rel_val_get_property_name_at(lbug_value* rel_val, uint64_t index, + char** out_result) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + std::string property_name = + RelVal::getPropertyName(static_cast(rel_val->_value), index); + if (property_name.empty()) { + return LbugError; + } + *out_result = convertToOwnedCString(property_name); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_get_property_value_at(lbug_value* rel_val, uint64_t index, + lbug_value* out_value) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + auto value = RelVal::getPropertyVal(static_cast(rel_val->_value), index); + out_value->_value = value; + out_value->_is_owned_by_cpp = true; + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +lbug_state lbug_rel_val_to_string(lbug_value* rel_val, char** out_result) { + auto logical_type_id = static_cast(rel_val->_value)->getDataType().getLogicalTypeID(); + if (logical_type_id != LogicalTypeID::REL) { + return LbugError; + } + try { + *out_result = convertToOwnedCString(RelVal::toString(static_cast(rel_val->_value))); + } catch (Exception& e) { + return LbugError; + } + return LbugSuccess; +} + +void lbug_destroy_string(char* str) { + free(str); +} + +void lbug_destroy_blob(uint8_t* blob) { + free(blob); +} + +lbug_state lbug_timestamp_ns_to_tm(lbug_timestamp_ns_t timestamp, struct tm* out_result) { + time_t time = timestamp.value / 1000000000; +#ifdef _WIN32 + if (convertTimeToTm(time, out_result) != 0) { + return LbugError; + } +#else + if (gmtime_r(&time, out_result) == nullptr) { + return LbugError; + } +#endif + return LbugSuccess; +} + +lbug_state lbug_timestamp_ms_to_tm(lbug_timestamp_ms_t timestamp, struct tm* out_result) { + time_t time = timestamp.value / 1000; +#ifdef _WIN32 + if (convertTimeToTm(time, out_result) != 0) { + return LbugError; + } +#else + if (gmtime_r(&time, out_result) == nullptr) { + return LbugError; + } +#endif + return LbugSuccess; +} + +lbug_state lbug_timestamp_sec_to_tm(lbug_timestamp_sec_t timestamp, struct tm* out_result) { + time_t time = timestamp.value; +#ifdef _WIN32 + if (convertTimeToTm(time, out_result) != 0) { + return LbugError; + } +#else + if (gmtime_r(&time, out_result) == nullptr) { + return LbugError; + } +#endif + return LbugSuccess; +} + +lbug_state lbug_timestamp_tz_to_tm(lbug_timestamp_tz_t timestamp, struct tm* out_result) { + time_t time = timestamp.value / 1000000; +#ifdef _WIN32 + if (convertTimeToTm(time, out_result) != 0) { + return LbugError; + } +#else + if (gmtime_r(&time, out_result) == nullptr) { + return LbugError; + } +#endif + return LbugSuccess; +} + +lbug_state lbug_timestamp_to_tm(lbug_timestamp_t timestamp, struct tm* out_result) { + time_t time = timestamp.value / 1000000; +#ifdef _WIN32 + if (convertTimeToTm(time, out_result) != 0) { + return LbugError; + } +#else + if (gmtime_r(&time, out_result) == nullptr) { + return LbugError; + } +#endif + return LbugSuccess; +} + +lbug_state lbug_timestamp_ns_from_tm(struct tm tm, lbug_timestamp_ns_t* out_result) { +#ifdef _WIN32 + int64_t time = convertTmToTime(tm); +#else + int64_t time = timegm(&tm); +#endif + if (time == -1) { + return LbugError; + } + out_result->value = time * 1000000000; + return LbugSuccess; +} + +lbug_state lbug_timestamp_ms_from_tm(struct tm tm, lbug_timestamp_ms_t* out_result) { +#ifdef _WIN32 + int64_t time = convertTmToTime(tm); +#else + int64_t time = timegm(&tm); +#endif + if (time == -1) { + return LbugError; + } + out_result->value = time * 1000; + return LbugSuccess; +} + +lbug_state lbug_timestamp_sec_from_tm(struct tm tm, lbug_timestamp_sec_t* out_result) { +#ifdef _WIN32 + int64_t time = convertTmToTime(tm); +#else + int64_t time = timegm(&tm); +#endif + if (time == -1) { + return LbugError; + } + out_result->value = time; + return LbugSuccess; +} + +lbug_state lbug_timestamp_tz_from_tm(struct tm tm, lbug_timestamp_tz_t* out_result) { +#ifdef _WIN32 + int64_t time = convertTmToTime(tm); +#else + int64_t time = timegm(&tm); +#endif + if (time == -1) { + return LbugError; + } + out_result->value = time * 1000000; + return LbugSuccess; +} + +lbug_state lbug_timestamp_from_tm(struct tm tm, lbug_timestamp_t* out_result) { +#ifdef _WIN32 + int64_t time = convertTmToTime(tm); +#else + int64_t time = timegm(&tm); +#endif + if (time == -1) { + return LbugError; + } + out_result->value = time * 1000000; + return LbugSuccess; +} + +lbug_state lbug_date_to_tm(lbug_date_t date, struct tm* out_result) { + time_t time = date.days * 86400; +#ifdef _WIN32 + if (convertTimeToTm(time, out_result) != 0) { + return LbugError; + } +#else + if (gmtime_r(&time, out_result) == nullptr) { + return LbugError; + } +#endif + out_result->tm_hour = 0; + out_result->tm_min = 0; + out_result->tm_sec = 0; + return LbugSuccess; +} + +lbug_state lbug_date_from_tm(struct tm tm, lbug_date_t* out_result) { +#ifdef _WIN32 + int64_t time = convertTmToTime(tm); +#else + int64_t time = timegm(&tm); +#endif + if (time == -1) { + return LbugError; + } + out_result->days = time / 86400; + return LbugSuccess; +} + +lbug_state lbug_date_to_string(lbug_date_t date, char** out_result) { + tm tm{}; + if (lbug_date_to_tm(date, &tm) != LbugSuccess) { + return LbugError; + } + char buffer[80]; + if (strftime(buffer, 80, "%Y-%m-%d", &tm) == 0) { + return LbugError; + } + *out_result = convertToOwnedCString(buffer); + return LbugSuccess; +} + +lbug_state lbug_date_from_string(const char* str, lbug_date_t* out_result) { + try { + date_t date = Date::fromCString(str, strlen(str)); + out_result->days = date.days; + } catch (ConversionException& e) { + return LbugError; + } + return LbugSuccess; +} + +void lbug_interval_to_difftime(lbug_interval_t interval, double* out_result) { + auto micros = interval.micros + interval.months * Interval::MICROS_PER_MONTH + + interval.days * Interval::MICROS_PER_DAY; + double seconds = micros / 1000000.0; + *out_result = seconds; +} + +void lbug_interval_from_difftime(double difftime, lbug_interval_t* out_result) { + int64_t total_micros = static_cast(difftime * 1000000); + out_result->months = total_micros / Interval::MICROS_PER_MONTH; + total_micros -= out_result->months * Interval::MICROS_PER_MONTH; + out_result->days = total_micros / Interval::MICROS_PER_DAY; + total_micros -= out_result->days * Interval::MICROS_PER_DAY; + out_result->micros = total_micros; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/version.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/version.cpp new file mode 100644 index 0000000000..dd1d2651e6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/c_api/version.cpp @@ -0,0 +1,12 @@ +#include "main/version.h" + +#include "c_api/helpers.h" +#include "c_api/lbug.h" + +char* lbug_get_version() { + return convertToOwnedCString(lbug::main::Version::getVersion()); +} + +uint64_t lbug_get_storage_version() { + return lbug::main::Version::getStorageVersion(); +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/CMakeLists.txt new file mode 100644 index 0000000000..ace8f72ba8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/CMakeLists.txt @@ -0,0 +1,11 @@ +add_subdirectory(catalog_entry) + +add_library(lbug_catalog + OBJECT + catalog.cpp + catalog_set.cpp + property_definition_collection.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog.cpp new file mode 100644 index 0000000000..edc82180f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog.cpp @@ -0,0 +1,607 @@ +#include "catalog/catalog.h" + +#include "binder/ddl/bound_create_sequence_info.h" +#include "binder/ddl/bound_create_table_info.h" +#include "catalog/catalog_entry/function_catalog_entry.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "catalog/catalog_entry/type_catalog_entry.h" +#include "common/exception/catalog.h" +#include "common/exception/runtime.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/string_format.h" +#include "extension/extension_manager.h" +#include "function/function_collection.h" +#include "main/client_context.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace catalog { + +Catalog::Catalog() : version{0} { + initCatalogSets(); + registerBuiltInFunctions(); +} + +Catalog* Catalog::Get(const main::ClientContext& context) { + if (context.getAttachedDatabase()) { + return context.getAttachedDatabase()->getCatalog(); + } else { + return context.getDatabase()->getCatalog(); + } +} + +void Catalog::initCatalogSets() { + tables = std::make_unique(); + sequences = std::make_unique(); + functions = std::make_unique(); + types = std::make_unique(); + indexes = std::make_unique(); + macros = std::make_unique(); + internalTables = std::make_unique(true /* isInternal */); + internalSequences = std::make_unique(true /* isInternal */); + internalFunctions = std::make_unique(true /* isInternal */); +} + +bool Catalog::containsTable(const Transaction* transaction, const std::string& tableName, + bool useInternal) const { + if (tables->containsEntry(transaction, tableName)) { + return true; + } + if (useInternal) { + return internalTables->containsEntry(transaction, tableName); + } + return false; +} + +bool Catalog::containsTable(const Transaction* transaction, table_id_t tableID, + bool useInternal) const { + if (tables->getEntryOfOID(transaction, tableID) != nullptr) { + return true; + } + if (useInternal) { + return internalTables->getEntryOfOID(transaction, tableID) != nullptr; + } + return false; +} + +TableCatalogEntry* Catalog::getTableCatalogEntry(const Transaction* transaction, + table_id_t tableID) const { + auto result = tables->getEntryOfOID(transaction, tableID); + if (result == nullptr) { + result = internalTables->getEntryOfOID(transaction, tableID); + } + // LCOV_EXCL_START + if (result == nullptr) { + throw RuntimeException( + stringFormat("Cannot find table catalog entry with id {}.", std::to_string(tableID))); + } + // LCOV_EXCL_STOP + return result->ptrCast(); +} + +TableCatalogEntry* Catalog::getTableCatalogEntry(const Transaction* transaction, + const std::string& tableName, bool useInternal) const { + CatalogEntry* result = nullptr; + if (!tables->containsEntry(transaction, tableName)) { + if (!useInternal) { + throw CatalogException(stringFormat("{} does not exist in catalog.", tableName)); + } else { + result = internalTables->getEntry(transaction, tableName); + } + } else { + result = tables->getEntry(transaction, tableName); + } + // LCOV_EXCL_STOP + return result->ptrCast(); +} + +template +std::vector Catalog::getTableEntries(const Transaction* transaction, bool useInternal, + CatalogEntryType entryType) const { + std::vector result; + for (auto& [_, entry] : tables->getEntries(transaction)) { + if (entry->getType() != entryType) { + continue; + } + result.push_back(entry->template ptrCast()); + } + if (useInternal) { + for (auto& [_, entry] : internalTables->getEntries(transaction)) { + if (entry->getType() != entryType) { + continue; + } + result.push_back(entry->template ptrCast()); + } + } + return result; +} + +std::vector Catalog::getNodeTableEntries(const Transaction* transaction, + bool useInternal) const { + return getTableEntries(transaction, useInternal, + CatalogEntryType::NODE_TABLE_ENTRY); +} + +std::vector Catalog::getRelGroupEntries(const Transaction* transaction, + bool useInternal) const { + return getTableEntries(transaction, useInternal, + CatalogEntryType::REL_GROUP_ENTRY); +} + +std::vector Catalog::getTableEntries(const Transaction* transaction, + bool useInternal) const { + std::vector result; + for (auto& [_, entry] : tables->getEntries(transaction)) { + result.push_back(entry->ptrCast()); + } + if (useInternal) { + for (auto& [_, entry] : internalTables->getEntries(transaction)) { + result.push_back(entry->ptrCast()); + } + } + return result; +} + +void Catalog::dropTableEntryAndIndex(Transaction* transaction, const std::string& name) { + auto tableID = getTableCatalogEntry(transaction, name)->getTableID(); + dropAllIndexes(transaction, tableID); + dropTableEntry(transaction, tableID); +} + +void Catalog::dropTableEntry(Transaction* transaction, table_id_t tableID) { + dropTableEntry(transaction, getTableCatalogEntry(transaction, tableID)); +} + +void Catalog::dropTableEntry(Transaction* transaction, const TableCatalogEntry* entry) { + dropSerialSequence(transaction, entry); + if (tables->containsEntry(transaction, entry->getName())) { + tables->dropEntry(transaction, entry->getName(), entry->getOID()); + } else { + internalTables->dropEntry(transaction, entry->getName(), entry->getOID()); + } +} +void Catalog::dropMacroEntry(Transaction* transaction, const lbug::common::oid_t macroID) { + dropMacroEntry(transaction, getScalarMacroCatalogEntry(transaction, macroID)); +} + +void Catalog::dropMacroEntry(Transaction* transaction, const ScalarMacroCatalogEntry* entry) { + macros->dropEntry(transaction, entry->getName(), entry->getOID()); +} + +void Catalog::alterTableEntry(Transaction* transaction, const BoundAlterInfo& info) { + tables->alterTableEntry(transaction, info); +} + +CatalogEntry* Catalog::createRelGroupEntry(Transaction* transaction, + const BoundCreateTableInfo& info) { + const auto extraInfo = info.extraInfo->ptrCast(); + std::vector relTableInfos; + KU_ASSERT(extraInfo->nodePairs.size() > 0); + for (auto& nodePair : extraInfo->nodePairs) { + relTableInfos.emplace_back(nodePair, tables->getNextOID()); + } + auto relGroupEntry = + std::make_unique(info.tableName, extraInfo->srcMultiplicity, + extraInfo->dstMultiplicity, extraInfo->storageDirection, std::move(relTableInfos)); + for (auto& definition : extraInfo->propertyDefinitions) { + relGroupEntry->addProperty(definition); + } + KU_ASSERT(info.hasParent == false); + relGroupEntry->setHasParent(info.hasParent); + createSerialSequence(transaction, relGroupEntry.get(), info.isInternal); + auto catalogSet = info.isInternal ? internalTables.get() : tables.get(); + catalogSet->createEntry(transaction, std::move(relGroupEntry)); + return catalogSet->getEntry(transaction, info.tableName); +} + +bool Catalog::containsSequence(const Transaction* transaction, const std::string& name) const { + return sequences->containsEntry(transaction, name); +} + +SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction, + const std::string& sequenceName, bool useInternalSeq) const { + CatalogEntry* entry = nullptr; + if (!sequences->containsEntry(transaction, sequenceName) && useInternalSeq) { + entry = internalSequences->getEntry(transaction, sequenceName); + } else { + entry = sequences->getEntry(transaction, sequenceName); + } + KU_ASSERT(entry); + return entry->ptrCast(); +} + +SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction, + sequence_id_t sequenceID) const { + auto entry = internalSequences->getEntryOfOID(transaction, sequenceID); + if (entry == nullptr) { + entry = sequences->getEntryOfOID(transaction, sequenceID); + } + KU_ASSERT(entry); + return entry->ptrCast(); +} + +std::vector Catalog::getSequenceEntries( + const Transaction* transaction) const { + std::vector result; + for (auto& [_, entry] : sequences->getEntries(transaction)) { + result.push_back(entry->ptrCast()); + } + return result; +} + +sequence_id_t Catalog::createSequence(Transaction* transaction, + const BoundCreateSequenceInfo& info) { + auto entry = std::make_unique(info); + entry->setHasParent(info.hasParent); + if (info.isInternal) { + return internalSequences->createEntry(transaction, std::move(entry)); + } else { + return sequences->createEntry(transaction, std::move(entry)); + } +} + +void Catalog::dropSequence(Transaction* transaction, const std::string& name) { + const auto entry = getSequenceEntry(transaction, name); + dropSequence(transaction, entry->getOID()); +} + +void Catalog::dropSequence(Transaction* transaction, sequence_id_t sequenceID) { + const auto sequenceEntry = getSequenceEntry(transaction, sequenceID); + CatalogSet* set = nullptr; + set = sequences->containsEntry(transaction, sequenceEntry->getName()) ? sequences.get() : + internalSequences.get(); + set->dropEntry(transaction, sequenceEntry->getName(), sequenceEntry->getOID()); +} + +void Catalog::createType(Transaction* transaction, std::string name, LogicalType type) { + if (types->containsEntry(transaction, name)) { + return; + } + auto entry = std::make_unique(std::move(name), std::move(type)); + types->createEntry(transaction, std::move(entry)); +} + +static std::string getInstallExtensionMessage(std::string_view extensionName, + std::string_view entryType) { + return stringFormat("This {} exists in the {} " + "extension. You can install and load the " + "extension by running 'INSTALL {}; LOAD EXTENSION {};'.", + entryType, extensionName, extensionName, extensionName); +} + +static std::string getTypeDoesNotExistMessage(std::string_view entryName) { + std::string message = + stringFormat("{} is neither an internal type nor a user defined type.", entryName); + const auto matchingExtensionFunction = + extension::ExtensionManager::lookupExtensionsByTypeName(entryName); + if (matchingExtensionFunction.has_value()) { + message = stringFormat("{} {}", message, + getInstallExtensionMessage(matchingExtensionFunction->extensionName, "type")); + } + return message; +} + +LogicalType Catalog::getType(const Transaction* transaction, const std::string& name) const { + if (!types->containsEntry(transaction, name)) { + throw CatalogException{getTypeDoesNotExistMessage(name)}; + } + return types->getEntry(transaction, name) + ->constCast() + .getLogicalType() + .copy(); +} + +bool Catalog::containsType(const Transaction* transaction, const std::string& typeName) const { + return types->containsEntry(transaction, typeName); +} + +void Catalog::createIndex(Transaction* transaction, + std::unique_ptr indexCatalogEntry) { + KU_ASSERT(indexCatalogEntry->getType() == CatalogEntryType::INDEX_ENTRY); + indexes->createEntry(transaction, std::move(indexCatalogEntry)); +} + +IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, table_id_t tableID, + const std::string& indexName) const { + auto internalName = IndexCatalogEntry::getInternalIndexName(tableID, indexName); + return indexes->getEntry(transaction, internalName)->ptrCast(); +} + +std::vector Catalog::getIndexEntries(const Transaction* transaction) const { + std::vector result; + for (auto& [_, entry] : indexes->getEntries(transaction)) { + result.push_back(entry->ptrCast()); + } + return result; +} + +std::vector Catalog::getIndexEntries(const Transaction* transaction, + table_id_t tableID) const { + std::vector result; + for (auto& [_, entry] : indexes->getEntries(transaction)) { + auto indexEntry = entry->ptrCast(); + if (indexEntry->getTableID() == tableID) { + result.push_back(indexEntry); + } + } + return result; +} + +bool Catalog::containsIndex(const Transaction* transaction, table_id_t tableID, + const std::string& indexName) const { + return indexes->containsEntry(transaction, + IndexCatalogEntry::getInternalIndexName(tableID, indexName)); +} + +bool Catalog::containsIndex(const Transaction* transaction, table_id_t tableID, + property_id_t propertyID) const { + for (auto& [_, entry] : indexes->getEntries(transaction)) { + auto indexEntry = entry->ptrCast(); + if (indexEntry->getTableID() != tableID) { + continue; + } + if (indexEntry->containsPropertyID(propertyID)) { + return true; + } + } + return false; +} + +bool Catalog::containsUnloadedIndex(const Transaction* transaction, common::table_id_t tableID, + common::property_id_t propertyID) const { + for (auto& [_, entry] : indexes->getEntries(transaction)) { + auto indexEntry = entry->ptrCast(); + if (indexEntry->getTableID() != tableID || !indexEntry->containsPropertyID(propertyID)) { + continue; + } + if (!indexEntry->isLoaded()) { + return true; + } + } + return false; +} + +void Catalog::dropAllIndexes(Transaction* transaction, table_id_t tableID) { + for (auto catalogEntry : indexes->getEntries(transaction)) { + auto& indexCatalogEntry = catalogEntry.second->constCast(); + if (indexCatalogEntry.getTableID() == tableID) { + indexes->dropEntry(transaction, catalogEntry.first, catalogEntry.second->getOID()); + } + } +} + +void Catalog::dropIndex(Transaction* transaction, table_id_t tableID, + const std::string& indexName) const { + auto uniqueName = IndexCatalogEntry::getInternalIndexName(tableID, indexName); + const auto entry = indexes->getEntry(transaction, uniqueName); + indexes->dropEntry(transaction, uniqueName, entry->getOID()); +} + +void Catalog::dropIndex(Transaction* transaction, oid_t indexOID) { + const auto entry = indexes->getEntryOfOID(transaction, indexOID); + if (entry == nullptr) { + throw CatalogException{stringFormat("Index with OID {} does not exist.", indexOID)}; + } + indexes->dropEntry(transaction, entry->getName(), indexOID); +} + +bool Catalog::containsFunction(const Transaction* transaction, const std::string& name, + bool useInternal) const { + auto hasEntry = functions->containsEntry(transaction, name); + if (!hasEntry && useInternal) { + return internalFunctions->containsEntry(transaction, name); + } + return hasEntry; +} + +void Catalog::addFunction(Transaction* transaction, CatalogEntryType entryType, std::string name, + function::function_set functionSet, bool isInternal) { + auto& catalogSet = isInternal ? internalFunctions : functions; + if (catalogSet->containsEntry(transaction, name)) { + throw CatalogException{stringFormat("function {} already exists.", name)}; + } + catalogSet->createEntry(transaction, + std::make_unique(entryType, std::move(name), std::move(functionSet))); +} + +static std::string getFunctionDoesNotExistMessage(std::string_view entryName) { + std::string message = stringFormat("function {} does not exist.", entryName); + const auto matchingExtensionFunction = + extension::ExtensionManager::lookupExtensionsByFunctionName(entryName); + if (matchingExtensionFunction.has_value()) { + message = stringFormat("function {} is not defined. {}", entryName, + getInstallExtensionMessage(matchingExtensionFunction->extensionName, "function")); + } + return message; +} + +void Catalog::dropFunction(Transaction* transaction, const std::string& name) { + if (!containsFunction(transaction, name)) { + throw CatalogException{stringFormat("function {} doesn't exist.", name)}; + } + auto entry = getFunctionEntry(transaction, name); + functions->dropEntry(transaction, name, entry->getOID()); +} + +CatalogEntry* Catalog::getFunctionEntry(const Transaction* transaction, const std::string& name, + bool useInternal) const { + CatalogEntry* result = nullptr; + if (functions->containsEntry(transaction, name)) { + result = functions->getEntry(transaction, name); + } else if (macros->containsEntry(transaction, name)) { + result = macros->getEntry(transaction, name); + } else if (useInternal) { + result = internalFunctions->getEntry(transaction, name); + } else { + throw CatalogException(getFunctionDoesNotExistMessage(name)); + } + return result; +} + +std::vector Catalog::getMacroEntries( + const Transaction* transaction) const { + std::vector result; + for (auto& [_, entry] : macros->getEntries(transaction)) { + KU_ASSERT(entry->getType() == CatalogEntryType::SCALAR_MACRO_ENTRY); + result.push_back(entry->ptrCast()); + } + return result; +} + +std::vector Catalog::getFunctionEntries( + const Transaction* transaction) const { + std::vector result; + for (auto& [_, entry] : functions->getEntries(transaction)) { + result.push_back(entry->ptrCast()); + } + return result; +} + +bool Catalog::containsMacro(const Transaction* transaction, const std::string& macroName) const { + return macros->containsEntry(transaction, macroName); +} + +function::ScalarMacroFunction* Catalog::getScalarMacroFunction(const Transaction* transaction, + const std::string& name) const { + return macros->getEntry(transaction, name) + ->constCast() + .getMacroFunction(); +} + +// addScalarMacroFunction +void Catalog::addScalarMacroFunction(Transaction* transaction, std::string name, + std::unique_ptr macro) { + auto entry = std::make_unique(std::move(name), std::move(macro)); + macros->createEntry(transaction, std::move(entry)); +} + +ScalarMacroCatalogEntry* Catalog::getScalarMacroCatalogEntry(const Transaction* transaction, + lbug::common::oid_t macroID) const { + auto result = functions->getEntryOfOID(transaction, macroID); + if (result == nullptr) { + throw RuntimeException( + stringFormat("Cannot find macro catalog entry with id {}.", std::to_string(macroID))); + } + + return result->ptrCast(); +} + +std::vector Catalog::getMacroNames(const Transaction* transaction) const { + std::vector macroNames; + for (auto& [_, function] : macros->getEntries(transaction)) { + KU_ASSERT(function->getType() == CatalogEntryType::SCALAR_MACRO_ENTRY); + macroNames.push_back(function->getName()); + } + return macroNames; +} + +void Catalog::dropMacro(Transaction* transaction, std::string& name) { + if (!containsMacro(transaction, name)) { + throw CatalogException{stringFormat("Marco {} doesn't exist.", name)}; + } + auto entry = getFunctionEntry(transaction, name); + macros->dropEntry(transaction, name, entry->getOID()); +} + +void Catalog::registerBuiltInFunctions() { + auto functionCollection = function::FunctionCollection::getFunctions(); + for (auto i = 0u; functionCollection[i].name != nullptr; ++i) { + auto& f = functionCollection[i]; + auto functionSet = f.getFunctionSetFunc(); + functions->createEntry(&DUMMY_TRANSACTION, + std::make_unique(f.catalogEntryType, f.name, + std::move(functionSet))); + } +} + +CatalogEntry* Catalog::createTableEntry(Transaction* transaction, + const BoundCreateTableInfo& info) { + switch (info.type) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + return createNodeTableEntry(transaction, info); + } + case CatalogEntryType::REL_GROUP_ENTRY: { + return createRelGroupEntry(transaction, info); + } + default: + KU_UNREACHABLE; + } +} + +CatalogEntry* Catalog::createNodeTableEntry(Transaction* transaction, + const BoundCreateTableInfo& info) { + const auto extraInfo = info.extraInfo->constPtrCast(); + auto entry = std::make_unique(info.tableName, extraInfo->primaryKeyName); + for (auto& definition : extraInfo->propertyDefinitions) { + entry->addProperty(definition); + } + entry->setHasParent(info.hasParent); + createSerialSequence(transaction, entry.get(), info.isInternal); + auto catalogSet = info.isInternal ? internalTables.get() : tables.get(); + catalogSet->createEntry(transaction, std::move(entry)); + return catalogSet->getEntry(transaction, info.tableName); +} + +void Catalog::createSerialSequence(Transaction* transaction, const TableCatalogEntry* entry, + bool isInternal) { + for (auto& definition : entry->getProperties()) { + if (definition.getType().getLogicalTypeID() != LogicalTypeID::SERIAL) { + continue; + } + const auto seqName = + SequenceCatalogEntry::getSerialName(entry->getName(), definition.getName()); + auto seqInfo = + BoundCreateSequenceInfo(seqName, 0, 1, 0, std::numeric_limits::max(), false, + ConflictAction::ON_CONFLICT_THROW, isInternal); + seqInfo.hasParent = true; + createSequence(transaction, seqInfo); + } +} + +void Catalog::dropSerialSequence(Transaction* transaction, const TableCatalogEntry* entry) { + for (auto& definition : entry->getProperties()) { + if (definition.getType().getLogicalTypeID() != LogicalTypeID::SERIAL) { + continue; + } + auto seqName = SequenceCatalogEntry::getSerialName(entry->getName(), definition.getName()); + dropSequence(transaction, seqName); + } +} + +void Catalog::serialize(Serializer& ser) const { + tables->serialize(ser); + sequences->serialize(ser); + functions->serialize(ser); + types->serialize(ser); + indexes->serialize(ser); + macros->serialize(ser); + internalTables->serialize(ser); + internalSequences->serialize(ser); + internalFunctions->serialize(ser); +} + +void Catalog::deserialize(Deserializer& deSer) { + tables = CatalogSet::deserialize(deSer); + sequences = CatalogSet::deserialize(deSer); + functions = CatalogSet::deserialize(deSer); + registerBuiltInFunctions(); + types = CatalogSet::deserialize(deSer); + indexes = CatalogSet::deserialize(deSer); + macros = CatalogSet::deserialize(deSer); + internalTables = CatalogSet::deserialize(deSer); + internalSequences = CatalogSet::deserialize(deSer); + internalFunctions = CatalogSet::deserialize(deSer); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/CMakeLists.txt new file mode 100644 index 0000000000..e7f697c93c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(lbug_catalog_entry + OBJECT + catalog_entry.cpp + catalog_entry_type.cpp + function_catalog_entry.cpp + table_catalog_entry.cpp + node_table_catalog_entry.cpp + node_table_id_pair.cpp + rel_group_catalog_entry.cpp + scalar_macro_catalog_entry.cpp + type_catalog_entry.cpp + sequence_catalog_entry.cpp + index_catalog_entry.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/catalog_entry.cpp new file mode 100644 index 0000000000..73250d06e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/catalog_entry.cpp @@ -0,0 +1,78 @@ +#include "catalog/catalog_entry/catalog_entry.h" + +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "catalog/catalog_entry/type_catalog_entry.h" +#include "common/serializer/deserializer.h" +#include "transaction/transaction.h" + +namespace lbug { +namespace catalog { + +void CatalogEntry::serialize(common::Serializer& serializer) const { + serializer.writeDebuggingInfo("type"); + serializer.write(type); + serializer.writeDebuggingInfo("name"); + serializer.write(name); + serializer.writeDebuggingInfo("oid"); + serializer.write(oid); + serializer.writeDebuggingInfo("hasParent_"); + serializer.write(hasParent_); +} + +std::unique_ptr CatalogEntry::deserialize(common::Deserializer& deserializer) { + std::string debuggingInfo; + auto type = CatalogEntryType::DUMMY_ENTRY; + std::string name; + common::oid_t oid = common::INVALID_OID; + bool hasParent_ = false; + deserializer.validateDebuggingInfo(debuggingInfo, "type"); + deserializer.deserializeValue(type); + deserializer.validateDebuggingInfo(debuggingInfo, "name"); + deserializer.deserializeValue(name); + deserializer.validateDebuggingInfo(debuggingInfo, "oid"); + deserializer.deserializeValue(oid); + deserializer.validateDebuggingInfo(debuggingInfo, "hasParent_"); + deserializer.deserializeValue(hasParent_); + std::unique_ptr entry; + switch (type) { + case CatalogEntryType::NODE_TABLE_ENTRY: + case CatalogEntryType::REL_GROUP_ENTRY: { + entry = TableCatalogEntry::deserialize(deserializer, type); + } break; + case CatalogEntryType::SCALAR_MACRO_ENTRY: { + entry = ScalarMacroCatalogEntry::deserialize(deserializer); + } break; + case CatalogEntryType::SEQUENCE_ENTRY: { + entry = SequenceCatalogEntry::deserialize(deserializer); + } break; + case CatalogEntryType::TYPE_ENTRY: { + entry = TypeCatalogEntry::deserialize(deserializer); + } break; + case CatalogEntryType::INDEX_ENTRY: { + entry = IndexCatalogEntry::deserialize(deserializer); + } break; + default: + KU_UNREACHABLE; + } + entry->type = type; + entry->name = std::move(name); + entry->oid = oid; + entry->hasParent_ = hasParent_; + entry->timestamp = transaction::Transaction::DUMMY_START_TIMESTAMP; + return entry; +} + +void CatalogEntry::copyFrom(const CatalogEntry& other) { + type = other.type; + name = other.name; + oid = other.oid; + timestamp = other.timestamp; + deleted = other.deleted; + hasParent_ = other.hasParent_; +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/catalog_entry_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/catalog_entry_type.cpp new file mode 100644 index 0000000000..7779e4e5e9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/catalog_entry_type.cpp @@ -0,0 +1,61 @@ +#include "catalog/catalog_entry/catalog_entry_type.h" + +#include "common/assert.h" + +namespace lbug { +namespace catalog { + +std::string CatalogEntryTypeUtils::toString(CatalogEntryType type) { + switch (type) { + case CatalogEntryType::NODE_TABLE_ENTRY: + return "NODE_TABLE_ENTRY"; + case CatalogEntryType::REL_GROUP_ENTRY: + return "REL_GROUP_ENTRY"; + case CatalogEntryType::FOREIGN_TABLE_ENTRY: + return "FOREIGN_TABLE_ENTRY"; + case CatalogEntryType::SCALAR_MACRO_ENTRY: + return "SCALAR_MACRO_ENTRY"; + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: + return "AGGREGATE_FUNCTION_ENTRY"; + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + return "SCALAR_FUNCTION_ENTRY"; + case CatalogEntryType::REWRITE_FUNCTION_ENTRY: + return "REWRITE_FUNCTION_ENTRY"; + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + return "TABLE_FUNCTION_ENTRY"; + case CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY: + return "STANDALONE_TABLE_FUNCTION_ENTRY"; + case CatalogEntryType::COPY_FUNCTION_ENTRY: + return "COPY_FUNCTION_ENTRY"; + case CatalogEntryType::DUMMY_ENTRY: + return "DUMMY_ENTRY"; + case CatalogEntryType::SEQUENCE_ENTRY: + return "SEQUENCE_ENTRY"; + default: + KU_UNREACHABLE; + } +} + +std::string FunctionEntryTypeUtils::toString(CatalogEntryType type) { + switch (type) { + case CatalogEntryType::SCALAR_MACRO_ENTRY: + return "MACRO FUNCTION"; + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: + return "AGGREGATE FUNCTION"; + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + return "SCALAR FUNCTION"; + case CatalogEntryType::REWRITE_FUNCTION_ENTRY: + return "REWRITE FUNCTION"; + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + return "TABLE FUNCTION"; + case CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY: + return "STANDALONE TABLE FUNCTION"; + case CatalogEntryType::COPY_FUNCTION_ENTRY: + return "COPY FUNCTION"; + default: + KU_UNREACHABLE; + } +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/function_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/function_catalog_entry.cpp new file mode 100644 index 0000000000..c4dc949151 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/function_catalog_entry.cpp @@ -0,0 +1,11 @@ +#include "catalog/catalog_entry/function_catalog_entry.h" + +namespace lbug { +namespace catalog { + +FunctionCatalogEntry::FunctionCatalogEntry(CatalogEntryType entryType, std::string name, + function::function_set functionSet) + : CatalogEntry{entryType, std::move(name)}, functionSet{std::move(functionSet)} {} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/index_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/index_catalog_entry.cpp new file mode 100644 index 0000000000..2d4d83e7d2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/index_catalog_entry.cpp @@ -0,0 +1,84 @@ +#include "catalog/catalog_entry/index_catalog_entry.h" + +#include "common/exception/runtime.h" +#include "common/serializer/buffer_writer.h" + +namespace lbug { +namespace catalog { + +std::shared_ptr IndexAuxInfo::serialize() const { + return std::make_shared(0 /*maximumSize*/); +} + +void IndexCatalogEntry::setAuxInfo(std::unique_ptr auxInfo_) { + auxInfo = std::move(auxInfo_); + auxBuffer = nullptr; + auxBufferSize = 0; +} + +bool IndexCatalogEntry::containsPropertyID(common::property_id_t propertyID) const { + for (auto id : propertyIDs) { + if (id == propertyID) { + return true; + } + } + return false; +} + +void IndexCatalogEntry::serialize(common::Serializer& serializer) const { + CatalogEntry::serialize(serializer); + serializer.write(type); + serializer.write(tableID); + serializer.write(indexName); + serializer.serializeVector(propertyIDs); + if (isLoaded()) { + const auto bufferedWriter = auxInfo->serialize(); + serializer.write(bufferedWriter->getSize()); + serializer.write(bufferedWriter->getData().data.get(), bufferedWriter->getSize()); + } else { + serializer.write(auxBufferSize); + serializer.write(auxBuffer.get(), auxBufferSize); + } +} + +std::unique_ptr IndexCatalogEntry::deserialize( + common::Deserializer& deserializer) { + std::string type; + common::table_id_t tableID = common::INVALID_TABLE_ID; + std::string indexName; + std::vector propertyIDs; + deserializer.deserializeValue(type); + deserializer.deserializeValue(tableID); + deserializer.deserializeValue(indexName); + deserializer.deserializeVector(propertyIDs); + auto indexEntry = std::make_unique(type, tableID, std::move(indexName), + std::move(propertyIDs), nullptr /* auxInfo */); + uint64_t auxBufferSize = 0; + deserializer.deserializeValue(auxBufferSize); + indexEntry->auxBuffer = std::make_unique(auxBufferSize); + indexEntry->auxBufferSize = auxBufferSize; + deserializer.read(indexEntry->auxBuffer.get(), auxBufferSize); + return indexEntry; +} + +void IndexCatalogEntry::copyFrom(const CatalogEntry& other) { + CatalogEntry::copyFrom(other); + auto& otherTable = other.constCast(); + tableID = otherTable.tableID; + indexName = otherTable.indexName; + if (auxInfo) { + auxInfo = otherTable.auxInfo->copy(); + } +} +std::unique_ptr IndexCatalogEntry::getAuxBufferReader() const { + // LCOV_EXCL_START + if (!auxBuffer) { + throw common::RuntimeException( + common::stringFormat("Auxiliary buffer for index \"{}\" is not set.", indexName)); + } + // LCOV_EXCL_STOP + return std::make_unique(auxBuffer.get(), auxBufferSize); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/node_table_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/node_table_catalog_entry.cpp new file mode 100644 index 0000000000..53515154ed --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/node_table_catalog_entry.cpp @@ -0,0 +1,56 @@ +#include "catalog/catalog_entry/node_table_catalog_entry.h" + +#include "binder/ddl/bound_create_table_info.h" +#include "common/serializer/deserializer.h" +#include "common/string_utils.h" + +using namespace lbug::binder; + +namespace lbug { +namespace catalog { + +void NodeTableCatalogEntry::renameProperty(const std::string& propertyName, + const std::string& newName) { + TableCatalogEntry::renameProperty(propertyName, newName); + if (common::StringUtils::caseInsensitiveEquals(propertyName, primaryKeyName)) { + primaryKeyName = newName; + } +} + +void NodeTableCatalogEntry::serialize(common::Serializer& serializer) const { + TableCatalogEntry::serialize(serializer); + serializer.writeDebuggingInfo("primaryKeyName"); + serializer.write(primaryKeyName); +} + +std::unique_ptr NodeTableCatalogEntry::deserialize( + common::Deserializer& deserializer) { + std::string debuggingInfo; + std::string primaryKeyName; + deserializer.validateDebuggingInfo(debuggingInfo, "primaryKeyName"); + deserializer.deserializeValue(primaryKeyName); + auto nodeTableEntry = std::make_unique(); + nodeTableEntry->primaryKeyName = primaryKeyName; + return nodeTableEntry; +} + +std::string NodeTableCatalogEntry::toCypher(const ToCypherInfo& /*info*/) const { + return common::stringFormat("CREATE NODE TABLE `{}` ({} PRIMARY KEY(`{}`));", getName(), + propertyCollection.toCypher(), primaryKeyName); +} + +std::unique_ptr NodeTableCatalogEntry::copy() const { + auto other = std::make_unique(); + other->primaryKeyName = primaryKeyName; + other->copyFrom(*this); + return other; +} + +std::unique_ptr NodeTableCatalogEntry::getBoundExtraCreateInfo( + transaction::Transaction*) const { + return std::make_unique(primaryKeyName, + copyVector(getProperties())); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/node_table_id_pair.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/node_table_id_pair.cpp new file mode 100644 index 0000000000..3d98763d04 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/node_table_id_pair.cpp @@ -0,0 +1,30 @@ +#include "catalog/catalog_entry/node_table_id_pair.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" + +using namespace lbug::common; + +namespace lbug { +namespace catalog { + +void NodeTableIDPair::serialize(Serializer& serializer) const { + serializer.writeDebuggingInfo("srcTableID"); + serializer.serializeValue(srcTableID); + serializer.writeDebuggingInfo("dstTableID"); + serializer.serializeValue(dstTableID); +} + +NodeTableIDPair NodeTableIDPair::deserialize(Deserializer& deser) { + std::string debuggingInfo; + table_id_t srcTableID = INVALID_TABLE_ID; + table_id_t dstTableID = INVALID_TABLE_ID; + deser.validateDebuggingInfo(debuggingInfo, "srcTableID"); + deser.deserializeValue(srcTableID); + deser.validateDebuggingInfo(debuggingInfo, "dstTableID"); + deser.deserializeValue(dstTableID); + return NodeTableIDPair{srcTableID, dstTableID}; +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/rel_group_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/rel_group_catalog_entry.cpp new file mode 100644 index 0000000000..b87722c6bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/rel_group_catalog_entry.cpp @@ -0,0 +1,187 @@ +#include "catalog/catalog_entry/rel_group_catalog_entry.h" + +#include + +#include "binder/ddl/bound_create_table_info.h" +#include "catalog/catalog.h" +#include "common/serializer/deserializer.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace catalog { + +void RelGroupCatalogEntry::addFromToConnection(table_id_t srcTableID, table_id_t dstTableID, + oid_t oid) { + relTableInfos.emplace_back(NodeTableIDPair{srcTableID, dstTableID}, oid); +} + +void RelGroupCatalogEntry::dropFromToConnection(table_id_t srcTableID, table_id_t dstTableID) { + auto tmpInfos = relTableInfos; + relTableInfos.clear(); + for (auto& tmpInfo : tmpInfos) { + if (tmpInfo.nodePair.srcTableID == srcTableID && + tmpInfo.nodePair.dstTableID == dstTableID) { + continue; + } + relTableInfos.emplace_back(tmpInfo); + } +} + +void RelTableCatalogInfo::serialize(Serializer& ser) const { + ser.writeDebuggingInfo("nodePair"); + nodePair.serialize(ser); + ser.writeDebuggingInfo("oid"); + ser.serializeValue(oid); +} + +RelTableCatalogInfo RelTableCatalogInfo::deserialize(Deserializer& deser) { + std::string debuggingInfo; + oid_t oid = INVALID_OID; + deser.validateDebuggingInfo(debuggingInfo, "nodePair"); + auto nodePair = NodeTableIDPair::deserialize(deser); + deser.validateDebuggingInfo(debuggingInfo, "oid"); + deser.deserializeValue(oid); + return RelTableCatalogInfo{nodePair, oid}; +} + +bool RelGroupCatalogEntry::isParent(table_id_t tableID) { + for (auto& info : relTableInfos) { + if (info.nodePair.srcTableID == tableID || info.nodePair.dstTableID == tableID) { + return true; + } + } + return false; +} + +const RelTableCatalogInfo& RelGroupCatalogEntry::getSingleRelEntryInfo() const { + KU_ASSERT(relTableInfos.size() == 1); + return relTableInfos[0]; +} + +const RelTableCatalogInfo* RelGroupCatalogEntry::getRelEntryInfo(table_id_t srcTableID, + table_id_t dstTableID) const { + for (auto& info : relTableInfos) { + if (info.nodePair.srcTableID == srcTableID && info.nodePair.dstTableID == dstTableID) { + return &info; + } + } + return nullptr; +} + +std::unordered_set RelGroupCatalogEntry::getSrcNodeTableIDSet() const { + std::unordered_set result; + for (auto& info : relTableInfos) { + result.insert(info.nodePair.srcTableID); + } + return result; +} + +std::unordered_set RelGroupCatalogEntry::getDstNodeTableIDSet() const { + std::unordered_set result; + for (auto& info : relTableInfos) { + result.insert(info.nodePair.dstTableID); + } + return result; +} + +void RelGroupCatalogEntry::serialize(Serializer& serializer) const { + TableCatalogEntry::serialize(serializer); + serializer.writeDebuggingInfo("srcMultiplicity"); + serializer.serializeValue(srcMultiplicity); + serializer.writeDebuggingInfo("dstMultiplicity"); + serializer.serializeValue(dstMultiplicity); + serializer.writeDebuggingInfo("storageDirection"); + serializer.serializeValue(storageDirection); + serializer.writeDebuggingInfo("relTableInfos"); + serializer.serializeVector(relTableInfos); +} + +std::unique_ptr RelGroupCatalogEntry::deserialize( + Deserializer& deserializer) { + std::string debuggingInfo; + auto srcMultiplicity = RelMultiplicity::MANY; + auto dstMultiplicity = RelMultiplicity::MANY; + auto storageDirection = ExtendDirection::BOTH; + std::vector relTableInfos; + deserializer.validateDebuggingInfo(debuggingInfo, "srcMultiplicity"); + deserializer.deserializeValue(srcMultiplicity); + deserializer.validateDebuggingInfo(debuggingInfo, "dstMultiplicity"); + deserializer.deserializeValue(dstMultiplicity); + deserializer.validateDebuggingInfo(debuggingInfo, "storageDirection"); + deserializer.deserializeValue(storageDirection); + deserializer.validateDebuggingInfo(debuggingInfo, "relTableInfos"); + deserializer.deserializeVector(relTableInfos); + auto relGroupEntry = std::make_unique(); + relGroupEntry->srcMultiplicity = srcMultiplicity; + relGroupEntry->dstMultiplicity = dstMultiplicity; + relGroupEntry->storageDirection = storageDirection; + relGroupEntry->relTableInfos = relTableInfos; + return relGroupEntry; +} + +static std::string getFromToStr(const NodeTableIDPair& pair, const Catalog* catalog, + const transaction::Transaction* transaction) { + auto srcTableName = catalog->getTableCatalogEntry(transaction, pair.srcTableID)->getName(); + auto dstTableName = catalog->getTableCatalogEntry(transaction, pair.dstTableID)->getName(); + return stringFormat("FROM `{}` TO `{}`", srcTableName, dstTableName); +} + +std::string RelGroupCatalogEntry::toCypher(const ToCypherInfo& info) const { + auto relGroupInfo = info.constCast(); + auto catalog = Catalog::Get(*relGroupInfo.context); + auto transaction = transaction::Transaction::Get(*relGroupInfo.context); + std::stringstream ss; + ss << stringFormat("CREATE REL TABLE `{}` (", getName()); + KU_ASSERT(!relTableInfos.empty()); + ss << getFromToStr(relTableInfos[0].nodePair, catalog, transaction); + for (auto i = 1u; i < relTableInfos.size(); ++i) { + ss << stringFormat(", {}", getFromToStr(relTableInfos[i].nodePair, catalog, transaction)); + } + ss << ", " << propertyCollection.toCypher() << RelMultiplicityUtils::toString(srcMultiplicity) + << "_" << RelMultiplicityUtils::toString(dstMultiplicity) << ");"; + return ss.str(); +} + +std::vector RelGroupCatalogEntry::getRelDataDirections() const { + switch (storageDirection) { + case ExtendDirection::FWD: { + return {RelDataDirection::FWD}; + } + case ExtendDirection::BWD: { + return {RelDataDirection::BWD}; + } + case ExtendDirection::BOTH: { + return {RelDataDirection::FWD, RelDataDirection::BWD}; + } + default: { + KU_UNREACHABLE; + } + } +} + +std::unique_ptr RelGroupCatalogEntry::copy() const { + auto other = std::make_unique(); + other->srcMultiplicity = srcMultiplicity; + other->dstMultiplicity = dstMultiplicity; + other->storageDirection = storageDirection; + other->relTableInfos = relTableInfos; + other->copyFrom(*this); + return other; +} + +std::unique_ptr +RelGroupCatalogEntry::getBoundExtraCreateInfo(transaction::Transaction*) const { + std::vector nodePairs; + for (auto& relTableInfo : relTableInfos) { + nodePairs.push_back(relTableInfo.nodePair); + } + return std::make_unique( + copyVector(propertyCollection.getDefinitions()), srcMultiplicity, dstMultiplicity, + storageDirection, std::move(nodePairs)); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/scalar_macro_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/scalar_macro_catalog_entry.cpp new file mode 100644 index 0000000000..7cebd6d328 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/scalar_macro_catalog_entry.cpp @@ -0,0 +1,29 @@ +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" + +namespace lbug { +namespace catalog { + +ScalarMacroCatalogEntry::ScalarMacroCatalogEntry(std::string name, + std::unique_ptr macroFunction) + : CatalogEntry{CatalogEntryType::SCALAR_MACRO_ENTRY, std::move(name)}, + macroFunction{std::move(macroFunction)} {} + +void ScalarMacroCatalogEntry::serialize(common::Serializer& serializer) const { + CatalogEntry::serialize(serializer); + macroFunction->serialize(serializer); +} + +std::unique_ptr ScalarMacroCatalogEntry::deserialize( + common::Deserializer& deserializer) { + auto scalarMacroCatalogEntry = std::make_unique(); + scalarMacroCatalogEntry->macroFunction = + function::ScalarMacroFunction::deserialize(deserializer); + return scalarMacroCatalogEntry; +} + +std::string ScalarMacroCatalogEntry::toCypher(const ToCypherInfo& /*info*/) const { + return macroFunction->toCypher(getName()); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/sequence_catalog_entry.cpp new file mode 100644 index 0000000000..aeb8c8d641 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/sequence_catalog_entry.cpp @@ -0,0 +1,173 @@ +#include "catalog/catalog_entry/sequence_catalog_entry.h" + +#include "binder/ddl/bound_create_sequence_info.h" +#include "common/exception/catalog.h" +#include "common/exception/overflow.h" +#include "common/serializer/deserializer.h" +#include "common/vector/value_vector.h" +#include "function/arithmetic/add.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace catalog { + +SequenceData SequenceCatalogEntry::getSequenceData() { + std::lock_guard lck(mtx); + return sequenceData; +} + +int64_t SequenceCatalogEntry::currVal() { + std::lock_guard lck(mtx); + if (sequenceData.usageCount == 0) { + throw CatalogException( + "currval: sequence \"" + name + + "\" is not yet defined. To define the sequence, call nextval first."); + } + return sequenceData.currVal; +} + +void SequenceCatalogEntry::nextValNoLock() { + if (sequenceData.usageCount == 0) { + // initialization of sequence + sequenceData.usageCount++; + return; + } + bool overflow = false; + auto next = sequenceData.currVal; + try { + function::Add::operation(next, sequenceData.increment, next); + } catch (const OverflowException&) { + overflow = true; + } + if (sequenceData.cycle) { + if (overflow) { + next = sequenceData.increment < 0 ? sequenceData.maxValue : sequenceData.minValue; + } else if (next < sequenceData.minValue) { + next = sequenceData.maxValue; + } else if (next > sequenceData.maxValue) { + next = sequenceData.minValue; + } + } else { + const bool minError = overflow ? sequenceData.increment < 0 : next < sequenceData.minValue; + const bool maxError = overflow ? sequenceData.increment > 0 : next > sequenceData.maxValue; + if (minError) { + throw CatalogException("nextval: reached minimum value of sequence \"" + name + "\" " + + std::to_string(sequenceData.minValue)); + } + if (maxError) { + throw CatalogException("nextval: reached maximum value of sequence \"" + name + "\" " + + std::to_string(sequenceData.maxValue)); + } + } + sequenceData.currVal = next; + sequenceData.usageCount++; +} + +// referenced from DuckDB +void SequenceCatalogEntry::nextKVal(transaction::Transaction* transaction, const uint64_t& count) { + KU_ASSERT(count > 0); + SequenceRollbackData rollbackData{}; + { + std::lock_guard lck(mtx); + rollbackData = SequenceRollbackData{sequenceData.usageCount, sequenceData.currVal}; + for (auto i = 0ul; i < count; i++) { + nextValNoLock(); + } + } + transaction->pushSequenceChange(this, count, rollbackData); +} + +void SequenceCatalogEntry::nextKVal(transaction::Transaction* transaction, const uint64_t& count, + ValueVector& resultVector) { + KU_ASSERT(count > 0); + SequenceRollbackData rollbackData{}; + { + std::lock_guard lck(mtx); + rollbackData = SequenceRollbackData{sequenceData.usageCount, sequenceData.currVal}; + for (auto i = 0ul; i < count; i++) { + nextValNoLock(); + resultVector.setValue(i, sequenceData.currVal); + } + } + transaction->pushSequenceChange(this, count, rollbackData); +} + +void SequenceCatalogEntry::rollbackVal(const uint64_t& usageCount, const int64_t& currVal) { + std::lock_guard lck(mtx); + sequenceData.usageCount = usageCount; + sequenceData.currVal = currVal; +} + +void SequenceCatalogEntry::serialize(Serializer& serializer) const { + CatalogEntry::serialize(serializer); + serializer.writeDebuggingInfo("usageCount"); + serializer.write(sequenceData.usageCount); + serializer.writeDebuggingInfo("currVal"); + serializer.write(sequenceData.currVal); + serializer.writeDebuggingInfo("increment"); + serializer.write(sequenceData.increment); + serializer.writeDebuggingInfo("startValue"); + serializer.write(sequenceData.startValue); + serializer.writeDebuggingInfo("minValue"); + serializer.write(sequenceData.minValue); + serializer.writeDebuggingInfo("maxValue"); + serializer.write(sequenceData.maxValue); + serializer.writeDebuggingInfo("cycle"); + serializer.write(sequenceData.cycle); +} + +std::unique_ptr SequenceCatalogEntry::deserialize( + Deserializer& deserializer) { + std::string debuggingInfo; + uint64_t usageCount = 0; + int64_t currVal = 0; + int64_t increment = 0; + int64_t startValue = 0; + int64_t minValue = 0; + int64_t maxValue = 0; + bool cycle = false; + deserializer.validateDebuggingInfo(debuggingInfo, "usageCount"); + deserializer.deserializeValue(usageCount); + deserializer.validateDebuggingInfo(debuggingInfo, "currVal"); + deserializer.deserializeValue(currVal); + deserializer.validateDebuggingInfo(debuggingInfo, "increment"); + deserializer.deserializeValue(increment); + deserializer.validateDebuggingInfo(debuggingInfo, "startValue"); + deserializer.deserializeValue(startValue); + deserializer.validateDebuggingInfo(debuggingInfo, "minValue"); + deserializer.deserializeValue(minValue); + deserializer.validateDebuggingInfo(debuggingInfo, "maxValue"); + deserializer.deserializeValue(maxValue); + deserializer.validateDebuggingInfo(debuggingInfo, "cycle"); + deserializer.deserializeValue(cycle); + auto result = std::make_unique(); + result->sequenceData.usageCount = usageCount; + result->sequenceData.currVal = currVal; + result->sequenceData.increment = increment; + result->sequenceData.startValue = startValue; + result->sequenceData.minValue = minValue; + result->sequenceData.maxValue = maxValue; + result->sequenceData.cycle = cycle; + return result; +} + +std::string SequenceCatalogEntry::toCypher(const ToCypherInfo& /* info */) const { + return stringFormat("DROP SEQUENCE IF EXISTS `{}`;\n" + "CREATE SEQUENCE IF NOT EXISTS `{}` START {} INCREMENT {} MINVALUE {} " + "MAXVALUE {} {} CYCLE;\n" + "RETURN nextval('{}');", + getName(), getName(), sequenceData.currVal, sequenceData.increment, sequenceData.minValue, + sequenceData.maxValue, sequenceData.cycle ? "" : "NO", getName()); +} + +BoundCreateSequenceInfo SequenceCatalogEntry::getBoundCreateSequenceInfo(bool isInternal) const { + return BoundCreateSequenceInfo(name, sequenceData.startValue, sequenceData.increment, + sequenceData.minValue, sequenceData.maxValue, sequenceData.cycle, + ConflictAction::ON_CONFLICT_THROW, isInternal); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/table_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/table_catalog_entry.cpp new file mode 100644 index 0000000000..4ec34c733a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -0,0 +1,153 @@ +#include "catalog/catalog_entry/table_catalog_entry.h" + +#include "binder/ddl/bound_alter_info.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/serializer/deserializer.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace catalog { + +std::unique_ptr TableCatalogEntry::alter(transaction_t timestamp, + const BoundAlterInfo& alterInfo, CatalogSet* tables) const { + KU_ASSERT(!deleted); + auto newEntry = copy(); + switch (alterInfo.alterType) { + case AlterType::RENAME: { + auto& renameTableInfo = *alterInfo.extraInfo->constPtrCast(); + newEntry->rename(renameTableInfo.newName); + } break; + case AlterType::RENAME_PROPERTY: { + auto& renamePropInfo = *alterInfo.extraInfo->constPtrCast(); + newEntry->renameProperty(renamePropInfo.oldName, renamePropInfo.newName); + } break; + case AlterType::ADD_PROPERTY: { + auto& addPropInfo = *alterInfo.extraInfo->constPtrCast(); + newEntry->addProperty(addPropInfo.propertyDefinition); + } break; + case AlterType::DROP_PROPERTY: { + auto& dropPropInfo = *alterInfo.extraInfo->constPtrCast(); + newEntry->dropProperty(dropPropInfo.propertyName); + } break; + case AlterType::COMMENT: { + auto& commentInfo = *alterInfo.extraInfo->constPtrCast(); + newEntry->setComment(commentInfo.comment); + } break; + case AlterType::ADD_FROM_TO_CONNECTION: { + auto& connectionInfo = + *alterInfo.extraInfo->constPtrCast(); + newEntry->ptrCast()->addFromToConnection(connectionInfo.fromTableID, + connectionInfo.toTableID, tables->getNextOIDNoLock()); + } break; + case AlterType::DROP_FROM_TO_CONNECTION: { + auto& connectionInfo = + *alterInfo.extraInfo->constPtrCast(); + newEntry->ptrCast()->dropFromToConnection(connectionInfo.fromTableID, + connectionInfo.toTableID); + } break; + default: { + KU_UNREACHABLE; + } + } + newEntry->setOID(oid); + newEntry->setTimestamp(timestamp); + return newEntry; +} + +column_id_t TableCatalogEntry::getMaxColumnID() const { + return propertyCollection.getMaxColumnID(); +} + +void TableCatalogEntry::vacuumColumnIDs(column_id_t nextColumnID) { + propertyCollection.vacuumColumnIDs(nextColumnID); +} + +bool TableCatalogEntry::containsProperty(const std::string& propertyName) const { + return propertyCollection.contains(propertyName); +} + +property_id_t TableCatalogEntry::getPropertyID(const std::string& propertyName) const { + return propertyCollection.getPropertyID(propertyName); +} + +const PropertyDefinition& TableCatalogEntry::getProperty(const std::string& propertyName) const { + return propertyCollection.getDefinition(propertyName); +} + +const PropertyDefinition& TableCatalogEntry::getProperty(idx_t idx) const { + return propertyCollection.getDefinition(idx); +} + +column_id_t TableCatalogEntry::getColumnID(const std::string& propertyName) const { + return propertyCollection.getColumnID(propertyName); +} + +common::column_id_t TableCatalogEntry::getColumnID(common::idx_t idx) const { + return propertyCollection.getColumnID(idx); +} + +void TableCatalogEntry::addProperty(const PropertyDefinition& propertyDefinition) { + propertyCollection.add(propertyDefinition); +} + +void TableCatalogEntry::dropProperty(const std::string& propertyName) { + propertyCollection.drop(propertyName); +} + +void TableCatalogEntry::renameProperty(const std::string& propertyName, + const std::string& newName) { + propertyCollection.rename(propertyName, newName); +} + +void TableCatalogEntry::serialize(Serializer& serializer) const { + CatalogEntry::serialize(serializer); + serializer.writeDebuggingInfo("comment"); + serializer.write(comment); + serializer.writeDebuggingInfo("properties"); + propertyCollection.serialize(serializer); +} + +std::unique_ptr TableCatalogEntry::deserialize(Deserializer& deserializer, + CatalogEntryType type) { + std::string debuggingInfo; + std::string comment; + deserializer.validateDebuggingInfo(debuggingInfo, "comment"); + deserializer.deserializeValue(comment); + deserializer.validateDebuggingInfo(debuggingInfo, "properties"); + auto propertyCollection = PropertyDefinitionCollection::deserialize(deserializer); + std::unique_ptr result; + switch (type) { + case CatalogEntryType::NODE_TABLE_ENTRY: + result = NodeTableCatalogEntry::deserialize(deserializer); + break; + case CatalogEntryType::REL_GROUP_ENTRY: + result = RelGroupCatalogEntry::deserialize(deserializer); + break; + default: + KU_UNREACHABLE; + } + result->comment = std::move(comment); + result->propertyCollection = std::move(propertyCollection); + return result; +} + +void TableCatalogEntry::copyFrom(const CatalogEntry& other) { + CatalogEntry::copyFrom(other); + auto& otherTable = ku_dynamic_cast(other); + comment = otherTable.comment; + propertyCollection = otherTable.propertyCollection.copy(); +} + +BoundCreateTableInfo TableCatalogEntry::getBoundCreateTableInfo( + transaction::Transaction* transaction, bool isInternal) const { + auto extraInfo = getBoundExtraCreateInfo(transaction); + return BoundCreateTableInfo(type, name, ConflictAction::ON_CONFLICT_THROW, std::move(extraInfo), + isInternal, hasParent_); +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/type_catalog_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/type_catalog_entry.cpp new file mode 100644 index 0000000000..215cda26e9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_entry/type_catalog_entry.cpp @@ -0,0 +1,24 @@ +#include "catalog/catalog_entry/type_catalog_entry.h" + +#include "common/serializer/deserializer.h" + +namespace lbug { +namespace catalog { + +void TypeCatalogEntry::serialize(common::Serializer& serializer) const { + CatalogEntry::serialize(serializer); + serializer.writeDebuggingInfo("type"); + type.serialize(serializer); +} + +std::unique_ptr TypeCatalogEntry::deserialize( + common::Deserializer& deserializer) { + std::string debuggingInfo; + auto typeCatalogEntry = std::make_unique(); + deserializer.validateDebuggingInfo(debuggingInfo, "type"); + typeCatalogEntry->type = common::LogicalType::deserialize(deserializer); + return typeCatalogEntry; +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_set.cpp new file mode 100644 index 0000000000..051cf65daf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/catalog_set.cpp @@ -0,0 +1,302 @@ +#include "catalog/catalog_set.h" + +#include + +#include "binder/ddl/bound_alter_info.h" +#include "catalog/catalog_entry/dummy_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/assert.h" +#include "common/exception/catalog.h" +#include "common/serializer/deserializer.h" +#include "common/string_format.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace catalog { + +CatalogSet::CatalogSet(bool isInternal) { + if (isInternal) { + nextOID = INTERNAL_CATALOG_SET_START_OID; + } +} + +static bool checkWWConflict(const Transaction* transaction, const CatalogEntry* entry) { + return (entry->getTimestamp() >= Transaction::START_TRANSACTION_ID && + entry->getTimestamp() != transaction->getID()) || + (entry->getTimestamp() < Transaction::START_TRANSACTION_ID && + entry->getTimestamp() > transaction->getStartTS()); +} + +bool CatalogSet::containsEntry(const Transaction* transaction, const std::string& name) { + std::shared_lock lck{mtx}; + return containsEntryNoLock(transaction, name); +} + +bool CatalogSet::containsEntryNoLock(const Transaction* transaction, + const std::string& name) const { + if (!entries.contains(name)) { + return false; + } + // Check versions. + const auto entry = + traverseVersionChainsForTransactionNoLock(transaction, entries.at(name).get()); + return !entry->isDeleted(); +} + +CatalogEntry* CatalogSet::getEntry(const Transaction* transaction, const std::string& name) { + std::shared_lock lck{mtx}; + return getEntryNoLock(transaction, name); +} + +CatalogEntry* CatalogSet::getEntryNoLock(const Transaction* transaction, + const std::string& name) const { + // LCOV_EXCL_START + validateExistNoLock(transaction, name); + // LCOV_EXCL_STOP + const auto entry = + traverseVersionChainsForTransactionNoLock(transaction, entries.at(name).get()); + KU_ASSERT(entry != nullptr && !entry->isDeleted()); + return entry; +} + +oid_t CatalogSet::createEntry(Transaction* transaction, std::unique_ptr entry) { + CatalogEntry* entryPtr = nullptr; + oid_t oid = INVALID_OID; + { + std::unique_lock lck{mtx}; + oid = nextOID++; + entry->setOID(oid); + entryPtr = createEntryNoLock(transaction, std::move(entry)); + } + KU_ASSERT(entryPtr); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushCreateDropCatalogEntry(*this, *entryPtr, isInternal()); + } + return oid; +} + +CatalogEntry* CatalogSet::createEntryNoLock(const Transaction* transaction, + std::unique_ptr entry) { + // LCOV_EXCL_START + validateNotExistNoLock(transaction, entry->getName()); + // LCOV_EXCL_STOP + entry->setTimestamp(transaction->getID()); + if (entries.contains(entry->getName())) { + const auto existingEntry = entries.at(entry->getName()).get(); + if (checkWWConflict(transaction, existingEntry)) { + throw CatalogException(stringFormat( + "Write-write conflict on creating catalog entry with name {}.", entry->getName())); + } + if (!existingEntry->isDeleted()) { + throw CatalogException( + stringFormat("Catalog entry with name {} already exists.", entry->getName())); + } + } + auto dummyEntry = createDummyEntryNoLock(entry->getName(), entry->getOID()); + entries.emplace(entry->getName(), std::move(dummyEntry)); + const auto entryPtr = entry.get(); + emplaceNoLock(std::move(entry)); + return entryPtr->getPrev(); +} + +void CatalogSet::emplaceNoLock(std::unique_ptr entry) { + if (entries.contains(entry->getName())) { + entry->setPrev(std::move(entries.at(entry->getName()))); + entries.erase(entry->getName()); + } + entries.emplace(entry->getName(), std::move(entry)); +} + +void CatalogSet::eraseNoLock(const std::string& name) { + entries.erase(name); +} + +std::unique_ptr CatalogSet::createDummyEntryNoLock(std::string name, oid_t oid) { + return std::make_unique(std::move(name), oid); +} + +CatalogEntry* CatalogSet::traverseVersionChainsForTransactionNoLock(const Transaction* transaction, + CatalogEntry* currentEntry) { + while (currentEntry) { + if (currentEntry->getTimestamp() == transaction->getID()) { + // This entry is created by the current transaction. + break; + } + if (currentEntry->getTimestamp() <= transaction->getStartTS()) { + // This entry was committed before the current transaction starts. + break; + } + currentEntry = currentEntry->getPrev(); + } + return currentEntry; +} + +CatalogEntry* CatalogSet::getCommittedEntryNoLock(CatalogEntry* entry) { + while (entry) { + if (entry->getTimestamp() < Transaction::START_TRANSACTION_ID) { + break; + } + entry = entry->getPrev(); + } + return entry; +} + +void CatalogSet::dropEntry(Transaction* transaction, const std::string& name, oid_t oid) { + CatalogEntry* entryPtr = nullptr; + { + std::unique_lock lck{mtx}; + entryPtr = dropEntryNoLock(transaction, name, oid); + } + KU_ASSERT(entryPtr); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushCreateDropCatalogEntry(*this, *entryPtr, isInternal()); + } +} + +CatalogEntry* CatalogSet::dropEntryNoLock(const Transaction* transaction, const std::string& name, + oid_t oid) { + // LCOV_EXCL_START + validateExistNoLock(transaction, name); + // LCOV_EXCL_STOP + auto tombstone = createDummyEntryNoLock(name, oid); + tombstone->setTimestamp(transaction->getID()); + const auto tombstonePtr = tombstone.get(); + emplaceNoLock(std::move(tombstone)); + return tombstonePtr->getPrev(); +} + +void CatalogSet::alterTableEntry(Transaction* transaction, + const binder::BoundAlterInfo& alterInfo) { + std::unique_lock lck{mtx}; + // LCOV_EXCL_START + validateExistNoLock(transaction, alterInfo.tableName); + // LCOV_EXCL_STOP + auto entry = getEntryNoLock(transaction, alterInfo.tableName); + KU_ASSERT(entry->getType() == CatalogEntryType::NODE_TABLE_ENTRY || + entry->getType() == CatalogEntryType::REL_GROUP_ENTRY); + const auto tableEntry = entry->ptrCast(); + auto newEntry = tableEntry->alter(transaction->getID(), alterInfo, this); + switch (alterInfo.alterType) { + case AlterType::RENAME: { + // We treat rename table as drop and create. + dropEntryNoLock(transaction, alterInfo.tableName, entry->getOID()); + auto createdEntry = createEntryNoLock(transaction, std::move(newEntry)); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushAlterCatalogEntry(*this, *entry, alterInfo); + transaction->pushCreateDropCatalogEntry(*this, *createdEntry, isInternal(), + true /* skipLoggingToWAL */); + } + } break; + case AlterType::COMMENT: + case AlterType::ADD_PROPERTY: + case AlterType::DROP_PROPERTY: + case AlterType::RENAME_PROPERTY: + case AlterType::ADD_FROM_TO_CONNECTION: + case AlterType::DROP_FROM_TO_CONNECTION: { + emplaceNoLock(std::move(newEntry)); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushAlterCatalogEntry(*this, *entry, alterInfo); + } + } break; + default: { + KU_UNREACHABLE; + } + } +} + +CatalogEntrySet CatalogSet::getEntries(const Transaction* transaction) { + CatalogEntrySet result; + std::shared_lock lck{mtx}; + for (auto& [name, entry] : entries) { + auto currentEntry = traverseVersionChainsForTransactionNoLock(transaction, entry.get()); + if (currentEntry->isDeleted()) { + continue; + } + result.emplace(name, currentEntry); + } + return result; +} + +CatalogEntry* CatalogSet::getEntryOfOID(const Transaction* transaction, oid_t oid) { + std::shared_lock lck{mtx}; + for (auto& [_, entry] : entries) { + if (entry->getOID() != oid) { + continue; + } + const auto currentEntry = + traverseVersionChainsForTransactionNoLock(transaction, entry.get()); + if (currentEntry->isDeleted()) { + continue; + } + return currentEntry; + } + return nullptr; +} + +void CatalogSet::serialize(Serializer serializer) const { + std::vector entriesToSerialize; + for (auto& [_, entry] : entries) { + switch (entry->getType()) { + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + case CatalogEntryType::REWRITE_FUNCTION_ENTRY: + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: + case CatalogEntryType::COPY_FUNCTION_ENTRY: + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + case CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY: + continue; + default: { + auto committedEntry = getCommittedEntryNoLock(entry.get()); + if (committedEntry && !committedEntry->isDeleted()) { + entriesToSerialize.push_back(committedEntry); + } + } + } + } + serializer.writeDebuggingInfo("nextOID"); + serializer.serializeValue(nextOID); + serializer.writeDebuggingInfo("numEntries"); + const uint64_t numEntriesToSerialize = entriesToSerialize.size(); + serializer.serializeValue(numEntriesToSerialize); + for (const auto entry : entriesToSerialize) { + entry->serialize(serializer); + } +} + +std::unique_ptr CatalogSet::deserialize(Deserializer& deserializer) { + std::string debuggingInfo; + auto catalogSet = std::make_unique(); + deserializer.validateDebuggingInfo(debuggingInfo, "nextOID"); + deserializer.deserializeValue(catalogSet->nextOID); + uint64_t numEntries = 0; + deserializer.validateDebuggingInfo(debuggingInfo, "numEntries"); + deserializer.deserializeValue(numEntries); + for (uint64_t i = 0; i < numEntries; i++) { + auto entry = CatalogEntry::deserialize(deserializer); + if (entry != nullptr) { + catalogSet->emplaceNoLock(std::move(entry)); + } + } + return catalogSet; +} + +// Ideally we should not trigger the following check. Instead, we should throw more informative +// error message at catalog level. +void CatalogSet::validateExistNoLock(const Transaction* transaction, + const std::string& name) const { + if (!containsEntryNoLock(transaction, name)) { + throw CatalogException(stringFormat("{} does not exist in catalog.", name)); + } +} + +void CatalogSet::validateNotExistNoLock(const Transaction* transaction, + const std::string& name) const { + if (containsEntryNoLock(transaction, name)) { + throw CatalogException(stringFormat("{} already exists in catalog.", name)); + } +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/property_definition_collection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/property_definition_collection.cpp new file mode 100644 index 0000000000..fcf0ddbb66 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/catalog/property_definition_collection.cpp @@ -0,0 +1,147 @@ +#include "catalog/property_definition_collection.h" + +#include +#include + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/string_utils.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace catalog { + +std::vector PropertyDefinitionCollection::getDefinitions() const { + std::vector propertyDefinitions; + for (auto i = 0u; i < nextPropertyID; i++) { + if (definitions.contains(i)) { + propertyDefinitions.push_back(definitions.at(i).copy()); + } + } + return propertyDefinitions; +} + +const PropertyDefinition& PropertyDefinitionCollection::getDefinition( + const std::string& name) const { + return getDefinition(getPropertyID(name)); +} + +const PropertyDefinition& PropertyDefinitionCollection::getDefinition( + property_id_t propertyID) const { + KU_ASSERT(definitions.contains(propertyID)); + return definitions.at(propertyID); +} + +column_id_t PropertyDefinitionCollection::getColumnID(const std::string& name) const { + return getColumnID(getPropertyID(name)); +} + +column_id_t PropertyDefinitionCollection::getColumnID(property_id_t propertyID) const { + KU_ASSERT(columnIDs.contains(propertyID)); + return columnIDs.at(propertyID); +} + +void PropertyDefinitionCollection::vacuumColumnIDs(column_id_t nextColumnID) { + this->nextColumnID = nextColumnID; + columnIDs.clear(); + for (auto& [propertyID, definition] : definitions) { + columnIDs.emplace(propertyID, this->nextColumnID++); + } +} + +void PropertyDefinitionCollection::add(const PropertyDefinition& definition) { + auto propertyID = nextPropertyID++; + columnIDs.emplace(propertyID, nextColumnID++); + definitions.emplace(propertyID, definition.copy()); + nameToPropertyIDMap.emplace(definition.getName(), propertyID); +} + +void PropertyDefinitionCollection::drop(const std::string& name) { + KU_ASSERT(contains(name)); + auto propertyID = nameToPropertyIDMap.at(name); + definitions.erase(propertyID); + columnIDs.erase(propertyID); + nameToPropertyIDMap.erase(name); +} + +void PropertyDefinitionCollection::rename(const std::string& name, const std::string& newName) { + KU_ASSERT(contains(name)); + auto idx = nameToPropertyIDMap.at(name); + definitions[idx].rename(newName); + nameToPropertyIDMap.erase(name); + nameToPropertyIDMap.insert({newName, idx}); +} + +column_id_t PropertyDefinitionCollection::getMaxColumnID() const { + column_id_t maxID = 0; + for (auto [_, id] : columnIDs) { + if (id > maxID) { + maxID = id; + } + } + return maxID; +} + +property_id_t PropertyDefinitionCollection::getPropertyID(const std::string& name) const { + KU_ASSERT(contains(name)); + return nameToPropertyIDMap.at(name); +} + +std::string PropertyDefinitionCollection::toCypher() const { + std::stringstream ss; + for (auto& [_, def] : definitions) { + auto& dataType = def.getType(); + // Avoid exporting internal ID + if (dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID) { + continue; + } + auto typeStr = dataType.toString(); + StringUtils::replaceAll(typeStr, ":", " "); + if (typeStr.find("MAP") != std::string::npos) { + StringUtils::replaceAll(typeStr, " ", ","); + } + ss << "`" << def.getName() << "`" << " " << typeStr << ","; + } + return ss.str(); +} + +void PropertyDefinitionCollection::serialize(Serializer& serializer) const { + serializer.writeDebuggingInfo("nextColumnID"); + serializer.serializeValue(nextColumnID); + serializer.writeDebuggingInfo("nextPropertyID"); + serializer.serializeValue(nextPropertyID); + serializer.writeDebuggingInfo("definitions"); + serializer.serializeMap(definitions); + serializer.writeDebuggingInfo("columnIDs"); + serializer.serializeUnorderedMap(columnIDs); +} + +PropertyDefinitionCollection PropertyDefinitionCollection::deserialize(Deserializer& deserializer) { + std::string debuggingInfo; + column_id_t nextColumnID = 0; + deserializer.validateDebuggingInfo(debuggingInfo, "nextColumnID"); + deserializer.deserializeValue(nextColumnID); + property_id_t nextPropertyID = 0; + deserializer.validateDebuggingInfo(debuggingInfo, "nextPropertyID"); + deserializer.deserializeValue(nextPropertyID); + std::map definitions; + deserializer.validateDebuggingInfo(debuggingInfo, "definitions"); + deserializer.deserializeMap(definitions); + std::unordered_map columnIDs; + deserializer.validateDebuggingInfo(debuggingInfo, "columnIDs"); + deserializer.deserializeUnorderedMap(columnIDs); + auto collection = PropertyDefinitionCollection(); + for (auto& [propertyID, definition] : definitions) { + collection.nameToPropertyIDMap.insert({definition.getName(), propertyID}); + } + collection.nextColumnID = nextColumnID; + collection.nextPropertyID = nextPropertyID; + collection.definitions = std::move(definitions); + collection.columnIDs = std::move(columnIDs); + return collection; +} + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/CMakeLists.txt new file mode 100644 index 0000000000..a9e1c5d3bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/CMakeLists.txt @@ -0,0 +1,37 @@ +add_subdirectory(arrow) +add_subdirectory(copier_config) +add_subdirectory(data_chunk) +add_subdirectory(enums) +add_subdirectory(exception) +add_subdirectory(serializer) +add_subdirectory(signal) +add_subdirectory(task_system) +add_subdirectory(types) +add_subdirectory(vector) +add_subdirectory(file_system) + +add_library(lbug_common + OBJECT + case_insensitive_map.cpp + checksum.cpp + constants.cpp + database_lifecycle_manager.cpp + expression_type.cpp + in_mem_overflow_buffer.cpp + mask.cpp + md5.cpp + metric.cpp + null_mask.cpp + profiler.cpp + random_engine.cpp + roaring_mask.cpp + sha256.cpp + string_utils.cpp + system_message.cpp + type_utils.cpp + utils.cpp + windows_utils.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/CMakeLists.txt new file mode 100644 index 0000000000..e9659d26a4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_common_arrow + OBJECT + arrow_array_scan.cpp + arrow_converter.cpp + arrow_null_mask_tree.cpp + arrow_row_batch.cpp + arrow_type.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_array_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_array_scan.cpp new file mode 100644 index 0000000000..ae23c27074 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_array_scan.cpp @@ -0,0 +1,599 @@ +#include "common/arrow/arrow_converter.h" +#include "common/exception/runtime.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +// scans are based on data specification found here +// https://arrow.apache.org/docs/format/Columnar.html + +// all offsets are measured by value, not physical size + +template +static void scanArrowArrayFixedSizePrimitive(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = (const uint8_t*)array->buffers[1]; + mask->copyToValueVector(&outputVector, dstOffset, count); + memcpy(outputVector.getData() + dstOffset * outputVector.getNumBytesPerValue(), + arrayBuffer + srcOffset * sizeof(T), count * sizeof(T)); +} + +template +static void scanArrowArrayFixedSizePrimitiveAndCastTo(const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto arrayBuffer = (const SRC*)array->buffers[1]; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curValue = arrayBuffer[i + srcOffset]; + outputVector.setValue(i + dstOffset, (DST)curValue); + } + } +} + +template<> +void scanArrowArrayFixedSizePrimitive(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = (const uint8_t*)array->buffers[1]; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + outputVector.setValue(i + dstOffset, + NullMask::isNull((const uint64_t*)arrayBuffer, i + srcOffset)); + } +} + +static void scanArrowArrayDurationScaledUp(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, int64_t scaleFactor, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto arrayBuffer = ((const int64_t*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curValue = arrayBuffer[i]; + outputVector.setValue(i + dstOffset, + interval_t(0, 0, curValue * scaleFactor)); + } + } +} + +static void scanArrowArrayDurationScaledDown(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, int64_t scaleFactor, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto arrayBuffer = ((const int64_t*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curValue = arrayBuffer[i]; + outputVector.setValue(i + dstOffset, + interval_t(0, 0, curValue / scaleFactor)); + } + } +} + +static void scanArrowArrayMonthInterval(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = ((const int32_t*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curValue = arrayBuffer[i]; + outputVector.setValue(i + dstOffset, interval_t(curValue, 0, 0)); + } + } +} + +static void scanArrowArrayDayTimeInterval(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = ((const int64_t*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + int64_t curValue = arrayBuffer[i]; + int32_t day = curValue; + int64_t micros = (curValue >> (4 * sizeof(int64_t))) * 1000; + // arrow stores ms, while we store us + outputVector.setValue(i + dstOffset, interval_t(0, day, micros)); + } + } +} + +static void scanArrowArrayMonthDayNanoInterval(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = + (const int64_t*)((const uint8_t*)array->buffers[1] + srcOffset * 16); // 16 bits per value + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + int64_t curValue = arrayBuffer[2 * i]; + int32_t month = curValue; + int32_t day = curValue >> (4 * sizeof(int64_t)); + int64_t micros = arrayBuffer[2 * i + 1] / 1000; + outputVector.setValue(i + dstOffset, interval_t(month, day, micros)); + } + } +} + +template +static void scanArrowArrayBLOB(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto offsets = ((const offsetsT*)array->buffers[1]) + srcOffset; + auto arrayBuffer = (const uint8_t*)array->buffers[2]; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curOffset = offsets[i], nextOffset = offsets[i + 1]; + const uint8_t* data = arrayBuffer + curOffset; + auto length = nextOffset - curOffset; + BlobVector::addBlob(&outputVector, i + dstOffset, data, length); + } + } +} + +static void scanArrowArrayBLOBView(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = (const uint8_t*)(array->buffers[1]); + auto valueBuffs = (const uint8_t**)(array->buffers + 2); + // BLOB value buffers begin from index 2 onwards + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curView = (const int32_t*)(arrayBuffer + (i + srcOffset) * 16); + // view structures are 16 bytes long + auto viewLength = curView[0]; + if (viewLength <= 12) { + BlobVector::addBlob(&outputVector, i + dstOffset, (uint8_t*)(curView + 1), + viewLength); + } else { + auto bufIndex = curView[2]; + auto offset = curView[3]; + BlobVector::addBlob(&outputVector, i + dstOffset, valueBuffs[bufIndex] + offset, + viewLength); + } + } + } +} + +static void scanArrowArrayFixedBLOB(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, int64_t BLOBsize, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto arrayBuffer = ((const uint8_t*)array->buffers[1]) + srcOffset * BLOBsize; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + BlobVector::addBlob(&outputVector, i + dstOffset, arrayBuffer + i * BLOBsize, BLOBsize); + } + } +} + +template +static void scanArrowArrayList(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto offsets = ((const offsetsT*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + uint64_t auxDstPosition = 0; + for (uint64_t i = 0; i < count; i++) { + auto curOffset = offsets[i], nextOffset = offsets[i + 1]; + // don't check for validity, since we still need to update the offsets + auto newEntry = ListVector::addList(&outputVector, nextOffset - curOffset); + outputVector.setValue(i + dstOffset, newEntry); + if (i == 0) { + auxDstPosition = newEntry.offset; + } + } + ValueVector* auxiliaryBuffer = ListVector::getDataVector(&outputVector); + ArrowConverter::fromArrowArray(schema->children[0], array->children[0], *auxiliaryBuffer, + mask->getChild(0), offsets[0] + array->children[0]->offset, auxDstPosition, + offsets[count] - offsets[0]); +} + +template +static void scanArrowArrayListView(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto offsets = ((const offsetsT*)array->buffers[1]) + srcOffset; + auto sizes = ((const offsetsT*)array->buffers[2]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + ValueVector* auxiliaryBuffer = ListVector::getDataVector(&outputVector); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto curOffset = offsets[i], size = sizes[i]; + auto newEntry = ListVector::addList(&outputVector, size); + outputVector.setValue(i + dstOffset, newEntry); + ArrowNullMaskTree childTree(schema->children[0], array->children[0], srcOffset, count); + // make our own child here. precomputing through the mask tree is too complicated + ArrowConverter::fromArrowArray(schema->children[0], array->children[0], + *auxiliaryBuffer, &childTree, curOffset, newEntry.offset, newEntry.size); + } + } +} + +static void scanArrowArrayFixedList(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + mask->copyToValueVector(&outputVector, dstOffset, count); + auto numElements = ArrayType::getNumElements(outputVector.dataType); + for (auto i = 0u; i < count; ++i) { + auto newEntry = ListVector::addList(&outputVector, numElements); + outputVector.setValue(i + dstOffset, newEntry); + } + auto auxiliaryBuffer = ListVector::getDataVector(&outputVector); + ArrowConverter::fromArrowArray(schema->children[0], array->children[0], *auxiliaryBuffer, + mask->getChild(0), srcOffset * numElements + array->children[0]->offset, + dstOffset * numElements, count * numElements); +} + +static void scanArrowArrayStruct(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + outputVector.setValue(i + dstOffset, + i + dstOffset); // struct_entry_t doesn't work for some reason + } + } + for (int64_t j = 0; j < schema->n_children; j++) { + ArrowConverter::fromArrowArray(schema->children[j], array->children[j], + *StructVector::getFieldVector(&outputVector, j).get(), mask->getChild(j), + srcOffset + array->children[j]->offset, dstOffset, count); + } +} + +static void scanArrowArrayDenseUnion(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto types = ((const uint8_t*)array->buffers[0]) + srcOffset; + auto dstTypes = (uint16_t*)UnionVector::getTagVector(&outputVector)->getData(); + auto offsets = ((const int32_t*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + std::vector firstIncident(array->n_children, INT32_MAX); + for (auto i = 0u; i < count; i++) { + auto curType = types[i]; + auto curOffset = offsets[i]; + if (curOffset < firstIncident[curType]) { + firstIncident[curType] = curOffset; + } + if (!mask->isNull(i)) { + dstTypes[i] = curType; + auto childOffset = + mask->getChild(curType)->offsetBy(curOffset - firstIncident[curType]); + ArrowConverter::fromArrowArray(schema->children[curType], array->children[curType], + *UnionVector::getValVector(&outputVector, curType), &childOffset, + curOffset + array->children[curType]->offset, i + dstOffset, 1); + // may be inefficient, since we're only scanning a single value + } + } +} + +static void scanArrowArraySparseUnion(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + auto types = ((const uint8_t*)array->buffers[0]) + srcOffset; + auto dstTypes = (uint16_t*)UnionVector::getTagVector(&outputVector)->getData(); + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + dstTypes[i] = types[i]; + } + } + // it is specified that values that aren't selected in the type buffer + // must also be semantically correct. this is why this scanning works. + // however, there is possibly room for optimization here. + // eg. nulling out unselected children + for (int8_t i = 0; i < array->n_children; i++) { + ArrowConverter::fromArrowArray(schema->children[i], array->children[i], + *UnionVector::getValVector(&outputVector, i), mask->getChild(i), + srcOffset + array->children[i]->offset, dstOffset, count); + } +} + +template +static void scanArrowArrayDictionaryEncoded(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + + auto values = ((const offsetsT*)array->buffers[1]) + srcOffset; + mask->copyToValueVector(&outputVector, dstOffset, count); + for (uint64_t i = 0; i < count; i++) { + if (!mask->isNull(i)) { + auto dictOffseted = mask->getDictionary()->offsetBy(values[i]); + ArrowConverter::fromArrowArray(schema->dictionary, array->dictionary, outputVector, + &dictOffseted, values[i] + array->dictionary->offset, i + dstOffset, + 1); // possibly inefficient? + } + } +} + +static void scanArrowArrayRunEndEncoded(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + + const ArrowArray* runEndArray = array->children[0]; + auto runEndBuffer = (const uint32_t*)runEndArray->buffers[1]; + + // binary search run end corresponding to srcOffset + auto runEndIdx = runEndArray->offset; + { + auto L = runEndArray->offset, H = L + runEndArray->length; + while (H >= L) { + auto M = (H + L) >> 1; + if (runEndBuffer[M] < srcOffset) { + runEndIdx = M; + H = M - 1; + } else { + L = M + 1; + } + } + } + + for (uint64_t i = 0; i < count; i++) { + while (i + srcOffset >= runEndBuffer[runEndIdx + 1]) { + runEndIdx++; + } + auto valuesOffseted = mask->getChild(1)->offsetBy(runEndIdx); + ArrowConverter::fromArrowArray(schema->children[1], array->children[1], outputVector, + &valuesOffseted, runEndIdx, i + dstOffset, + 1); // there is optimization to be made here... + } +} + +void ArrowConverter::fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count) { + const auto arrowType = schema->format; + if (array->dictionary != nullptr) { + switch (arrowType[0]) { + case 'c': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 'C': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 's': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 'S': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 'i': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 'I': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 'l': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + case 'L': + return scanArrowArrayDictionaryEncoded(schema, array, outputVector, mask, + srcOffset, dstOffset, count); + default: + throw RuntimeException("Invalid Index Type: " + std::string(arrowType)); + } + } + switch (arrowType[0]) { + case 'n': + // NULL + outputVector.setAllNull(); + return; + case 'b': + // BOOL + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'c': + // INT8 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'C': + // UINT8 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 's': + // INT16 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'S': + // UINT16 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'i': + // INT32 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'I': + // UINT32 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'l': + // INT64 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'L': + // UINT64 + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'f': + // FLOAT + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'g': + // DOUBLE + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'z': + // BLOB + return scanArrowArrayBLOB(array, outputVector, mask, srcOffset, dstOffset, count); + case 'Z': + // LONG BLOB + return scanArrowArrayBLOB(array, outputVector, mask, srcOffset, dstOffset, count); + case 'u': + // STRING + return scanArrowArrayBLOB(array, outputVector, mask, srcOffset, dstOffset, count); + case 'U': + // LONG STRING + return scanArrowArrayBLOB(array, outputVector, mask, srcOffset, dstOffset, count); + case 'v': + switch (arrowType[1]) { + case 'z': + // BINARY VIEW + case 'u': + // STRING VIEW + return scanArrowArrayBLOBView(array, outputVector, mask, srcOffset, dstOffset, count); + default: + KU_UNREACHABLE; + } + case 'd': { + switch (outputVector.dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, + mask, srcOffset, dstOffset, count); + case PhysicalTypeID::INT32: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, + mask, srcOffset, dstOffset, count); + case PhysicalTypeID::INT64: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, + mask, srcOffset, dstOffset, count); + case PhysicalTypeID::INT128: + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + default: + KU_UNREACHABLE; + } + } + case 'w': + // FIXED BLOB + return scanArrowArrayFixedBLOB(array, outputVector, mask, std::stoi(arrowType + 2), + srcOffset, dstOffset, count); + case 't': + switch (arrowType[1]) { + case 'd': + // DATE + if (arrowType[2] == 'D') { + // days since unix epoch + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, + srcOffset, dstOffset, count); + } else { + // ms since unix epoch + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, + srcOffset, dstOffset, count); + } + case 't': + // TODO pure time type + KU_UNREACHABLE; + case 's': + // TIMESTAMP + return scanArrowArrayFixedSizePrimitive(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'D': + // DURATION (LBUG INTERVAL) + switch (arrowType[2]) { + case 's': + // consider implement overflow checking here? + return scanArrowArrayDurationScaledUp(array, outputVector, mask, 1000000, srcOffset, + dstOffset, count); + case 'm': + return scanArrowArrayDurationScaledUp(array, outputVector, mask, 1000, srcOffset, + dstOffset, count); + case 'u': + return scanArrowArrayDurationScaledUp(array, outputVector, mask, 1, srcOffset, + dstOffset, count); + case 'n': + return scanArrowArrayDurationScaledDown(array, outputVector, mask, 1000, srcOffset, + dstOffset, count); + default: + KU_UNREACHABLE; + } + case 'i': + // INTERVAL + switch (arrowType[2]) { + case 'M': + return scanArrowArrayMonthInterval(array, outputVector, mask, srcOffset, dstOffset, + count); + case 'D': + return scanArrowArrayDayTimeInterval(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'n': + return scanArrowArrayMonthDayNanoInterval(array, outputVector, mask, srcOffset, + dstOffset, count); + default: + KU_UNREACHABLE; + } + default: + KU_UNREACHABLE; + } + case '+': + switch (arrowType[1]) { + case 'r': + // RUN END ENCODED + return scanArrowArrayRunEndEncoded(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + case 'l': + // LIST + return scanArrowArrayList(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + case 'L': + // LONG LIST + return scanArrowArrayList(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + case 'w': { + // ARRAY + RUNTIME_CHECK({ + auto arrowNumElements = std::stoul(arrowType + 3); + auto outputNumElements = ArrayType::getNumElements(outputVector.dataType); + KU_ASSERT(arrowNumElements == outputNumElements); + }); + return scanArrowArrayFixedList(schema, array, outputVector, mask, srcOffset, dstOffset, + count); + } + case 's': + // STRUCT + return scanArrowArrayStruct(schema, array, outputVector, mask, srcOffset, dstOffset, + count); + case 'm': + // MAP + return scanArrowArrayList(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + case 'u': + if (arrowType[2] == 'd') { + // DENSE UNION + return scanArrowArrayDenseUnion(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + } else { + // SPARSE UNION + return scanArrowArraySparseUnion(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + } + case 'v': + switch (arrowType[2]) { + case 'l': + return scanArrowArrayListView(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + case 'L': + return scanArrowArrayListView(schema, array, outputVector, mask, srcOffset, + dstOffset, count); + // LONG LIST VIEW + default: + KU_UNREACHABLE; + } + default: + KU_UNREACHABLE; + } + default: + KU_UNREACHABLE; + } +} + +void ArrowConverter::fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector) { + ArrowNullMaskTree mask(schema, array, array->offset, array->length); + return fromArrowArray(schema, array, outputVector, &mask, array->offset, 0, array->length); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_converter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_converter.cpp new file mode 100644 index 0000000000..ce74a18db5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_converter.cpp @@ -0,0 +1,322 @@ +#include "common/arrow/arrow_converter.h" + +#include + +#include "common/arrow/arrow_row_batch.h" +#include "common/exception/runtime.h" + +namespace lbug { +namespace common { + +static void releaseArrowSchema(ArrowSchema* schema) { + if (!schema || !schema->release) { + return; + } + schema->release = nullptr; + auto holder = static_cast(schema->private_data); + delete holder; +} + +// Copies the given string into the arrow holder's owned names and returns a pointer to the owned +// version +static const char* copyName(ArrowSchemaHolder& rootHolder, const std::string& name) { + auto length = name.length(); + std::unique_ptr namePtr = std::make_unique(length + 1); + std::memcpy(namePtr.get(), name.c_str(), length); + namePtr[length] = '\0'; + rootHolder.ownedTypeNames.push_back(std::move(namePtr)); + return rootHolder.ownedTypeNames.back().get(); +} + +// The resulting byte array follows the format described here: +// https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata +static std::unique_ptr serializeMetadata( + const std::map& metadata) { + // Calculate size of byte array + auto numEntries = metadata.size(); + auto size = (2 * numEntries + 1) * sizeof(int32_t); + for (const auto& [k, v] : metadata) { + size += k.size() + v.size(); + } + std::unique_ptr bytes(new char[size]); + // Copy data into byte array + char* ptr = bytes.get(); + memcpy(ptr, &numEntries, sizeof(int32_t)); + ptr += sizeof(int32_t); + for (const auto& [k, v] : metadata) { + auto ksz = k.size(), vsz = v.size(); + memcpy(ptr, &ksz, sizeof(int32_t)); + ptr += sizeof(int32_t); + memcpy(ptr, k.c_str(), ksz); + ptr += ksz; + memcpy(ptr, &vsz, sizeof(int32_t)); + ptr += sizeof(int32_t); + memcpy(ptr, v.c_str(), vsz); + ptr += vsz; + } + return bytes; +} + +static const char* copyMetadata(ArrowSchemaHolder& rootHolder, + const std::map& metadata) { + rootHolder.ownedMetadatas.push_back(serializeMetadata(metadata)); + return rootHolder.ownedMetadatas.back().get(); +} + +void ArrowConverter::initializeChild(ArrowSchema& child, const std::string& name) { + //! Child is cleaned up by parent + child.private_data = nullptr; + child.release = releaseArrowSchema; + + //! Store the child schema + child.flags = ARROW_FLAG_NULLABLE; + child.name = name.c_str(); + child.n_children = 0; + child.children = nullptr; + child.metadata = nullptr; + child.dictionary = nullptr; +} + +void ArrowConverter::setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes) { + child.format = "+s"; + // name is set by parent. + child.n_children = (std::int64_t)StructType::getNumFields(dataType); + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(child.n_children); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().resize(child.n_children); + for (auto i = 0u; i < child.n_children; i++) { + rootHolder.nestedChildrenPtr.back()[i] = &rootHolder.nestedChildren.back()[i]; + } + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + for (auto i = 0u; i < child.n_children; i++) { + initializeChild(*child.children[i]); + const auto& structField = StructType::getField(dataType, i); + child.children[i]->name = copyName(rootHolder, structField.getName()); + setArrowFormat(rootHolder, *child.children[i], structField.getType(), + fallbackExtensionTypes); + } +} + +void ArrowConverter::setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes) { + std::string formatStr = "+ud"; + child.n_children = (std::int64_t)UnionType::getNumFields(dataType); + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(child.n_children); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().resize(child.n_children); + for (auto i = 0u; i < child.n_children; i++) { + rootHolder.nestedChildrenPtr.back()[i] = &rootHolder.nestedChildren.back()[i]; + } + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + for (auto i = 0u; i < child.n_children; i++) { + initializeChild(*child.children[i]); + const auto& unionFieldType = UnionType::getFieldType(dataType, i); + auto unionFieldName = UnionType::getFieldName(dataType, i); + child.children[i]->name = copyName(rootHolder, unionFieldName); + setArrowFormat(rootHolder, *child.children[i], unionFieldType, fallbackExtensionTypes); + formatStr += (i == 0u ? ":" : ",") + std::to_string(i); + } + child.format = copyName(rootHolder, formatStr); +} + +void ArrowConverter::setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& /*dataType*/, bool fallbackExtensionTypes) { + child.format = "+s"; + // name is set by parent. + child.n_children = 2; + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(child.n_children); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().resize(child.n_children); + for (auto i = 0u; i < child.n_children; i++) { + rootHolder.nestedChildrenPtr.back()[i] = &rootHolder.nestedChildren.back()[i]; + } + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + initializeChild(*child.children[0]); + child.children[0]->name = copyName(rootHolder, "offset"); + setArrowFormat(rootHolder, *child.children[0], LogicalType::INT64(), fallbackExtensionTypes); + initializeChild(*child.children[1]); + child.children[1]->name = copyName(rootHolder, "table"); + setArrowFormat(rootHolder, *child.children[1], LogicalType::INT64(), fallbackExtensionTypes); +} + +void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + child.format = "b"; + } break; + case LogicalTypeID::INT128: { + child.format = "d:38,0"; + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + child.format = "l"; + } break; + case LogicalTypeID::INT32: { + child.format = "i"; + } break; + case LogicalTypeID::INT16: { + child.format = "s"; + } break; + case LogicalTypeID::INT8: { + child.format = "c"; + } break; + case LogicalTypeID::UINT64: { + child.format = "L"; + } break; + case LogicalTypeID::UINT32: { + child.format = "I"; + } break; + case LogicalTypeID::UINT16: { + child.format = "S"; + } break; + case LogicalTypeID::UINT8: { + child.format = "C"; + } break; + case LogicalTypeID::DOUBLE: { + child.format = "g"; + } break; + case LogicalTypeID::FLOAT: { + child.format = "f"; + } break; + case LogicalTypeID::DECIMAL: { + auto formatString = "d:" + std::to_string(DecimalType::getPrecision(dataType)) + "," + + std::to_string(DecimalType::getScale(dataType)); + child.format = copyName(rootHolder, formatString); + } break; + case LogicalTypeID::DATE: { + child.format = "tdD"; + } break; + case LogicalTypeID::TIMESTAMP_MS: { + child.format = "tsm:"; + } break; + case LogicalTypeID::TIMESTAMP_NS: { + child.format = "tsn:"; + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + child.format = "tss:"; + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + auto format = "tsu:UTC"; + child.format = copyName(rootHolder, format); + } break; + case LogicalTypeID::TIMESTAMP: { + child.format = "tsu:"; + } break; + case LogicalTypeID::INTERVAL: { + child.format = "tDu"; + } break; + case LogicalTypeID::UUID: { + if (!fallbackExtensionTypes) { + child.format = "w:16"; + child.metadata = copyMetadata(rootHolder, + {{"ARROW:extension:name", "arrow.uuid"}, {"ARROW:extension:metadata", ""}}); + break; + } + [[fallthrough]]; + } + case LogicalTypeID::STRING: { + child.format = "u"; + } break; + case LogicalTypeID::BLOB: { + child.format = "z"; + } break; + case LogicalTypeID::LIST: { + child.format = "+l"; + child.n_children = 1; + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(1); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().push_back(&rootHolder.nestedChildren.back()[0]); + initializeChild(rootHolder.nestedChildren.back()[0]); + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + child.children[0]->name = "l"; + setArrowFormat(rootHolder, **child.children, ListType::getChildType(dataType), + fallbackExtensionTypes); + } break; + case LogicalTypeID::ARRAY: { + auto numValuesPerArray = "+w:" + std::to_string(ArrayType::getNumElements(dataType)); + child.format = copyName(rootHolder, numValuesPerArray); + child.n_children = 1; + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(1); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().push_back(&rootHolder.nestedChildren.back()[0]); + initializeChild(rootHolder.nestedChildren.back()[0]); + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + child.children[0]->name = "l"; + setArrowFormat(rootHolder, **child.children, ArrayType::getChildType(dataType), + fallbackExtensionTypes); + } break; + case LogicalTypeID::MAP: { + child.format = "+m"; + child.n_children = 1; + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(1); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().push_back(&rootHolder.nestedChildren.back()[0]); + initializeChild(rootHolder.nestedChildren.back()[0]); + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + child.children[0]->name = "entries"; + setArrowFormat(rootHolder, **child.children, ListType::getChildType(dataType), + fallbackExtensionTypes); + child.children[0]->children[0]->flags &= + ~ARROW_FLAG_NULLABLE; // Map's keys must be non-nullable + } break; + case LogicalTypeID::STRUCT: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + setArrowFormatForStruct(rootHolder, child, dataType, fallbackExtensionTypes); + break; + case LogicalTypeID::INTERNAL_ID: + setArrowFormatForInternalID(rootHolder, child, dataType, fallbackExtensionTypes); + break; + case LogicalTypeID::UNION: + setArrowFormatForUnion(rootHolder, child, dataType, fallbackExtensionTypes); + break; + default: + throw RuntimeException( + stringFormat("{} cannot be exported to arrow.", dataType.toString())); + } +} + +std::unique_ptr ArrowConverter::toArrowSchema( + const std::vector& dataTypes, const std::vector& columnNames, + bool fallbackExtensionTypes) { + auto outSchema = std::make_unique(); + auto rootHolder = std::make_unique(); + + auto columnCount = (int64_t)dataTypes.size(); + rootHolder->children.resize(columnCount); + rootHolder->childrenPtrs.resize(columnCount); + for (auto i = 0u; i < columnCount; i++) { + rootHolder->childrenPtrs[i] = &rootHolder->children[i]; + } + outSchema->children = rootHolder->childrenPtrs.data(); + outSchema->n_children = columnCount; + + outSchema->format = "+s"; // struct apparently + outSchema->flags = 0; + outSchema->metadata = nullptr; + outSchema->name = "lbug_query_result"; + outSchema->dictionary = nullptr; + + for (auto i = 0u; i < columnCount; i++) { + auto& child = rootHolder->children[i]; + initializeChild(child); + child.name = copyName(*rootHolder, columnNames[i]); + setArrowFormat(*rootHolder, child, dataTypes[i], fallbackExtensionTypes); + } + + outSchema->private_data = rootHolder.release(); + outSchema->release = releaseArrowSchema; + return outSchema; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_null_mask_tree.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_null_mask_tree.cpp new file mode 100644 index 0000000000..0ad4fa4034 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_null_mask_tree.cpp @@ -0,0 +1,222 @@ +#include + +#include "common/arrow/arrow.h" +#include "common/arrow/arrow_nullmask_tree.h" + +namespace lbug { +namespace common { + +// scans are based on data specification found here +// https://arrow.apache.org/docs/format/Columnar.html + +// all offsets are measured by value, not physical size + +void ArrowNullMaskTree::copyToValueVector(ValueVector* vec, uint64_t dstOffset, uint64_t count) { + vec->setNullFromBits(mask->getData(), offset, dstOffset, count); +} + +ArrowNullMaskTree ArrowNullMaskTree::offsetBy(int64_t offset) { + // this operation is mostly a special case for dictionary/run-end encoding + ArrowNullMaskTree ret(*this); + ret.offset += offset; + return ret; +} + +bool ArrowNullMaskTree::copyFromBuffer(const void* buffer, uint64_t srcOffset, uint64_t count) { + if (buffer == nullptr) { + mask->setAllNonNull(); + return false; + } + mask->copyFromNullBits((const uint64_t*)buffer, srcOffset, 0, count, true); + return true; +} + +bool ArrowNullMaskTree::applyParentBitmap(const NullMask* parent) { + if (parent == nullptr) { + return false; + } + if (parent->getData() != nullptr) { + *mask |= *parent; + return true; + } + return false; +} + +template +void ArrowNullMaskTree::scanListPushDown(const ArrowSchema* schema, const ArrowArray* array, + uint64_t srcOffset, uint64_t count) { + const offsetsT* offsets = ((const offsetsT*)array->buffers[1]) + srcOffset; + offsetsT auxiliaryLength = offsets[count] - offsets[0]; + NullMask pushDownMask(auxiliaryLength); + for (uint64_t i = 0; i < count; i++) { + pushDownMask.setNullFromRange(offsets[i] - offsets[0], offsets[i + 1] - offsets[i], + isNull(i)); + } + children->push_back(ArrowNullMaskTree(schema->children[0], array->children[0], + offsets[0] + array->children[0]->offset, auxiliaryLength, &pushDownMask)); +} + +void ArrowNullMaskTree::scanArrayPushDown(const ArrowSchema* schema, const ArrowArray* array, + uint64_t srcOffset, uint64_t count) { + auto numElements = std::stoul(schema->format + 3); + auto auxiliaryLength = count * numElements; + NullMask pushDownMask(auxiliaryLength); + for (auto i = 0u; i < count; ++i) { + pushDownMask.setNullFromRange(i * numElements, numElements, isNull(i)); + } + children->push_back(ArrowNullMaskTree(schema->children[0], array->children[0], + srcOffset * numElements + array->children[0]->offset, auxiliaryLength, &pushDownMask)); +} + +void ArrowNullMaskTree::scanStructPushDown(const ArrowSchema* schema, const ArrowArray* array, + uint64_t srcOffset, uint64_t count) { + for (int64_t i = 0; i < array->n_children; i++) { + children->push_back(ArrowNullMaskTree(schema->children[i], array->children[i], + srcOffset + array->children[i]->offset, count, mask.get())); + } +} + +ArrowNullMaskTree::ArrowNullMaskTree(const ArrowSchema* schema, const ArrowArray* array, + uint64_t srcOffset, uint64_t count, const NullMask* parentBitmap) + : offset{0}, mask{std::make_shared(count)}, + children(std::make_shared>()) { + if (schema->dictionary != nullptr) { + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + dictionary = std::make_shared(schema->dictionary, array->dictionary, + array->dictionary->offset, array->dictionary->length); + return; + } + const char* arrowType = schema->format; + std::vector structFields; + switch (arrowType[0]) { + case 'n': + mask->setAllNull(); + break; + case 'b': + case 'c': + case 'C': + case 's': + case 'S': + case 'i': + case 'I': + case 'l': + case 'L': + case 'd': + case 'f': + case 'g': + copyFromBuffer(array->buffers[0], srcOffset, count); + break; + case 'z': + case 'Z': + case 'u': + case 'U': + case 'v': + case 'w': + case 't': + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + break; + case '+': + switch (arrowType[1]) { + case 'l': + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + scanListPushDown(schema, array, srcOffset, count); + break; + case 'L': + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + scanListPushDown(schema, array, srcOffset, count); + break; + case 'w': + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + scanArrayPushDown(schema, array, srcOffset, count); + break; + case 's': + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + scanStructPushDown(schema, array, srcOffset, count); + break; + case 'm': + copyFromBuffer(array->buffers[0], srcOffset, count); + applyParentBitmap(parentBitmap); + scanListPushDown(schema, array, srcOffset, count); + break; + case 'u': { + auto types = (const int8_t*)array->buffers[0]; + if (schema->format[2] == 'd') { + auto offsets = (const int32_t*)array->buffers[1]; + std::vector countChildren(array->n_children), + lowestOffsets(array->n_children); + std::vector highestOffsets(array->n_children); + for (auto i = srcOffset; i < srcOffset + count; i++) { + int32_t curOffset = offsets[i]; + int32_t curType = types[i]; + if (countChildren[curType] == 0) { + lowestOffsets[curType] = curOffset; + } + highestOffsets[curType] = curOffset; + countChildren[curType]++; + } + for (int64_t i = 0; i < array->n_children; i++) { + children->push_back(ArrowNullMaskTree(schema->children[i], array->children[i], + lowestOffsets[i] + array->children[i]->offset, + highestOffsets[i] - lowestOffsets[i] + 1)); + } + for (auto i = 0u; i < count; i++) { + int32_t curOffset = offsets[i + srcOffset]; + int8_t curType = types[i + srcOffset]; + mask->setNull(i, + children->operator[](curType).isNull(curOffset - lowestOffsets[curType])); + } + } else { + for (int64_t i = 0; i < array->n_children; i++) { + children->push_back(ArrowNullMaskTree(schema->children[i], array->children[i], + srcOffset + array->children[i]->offset, count)); + } + for (auto i = 0u; i < count; i++) { + int8_t curType = types[i + srcOffset]; + mask->setNull(i, children->operator[](curType).isNull(i)); + // this isn't specified in the arrow specification, but is it valid to + // compute this using a bitwise OR? + } + } + if (parentBitmap != nullptr) { + *mask |= *parentBitmap; + } + } break; + case 'v': + // list views *suck*, especially when trying to write code that can support + // parallelization for this, we generate child NullMaskTrees on the fly, rather than + // attempt any precomputation + if (array->buffers[0] == nullptr) { + mask->setAllNonNull(); + } else { + mask->copyFromNullBits((const uint64_t*)array->buffers[0], srcOffset, 0, count, + true); + } + if (parentBitmap != nullptr) { + *mask |= *parentBitmap; + } + break; + case 'r': + // it's better to resolve validity during the actual scanning for run-end encoded arrays + // so for this, let's just resolve child validities and move on + for (int64_t i = 0; i < array->n_children; i++) { + children->push_back(ArrowNullMaskTree(schema->children[i], array->children[i], + array->children[i]->offset, array->children[i]->length)); + } + break; + default: + KU_UNREACHABLE; + } + break; + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_row_batch.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_row_batch.cpp new file mode 100644 index 0000000000..dba091ef8b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_row_batch.cpp @@ -0,0 +1,1022 @@ +#include "common/arrow/arrow_row_batch.h" + +#include + +#include "common/exception/runtime.h" +#include "common/types/value/node.h" +#include "common/types/value/rel.h" +#include "common/types/value/value.h" +#include "processor/result/flat_tuple.h" +#include "storage/storage_utils.h" + +namespace lbug { +namespace common { + +static void resizeVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes); + +ArrowRowBatch::ArrowRowBatch(const std::vector& types, std::int64_t capacity, + bool fallbackExtensionTypes) + : numTuples{0}, fallbackExtensionTypes{fallbackExtensionTypes} { + vectors.resize(types.size()); + for (auto i = 0u; i < types.size(); i++) { + vectors[i] = std::make_unique(); + resizeVector(vectors[i].get(), types[i], capacity, fallbackExtensionTypes); + } +} + +static uint64_t getArrowMainBufferSize(const LogicalType& type, uint64_t capacity, + bool fallbackExtensionTypes) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + return getNumBytesForBits(capacity); + case LogicalTypeID::SERIAL: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::INTERVAL: + case LogicalTypeID::UINT64: + case LogicalTypeID::INT64: + return sizeof(int64_t) * capacity; + case LogicalTypeID::DATE: + case LogicalTypeID::UINT32: + case LogicalTypeID::INT32: + return sizeof(int32_t) * capacity; + case LogicalTypeID::UINT16: + case LogicalTypeID::INT16: + return sizeof(int16_t) * capacity; + case LogicalTypeID::UNION: + case LogicalTypeID::UINT8: + case LogicalTypeID::INT8: + return sizeof(int8_t) * capacity; + case LogicalTypeID::DECIMAL: + case LogicalTypeID::INT128: + return sizeof(int128_t) * capacity; + case LogicalTypeID::DOUBLE: + return sizeof(double) * capacity; + case LogicalTypeID::FLOAT: + return sizeof(float) * capacity; + case LogicalTypeID::UUID: { + if (!fallbackExtensionTypes) { + return sizeof(char) * 16 * capacity; + } + [[fallthrough]]; + } + case LogicalTypeID::STRING: + case LogicalTypeID::BLOB: + case LogicalTypeID::LIST: + case LogicalTypeID::MAP: + return sizeof(int32_t) * (capacity + 1); + case LogicalTypeID::ARRAY: + case LogicalTypeID::STRUCT: + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + return 0; // no main buffer + default: + KU_UNREACHABLE; // should enumerate all types. + } +} + +static void resizeValidityBuffer(ArrowVector* vector, int64_t capacity) { + vector->validity.resize(getNumBytesForBits(capacity), 0xFF); +} + +static void resizeMainBuffer(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes) { + vector->data.resize(getArrowMainBufferSize(type, capacity, fallbackExtensionTypes)); +} + +static void resizeBLOBOverflow(ArrowVector* vector, int64_t capacity) { + vector->overflow.resize(capacity); +} + +static void resizeUnionOverflow(ArrowVector* vector, int64_t capacity) { + vector->overflow.resize(capacity * sizeof(int32_t)); +} + +static void resizeChildVectors(ArrowVector* vector, const std::vector& childTypes, + int64_t childCapacity, bool fallbackExtensionTypes) { + for (auto i = 0u; i < childTypes.size(); i++) { + if (i >= vector->childData.size()) { + vector->childData.push_back(std::make_unique()); + } + resizeVector(vector->childData[i].get(), childTypes[i], childCapacity, + fallbackExtensionTypes); + } +} + +static void resizeGeneric(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes) { + if (vector->capacity >= capacity) { + return; + } + while (vector->capacity < capacity) { + if (vector->capacity == 0) { + vector->capacity = 1; + } else { + vector->capacity *= 2; + } + } + resizeValidityBuffer(vector, vector->capacity); + resizeMainBuffer(vector, type, vector->capacity, fallbackExtensionTypes); +} + +static void resizeBLOBVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + int64_t overflowCapacity, bool fallbackExtensionTypes) { + resizeGeneric(vector, type, capacity, fallbackExtensionTypes); + resizeBLOBOverflow(vector, overflowCapacity); +} + +static void resizeFixedListVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes) { + resizeGeneric(vector, type, capacity, fallbackExtensionTypes); + std::vector typeVec; + typeVec.push_back(ArrayType::getChildType(type).copy()); + resizeChildVectors(vector, typeVec, vector->capacity * ArrayType::getNumElements(type), + fallbackExtensionTypes); +} + +static void resizeListVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + int64_t childCapacity, bool fallbackExtensionTypes) { + resizeGeneric(vector, type, capacity, fallbackExtensionTypes); + std::vector typeVec; + typeVec.push_back(ListType::getChildType(type).copy()); + resizeChildVectors(vector, typeVec, childCapacity, fallbackExtensionTypes); +} + +static void resizeStructVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes) { + resizeGeneric(vector, type, capacity, fallbackExtensionTypes); + std::vector typeVec; + for (auto i : StructType::getFieldTypes(type)) { + typeVec.push_back(i->copy()); + } + resizeChildVectors(vector, typeVec, vector->capacity, fallbackExtensionTypes); +} + +static void resizeUnionVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes) { + if (vector->capacity < capacity) { + while (vector->capacity < capacity) { + if (vector->capacity == 0) { + vector->capacity = 1; + } else { + vector->capacity *= 2; + } + } + resizeMainBuffer(vector, type, vector->capacity, fallbackExtensionTypes); + } + resizeUnionOverflow(vector, vector->capacity); + std::vector childTypes; + for (auto i = 0u; i < UnionType::getNumFields(type); i++) { + childTypes.push_back(UnionType::getFieldType(type, i).copy()); + } + resizeChildVectors(vector, childTypes, vector->capacity, fallbackExtensionTypes); +} + +static void resizeInternalIDVector(ArrowVector* vector, const LogicalType& type, int64_t capacity, + bool fallbackExtensionTypes) { + resizeGeneric(vector, type, capacity, fallbackExtensionTypes); + std::vector typeVec; + typeVec.push_back(LogicalType::INT64()); + typeVec.push_back(LogicalType::INT64()); + resizeChildVectors(vector, typeVec, vector->capacity, fallbackExtensionTypes); +} + +static void resizeVector(ArrowVector* vector, const LogicalType& type, std::int64_t capacity, + bool fallbackExtensionTypes) { + auto result = std::make_unique(); + switch (type.getLogicalTypeID()) { + case LogicalTypeID::UUID: { + if (fallbackExtensionTypes) { + resizeBLOBVector(vector, type, capacity, capacity, fallbackExtensionTypes); + return; + } + [[fallthrough]]; + } + case LogicalTypeID::BOOL: + case LogicalTypeID::DECIMAL: + case LogicalTypeID::INT128: + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::INT8: + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INTERVAL: + return resizeGeneric(vector, type, capacity, fallbackExtensionTypes); + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: + return resizeBLOBVector(vector, type, capacity, capacity, fallbackExtensionTypes); + case LogicalTypeID::LIST: + case LogicalTypeID::MAP: + return resizeListVector(vector, type, capacity, capacity, fallbackExtensionTypes); + case LogicalTypeID::ARRAY: + return resizeFixedListVector(vector, type, capacity, fallbackExtensionTypes); + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::STRUCT: + return resizeStructVector(vector, type, capacity, fallbackExtensionTypes); + case LogicalTypeID::UNION: + return resizeUnionVector(vector, type, capacity, fallbackExtensionTypes); + case LogicalTypeID::INTERNAL_ID: + return resizeInternalIDVector(vector, type, capacity, fallbackExtensionTypes); + default: { + // LCOV_EXCL_START + throw common::RuntimeException{ + common::stringFormat("Unsupported type: {} for arrow conversion.", type.toString())}; + // LCOV_EXCL_STOP + } + } +} + +static void getBitPosition(std::int64_t pos, std::int64_t& bytePos, std::int64_t& bitOffset) { + bytePos = pos >> 3; + bitOffset = pos - (bytePos << 3); +} + +static void setBitToZero(std::uint8_t* data, std::int64_t pos) { + std::int64_t bytePos = 0, bitOffset = 0; + getBitPosition(pos, bytePos, bitOffset); + data[bytePos] &= ~((std::uint64_t)1 << bitOffset); +} + +static void setBitToOne(std::uint8_t* data, std::int64_t pos) { + std::int64_t bytePos = 0, bitOffset = 0; + getBitPosition(pos, bytePos, bitOffset); + data[bytePos] |= ((std::uint64_t)1 << bitOffset); +} + +void ArrowRowBatch::appendValue(ArrowVector* vector, const Value& value, + bool fallbackExtensionTypes) { + if (value.isNull()) { + copyNullValue(vector, value, vector->numValues); + } else { + copyNonNullValue(vector, value, vector->numValues, fallbackExtensionTypes); + } + vector->numValues++; +} + +template +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const Value& value, + std::int64_t pos, bool) { + auto valSize = storage::StorageUtils::getDataTypeSize(LogicalType{DT}); + std::memcpy(vector->data.data() + pos * valSize, &value.val, valSize); +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool) { + auto valSize = storage::StorageUtils::getDataTypeSize(value.getDataType()); + std::memcpy(vector->data.data() + pos * 16, &value.val, valSize); +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool) { + auto destAddr = (int64_t*)(vector->data.data() + pos * sizeof(std::int64_t)); + auto intervalVal = value.val.intervalVal; + *destAddr = intervalVal.micros + intervalVal.days * Interval::MICROS_PER_DAY + + intervalVal.months * Interval::MICROS_PER_MONTH; +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool) { + if (value.val.booleanVal) { + setBitToOne(vector->data.data(), pos); + } else { + setBitToZero(vector->data.data(), pos); + } +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool) { + auto offsets = (std::uint32_t*)vector->data.data(); + auto strLength = value.strVal.length(); + if (pos == 0) { + offsets[pos] = 0; + } + offsets[pos + 1] = offsets[pos] + strLength; + vector->overflow.resize(offsets[pos + 1] + 1); + std::memcpy(vector->overflow.data() + offsets[pos], value.strVal.data(), strLength); +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool fallbackExtensionTypes) { + if (!fallbackExtensionTypes) { + auto valSize = sizeof(int128_t); + auto val = value.val.int128Val; + val.high ^= (int64_t(1) << 63); // MSB is stored flipped internally + // Convert to little-endian + auto valPtr = reinterpret_cast(&val); + for (auto i = 0u; i < valSize / 2; ++i) { + std::swap(valPtr[i], valPtr[valSize - i - 1]); + } + std::memcpy(vector->data.data() + pos * valSize, &val, valSize); + } else { + auto offsets = (std::uint32_t*)vector->data.data(); + auto str = UUID::toString(value.val.int128Val); + auto strLength = str.length(); + if (pos == 0) { + offsets[pos] = 0; + } + offsets[pos + 1] = offsets[pos] + strLength; + vector->overflow.resize(offsets[pos + 1]); + std::memcpy(vector->overflow.data() + offsets[pos], str.data(), strLength); + } +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool fallbackExtensionTypes) { + auto offsets = (std::uint32_t*)vector->data.data(); + auto numElements = value.childrenSize; + if (pos == 0) { + offsets[pos] = 0; + } + offsets[pos + 1] = offsets[pos] + numElements; + std::vector typeVec; + typeVec.push_back(ListType::getChildType(value.getDataType()).copy()); + resizeChildVectors(vector, typeVec, offsets[pos + 1] + 1, fallbackExtensionTypes); + for (auto i = 0u; i < numElements; i++) { + appendValue(vector->childData[0].get(), *value.children[i], fallbackExtensionTypes); + } +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t /*pos*/, bool fallbackExtensionTypes) { + auto numElements = value.childrenSize; + for (auto i = 0u; i < numElements; i++) { + appendValue(vector->childData[0].get(), *value.children[i], fallbackExtensionTypes); + } +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool fallbackExtensionTypes) { + // Verify all keys are not null + for (auto i = 0u; i < value.childrenSize; ++i) { + if (value.children[i]->children[0]->isNull()) { + throw RuntimeException{ + stringFormat("Cannot convert map with null key to Arrow: {}", value.toString())}; + } + } + return templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t /*pos*/, bool fallbackExtensionTypes) { + for (auto i = 0u; i < value.childrenSize; i++) { + appendValue(vector->childData[i].get(), *value.children[i], fallbackExtensionTypes); + } +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t pos, bool fallbackExtensionTypes) { + auto typeBuffer = (std::uint8_t*)vector->data.data(); + auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); + auto& type = value.getDataType(); + for (auto i = 0u; i < UnionType::getNumFields(type); i++) { + if (UnionType::getFieldType(type, i) == value.children[0]->dataType) { + typeBuffer[pos] = i; + offsetsBuffer[pos] = vector->childData[i]->numValues; + return appendValue(vector->childData[i].get(), *value.children[0], + fallbackExtensionTypes); + } + } + KU_UNREACHABLE; // We should always be able to find a matching type +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t /*pos*/, bool fallbackExtensionTypes) { + auto nodeID = value.getValue(); + Value offsetVal((std::int64_t)nodeID.offset); + Value tableIDVal((std::int64_t)nodeID.tableID); + appendValue(vector->childData[0].get(), offsetVal, fallbackExtensionTypes); + appendValue(vector->childData[1].get(), tableIDVal, fallbackExtensionTypes); +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t /*pos*/, bool fallbackExtensionTypes) { + appendValue(vector->childData[0].get(), *NodeVal::getNodeIDVal(&value), fallbackExtensionTypes); + appendValue(vector->childData[1].get(), *NodeVal::getLabelVal(&value), fallbackExtensionTypes); + std::int64_t propertyId = 2; + auto numProperties = NodeVal::getNumProperties(&value); + for (auto i = 0u; i < numProperties; i++) { + auto val = NodeVal::getPropertyVal(&value, i); + appendValue(vector->childData[propertyId].get(), *val, fallbackExtensionTypes); + propertyId++; + } +} + +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const Value& value, std::int64_t /*pos*/, bool fallbackExtensionTypes) { + appendValue(vector->childData[0].get(), *RelVal::getSrcNodeIDVal(&value), + fallbackExtensionTypes); + appendValue(vector->childData[1].get(), *RelVal::getDstNodeIDVal(&value), + fallbackExtensionTypes); + appendValue(vector->childData[2].get(), *RelVal::getLabelVal(&value), fallbackExtensionTypes); + appendValue(vector->childData[3].get(), *RelVal::getIDVal(&value), fallbackExtensionTypes); + common::property_id_t propertyID = 4; + auto numProperties = RelVal::getNumProperties(&value); + for (auto i = 0u; i < numProperties; i++) { + auto val = RelVal::getPropertyVal(&value, i); + appendValue(vector->childData[propertyID].get(), *val, fallbackExtensionTypes); + propertyID++; + } +} + +void ArrowRowBatch::copyNonNullValue(ArrowVector* vector, const Value& value, std::int64_t pos, + bool fallbackExtensionTypes) { + switch (value.getDataType().getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::DECIMAL: + case LogicalTypeID::INT128: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::UUID: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::INT32: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::INT16: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::INT8: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::UINT64: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::UINT32: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::UINT16: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::UINT8: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::DOUBLE: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::FLOAT: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::DATE: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::TIMESTAMP: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::TIMESTAMP_NS: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::TIMESTAMP_MS: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::INTERVAL: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::LIST: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::ARRAY: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::MAP: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::UNION: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::INTERNAL_ID: { + templateCopyNonNullValue(vector, value, pos, + fallbackExtensionTypes); + } break; + case LogicalTypeID::NODE: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + case LogicalTypeID::REL: { + templateCopyNonNullValue(vector, value, pos, fallbackExtensionTypes); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +template +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) { + // TODO(Guodong): make this as a function. + setBitToZero(vector->validity.data(), pos); + vector->numNulls++; +} + +template<> +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, + std::int64_t pos) { + auto offsets = (std::uint32_t*)vector->data.data(); + if (pos == 0) { + offsets[pos] = 0; + } + offsets[pos + 1] = offsets[pos]; + setBitToZero(vector->validity.data(), pos); + vector->numNulls++; +} + +template<> +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, + std::int64_t pos) { + auto offsets = (std::uint32_t*)vector->data.data(); + if (pos == 0) { + offsets[pos] = 0; + } + offsets[pos + 1] = offsets[pos]; + setBitToZero(vector->validity.data(), pos); + vector->numNulls++; +} + +template<> +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, + std::int64_t pos) { + return templateCopyNullValue(vector, pos); +} + +template<> +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, + std::int64_t pos) { + setBitToZero(vector->validity.data(), pos); + vector->numNulls++; +} + +void ArrowRowBatch::copyNullValueUnion(ArrowVector* vector, const Value& value, std::int64_t pos) { + auto typeBuffer = (std::uint8_t*)vector->data.data(); + auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); + typeBuffer[pos] = 0; + offsetsBuffer[pos] = vector->childData[0]->numValues; + copyNullValue(vector->childData[0].get(), *value.children[0], pos); + vector->numNulls++; +} + +static void copyArrowArray(ArrowVector* vector, std::int64_t pos, uint64_t numElements) { + setBitToZero(vector->validity.data(), pos); + vector->numNulls++; + auto& child = vector->childData[0]; + child->numValues += numElements; +} + +void ArrowRowBatch::copyNullValue(ArrowVector* vector, const Value& value, std::int64_t pos) { + switch (value.dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::DECIMAL: + case LogicalTypeID::INT128: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::INT32: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::INT16: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::INT8: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::UINT64: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::UINT32: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::UINT16: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::UINT8: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::DOUBLE: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::FLOAT: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::DATE: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::TIMESTAMP_MS: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::TIMESTAMP_NS: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::TIMESTAMP: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::INTERVAL: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::UUID: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::LIST: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::ARRAY: { + copyArrowArray(vector, pos, ArrayType::getNumElements(value.dataType)); + } break; + case LogicalTypeID::MAP: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::INTERNAL_ID: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::UNION: { + copyNullValueUnion(vector, value, pos); + } break; + case LogicalTypeID::NODE: { + templateCopyNullValue(vector, pos); + } break; + case LogicalTypeID::REL: { + templateCopyNullValue(vector, pos); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +static void releaseArrowVector(ArrowArray* array) { + if (!array || !array->release) { + return; + } + array->release = nullptr; + auto holder = static_cast(array->private_data); + delete holder; +} + +static std::unique_ptr createArrayFromVector(ArrowVector& vector) { + auto result = std::make_unique(); + result->private_data = nullptr; + result->release = releaseArrowVector; + result->n_children = 0; + result->offset = 0; + result->dictionary = nullptr; + result->buffers = vector.buffers.data(); + result->null_count = vector.numNulls; + result->length = vector.numValues; + result->n_buffers = 1; + result->buffers[0] = vector.validity.data(); + if (vector.data.data() != nullptr) { + result->n_buffers++; + result->buffers[1] = vector.data.data(); + } + return result; +} + +template +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, const LogicalType& /*type*/, + bool) { + auto result = createArrayFromVector(vector); + vector.array = std::move(result); + return vector.array.get(); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& /*type*/, bool) { + auto result = createArrayFromVector(vector); + result->n_buffers = 3; + result->buffers[2] = vector.overflow.data(); + vector.array = std::move(result); + return vector.array.get(); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + auto result = createArrayFromVector(vector); + vector.childPointers.resize(1); + result->children = vector.childPointers.data(); + result->n_children = 1; + vector.childPointers[0] = convertVectorToArray(*vector.childData[0], + ListType::getChildType(type), fallbackExtensionTypes); + vector.array = std::move(result); + return vector.array.get(); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + auto result = createArrayFromVector(vector); + vector.childPointers.resize(1); + result->n_buffers = 1; + result->children = vector.childPointers.data(); + result->n_children = 1; + vector.childPointers[0] = convertVectorToArray(*vector.childData[0], + ArrayType::getChildType(type), fallbackExtensionTypes); + vector.array = std::move(result); + return vector.array.get(); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + return templateCreateArray(vector, type, fallbackExtensionTypes); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + return convertStructVectorToArray(vector, type, fallbackExtensionTypes); +} + +ArrowArray* ArrowRowBatch::convertStructVectorToArray(ArrowVector& vector, const LogicalType& type, + bool fallbackExtensionTypes) { + auto result = createArrayFromVector(vector); + result->n_buffers = 1; + vector.childPointers.resize(StructType::getNumFields(type)); + result->children = vector.childPointers.data(); + result->n_children = (std::int64_t)StructType::getNumFields(type); + for (auto i = 0u; i < StructType::getNumFields(type); i++) { + const auto& childType = StructType::getFieldType(type, i); + vector.childPointers[i] = + convertVectorToArray(*vector.childData[i], childType, fallbackExtensionTypes); + } + vector.array = std::move(result); + return vector.array.get(); +} + +ArrowArray* ArrowRowBatch::convertInternalIDVectorToArray(ArrowVector& vector, + const LogicalType& /*type*/, bool fallbackExtensionTypes) { + auto result = createArrayFromVector(vector); + result->n_buffers = 1; + vector.childPointers.resize(2); + result->children = vector.childPointers.data(); + result->n_children = 2; + for (auto i = 0; i < 2; i++) { + auto childType = LogicalType::INT64(); + vector.childPointers[i] = + convertVectorToArray(*vector.childData[i], childType, fallbackExtensionTypes); + } + vector.array = std::move(result); + return vector.array.get(); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + // since union is a special case, we make the ArrowArray ourselves instead of using + // createArrayFromVector + auto nChildren = UnionType::getNumFields(type); + vector.array = std::make_unique(); + vector.array->private_data = nullptr; + vector.array->release = releaseArrowVector; + vector.array->n_children = nChildren; + vector.childPointers.resize(nChildren); + vector.array->children = vector.childPointers.data(); + vector.array->offset = 0; + vector.array->dictionary = nullptr; + vector.array->buffers = vector.buffers.data(); + vector.array->null_count = 0; + vector.array->length = vector.numValues; + vector.array->n_buffers = 2; + vector.array->buffers[0] = vector.data.data(); + vector.array->buffers[1] = vector.overflow.data(); + for (auto i = 0u; i < nChildren; i++) { + const auto& childType = UnionType::getFieldType(type, i); + vector.childPointers[i] = + convertVectorToArray(*vector.childData[i], childType, fallbackExtensionTypes); + } + return vector.array.get(); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + return convertInternalIDVectorToArray(vector, type, fallbackExtensionTypes); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + return convertStructVectorToArray(vector, type, fallbackExtensionTypes); +} + +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type, bool fallbackExtensionTypes) { + return convertStructVectorToArray(vector, type, fallbackExtensionTypes); +} + +ArrowArray* ArrowRowBatch::convertVectorToArray(ArrowVector& vector, const LogicalType& type, + bool fallbackExtensionTypes) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::DECIMAL: + case LogicalTypeID::INT128: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::INT32: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::INT16: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::INT8: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::UINT64: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::UINT32: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::UINT16: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::UINT8: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::DOUBLE: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::FLOAT: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::DATE: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::TIMESTAMP_MS: { + return templateCreateArray(vector, type, + fallbackExtensionTypes); + } + case LogicalTypeID::TIMESTAMP_NS: { + return templateCreateArray(vector, type, + fallbackExtensionTypes); + } + case LogicalTypeID::TIMESTAMP_SEC: { + return templateCreateArray(vector, type, + fallbackExtensionTypes); + } + case LogicalTypeID::TIMESTAMP_TZ: { + return templateCreateArray(vector, type, + fallbackExtensionTypes); + } + case LogicalTypeID::TIMESTAMP: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::INTERVAL: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::UUID: { + if (!fallbackExtensionTypes) { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + [[fallthrough]]; + } + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::LIST: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::ARRAY: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::MAP: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::UNION: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::INTERNAL_ID: { + return templateCreateArray(vector, type, + fallbackExtensionTypes); + } + case LogicalTypeID::NODE: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + case LogicalTypeID::REL: { + return templateCreateArray(vector, type, fallbackExtensionTypes); + } + default: { + KU_UNREACHABLE; + } + } +} + +ArrowArray ArrowRowBatch::toArray(const std::vector& types) { + auto rootHolder = std::make_unique(); + ArrowArray result{}; + rootHolder->childPointers.resize(vectors.size()); + result.children = rootHolder->childPointers.data(); + result.n_children = (std::int64_t)vectors.size(); + result.length = numTuples; + result.n_buffers = 1; + result.buffers = rootHolder->buffers.data(); // no actual buffer + result.offset = 0; + result.null_count = 0; + result.dictionary = nullptr; + rootHolder->childData = std::move(vectors); + for (auto i = 0u; i < rootHolder->childData.size(); i++) { + rootHolder->childPointers[i] = + convertVectorToArray(*rootHolder->childData[i], types[i], fallbackExtensionTypes); + } + result.private_data = rootHolder.release(); + result.release = releaseArrowVector; + return result; +} + +void ArrowRowBatch::append(const processor::FlatTuple& tuple) { + for (auto i = 0u; i < vectors.size(); i++) { + appendValue(vectors[i].get(), tuple[i], fallbackExtensionTypes); + } + numTuples++; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_type.cpp new file mode 100644 index 0000000000..7d17fc541a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/arrow/arrow_type.cpp @@ -0,0 +1,151 @@ +#include "common/arrow/arrow_converter.h" +#include "common/exception/not_implemented.h" +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +// Pyarrow format string specifications can be found here +// https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + +LogicalType ArrowConverter::fromArrowSchema(const ArrowSchema* schema) { + const char* arrowType = schema->format; + std::vector structFields; + // If we have a dictionary, then the logical type of the column is dependent upon the + // logical type of the dict + if (schema->dictionary != nullptr) { + return fromArrowSchema(schema->dictionary); + } + switch (arrowType[0]) { + case 'n': + return LogicalType(LogicalTypeID::ANY); + case 'b': + return LogicalType(LogicalTypeID::BOOL); + case 'c': + return LogicalType(LogicalTypeID::INT8); + case 'C': + return LogicalType(LogicalTypeID::UINT8); + case 's': + return LogicalType(LogicalTypeID::INT16); + case 'S': + return LogicalType(LogicalTypeID::UINT16); + case 'i': + return LogicalType(LogicalTypeID::INT32); + case 'I': + return LogicalType(LogicalTypeID::UINT32); + case 'l': + return LogicalType(LogicalTypeID::INT64); + case 'L': + return LogicalType(LogicalTypeID::UINT64); + case 'e': + throw NotImplementedException("16 bit floats are not supported"); + case 'f': + return LogicalType(LogicalTypeID::FLOAT); + case 'g': + return LogicalType(LogicalTypeID::DOUBLE); + case 'z': + case 'Z': + return LogicalType(LogicalTypeID::BLOB); + case 'u': + case 'U': + return LogicalType(LogicalTypeID::STRING); + case 'v': + switch (arrowType[1]) { + case 'z': + return LogicalType(LogicalTypeID::BLOB); + case 'u': + return LogicalType(LogicalTypeID::STRING); + default: + KU_UNREACHABLE; + } + + case 'd': { + auto split = StringUtils::splitComma(std::string(arrowType + 2)); + if (split.size() > 2 && split[2] != "128") { + throw NotImplementedException("Decimal bitwidths other than 128 are not implemented"); + } + return LogicalType::DECIMAL(stoul(split[0]), stoul(split[1])); + } + case 'w': + return LogicalType(LogicalTypeID::BLOB); // fixed width binary + case 't': + switch (arrowType[1]) { + case 'd': + if (arrowType[2] == 'D') { + return LogicalType(LogicalTypeID::DATE); + } else { + return LogicalType(LogicalTypeID::TIMESTAMP_MS); + } + case 't': + // TODO implement pure time type + throw NotImplementedException("Pure time types are not supported"); + case 's': + // TODO maxwell: timezone support + switch (arrowType[2]) { + case 's': + return LogicalType(LogicalTypeID::TIMESTAMP_SEC); + case 'm': + return LogicalType(LogicalTypeID::TIMESTAMP_MS); + case 'u': + return LogicalType(LogicalTypeID::TIMESTAMP); + case 'n': + return LogicalType(LogicalTypeID::TIMESTAMP_NS); + default: + KU_UNREACHABLE; + } + case 'D': + // duration + case 'i': + // interval + return LogicalType(LogicalTypeID::INTERVAL); + default: + KU_UNREACHABLE; + } + case '+': + KU_ASSERT(schema->n_children > 0); + switch (arrowType[1]) { + // complex types need a complementary ExtraTypeInfo object + case 'l': + case 'L': + return LogicalType::LIST(LogicalType(fromArrowSchema(schema->children[0]))); + case 'w': + return LogicalType::ARRAY(LogicalType(fromArrowSchema(schema->children[0])), + std::stoul(arrowType + 3)); + case 's': + for (int64_t i = 0; i < schema->n_children; i++) { + structFields.emplace_back(std::string(schema->children[i]->name), + LogicalType(fromArrowSchema(schema->children[i]))); + } + return LogicalType::STRUCT(std::move(structFields)); + case 'm': + return LogicalType::MAP(LogicalType(fromArrowSchema(schema->children[0]->children[0])), + LogicalType(fromArrowSchema(schema->children[0]->children[1]))); + case 'u': { + for (int64_t i = 0; i < schema->n_children; i++) { + structFields.emplace_back(std::to_string(i), + LogicalType(fromArrowSchema(schema->children[i]))); + } + return LogicalType::UNION(std::move(structFields)); + } + case 'v': + switch (arrowType[2]) { + case 'l': + case 'L': + return LogicalType::LIST(LogicalType(fromArrowSchema(schema->children[0]))); + default: + KU_UNREACHABLE; + } + case 'r': + // logical type corresponds to second child + return fromArrowSchema(schema->children[1]); + default: + KU_UNREACHABLE; + } + default: + KU_UNREACHABLE; + } + // refer to arrow_converted.cpp:65 +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/case_insensitive_map.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/case_insensitive_map.cpp new file mode 100644 index 0000000000..b00a217757 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/case_insensitive_map.cpp @@ -0,0 +1,18 @@ +#include "common/case_insensitive_map.h" + +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +uint64_t CaseInsensitiveStringHashFunction::operator()(const std::string& str) const { + return common::StringUtils::caseInsensitiveHash(str); +} + +bool CaseInsensitiveStringEquality::operator()(const std::string& lhs, + const std::string& rhs) const { + return common::StringUtils::caseInsensitiveEquals(lhs, rhs); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/checksum.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/checksum.cpp new file mode 100644 index 0000000000..2abeb6cc6a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/checksum.cpp @@ -0,0 +1,88 @@ +/** + * The implementation for checksumming is taken from DuckDB: + * https://github.com/duckdb/duckdb/blob/v1.3-ossivalis/src/common/checksum.cpp + * https://github.com/duckdb/duckdb/blob/v1.3-ossivalis/LICENSE + */ + +#include "common/checksum.h" + +#include "common/types/types.h" + +namespace lbug::common { + +hash_t checksum(uint64_t x) { + return x * UINT64_C(0xbf58476d1ce4e5b9); +} + +// MIT License +// Copyright (c) 2018-2021 Martin Ankerl +// https://github.com/martinus/robin-hood-hashing/blob/3.11.5/LICENSE +hash_t checksumRemainder(void* ptr, size_t len) noexcept { + static constexpr uint64_t M = UINT64_C(0xc6a4a7935bd1e995); + static constexpr uint64_t SEED = UINT64_C(0xe17a1465); + static constexpr unsigned int R = 47; + + auto const* const data64 = static_cast(ptr); + uint64_t h = SEED ^ (len * M); + + size_t const n_blocks = len / 8; + for (size_t i = 0; i < n_blocks; ++i) { + auto k = *reinterpret_cast(data64 + i); + + k *= M; + k ^= k >> R; + k *= M; + + h ^= k; + h *= M; + } + + auto const* const data8 = reinterpret_cast(data64 + n_blocks); + switch (len & 7U) { + case 7: + h ^= static_cast(data8[6]) << 48U; + [[fallthrough]]; + case 6: + h ^= static_cast(data8[5]) << 40U; + [[fallthrough]]; + case 5: + h ^= static_cast(data8[4]) << 32U; + [[fallthrough]]; + case 4: + h ^= static_cast(data8[3]) << 24U; + [[fallthrough]]; + case 3: + h ^= static_cast(data8[2]) << 16U; + [[fallthrough]]; + case 2: + h ^= static_cast(data8[1]) << 8U; + [[fallthrough]]; + case 1: + h ^= static_cast(data8[0]); + h *= M; + [[fallthrough]]; + default: + break; + } + h ^= h >> R; + h *= M; + h ^= h >> R; + return static_cast(h); +} + +uint64_t checksum(uint8_t* buffer, size_t size) { + uint64_t result = 5381; + uint64_t* ptr = reinterpret_cast(buffer); + size_t i{}; + // for efficiency, we first checksum uint64_t values + for (i = 0; i < size / 8; i++) { + result ^= checksum(ptr[i]); + } + if (size - i * 8 > 0) { + // the remaining 0-7 bytes we hash using a string hash + result ^= checksumRemainder(buffer + i * 8, size - i * 8); + } + return result; +} + +} // namespace lbug::common diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/constants.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/constants.cpp new file mode 100644 index 0000000000..ab9e2f3bec --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/constants.cpp @@ -0,0 +1,6 @@ +namespace lbug { +namespace common { + +const char* LBUG_VERSION = LBUG_CMAKE_VERSION; +} +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/CMakeLists.txt new file mode 100644 index 0000000000..2ea2e7492a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_common_copier_config + OBJECT + csv_reader_config.cpp + reader_config.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/csv_reader_config.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/csv_reader_config.cpp new file mode 100644 index 0000000000..39f978efb2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/csv_reader_config.cpp @@ -0,0 +1,170 @@ +#include "common/copier_config/csv_reader_config.h" + +#include + +#include "common/exception/binder.h" +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include "common/types/value/nested.h" + +namespace lbug { +namespace common { + +static char bindParsingOptionValue(const std::string& value) { + if (value == "\\t") { + return '\t'; + } + if (value.length() < 1 || value.length() > 2 || (value.length() == 2 && value[0] != '\\')) { + throw BinderException("Copy csv option value must be a single character with an " + "optional escape character."); + } + return value[value.length() - 1]; +} + +static void bindBoolParsingOption(CSVReaderConfig& config, const std::string& optionName, + bool optionValue) { + if (optionName == "HEADER") { + config.option.hasHeader = optionValue; + config.option.setHeader = true; + } else if (optionName == "PARALLEL") { + config.parallel = optionValue; + } else if (optionName == "LIST_UNBRACED") { + config.option.allowUnbracedList = optionValue; + } else if (optionName == CopyConstants::IGNORE_ERRORS_OPTION_NAME) { + config.option.ignoreErrors = optionValue; + } else if (optionName == "AUTODETECT" || optionName == "AUTO_DETECT") { + config.option.autoDetection = optionValue; + } else { + KU_UNREACHABLE; + } +} + +static void bindStringParsingOption(CSVReaderConfig& config, const std::string& optionName, + const std::string& optionValue) { + auto parsingOptionValue = bindParsingOptionValue(optionValue); + if (optionName == "ESCAPE") { + config.option.escapeChar = parsingOptionValue; + config.option.setEscape = true; + } else if (optionName == "DELIM" || optionName == "DELIMITER") { + config.option.delimiter = parsingOptionValue; + config.option.setDelim = true; + } else if (optionName == "QUOTE") { + config.option.quoteChar = parsingOptionValue; + config.option.setQuote = true; + } else { + KU_UNREACHABLE; + } +} + +static void bindIntParsingOption(CSVReaderConfig& config, const std::string& optionName, + const int64_t& optionValue) { + if (optionName == "SKIP") { + if (optionValue < 0) { + throw RuntimeException{"Skip number must be a non-negative integer"}; + } + config.option.skipNum = optionValue; + } else if (optionName == "SAMPLE_SIZE") { + if (optionValue < 0) { + // technically impossible at the moment since negative values aren't supported + // in parameters + throw RuntimeException{"Sample size must be a non-negative integer"}; + } + config.option.sampleSize = optionValue; + } else { + KU_UNREACHABLE; + } +} + +static void bindListParsingOption(CSVReaderConfig& config, const std::string& optionName, + const std::vector& optionValue) { + if (optionName == "NULL_STRINGS") { + config.option.nullStrings = optionValue; + } else { + KU_UNREACHABLE; + } +} + +template +static bool hasOption(const char* const (&arr)[size], const std::string& option) { + return std::find(std::begin(arr), std::end(arr), option) != std::end(arr); +} + +static bool validateBoolParsingOptionName(const std::string& parsingOptionName) { + return hasOption(CopyConstants::BOOL_CSV_PARSING_OPTIONS, parsingOptionName); +} + +static bool validateStringParsingOptionName(const std::string& parsingOptionName) { + return hasOption(CopyConstants::STRING_CSV_PARSING_OPTIONS, parsingOptionName); +} + +static bool validateIntParsingOptionName(const std::string& parsingOptionName) { + return hasOption(CopyConstants::INT_CSV_PARSING_OPTIONS, parsingOptionName); +} + +static bool validateListParsingOptionName(const std::string& parsingOptionName) { + return hasOption(CopyConstants::LIST_CSV_PARSING_OPTIONS, parsingOptionName); +} + +static bool isValidBooleanOptionValue(const Value& value, const std::string& name) { + // Normalize and check if the string is a valid Boolean representation + auto strValue = value.toString(); + StringUtils::toUpper(strValue); + + // Check for valid Boolean string representations + if (strValue == "TRUE" || strValue == "1") { + return true; + } else if (strValue == "FALSE" || strValue == "0") { + return false; + } else { + // In this case the boolean is not valid + throw BinderException( + stringFormat("The type of csv parsing option {} must be a boolean.", name)); + } +} + +CSVReaderConfig CSVReaderConfig::construct(const case_insensitive_map_t& options) { + auto config = CSVReaderConfig(); + for (auto& op : options) { + auto name = op.first; + auto isValidStringParsingOption = validateStringParsingOptionName(name); + auto isValidBoolParsingOption = validateBoolParsingOptionName(name); + auto isValidIntParsingOption = validateIntParsingOptionName(name); + auto isValidListParsingOption = validateListParsingOptionName(name); + if (isValidBoolParsingOption) { + bindBoolParsingOption(config, name, isValidBooleanOptionValue(op.second, name)); + } else if (isValidStringParsingOption) { + if (op.second.getDataType() != LogicalType::STRING()) { + throw BinderException( + stringFormat("The type of csv parsing option {} must be a string.", name)); + } + bindStringParsingOption(config, name, op.second.getValue()); + } else if (isValidIntParsingOption) { + if (op.second.getDataType() != LogicalType::INT64()) { + throw BinderException( + stringFormat("The type of csv parsing option {} must be a INT64.", name)); + } + bindIntParsingOption(config, name, op.second.getValue()); + } else if (isValidListParsingOption) { + if (op.second.getDataType() != LogicalType::LIST(LogicalType::STRING())) { + throw BinderException( + stringFormat("The type of csv parsing option {} must be a STRING[].", name)); + } + std::vector optionValues; + for (auto i = 0u; i < op.second.getChildrenSize(); i++) { + optionValues.push_back( + NestedVal::getChildVal(&op.second, i)->getValue()); + } + bindListParsingOption(config, name, optionValues); + } else { + throw BinderException(stringFormat("Unrecognized csv parsing option: {}.", name)); + } + } + if (config.option.skipNum > 0) { + // If the user sets the number of rows to skip, we cannot read in parallel mode. + config.parallel = false; + } + return config; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/reader_config.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/reader_config.cpp new file mode 100644 index 0000000000..2aaae3b2d1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/copier_config/reader_config.cpp @@ -0,0 +1,56 @@ +#include "common/assert.h" +#include "common/copier_config/file_scan_info.h" +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +FileType FileTypeUtils::getFileTypeFromExtension(std::string_view extension) { + if (extension == ".csv") { + return FileType::CSV; + } + if (extension == ".parquet") { + return FileType::PARQUET; + } + if (extension == ".npy") { + return FileType::NPY; + } + return FileType::UNKNOWN; +} + +std::string FileTypeUtils::toString(FileType fileType) { + switch (fileType) { + case FileType::UNKNOWN: { + return "UNKNOWN"; + } + case FileType::CSV: { + return "CSV"; + } + case FileType::PARQUET: { + return "PARQUET"; + } + case FileType::NPY: { + return "NPY"; + } + default: { + KU_UNREACHABLE; + } + } +} + +FileType FileTypeUtils::fromString(std::string fileType) { + fileType = common::StringUtils::getUpper(fileType); + if (fileType == "CSV") { + return FileType::CSV; + } else if (fileType == "PARQUET") { + return FileType::PARQUET; + } else if (fileType == "NPY") { + return FileType::NPY; + } else { + return FileType::UNKNOWN; + // throw BinderException(stringFormat("Unsupported file type: {}.", fileType)); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/CMakeLists.txt new file mode 100644 index 0000000000..d2fe29e0de --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(lbug_common_data_chunk + OBJECT + data_chunk.cpp + data_chunk_collection.cpp + data_chunk_state.cpp + sel_vector.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk.cpp new file mode 100644 index 0000000000..85d08396a5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk.cpp @@ -0,0 +1,19 @@ +#include "common/data_chunk/data_chunk.h" + +namespace lbug { +namespace common { + +void DataChunk::insert(uint32_t pos, std::shared_ptr valueVector) { + valueVector->setState(state); + KU_ASSERT(valueVectors.size() > pos); + valueVectors[pos] = std::move(valueVector); +} + +void DataChunk::resetAuxiliaryBuffer() { + for (auto& valueVector : valueVectors) { + valueVector->resetAuxiliaryBuffer(); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk_collection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk_collection.cpp new file mode 100644 index 0000000000..1b02a69bc4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk_collection.cpp @@ -0,0 +1,57 @@ +#include "common/data_chunk/data_chunk_collection.h" + +#include "common/system_config.h" + +namespace lbug { +namespace common { + +DataChunkCollection::DataChunkCollection(storage::MemoryManager* mm) : mm{mm} {} + +void DataChunkCollection::append(DataChunk& chunk) { + auto numTuplesToAppend = chunk.state->getSelVector().getSelSize(); + auto numTuplesAppended = 0u; + while (numTuplesAppended < numTuplesToAppend) { + if (chunks.empty() || + chunks.back().state->getSelVector().getSelSize() == DEFAULT_VECTOR_CAPACITY) { + allocateChunk(chunk); + } + auto& chunkToAppend = chunks.back(); + auto numTuplesToCopy = std::min((uint64_t)numTuplesToAppend - numTuplesAppended, + DEFAULT_VECTOR_CAPACITY - chunkToAppend.state->getSelVector().getSelSize()); + for (auto vectorIdx = 0u; vectorIdx < chunk.getNumValueVectors(); vectorIdx++) { + for (auto i = 0u; i < numTuplesToCopy; i++) { + auto srcPos = chunk.state->getSelVector()[numTuplesAppended + i]; + auto dstPos = chunkToAppend.state->getSelVector().getSelSize() + i; + chunkToAppend.getValueVectorMutable(vectorIdx).copyFromVectorData(dstPos, + &chunk.getValueVector(vectorIdx), srcPos); + } + } + chunkToAppend.state->getSelVectorUnsafe().incrementSelSize(numTuplesToCopy); + numTuplesAppended += numTuplesToCopy; + } +} + +void DataChunkCollection::initTypes(const DataChunk& chunk) { + types.clear(); + types.reserve(chunk.getNumValueVectors()); + for (auto vectorIdx = 0u; vectorIdx < chunk.getNumValueVectors(); vectorIdx++) { + types.push_back(chunk.getValueVector(vectorIdx).dataType.copy()); + } +} + +void DataChunkCollection::allocateChunk(const DataChunk& chunk) { + if (chunks.empty()) { + types.reserve(chunk.getNumValueVectors()); + for (auto vectorIdx = 0u; vectorIdx < chunk.getNumValueVectors(); vectorIdx++) { + types.push_back(chunk.getValueVector(vectorIdx).dataType.copy()); + } + } + DataChunk newChunk(types.size(), std::make_shared()); + for (auto i = 0u; i < types.size(); i++) { + newChunk.insert(i, std::make_shared(types[i].copy(), mm)); + } + chunks.push_back(std::move(newChunk)); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk_state.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk_state.cpp new file mode 100644 index 0000000000..36afb266a7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/data_chunk_state.cpp @@ -0,0 +1,18 @@ +#include "common/data_chunk/data_chunk_state.h" + +#include "common/system_config.h" + +namespace lbug { +namespace common { + +DataChunkState::DataChunkState() : DataChunkState(DEFAULT_VECTOR_CAPACITY) {} + +std::shared_ptr DataChunkState::getSingleValueDataChunkState() { + auto state = std::make_shared(1); + state->initOriginalAndSelectedSize(1); + state->setToFlat(); + return state; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/sel_vector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/sel_vector.cpp new file mode 100644 index 0000000000..14aec93f50 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/data_chunk/sel_vector.cpp @@ -0,0 +1,48 @@ +#include "common/data_chunk/sel_vector.h" + +#include +#include + +#include "common/system_config.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +// NOLINTNEXTLINE(cert-err58-cpp): always evaluated at compile time, and even not it would not throw +static const std::array INCREMENTAL_SELECTED_POS = + []() constexpr noexcept { + std::array selectedPos{}; + std::iota(selectedPos.begin(), selectedPos.end(), 0); + return selectedPos; + }(); + +SelectionView::SelectionView(sel_t selectedSize) + : selectedPositions{INCREMENTAL_SELECTED_POS.data()}, selectedSize{selectedSize}, + state{State::STATIC} {} + +SelectionVector::SelectionVector() : SelectionVector{DEFAULT_VECTOR_CAPACITY} {} + +void SelectionVector::setToUnfiltered() { + selectedPositions = INCREMENTAL_SELECTED_POS.data(); + state = State::STATIC; +} +void SelectionVector::setToUnfiltered(sel_t size) { + KU_ASSERT(size <= capacity); + selectedPositions = INCREMENTAL_SELECTED_POS.data(); + selectedSize = size; + state = State::STATIC; +} + +std::vector SelectionVector::fromValueVectors( + const std::vector>& vec) { + std::vector ret(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + ret[i] = vec[i]->getSelVectorPtr(); + } + return ret; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/database_lifecycle_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/database_lifecycle_manager.cpp new file mode 100644 index 0000000000..48b7714f41 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/database_lifecycle_manager.cpp @@ -0,0 +1,14 @@ +#include "common/database_lifecycle_manager.h" + +#include "common/exception/runtime.h" + +namespace lbug { +namespace common { +void DatabaseLifeCycleManager::checkDatabaseClosedOrThrow() const { + if (isDatabaseClosed) { + throw RuntimeException( + "The current operation is not allowed because the parent database is closed."); + } +} +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/CMakeLists.txt new file mode 100644 index 0000000000..f789e52a0d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(lbug_common_enums + OBJECT + accumulate_type.cpp + path_semantic.cpp + query_rel_type.cpp + rel_direction.cpp + rel_multiplicity.cpp + scan_source_type.cpp + table_type.cpp + transaction_action.cpp + drop_type.cpp + extend_direction_util.cpp + conflict_action.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/accumulate_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/accumulate_type.cpp new file mode 100644 index 0000000000..b6a2d56e91 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/accumulate_type.cpp @@ -0,0 +1,22 @@ +#include "common/enums/accumulate_type.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +std::string AccumulateTypeUtil::toString(AccumulateType type) { + switch (type) { + case AccumulateType::REGULAR: { + return "REGULAR"; + } + case AccumulateType::OPTIONAL_: { + return "OPTIONAL"; + } + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/conflict_action.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/conflict_action.cpp new file mode 100644 index 0000000000..054069eb18 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/conflict_action.cpp @@ -0,0 +1,22 @@ +#include "common/enums/conflict_action.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +std::string ConflictActionUtil::toString(ConflictAction action) { + switch (action) { + case ConflictAction::ON_CONFLICT_THROW: { + return "ON_CONFLICT_THROW"; + } + case ConflictAction::ON_CONFLICT_DO_NOTHING: { + return "ON_CONFLICT_DO_NOTHING"; + } + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/drop_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/drop_type.cpp new file mode 100644 index 0000000000..3ac4ad7630 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/drop_type.cpp @@ -0,0 +1,20 @@ +#include "common/enums/drop_type.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +std::string DropTypeUtils::toString(DropType type) { + switch (type) { + case DropType::TABLE: + return "Table"; + case DropType::SEQUENCE: + return "Sequence"; + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/extend_direction_util.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/extend_direction_util.cpp new file mode 100644 index 0000000000..fc4f249bda --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/extend_direction_util.cpp @@ -0,0 +1,36 @@ +#include "common/enums/extend_direction_util.h" + +#include "common/exception/runtime.h" +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +ExtendDirection ExtendDirectionUtil::fromString(const std::string& str) { + auto normalizedString = StringUtils::getUpper(str); + if (normalizedString == "FWD") { + return ExtendDirection::FWD; + } else if (normalizedString == "BWD") { + return ExtendDirection::BWD; + } else if (normalizedString == "BOTH") { + return ExtendDirection::BOTH; + } else { + throw RuntimeException(stringFormat("Cannot parse {} as ExtendDirection.", str)); + } +} + +std::string ExtendDirectionUtil::toString(ExtendDirection direction) { + switch (direction) { + case ExtendDirection::FWD: + return "fwd"; + case ExtendDirection::BWD: + return "bwd"; + case ExtendDirection::BOTH: + return "both"; + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/path_semantic.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/path_semantic.cpp new file mode 100644 index 0000000000..9ab229e09c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/path_semantic.cpp @@ -0,0 +1,40 @@ +#include "common/enums/path_semantic.h" + +#include "common/assert.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +PathSemantic PathSemanticUtils::fromString(const std::string& str) { + auto normalizedStr = StringUtils::getUpper(str); + if (normalizedStr == "WALK") { + return PathSemantic::WALK; + } + if (normalizedStr == "TRAIL") { + return PathSemantic::TRAIL; + } + if (normalizedStr == "ACYCLIC") { + return PathSemantic::ACYCLIC; + } + throw BinderException(stringFormat( + "Cannot parse {} as a path semantic. Supported inputs are [WALK, TRAIL, ACYCLIC]", str)); +} + +std::string PathSemanticUtils::toString(PathSemantic semantic) { + switch (semantic) { + case PathSemantic::WALK: + return "WALK"; + case PathSemantic::TRAIL: + return "TRAIL"; + case PathSemantic::ACYCLIC: + return "ACYCLIC"; + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/query_rel_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/query_rel_type.cpp new file mode 100644 index 0000000000..fb9f417b46 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/query_rel_type.cpp @@ -0,0 +1,53 @@ +#include "common/enums/query_rel_type.h" + +#include "common/assert.h" +#include "function/gds/gds_function_collection.h" + +using namespace lbug::function; + +namespace lbug { +namespace common { + +PathSemantic QueryRelTypeUtils::getPathSemantic(QueryRelType queryRelType) { + switch (queryRelType) { + case QueryRelType::VARIABLE_LENGTH_WALK: + return PathSemantic::WALK; + case QueryRelType::VARIABLE_LENGTH_TRAIL: + return PathSemantic::TRAIL; + case QueryRelType::VARIABLE_LENGTH_ACYCLIC: + case QueryRelType::SHORTEST: + case QueryRelType::ALL_SHORTEST: + case QueryRelType::WEIGHTED_SHORTEST: + case QueryRelType::ALL_WEIGHTED_SHORTEST: + return PathSemantic::ACYCLIC; + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr QueryRelTypeUtils::getFunction(QueryRelType type) { + switch (type) { + case QueryRelType::VARIABLE_LENGTH_WALK: + case QueryRelType::VARIABLE_LENGTH_TRAIL: + case QueryRelType::VARIABLE_LENGTH_ACYCLIC: { + return VarLenJoinsFunction::getAlgorithm(); + } + case QueryRelType::SHORTEST: { + return SingleSPPathsFunction::getAlgorithm(); + } + case QueryRelType::ALL_SHORTEST: { + return AllSPPathsFunction::getAlgorithm(); + } + case QueryRelType::WEIGHTED_SHORTEST: { + return WeightedSPPathsFunction::getAlgorithm(); + } + case QueryRelType::ALL_WEIGHTED_SHORTEST: { + return AllWeightedSPPathsFunction::getAlgorithm(); + } + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/rel_direction.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/rel_direction.cpp new file mode 100644 index 0000000000..310cf9da33 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/rel_direction.cpp @@ -0,0 +1,44 @@ +#include "common/enums/rel_direction.h" + +#include + +#include "common/assert.h" + +namespace lbug { +namespace common { + +RelDataDirection RelDirectionUtils::getOppositeDirection(RelDataDirection direction) { + static constexpr std::array oppositeDirections = {RelDataDirection::BWD, RelDataDirection::FWD}; + return oppositeDirections[relDirectionToKeyIdx(direction)]; +} + +std::string RelDirectionUtils::relDirectionToString(RelDataDirection direction) { + static constexpr std::array directionStrs = {"fwd", "bwd"}; + return directionStrs[relDirectionToKeyIdx(direction)]; +} + +idx_t RelDirectionUtils::relDirectionToKeyIdx(RelDataDirection direction) { + switch (direction) { + case RelDataDirection::FWD: + return 0; + case RelDataDirection::BWD: + return 1; + default: + KU_UNREACHABLE; + } +} + +table_id_t RelDirectionUtils::getNbrTableID(RelDataDirection direction, table_id_t srcTableID, + table_id_t dstTableID) { + switch (direction) { + case RelDataDirection::FWD: + return dstTableID; + case RelDataDirection::BWD: + return srcTableID; + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/rel_multiplicity.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/rel_multiplicity.cpp new file mode 100644 index 0000000000..2f1a6bef16 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/rel_multiplicity.cpp @@ -0,0 +1,43 @@ +#include "common/enums/rel_multiplicity.h" + +#include "common/assert.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +RelMultiplicity RelMultiplicityUtils::getFwd(const std::string& str) { + auto normStr = common::StringUtils::getUpper(str); + if ("ONE_ONE" == normStr || "ONE_MANY" == normStr) { + return RelMultiplicity::ONE; + } else if ("MANY_ONE" == normStr || "MANY_MANY" == normStr) { + return RelMultiplicity::MANY; + } + throw BinderException(stringFormat("Cannot bind {} as relationship multiplicity.", str)); +} + +RelMultiplicity RelMultiplicityUtils::getBwd(const std::string& str) { + auto normStr = common::StringUtils::getUpper(str); + if ("ONE_ONE" == normStr || "MANY_ONE" == normStr) { + return RelMultiplicity::ONE; + } else if ("ONE_MANY" == normStr || "MANY_MANY" == normStr) { + return RelMultiplicity::MANY; + } + throw BinderException(stringFormat("Cannot bind {} as relationship multiplicity.", str)); +} + +std::string RelMultiplicityUtils::toString(RelMultiplicity multiplicity) { + switch (multiplicity) { + case RelMultiplicity::ONE: + return "ONE"; + case RelMultiplicity::MANY: + return "MANY"; + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/scan_source_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/scan_source_type.cpp new file mode 100644 index 0000000000..84908a2b26 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/scan_source_type.cpp @@ -0,0 +1,28 @@ +#include "common/enums/scan_source_type.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +std::string ScanSourceTypeUtils::toString(ScanSourceType type) { + switch (type) { + case ScanSourceType::EMPTY: { + return "EMPTY"; + } + case ScanSourceType::FILE: { + return "FILE"; + } + case ScanSourceType::OBJECT: { + return "OBJECT"; + } + case ScanSourceType::QUERY: { + return "QUERY"; + } + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/table_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/table_type.cpp new file mode 100644 index 0000000000..e68252ff6d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/table_type.cpp @@ -0,0 +1,28 @@ +#include "common/enums/table_type.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +std::string TableTypeUtils::toString(TableType tableType) { + switch (tableType) { + case TableType::UNKNOWN: { + return "UNKNOWN"; + } + case TableType::NODE: { + return "NODE"; + } + case TableType::REL: { + return "REL"; + } + case TableType::FOREIGN: { + return "ATTACHED"; + } + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/transaction_action.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/transaction_action.cpp new file mode 100644 index 0000000000..5321b61b08 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/enums/transaction_action.cpp @@ -0,0 +1,31 @@ +#include "transaction/transaction_action.h" + +#include "common/assert.h" + +namespace lbug { +namespace transaction { + +std::string TransactionActionUtils::toString(TransactionAction action) { + switch (action) { + case TransactionAction::BEGIN_READ: { + return "BEGIN_READ"; + } + case TransactionAction::BEGIN_WRITE: { + return "BEGIN_WRITE"; + } + case TransactionAction::COMMIT: { + return "COMMIT"; + } + case TransactionAction::ROLLBACK: { + return "ROLLBACK"; + } + case TransactionAction::CHECKPOINT: { + return "CHECKPOINT"; + } + default: + KU_UNREACHABLE; + } +} + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/CMakeLists.txt new file mode 100644 index 0000000000..7f7036a7ad --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(lbug_common_exception + OBJECT + exception.cpp + message.cpp) + +if (ENABLE_BACKTRACES) + target_link_libraries(lbug_common_exception PRIVATE cpptrace::cpptrace) +endif() + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/exception.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/exception.cpp new file mode 100644 index 0000000000..550d1bb567 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/exception.cpp @@ -0,0 +1,17 @@ +#include "common/exception/exception.h" + +#ifdef LBUG_BACKTRACE +#include +#endif + +namespace lbug { +namespace common { + +Exception::Exception(std::string msg) : exception(), exception_message_(std::move(msg)) { +#ifdef LBUG_BACKTRACE + cpptrace::generate_trace(1 /*skip this function's frame*/).print(); +#endif +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/message.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/message.cpp new file mode 100644 index 0000000000..30570f6d5b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/exception/message.cpp @@ -0,0 +1,70 @@ +#include "common/exception/message.h" + +#include "common/string_format.h" + +namespace lbug { +namespace common { + +std::string ExceptionMessage::duplicatePKException(const std::string& pkString) { + return stringFormat("Found duplicated primary key value {}, which violates the uniqueness" + " constraint of the primary key column.", + pkString); +} + +std::string ExceptionMessage::nonExistentPKException(const std::string& pkString) { + return stringFormat("Unable to find primary key value {}.", pkString); +} + +std::string ExceptionMessage::invalidPKType(const std::string& type) { + return stringFormat("Invalid primary key column type {}. Primary keys must be either STRING or " + "a numeric type.", + type); +} + +std::string ExceptionMessage::nullPKException() { + return "Found NULL, which violates the non-null constraint of the primary key column."; +} + +std::string ExceptionMessage::overLargeStringPKValueException(uint64_t length) { + return stringFormat("The maximum length of primary key strings is 262144 bytes. The input " + "string's length was {}.", + length); +} + +std::string ExceptionMessage::overLargeStringValueException(uint64_t length) { + return stringFormat( + "The maximum length of strings is 262144 bytes. The input string's length was {}.", length); +} + +std::string ExceptionMessage::violateDeleteNodeWithConnectedEdgesConstraint( + const std::string& tableName, const std::string& offset, const std::string& direction) { + return stringFormat( + "Node(nodeOffset: {}) has connected edges in table {} in the {} direction, " + "which cannot be deleted. Please delete the edges first or try DETACH DELETE.", + offset, tableName, direction); +} + +std::string ExceptionMessage::violateRelMultiplicityConstraint(const std::string& tableName, + const std::string& offset, const std::string& direction) { + return stringFormat("Node(nodeOffset: {}) has more than one neighbour in table {} in the {} " + "direction, which violates the rel multiplicity constraint.", + offset, tableName, direction); +} + +std::string ExceptionMessage::variableNotInScope(const std::string& varName) { + return stringFormat("Variable {} is not in scope.", varName); +} + +std::string ExceptionMessage::listFunctionIncompatibleChildrenType(const std::string& functionName, + const std::string& leftType, const std::string& rightType) { + return std::string("Cannot bind " + functionName + " with parameter type " + leftType + + " and " + rightType + "."); +} + +std::string ExceptionMessage::invalidSkipLimitParam(const std::string& exprName, + const std::string& skipOrLimit) { + return stringFormat("Cannot evaluate {} as a valid {} number.", exprName, skipOrLimit); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/expression_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/expression_type.cpp new file mode 100644 index 0000000000..4210a3fe07 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/expression_type.cpp @@ -0,0 +1,136 @@ +#include "common/enums/expression_type.h" + +#include "common/assert.h" +#include "function/comparison/vector_comparison_functions.h" + +using namespace lbug::function; + +namespace lbug { +namespace common { + +bool ExpressionTypeUtil::isUnary(ExpressionType type) { + return ExpressionType::NOT == type || ExpressionType::IS_NULL == type || + ExpressionType::IS_NOT_NULL == type; +} + +bool ExpressionTypeUtil::isBinary(ExpressionType type) { + return isComparison(type) || ExpressionType::OR == type || ExpressionType::XOR == type || + ExpressionType::AND == type; +} + +bool ExpressionTypeUtil::isBoolean(ExpressionType type) { + return ExpressionType::OR == type || ExpressionType::XOR == type || + ExpressionType::AND == type || ExpressionType::NOT == type; +} + +bool ExpressionTypeUtil::isComparison(ExpressionType type) { + return ExpressionType::EQUALS == type || ExpressionType::NOT_EQUALS == type || + ExpressionType::GREATER_THAN == type || ExpressionType::GREATER_THAN_EQUALS == type || + ExpressionType::LESS_THAN == type || ExpressionType::LESS_THAN_EQUALS == type; +} + +bool ExpressionTypeUtil::isNullOperator(ExpressionType type) { + return ExpressionType::IS_NULL == type || ExpressionType::IS_NOT_NULL == type; +} + +ExpressionType ExpressionTypeUtil::reverseComparisonDirection(ExpressionType type) { + KU_ASSERT(isComparison(type)); + switch (type) { + case ExpressionType::GREATER_THAN: + return ExpressionType::LESS_THAN; + case ExpressionType::GREATER_THAN_EQUALS: + return ExpressionType::LESS_THAN_EQUALS; + case ExpressionType::LESS_THAN: + return ExpressionType::GREATER_THAN; + case ExpressionType::LESS_THAN_EQUALS: + return ExpressionType::GREATER_THAN_EQUALS; + default: + return type; + } +} + +// LCOV_EXCL_START +std::string ExpressionTypeUtil::toString(ExpressionType type) { + switch (type) { + case ExpressionType::OR: + return "OR"; + case ExpressionType::XOR: + return "XOR"; + case ExpressionType::AND: + return "AND"; + case ExpressionType::NOT: + return "NOT"; + case ExpressionType::EQUALS: + return EqualsFunction::name; + case ExpressionType::NOT_EQUALS: + return NotEqualsFunction::name; + case ExpressionType::GREATER_THAN: + return GreaterThanFunction::name; + case ExpressionType::GREATER_THAN_EQUALS: + return GreaterThanEqualsFunction::name; + case ExpressionType::LESS_THAN: + return LessThanFunction::name; + case ExpressionType::LESS_THAN_EQUALS: + return LessThanEqualsFunction::name; + case ExpressionType::IS_NULL: + return "IS_NULL"; + case ExpressionType::IS_NOT_NULL: + return "IS_NOT_NULL"; + case ExpressionType::PROPERTY: + return "PROPERTY"; + case ExpressionType::LITERAL: + return "LITERAL"; + case ExpressionType::STAR: + return "STAR"; + case ExpressionType::VARIABLE: + return "VARIABLE"; + case ExpressionType::PATH: + return "PATH"; + case ExpressionType::PATTERN: + return "PATTERN"; + case ExpressionType::PARAMETER: + return "PARAMETER"; + case ExpressionType::FUNCTION: + return "SCALAR_FUNCTION"; + case ExpressionType::AGGREGATE_FUNCTION: + return "AGGREGATE_FUNCTION"; + case ExpressionType::SUBQUERY: + return "SUBQUERY"; + case ExpressionType::CASE_ELSE: + return "CASE_ELSE"; + case ExpressionType::GRAPH: + return "GRAPH"; + case ExpressionType::LAMBDA: + return "LAMBDA"; + default: + KU_UNREACHABLE; + } +} + +std::string ExpressionTypeUtil::toParsableString(ExpressionType type) { + switch (type) { + case ExpressionType::EQUALS: + return "="; + case ExpressionType::NOT_EQUALS: + return "<>"; + case ExpressionType::GREATER_THAN: + return ">"; + case ExpressionType::GREATER_THAN_EQUALS: + return ">="; + case ExpressionType::LESS_THAN: + return "<"; + case ExpressionType::LESS_THAN_EQUALS: + return "<="; + case ExpressionType::IS_NULL: + return "IS NULL"; + case ExpressionType::IS_NOT_NULL: + return "IS NOT NULL"; + default: + throw RuntimeException(stringFormat( + "ExpressionTypeUtil::toParsableString not implemented for {}", toString(type))); + } +} +// LCOV_EXCL_STOP + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/CMakeLists.txt new file mode 100644 index 0000000000..0b91ee771a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/CMakeLists.txt @@ -0,0 +1,14 @@ +add_library(lbug_file_system + OBJECT + compressed_file_system.cpp + file_info.cpp + file_system.cpp + local_file_system.cpp + virtual_file_system.cpp + gzip_file_system.cpp) + +target_link_libraries(lbug_file_system Glob) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/compressed_file_system.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/compressed_file_system.cpp new file mode 100644 index 0000000000..be1f9053f3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/compressed_file_system.cpp @@ -0,0 +1,123 @@ +#include "common/file_system/compressed_file_system.h" + +#include + +#include "common/exception/io.h" + +namespace lbug { +namespace common { + +int64_t CompressedFileSystem::readFile(lbug::common::FileInfo& fileInfo, void* buf, + size_t numBytes) const { + auto& compressedFileInfo = fileInfo.cast(); + return compressedFileInfo.readData(buf, numBytes); +} + +void CompressedFileSystem::reset(lbug::common::FileInfo& fileInfo) { + auto& compressedFileInfo = fileInfo.cast(); + compressedFileInfo.childFileInfo->reset(); + compressedFileInfo.initialize(); +} + +uint64_t CompressedFileSystem::getFileSize(const lbug::common::FileInfo& fileInfo) const { + auto& compressedFileInfo = fileInfo.constCast(); + return compressedFileInfo.childFileInfo->getFileSize(); +} + +void CompressedFileSystem::syncFile(const lbug::common::FileInfo& fileInfo) const { + auto& compressedFileInfo = fileInfo.constCast(); + return compressedFileInfo.childFileInfo->syncFile(); +} + +void CompressedFileSystem::readFromFile(FileInfo& /*fileInfo*/, void* /*buffer*/, + uint64_t /*numBytes*/, uint64_t /*position*/) const { + throw IOException("Only sequential read is allowed in compressed file system."); +} + +void CompressedFileInfo::initialize() { + close(); + streamData.inputBufSize = compressedFS.getInputBufSize(); + streamData.outputBufSize = compressedFS.getOutputBufSize(); + streamData.inputBuf = std::make_unique(streamData.inputBufSize); + streamData.inputBufStart = streamData.inputBuf.get(); + streamData.inputBufEnd = streamData.inputBuf.get(); + streamData.outputBuf = std::make_unique(streamData.outputBufSize); + streamData.outputBufStart = streamData.outputBuf.get(); + streamData.outputBufEnd = streamData.outputBuf.get(); + currentPos = 0; + stream_wrapper = compressedFS.createStream(); + stream_wrapper->initialize(*this); +} + +int64_t CompressedFileInfo::readData(void* buffer, size_t numBytes) { + common::idx_t totalNumBytesRead = 0; + while (true) { + if (streamData.outputBufStart != streamData.outputBufEnd) { + auto available = + std::min(numBytes, streamData.outputBufEnd - streamData.outputBufStart); + memcpy(reinterpret_cast(buffer) + totalNumBytesRead, + streamData.outputBufStart, available); + streamData.outputBufStart += available; + totalNumBytesRead += available; + numBytes -= available; + if (numBytes == 0) { + return totalNumBytesRead; + } + } + if (!stream_wrapper) { + return totalNumBytesRead; + } + currentPos += streamData.inputBufEnd - streamData.inputBufStart; + streamData.outputBufStart = streamData.outputBuf.get(); + streamData.outputBufEnd = streamData.outputBuf.get(); + if (streamData.refresh && + (streamData.inputBufEnd == streamData.inputBuf.get() + streamData.inputBufSize)) { + auto numBytesLeftInBuf = streamData.inputBufEnd - streamData.inputBufStart; + memmove(streamData.inputBuf.get(), streamData.inputBufStart, numBytesLeftInBuf); + streamData.inputBufStart = streamData.inputBuf.get(); + auto sz = childFileInfo->readFile(streamData.inputBufStart + numBytesLeftInBuf, + streamData.inputBufSize - numBytesLeftInBuf); + streamData.inputBufEnd = streamData.inputBufStart + numBytesLeftInBuf + sz; + if (sz <= 0) { + stream_wrapper.reset(); + break; + } + } + + if (streamData.inputBufStart == streamData.inputBufEnd) { + streamData.inputBufStart = streamData.inputBuf.get(); + streamData.inputBufEnd = streamData.inputBufStart; + auto sz = childFileInfo->readFile(streamData.inputBuf.get(), streamData.inputBufSize); + if (sz <= 0) { + stream_wrapper.reset(); + break; + } + streamData.inputBufEnd = streamData.inputBufStart + sz; + } + + auto finished = stream_wrapper->read(streamData); + if (finished) { + stream_wrapper.reset(); + } + } + return totalNumBytesRead; +} + +void CompressedFileInfo::close() { + if (stream_wrapper) { + stream_wrapper->close(); + stream_wrapper.reset(); + } + streamData.inputBuf.reset(); + streamData.outputBuf.reset(); + streamData.outputBufStart = nullptr; + streamData.outputBufEnd = nullptr; + streamData.inputBufStart = nullptr; + streamData.inputBufEnd = nullptr; + streamData.inputBufSize = 0; + streamData.outputBufSize = 0; + streamData.refresh = false; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/file_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/file_info.cpp new file mode 100644 index 0000000000..72355ec266 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/file_info.cpp @@ -0,0 +1,51 @@ +#include "common/file_system/file_info.h" + +#include "common/file_system/file_system.h" + +#if defined(_WIN32) +#include +#else +#include +#endif + +namespace lbug { +namespace common { + +uint64_t FileInfo::getFileSize() const { + return fileSystem->getFileSize(*this); +} + +void FileInfo::readFromFile(void* buffer, uint64_t numBytes, uint64_t position) { + fileSystem->readFromFile(*this, buffer, numBytes, position); +} + +int64_t FileInfo::readFile(void* buf, size_t nbyte) { + return fileSystem->readFile(*this, buf, nbyte); +} + +void FileInfo::writeFile(const uint8_t* buffer, uint64_t numBytes, uint64_t offset) { + fileSystem->writeFile(*this, buffer, numBytes, offset); +} + +void FileInfo::syncFile() const { + fileSystem->syncFile(*this); +} + +int64_t FileInfo::seek(uint64_t offset, int whence) { + return fileSystem->seek(*this, offset, whence); +} + +void FileInfo::reset() { + fileSystem->reset(*this); +} + +void FileInfo::truncate(uint64_t size) { + fileSystem->truncate(*this, size); +} + +bool FileInfo::canPerformSeek() const { + return fileSystem->canPerformSeek(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/file_system.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/file_system.cpp new file mode 100644 index 0000000000..4993c106de --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/file_system.cpp @@ -0,0 +1,72 @@ +#include "common/file_system/file_system.h" + +#include "common/string_utils.h" + +namespace lbug { +namespace common { + +void FileSystem::overwriteFile(const std::string& /*from*/, const std::string& /*to*/) { + KU_UNREACHABLE; +} + +void FileSystem::copyFile(const std::string& /*from*/, const std::string& /*to*/) { + KU_UNREACHABLE; +} + +void FileSystem::createDir(const std::string& /*dir*/) const { + KU_UNREACHABLE; +} + +void FileSystem::removeFileIfExists(const std::string&, const main::ClientContext* /*context*/) { + KU_UNREACHABLE; +} + +bool FileSystem::fileOrPathExists(const std::string& /*path*/, main::ClientContext* /*context*/) { + KU_UNREACHABLE; +} + +std::string FileSystem::expandPath(main::ClientContext* /*context*/, + const std::string& path) const { + return path; +} + +std::string FileSystem::joinPath(const std::string& base, const std::string& part) { + return base + "/" + part; +} + +std::string FileSystem::getFileExtension(const std::filesystem::path& path) { + auto extension = path.extension(); + if (isCompressedFile(path)) { + extension = path.stem().extension(); + } + return extension.string(); +} + +bool FileSystem::isCompressedFile(const std::filesystem::path& path) { + return isGZIPCompressed(path); +} + +std::string FileSystem::getFileName(const std::filesystem::path& path) { + return path.filename().string(); +} + +void FileSystem::writeFile(FileInfo& /*fileInfo*/, const uint8_t* /*buffer*/, uint64_t /*numBytes*/, + uint64_t /*offset*/) const { + KU_UNREACHABLE; +} + +void FileSystem::truncate(FileInfo& /*fileInfo*/, uint64_t /*size*/) const { + KU_UNREACHABLE; +} + +void FileSystem::reset(FileInfo& fileInfo) { + fileInfo.seek(0, SEEK_SET); +} + +bool FileSystem::isGZIPCompressed(const std::filesystem::path& path) { + auto extensionLowerCase = StringUtils::getLower(path.extension().string()); + return extensionLowerCase == ".gz" || extensionLowerCase == ".gzip"; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/gzip_file_system.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/gzip_file_system.cpp new file mode 100644 index 0000000000..b3eb0e5264 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/gzip_file_system.cpp @@ -0,0 +1,166 @@ +#include "common/file_system/gzip_file_system.h" + +#include "common/exception/io.h" +#include "miniz.hpp" + +namespace lbug { +namespace common { + +std::unique_ptr GZipFileSystem::openCompressedFile(std::unique_ptr fileInfo) { + return std::make_unique(*this, std::move(fileInfo)); +} + +static idx_t consumeStr(FileInfo& input) { + idx_t size = 1; // terminator + char buffer[1]; + while (input.readFile(buffer, 1) == 1) { + if (buffer[0] == '\0') { + break; + } + size++; + } + return size; +} + +struct MiniZStreamWrapper : public StreamWrapper { + ~MiniZStreamWrapper() override { MiniZStreamWrapper::close(); } + + CompressedFileInfo* file = nullptr; + std::unique_ptr mzStreamPtr = nullptr; + miniz::mz_ulong crc = 0; + idx_t total_size = 0; + +public: + void initialize(CompressedFileInfo& fileInfo) override; + + bool read(StreamData& stream_data) override; + + void close() override; +}; + +static void verifyGZIPHeader(const uint8_t gzip_hdr[], idx_t read_count) { + if (read_count != GZipFileSystem::GZIP_HEADER_MINSIZE) { + throw IOException("Input is not a GZIP stream."); + } + if (gzip_hdr[0] != 0x1F || gzip_hdr[1] != 0x8B) { + throw IOException("Input is not a GZIP stream."); + } + if (gzip_hdr[2] != GZipFileSystem::GZIP_COMPRESSION_DEFLATE) { + throw IOException("Unsupported GZIP compression method."); + } + if (gzip_hdr[3] & GZipFileSystem::GZIP_FLAG_UNSUPPORTED) { + throw IOException("Unsupported GZIP archive."); + } +} + +void MiniZStreamWrapper::initialize(CompressedFileInfo& fileInfo) { + close(); + this->file = &fileInfo; + mzStreamPtr = std::make_unique(); + memset(mzStreamPtr.get(), 0, sizeof(miniz::mz_stream)); + uint8_t gzipHdr[GZipFileSystem::GZIP_HEADER_MINSIZE]; + + idx_t dataStart = GZipFileSystem::GZIP_HEADER_MINSIZE; + auto numBytesRead = + fileInfo.childFileInfo->readFile(gzipHdr, GZipFileSystem::GZIP_HEADER_MINSIZE); + verifyGZIPHeader(gzipHdr, numBytesRead); + if (gzipHdr[3] & GZipFileSystem::GZIP_FLAG_EXTRA) { + uint8_t gzipXLen[2]; + fileInfo.childFileInfo->seek(dataStart, SEEK_SET); + fileInfo.childFileInfo->readFile(gzipXLen, 2); + auto xlen = (uint8_t)gzipXLen[0] | (uint8_t)gzipXLen[1] << 8; + dataStart += xlen + 2; + } + if (gzipHdr[3] & GZipFileSystem::GZIP_FLAG_NAME) { + fileInfo.childFileInfo->seek(dataStart, SEEK_SET); + dataStart += consumeStr(*fileInfo.childFileInfo); + } + fileInfo.childFileInfo->seek(dataStart, SEEK_SET); + auto ret = miniz::mz_inflateInit2(mzStreamPtr.get(), -MZ_DEFAULT_WINDOW_BITS); + // LCOV_EXCL_START + if (ret != miniz::MZ_OK) { + throw InternalException("Failed to initialize miniz"); + } + // LCOV_EXCL_STOP +} + +bool MiniZStreamWrapper::read(StreamData& sd) { + if (sd.refresh) { + uint32_t available = sd.inputBufEnd - sd.inputBufStart; + if (available <= GZipFileSystem::GZIP_FOOTER_SIZE) { + close(); + return true; + } + + sd.refresh = false; + auto bodyPtr = sd.inputBufStart + GZipFileSystem::GZIP_FOOTER_SIZE; + uint8_t gzipHdr[GZipFileSystem::GZIP_HEADER_MINSIZE]; + memcpy(gzipHdr, bodyPtr, GZipFileSystem::GZIP_HEADER_MINSIZE); + verifyGZIPHeader(gzipHdr, GZipFileSystem::GZIP_HEADER_MINSIZE); + bodyPtr += GZipFileSystem::GZIP_HEADER_MINSIZE; + if (gzipHdr[3] & GZipFileSystem::GZIP_FLAG_EXTRA) { + auto xlen = (uint8_t)*bodyPtr | (uint8_t) * (bodyPtr + 1) << 8; + bodyPtr += xlen + 2; + KU_ASSERT((common::idx_t)(GZipFileSystem::GZIP_FOOTER_SIZE + + GZipFileSystem::GZIP_HEADER_MINSIZE + 2 + xlen) < + GZipFileSystem::GZIP_HEADER_MAXSIZE); + } + if (gzipHdr[3] & GZipFileSystem::GZIP_FLAG_NAME) { + char c = '\0'; + do { + c = *bodyPtr; + bodyPtr++; + } while (c != '\0' && bodyPtr < sd.inputBufEnd); + KU_ASSERT(bodyPtr - sd.inputBufStart < GZipFileSystem::GZIP_HEADER_MAXSIZE); + } + sd.inputBufStart = bodyPtr; + if (sd.inputBufEnd - sd.inputBufStart < 1) { + close(); + return true; + } + miniz::mz_inflateEnd(mzStreamPtr.get()); + auto sta = miniz::mz_inflateInit2(mzStreamPtr.get(), -MZ_DEFAULT_WINDOW_BITS); + // LCOV_EXCL_START + if (sta != miniz::MZ_OK) { + throw InternalException("Failed to initialize miniz"); + } + // LCOV_EXCL_STOP + } + + mzStreamPtr->next_in = sd.inputBufStart; + mzStreamPtr->avail_in = sd.inputBufEnd - sd.inputBufStart; + mzStreamPtr->next_out = sd.outputBufEnd; + mzStreamPtr->avail_out = sd.outputBuf.get() + sd.outputBufSize - sd.outputBufEnd; + auto ret = miniz::mz_inflate(mzStreamPtr.get(), miniz::MZ_NO_FLUSH); + // LCOV_EXCL_START + if (ret != miniz::MZ_OK && ret != miniz::MZ_STREAM_END) { + throw IOException( + common::stringFormat("Failed to decode gzip stream: {}", miniz::mz_error(ret))); + } + // LCOV_EXCL_STOP + sd.inputBufStart = (uint8_t*)mzStreamPtr->next_in; + sd.inputBufEnd = sd.inputBufStart + mzStreamPtr->avail_in; + sd.outputBufEnd = (uint8_t*)mzStreamPtr->next_out; + KU_ASSERT(sd.outputBufEnd + mzStreamPtr->avail_out == sd.outputBuf.get() + sd.outputBufSize); + + if (ret == miniz::MZ_STREAM_END) { + sd.refresh = true; + } + return false; +} + +void MiniZStreamWrapper::close() { + if (!mzStreamPtr) { + return; + } + miniz::mz_inflateEnd(mzStreamPtr.get()); + mzStreamPtr = nullptr; + file = nullptr; +} + +std::unique_ptr GZipFileSystem::createStream() { + return std::make_unique(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/local_file_system.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/local_file_system.cpp new file mode 100644 index 0000000000..1903b7d465 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/local_file_system.cpp @@ -0,0 +1,553 @@ +#include "common/file_system/local_file_system.h" + +#include "common/assert.h" +#include "common/exception/io.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "common/system_message.h" +#include "glob/glob.hpp" +#include "main/client_context.h" +#include "main/settings.h" + +#if defined(_WIN32) +#include + +#include "common/windows_utils.h" +#include +#include +#include +#else +#include "sys/stat.h" +#include +#endif + +#include + +#include + +#include "storage/storage_utils.h" + +namespace lbug { +namespace common { + +LocalFileInfo::~LocalFileInfo() { +#ifdef _WIN32 + if (handle != nullptr) { + CloseHandle((HANDLE)handle); + } +#else + if (fd != -1) { + close(fd); + } +#endif +} + +static void validateFileFlags(uint8_t flags) { + const bool isRead = flags & FileFlags::READ_ONLY; + const bool isWrite = flags & FileFlags::WRITE; + KU_UNUSED(isRead); + KU_UNUSED(isWrite); + // Require either READ or WRITE (or both). + KU_ASSERT(isRead || isWrite); + // CREATE flags require writing. + KU_ASSERT(isWrite || !(flags & FileFlags::CREATE_IF_NOT_EXISTS)); + KU_ASSERT(isWrite || !(flags & FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS)); + // CREATE_IF_NOT_EXISTS and CREATE_AND_TRUNCATE_IF_EXISTS flags cannot be combined. + KU_ASSERT(!(flags & FileFlags::CREATE_IF_NOT_EXISTS && + flags & FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS)); +} + +std::unique_ptr LocalFileSystem::openFile(const std::string& path, FileOpenFlags flags, + main::ClientContext* context) { + auto fullPath = expandPath(context, path); + auto fileFlags = flags.flags; + validateFileFlags(fileFlags); + + int openFlags = 0; + bool readMode = fileFlags & FileFlags::READ_ONLY; + bool writeMode = fileFlags & FileFlags::WRITE; + if (readMode && writeMode) { + openFlags = O_RDWR; + } else if (readMode) { + openFlags = O_RDONLY; + } else if (writeMode) { + openFlags = O_WRONLY; + } else { + // LCOV_EXCL_START + throw InternalException("READ, WRITE or both should be specified when opening a file."); + // LCOV_EXCL_STOP + } + if (writeMode) { + KU_ASSERT(fileFlags & FileFlags::WRITE); + if (fileFlags & FileFlags::CREATE_IF_NOT_EXISTS) { + openFlags |= O_CREAT; + } else if (fileFlags & FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS) { + openFlags |= O_CREAT | O_TRUNC; + } + } + +#if defined(_WIN32) + auto dwDesiredAccess = 0ul; + int dwCreationDisposition; + if (fileFlags & FileFlags::CREATE_IF_NOT_EXISTS) { + dwCreationDisposition = OPEN_ALWAYS; + } else if (fileFlags & FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS) { + dwCreationDisposition = CREATE_ALWAYS; + } else { + dwCreationDisposition = OPEN_EXISTING; + } + auto dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; + if (openFlags & (O_CREAT | O_WRONLY | O_RDWR)) { + dwDesiredAccess |= GENERIC_WRITE; + } + // O_RDONLY is 0 in practice, so openFlags & (O_RDONLY | O_RDWR) doesn't work. + if (!(openFlags & O_WRONLY)) { + dwDesiredAccess |= GENERIC_READ; + } + if (openFlags & FileFlags::BINARY) { + dwDesiredAccess |= _O_BINARY; + } + + HANDLE handle = CreateFileA(fullPath.c_str(), dwDesiredAccess, dwShareMode, nullptr, + dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, nullptr); + if (handle == INVALID_HANDLE_VALUE) { + throw IOException(stringFormat("Cannot open file. path: {} - Error {}: {}", fullPath, + GetLastError(), std::system_category().message(GetLastError()))); + } + if (flags.lockType != FileLockType::NO_LOCK) { + DWORD dwFlags = flags.lockType == FileLockType::READ_LOCK ? + LOCKFILE_FAIL_IMMEDIATELY : + LOCKFILE_FAIL_IMMEDIATELY | LOCKFILE_EXCLUSIVE_LOCK; + OVERLAPPED overlapped = {0}; + overlapped.Offset = 0; + BOOL rc = LockFileEx(handle, dwFlags, 0 /*reserved*/, 1 /*numBytesLow*/, 0 /*numBytesHigh*/, + &overlapped); + if (!rc) { + throw IOException( + "Could not set lock on file : " + fullPath + "\n" + + "See the docs: https://docs.ladybugdb.com/concurrency for more information."); + } + } + return std::make_unique(fullPath, handle, this); +#else + int fd = open(fullPath.c_str(), openFlags, 0644); + if (fd == -1) { + throw IOException(stringFormat("Cannot open file {}: {}", fullPath, posixErrMessage())); + } + if (flags.lockType != FileLockType::NO_LOCK) { + struct flock fl {}; + memset(&fl, 0, sizeof fl); + fl.l_type = flags.lockType == FileLockType::READ_LOCK ? F_RDLCK : F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 0; + int rc = fcntl(fd, F_SETLK, &fl); + if (rc == -1) { + throw IOException( + "Could not set lock on file : " + fullPath + "\n" + + "See the docs: https://docs.ladybugdb.com/concurrency for more information."); + } + } + return std::make_unique(fullPath, fd, this); +#endif +} + +std::vector LocalFileSystem::glob(main::ClientContext* context, + const std::string& path) const { + if (path.empty()) { + return std::vector(); + } + std::vector pathsToGlob; + if (path[0] == '/' || (std::isalpha(path[0]) && path[1] == ':')) { + // Note: + // Unix absolute path starts with '/' + // Windows absolute path starts with "[DiskID]://" + pathsToGlob.push_back(path); + } else if (path[0] == '~') { + // Expands home directory + auto homeDirectory = + context->getCurrentSetting(main::HomeDirectorySetting::name).getValue(); + pathsToGlob.push_back(homeDirectory + path.substr(1)); + } else { + // Relative path to the file search path. + auto globbedPaths = glob::glob(path); + if (!globbedPaths.empty()) { + pathsToGlob.push_back(path); + } else { + auto fileSearchPath = context->getCurrentSetting(main::FileSearchPathSetting::name) + .getValue(); + if (fileSearchPath != "") { + auto searchPaths = StringUtils::split(fileSearchPath, ","); + for (auto& searchPath : searchPaths) { + pathsToGlob.push_back(stringFormat("{}/{}", searchPath, path)); + } + } + } + } + std::vector result; + for (auto& pathToGlob : pathsToGlob) { + for (auto& resultPath : glob::glob(pathToGlob)) { + result.emplace_back(resultPath.string()); + } + } + return result; +} + +void LocalFileSystem::overwriteFile(const std::string& from, const std::string& to) { + if (!fileOrPathExists(from) || !fileOrPathExists(to)) { + return; + } + std::error_code errorCode; + if (!std::filesystem::copy_file(from, to, std::filesystem::copy_options::overwrite_existing, + errorCode)) { + // LCOV_EXCL_START + throw IOException(stringFormat("Error copying file {} to {}. ErrorMessage: {}", from, to, + errorCode.message())); + // LCOV_EXCL_STOP + } +} + +void LocalFileSystem::copyFile(const std::string& from, const std::string& to) { + if (!fileOrPathExists(from)) { + return; + } + std::error_code errorCode; + if (!std::filesystem::copy_file(from, to, std::filesystem::copy_options::none, errorCode)) { + // LCOV_EXCL_START + throw IOException(stringFormat("Error copying file {} to {}. ErrorMessage: {}", from, to, + errorCode.message())); + // LCOV_EXCL_STOP + } +} + +void LocalFileSystem::createDir(const std::string& dir) const { + try { + if (std::filesystem::exists(dir)) { + // LCOV_EXCL_START + throw IOException(stringFormat("Directory {} already exists.", dir)); + // LCOV_EXCL_STOP + } + auto directoryToCreate = dir; + if (directoryToCreate.ends_with('/') +#if defined(_WIN32) + || directoryToCreate.ends_with('\\') +#endif + ) { + // This is a known issue with std::filesystem::create_directories. (link: + // https://github.com/llvm/llvm-project/issues/60634). We have to manually remove the + // last '/' if the path ends with '/'. (Added the second one for windows) + directoryToCreate = directoryToCreate.substr(0, directoryToCreate.size() - 1); + } + std::error_code errCode; + if (!std::filesystem::create_directories(directoryToCreate, errCode)) { + // LCOV_EXCL_START + throw IOException( + stringFormat("Directory {} cannot be created. Check if it exists and remove it.", + directoryToCreate)); + // LCOV_EXCL_STOP + } + if (errCode) { + // LCOV_EXCL_START + throw IOException(stringFormat("Failed to create directory: {}, error message: {}.", + dir, errCode.message())); + // LCOV_EXCL_STOP + } + } catch (std::exception& e) { + // LCOV_EXCL_START + throw IOException(stringFormat("Failed to create directory {} due to: {}", dir, e.what())); + // LCOV_EXCL_STOP + } +} + +static std::unordered_set getDatabaseFileSet(const std::string& path) { + std::unordered_set result; + result.insert(storage::StorageUtils::getWALFilePath(path)); + result.insert(storage::StorageUtils::getShadowFilePath(path)); + result.insert(storage::StorageUtils::getTmpFilePath(path)); + return result; +} + +static bool isExtensionFile(const main::ClientContext* context, const std::string& path) { + if (context == nullptr) { + return false; + } + auto extensionDir = context->getExtensionDir(); + std::filesystem::path rel = std::filesystem::relative(path, extensionDir); + for (const auto& part : rel) { + if (part == "..") { + return false; + } + } + return true; +} + +void LocalFileSystem::removeFileIfExists(const std::string& path, + const main::ClientContext* context) { + if (!fileOrPathExists(path)) { + return; + } + if (!getDatabaseFileSet(dbPath).contains(path) && !isExtensionFile(context, path)) { + throw IOException(stringFormat( + "Error: Path {} is not within the allowed list of files to be removed.", path)); + } + std::error_code errCode; + bool success = false; + if (std::filesystem::is_directory(path)) { + success = std::filesystem::remove_all(path, errCode); + } else { + success = std::filesystem::remove(path, errCode); + } + if (!success) { + // LCOV_EXCL_START + throw IOException(stringFormat("Error removing directory or file {}. Error Message: {}", + path, errCode.message())); + // LCOV_EXCL_STOP + } +} + +bool LocalFileSystem::fileOrPathExists(const std::string& path, main::ClientContext* /*context*/) { + return std::filesystem::exists(path); +} + +#ifndef _WIN32 +bool LocalFileSystem::fileExists(const std::string& filename) { + if (!filename.empty()) { + if (access(filename.c_str(), 0) == 0) { + struct stat status = {}; + stat(filename.c_str(), &status); + if (S_ISREG(status.st_mode)) { + return true; + } + } + } + // if any condition fails + return false; +} +#else +bool LocalFileSystem::fileExists(const std::string& filename) { + auto unicode_path = WindowsUtils::utf8ToUnicode(filename.c_str()); + const wchar_t* wpath = unicode_path.c_str(); + if (_waccess(wpath, 0) == 0) { + struct _stati64 status = {}; + _wstati64(wpath, &status); + if (status.st_mode & _S_IFREG) { + return true; + } + } + return false; +} +#endif + +std::string LocalFileSystem::expandPath(main::ClientContext* context, + const std::string& path) const { + auto fullPath = path; + if (path.starts_with('~')) { + fullPath = + context->getCurrentSetting(main::HomeDirectorySetting::name).getValue() + + fullPath.substr(1); + } + return fullPath; +} + +bool LocalFileSystem::isLocalPath(const std::string& path) { + return path.rfind("s3://", 0) != 0 && path.rfind("gs://", 0) != 0 && + path.rfind("gcs://", 0) != 0 && path.rfind("http://", 0) != 0 && + path.rfind("https://", 0) != 0 && path.rfind("az://", 0) != 0 && + path.rfind("abfss://", 0) != 0; +} + +void LocalFileSystem::readFromFile(FileInfo& fileInfo, void* buffer, uint64_t numBytes, + uint64_t position) const { + auto localFileInfo = fileInfo.constPtrCast(); + KU_ASSERT(localFileInfo->getFileSize() >= position + numBytes); +#if defined(_WIN32) + DWORD numBytesRead; + OVERLAPPED overlapped{0, 0, 0, 0}; + overlapped.Offset = position & 0xffffffff; + overlapped.OffsetHigh = position >> 32; + if (!ReadFile((HANDLE)localFileInfo->handle, buffer, numBytes, &numBytesRead, &overlapped)) { + auto error = GetLastError(); + throw IOException( + stringFormat("Cannot read from file: {} handle: {} " + "numBytesRead: {} numBytesToRead: {} position: {}. Error {}: {}", + fileInfo.path, (intptr_t)localFileInfo->handle, numBytesRead, numBytes, position, + error, std::system_category().message(error))); + } + if (numBytesRead != numBytes && fileInfo.getFileSize() != position + numBytesRead) { + throw IOException(stringFormat("Cannot read from file: {} handle: {} " + "numBytesRead: {} numBytesToRead: {} position: {}", + fileInfo.path, (intptr_t)localFileInfo->handle, numBytesRead, numBytes, position)); + } +#else + auto numBytesRead = pread(localFileInfo->fd, buffer, numBytes, position); + if (static_cast(numBytesRead) != numBytes && + localFileInfo->getFileSize() != position + numBytesRead) { + // LCOV_EXCL_START + throw IOException(stringFormat("Cannot read from file: {} fileDescriptor: {} " + "numBytesRead: {} numBytesToRead: {} position: {}", + fileInfo.path, localFileInfo->fd, numBytesRead, numBytes, position)); + // LCOV_EXCL_STOP + } +#endif +} + +int64_t LocalFileSystem::readFile(FileInfo& fileInfo, void* buf, size_t nbyte) const { + auto localFileInfo = fileInfo.constPtrCast(); +#if defined(_WIN32) + DWORD numBytesRead; + ReadFile((HANDLE)localFileInfo->handle, buf, nbyte, &numBytesRead, nullptr); + return numBytesRead; +#else + return read(localFileInfo->fd, buf, nbyte); +#endif +} + +void LocalFileSystem::writeFile(FileInfo& fileInfo, const uint8_t* buffer, uint64_t numBytes, + uint64_t offset) const { + auto localFileInfo = fileInfo.constPtrCast(); + uint64_t remainingNumBytesToWrite = numBytes; + uint64_t bufferOffset = 0; + // Split large writes to 1GB at a time + uint64_t maxBytesToWriteAtOnce = 1ull << 30; // 1ull << 30 = 1G + while (remainingNumBytesToWrite > 0) { + uint64_t numBytesToWrite = std::min(remainingNumBytesToWrite, maxBytesToWriteAtOnce); + +#if defined(_WIN32) + DWORD numBytesWritten; + OVERLAPPED overlapped{0, 0, 0, 0}; + overlapped.Offset = offset & 0xffffffff; + overlapped.OffsetHigh = offset >> 32; + if (!WriteFile((HANDLE)localFileInfo->handle, buffer + bufferOffset, numBytesToWrite, + &numBytesWritten, &overlapped)) { + auto error = GetLastError(); + throw IOException( + stringFormat("Cannot write to file. path: {} handle: {} offsetToWrite: {} " + "numBytesToWrite: {} numBytesWritten: {}. Error {}: {}.", + fileInfo.path, (intptr_t)localFileInfo->handle, offset, numBytesToWrite, + numBytesWritten, error, std::system_category().message(error))); + } +#else + auto numBytesWritten = + pwrite(localFileInfo->fd, buffer + bufferOffset, numBytesToWrite, offset); + if (numBytesWritten != static_cast(numBytesToWrite)) { + // LCOV_EXCL_START + throw IOException( + stringFormat("Cannot write to file. path: {} fileDescriptor: {} offsetToWrite: {} " + "numBytesToWrite: {} numBytesWritten: {}. Error: {}", + fileInfo.path, localFileInfo->fd, offset, numBytesToWrite, numBytesWritten, + posixErrMessage())); + // LCOV_EXCL_STOP + } +#endif + remainingNumBytesToWrite -= numBytesWritten; + offset += numBytesWritten; + bufferOffset += numBytesWritten; + } +} + +void LocalFileSystem::syncFile(const FileInfo& fileInfo) const { + auto localFileInfo = fileInfo.constPtrCast(); +#if defined(_WIN32) + // Note that `FlushFileBuffers` returns 0 when fails, while `fsync` returns 0 when succeeds. + if (FlushFileBuffers((HANDLE)localFileInfo->handle) == 0) { + auto error = GetLastError(); + throw IOException(stringFormat("Failed to sync file {}. Error {}: {}", fileInfo.path, error, + std::system_category().message(error))); + } +#else +#if HAS_FULLFSYNC and defined(__APPLE__) + // Try F_FULLFSYNC first on macOS/iOS, which is required to guarantee durability past power + // failures. + if (fcntl(localFileInfo->fd, F_FULLFSYNC) == 0) { + return; + } + if (errno != ENOTSUP && errno != EINVAL) { + // LCOV_EXCL_START + if (errno == EIO) { + throw IOException("Fatal error: fsync failed!"); + } + throw IOException( + stringFormat("Failed to sync file {}: {}", fileInfo.path, posixErrMessage())); + // LCOV_EXCL_STOP + } +#endif + bool syncSuccess = false; +#if HAS_FDATASYNC + syncSuccess = fdatasync(localFileInfo->fd) == 0; // Only sync file data + essential metadata. +#else + syncSuccess = fsync(localFileInfo->fd) == 0; // Sync file data + all metadata. +#endif + if (!syncSuccess) { + throw IOException(stringFormat("Failed to sync file {}.", fileInfo.path)); + } +#endif +} + +int64_t LocalFileSystem::seek(FileInfo& fileInfo, uint64_t offset, int whence) const { + auto localFileInfo = fileInfo.constPtrCast(); +#if defined(_WIN32) + LARGE_INTEGER result; + LARGE_INTEGER offset_; + offset_.QuadPart = offset; + SetFilePointerEx((HANDLE)localFileInfo->handle, offset_, &result, whence); + return result.QuadPart; +#else + return lseek(localFileInfo->fd, offset, whence); +#endif +} + +void LocalFileSystem::truncate(FileInfo& fileInfo, uint64_t size) const { + auto localFileInfo = fileInfo.constPtrCast(); +#if defined(_WIN32) + auto offsetHigh = (LONG)(size >> 32); + LONG* offsetHighPtr = NULL; + if (offsetHigh > 0) + offsetHighPtr = &offsetHigh; + if (SetFilePointer((HANDLE)localFileInfo->handle, size & 0xffffffff, offsetHighPtr, + FILE_BEGIN) == INVALID_SET_FILE_POINTER) { + auto error = GetLastError(); + throw IOException(stringFormat("Cannot set file pointer for file: {} handle: {} " + "new position: {}. Error {}: {}", + fileInfo.path, (intptr_t)localFileInfo->handle, size, error, + std::system_category().message(error))); + } + if (!SetEndOfFile((HANDLE)localFileInfo->handle)) { + auto error = GetLastError(); + throw IOException(stringFormat("Cannot truncate file: {} handle: {} " + "size: {}. Error {}: {}", + fileInfo.path, (intptr_t)localFileInfo->handle, size, error, + std::system_category().message(error))); + } +#else + if (ftruncate(localFileInfo->fd, size) < 0) { + // LCOV_EXCL_START + throw IOException( + stringFormat("Failed to truncate file {}: {}", fileInfo.path, posixErrMessage())); + // LCOV_EXCL_STOP + } +#endif +} + +uint64_t LocalFileSystem::getFileSize(const FileInfo& fileInfo) const { + auto localFileInfo = fileInfo.constPtrCast(); +#ifdef _WIN32 + LARGE_INTEGER size; + if (!GetFileSizeEx((HANDLE)localFileInfo->handle, &size)) { + auto error = GetLastError(); + throw IOException(stringFormat("Cannot read size of file. path: {} - Error {}: {}", + fileInfo.path, error, systemErrMessage(error))); + } + return size.QuadPart; +#else + struct stat s {}; + if (fstat(localFileInfo->fd, &s) == -1) { + throw IOException(stringFormat("Cannot read size of file. path: {} - Error {}: {}", + fileInfo.path, errno, posixErrMessage())); + } + KU_ASSERT(s.st_size >= 0); + return s.st_size; +#endif +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/virtual_file_system.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/virtual_file_system.cpp new file mode 100644 index 0000000000..869a7db67a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/file_system/virtual_file_system.cpp @@ -0,0 +1,139 @@ +#include "common/file_system/virtual_file_system.h" + +#include "common/assert.h" +#include "common/exception/io.h" +#include "common/file_system/gzip_file_system.h" +#include "common/file_system/local_file_system.h" +#include "common/string_utils.h" +#include "main/client_context.h" +#include "main/database.h" + +namespace lbug { +namespace common { + +VirtualFileSystem::VirtualFileSystem() : VirtualFileSystem{""} {} + +VirtualFileSystem::VirtualFileSystem(std::string homeDir) { + defaultFS = std::make_unique(homeDir); + compressedFileSystem.emplace(FileCompressionType::GZIP, std::make_unique()); +} + +VirtualFileSystem::~VirtualFileSystem() = default; + +void VirtualFileSystem::registerFileSystem(std::unique_ptr fileSystem) { + subSystems.push_back(std::move(fileSystem)); +} + +FileCompressionType VirtualFileSystem::autoDetectCompressionType(const std::string& path) { + if (isGZIPCompressed(path)) { + return FileCompressionType::GZIP; + } + return FileCompressionType::UNCOMPRESSED; +} + +std::unique_ptr VirtualFileSystem::openFile(const std::string& path, FileOpenFlags flags, + main::ClientContext* context) { + auto compressionType = flags.compressionType; + if (compressionType == FileCompressionType::AUTO_DETECT) { + compressionType = autoDetectCompressionType(path); + } + auto fileHandle = findFileSystem(path)->openFile(path, flags, context); + if (compressionType == FileCompressionType::UNCOMPRESSED) { + return fileHandle; + } + if (flags.flags & FileFlags::WRITE) { + throw IOException{"Writing to compressed files is not supported yet."}; + } + if (StringUtils::getLower(getFileExtension(path)) != ".csv") { + throw IOException{"Lbug currently only supports reading from compressed csv files."}; + } + return compressedFileSystem.at(compressionType)->openCompressedFile(std::move(fileHandle)); +} + +std::vector VirtualFileSystem::glob(main::ClientContext* context, + const std::string& path) const { + return findFileSystem(path)->glob(context, path); +} + +void VirtualFileSystem::overwriteFile(const std::string& from, const std::string& to) { + findFileSystem(from)->overwriteFile(from, to); +} + +void VirtualFileSystem::createDir(const std::string& dir) const { + findFileSystem(dir)->createDir(dir); +} + +void VirtualFileSystem::removeFileIfExists(const std::string& path, + const main::ClientContext* context) { + findFileSystem(path)->removeFileIfExists(path, context); +} + +bool VirtualFileSystem::fileOrPathExists(const std::string& path, main::ClientContext* context) { + return findFileSystem(path)->fileOrPathExists(path, context); +} + +std::string VirtualFileSystem::expandPath(main::ClientContext* context, + const std::string& path) const { + return findFileSystem(path)->expandPath(context, path); +} + +void VirtualFileSystem::readFromFile(FileInfo& /*fileInfo*/, void* /*buffer*/, + uint64_t /*numBytes*/, uint64_t /*position*/) const { + KU_UNREACHABLE; +} + +int64_t VirtualFileSystem::readFile(FileInfo& /*fileInfo*/, void* /*buf*/, size_t /*nbyte*/) const { + KU_UNREACHABLE; +} + +void VirtualFileSystem::writeFile(FileInfo& /*fileInfo*/, const uint8_t* /*buffer*/, + uint64_t /*numBytes*/, uint64_t /*offset*/) const { + KU_UNREACHABLE; +} + +void VirtualFileSystem::syncFile(const FileInfo& fileInfo) const { + findFileSystem(fileInfo.path)->syncFile(fileInfo); +} + +void VirtualFileSystem::cleanUP(main::ClientContext* context) { + for (auto& subSystem : subSystems) { + subSystem->cleanUP(context); + } + defaultFS->cleanUP(context); +} + +bool VirtualFileSystem::handleFileViaFunction(const std::string& path) const { + return findFileSystem(path)->handleFileViaFunction(path); +} + +function::TableFunction VirtualFileSystem::getHandleFunction(const std::string& path) const { + return findFileSystem(path)->getHandleFunction(path); +} + +int64_t VirtualFileSystem::seek(FileInfo& /*fileInfo*/, uint64_t /*offset*/, int /*whence*/) const { + KU_UNREACHABLE; +} + +void VirtualFileSystem::truncate(FileInfo& /*fileInfo*/, uint64_t /*size*/) const { + KU_UNREACHABLE; +} + +uint64_t VirtualFileSystem::getFileSize(const FileInfo& /*fileInfo*/) const { + KU_UNREACHABLE; +} + +FileSystem* VirtualFileSystem::findFileSystem(const std::string& path) const { + for (auto& subSystem : subSystems) { + if (subSystem->canHandleFile(path)) { + return subSystem.get(); + } + } + return defaultFS.get(); +} + +VirtualFileSystem* VirtualFileSystem::GetUnsafe(const main::ClientContext& context) { + return context.getDatabase()->getVFS(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/in_mem_overflow_buffer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/in_mem_overflow_buffer.cpp new file mode 100644 index 0000000000..b118c77787 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/in_mem_overflow_buffer.cpp @@ -0,0 +1,69 @@ +#include "common/in_mem_overflow_buffer.h" + +#include "common/system_config.h" +#include "storage/buffer_manager/memory_manager.h" +#include + +using namespace lbug::storage; + +namespace lbug { +namespace common { + +BufferBlock::BufferBlock(std::unique_ptr block) + : currentOffset{0}, block{std::move(block)} {} + +BufferBlock::~BufferBlock() = default; + +uint64_t BufferBlock::size() const { + return block->getBuffer().size(); +} + +uint8_t* BufferBlock::data() const { + return block->getBuffer().data(); +} + +uint8_t* InMemOverflowBuffer::allocateSpace(uint64_t size) { + if (requireNewBlock(size)) { + if (!blocks.empty() && currentBlock()->currentOffset == 0) { + blocks.pop_back(); + } + allocateNewBlock(size); + } + auto data = currentBlock()->data() + currentBlock()->currentOffset; + currentBlock()->currentOffset += size; + return data; +} + +void InMemOverflowBuffer::resetBuffer() { + if (!blocks.empty()) { + // Last block is usually the largest + auto lastBlock = std::move(blocks.back()); + blocks.clear(); + lastBlock->resetCurrentOffset(); + blocks.push_back(std::move(lastBlock)); + } +} + +void InMemOverflowBuffer::preventDestruction() { + for (auto& block : blocks) { + block->block->preventDestruction(); + } +} + +void InMemOverflowBuffer::allocateNewBlock(uint64_t size) { + std::unique_ptr newBlock; + if (blocks.empty()) { + newBlock = make_unique( + memoryManager->allocateBuffer(false /* do not initialize to zero */, size)); + } else { + // Use the doubling strategy so that the initial allocations are small, but if we need many + // allocations they approach the TEMP_PAGE_SIZE quickly + auto min = std::min(TEMP_PAGE_SIZE, std::bit_ceil(currentBlock()->size() * 2)); + newBlock = make_unique(memoryManager->allocateBuffer( + false /* do not initialize to zero */, std::max(min, size))); + } + blocks.push_back(std::move(newBlock)); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/mask.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/mask.cpp new file mode 100644 index 0000000000..daf440a263 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/mask.cpp @@ -0,0 +1,24 @@ +#include "common/mask.h" + +#include "common/roaring_mask.h" + +namespace lbug { +namespace common { + +std::unique_ptr SemiMaskUtil::createMask(offset_t maxOffset) { + if (maxOffset > std::numeric_limits::max()) { + return std::make_unique(maxOffset); + } + return std::make_unique(maxOffset); +} + +offset_t NodeOffsetMaskMap::getNumMaskedNode() const { + offset_t numNodes = 0; + for (auto& [tableID, mask] : maskMap) { + numNodes += mask->getNumMaskedNodes(); + } + return numNodes; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/md5.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/md5.cpp new file mode 100644 index 0000000000..6dc16e0221 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/md5.cpp @@ -0,0 +1,233 @@ +/* +** This code taken from the SQLite test library (can be found at +** https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). +** Originally found on the internet. The original header comment follows this comment. +** The code has been refactored, but the algorithm stays the same. +*/ +/* + * This code implements the MD5 message-digest algorithm. + * The algorithm is due to Ron Rivest. This code was + * written by Colin Plumb in 1993, no copyright is claimed. + * This code is in the public domain; do with it what you wish. + * + * Equivalent code is available from RSA Data Security, Inc. + * This code has been tested against that, and is equivalent, + * except that you don't need to include two pages of legalese + * with every copy. + * + * To compute the message digest of a chunk of bytes, declare an + * MD5Context structure, pass it to MD5Init, call MD5Update as + * needed on buffers full of bytes, and then call MD5Final, which + * will fill a supplied 16-byte array with the digest. + */ + +#include "common/md5.h" + +#include + +namespace lbug { +namespace common { + +void MD5::byteReverse(unsigned char* buf, unsigned longs) { + uint32_t t = 0; + do { + t = (uint32_t)((unsigned)buf[3] << 8 | buf[2]) << 16 | ((unsigned)buf[1] << 8 | buf[0]); + *(uint32_t*)buf = t; + buf += 4; + } while (--longs); +} + +// The four core functions - F1 is optimized somewhat +#define F1(x, y, z) ((z) ^ ((x) & ((y) ^ (z)))) +#define F2(x, y, z) F1(z, x, y) +#define F3(x, y, z) ((x) ^ (y) ^ (z)) +#define F4(x, y, z) ((y) ^ ((x) | ~(z))) +// This is the central step in the MD5 algorithm. +#define MD5STEP(f, w, x, y, z, data, s) \ + ((w) += f(x, y, z) + (data), (w) = (w) << (s) | (w) >> (32 - (s)), (w) += (x)) + +void MD5::MD5Transform(uint32_t buf[4], const uint32_t in[16]) { + uint32_t a = buf[0]; + uint32_t b = buf[1]; + uint32_t c = buf[2]; + uint32_t d = buf[3]; + + MD5STEP(F1, a, b, c, d, in[0] + 0xd76aa478, 7); + MD5STEP(F1, d, a, b, c, in[1] + 0xe8c7b756, 12); + MD5STEP(F1, c, d, a, b, in[2] + 0x242070db, 17); + MD5STEP(F1, b, c, d, a, in[3] + 0xc1bdceee, 22); + MD5STEP(F1, a, b, c, d, in[4] + 0xf57c0faf, 7); + MD5STEP(F1, d, a, b, c, in[5] + 0x4787c62a, 12); + MD5STEP(F1, c, d, a, b, in[6] + 0xa8304613, 17); + MD5STEP(F1, b, c, d, a, in[7] + 0xfd469501, 22); + MD5STEP(F1, a, b, c, d, in[8] + 0x698098d8, 7); + MD5STEP(F1, d, a, b, c, in[9] + 0x8b44f7af, 12); + MD5STEP(F1, c, d, a, b, in[10] + 0xffff5bb1, 17); + MD5STEP(F1, b, c, d, a, in[11] + 0x895cd7be, 22); + MD5STEP(F1, a, b, c, d, in[12] + 0x6b901122, 7); + MD5STEP(F1, d, a, b, c, in[13] + 0xfd987193, 12); + MD5STEP(F1, c, d, a, b, in[14] + 0xa679438e, 17); + MD5STEP(F1, b, c, d, a, in[15] + 0x49b40821, 22); + + MD5STEP(F2, a, b, c, d, in[1] + 0xf61e2562, 5); + MD5STEP(F2, d, a, b, c, in[6] + 0xc040b340, 9); + MD5STEP(F2, c, d, a, b, in[11] + 0x265e5a51, 14); + MD5STEP(F2, b, c, d, a, in[0] + 0xe9b6c7aa, 20); + MD5STEP(F2, a, b, c, d, in[5] + 0xd62f105d, 5); + MD5STEP(F2, d, a, b, c, in[10] + 0x02441453, 9); + MD5STEP(F2, c, d, a, b, in[15] + 0xd8a1e681, 14); + MD5STEP(F2, b, c, d, a, in[4] + 0xe7d3fbc8, 20); + MD5STEP(F2, a, b, c, d, in[9] + 0x21e1cde6, 5); + MD5STEP(F2, d, a, b, c, in[14] + 0xc33707d6, 9); + MD5STEP(F2, c, d, a, b, in[3] + 0xf4d50d87, 14); + MD5STEP(F2, b, c, d, a, in[8] + 0x455a14ed, 20); + MD5STEP(F2, a, b, c, d, in[13] + 0xa9e3e905, 5); + MD5STEP(F2, d, a, b, c, in[2] + 0xfcefa3f8, 9); + MD5STEP(F2, c, d, a, b, in[7] + 0x676f02d9, 14); + MD5STEP(F2, b, c, d, a, in[12] + 0x8d2a4c8a, 20); + + MD5STEP(F3, a, b, c, d, in[5] + 0xfffa3942, 4); + MD5STEP(F3, d, a, b, c, in[8] + 0x8771f681, 11); + MD5STEP(F3, c, d, a, b, in[11] + 0x6d9d6122, 16); + MD5STEP(F3, b, c, d, a, in[14] + 0xfde5380c, 23); + MD5STEP(F3, a, b, c, d, in[1] + 0xa4beea44, 4); + MD5STEP(F3, d, a, b, c, in[4] + 0x4bdecfa9, 11); + MD5STEP(F3, c, d, a, b, in[7] + 0xf6bb4b60, 16); + MD5STEP(F3, b, c, d, a, in[10] + 0xbebfbc70, 23); + MD5STEP(F3, a, b, c, d, in[13] + 0x289b7ec6, 4); + MD5STEP(F3, d, a, b, c, in[0] + 0xeaa127fa, 11); + MD5STEP(F3, c, d, a, b, in[3] + 0xd4ef3085, 16); + MD5STEP(F3, b, c, d, a, in[6] + 0x04881d05, 23); + MD5STEP(F3, a, b, c, d, in[9] + 0xd9d4d039, 4); + MD5STEP(F3, d, a, b, c, in[12] + 0xe6db99e5, 11); + MD5STEP(F3, c, d, a, b, in[15] + 0x1fa27cf8, 16); + MD5STEP(F3, b, c, d, a, in[2] + 0xc4ac5665, 23); + + MD5STEP(F4, a, b, c, d, in[0] + 0xf4292244, 6); + MD5STEP(F4, d, a, b, c, in[7] + 0x432aff97, 10); + MD5STEP(F4, c, d, a, b, in[14] + 0xab9423a7, 15); + MD5STEP(F4, b, c, d, a, in[5] + 0xfc93a039, 21); + MD5STEP(F4, a, b, c, d, in[12] + 0x655b59c3, 6); + MD5STEP(F4, d, a, b, c, in[3] + 0x8f0ccc92, 10); + MD5STEP(F4, c, d, a, b, in[10] + 0xffeff47d, 15); + MD5STEP(F4, b, c, d, a, in[1] + 0x85845dd1, 21); + MD5STEP(F4, a, b, c, d, in[8] + 0x6fa87e4f, 6); + MD5STEP(F4, d, a, b, c, in[15] + 0xfe2ce6e0, 10); + MD5STEP(F4, c, d, a, b, in[6] + 0xa3014314, 15); + MD5STEP(F4, b, c, d, a, in[13] + 0x4e0811a1, 21); + MD5STEP(F4, a, b, c, d, in[4] + 0xf7537e82, 6); + MD5STEP(F4, d, a, b, c, in[11] + 0xbd3af235, 10); + MD5STEP(F4, c, d, a, b, in[2] + 0x2ad7d2bb, 15); + MD5STEP(F4, b, c, d, a, in[9] + 0xeb86d391, 21); + + buf[0] += a; + buf[1] += b; + buf[2] += c; + buf[3] += d; +} + +void MD5::MD5Init() { + ctx.isInit = 1; + ctx.buf[0] = 0x67452301; + ctx.buf[1] = 0xefcdab89; + ctx.buf[2] = 0x98badcfe; + ctx.buf[3] = 0x10325476; + ctx.bits[0] = 0; + ctx.bits[1] = 0; +} + +void MD5::MD5Update(const unsigned char* buf, unsigned int len) { + // Update bitcount + + uint32_t t = ctx.bits[0]; + ctx.bits[0] = t + ((uint32_t)len << 3); + if (ctx.bits[0] < t) { + ctx.bits[1]++; // Carry from low to high + } + ctx.bits[1] += len >> 29; + + t = (t >> 3) & 0x3f; // Bytes already in shsInfo->data + + // Handle any leading odd-sized chunks + + if (t) { + unsigned char* p = (unsigned char*)ctx.in + t; + + t = 64 - t; + if (len < t) { + std::memcpy(p, buf, len); + return; + } + std::memcpy(p, buf, t); + byteReverse(ctx.in, 16); + MD5Transform(ctx.buf, (uint32_t*)ctx.in); + buf += t; + len -= t; + } + + // Process data in 64-byte chunks + + while (len >= 64) { + std::memcpy(ctx.in, buf, 64); + byteReverse(ctx.in, 16); + MD5Transform(ctx.buf, (uint32_t*)ctx.in); + buf += 64; + len -= 64; + } + + // Handle any remaining bytes of data. + + std::memcpy(ctx.in, buf, len); +} + +void MD5::MD5Final(unsigned char digest[16]) { + // Compute number of bytes mod 64 */ + unsigned count = (ctx.bits[0] >> 3) & 0x3F; + + // Set the first char of padding to 0x80. This is safe since there is + // always at least one byte free + unsigned char* p = ctx.in + count; + *p++ = 0x80; + + // Bytes of padding needed to make 64 bytes + count = 64 - 1 - count; + + // Pad out to 56 mod 64 + if (count < 8) { + // Two lots of padding: Pad the first block to 64 bytes + std::memset(p, 0, count); + byteReverse(ctx.in, 16); + MD5Transform(ctx.buf, (uint32_t*)ctx.in); + + // Now fill the next block with 56 bytes + std::memset(ctx.in, 0, 56); + } else { + // Pad block to 56 bytes */ + std::memset(p, 0, count - 8); + } + byteReverse(ctx.in, 14); + + // Append length in bits and transform + ((uint32_t*)ctx.in)[14] = ctx.bits[0]; + ((uint32_t*)ctx.in)[15] = ctx.bits[1]; + + MD5Transform(ctx.buf, (uint32_t*)ctx.in); + byteReverse((unsigned char*)ctx.buf, 4); + std::memcpy(digest, ctx.buf, 16); + std::memset(&ctx, 0, sizeof(ctx)); // In case it is sensitive +} + +void MD5::DigestToBase16(const unsigned char* digest, char* zBuf) { + static char const zEncode[] = "0123456789abcdef"; + int i = 0, j = 0; + + for (j = i = 0; i < 16; i++) { + int a = digest[i]; + zBuf[j++] = zEncode[(a >> 4) & 0xf]; + zBuf[j++] = zEncode[a & 0xf]; + } + zBuf[j] = 0; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/metric.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/metric.cpp new file mode 100644 index 0000000000..3954cecd44 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/metric.cpp @@ -0,0 +1,55 @@ +#include "common/metric.h" + +namespace lbug { +namespace common { + +TimeMetric::TimeMetric(bool enable) : Metric(enable) { + accumulatedTime = 0; + isStarted = false; + timer = Timer(); +} + +void TimeMetric::start() { + if (!enabled) { + return; + } + isStarted = true; + timer.start(); +} + +void TimeMetric::stop() { + if (!enabled) { + return; + } + if (!isStarted) { + throw Exception("Timer metric has not started."); + } + timer.stop(); + accumulatedTime += timer.getDuration(); + isStarted = false; +} + +double TimeMetric::getElapsedTimeMS() const { + return accumulatedTime / 1000; +} + +NumericMetric::NumericMetric(bool enable) : Metric(enable) { + accumulatedValue = 0u; +} + +void NumericMetric::increase(uint64_t value) { + if (!enabled) { + return; + } + accumulatedValue += value; +} + +void NumericMetric::incrementByOne() { + if (!enabled) { + return; + } + accumulatedValue++; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/null_mask.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/null_mask.cpp new file mode 100644 index 0000000000..7ced00b80a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/null_mask.cpp @@ -0,0 +1,291 @@ +#include "common/null_mask.h" + +#include +#include +#include + +#include "common/assert.h" +#include + +namespace lbug { +namespace common { + +void NullMask::setNull(uint64_t* nullEntries, uint32_t pos, bool isNull) { + auto entryPos = pos >> NUM_BITS_PER_NULL_ENTRY_LOG2; + auto bitPosInEntry = pos - (entryPos << NUM_BITS_PER_NULL_ENTRY_LOG2); + if (isNull) { + nullEntries[entryPos] |= NULL_BITMASKS_WITH_SINGLE_ONE[bitPosInEntry]; + } else { + nullEntries[entryPos] &= NULL_BITMASKS_WITH_SINGLE_ZERO[bitPosInEntry]; + } +} + +bool NullMask::copyNullMask(const uint64_t* srcNullEntries, uint64_t srcOffset, + uint64_t* dstNullEntries, uint64_t dstOffset, uint64_t numBitsToCopy, bool invert) { + // From benchmarks using setNull/isNull is faster for up to 3 bits + // (~4x faster for a single bit copy) + if (numBitsToCopy <= 3) { + bool anyNull = false; + for (size_t i = 0; i < numBitsToCopy; i++) { + bool isNull = NullMask::isNull(srcNullEntries, srcOffset + i); + NullMask::setNull(dstNullEntries, dstOffset + i, isNull ^ invert); + anyNull |= isNull; + } + return anyNull; + } + // If both offsets are aligned relative to each other then copy up to the first byte using the + // non-aligned method, then copy aligned, then copy the end unaligned again. + if (!invert && (srcOffset % 8 == dstOffset % 8) && numBitsToCopy >= 8 && + numBitsToCopy - (srcOffset % 8) >= 8) { + auto numBitsInFirstByte = 0; + bool hasNullInSrcNullMask = false; + if (srcOffset != 0) { + numBitsInFirstByte = 8 - (srcOffset % 8); + if (copyUnaligned(srcNullEntries, srcOffset, dstNullEntries, dstOffset, + numBitsInFirstByte, false)) { + hasNullInSrcNullMask = true; + } + } + auto* src = + reinterpret_cast(srcNullEntries) + (srcOffset + numBitsInFirstByte) / 8; + auto* dst = + reinterpret_cast(dstNullEntries) + (dstOffset + numBitsInFirstByte) / 8; + auto numBytesForAlignedCopy = (numBitsToCopy - numBitsInFirstByte) / 8; + memcpy(dst, src, numBytesForAlignedCopy); + if (std::any_of(src, src + numBytesForAlignedCopy, [&](uint8_t val) { return val != 0; })) { + hasNullInSrcNullMask = true; + } + auto lastByteStart = numBitsInFirstByte + numBytesForAlignedCopy * 8; + auto numBitsInLastByte = numBitsToCopy - numBitsInFirstByte - numBytesForAlignedCopy * 8; + if (numBitsInLastByte > 0) { + return copyUnaligned(srcNullEntries, srcOffset + lastByteStart, dstNullEntries, + dstOffset + lastByteStart, numBitsInLastByte, false) || + hasNullInSrcNullMask; + } else { + return hasNullInSrcNullMask; + } + } else { + return copyUnaligned(srcNullEntries, srcOffset, dstNullEntries, dstOffset, numBitsToCopy, + invert); + } +} + +bool NullMask::copyUnaligned(const uint64_t* srcNullEntries, uint64_t srcOffset, + uint64_t* dstNullEntries, uint64_t dstOffset, uint64_t numBitsToCopy, bool invert) { + uint64_t bitPos = 0; + bool hasNullInSrcNullMask = false; + auto [srcNullEntryPos, srcNullBitPos] = getNullEntryAndBitPos(srcOffset + bitPos); + auto [dstNullEntryPos, dstNullBitPos] = getNullEntryAndBitPos(dstOffset + bitPos); + while (bitPos < numBitsToCopy) { + auto curDstNullEntryPos = dstNullEntryPos; + auto curDstNullBitPos = dstNullBitPos; + uint64_t numBitsToReadInCurrentEntry = 0; + uint64_t srcNullMaskEntry = + invert ? ~srcNullEntries[srcNullEntryPos] : srcNullEntries[srcNullEntryPos]; + if (dstNullBitPos < srcNullBitPos) { + numBitsToReadInCurrentEntry = + std::min(NullMask::NUM_BITS_PER_NULL_ENTRY - srcNullBitPos, numBitsToCopy - bitPos); + // Mask higher bits out of current read range to 0. + srcNullMaskEntry &= ~NULL_HIGH_MASKS[NullMask::NUM_BITS_PER_NULL_ENTRY - + (srcNullBitPos + numBitsToReadInCurrentEntry)]; + // Shift right to align the bit in the src and dst entry. + srcNullMaskEntry = srcNullMaskEntry >> (srcNullBitPos - dstNullBitPos); + // Mask lower bits out of current read range to 0. + srcNullMaskEntry &= ~NULL_LOWER_MASKS[dstNullBitPos]; + // Move to the next null entry in src null mask. + srcNullEntryPos++; + srcNullBitPos = 0; + dstNullBitPos += numBitsToReadInCurrentEntry; + } else if (dstNullBitPos > srcNullBitPos) { + numBitsToReadInCurrentEntry = + std::min(NullMask::NUM_BITS_PER_NULL_ENTRY - dstNullBitPos, numBitsToCopy - bitPos); + // Mask lower bits out of current read range to 0. + srcNullMaskEntry &= ~NULL_LOWER_MASKS[srcNullBitPos]; + // Shift left to align the bit in the src and dst entry. + srcNullMaskEntry = srcNullMaskEntry << (dstNullBitPos - srcNullBitPos); + // Mask higher bits out of current read range to 0. + srcNullMaskEntry &= ~NULL_HIGH_MASKS[NullMask::NUM_BITS_PER_NULL_ENTRY - + (dstNullBitPos + numBitsToReadInCurrentEntry)]; + // Move to the next null entry in dst null mask. + dstNullEntryPos++; + dstNullBitPos = 0; + srcNullBitPos += numBitsToReadInCurrentEntry; + } else { + numBitsToReadInCurrentEntry = + std::min(NullMask::NUM_BITS_PER_NULL_ENTRY - dstNullBitPos, numBitsToCopy - bitPos); + // Mask lower bits out of current read range to 0. + srcNullMaskEntry &= ~NULL_LOWER_MASKS[srcNullBitPos]; + // Mask higher bits out of current read range to 0. + srcNullMaskEntry &= ~NULL_HIGH_MASKS[NullMask::NUM_BITS_PER_NULL_ENTRY - + (dstNullBitPos + numBitsToReadInCurrentEntry)]; + // The input entry and the result entry are already aligned. + srcNullEntryPos++; + dstNullEntryPos++; + srcNullBitPos = dstNullBitPos = 0; + } + bitPos += numBitsToReadInCurrentEntry; + dstNullEntries[curDstNullEntryPos] &= + ~(NULL_LOWER_MASKS[numBitsToReadInCurrentEntry] << curDstNullBitPos); + if (srcNullMaskEntry != 0) { + dstNullEntries[curDstNullEntryPos] |= srcNullMaskEntry; + hasNullInSrcNullMask = true; + } + } + return hasNullInSrcNullMask; +} + +void NullMask::resize(uint64_t capacity) { + auto numNullEntries = (capacity + NUM_BITS_PER_NULL_ENTRY - 1) / NUM_BITS_PER_NULL_ENTRY; + auto resizedBuffer = std::make_unique(numNullEntries); + memcpy(resizedBuffer.get(), data.data(), data.size_bytes()); + buffer = std::move(resizedBuffer); + data = std::span(buffer.get(), numNullEntries); +} + +bool NullMask::copyFromNullBits(const uint64_t* srcNullEntries, uint64_t srcOffset, + uint64_t dstOffset, uint64_t numBitsToCopy, bool invert) { + KU_ASSERT(dstOffset + numBitsToCopy <= getNumNullBits(data)); + if (copyNullMask(srcNullEntries, srcOffset, this->data.data(), dstOffset, numBitsToCopy, + invert)) { + this->mayContainNulls = true; + return true; + } + return false; +} + +void NullMask::setNullFromRange(uint64_t offset, uint64_t numBitsToSet, bool isNull) { + if (isNull) { + this->mayContainNulls = true; + } + KU_ASSERT(offset + numBitsToSet <= getNumNullBits(data)); + setNullRange(data.data(), offset, numBitsToSet, isNull); +} + +void NullMask::setNullRange(uint64_t* nullEntries, uint64_t offset, uint64_t numBitsToSet, + bool isNull) { + if (numBitsToSet == 0) { + return; + } + + auto [firstEntryPos, firstBitPos] = getNullEntryAndBitPos(offset); + auto [lastEntryPos, lastBitPos] = getNullEntryAndBitPos(offset + numBitsToSet); + + // If the range spans multiple entries, set the entries in the middle to the appropriate value + // with std::fill + if (lastEntryPos > firstEntryPos + 1) { + std::fill(nullEntries + firstEntryPos + 1, nullEntries + lastEntryPos, + isNull ? ALL_NULL_ENTRY : NO_NULL_ENTRY); + } + + if (firstEntryPos == lastEntryPos) { + if (isNull) { + // Set bits between the first and the last bit pos to true + nullEntries[firstEntryPos] |= (~NULL_LOWER_MASKS[firstBitPos] & + ~NULL_HIGH_MASKS[NUM_BITS_PER_NULL_ENTRY - lastBitPos]); + } else { + // Set bits between the first and the last bit pos to false + nullEntries[firstEntryPos] &= (NULL_LOWER_MASKS[firstBitPos] | + NULL_HIGH_MASKS[NUM_BITS_PER_NULL_ENTRY - lastBitPos]); + } + } else { + if (isNull) { + // Set bits including and after the first bit pos to true + nullEntries[firstEntryPos] |= ~NULL_LOWER_MASKS[firstBitPos]; + if (lastBitPos > 0) { + // Set bits before the last bit pos to true + nullEntries[lastEntryPos] |= NULL_LOWER_MASKS[lastBitPos]; + } + } else { + // Set bits including and after the first bit pos to false + nullEntries[firstEntryPos] &= NULL_LOWER_MASKS[firstBitPos]; + if (lastBitPos > 0) { + // Set bits before the last bit pos to false + nullEntries[lastEntryPos] &= ~NULL_LOWER_MASKS[lastBitPos]; + } + } + } +} + +void NullMask::operator|=(const NullMask& other) { + KU_ASSERT(other.data.size() == data.size()); + for (size_t i = 0; i < data.size(); i++) { + data[i] |= other.getData()[i]; + } +} + +std::pair NullMask::getMinMax(const uint64_t* nullEntries, uint64_t offset, + uint64_t numValues) { + nullEntries += offset / NUM_BITS_PER_NULL_ENTRY; + offset = offset % NUM_BITS_PER_NULL_ENTRY; + + // If the offset+numValues are both within a word, just combine the appropriate masks and + // compare to 0/mask to determine if they are all 1s/all 0s (else a mix) + if (offset + numValues <= NUM_BITS_PER_NULL_ENTRY) { + auto mask = NULL_HIGH_MASKS[NUM_BITS_PER_NULL_ENTRY - offset] & + NULL_LOWER_MASKS[offset + numValues]; + auto masked = *nullEntries & mask; + if (masked == 0) { + return std::make_pair(false, false); + } else if (masked == mask) { + return std::make_pair(true, true); + } else { + return std::make_pair(false, true); + } + } + // If the range spans multiple entries, check the first one by masking the start + bool min = false, max = false; + if (offset > 0) { + auto mask = NULL_HIGH_MASKS[NUM_BITS_PER_NULL_ENTRY - offset]; + auto masked = *nullEntries & mask; + if (masked == 0) { + min = max = false; + } else if (masked == mask) { + min = max = true; + } else { + return std::make_pair(false, true); + } + nullEntries++; + numValues -= NUM_BITS_PER_NULL_ENTRY - offset; + } else { + // Rest of the entry will be checked in the loop below + min = max = isNull(nullEntries, 0); + } + + // Check central full bytes, which can be compared in a single operation since we don't ignore + // any bits If there was no offset, then we calculate the baseline based on the first bit, and + // compare that to the actual entry + auto baseline = min ? ~static_cast(0) : 0; + for (size_t i = 0; i < numValues / NUM_BITS_PER_NULL_ENTRY; i++) { + if (nullEntries[i] != baseline) { + return std::make_pair(false, true); + } + } + nullEntries += numValues / NUM_BITS_PER_NULL_ENTRY; + numValues = numValues % NUM_BITS_PER_NULL_ENTRY; + if (numValues > 0) { + // Check last entry + auto mask = NULL_LOWER_MASKS[numValues]; + auto masked = *nullEntries & mask; + if (masked == 0) { + return std::make_pair(false, max); + } else if (masked == mask) { + return std::make_pair(min, true); + } else { + return std::make_pair(false, true); + } + } + return std::make_pair(min, max); +} + +uint64_t NullMask::countNulls() const { + // If capacity % 64 != 0 then there may be unused bits at the end of the last entry, + // but these should always be 0. + uint64_t sum = 0; + for (auto entry : data) { + sum += std::popcount(entry); + } + return sum; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/profiler.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/profiler.cpp new file mode 100644 index 0000000000..973468f167 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/profiler.cpp @@ -0,0 +1,51 @@ +#include "common/profiler.h" + +namespace lbug { +namespace common { + +TimeMetric* Profiler::registerTimeMetric(const std::string& key) { + auto timeMetric = std::make_unique(enabled); + auto metricPtr = timeMetric.get(); + addMetric(key, std::move(timeMetric)); + return metricPtr; +} + +NumericMetric* Profiler::registerNumericMetric(const std::string& key) { + auto numericMetric = std::make_unique(enabled); + auto metricPtr = numericMetric.get(); + addMetric(key, std::move(numericMetric)); + return metricPtr; +} + +double Profiler::sumAllTimeMetricsWithKey(const std::string& key) { + auto sum = 0.0; + if (!metrics.contains(key)) { + return sum; + } + for (auto& metric : metrics.at(key)) { + sum += ((TimeMetric*)metric.get())->getElapsedTimeMS(); + } + return sum; +} + +uint64_t Profiler::sumAllNumericMetricsWithKey(const std::string& key) { + auto sum = 0ul; + if (!metrics.contains(key)) { + return sum; + } + for (auto& metric : metrics.at(key)) { + sum += ((NumericMetric*)metric.get())->accumulatedValue; + } + return sum; +} + +void Profiler::addMetric(const std::string& key, std::unique_ptr metric) { + std::lock_guard lck(mtx); + if (!metrics.contains(key)) { + metrics.insert({key, std::vector>()}); + } + metrics.at(key).push_back(std::move(metric)); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/random_engine.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/random_engine.cpp new file mode 100644 index 0000000000..9f991e83c2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/random_engine.cpp @@ -0,0 +1,38 @@ +#include "common/random_engine.h" + +#include + +#include "main/client_context.h" + +namespace lbug { +namespace common { + +RandomEngine::RandomEngine() : randomState(RandomState()) { + randomState.pcg.seed(pcg_extras::seed_seq_from()); +} + +RandomEngine::RandomEngine(uint64_t seed, uint64_t stream) : randomState(RandomState()) { + randomState.pcg.seed(seed, stream); +} + +void RandomEngine::setSeed(uint64_t seed) { + std::unique_lock xLck{mtx}; + randomState.pcg.seed(seed); +} + +uint32_t RandomEngine::nextRandomInteger() { + std::unique_lock xLck{mtx}; + return randomState.pcg(); +} + +uint32_t RandomEngine::nextRandomInteger(uint32_t upper) { + std::unique_lock xLck{mtx}; + return randomState.pcg(upper); +} + +RandomEngine* RandomEngine::Get(const main::ClientContext& context) { + return context.randomEngine.get(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/roaring_mask.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/roaring_mask.cpp new file mode 100644 index 0000000000..2a5876cea1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/roaring_mask.cpp @@ -0,0 +1,63 @@ +#include "common/roaring_mask.h" + +namespace lbug { +namespace common { + +offset_vec_t Roaring32BitmapSemiMask::collectMaskedNodes(uint64_t size) const { + offset_vec_t result; + result.reserve(size); + auto it = roaring->begin(); + for (; it != roaring->end(); it++) { + auto value = *it; + result.push_back(value); + if (result.size() == size) { + break; + } + } + return result; +} + +offset_vec_t Roaring32BitmapSemiMask::range(uint32_t start, uint32_t end) { + auto it = roaring->begin(); + it.equalorlarger(start); + offset_vec_t ans; + for (; it != roaring->end(); it++) { + auto value = *it; + if (value >= end) { + break; + } + ans.push_back(value); + } + return ans; +} + +offset_vec_t Roaring64BitmapSemiMask::collectMaskedNodes(uint64_t size) const { + offset_vec_t result; + result.reserve(size); + auto it = roaring->begin(); + for (; it != roaring->end(); it++) { + auto value = *it; + result.push_back(value); + if (result.size() == size) { + break; + } + } + return result; +} + +offset_vec_t Roaring64BitmapSemiMask::range(uint32_t start, uint32_t end) { + auto it = roaring->begin(); + it.move(start); + offset_vec_t ans; + for (; it != roaring->end(); it++) { + auto value = *it; + if (value >= end) { + break; + } + ans.push_back(value); + } + return ans; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/CMakeLists.txt new file mode 100644 index 0000000000..8a8fba4f80 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_common_serializer + OBJECT + buffer_writer.cpp + buffered_file.cpp + deserializer.cpp + in_mem_file_writer.cpp + serializer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/buffer_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/buffer_writer.cpp new file mode 100644 index 0000000000..4545bd134e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/buffer_writer.cpp @@ -0,0 +1,30 @@ +#include "common/serializer/buffer_writer.h" + +#include + +namespace lbug { +namespace common { + +BufferWriter::BufferWriter(uint64_t maximumSize) : maximumSize(maximumSize) { + blob.data = std::make_unique(maximumSize); + blob.size = 0; + data = blob.data.get(); +} + +void BufferWriter::write(const uint8_t* buffer, uint64_t len) { + if (blob.size + len >= maximumSize) { + do { + maximumSize *= 2; + } while (blob.size + len > maximumSize); + auto new_data = std::make_unique(maximumSize); + memcpy(new_data.get(), data, blob.size); + data = new_data.get(); + blob.data = std::move(new_data); + } + + memcpy(data + blob.size, buffer, len); + blob.size += len; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/buffered_file.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/buffered_file.cpp new file mode 100644 index 0000000000..3ccd794623 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/buffered_file.cpp @@ -0,0 +1,114 @@ +#include "common/serializer/buffered_file.h" + +#include + +#include "common/assert.h" +#include "common/exception/runtime.h" +#include "common/file_system/file_info.h" +#include "common/system_config.h" + +namespace lbug { +namespace common { + +static constexpr uint64_t BUFFER_SIZE = LBUG_PAGE_SIZE; + +BufferedFileWriter::BufferedFileWriter(FileInfo& fileInfo) + : buffer(std::make_unique(BUFFER_SIZE)), fileOffset(0), bufferOffset(0), + fileInfo(fileInfo) {} + +BufferedFileWriter::~BufferedFileWriter() { + flush(); +} + +void BufferedFileWriter::write(const uint8_t* data, uint64_t size) { + if (size > BUFFER_SIZE) { + flush(); + fileInfo.writeFile(data, size, fileOffset); + fileOffset += size; + return; + } + KU_ASSERT(size <= BUFFER_SIZE); + if (bufferOffset + size <= BUFFER_SIZE) { + memcpy(&buffer[bufferOffset], data, size); + bufferOffset += size; + } else { + auto toCopy = BUFFER_SIZE - bufferOffset; + memcpy(&buffer[bufferOffset], data, toCopy); + bufferOffset += toCopy; + flush(); + auto remaining = size - toCopy; + memcpy(buffer.get(), data + toCopy, remaining); + bufferOffset += remaining; + } +} + +void BufferedFileWriter::clear() { + fileInfo.truncate(0); + resetOffsets(); +} + +void BufferedFileWriter::flush() { + if (bufferOffset == 0) { + return; + } + fileInfo.writeFile(buffer.get(), bufferOffset, fileOffset); + fileOffset += bufferOffset; + bufferOffset = 0; + memset(buffer.get(), 0, BUFFER_SIZE); +} + +void BufferedFileWriter::sync() { + fileInfo.syncFile(); +} + +uint64_t BufferedFileWriter::getSize() const { + return fileInfo.getFileSize() + bufferOffset; +} + +BufferedFileReader::BufferedFileReader(FileInfo& fileInfo) + : buffer(std::make_unique(BUFFER_SIZE)), fileOffset(0), bufferOffset(0), + fileInfo(fileInfo), bufferSize{0} { + fileSize = this->fileInfo.getFileSize(); + readNextPage(); +} + +void BufferedFileReader::read(uint8_t* data, uint64_t size) { + if (size > BUFFER_SIZE) { + // Clear read buffer. + fileOffset -= bufferSize; + fileOffset += bufferOffset; + fileInfo.readFromFile(data, size, fileOffset); + fileOffset += size; + bufferOffset = bufferSize; + } else if (bufferOffset + size <= bufferSize) { + memcpy(data, &buffer[bufferOffset], size); + bufferOffset += size; + } else { + auto toCopy = bufferSize - bufferOffset; + memcpy(data, &buffer[bufferOffset], toCopy); + bufferOffset += toCopy; + readNextPage(); + auto remaining = size - toCopy; + memcpy(data + toCopy, buffer.get(), remaining); + bufferOffset += remaining; + } +} + +bool BufferedFileReader::finished() { + return bufferOffset >= bufferSize && fileSize <= fileOffset; +} + +void BufferedFileReader::readNextPage() { + if (fileSize <= fileOffset) { + throw RuntimeException( + stringFormat("Reading past the end of the file {} with size {} at offset {}", + fileInfo.path, fileSize, fileOffset)); + } + bufferSize = std::min(fileSize - fileOffset, BUFFER_SIZE); + fileInfo.readFromFile(buffer.get(), bufferSize, fileOffset); + fileOffset += bufferSize; + bufferOffset = 0; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/deserializer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/deserializer.cpp new file mode 100644 index 0000000000..e2512f4a5c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/deserializer.cpp @@ -0,0 +1,25 @@ +#include "common/serializer/deserializer.h" + +namespace lbug { +namespace common { + +template<> +void Deserializer::deserializeValue(std::string& value) { + uint64_t valueLength = 0; + deserializeValue(valueLength); + value.resize(valueLength); + reader->read(reinterpret_cast(value.data()), valueLength); +} + +void Deserializer::validateDebuggingInfo(std::string& value, const std::string& expectedVal) { +#if defined(LBUG_DESER_DEBUG) && (defined(LBUG_RUNTIME_CHECKS) || !defined(NDEBUG)) + deserializeValue(value); + KU_ASSERT(value == expectedVal); +#endif + // DO NOTHING + KU_UNUSED(value); + KU_UNUSED(expectedVal); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/in_mem_file_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/in_mem_file_writer.cpp new file mode 100644 index 0000000000..dfee92209f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/in_mem_file_writer.cpp @@ -0,0 +1,83 @@ +#include "common/serializer/in_mem_file_writer.h" + +#include "storage/file_handle.h" +#include "storage/shadow_file.h" +#include "storage/shadow_utils.h" + +namespace lbug { +namespace common { + +InMemFileWriter::InMemFileWriter(storage::MemoryManager& mm) : mm{mm}, pageOffset{0} {} + +void InMemFileWriter::write(const uint8_t* data, uint64_t size) { + auto remaining = size; + while (remaining > 0) { + if (needNewBuffer(size)) { + const auto lastPage = pages.empty() ? nullptr : pages.back().get(); + if (lastPage) { + auto toCopy = std::min(size, LBUG_PAGE_SIZE - pageOffset); + memcpy(lastPage->getData() + pageOffset, data + (size - remaining), toCopy); + remaining -= toCopy; + } + pages.push_back(mm.allocateBuffer(false, LBUG_PAGE_SIZE)); + pageOffset = 0; + } + auto toCopy = std::min(remaining, LBUG_PAGE_SIZE - pageOffset); + memcpy(pages.back()->getData() + pageOffset, data + (size - remaining), toCopy); + pageOffset += toCopy; + remaining -= toCopy; + } +} + +storage::PageRange InMemFileWriter::flush(storage::PageAllocator& pageAllocator, + storage::ShadowFile& shadowFile) const { + auto numPagesToFlush = getNumPagesToFlush(); + auto pageRange = pageAllocator.allocatePageRange(numPagesToFlush); + flush(pageRange, pageAllocator.getDataFH(), shadowFile); + return pageRange; +} + +void InMemFileWriter::flush(storage::PageRange allocatedPageRange, storage::FileHandle* fileHandle, + storage::ShadowFile& shadowFile) const { + auto numPagesToWrite = getNumPagesToFlush(); + KU_ASSERT(allocatedPageRange.numPages >= numPagesToWrite); + auto numPagesBeforeAllocate = allocatedPageRange.startPageIdx; + for (auto i = 0u; i < numPagesToWrite; i++) { + auto pageIdx = allocatedPageRange.startPageIdx + i; + auto insertingNewPage = pageIdx >= numPagesBeforeAllocate; + auto shadowPageAndFrame = storage::ShadowUtils::createShadowVersionIfNecessaryAndPinPage( + pageIdx, insertingNewPage, *fileHandle, shadowFile); + memcpy(shadowPageAndFrame.frame, pages[i]->getData(), LBUG_PAGE_SIZE); + shadowFile.getShadowingFH().unpinPage(shadowPageAndFrame.shadowPage); + } + + // Write zeroes to any extra pages + // This ensures that the size of the data file matches the size expected from allocations + // even if we reload the database immediately after this + for (auto i = numPagesToWrite; i < allocatedPageRange.numPages; i++) { + auto pageIdx = allocatedPageRange.startPageIdx + i; + auto insertingNewPage = pageIdx >= numPagesBeforeAllocate; + auto shadowPageAndFrame = storage::ShadowUtils::createShadowVersionIfNecessaryAndPinPage( + pageIdx, insertingNewPage, *fileHandle, shadowFile); + memset(shadowPageAndFrame.frame, 0u, LBUG_PAGE_SIZE); + shadowFile.getShadowingFH().unpinPage(shadowPageAndFrame.shadowPage); + } +} + +void InMemFileWriter::flush(Writer& writer) const { + for (auto i = 0u; i < pages.size(); i++) { + auto sizeToFlush = (i == pages.size() - 1) ? pageOffset : LBUG_PAGE_SIZE; + writer.write(pages[i]->getData(), sizeToFlush); + } +} + +bool InMemFileWriter::needNewBuffer(uint64_t size) const { + return pages.empty() || pageOffset + size > LBUG_PAGE_SIZE; +} + +uint64_t InMemFileWriter::getPageSize() { + return LBUG_PAGE_SIZE; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/serializer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/serializer.cpp new file mode 100644 index 0000000000..10eafeb1be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/serializer/serializer.cpp @@ -0,0 +1,24 @@ +#include "common/serializer/serializer.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +template<> +void Serializer::serializeValue(const std::string& value) { + uint64_t valueLength = value.length(); + writer->write((uint8_t*)&valueLength, sizeof(uint64_t)); + writer->write((uint8_t*)value.data(), valueLength); +} + +void Serializer::writeDebuggingInfo(const std::string& value) { +#if defined(LBUG_DESER_DEBUG) && (defined(LBUG_RUNTIME_CHECKS) || !defined(NDEBUG)) + serializeValue(value); +#endif + // DO NOTHING + KU_UNUSED(value); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/sha256.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/sha256.cpp new file mode 100644 index 0000000000..8079f7ee75 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/sha256.cpp @@ -0,0 +1,52 @@ +#include "common/sha256.h" + +#include "common/exception/runtime.h" + +namespace lbug { +namespace common { + +SHA256::SHA256() : shaContext{} { + mbedtls_sha256_init(&shaContext); + + // These errors would only occur if there's an issue with shaContext which is wrapped inside + // SHA256, or with the mbedtls library itself + if (mbedtls_sha256_starts(&shaContext, false)) { + throw RuntimeException{"SHA256 Error"}; + } +} + +SHA256::~SHA256() { + mbedtls_sha256_free(&shaContext); +} + +void SHA256::addString(const std::string& str) { + if (mbedtls_sha256_update(&shaContext, reinterpret_cast(str.data()), + str.size())) { + throw RuntimeException{"SHA256 Error"}; + } +} + +void SHA256::finishSHA256(char* out) { + std::string hash; + hash.resize(SHA256_HASH_LENGTH_BYTES); + + if (mbedtls_sha256_finish(&shaContext, reinterpret_cast(hash.data()))) { + throw RuntimeException{"SHA256 Error"}; + } + + toBase16(hash.c_str(), out, SHA256_HASH_LENGTH_BYTES); +} + +void SHA256::toBase16(const char* in, char* out, size_t len) { + static char const HEX_CODES[] = "0123456789abcdef"; + size_t i = 0, j = 0; + + for (j = i = 0; i < len; i++) { + int a = in[i]; + out[j++] = HEX_CODES[(a >> 4) & 0xf]; + out[j++] = HEX_CODES[a & 0xf]; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/signal/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/signal/CMakeLists.txt new file mode 100644 index 0000000000..af256cc72e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/signal/CMakeLists.txt @@ -0,0 +1,4 @@ +if (ENABLE_BACKTRACES) + add_library(register_backtrace_signal_handler OBJECT register.cpp) + target_link_libraries(register_backtrace_signal_handler cpptrace::cpptrace) +endif() diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/signal/register.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/signal/register.cpp new file mode 100644 index 0000000000..8da3340832 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/signal/register.cpp @@ -0,0 +1,32 @@ +#ifdef LBUG_BACKTRACE +#include +#include +#include +#include + +#include + +namespace { + +void handler(int signo) { + // Not safe. Safe method would be writing directly to stderr with a pre-defined string + // But since the below isn't safe either... + std::cerr << "Fatal signal " << signo << std::endl; + // This is not safe, however the safe version, described at the link below, + // was causing hangs when the tracer program can't be found. + // Since this is only used in CI, the occasional failure/hang is probably acceptable. + // https://github.com/jeremy-rifkin/cpptrace/blob/main/docs/signal-safe-tracing.md + cpptrace::generate_trace(1 /*skip this function's frame*/).print(); + std::_Exit(1); +} + +int register_signal_handlers() noexcept { + std::signal(SIGSEGV, handler); + std::signal(SIGFPE, handler); + cpptrace::register_terminate_handler(); + return 0; +} + +static int ignore = register_signal_handlers(); +}; // namespace +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/string_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/string_utils.cpp new file mode 100644 index 0000000000..e474de398d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/string_utils.cpp @@ -0,0 +1,281 @@ +#include "common/string_utils.h" + +#include +#include + +#include "common/exception/runtime.h" +#include "function/string/functions/base_lower_upper_function.h" +#include "utf8proc_wrapper.h" + +namespace lbug { +namespace common { + +std::vector StringUtils::splitComma(const std::string& input) { + auto result = std::vector(); + auto currentPos = 0u; + auto lvl = 0u; + while (currentPos < input.length()) { + if (input[currentPos] == '(') { + lvl++; + } else if (input[currentPos] == ')') { + lvl--; + } else if (lvl == 0 && input[currentPos] == ',') { + break; + } + currentPos++; + } + result.push_back(input.substr(0, currentPos)); + result.push_back(input.substr(currentPos == input.length() ? input.length() : currentPos + 1)); + return result; +} + +static char openingBracket(char c) { + if (c == ')') { + return '('; + } + if (c == ']') { + return '['; + } + if (c == '}') { + return '{'; + } + return c; +} + +std::vector StringUtils::smartSplit(std::string_view input, char splitChar, + uint64_t maxNumEle) { + if (input.size() == 0) { + return {}; + } + std::vector result; + auto currentItem = 0u; + std::vector stk; + bool insideSingleQuote = false; + for (auto i = 0u; i < input.size(); i++) { + char c = input[i]; + + if (c == '\'' && (stk.size() == 0 || stk.back() != '\'')) { + // Entering/Exiting a single-quoted block. + insideSingleQuote = !insideSingleQuote; + } else if (c == splitChar && stk.size() == 0u && !insideSingleQuote) { + if (result.size() + 1 == maxNumEle) { + result.push_back(input.substr(currentItem)); + return result; + } else { + result.push_back(input.substr(currentItem, i - currentItem)); + currentItem = i + 1; + } + } else if (c == '{' || c == '(' || c == '[' || + (c == '\"' && (stk.size() == 0u || stk.back() != '\"'))) { + stk.push_back(c); + } else if (stk.size() > 0 && openingBracket(c) == stk.back()) { + stk.pop_back(); + } + } + result.push_back(input.substr(currentItem)); + return result; +} + +uint64_t findDelim(const std::string& input, const std::string& delimiter, uint64_t prevPos) { + if (delimiter != "") { + return input.find(delimiter, prevPos); + } + return prevPos < input.size() - 1 ? prevPos + 1 : std::string::npos; +} + +std::vector StringUtils::split(const std::string& input, const std::string& delimiter, + bool ignoreEmptyStringParts) { + auto result = std::vector(); + auto prevPos = 0u; + auto currentPos = findDelim(input, delimiter, prevPos); + while (currentPos != std::string::npos) { + auto stringPart = input.substr(prevPos, currentPos - prevPos); + if (!ignoreEmptyStringParts || !stringPart.empty()) { + result.push_back(input.substr(prevPos, currentPos - prevPos)); + } + prevPos = currentPos + delimiter.size(); + currentPos = findDelim(input, delimiter, prevPos); + } + result.push_back(input.substr(prevPos)); + return result; +} + +std::vector StringUtils::splitBySpace(const std::string& input) { + std::istringstream iss(input); + std::vector result; + std::string token; + while (iss >> token) { + result.push_back(token); + } + return result; +} + +std::string StringUtils::getUpper(const std::string& input) { + auto result = input; + toUpper(result); + return result; +} + +std::string StringUtils::getUpper(const std::string_view& input) { + auto result = std::string(input); + toUpper(result); + return result; +} + +std::string StringUtils::getLower(const std::string& input) { + auto result = input; + toLower(result); + return result; +} + +template +static void changeCase(std::string& input) { + if (!utf8proc::Utf8Proc::isValid(input.c_str(), input.length())) { + throw RuntimeException{"Invalid UTF8-encoded string."}; + } + auto resultLen = function::BaseLowerUpperFunction::getResultLen((char*)input.data(), + input.length(), toUpper); + std::string result(resultLen, '\0' /* char */); + function::BaseLowerUpperFunction::convertCase((char*)result.data(), input.length(), + input.data(), toUpper); + input = result; +} + +void StringUtils::toLower(std::string& input) { + changeCase(input); +} + +void StringUtils::toUpper(std::string& input) { + changeCase(input); +} + +void StringUtils::removeCStringWhiteSpaces(const char*& input, uint64_t& len) { + // skip leading/trailing spaces + while (len > 0 && isspace(input[0])) { + input++; + len--; + } + while (len > 0 && isspace(input[len - 1])) { + len--; + } +} + +void StringUtils::replaceAll(std::string& str, const std::string& search, + const std::string& replacement) { + size_t pos = 0; + while ((pos = str.find(search, pos)) != std::string::npos) { + str.replace(pos, search.length(), replacement); + pos += replacement.length(); + } +} + +std::string StringUtils::extractStringBetween(const std::string& input, char delimiterStart, + char delimiterEnd, bool includeDelimiter) { + std::string::size_type posStart = input.find_first_of(delimiterStart); + std::string::size_type posEnd = input.find_last_of(delimiterEnd); + if (posStart == std::string::npos || posEnd == std::string::npos || posStart >= posEnd) { + return ""; + } + if (includeDelimiter) { + posEnd++; + } else { + posStart++; + } + return input.substr(posStart, posEnd - posStart); +} + +// Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function. +// We transform each character to its lower case and apply one_at_a_time hash. +uint64_t StringUtils::caseInsensitiveHash(const std::string& str) { + uint32_t hash = 0; + for (auto c : str) { + hash += tolower(c); + hash += hash << 10; + hash ^= hash >> 6; + } + hash += hash << 3; + hash ^= hash >> 11; + hash += hash << 15; + return hash; +} + +bool StringUtils::caseInsensitiveEquals(std::string_view left, std::string_view right) { + if (left.size() != right.size()) { + return false; + } + for (auto c = 0u; c < left.size(); c++) { + if (asciiToLowerCaseMap[(uint8_t)left[c]] != asciiToLowerCaseMap[(uint8_t)right[c]]) { + return false; + } + } + return true; +} + +std::string StringUtils::join(const std::vector& input, const std::string& separator) { + return StringUtils::join(input, input.size(), separator, + [](const std::string& s) { return s; }); +} + +std::string StringUtils::join(const std::span input, + const std::string& separator) { + return StringUtils::join(input, input.size(), separator, + [](const std::string_view s) { return std::string{s}; }); +} + +template +std::string StringUtils::join(const C& input, S count, const std::string& separator, Func f) { + std::string result; + // if the input isn't empty, append the first element. We do this so we + // don't need to introduce an if into the loop. + if (count > 0) { + result += f(input[0]); + } + // append the remaining input components, after the first + for (size_t i = 1; i < count; i++) { + result += separator + f(input[i]); + } + return result; +} + +std::string StringUtils::ltrimNewlines(const std::string& input) { + auto s = input; + s.erase(s.begin(), + find_if(s.begin(), s.end(), [](unsigned char ch) { return !characterIsNewLine(ch); })); + return s; +} + +std::string StringUtils::rtrimNewlines(const std::string& input) { + auto s = input; + s.erase(find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !characterIsNewLine(ch); }) + .base(), + s.end()); + return s; +} + +std::string StringUtils::encodeURL(const std::string& input, bool encodeSlash) { + static const char* hex_digit = "0123456789ABCDEF"; + static constexpr std::string_view unreserved_chars = "_.-~"; + constexpr auto isUnreserved = [](char ch) { + return std::isalnum(ch) || unreserved_chars.find(ch) != std::string_view::npos; + }; + std::string result; + result.reserve(input.size()); + for (auto ch : input) { + if (isUnreserved(ch)) { + result += ch; + } else if (ch == '/') { + if (encodeSlash) { + result += std::string("%2F"); + } else { + result += ch; + } + } else { + result += std::string("%"); + result += hex_digit[static_cast(ch) >> 4]; + result += hex_digit[static_cast(ch) & 15]; + } + } + return result; +} +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/system_message.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/system_message.cpp new file mode 100644 index 0000000000..8abe2b82a5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/system_message.cpp @@ -0,0 +1,37 @@ +#include "common/system_message.h" + +#ifdef _WIN32 +#include "windows.h" +#else +#include +#endif + +namespace lbug { +namespace common { + +std::string dlErrMessage() { +#ifdef _WIN32 + DWORD errorMessageID = GetLastError(); + if (errorMessageID == 0) { + return std::string(); + } + + LPSTR messageBuffer = nullptr; + auto size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, + NULL); + + std::string message(messageBuffer, size); + + // Free the buffer. + LocalFree(messageBuffer); + + return message; +#else + return dlerror(); // NOLINT(concurrency-mt-unsafe): load can only be executed in single thread. +#endif +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/CMakeLists.txt new file mode 100644 index 0000000000..141c08fbeb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(lbug_common_task_system + OBJECT + task.cpp + task_scheduler.cpp + progress_bar.cpp + terminal_progress_bar_display.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/progress_bar.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/progress_bar.cpp new file mode 100644 index 0000000000..b53e418747 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/progress_bar.cpp @@ -0,0 +1,79 @@ +#include "common/task_system/progress_bar.h" + +#include "common/task_system/terminal_progress_bar_display.h" +#include "main/client_context.h" + +namespace lbug { +namespace common { + +ProgressBar::ProgressBar(bool enableProgressBar) { + display = DefaultProgressBarDisplay(); + numPipelines = 0; + numPipelinesFinished = 0; + trackProgress = enableProgressBar; +} + +std::shared_ptr ProgressBar::DefaultProgressBarDisplay() { + return std::make_shared(); +} + +void ProgressBar::setDisplay(std::shared_ptr progressBarDipslay) { + display = progressBarDipslay; +} + +void ProgressBar::startProgress(uint64_t queryID) { + if (!trackProgress) { + return; + } + std::lock_guard lock(progressBarLock); + updateDisplay(queryID, 0.0); +} + +void ProgressBar::endProgress(uint64_t queryID) { + std::lock_guard lock(progressBarLock); + resetProgressBar(queryID); +} + +void ProgressBar::addPipeline() { + if (!trackProgress) { + return; + } + numPipelines++; + display->setNumPipelines(numPipelines); +} + +void ProgressBar::finishPipeline(uint64_t queryID) { + if (!trackProgress) { + return; + } + numPipelinesFinished++; + updateProgress(queryID, 0.0); +} + +void ProgressBar::updateProgress(uint64_t queryID, double curPipelineProgress) { + if (!trackProgress) { + return; + } + updateDisplay(queryID, curPipelineProgress); +} + +void ProgressBar::resetProgressBar(uint64_t queryID) { + numPipelines = 0; + numPipelinesFinished = 0; + display->finishProgress(queryID); +} + +void ProgressBar::updateDisplay(uint64_t queryID, double curPipelineProgress) { + display->updateProgress(queryID, curPipelineProgress, numPipelinesFinished); +} + +void ProgressBar::toggleProgressBarPrinting(bool enable) { + trackProgress = enable; +} + +ProgressBar* ProgressBar::Get(const main::ClientContext& context) { + return context.progressBar.get(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/task.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/task.cpp new file mode 100644 index 0000000000..c0cd09d009 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/task.cpp @@ -0,0 +1,32 @@ +#include "common/task_system/task.h" + +namespace lbug { +namespace common { + +bool Task::registerThread() { + lock_t lck{taskMtx}; + if (!hasExceptionNoLock() && canRegisterNoLock()) { + numThreadsRegistered++; + return true; + } + return false; +} + +void Task::deRegisterThreadAndFinalizeTask() { + lock_t lck{taskMtx}; + ++numThreadsFinished; + if (!hasExceptionNoLock() && isCompletedNoLock()) { + try { + finalize(); + } catch (std::exception& e) { + setExceptionNoLock(std::current_exception()); + } + } + if (isCompletedNoLock()) { + lck.unlock(); + cv.notify_all(); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/task_scheduler.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/task_scheduler.cpp new file mode 100644 index 0000000000..59ef8d56dc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/task_scheduler.cpp @@ -0,0 +1,230 @@ +#include "common/task_system/task_scheduler.h" + +#include "main/client_context.h" +#include "main/database.h" +#include "processor/processor.h" + +#if defined(__APPLE__) +#include + +#include +#endif + +namespace lbug { +namespace common { + +#ifndef __SINGLE_THREADED__ + +#if defined(__APPLE__) +TaskScheduler::TaskScheduler(uint64_t numWorkerThreads, uint32_t threadQos) +#else +TaskScheduler::TaskScheduler(uint64_t numWorkerThreads) +#endif + : stopWorkerThreads{false}, nextScheduledTaskID{0} { +#if defined(__APPLE__) + this->threadQos = threadQos; +#endif + for (auto n = 0u; n < numWorkerThreads; ++n) { + workerThreads.emplace_back([&] { runWorkerThread(); }); + } +} + +TaskScheduler::~TaskScheduler() { + lock_t lck{taskSchedulerMtx}; + stopWorkerThreads = true; + lck.unlock(); + cv.notify_all(); + for (auto& thread : workerThreads) { + thread.join(); + } +} + +void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr& task, + processor::ExecutionContext* context, bool launchNewWorkerThread) { + for (auto& dependency : task->children) { + scheduleTaskAndWaitOrError(dependency, context); + if (dependency->terminate()) { + return; + } + } + std::thread newWorkerThread; + if (launchNewWorkerThread) { + // Note that newWorkerThread is not executing yet. However, we still call + // task->registerThread() function because the call in the next line will guarantee + // that the thread starts working on it. registerThread() function only increases the + // numThreadsRegistered field of the task, tt does not keep track of the thread ids or + // anything specific to the thread. + task->registerThread(); + newWorkerThread = std::thread(runTask, task.get()); + } + auto scheduledTask = pushTaskIntoQueue(task); + cv.notify_all(); + std::unique_lock taskLck{task->taskMtx, std::defer_lock}; + while (true) { + taskLck.lock(); + bool timedWait = false; + auto timeout = 0u; + if (task->isCompletedNoLock()) { + // Note: we do not remove completed tasks from the queue in this function. They will be + // removed by the worker threads when they traverse down the queue for a task to work on + // (see getTaskAndRegister()). + taskLck.unlock(); + break; + } + if (context->clientContext->hasTimeout()) { + timeout = context->clientContext->getTimeoutRemainingInMS(); + if (timeout == 0) { + context->clientContext->interrupt(); + } else { + timedWait = true; + } + } else if (task->hasExceptionNoLock()) { + // Interrupt tasks that errored, so other threads can stop working on them early. + context->clientContext->interrupt(); + } + if (timedWait) { + task->cv.wait_for(taskLck, std::chrono::milliseconds(timeout)); + } else { + task->cv.wait(taskLck); + } + taskLck.unlock(); + } + if (launchNewWorkerThread) { + newWorkerThread.join(); + } + if (task->hasException()) { + removeErroringTask(scheduledTask->ID); + std::rethrow_exception(task->getExceptionPtr()); + } +} + +void TaskScheduler::runWorkerThread() { +#if defined(__APPLE__) + qos_class_t qosClass = (qos_class_t)threadQos; + if (qosClass != QOS_CLASS_DEFAULT && qosClass != QOS_CLASS_UNSPECIFIED) { + auto pthreadQosStatus = pthread_set_qos_class_self_np(qosClass, 0); + KU_UNUSED(pthreadQosStatus); + } +#endif + std::unique_lock lck{taskSchedulerMtx, std::defer_lock}; + std::exception_ptr exceptionPtr = nullptr; + std::shared_ptr scheduledTask = nullptr; + while (true) { + // Warning: Threads acquire a global lock (using taskSchedulerMutex) right before + // deregistering themselves from a task (and they immediately register themselves for + // another task without releasing the lock). This acquire-right-before-deregistering ensures + // that all writes that were done by threads in Task_j happen before a Task_{j+1} which + // depends on Task_j can start. That's because before Task_{j+1} can start, each thread T_i + // working on Task_j will need to deregister itself using the global lock. Therefore, by the + // time any thread gets to start on Task_{j+1}, all writes made to Task_j by T_i will become + // globally visible because T_i grabbed the global lock before deregistering (and without + // T_i deregistering Task_{j+1} cannot start). + lck.lock(); + if (scheduledTask != nullptr) { + if (exceptionPtr != nullptr) { + scheduledTask->task->setException(exceptionPtr); + exceptionPtr = nullptr; + } + scheduledTask->task->deRegisterThreadAndFinalizeTask(); + scheduledTask = nullptr; + } + cv.wait(lck, [&] { + scheduledTask = getTaskAndRegister(); + return scheduledTask != nullptr || stopWorkerThreads; + }); + lck.unlock(); + if (stopWorkerThreads) { + return; + } + try { + scheduledTask->task->run(); + } catch (std::exception& e) { + exceptionPtr = std::current_exception(); + } + } +} +#else +// Single-threaded version of TaskScheduler +TaskScheduler::TaskScheduler(uint64_t) : stopWorkerThreads{false}, nextScheduledTaskID{0} {} + +TaskScheduler::~TaskScheduler() { + stopWorkerThreads = true; +} + +void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr& task, + processor::ExecutionContext* context, bool) { + for (auto& dependency : task->children) { + scheduleTaskAndWaitOrError(dependency, context); + if (dependency->terminate()) { + return; + } + } + task->registerThread(); + // runTask deregisters, so we don't need to deregister explicitly here + runTask(task.get()); + if (task->hasException()) { + removeErroringTask(task->ID); + std::rethrow_exception(task->getExceptionPtr()); + } +} +#endif + +std::shared_ptr TaskScheduler::pushTaskIntoQueue(const std::shared_ptr& task) { + lock_t lck{taskSchedulerMtx}; + auto scheduledTask = std::make_shared(task, nextScheduledTaskID++); + taskQueue.push_back(scheduledTask); + return scheduledTask; +} + +std::shared_ptr TaskScheduler::getTaskAndRegister() { + if (taskQueue.empty()) { + return nullptr; + } + auto it = taskQueue.begin(); + while (it != taskQueue.end()) { + auto task = (*it)->task; + if (!task->registerThread()) { + // If we cannot register for a thread it is because of three possibilities: + // (i) maximum number of threads have registered for task and the task is completed + // without an exception; or (ii) same as (i) but the task has not yet successfully + // completed; or (iii) task has an exception; Only in (i) we remove the task from the + // queue. For (ii) and (iii) we keep the task in queue. Recall erroring tasks need to be + // manually removed. + if (task->isCompletedSuccessfully()) { // option (i) + it = taskQueue.erase(it); + } else { // option (ii) or (iii): keep the task in the queue. + ++it; + } + } else { + return *it; + } + } + return nullptr; +} + +void TaskScheduler::removeErroringTask(uint64_t scheduledTaskID) { + lock_t lck{taskSchedulerMtx}; + for (auto it = taskQueue.begin(); it != taskQueue.end(); ++it) { + if (scheduledTaskID == (*it)->ID) { + taskQueue.erase(it); + return; + } + } +} + +void TaskScheduler::runTask(Task* task) { + try { + task->run(); + task->deRegisterThreadAndFinalizeTask(); + } catch (std::exception& e) { + task->setException(std::current_exception()); + task->deRegisterThreadAndFinalizeTask(); + } +} + +TaskScheduler* TaskScheduler::Get(const main::ClientContext& context) { + return context.getDatabase()->getQueryProcessor()->getTaskScheduler(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/terminal_progress_bar_display.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/terminal_progress_bar_display.cpp new file mode 100644 index 0000000000..e275ef558b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/task_system/terminal_progress_bar_display.cpp @@ -0,0 +1,66 @@ +#include "common/task_system/terminal_progress_bar_display.h" + +#include "common/assert.h" + +namespace lbug { +namespace common { + +void TerminalProgressBarDisplay::updateProgress(uint64_t /*queryID*/, double newPipelineProgress, + uint32_t newNumPipelinesFinished) { + KU_ASSERT(0.0 <= newPipelineProgress && newPipelineProgress <= 1.0); + + // There can still be data races as the comparison + update of cur/old progress is not done + // atomically + // However this error does not build up over time and we don't require perfect progress bar + // accuracy + // So we implement it this way with atomics (instead of mutexes) for better performance + uint32_t curPipelineProgress = (uint32_t)(newPipelineProgress * 100.0); + uint32_t oldPipelineProgress = (uint32_t)(pipelineProgress * 100.0); + if (curPipelineProgress > oldPipelineProgress || + newNumPipelinesFinished > numPipelinesFinished) { + pipelineProgress.store(newPipelineProgress); + numPipelinesFinished.store(newNumPipelinesFinished); + printProgressBar(); + } +} + +void TerminalProgressBarDisplay::finishProgress(uint64_t /*queryID*/) { + if (printing) { + std::cout << "\033[2A\033[2K\033[1B\033[2K\033[1A"; + std::cout.flush(); + } + printing = false; + numPipelines = 0; + numPipelinesFinished = 0; + pipelineProgress = 0; +} + +void TerminalProgressBarDisplay::printProgressBar() { + // If a different thread is already printing the progress skip the current update + // As we do not require the displayed value to be perfectly up to date + bool falseValue{false}; + if (currentlyPrintingProgress.compare_exchange_strong(falseValue, true)) { + setGreenFont(); + if (printing) { + if (pipelineProgress == 0) { + std::cout << "\033[1A\033[2K\033[1A"; + printing = false; + } else { + std::cout << "\033[1A"; + } + } + if (!printing) { + std::cout << "Pipelines Finished: " << numPipelinesFinished << "/" << numPipelines + << "\n"; + printing = true; + } + std::cout << "Current Pipeline Progress: " << uint32_t(pipelineProgress * 100.0) << "%" + << "\n"; + setDefaultFont(); + std::cout.flush(); + currentlyPrintingProgress.store(false); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/type_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/type_utils.cpp new file mode 100644 index 0000000000..f2f896978c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/type_utils.cpp @@ -0,0 +1,298 @@ +#include "common/type_utils.h" + +#include "common/exception/runtime.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +std::string TypeUtils::entryToString(const LogicalType& dataType, const uint8_t* value, + ValueVector* vector) { + auto valueVector = reinterpret_cast(vector); + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::INT32: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::INT16: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::INT8: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::UINT64: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::UINT32: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::UINT16: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::UINT8: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::INT128: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::DOUBLE: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::FLOAT: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::DECIMAL: + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + return DecimalType::insertDecimalPoint( + TypeUtils::toString(*reinterpret_cast(value)), + DecimalType::getScale(dataType)); + case PhysicalTypeID::INT32: + return DecimalType::insertDecimalPoint( + TypeUtils::toString(*reinterpret_cast(value)), + DecimalType::getScale(dataType)); + case PhysicalTypeID::INT64: + return DecimalType::insertDecimalPoint( + TypeUtils::toString(*reinterpret_cast(value)), + DecimalType::getScale(dataType)); + case PhysicalTypeID::INT128: + return DecimalType::insertDecimalPoint( + TypeUtils::toString(*reinterpret_cast(value)), + DecimalType::getScale(dataType)); + default: + // decimals should always be backed by one of these four + KU_UNREACHABLE; + } + case LogicalTypeID::DATE: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::TIMESTAMP_NS: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::TIMESTAMP_MS: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::TIMESTAMP_SEC: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::TIMESTAMP_TZ: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::TIMESTAMP: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::INTERVAL: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::BLOB: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::STRING: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::INTERNAL_ID: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::UINT128: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: + return TypeUtils::toString(*reinterpret_cast(value), valueVector); + case LogicalTypeID::MAP: + return TypeUtils::toString(*reinterpret_cast(value), valueVector); + case LogicalTypeID::STRUCT: + return TypeUtils::toString(*reinterpret_cast(value), valueVector); + case LogicalTypeID::UNION: + return TypeUtils::toString(*reinterpret_cast(value), valueVector); + case LogicalTypeID::UUID: + return TypeUtils::toString(*reinterpret_cast(value)); + case LogicalTypeID::NODE: + return TypeUtils::nodeToString(*reinterpret_cast(value), + valueVector); + case LogicalTypeID::REL: + return TypeUtils::relToString(*reinterpret_cast(value), valueVector); + default: + throw common::RuntimeException{ + common::stringFormat("Unsupported type: {} to string.", dataType.toString())}; + } +} + +static std::string entryToStringWithPos(sel_t pos, ValueVector* vector) { + if (vector->isNull(pos)) { + return ""; + } + return TypeUtils::entryToString(vector->dataType, + vector->getData() + vector->getNumBytesPerValue() * pos, vector); +} + +template<> +std::string TypeUtils::toString(const int128_t& val, void* /*valueVector*/) { + return Int128_t::toString(val); +} + +template<> +std::string TypeUtils::toString(const uint128_t& val, void* /*valueVector*/) { + return UInt128_t::toString(val); +} + +template<> +std::string TypeUtils::toString(const bool& val, void* /*valueVector*/) { + return val ? "True" : "False"; +} + +template<> +std::string TypeUtils::toString(const internalID_t& val, void* /*valueVector*/) { + return std::to_string(val.tableID) + ":" + std::to_string(val.offset); +} + +template<> +std::string TypeUtils::toString(const date_t& val, void* /*valueVector*/) { + return Date::toString(val); +} + +template<> +std::string TypeUtils::toString(const timestamp_ns_t& val, void* /*valueVector*/) { + return toString(Timestamp::fromEpochNanoSeconds(val.value)); +} + +template<> +std::string TypeUtils::toString(const timestamp_ms_t& val, void* /*valueVector*/) { + return toString(Timestamp::fromEpochMilliSeconds(val.value)); +} + +template<> +std::string TypeUtils::toString(const timestamp_sec_t& val, void* /*valueVector*/) { + return toString(Timestamp::fromEpochSeconds(val.value)); +} + +template<> +std::string TypeUtils::toString(const timestamp_tz_t& val, void* /*valueVector*/) { + return toString(static_cast(val)) + "+00"; +} + +template<> +std::string TypeUtils::toString(const timestamp_t& val, void* /*valueVector*/) { + return Timestamp::toString(val); +} + +template<> +std::string TypeUtils::toString(const interval_t& val, void* /*valueVector*/) { + return Interval::toString(val); +} + +template<> +std::string TypeUtils::toString(const ku_string_t& val, void* /*valueVector*/) { + return val.getAsString(); +} + +template<> +std::string TypeUtils::toString(const blob_t& val, void* /*valueVector*/) { + return Blob::toString(val.value.getData(), val.value.len); +} + +template<> +std::string TypeUtils::toString(const ku_uuid_t& val, void* /*valueVector*/) { + return UUID::toString(val); +} + +template<> +std::string TypeUtils::toString(const list_entry_t& val, void* valueVector) { + auto listVector = (ValueVector*)valueVector; + if (val.size == 0) { + return "[]"; + } + std::string result = "["; + auto dataVector = ListVector::getDataVector(listVector); + for (auto i = 0u; i < val.size - 1; ++i) { + result += entryToStringWithPos(val.offset + i, dataVector); + result += ","; + } + result += entryToStringWithPos(val.offset + val.size - 1, dataVector); + result += "]"; + return result; +} + +static std::string getMapEntryStr(sel_t pos, ValueVector* dataVector, ValueVector* keyVector, + ValueVector* valVector) { + if (dataVector->isNull(pos)) { + return ""; + } + return entryToStringWithPos(pos, keyVector) + "=" + entryToStringWithPos(pos, valVector); +} + +template<> +std::string TypeUtils::toString(const map_entry_t& val, void* valueVector) { + auto mapVector = (ValueVector*)valueVector; + if (val.entry.size == 0) { + return "{}"; + } + std::string result = "{"; + auto dataVector = ListVector::getDataVector(mapVector); + auto keyVector = MapVector::getKeyVector(mapVector); + auto valVector = MapVector::getValueVector(mapVector); + for (auto i = 0u; i < val.entry.size - 1; ++i) { + auto pos = val.entry.offset + i; + result += getMapEntryStr(pos, dataVector, keyVector, valVector); + result += ", "; + } + auto pos = val.entry.offset + val.entry.size - 1; + result += getMapEntryStr(pos, dataVector, keyVector, valVector); + result += "}"; + return result; +} + +template +static std::string structToString(const struct_entry_t& val, ValueVector* vector) { + const auto& fields = StructType::getFields(vector->dataType); + if (fields.size() == 0) { + return "{}"; + } + std::string result = "{"; + auto i = 0u; + for (; i < fields.size() - 1; ++i) { + auto fieldVector = StructVector::getFieldVector(vector, i); + if constexpr (SKIP_NULL_ENTRY) { + if (fieldVector->isNull(val.pos)) { + continue; + } + } + if (i != 0) { + result += ", "; + } + result += StructType::getField(vector->dataType, i).getName(); + result += ": "; + result += entryToStringWithPos(val.pos, fieldVector.get()); + } + auto fieldVector = StructVector::getFieldVector(vector, i); + if constexpr (SKIP_NULL_ENTRY) { + if (fieldVector->isNull(val.pos)) { + result += "}"; + return result; + } + } + if (i != 0) { + result += ", "; + } + result += StructType::getField(vector->dataType, i).getName(); + result += ": "; + result += entryToStringWithPos(val.pos, fieldVector.get()); + result += "}"; + return result; +} + +std::string TypeUtils::nodeToString(const struct_entry_t& val, ValueVector* vector) { + // Internal ID vector is the first field vector. + if (StructVector::getFieldVector(vector, 0)->isNull(val.pos)) { + return ""; + } + return structToString(val, vector); +} + +std::string TypeUtils::relToString(const struct_entry_t& val, ValueVector* vector) { + // Internal ID vector is the third field vector. + if (StructVector::getFieldVector(vector, 3)->isNull(val.pos)) { + return ""; + } + return structToString(val, vector); +} + +template<> +std::string TypeUtils::toString(const struct_entry_t& val, void* valVector) { + return structToString(val, (ValueVector*)valVector); +} + +template<> +std::string TypeUtils::toString(const union_entry_t& val, void* valVector) { + auto structVector = (ValueVector*)valVector; + auto unionFieldIdx = + UnionVector::getTagVector(structVector)->getValue(val.entry.pos); + auto unionFieldVector = UnionVector::getValVector(structVector, unionFieldIdx); + return entryToStringWithPos(val.entry.pos, unionFieldVector); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/CMakeLists.txt new file mode 100644 index 0000000000..c87bb7acc2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/CMakeLists.txt @@ -0,0 +1,19 @@ +add_subdirectory(value) + +add_library(lbug_common_types + OBJECT + blob.cpp + date_t.cpp + dtime_t.cpp + interval_t.cpp + ku_list.cpp + ku_string.cpp + timestamp_t.cpp + types.cpp + int128_t.cpp + uint128_t.cpp + uuid.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/blob.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/blob.cpp new file mode 100644 index 0000000000..c98494089e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/blob.cpp @@ -0,0 +1,101 @@ +#include "common/types/blob.h" + +#include "common/exception/conversion.h" +#include "common/string_format.h" + +namespace lbug { +namespace common { + +const int HexFormatConstants::HEX_MAP[256] = {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, + -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1}; + +static bool isRegularChar(char c) { + return c >= 32 && c <= 126 && c != '\\' && c != '\'' && c != '"'; +} + +uint64_t Blob::getBlobSize(const ku_string_t& blob) { + uint64_t blobSize = 0; + auto length = blob.len; + auto blobStr = blob.getData(); + for (auto i = 0u; i < length; i++) { + if (blobStr[i] == '\\') { + validateHexCode(blobStr, length, i); + blobSize++; + i += HexFormatConstants::LENGTH - 1; + } else if (blobStr[i] <= 127) { + blobSize++; + } else { + throw ConversionException( + "Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " + "must be escaped with hex codes (e.g. \\xAA)"); + } + } + return blobSize; +} + +uint64_t Blob::fromString(const char* str, uint64_t length, uint8_t* resultBuffer) { + auto resultPos = 0u; + for (auto i = 0u; i < length; i++) { + if (str[i] == '\\') { + validateHexCode(reinterpret_cast(str), length, i); + auto firstByte = HexFormatConstants::HEX_MAP[( + unsigned char)str[i + HexFormatConstants::FIRST_BYTE_POS]]; + auto secondByte = HexFormatConstants::HEX_MAP[( + unsigned char)str[i + HexFormatConstants::SECOND_BYTES_POS]]; + resultBuffer[resultPos++] = + (firstByte << HexFormatConstants::NUM_BYTES_TO_SHIFT_FOR_FIRST_BYTE) + secondByte; + i += HexFormatConstants::LENGTH - 1; + } else { + resultBuffer[resultPos++] = str[i]; + } + } + return resultPos; +} + +std::string Blob::toString(const uint8_t* value, uint64_t len) { + std::string result; + for (auto i = 0u; i < len; i++) { + if (isRegularChar(value[i])) { + // ascii characters are rendered as-is. + result += value[i]; + } else { + auto firstByte = value[i] >> HexFormatConstants::NUM_BYTES_TO_SHIFT_FOR_FIRST_BYTE; + auto secondByte = value[i] & HexFormatConstants::SECOND_BYTE_MASK; + // non-ascii characters are rendered as hexadecimal (e.g. \x00). + result += '\\'; + result += 'x'; + result += HexFormatConstants::HEX_TABLE[firstByte]; + result += HexFormatConstants::HEX_TABLE[secondByte]; + } + } + return result; +} + +void Blob::validateHexCode(const uint8_t* blobStr, uint64_t length, uint64_t curPos) { + if (curPos + HexFormatConstants::LENGTH > length) { + throw ConversionException( + "Invalid hex escape code encountered in string -> blob conversion: " + "unterminated escape code at end of string"); + } + if (memcmp(blobStr + curPos, HexFormatConstants::PREFIX, HexFormatConstants::PREFIX_LENGTH) != + 0 || + HexFormatConstants::HEX_MAP[blobStr[curPos + HexFormatConstants::FIRST_BYTE_POS]] < 0 || + HexFormatConstants::HEX_MAP[blobStr[curPos + HexFormatConstants::SECOND_BYTES_POS]] < 0) { + throw ConversionException( + stringFormat("Invalid hex escape code encountered in string -> blob conversion: {}", + std::string((char*)blobStr + curPos, HexFormatConstants::LENGTH))); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/date_t.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/date_t.cpp new file mode 100644 index 0000000000..c7ab372ee0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/date_t.cpp @@ -0,0 +1,485 @@ +#include "common/types/date_t.h" + +#include "common/assert.h" +#include "common/exception/conversion.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "common/types/cast_helpers.h" +#include "common/types/timestamp_t.h" +#include "re2.h" + +namespace lbug { +namespace common { + +date_t::date_t() : days{0} {} + +date_t::date_t(int32_t days_p) : days(days_p) {} + +bool date_t::operator==(const date_t& rhs) const { + return days == rhs.days; +} + +bool date_t::operator!=(const date_t& rhs) const { + return days != rhs.days; +} + +bool date_t::operator<=(const date_t& rhs) const { + return days <= rhs.days; +} + +bool date_t::operator<(const date_t& rhs) const { + return days < rhs.days; +} + +bool date_t::operator>(const date_t& rhs) const { + return days > rhs.days; +} + +bool date_t::operator>=(const date_t& rhs) const { + return days >= rhs.days; +} + +date_t date_t::operator+(const interval_t& interval) const { + date_t result{}; + if (interval.months != 0) { + int32_t year = 0, month = 0, day = 0, maxDayInMonth = 0; + Date::convert(*this, year, month, day); + int32_t year_diff = interval.months / Interval::MONTHS_PER_YEAR; + year += year_diff; + month += interval.months - year_diff * Interval::MONTHS_PER_YEAR; + if (month > Interval::MONTHS_PER_YEAR) { + year++; + month -= Interval::MONTHS_PER_YEAR; + } else if (month <= 0) { + year--; + month += Interval::MONTHS_PER_YEAR; + } + // handle date overflow + // example: 2020-01-31 + "1 months" + maxDayInMonth = Date::monthDays(year, month); + day = day > maxDayInMonth ? maxDayInMonth : day; + result = Date::fromDate(year, month, day); + } else { + result = *this; + } + if (interval.days != 0) { + result.days += interval.days; + } + if (interval.micros != 0) { + result.days += int32_t(interval.micros / Interval::MICROS_PER_DAY); + } + return result; +} + +date_t date_t::operator-(const interval_t& interval) const { + interval_t inverseRight{}; + inverseRight.months = -interval.months; + inverseRight.days = -interval.days; + inverseRight.micros = -interval.micros; + return *this + inverseRight; +} + +int64_t date_t::operator-(const date_t& rhs) const { + return (*this).days - rhs.days; +} + +bool date_t::operator==(const timestamp_t& rhs) const { + return Timestamp::fromDateTime(*this, dtime_t(0)).value == rhs.value; +} + +bool date_t::operator!=(const timestamp_t& rhs) const { + return !(*this == rhs); +} + +bool date_t::operator<(const timestamp_t& rhs) const { + return Timestamp::fromDateTime(*this, dtime_t(0)).value < rhs.value; +} + +bool date_t::operator<=(const timestamp_t& rhs) const { + return (*this) < rhs || (*this) == rhs; +} + +bool date_t::operator>(const timestamp_t& rhs) const { + return !(*this <= rhs); +} + +bool date_t::operator>=(const timestamp_t& rhs) const { + return !(*this < rhs); +} + +date_t date_t::operator+(const int32_t& day) const { + return date_t(this->days + day); +}; +date_t date_t::operator-(const int32_t& day) const { + return date_t(this->days - day); +}; + +const int32_t Date::NORMAL_DAYS[] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; +const int32_t Date::LEAP_DAYS[] = {0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; +const int32_t Date::CUMULATIVE_LEAP_DAYS[] = {0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, + 366}; +const int32_t Date::CUMULATIVE_DAYS[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, + 365}; +const int8_t Date::MONTH_PER_DAY_OF_YEAR[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; +const int8_t Date::LEAP_MONTH_PER_DAY_OF_YEAR[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; +const int32_t Date::CUMULATIVE_YEAR_DAYS[] = {0, 365, 730, 1096, 1461, 1826, 2191, 2557, 2922, 3287, + 3652, 4018, 4383, 4748, 5113, 5479, 5844, 6209, 6574, 6940, 7305, 7670, 8035, 8401, 8766, 9131, + 9496, 9862, 10227, 10592, 10957, 11323, 11688, 12053, 12418, 12784, 13149, 13514, 13879, 14245, + 14610, 14975, 15340, 15706, 16071, 16436, 16801, 17167, 17532, 17897, 18262, 18628, 18993, + 19358, 19723, 20089, 20454, 20819, 21184, 21550, 21915, 22280, 22645, 23011, 23376, 23741, + 24106, 24472, 24837, 25202, 25567, 25933, 26298, 26663, 27028, 27394, 27759, 28124, 28489, + 28855, 29220, 29585, 29950, 30316, 30681, 31046, 31411, 31777, 32142, 32507, 32872, 33238, + 33603, 33968, 34333, 34699, 35064, 35429, 35794, 36160, 36525, 36890, 37255, 37621, 37986, + 38351, 38716, 39082, 39447, 39812, 40177, 40543, 40908, 41273, 41638, 42004, 42369, 42734, + 43099, 43465, 43830, 44195, 44560, 44926, 45291, 45656, 46021, 46387, 46752, 47117, 47482, + 47847, 48212, 48577, 48942, 49308, 49673, 50038, 50403, 50769, 51134, 51499, 51864, 52230, + 52595, 52960, 53325, 53691, 54056, 54421, 54786, 55152, 55517, 55882, 56247, 56613, 56978, + 57343, 57708, 58074, 58439, 58804, 59169, 59535, 59900, 60265, 60630, 60996, 61361, 61726, + 62091, 62457, 62822, 63187, 63552, 63918, 64283, 64648, 65013, 65379, 65744, 66109, 66474, + 66840, 67205, 67570, 67935, 68301, 68666, 69031, 69396, 69762, 70127, 70492, 70857, 71223, + 71588, 71953, 72318, 72684, 73049, 73414, 73779, 74145, 74510, 74875, 75240, 75606, 75971, + 76336, 76701, 77067, 77432, 77797, 78162, 78528, 78893, 79258, 79623, 79989, 80354, 80719, + 81084, 81450, 81815, 82180, 82545, 82911, 83276, 83641, 84006, 84371, 84736, 85101, 85466, + 85832, 86197, 86562, 86927, 87293, 87658, 88023, 88388, 88754, 89119, 89484, 89849, 90215, + 90580, 90945, 91310, 91676, 92041, 92406, 92771, 93137, 93502, 93867, 94232, 94598, 94963, + 95328, 95693, 96059, 96424, 96789, 97154, 97520, 97885, 98250, 98615, 98981, 99346, 99711, + 100076, 100442, 100807, 101172, 101537, 101903, 102268, 102633, 102998, 103364, 103729, 104094, + 104459, 104825, 105190, 105555, 105920, 106286, 106651, 107016, 107381, 107747, 108112, 108477, + 108842, 109208, 109573, 109938, 110303, 110669, 111034, 111399, 111764, 112130, 112495, 112860, + 113225, 113591, 113956, 114321, 114686, 115052, 115417, 115782, 116147, 116513, 116878, 117243, + 117608, 117974, 118339, 118704, 119069, 119435, 119800, 120165, 120530, 120895, 121260, 121625, + 121990, 122356, 122721, 123086, 123451, 123817, 124182, 124547, 124912, 125278, 125643, 126008, + 126373, 126739, 127104, 127469, 127834, 128200, 128565, 128930, 129295, 129661, 130026, 130391, + 130756, 131122, 131487, 131852, 132217, 132583, 132948, 133313, 133678, 134044, 134409, 134774, + 135139, 135505, 135870, 136235, 136600, 136966, 137331, 137696, 138061, 138427, 138792, 139157, + 139522, 139888, 140253, 140618, 140983, 141349, 141714, 142079, 142444, 142810, 143175, 143540, + 143905, 144271, 144636, 145001, 145366, 145732, 146097}; + +void Date::extractYearOffset(int32_t& n, int32_t& year, int32_t& year_offset) { + year = Date::EPOCH_YEAR; + // first we normalize n to be in the year range [1970, 2370] + // since leap years repeat every 400 years, we can safely normalize just by "shifting" the + // CumulativeYearDays array + while (n < 0) { + n += Date::DAYS_PER_YEAR_INTERVAL; + year -= Date::YEAR_INTERVAL; + } + while (n >= Date::DAYS_PER_YEAR_INTERVAL) { + n -= Date::DAYS_PER_YEAR_INTERVAL; + year += Date::YEAR_INTERVAL; + } + // interpolation search + // we can find an upper bound of the year by assuming each year has 365 days + year_offset = n / 365; + // because of leap years we might be off by a little bit: compensate by decrementing the year + // offset until we find our year + while (n < Date::CUMULATIVE_YEAR_DAYS[year_offset]) { + year_offset--; + KU_ASSERT(year_offset >= 0); + } + year += year_offset; + KU_ASSERT(n >= Date::CUMULATIVE_YEAR_DAYS[year_offset]); +} + +void Date::convert(date_t date, int32_t& out_year, int32_t& out_month, int32_t& out_day) { + auto n = date.days; + int32_t year_offset = 0; + Date::extractYearOffset(n, out_year, year_offset); + + out_day = n - Date::CUMULATIVE_YEAR_DAYS[year_offset]; + KU_ASSERT(out_day >= 0 && out_day <= 365); + + bool is_leap_year = (Date::CUMULATIVE_YEAR_DAYS[year_offset + 1] - + Date::CUMULATIVE_YEAR_DAYS[year_offset]) == 366; + if (is_leap_year) { + out_month = Date::LEAP_MONTH_PER_DAY_OF_YEAR[out_day]; + out_day -= Date::CUMULATIVE_LEAP_DAYS[out_month - 1]; + } else { + out_month = Date::MONTH_PER_DAY_OF_YEAR[out_day]; + out_day -= Date::CUMULATIVE_DAYS[out_month - 1]; + } + out_day++; + KU_ASSERT(out_day > 0 && out_day <= (is_leap_year ? Date::LEAP_DAYS[out_month] : + Date::NORMAL_DAYS[out_month])); + KU_ASSERT(out_month > 0 && out_month <= 12); + KU_ASSERT(Date::isValid(out_year, out_month, out_day)); +} + +date_t Date::fromDate(int32_t year, int32_t month, int32_t day) { + int32_t n = 0; + if (!Date::isValid(year, month, day)) { + throw ConversionException(stringFormat("Date out of range: {}-{}-{}.", year, month, day)); + } + while (year < 1970) { + year += Date::YEAR_INTERVAL; + n -= Date::DAYS_PER_YEAR_INTERVAL; + } + while (year >= 2370) { + year -= Date::YEAR_INTERVAL; + n += Date::DAYS_PER_YEAR_INTERVAL; + } + n += Date::CUMULATIVE_YEAR_DAYS[year - 1970]; + n += Date::isLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month - 1] : + Date::CUMULATIVE_DAYS[month - 1]; + n += day - 1; + return date_t(n); +} + +bool Date::parseDoubleDigit(const char* buf, uint64_t len, uint64_t& pos, int32_t& result) { + if (pos < len && StringUtils::CharacterIsDigit(buf[pos])) { + result = buf[pos++] - '0'; + if (pos < len && StringUtils::CharacterIsDigit(buf[pos])) { + result = (buf[pos++] - '0') + result * 10; + } + return true; + } + return false; +} + +// Checks if the date std::string given in buf complies with the YYYY:MM:DD format. Ignores leading +// and trailing spaces. Removes from the original DuckDB code the following features: +// 1) we don't parse "negative years", i.e., date formats that start with -. +// 2) we don't parse dates that end with trailing "BC". +bool Date::tryConvertDate(const char* buf, uint64_t len, uint64_t& pos, date_t& result, + bool allowTrailing) { + if (len == 0) { + return false; + } + + int32_t day = 0; + int32_t month = -1; + int32_t year = 0; + + // skip leading spaces + while (pos < len && StringUtils::isSpace(buf[pos])) { + pos++; + } + + if (pos >= len) { + return false; + } + + if (!StringUtils::CharacterIsDigit(buf[pos])) { + return false; + } + // first parse the year + for (; pos < len && StringUtils::CharacterIsDigit(buf[pos]); pos++) { + year = (buf[pos] - '0') + year * 10; + if (year > Date::MAX_YEAR) { + break; + } + } + + if (pos >= len) { + return false; + } + + // fetch the separator + char sep = buf[pos++]; + if (sep != ' ' && sep != '-' && sep != '/' && sep != '\\') { + // invalid separator + return false; + } + + // parse the month + if (!Date::parseDoubleDigit(buf, len, pos, month)) { + return false; + } + + // Also checks that the separator is not the end of the string + if (pos + 1 >= len) { + return false; + } + + if (buf[pos++] != sep) { + return false; + } + + // now parse the day + if (!Date::parseDoubleDigit(buf, len, pos, day)) { + return false; + } + + // skip trailing spaces + while (pos < len && StringUtils::isSpace((unsigned char)buf[pos])) { + pos++; + } + // check position. if end was not reached, non-space chars remaining + if (pos < len && !allowTrailing) { + return false; + } + + try { + result = Date::fromDate(year, month, day); + } catch (ConversionException& exc) { + return false; + } + return true; +} + +date_t Date::fromCString(const char* str, uint64_t len) { + date_t result; + uint64_t pos = 0; + if (!tryConvertDate(str, len, pos, result)) { + throw ConversionException("Error occurred during parsing date. Given: \"" + + std::string(str, len) + "\". Expected format: (YYYY-MM-DD)"); + } + return result; +} + +std::string Date::toString(date_t date) { + int32_t dateUnits[3]; + uint64_t yearLength = 0; + bool addBC = false; + Date::convert(date, dateUnits[0], dateUnits[1], dateUnits[2]); + + auto length = DateToStringCast::Length(dateUnits, yearLength, addBC); + auto buffer = std::make_unique(length); + DateToStringCast::Format(buffer.get(), dateUnits, yearLength, addBC); + return std::string(buffer.get(), length); +} + +bool Date::isLeapYear(int32_t year) { + return year % 4 == 0 && (year % 100 != 0 || year % 400 == 0); +} + +bool Date::isValid(int32_t year, int32_t month, int32_t day) { + if (month < 1 || month > 12) { + return false; + } + if (year < Date::MIN_YEAR || year > Date::MAX_YEAR) { + return false; + } + if (day < 1) { + return false; + } + return Date::isLeapYear(year) ? day <= Date::LEAP_DAYS[month] : day <= Date::NORMAL_DAYS[month]; +} + +int32_t Date::monthDays(int32_t year, int32_t month) { + KU_ASSERT(month >= 1 && month <= 12); + return Date::isLeapYear(year) ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month]; +} + +std::string Date::getDayName(date_t date) { + std::string dayNames[] = {"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", + "Saturday"}; + return dayNames[(date.days < 0 ? 7 - ((-date.days + 3) % 7) : ((date.days + 3) % 7) + 1) % 7]; +} + +std::string Date::getMonthName(date_t date) { + std::string monthNames[] = {"January", "February", "March", "April", "May", "June", "July", + "August", "September", "October", "November", "December"}; + int32_t year = 0, month = 0, day = 0; + Date::convert(date, year, month, day); + return monthNames[month - 1]; +} + +date_t Date::getLastDay(date_t date) { + int32_t year = 0, month = 0, day = 0; + Date::convert(date, year, month, day); + year += (month / 12); + month %= 12; + ++month; + return Date::fromDate(year, month, 1) - 1; +} + +int32_t Date::getDatePart(DatePartSpecifier specifier, date_t date) { + int32_t year = 0, month = 0, day = 0; + Date::convert(date, year, month, day); + switch (specifier) { + case DatePartSpecifier::YEAR: { + int32_t yearOffset = 0; + extractYearOffset(date.days, year, yearOffset); + return year; + } + case DatePartSpecifier::MONTH: + return month; + case DatePartSpecifier::DAY: + return day; + case DatePartSpecifier::DECADE: + return year / 10; + case DatePartSpecifier::CENTURY: + // From the PG docs: + // "The first century starts at 0001-01-01 00:00:00 AD, although they did not know it at the + // time. This definition applies to all Gregorian calendar countries. There is no century + // number 0, you go from -1 century to 1 century. If you disagree with this, please write + // your complaint to: Pope, Cathedral Saint-Peter of Roma, Vatican." (To be fair, His + // Holiness had nothing to do with this - it was the lack of zero in the counting systems of + // the time...). + return year > 0 ? ((year - 1) / 100) + 1 : (year / 100) - 1; + case DatePartSpecifier::MILLENNIUM: + return year > 0 ? ((year - 1) / 1000) + 1 : (year / 1000) - 1; + case DatePartSpecifier::QUARTER: + return (month - 1) / Interval::MONTHS_PER_QUARTER + 1; + default: + return 0; + } +} + +date_t Date::trunc(DatePartSpecifier specifier, date_t date) { + switch (specifier) { + case DatePartSpecifier::YEAR: + return Date::fromDate(Date::getDatePart(DatePartSpecifier::YEAR, date), 1 /* month */, + 1 /* day */); + case DatePartSpecifier::MONTH: + return Date::fromDate(Date::getDatePart(DatePartSpecifier::YEAR, date), + Date::getDatePart(DatePartSpecifier::MONTH, date), 1 /* day */); + case DatePartSpecifier::DAY: + return date; + case DatePartSpecifier::DECADE: + return Date::fromDate((Date::getDatePart(DatePartSpecifier::YEAR, date) / 10) * 10, + 1 /* month */, 1 /* day */); + case DatePartSpecifier::CENTURY: + return Date::fromDate((Date::getDatePart(DatePartSpecifier::YEAR, date) / 100) * 100, + 1 /* month */, 1 /* day */); + case DatePartSpecifier::MILLENNIUM: + return Date::fromDate((Date::getDatePart(DatePartSpecifier::YEAR, date) / 1000) * 1000, + 1 /* month */, 1 /* day */); + case DatePartSpecifier::QUARTER: { + int32_t year = 0, month = 0, day = 0; + Date::convert(date, year, month, day); + month = 1 + (((month - 1) / 3) * 3); + return Date::fromDate(year, month, 1); + } + default: + return date; + } +} + +int64_t Date::getEpochNanoSeconds(const date_t& date) { + return ((int64_t)date.days) * (Interval::MICROS_PER_DAY * Interval::NANOS_PER_MICRO); +} + +const regex::RE2& Date::regexPattern() { + static regex::RE2 retval("\\d{4}/\\d{1,2}/\\d{1,2}|\\d{4}-\\d{1,2}-\\d{1,2}|\\d{4} \\d{1,2} " + "\\d{1,2}|\\d{4}\\\\\\d{1,2}\\\\\\d{1,2}"); + return retval; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/dtime_t.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/dtime_t.cpp new file mode 100644 index 0000000000..eaa7836075 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/dtime_t.cpp @@ -0,0 +1,222 @@ +#include "common/types/dtime_t.h" + +#include + +#include "common/assert.h" +#include "common/exception/conversion.h" +#include "common/string_format.h" +#include "common/types/cast_helpers.h" +#include "common/types/date_t.h" + +namespace lbug { +namespace common { + +static_assert(sizeof(dtime_t) == sizeof(int64_t), "dtime_t was padded"); + +dtime_t::dtime_t() : micros(0) {} + +dtime_t::dtime_t(int64_t micros_p) : micros(micros_p) {} + +dtime_t& dtime_t::operator=(int64_t micros_p) { + micros = micros_p; + return *this; +} + +dtime_t::operator int64_t() const { + return micros; +} + +dtime_t::operator double() const { + return micros; +} + +bool dtime_t::operator==(const dtime_t& rhs) const { + return micros == rhs.micros; +} + +bool dtime_t::operator!=(const dtime_t& rhs) const { + return micros != rhs.micros; +} + +bool dtime_t::operator<=(const dtime_t& rhs) const { + return micros <= rhs.micros; +} + +bool dtime_t::operator<(const dtime_t& rhs) const { + return micros < rhs.micros; +} + +bool dtime_t::operator>(const dtime_t& rhs) const { + return micros > rhs.micros; +} + +bool dtime_t::operator>=(const dtime_t& rhs) const { + return micros >= rhs.micros; +} + +bool Time::tryConvertInternal(const char* buf, uint64_t len, uint64_t& pos, dtime_t& result) { + int32_t hour = -1, min = -1, sec = -1, micros = -1; + pos = 0; + + if (len == 0) { + return false; + } + + // skip leading spaces + while (pos < len && isspace(buf[pos])) { + pos++; + } + + if (pos >= len) { + return false; + } + + if (!isdigit(buf[pos])) { + return false; + } + + // Allow up to 9 digit hours to support intervals + hour = 0; + for (int32_t digits = 9; pos < len && isdigit(buf[pos]); ++pos) { + if (digits-- > 0) { + hour = hour * 10 + (buf[pos] - '0'); + } else { + return false; + } + } + + if (pos >= len) { + return false; + } + + // fetch the separator + char sep = buf[pos++]; + if (sep != ':') { + // invalid separator + return false; + } + + if (!Date::parseDoubleDigit(buf, len, pos, min)) { + return false; + } + if (min < 0 || min >= 60) { + return false; + } + + if (pos >= len) { + return false; + } + + if (buf[pos++] != sep) { + return false; + } + + if (!Date::parseDoubleDigit(buf, len, pos, sec)) { + return false; + } + if (sec < 0 || sec >= 60) { + return false; + } + + micros = 0; + if (pos < len && buf[pos] == '.') { + pos++; + // we expect some microseconds + int32_t mult = 100000; + for (; pos < len && isdigit(buf[pos]); pos++, mult /= 10) { + if (mult > 0) { + micros += (buf[pos] - '0') * mult; + } + } + } + + result = Time::fromTimeInternal(hour, min, sec, micros); + return true; +} + +bool Time::tryConvertInterval(const char* buf, uint64_t len, uint64_t& pos, dtime_t& result) { + if (!Time::tryConvertInternal(buf, len, pos, result)) { + return false; + } + // check remaining string for non-space characters + // skip trailing spaces + while (pos < len && isspace(buf[pos])) { + pos++; + } + // check position. if end was not reached, non-space chars remaining + if (pos < len) { + return false; + } + return true; +} + +// string format is hh:mm:ss[.mmmmmm] (ISO 8601) (m represent microseconds) +// microseconds is optional, timezone is currently not supported +bool Time::tryConvertTime(const char* buf, uint64_t len, uint64_t& pos, dtime_t& result) { + if (!Time::tryConvertInternal(buf, len, pos, result)) { + return false; + } + return result.micros < Interval::MICROS_PER_DAY; +} + +dtime_t Time::fromCString(const char* buf, uint64_t len) { + dtime_t result; + uint64_t pos = 0; + if (!Time::tryConvertTime(buf, len, pos, result)) { + throw ConversionException(stringFormat("Error occurred during parsing time. Given: \"{}\". " + "Expected format: (hh:mm:ss[.zzzzzz]).", + std::string(buf, len))); + } + return result; +} + +std::string Time::toString(dtime_t time) { + int32_t time_units[4]; + Time::convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); + + char micro_buffer[6]; + auto length = TimeToStringCast::Length(time_units, micro_buffer); + auto buffer = std::unique_ptr(new char[length]); + TimeToStringCast::Format(buffer.get(), length, time_units, micro_buffer); + return std::string(buffer.get(), length); +} + +bool Time::isValid(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { + if (hour > 23 || hour < 0 || minute > 59 || minute < 0 || second > 59 || second < 0 || + microseconds > 999999 || microseconds < 0) { + return false; + } + return true; +} + +dtime_t Time::fromTimeInternal(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { + int64_t result = 0; + result = hour; // hours + result = result * Interval::MINS_PER_HOUR + minute; // hours -> minutes + result = result * Interval::SECS_PER_MINUTE + second; // minutes -> seconds + result = result * Interval::MICROS_PER_SEC + microseconds; // seconds -> microseconds + return dtime_t(result); +} + +dtime_t Time::fromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { + if (!Time::isValid(hour, minute, second, microseconds)) { + throw ConversionException(stringFormat("Time field value out of range: {}:{}:{}[.{}].", + hour, minute, second, microseconds)); + } + return Time::fromTimeInternal(hour, minute, second, microseconds); +} + +void Time::convert(dtime_t dtime, int32_t& hour, int32_t& min, int32_t& sec, int32_t& micros) { + int64_t time = dtime.micros; + hour = int32_t(time / Interval::MICROS_PER_HOUR); + time -= int64_t(hour) * Interval::MICROS_PER_HOUR; + min = int32_t(time / Interval::MICROS_PER_MINUTE); + time -= int64_t(min) * Interval::MICROS_PER_MINUTE; + sec = int32_t(time / Interval::MICROS_PER_SEC); + time -= int64_t(sec) * Interval::MICROS_PER_SEC; + micros = int32_t(time); + KU_ASSERT(Time::isValid(hour, min, sec, micros)); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/int128_t.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/int128_t.cpp new file mode 100644 index 0000000000..f1d5cd6b1f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/int128_t.cpp @@ -0,0 +1,778 @@ +#include "common/types/int128_t.h" + +#include +#include + +#include "common/exception/runtime.h" +#include "common/numeric_utils.h" +#include "common/type_utils.h" +#include "common/types/uint128_t.h" +#include "function/cast/functions/numeric_limits.h" +#include "function/hash/hash_functions.h" +#include + +namespace lbug::common { + +static uint8_t positiveInt128BitsAmount(int128_t input) { + if (input.high) { + return 128 - std::countl_zero((uint64_t)input.high); + } else { + return 64 - std::countl_zero(input.low); + } +} + +static bool positiveInt128IsBitSet(int128_t input, uint8_t bit) { + if (bit < 64) { + return input.low & (1ULL << uint64_t(bit)); + } else { + return input.high & (1ULL << uint64_t(bit - 64)); + } +} + +int128_t positiveInt128LeftShift(int128_t lhs, uint32_t amount) { + int128_t result{}; + result.low = lhs.low << amount; + result.high = (lhs.high << amount) + (lhs.low >> (64 - amount)); + return result; +} + +int128_t Int128_t::divModPositive(int128_t lhs, uint64_t rhs, uint64_t& remainder) { + int128_t result{0}; + remainder = 0; + + for (uint8_t i = positiveInt128BitsAmount(lhs); i > 0; i--) { + result = positiveInt128LeftShift(result, 1); + remainder <<= 1; + if (positiveInt128IsBitSet(lhs, i - 1)) { + remainder++; + } + if (remainder >= rhs) { + remainder -= rhs; + result.low++; + if (result.low == 0) { + result.high++; + } + } + } + return result; +} + +std::string Int128_t::toString(int128_t input) { + bool isMin = (input.high == INT64_MIN && input.low == 0); + bool negative = input.high < 0; + + if (isMin) { + + uint64_t remainder = 0; + int128_t quotient = divModPositive(input, 10, remainder); + + std::string result = toString(quotient); + + result += static_cast('0' + remainder); + + return "-" + result; + } + + if (negative) { + negateInPlace(input); + } + + std::string result; + uint64_t remainder = 0; + + while (input.high != 0 || input.low != 0) { + input = divModPositive(input, 10, remainder); + result = std::string(1, '0' + remainder) + std::move(result); + } + + if (result.empty()) { + result = "0"; + } + + return negative ? "-" + result : result; +} + +bool Int128_t::addInPlace(int128_t& lhs, int128_t rhs) { + bool lhsPositive = lhs.high >= 0; + bool rhsPositive = rhs.high >= 0; + int overflow = lhs.low + rhs.low < lhs.low; + if (rhs.high >= 0) { + if (lhs.high > INT64_MAX - rhs.high - overflow) { + return false; + } + lhs.high = lhs.high + rhs.high + overflow; + } else { + if (lhs.high < INT64_MIN - rhs.high - overflow) { + return false; + } + lhs.high = lhs.high + rhs.high + overflow; + } + lhs.low += rhs.low; + if (lhsPositive && rhsPositive && lhs.high == INT64_MIN && lhs.low == 0) { + return false; + } + return true; +} + +bool Int128_t::subInPlace(int128_t& lhs, int128_t rhs) { + int underflow = lhs.low - rhs.low > lhs.low; + if (rhs.high >= 0) { + if (lhs.high < INT64_MIN + rhs.high + underflow) { + return false; + } + lhs.high = lhs.high - rhs.high - underflow; + } else { + if (lhs.high > INT64_MIN && lhs.high - 1 >= INT64_MAX + rhs.high + underflow) { + return false; + } + lhs.high = lhs.high - rhs.high - underflow; + } + lhs.low -= rhs.low; + if (lhs.high == INT64_MIN && lhs.low == 0) { + return false; + } + return true; +} + +int128_t Int128_t::Add(int128_t lhs, const int128_t rhs) { + if (!addInPlace(lhs, rhs)) { + throw common::OverflowException("INT128 is out of range: cannot add."); + } + return lhs; +} + +int128_t Int128_t::Sub(int128_t lhs, const int128_t rhs) { + if (!subInPlace(lhs, rhs)) { + throw common::OverflowException("INT128 is out of range: cannot subtract."); + } + return lhs; +} + +bool Int128_t::tryMultiply(int128_t lhs, int128_t rhs, int128_t& result) { + bool lhs_negative = lhs.high < 0; + bool rhs_negative = rhs.high < 0; + if (lhs_negative) { + negateInPlace(lhs); + } + if (rhs_negative) { + negateInPlace(rhs); + } +#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) + __uint128_t left = __uint128_t(lhs.low) + (__uint128_t(lhs.high) << 64); + __uint128_t right = __uint128_t(rhs.low) + (__uint128_t(rhs.high) << 64); + __uint128_t result_i128 = 0; + if (__builtin_mul_overflow(left, right, &result_i128)) { + return false; + } + auto high = uint64_t(result_i128 >> 64); + if (high & 0x8000000000000000) { + return false; + } + result.high = int64_t(high); + result.low = uint64_t(result_i128 & 0xffffffffffffffff); +#else + // Multiply code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + // License: https://github.com/calccrypto/uint128_t/blob/c%2B%2B11_14/LICENSE + uint64_t top[4] = {uint64_t(lhs.high) >> 32, uint64_t(lhs.high) & 0xffffffff, lhs.low >> 32, + lhs.low & 0xffffffff}; + uint64_t bottom[4] = {uint64_t(rhs.high) >> 32, uint64_t(rhs.high) & 0xffffffff, rhs.low >> 32, + rhs.low & 0xffffffff}; + uint64_t products[4][4]; + + // multiply each component of the values + for (auto x = 0; x < 4; x++) { + for (auto y = 0; y < 4; y++) { + products[x][y] = top[x] * bottom[y]; + } + } + + // if any of these products are set to a non-zero value, there is always an overflow + if (products[0][0] || products[0][1] || products[0][2] || products[1][0] || products[2][0] || + products[1][1]) { + return false; + } + // if the high bits of any of these are set, there is always an overflow + if ((products[0][3] & 0xffffffff80000000) || (products[1][2] & 0xffffffff80000000) || + (products[2][1] & 0xffffffff80000000) || (products[3][0] & 0xffffffff80000000)) { + return false; + } + + // otherwise we merge the result of the different products together in-order + + // first row + uint64_t fourth32 = (products[3][3] & 0xffffffff); + uint64_t third32 = (products[3][2] & 0xffffffff) + (products[3][3] >> 32); + uint64_t second32 = (products[3][1] & 0xffffffff) + (products[3][2] >> 32); + uint64_t first32 = (products[3][0] & 0xffffffff) + (products[3][1] >> 32); + + // second row + third32 += (products[2][3] & 0xffffffff); + second32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); + first32 += (products[2][1] & 0xffffffff) + (products[2][2] >> 32); + + // third row + second32 += (products[1][3] & 0xffffffff); + first32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); + + // fourth row + first32 += (products[0][3] & 0xffffffff); + + // move carry to next digit + third32 += fourth32 >> 32; + second32 += third32 >> 32; + first32 += second32 >> 32; + + // check if the combination of the different products resulted in an overflow + if (first32 & 0xffffff80000000) { + return false; + } + + // remove carry from current digit + fourth32 &= 0xffffffff; + third32 &= 0xffffffff; + second32 &= 0xffffffff; + first32 &= 0xffffffff; + + // combine components + result.low = (third32 << 32) | fourth32; + result.high = (first32 << 32) | second32; +#endif + if (lhs_negative ^ rhs_negative) { + negateInPlace(result); + } + return true; +} + +int128_t Int128_t::Mul(int128_t lhs, int128_t rhs) { + int128_t result{}; + if (!tryMultiply(lhs, rhs, result)) { + throw common::OverflowException("INT128 is out of range: cannot multiply."); + } + return result; +} + +int128_t Int128_t::divMod(int128_t lhs, int128_t rhs, int128_t& remainder) { + bool lhs_negative = lhs.high < 0; + bool rhs_negative = rhs.high < 0; + if (lhs_negative) { + negateInPlace(lhs); + } + if (rhs_negative) { + negateInPlace(rhs); + } + + // divMod code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + // License: https://github.com/calccrypto/uint128_t/blob/c%2B%2B11_14/LICENSE + // initialize the result and remainder to 0 + int128_t div_result{0}; + remainder.low = 0; + remainder.high = 0; + + // now iterate over the amount of bits that are set in the LHS + for (uint8_t x = positiveInt128BitsAmount(lhs); x > 0; x--) { + // left-shift the current result and remainder by 1 + div_result = positiveInt128LeftShift(div_result, 1); + remainder = positiveInt128LeftShift(remainder, 1); + + // we get the value of the bit at position X, where position 0 is the least-significant bit + if (positiveInt128IsBitSet(lhs, x - 1)) { + // increment the remainder + addInPlace(remainder, 1); + } + if (greaterThanOrEquals(remainder, rhs)) { + // the remainder has passed the division multiplier: add one to the divide result + remainder = Sub(remainder, rhs); + addInPlace(div_result, 1); + } + } + if (lhs_negative ^ rhs_negative) { + negateInPlace(div_result); + } + if (lhs_negative) { + negateInPlace(remainder); + } + return div_result; +} + +int128_t Int128_t::Div(int128_t lhs, int128_t rhs) { + if (rhs.high == 0 && rhs.low == 0) { + throw common::RuntimeException("Divide by zero."); + } + int128_t remainder{}; + return divMod(lhs, rhs, remainder); +} + +int128_t Int128_t::Mod(int128_t lhs, int128_t rhs) { + if (rhs.high == 0 && rhs.low == 0) { + throw common::RuntimeException("Modulo by zero."); + } + int128_t result{}; + divMod(lhs, rhs, result); + return result; +} + +int128_t Int128_t::Xor(int128_t lhs, int128_t rhs) { + int128_t result{lhs.low ^ rhs.low, lhs.high ^ rhs.high}; + return result; +} + +int128_t Int128_t::BinaryAnd(int128_t lhs, int128_t rhs) { + int128_t result{lhs.low & rhs.low, lhs.high & rhs.high}; + return result; +} + +int128_t Int128_t::BinaryOr(int128_t lhs, int128_t rhs) { + int128_t result{lhs.low | rhs.low, lhs.high | rhs.high}; + return result; +} + +int128_t Int128_t::BinaryNot(int128_t val) { + return int128_t{~val.low, ~val.high}; +} + +int128_t Int128_t::LeftShift(int128_t lhs, int amount) { + // adapted from + // https://github.com/abseil/abseil-cpp/blob/master/absl/numeric/int128.h + return amount >= 64 ? int128_t(0, lhs.low << (amount - 64)) : + amount == 0 ? lhs : + int128_t{lhs.low << amount, + (lhs.high << amount) | + (numeric_utils::makeValueSigned(lhs.low >> (64 - amount)))}; +} + +int128_t Int128_t::RightShift(int128_t lhs, int amount) { + // adapted from + // https://github.com/abseil/abseil-cpp/blob/master/absl/numeric/int128.h + return amount >= 64 ? + // we shift the high value regardless for sign extension + int128_t(lhs.high >> (amount - 64), lhs.high >> 63) : + amount == 0 ? + lhs : + int128_t((lhs.low >> amount) | (lhs.high << (64 - amount)), lhs.high >> amount); +} + +//=============================================================================================== +// Cast operation +//=============================================================================================== +template +bool TryCastInt128Template(int128_t input, DST& result) { + switch (input.high) { + case 0: + if (input.low <= uint64_t(function::NumericLimits::maximum())) { + result = static_cast(input.low); + return true; + } + break; + case -1: + if constexpr (!SIGNED) { + throw common::OverflowException( + "Cast failed. Cannot cast " + Int128_t::toString(input) + " to unsigned type."); + } + if (input.low >= function::NumericLimits::maximum() - + uint64_t(function::NumericLimits::maximum())) { + result = -DST(function::NumericLimits::maximum() - input.low) - 1; + return true; + } + break; + default: + break; + } + return false; +} +// we can use the above template if we can get max using something like DST.max + +template<> +bool Int128_t::tryCast(int128_t input, int8_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, int16_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, int32_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, int64_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, uint8_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, uint16_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, uint32_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, uint64_t& result) { + return TryCastInt128Template(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, uint128_t& result) { + if (input.high < 0) { + return false; + } + result.low = input.low; + result.high = uint64_t(input.high); + return true; +} + +template<> +bool Int128_t::tryCast(int128_t input, float& result) { + double temp_res = NAN; + tryCast(input, temp_res); + result = static_cast(temp_res); + return true; +} + +template +bool CastInt128ToFloating(int128_t input, REAL_T& result) { + switch (input.high) { + case -1: + result = -REAL_T(function::NumericLimits::maximum() - input.low) - 1; + break; + default: + result = REAL_T(input.high) * REAL_T(function::NumericLimits::maximum()) + + REAL_T(input.low); + break; + } + return true; +} + +template<> +bool Int128_t::tryCast(int128_t input, double& result) { + return CastInt128ToFloating(input, result); +} + +template<> +bool Int128_t::tryCast(int128_t input, long double& result) { + return CastInt128ToFloating(input, result); +} + +template +int128_t tryCastToTemplate(SRC value) { + int128_t result{}; + result.low = (uint64_t)value; + result.high = (value < 0) * -1; + return result; +} + +template<> +bool Int128_t::tryCastTo(int8_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(int16_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(int32_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(int64_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(uint8_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(uint16_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(uint32_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(uint64_t value, int128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool Int128_t::tryCastTo(int128_t value, int128_t& result) { + result = value; + return true; +} + +template<> +bool Int128_t::tryCastTo(float value, int128_t& result) { + return tryCastTo(double(value), result); +} + +template +bool castFloatingToInt128(REAL_T value, int128_t& result) { + // TODO: Maybe need to add func isFinite in value.h to see if every type is finite. + if (value <= -170141183460469231731687303715884105728.0 || + value >= 170141183460469231731687303715884105727.0) { + return false; + } + bool negative = value < 0; + if (negative) { + value = -value; + } + value = std::nearbyint(value); + result.low = (uint64_t)fmod(value, REAL_T(function::NumericLimits::maximum())); + result.high = (uint64_t)(value / REAL_T(function::NumericLimits::maximum())); + if (negative) { + Int128_t::negateInPlace(result); + } + return true; +} + +template<> +bool Int128_t::tryCastTo(double value, int128_t& result) { + return castFloatingToInt128(value, result); +} + +template<> +bool Int128_t::tryCastTo(long double value, int128_t& result) { + return castFloatingToInt128(value, result); +} +//=============================================================================================== + +template +void constructInt128Template(T value, int128_t& result) { + int128_t casted = Int128_t::castTo(value); + result.low = casted.low; + result.high = casted.high; +} + +int128_t::int128_t(int64_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(int32_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(int16_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(int8_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(uint64_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(uint32_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(uint16_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(uint8_t value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(double value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +int128_t::int128_t(float value) { // NOLINT: fields are constructed by the template + constructInt128Template(value, *this); +} + +//============================================================================================ +bool operator==(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::equals(lhs, rhs); +} + +bool operator!=(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::notEquals(lhs, rhs); +} + +bool operator>(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::greaterThan(lhs, rhs); +} + +bool operator>=(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::greaterThanOrEquals(lhs, rhs); +} + +bool operator<(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::lessThan(lhs, rhs); +} + +bool operator<=(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::lessThanOrEquals(lhs, rhs); +} + +int128_t int128_t::operator-() const { + return Int128_t::negate(*this); +} + +// support for operations like (int32_t)x + (int128_t)y + +int128_t operator+(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::Add(lhs, rhs); +} +int128_t operator-(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::Sub(lhs, rhs); +} +int128_t operator*(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::Mul(lhs, rhs); +} +int128_t operator/(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::Div(lhs, rhs); +} +int128_t operator%(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::Mod(lhs, rhs); +} + +int128_t operator^(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::Xor(lhs, rhs); +} + +int128_t operator&(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::BinaryAnd(lhs, rhs); +} + +int128_t operator|(const int128_t& lhs, const int128_t& rhs) { + return Int128_t::BinaryOr(lhs, rhs); +} + +int128_t operator~(const int128_t& val) { + return Int128_t::BinaryNot(val); +} + +int128_t operator<<(const int128_t& lhs, int amount) { + return Int128_t::LeftShift(lhs, amount); +} + +int128_t operator>>(const int128_t& lhs, int amount) { + return Int128_t::RightShift(lhs, amount); +} + +// inplace arithmetic operators +int128_t& int128_t::operator+=(const int128_t& rhs) { + if (!Int128_t::addInPlace(*this, rhs)) { + throw common::OverflowException("INT128 is out of range: cannot add in place."); + } + return *this; +} + +int128_t& int128_t::operator*=(const int128_t& rhs) { + *this = Int128_t::Mul(*this, rhs); + return *this; +} + +int128_t& int128_t::operator|=(const int128_t& rhs) { + *this = Int128_t::BinaryOr(*this, rhs); + return *this; +} + +int128_t& int128_t::operator&=(const int128_t& rhs) { + *this = Int128_t::BinaryAnd(*this, rhs); + return *this; +} + +template +static T NarrowCast(const int128_t& input) { + return static_cast(input.low); +} + +int128_t::operator int64_t() const { + return NarrowCast(*this); +} + +int128_t::operator int32_t() const { + return NarrowCast(*this); +} + +int128_t::operator int16_t() const { + return NarrowCast(*this); +} + +int128_t::operator int8_t() const { + return NarrowCast(*this); +} + +int128_t::operator uint64_t() const { + return NarrowCast(*this); +} + +int128_t::operator uint32_t() const { + return NarrowCast(*this); +} + +int128_t::operator uint16_t() const { + return NarrowCast(*this); +} + +int128_t::operator uint8_t() const { + return NarrowCast(*this); +} + +int128_t::operator double() const { + double result = NAN; + if (!Int128_t::tryCast(*this, result)) { // LCOV_EXCL_START + throw common::OverflowException(common::stringFormat("Value {} is not within DOUBLE range", + common::TypeUtils::toString(*this))); + } // LCOV_EXCL_STOP + return result; +} + +int128_t::operator float() const { + float result = NAN; + if (!Int128_t::tryCast(*this, result)) { // LCOV_EXCL_START + throw common::OverflowException(common::stringFormat("Value {} is not within FLOAT range", + common::TypeUtils::toString(*this))); + } // LCOV_EXCL_STOP + return result; +} + +int128_t::operator uint128_t() const { + uint128_t result{}; + if (!Int128_t::tryCast(*this, result)) { + throw common::OverflowException(common::stringFormat( + "Cannot cast negative INT128 value {} to UINT128", common::TypeUtils::toString(*this))); + } + return result; +} + +} // namespace lbug::common + +std::size_t std::hash::operator()( + const lbug::common::int128_t& v) const noexcept { + lbug::common::hash_t hash = 0; + lbug::function::Hash::operation(v, hash); + return hash; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/interval_t.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/interval_t.cpp new file mode 100644 index 0000000000..0591e2e0d7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/interval_t.cpp @@ -0,0 +1,483 @@ +#include "common/types/interval_t.h" + +#include "common/assert.h" +#include "common/exception/conversion.h" +#include "common/exception/overflow.h" +#include "common/string_utils.h" +#include "common/types/cast_helpers.h" +#include "common/types/timestamp_t.h" +#include "function/arithmetic/add.h" +#include "function/arithmetic/multiply.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "function/cast/functions/cast_functions.h" +#include "re2.h" + +namespace lbug { +namespace common { + +interval_t::interval_t() = default; + +interval_t::interval_t(int32_t months_p, int32_t days_p, int64_t micros_p) + : months(months_p), days(days_p), micros(micros_p) {} + +bool interval_t::operator==(const interval_t& rhs) const { + return this->days == rhs.days && this->months == rhs.months && this->micros == rhs.micros; +} + +bool interval_t::operator!=(const interval_t& rhs) const { + return !(*this == rhs); +} + +bool interval_t::operator>(const interval_t& rhs) const { + return Interval::greaterThan(*this, rhs); +} + +bool interval_t::operator<=(const interval_t& rhs) const { + return !(*this > rhs); +} + +bool interval_t::operator<(const interval_t& rhs) const { + return !(*this >= rhs); +} + +bool interval_t::operator>=(const interval_t& rhs) const { + return *this > rhs || *this == rhs; +} + +interval_t interval_t::operator+(const interval_t& rhs) const { + interval_t result{}; + result.months = months + rhs.months; + result.days = days + rhs.days; + result.micros = micros + rhs.micros; + return result; +} + +timestamp_t interval_t::operator+(const timestamp_t& rhs) const { + return rhs + *this; +} + +date_t interval_t::operator+(const date_t& rhs) const { + return rhs + *this; +} + +interval_t interval_t::operator-(const interval_t& rhs) const { + interval_t result{}; + result.months = months - rhs.months; + result.days = days - rhs.days; + result.micros = micros - rhs.micros; + return result; +} + +interval_t interval_t::operator/(const uint64_t& rhs) const { + interval_t result{}; + int32_t monthsRemainder = months % rhs; + int32_t daysRemainder = (days + monthsRemainder * Interval::DAYS_PER_MONTH) % rhs; + result.months = months / rhs; + result.days = (days + monthsRemainder * Interval::DAYS_PER_MONTH) / rhs; + result.micros = (micros + daysRemainder * Interval::MICROS_PER_DAY) / rhs; + return result; +} + +void Interval::addition(interval_t& result, uint64_t number, std::string specifierStr) { + StringUtils::toLower(specifierStr); + if (specifierStr == "year" || specifierStr == "years" || specifierStr == "y") { + result.months += number * MONTHS_PER_YEAR; + } else if (specifierStr == "month" || specifierStr == "months" || specifierStr == "mon") { + result.months += number; + } else if (specifierStr == "day" || specifierStr == "days" || specifierStr == "d") { + result.days += number; + } else if (specifierStr == "hour" || specifierStr == "hours" || specifierStr == "h") { + result.micros += number * MICROS_PER_HOUR; + } else if (specifierStr == "minute" || specifierStr == "minutes" || specifierStr == "m") { + result.micros += number * MICROS_PER_MINUTE; + } else if (specifierStr == "second" || specifierStr == "seconds" || specifierStr == "s") { + result.micros += number * MICROS_PER_SEC; + } else if (specifierStr == "millisecond" || specifierStr == "milliseconds" || + specifierStr == "ms" || specifierStr == "msec") { + result.micros += number * MICROS_PER_MSEC; + } else if (specifierStr == "microsecond" || specifierStr == "microseconds" || + specifierStr == "us") { + result.micros += number; + } else { + throw ConversionException("Unrecognized interval specifier string: " + specifierStr + "."); + } +} + +template +T intervalTryCastInteger(int64_t input) { + if (std::is_same()) { + int8_t result = 0; + function::CastToInt8::operation(input, result); + return result; + } else if (std::is_same()) { + int16_t result = 0; + function::CastToInt16::operation(input, result); + return result; + } else if (std::is_same()) { + int32_t result = 0; + function::CastToInt32::operation(input, result); + return result; + } else if (std::is_same()) { + int64_t result = 0; + function::CastToInt64::operation(input, result); + return result; + } else { + throw ConversionException("The destination is not an integer"); + } +} + +template +void intervalTryAddition(T& target, int64_t input, int64_t multiplier, int64_t fraction = 0) { + int64_t addition = 0; + try { + function::Multiply::operation(input, multiplier, addition); + } catch (const OverflowException& e) { + throw OverflowException{"Interval value is out of range"}; + } + T additionBase = intervalTryCastInteger(addition); + try { + function::Add::operation(target, additionBase, target); + } catch (const OverflowException& e) { + throw OverflowException{"Interval value is out of range"}; + } + if (fraction) { + // Add in (fraction * multiplier) / MICROS_PER_SEC + // This is always in range + addition = (fraction * multiplier) / Interval::MICROS_PER_SEC; + additionBase = intervalTryCastInteger(addition); + try { + function::Add::operation(target, additionBase, target); + } catch (const OverflowException& e) { + throw OverflowException{"Interval fraction is out of range"}; + } + } +} + +interval_t Interval::fromCString(const char* str, uint64_t len) { + interval_t result; + uint64_t pos = 0; + uint64_t startPos = 0; + bool foundAny = false; + int64_t number = 0; + int64_t fraction = 0; + DatePartSpecifier specifier{}; + std::string specifierStr{}; + + result.days = 0; + result.micros = 0; + result.months = 0; + + if (len == 0) { + throw ConversionException("Error occurred during parsing interval. Given empty string."); + } + + if (str[pos] == '@') { + pos++; + } + +parse_interval: + for (; pos < len; pos++) { + char c = str[pos]; + if (isspace(c)) { + // skip spaces + continue; + } else if (isdigit(c)) { + // start parsing a number + goto interval_parse_number; + } else { + // unrecognized character, expected a number or end of string + throw ConversionException("Error occurred during parsing interval. Given: \"" + + std::string(str, len) + "\"."); + } + } + goto end_of_string; + +interval_parse_number: + startPos = pos; + for (; pos < len; pos++) { + char c = str[pos]; + if (isdigit(c)) { + // the number continues + continue; + } else if (c == ':') { + // colon: we are parsing a time + goto interval_parse_time; + } else { + // finished the number, parse it from the string + function::CastString::operation(ku_string_t{str + startPos, pos - startPos}, number); + fraction = 0; + if (c == '.') { + // we expect some microseconds + int32_t mult = 100000; + for (++pos; pos < len && isdigit(str[pos]); ++pos, mult /= 10) { + if (mult > 0) { + fraction += int64_t(str[pos] - '0') * mult; + } + } + } + goto interval_parse_identifier; + } + } + goto interval_parse_identifier; + +interval_parse_time: { + // parse the remainder of the time as a Time type + dtime_t time; + uint64_t tmpPos = 0; + if (!Time::tryConvertInterval(str + startPos, len - startPos, tmpPos, time)) { + throw ConversionException("Error occurred during parsing time. Given: \"" + + std::string(str + startPos, len - startPos) + "\"."); + } + result.micros += time.micros; + foundAny = true; + goto end_of_string; +} + +interval_parse_identifier: + for (; pos < len; pos++) { + char c = str[pos]; + if (isspace(c)) { + // skip spaces at the start + continue; + } else { + break; + } + } + // now parse the identifier + startPos = pos; + for (; pos < len; pos++) { + char c = str[pos]; + if (!isspace(c)) { + // keep parsing the string + continue; + } else { + break; + } + } + specifierStr = std::string(str + startPos, pos - startPos); + + // Specifier string is empty, missing field name + if (specifierStr.empty()) { + throw ConversionException("Error occurred during parsing interval. Field name is missing."); + } + + tryGetDatePartSpecifier(specifierStr, specifier); + + switch (specifier) { + case DatePartSpecifier::MILLENNIUM: + intervalTryAddition(result.months, number, MONTHS_PER_MILLENIUM, fraction); + break; + case DatePartSpecifier::CENTURY: + intervalTryAddition(result.months, number, MONTHS_PER_CENTURY, fraction); + break; + case DatePartSpecifier::DECADE: + intervalTryAddition(result.months, number, MONTHS_PER_DECADE, fraction); + break; + case DatePartSpecifier::YEAR: + intervalTryAddition(result.months, number, MONTHS_PER_YEAR, fraction); + break; + case DatePartSpecifier::QUARTER: + intervalTryAddition(result.months, number, MONTHS_PER_QUARTER, fraction); + // Reduce to fraction of a month + fraction *= MONTHS_PER_QUARTER; + fraction %= MICROS_PER_SEC; + intervalTryAddition(result.days, 0, DAYS_PER_MONTH, fraction); + break; + case DatePartSpecifier::MONTH: + intervalTryAddition(result.months, number, 1); + intervalTryAddition(result.days, 0, DAYS_PER_MONTH, fraction); + break; + case DatePartSpecifier::DAY: + intervalTryAddition(result.days, number, 1); + intervalTryAddition(result.micros, 0, MICROS_PER_DAY, fraction); + break; + case DatePartSpecifier::WEEK: + intervalTryAddition(result.days, number, DAYS_PER_WEEK, fraction); + // Reduce to fraction of a day + fraction *= DAYS_PER_WEEK; + fraction %= MICROS_PER_SEC; + intervalTryAddition(result.micros, 0, MICROS_PER_DAY, fraction); + break; + case DatePartSpecifier::HOUR: + intervalTryAddition(result.micros, number, MICROS_PER_HOUR, fraction); + break; + case DatePartSpecifier::MINUTE: + intervalTryAddition(result.micros, number, MICROS_PER_MINUTE, fraction); + break; + case DatePartSpecifier::SECOND: + intervalTryAddition(result.micros, number, MICROS_PER_SEC, fraction); + break; + case DatePartSpecifier::MILLISECOND: + intervalTryAddition(result.micros, number, MICROS_PER_MSEC, fraction); + break; + case DatePartSpecifier::MICROSECOND: + // Round the fraction + number += (fraction * 2) / MICROS_PER_SEC; + intervalTryAddition(result.micros, number, 1); + break; + default: + throw ConversionException("Unrecognized interval specifier string: " + specifierStr + "."); + } + + foundAny = true; + goto parse_interval; + +end_of_string: + if (!foundAny) { + throw ConversionException( + "Error occurred during parsing interval. Given: \"" + std::string(str, len) + "\"."); + } + return result; +} + +std::string Interval::toString(interval_t interval) { + char buffer[70]; + uint64_t length = IntervalToStringCast::Format(interval, buffer); + return std::string(buffer, length); +} + +// helper function of interval comparison +void Interval::normalizeIntervalEntries(interval_t input, int64_t& months, int64_t& days, + int64_t& micros) { + int64_t extra_months_d = input.days / Interval::DAYS_PER_MONTH; + int64_t extra_months_micros = input.micros / Interval::MICROS_PER_MONTH; + input.days -= extra_months_d * Interval::DAYS_PER_MONTH; + input.micros -= extra_months_micros * Interval::MICROS_PER_MONTH; + + int64_t extra_days_micros = input.micros / Interval::MICROS_PER_DAY; + input.micros -= extra_days_micros * Interval::MICROS_PER_DAY; + + months = input.months + extra_months_d + extra_months_micros; + days = input.days + extra_days_micros; + micros = input.micros; +} + +bool Interval::greaterThan(const interval_t& left, const interval_t& right) { + int64_t lMonths = 0, lDays = 0, lMicros = 0; + int64_t rMonths = 0, rDays = 0, rMicros = 0; + normalizeIntervalEntries(left, lMonths, lDays, lMicros); + normalizeIntervalEntries(right, rMonths, rDays, rMicros); + if (lMonths > rMonths) { + return true; + } else if (lMonths < rMonths) { + return false; + } + if (lDays > rDays) { + return true; + } else if (lDays < rDays) { + return false; + } + return lMicros > rMicros; +} + +void Interval::tryGetDatePartSpecifier(std::string specifier, DatePartSpecifier& result) { + StringUtils::toLower(specifier); + if (specifier == "year" || specifier == "yr" || specifier == "y" || specifier == "years" || + specifier == "yrs") { + result = DatePartSpecifier::YEAR; + } else if (specifier == "month" || specifier == "mon" || specifier == "months" || + specifier == "mons") { + result = DatePartSpecifier::MONTH; + } else if (specifier == "day" || specifier == "days" || specifier == "d" || + specifier == "dayofmonth") { + result = DatePartSpecifier::DAY; + } else if (specifier == "decade" || specifier == "dec" || specifier == "decades" || + specifier == "decs") { + result = DatePartSpecifier::DECADE; + } else if (specifier == "century" || specifier == "cent" || specifier == "centuries" || + specifier == "c") { + result = DatePartSpecifier::CENTURY; + } else if (specifier == "millennium" || specifier == "mil" || specifier == "millenniums" || + specifier == "millennia" || specifier == "mils" || specifier == "millenium" || + specifier == "milleniums") { + result = DatePartSpecifier::MILLENNIUM; + } else if (specifier == "microseconds" || specifier == "microsecond" || specifier == "us" || + specifier == "usec" || specifier == "usecs" || specifier == "usecond" || + specifier == "useconds") { + result = DatePartSpecifier::MICROSECOND; + } else if (specifier == "milliseconds" || specifier == "millisecond" || specifier == "ms" || + specifier == "msec" || specifier == "msecs" || specifier == "msecond" || + specifier == "mseconds") { + result = DatePartSpecifier::MILLISECOND; + } else if (specifier == "second" || specifier == "sec" || specifier == "seconds" || + specifier == "secs" || specifier == "s") { + result = DatePartSpecifier::SECOND; + } else if (specifier == "minute" || specifier == "min" || specifier == "minutes" || + specifier == "mins" || specifier == "m") { + result = DatePartSpecifier::MINUTE; + } else if (specifier == "hour" || specifier == "hr" || specifier == "hours" || + specifier == "hrs" || specifier == "h") { + result = DatePartSpecifier::HOUR; + } else if (specifier == "week" || specifier == "weeks" || specifier == "w" || + specifier == "weekofyear") { + // ISO week number + result = DatePartSpecifier::WEEK; + } else if (specifier == "quarter" || specifier == "quarters") { + // quarter of the year (1-4) + result = DatePartSpecifier::QUARTER; + } else { + throw ConversionException("Unrecognized interval specifier string: " + specifier + "."); + } +} + +int32_t Interval::getIntervalPart(DatePartSpecifier specifier, interval_t interval) { + switch (specifier) { + case DatePartSpecifier::YEAR: + return interval.months / Interval::MONTHS_PER_YEAR; + case DatePartSpecifier::MONTH: + return interval.months % Interval::MONTHS_PER_YEAR; + case DatePartSpecifier::DAY: + return interval.days; + case DatePartSpecifier::DECADE: + return interval.months / Interval::MONTHS_PER_DECADE; + case DatePartSpecifier::CENTURY: + return interval.months / Interval::MONTHS_PER_CENTURY; + case DatePartSpecifier::MILLENNIUM: + return interval.months / Interval::MONTHS_PER_MILLENIUM; + case DatePartSpecifier::QUARTER: + return getIntervalPart(DatePartSpecifier::MONTH, interval) / Interval::MONTHS_PER_QUARTER + + 1; + case DatePartSpecifier::MICROSECOND: + return interval.micros % Interval::MICROS_PER_MINUTE; + case DatePartSpecifier::MILLISECOND: + return getIntervalPart(DatePartSpecifier::MICROSECOND, interval) / + Interval::MICROS_PER_MSEC; + case DatePartSpecifier::SECOND: + return getIntervalPart(DatePartSpecifier::MICROSECOND, interval) / Interval::MICROS_PER_SEC; + case DatePartSpecifier::MINUTE: + return (interval.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; + case DatePartSpecifier::HOUR: + return interval.micros / Interval::MICROS_PER_HOUR; + default: + KU_UNREACHABLE; + } +} + +int64_t Interval::getMicro(const interval_t& val) { + return val.micros + val.months * MICROS_PER_MONTH + val.days * MICROS_PER_DAY; +} + +int64_t Interval::getNanoseconds(const interval_t& val) { + return getMicro(val) * NANOS_PER_MICRO; +} + +const regex::RE2& Interval::regexPattern1() { + static regex::RE2 retval( + "(?i)((0|[1-9]\\d*) " + "+(YEARS?|YRS?|Y|MONS?|MONTHS?|DAYS?|D|DAYOFMONTH|DECADES?|DECS?|CENTURY|CENTURIES|CENT|C|" + "MILLENN?IUMS?|MILS?|MILLENNIA|MICROSECONDS?|US|USECS?|USECONDS?|MILLISECONDS?|MS|SECONDS?|" + "SECS?|S|MINUTES?|MINS?|M|HOURS?|HRS?|H|WEEKS?|WEEKOFYEAR|W|QUARTERS?))( +(0|[1-9]\\d*) " + "+(YEARS?|YRS?|Y|MONS?|MONTHS?|DAYS?|D|DAYOFMONTH|DECADES?|DECS?|CENTURY|CENTURIES|CENT|C|" + "MILLENN?IUMS?|MILS?|MILLENNIA|MICROSECONDS?|US|USECS?|USECONDS?|MILLISECONDS?|MS|SECONDS?|" + "SECS?|S|MINUTES?|MINS?|M|HOURS?|HRS?|H|WEEKS?|WEEKOFYEAR|W|QUARTERS?))*( " + "+\\d+:\\d{2}:\\d{2}(\\.\\d+)?)?"); + return retval; +} + +const regex::RE2& Interval::regexPattern2() { + static regex::RE2 retval("\\d+:\\d{2}:\\d{2}(\\.\\d+)?"); + return retval; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/ku_list.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/ku_list.cpp new file mode 100644 index 0000000000..8cfb117a19 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/ku_list.cpp @@ -0,0 +1,25 @@ +#include "common/types/ku_list.h" + +#include + +#include "storage/storage_utils.h" + +namespace lbug { +namespace common { + +void ku_list_t::set(const uint8_t* values, const LogicalType& dataType) const { + memcpy(reinterpret_cast(overflowPtr), values, + size * storage::StorageUtils::getDataTypeSize(ListType::getChildType(dataType))); +} + +void ku_list_t::set(const std::vector& parameters, LogicalTypeID childTypeId) { + this->size = parameters.size(); + auto numBytesOfListElement = storage::StorageUtils::getDataTypeSize(LogicalType{childTypeId}); + for (auto i = 0u; i < parameters.size(); i++) { + memcpy(reinterpret_cast(this->overflowPtr) + (i * numBytesOfListElement), + parameters[i], numBytesOfListElement); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/ku_string.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/ku_string.cpp new file mode 100644 index 0000000000..f1e257019e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/ku_string.cpp @@ -0,0 +1,79 @@ +#include "common/types/ku_string.h" + +namespace lbug { +namespace common { + +ku_string_t::ku_string_t(const char* value, uint64_t length) : len(length), prefix{} { + if (isShortString(length)) { + memcpy(prefix, value, length); + return; + } + overflowPtr = (uint64_t)(value); + memcpy(prefix, value, PREFIX_LENGTH); +} + +void ku_string_t::set(const std::string& value) { + set(value.data(), value.length()); +} + +void ku_string_t::set(const char* value, uint64_t length) { + if (length <= SHORT_STR_LENGTH) { + setShortString(value, length); + } else { + setLongString(value, length); + } +} + +void ku_string_t::set(const ku_string_t& value) { + if (value.len <= SHORT_STR_LENGTH) { + setShortString(value); + } else { + setLongString(value); + } +} + +std::string ku_string_t::getAsShortString() const { + return std::string((char*)prefix, len); +} + +std::string ku_string_t::getAsString() const { + return std::string(getAsStringView()); +} + +std::string_view ku_string_t::getAsStringView() const { + if (len <= SHORT_STR_LENGTH) { + return std::string_view((char*)prefix, len); + } else { + return std::string_view(reinterpret_cast(overflowPtr), len); + } +} + +bool ku_string_t::operator==(const ku_string_t& rhs) const { + // First compare the length and prefix of the strings. + auto numBytesOfLenAndPrefix = + sizeof(uint32_t) + + std::min((uint64_t)len, static_cast(ku_string_t::PREFIX_LENGTH)); + if (!memcmp(this, &rhs, numBytesOfLenAndPrefix)) { + // If length and prefix of a and b are equal, we compare the overflow buffer. + return !memcmp(getData(), rhs.getData(), len); + } + return false; +} + +bool ku_string_t::operator>(const ku_string_t& rhs) const { + // Compare ku_string_t up to the shared length. + // If there is a tie, we just need to compare the std::string lengths. + auto sharedLen = std::min(len, rhs.len); + auto memcmpResult = memcmp(prefix, rhs.prefix, + sharedLen <= ku_string_t::PREFIX_LENGTH ? sharedLen : ku_string_t::PREFIX_LENGTH); + if (memcmpResult == 0 && len > ku_string_t::PREFIX_LENGTH) { + memcmpResult = memcmp(getData(), rhs.getData(), sharedLen); + } + if (memcmpResult == 0) { + return len > rhs.len; + } + return memcmpResult > 0; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/timestamp_t.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/timestamp_t.cpp new file mode 100644 index 0000000000..641ecb762a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/timestamp_t.cpp @@ -0,0 +1,383 @@ +#include "common/types/timestamp_t.h" + +#include +#include + +#include "common/exception/conversion.h" +#include "common/string_format.h" +#include "function/arithmetic/multiply.h" + +namespace lbug { +namespace common { + +timestamp_t::timestamp_t() : value(0) {} + +timestamp_t::timestamp_t(int64_t value_p) : value(value_p) {} + +timestamp_t& timestamp_t::operator=(int64_t value_p) { + value = value_p; + return *this; +} + +timestamp_t::operator int64_t() const { + return value; +} + +bool timestamp_t::operator==(const timestamp_t& rhs) const { + return value == rhs.value; +} + +bool timestamp_t::operator!=(const timestamp_t& rhs) const { + return value != rhs.value; +} + +bool timestamp_t::operator<=(const timestamp_t& rhs) const { + return value <= rhs.value; +} + +bool timestamp_t::operator<(const timestamp_t& rhs) const { + return value < rhs.value; +} + +bool timestamp_t::operator>(const timestamp_t& rhs) const { + return value > rhs.value; +} + +bool timestamp_t::operator>=(const timestamp_t& rhs) const { + return value >= rhs.value; +} + +bool timestamp_t::operator==(const date_t& rhs) const { + return rhs == *this; +} + +bool timestamp_t::operator!=(const date_t& rhs) const { + return !(rhs == *this); +} + +bool timestamp_t::operator<(const date_t& rhs) const { + return rhs > *this; +} + +bool timestamp_t::operator<=(const date_t& rhs) const { + return rhs >= *this; +} + +bool timestamp_t::operator>(const date_t& rhs) const { + return rhs < *this; +} + +bool timestamp_t::operator>=(const date_t& rhs) const { + return rhs <= *this; +} + +timestamp_t timestamp_t::operator+(const interval_t& interval) const { + date_t date{}; + date_t result_date{}; + dtime_t time{}; + Timestamp::convert(*this, date, time); + result_date = date + interval; + date = result_date; + int64_t diff = + interval.micros - ((interval.micros / Interval::MICROS_PER_DAY) * Interval::MICROS_PER_DAY); + time.micros += diff; + if (time.micros >= Interval::MICROS_PER_DAY) { + time.micros -= Interval::MICROS_PER_DAY; + date.days++; + } else if (time.micros < 0) { + time.micros += Interval::MICROS_PER_DAY; + date.days--; + } + return Timestamp::fromDateTime(date, time); +} + +timestamp_t timestamp_t::operator-(const interval_t& interval) const { + interval_t inverseRight{}; + inverseRight.months = -interval.months; + inverseRight.days = -interval.days; + inverseRight.micros = -interval.micros; + return (*this) + inverseRight; +} + +interval_t timestamp_t::operator-(const timestamp_t& rhs) const { + interval_t result{}; + uint64_t diff = std::abs(value - rhs.value); + result.months = 0; + result.days = diff / Interval::MICROS_PER_DAY; + result.micros = diff % Interval::MICROS_PER_DAY; + if (value < rhs.value) { + result.days = -result.days; + result.micros = -result.micros; + } + return result; +} + +static_assert(sizeof(timestamp_t) == sizeof(int64_t), "timestamp_t was padded"); + +bool Timestamp::tryConvertTimestamp(const char* str, uint64_t len, timestamp_t& result) { + uint64_t pos = 0; + date_t date; + dtime_t time; + + if (!Date::tryConvertDate(str, len, pos, date, true /*allowTrailing*/)) { + return false; + } + if (pos == len) { + // no time: only a date + result = fromDateTime(date, dtime_t(0)); + return true; + } + // try to parse a time field + if (str[pos] == ' ' || str[pos] == 'T') { + pos++; + } + uint64_t time_pos = 0; + if (!Time::tryConvertTime(str + pos, len - pos, time_pos, time)) { + return false; + } + pos += time_pos; + result = fromDateTime(date, time); + if (pos < len) { + // skip a "Z" at the end (as per the ISO8601 specs) + if (str[pos] == 'Z') { + pos++; + } + int hour_offset = 0, minute_offset = 0; + if (Timestamp::tryParseUTCOffset(str, pos, len, hour_offset, minute_offset)) { + result.value -= hour_offset * Interval::MICROS_PER_HOUR + + minute_offset * Interval::MICROS_PER_MINUTE; + } + // skip any spaces at the end + while (pos < len && isspace(str[pos])) { + pos++; + } + if (pos < len) { + return false; + } + } + return true; +} + +// string format is YYYY-MM-DDThh:mm:ss[.mmmmmm] +// T may be a space, timezone is not supported yet +// ISO 8601 +timestamp_t Timestamp::fromCString(const char* str, uint64_t len) { + timestamp_t result; + if (!tryConvertTimestamp(str, len, result)) { + throw ConversionException(getTimestampConversionExceptionMsg(str, len)); + } + return result; +} + +bool Timestamp::tryParseUTCOffset(const char* str, uint64_t& pos, uint64_t len, int& hour_offset, + int& minute_offset) { + minute_offset = 0; + uint64_t curpos = pos; + // parse the next 3 characters + if (curpos + 3 > len) { + // no characters left to parse + return false; + } + char sign_char = str[curpos]; + if (sign_char != '+' && sign_char != '-') { + // expected either + or - + return false; + } + curpos++; + if (!isdigit(str[curpos]) || !isdigit(str[curpos + 1])) { + // expected +HH or -HH + return false; + } + hour_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); + if (sign_char == '-') { + hour_offset = -hour_offset; + } + curpos += 2; + + // optional minute specifier: expected either "MM" or ":MM" + if (curpos >= len) { + // done, nothing left + pos = curpos; + return true; + } + if (str[curpos] == ':') { + curpos++; + } + if (curpos + 2 > len || !isdigit(str[curpos]) || !isdigit(str[curpos + 1])) { + // no MM specifier + pos = curpos; + return true; + } + // we have an MM specifier: parse it + minute_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); + if (sign_char == '-') { + minute_offset = -minute_offset; + } + pos = curpos + 2; + return true; +} + +std::string Timestamp::toString(timestamp_t timestamp) { + date_t date; + dtime_t time; + Timestamp::convert(timestamp, date, time); + return Date::toString(date) + " " + Time::toString(time); +} + +// Date header is in the format: %Y%m%d. +std::string Timestamp::getDateHeader(const timestamp_t& timestamp) { + auto date = Timestamp::getDate(timestamp); + int32_t year = 0, month = 0, day = 0; + std::string yearStr, monthStr, dayStr; + Date::convert(date, year, month, day); + yearStr = std::to_string(year); + monthStr = std::to_string(month); + dayStr = std::to_string(day); + if (month < 10) { + monthStr = "0" + monthStr; + } + if (day < 10) { + dayStr = "0" + dayStr; + } + return stringFormat("{}{}{}", yearStr, monthStr, dayStr); +} + +// Timestamp header is in the format: %Y%m%dT%H%M%SZ. +std::string Timestamp::getDateTimeHeader(const timestamp_t& timestamp) { + auto dateHeader = getDateHeader(timestamp); + auto time = Timestamp::getTime(timestamp); + int32_t hours = 0, minutes = 0, seconds = 0, micros = 0; + std::string hoursStr, minutesStr, secondsStr; + Time::convert(time, hours, minutes, seconds, micros); + hoursStr = std::to_string(hours); + minutesStr = std::to_string(minutes); + secondsStr = std::to_string(seconds); + + if (hours < 10) { + hoursStr = "0" + hoursStr; + } + if (minutes < 10) { + minutesStr = "0" + minutesStr; + } + if (seconds < 10) { + secondsStr = "0" + secondsStr; + } + return stringFormat("{}T{}{}{}Z", dateHeader, hoursStr, minutesStr, secondsStr); +} + +date_t Timestamp::getDate(timestamp_t timestamp) { + return date_t((timestamp.value + (timestamp.value < 0)) / Interval::MICROS_PER_DAY - + (timestamp.value < 0)); +} + +dtime_t Timestamp::getTime(timestamp_t timestamp) { + date_t date = Timestamp::getDate(timestamp); + return dtime_t(timestamp.value - (int64_t(date.days) * int64_t(Interval::MICROS_PER_DAY))); +} + +timestamp_t Timestamp::fromDateTime(date_t date, dtime_t time) { + timestamp_t result; + int32_t year = 0, month = 0, day = 0, hour = 0, minute = 0, second = 0, microsecond = -1; + Date::convert(date, year, month, day); + Time::convert(time, hour, minute, second, microsecond); + result.value = date.days * Interval::MICROS_PER_DAY + time.micros; + return result; +} + +void Timestamp::convert(timestamp_t timestamp, date_t& out_date, dtime_t& out_time) { + out_date = getDate(timestamp); + out_time = getTime(timestamp); +} + +timestamp_t Timestamp::fromEpochMicroSeconds(int64_t micros) { + return timestamp_t(micros); +} + +timestamp_t Timestamp::fromEpochMilliSeconds(int64_t ms) { + int64_t microSeconds = 0; + function::Multiply::operation(ms, Interval::MICROS_PER_MSEC, microSeconds); + return fromEpochMicroSeconds(microSeconds); +} + +// LCOV_EXCL_START +// TODO(Kebing): will add the tests in the timestamp PR +timestamp_t Timestamp::fromEpochSeconds(int64_t sec) { + int64_t microSeconds = 0; + function::Multiply::operation(sec, Interval::MICROS_PER_SEC, microSeconds); + return fromEpochMicroSeconds(microSeconds); +} +// LCOV_EXCL_STOP + +timestamp_t Timestamp::fromEpochNanoSeconds(int64_t ns) { + return fromEpochMicroSeconds(ns / 1000); +} + +int32_t Timestamp::getTimestampPart(DatePartSpecifier specifier, timestamp_t timestamp) { + switch (specifier) { + case DatePartSpecifier::MICROSECOND: + return getTime(timestamp).micros % Interval::MICROS_PER_MINUTE; + case DatePartSpecifier::MILLISECOND: + return getTimestampPart(DatePartSpecifier::MICROSECOND, timestamp) / + Interval::MICROS_PER_MSEC; + case DatePartSpecifier::SECOND: + return getTimestampPart(DatePartSpecifier::MICROSECOND, timestamp) / + Interval::MICROS_PER_SEC; + case DatePartSpecifier::MINUTE: + return (getTime(timestamp).micros % Interval::MICROS_PER_HOUR) / + Interval::MICROS_PER_MINUTE; + case DatePartSpecifier::HOUR: + return getTime(timestamp).micros / Interval::MICROS_PER_HOUR; + default: + date_t date = getDate(timestamp); + return Date::getDatePart(specifier, date); + } +} + +timestamp_t Timestamp::trunc(DatePartSpecifier specifier, timestamp_t timestamp) { + int32_t hour = 0, min = 0, sec = 0, micros = 0; + date_t date; + dtime_t time; + Timestamp::convert(timestamp, date, time); + Time::convert(time, hour, min, sec, micros); + switch (specifier) { + case DatePartSpecifier::MICROSECOND: + return timestamp; + case DatePartSpecifier::MILLISECOND: + micros -= micros % Interval::MICROS_PER_MSEC; + return Timestamp::fromDateTime(date, Time::fromTime(hour, min, sec, micros)); + case DatePartSpecifier::SECOND: + return Timestamp::fromDateTime(date, Time::fromTime(hour, min, sec, 0 /* microseconds */)); + case DatePartSpecifier::MINUTE: + return Timestamp::fromDateTime(date, + Time::fromTime(hour, min, 0 /* seconds */, 0 /* microseconds */)); + case DatePartSpecifier::HOUR: + return Timestamp::fromDateTime(date, + Time::fromTime(hour, 0 /* minutes */, 0 /* seconds */, 0 /* microseconds */)); + default: + date = getDate(timestamp); + return fromDateTime(Date::trunc(specifier, date), dtime_t(0)); + } +} + +int64_t Timestamp::getEpochNanoSeconds(const timestamp_t& timestamp) { + int64_t result = 0; + function::Multiply::operation(timestamp.value, Interval::NANOS_PER_MICRO, result); + return result; +} + +int64_t Timestamp::getEpochMilliSeconds(const timestamp_t& timestamp) { + return timestamp.value / Interval::MICROS_PER_MSEC; +} + +int64_t Timestamp::getEpochSeconds(const timestamp_t& timestamp) { + return timestamp.value / Interval::MICROS_PER_SEC; +} + +timestamp_t Timestamp::getCurrentTimestamp() { + auto now = std::chrono::system_clock::now(); + return Timestamp::fromEpochMilliSeconds( + duration_cast(now.time_since_epoch()).count()); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/types.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/types.cpp new file mode 100644 index 0000000000..017e266a83 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/types.cpp @@ -0,0 +1,1969 @@ +#include "common/types/types.h" + +#include + +#include "catalog/catalog.h" +#include "common/cast.h" +#include "common/constants.h" +#include "common/exception/binder.h" +#include "common/exception/conversion.h" +#include "common/exception/runtime.h" +#include "common/null_buffer.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/string_utils.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/ku_list.h" +#include "common/types/ku_string.h" +#include "common/types/uint128_t.h" +#include "function/built_in_function_utils.h" +#include "function/cast/functions/numeric_limits.h" +#include "storage/compression/float_compression.h" +#include "transaction/transaction.h" + +using lbug::function::BuiltInFunctionsUtils; + +namespace lbug { +namespace common { + +internalID_t::internalID_t() : offset{INVALID_OFFSET}, tableID{INVALID_TABLE_ID} {} + +internalID_t::internalID_t(offset_t offset, table_id_t tableID) + : offset(offset), tableID(tableID) {} + +bool internalID_t::operator==(const internalID_t& rhs) const { + return offset == rhs.offset && tableID == rhs.tableID; +} + +bool internalID_t::operator!=(const internalID_t& rhs) const { + return offset != rhs.offset || tableID != rhs.tableID; +} + +bool internalID_t::operator>(const internalID_t& rhs) const { + return (tableID > rhs.tableID) || (tableID == rhs.tableID && offset > rhs.offset); +} + +bool internalID_t::operator>=(const internalID_t& rhs) const { + return (tableID > rhs.tableID) || (tableID == rhs.tableID && offset >= rhs.offset); +} + +bool internalID_t::operator<(const internalID_t& rhs) const { + return (tableID < rhs.tableID) || (tableID == rhs.tableID && offset < rhs.offset); +} + +bool internalID_t::operator<=(const internalID_t& rhs) const { + return (tableID < rhs.tableID) || (tableID == rhs.tableID && offset <= rhs.offset); +} + +std::string DecimalType::insertDecimalPoint(const std::string& value, uint32_t positionFromEnd) { + if (positionFromEnd == 0) { + return value; + // Don't want to end up with cases where integral values are followed by a useless dot + } + std::string retval; + if (positionFromEnd > value.size()) { + auto greaterBy = positionFromEnd - value.size(); + retval = "0."; + for (auto i = 0u; i < greaterBy; i++) { + retval += "0"; + } + retval += value; + } else { + auto lessBy = value.size() - positionFromEnd; + retval = value.substr(0, lessBy); + if (retval == "" || retval == "-") { + retval += '0'; + } + retval += "."; + retval += value.substr(lessBy); + } + return retval; +} + +bool UDTTypeInfo::operator==(const lbug::common::ExtraTypeInfo& other) const { + return typeName == other.constPtrCast()->typeName; +} + +std::unique_ptr UDTTypeInfo::copy() const { + return std::make_unique(typeName); +} + +std::unique_ptr UDTTypeInfo::deserialize(Deserializer& deserializer) { + std::string typeName; + deserializer.deserializeValue(typeName); + return std::make_unique(std::move(typeName)); +} + +void UDTTypeInfo::serializeInternal(Serializer& serializer) const { + serializer.serializeValue(typeName); +} + +uint32_t DecimalType::getPrecision(const LogicalType& type) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::DECIMAL); + auto decimalTypeInfo = type.extraTypeInfo->constPtrCast(); + return decimalTypeInfo->getPrecision(); +} + +uint32_t DecimalType::getScale(const LogicalType& type) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::DECIMAL); + auto decimalTypeInfo = type.extraTypeInfo->constPtrCast(); + return decimalTypeInfo->getScale(); +} + +const LogicalType& ListType::getChildType(const lbug::common::LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::LIST || + type.getPhysicalType() == PhysicalTypeID::ARRAY); + auto listTypeInfo = type.extraTypeInfo->constPtrCast(); + return listTypeInfo->getChildType(); +} + +const LogicalType& ArrayType::getChildType(const LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::ARRAY); + auto arrayTypeInfo = type.extraTypeInfo->constPtrCast(); + return arrayTypeInfo->getChildType(); +} + +uint64_t ArrayType::getNumElements(const LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::ARRAY); + auto arrayTypeInfo = type.extraTypeInfo->constPtrCast(); + return arrayTypeInfo->getNumElements(); +} + +std::vector StructType::getFieldTypes(const LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->getChildrenTypes(); +} + +const LogicalType& StructType::getFieldType(const LogicalType& type, struct_field_idx_t idx) { + return StructType::getField(type, idx).getType(); +} + +const LogicalType& StructType::getFieldType(const LogicalType& type, const std::string& key) { + return StructType::getField(type, key).getType(); +} + +std::vector StructType::getFieldNames(const LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->getChildrenNames(); +} + +uint64_t StructType::getNumFields(const LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + return getFields(type).size(); +} + +const std::vector& StructType::getFields(const LogicalType& type) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->getStructFields(); +} + +bool StructType::hasField(const LogicalType& type, const std::string& key) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->hasField(key); +} + +const StructField& StructType::getField(const LogicalType& type, struct_field_idx_t idx) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->getStructField(idx); +} + +const StructField& StructType::getField(const LogicalType& type, const std::string& key) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->getStructField(key); +} + +struct_field_idx_t StructType::getFieldIdx(const LogicalType& type, const std::string& key) { + KU_ASSERT(type.getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = type.extraTypeInfo->constPtrCast(); + return structTypeInfo->getStructFieldIdx(key); +} + +const LogicalType& MapType::getKeyType(const LogicalType& type) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::MAP); + return *StructType::getFieldTypes(ListType::getChildType(type))[0]; +} + +const LogicalType& MapType::getValueType(const LogicalType& type) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::MAP); + return *StructType::getFieldTypes(ListType::getChildType(type))[1]; +} + +union_field_idx_t UnionType::getInternalFieldIdx(union_field_idx_t idx) { + return idx + 1; +} + +std::string UnionType::getFieldName(const LogicalType& type, union_field_idx_t idx) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::UNION); + return StructType::getFieldNames(type)[getInternalFieldIdx(idx)]; +} + +const LogicalType& UnionType::getFieldType(const LogicalType& type, union_field_idx_t idx) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::UNION); + return StructType::getFieldType(type, getInternalFieldIdx(idx)); +} + +const LogicalType& UnionType::getFieldType(const LogicalType& type, const std::string& key) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::UNION); + return StructType::getFieldType(type, key); +} + +uint64_t UnionType::getNumFields(const LogicalType& type) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::UNION); + return StructType::getNumFields(type) - 1; +} + +bool UnionType::hasField(const LogicalType& type, const std::string& key) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::UNION); + return StructType::hasField(type, key); +} + +union_field_idx_t UnionType::getFieldIdx(const LogicalType& type, const std::string& key) { + KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::UNION); + return StructType::getFieldIdx(type, key) - 1; // inverse of getInternalFieldIdx +} + +std::string PhysicalTypeUtils::toString(PhysicalTypeID physicalType) { + // LCOV_EXCL_START + switch (physicalType) { + case PhysicalTypeID::BOOL: + return "BOOL"; + case PhysicalTypeID::INT64: + return "INT64"; + case PhysicalTypeID::INT32: + return "INT32"; + case PhysicalTypeID::INT16: + return "INT16"; + case PhysicalTypeID::INT8: + return "INT8"; + case PhysicalTypeID::UINT64: + return "UINT64"; + case PhysicalTypeID::UINT32: + return "UINT32"; + case PhysicalTypeID::UINT16: + return "UINT16"; + case PhysicalTypeID::UINT8: + return "UINT8"; + case PhysicalTypeID::INT128: + return "INT128"; + case PhysicalTypeID::DOUBLE: + return "DOUBLE"; + case PhysicalTypeID::FLOAT: + return "FLOAT"; + case PhysicalTypeID::INTERVAL: + return "INTERVAL"; + case PhysicalTypeID::INTERNAL_ID: + return "INTERNAL_ID"; + case PhysicalTypeID::UINT128: + return "UINT128"; + case PhysicalTypeID::STRING: + return "STRING"; + case PhysicalTypeID::STRUCT: + return "STRUCT"; + case PhysicalTypeID::LIST: + return "LIST"; + case PhysicalTypeID::ARRAY: + return "ARRAY"; + case PhysicalTypeID::POINTER: + return "POINTER"; + case PhysicalTypeID::ALP_EXCEPTION_FLOAT: + return "ALP_EXCEPTION_FLOAT"; + case PhysicalTypeID::ALP_EXCEPTION_DOUBLE: + return "ALP_EXCEPTION_DOUBLE"; + default: + KU_UNREACHABLE; + } + // LCOV_EXCL_STOP +} + +uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) { + switch (physicalType) { + case PhysicalTypeID::BOOL: + return sizeof(bool); + case PhysicalTypeID::INT64: + return sizeof(int64_t); + case PhysicalTypeID::INT32: + return sizeof(int32_t); + case PhysicalTypeID::INT16: + return sizeof(int16_t); + case PhysicalTypeID::INT8: + return sizeof(int8_t); + case PhysicalTypeID::UINT64: + return sizeof(uint64_t); + case PhysicalTypeID::UINT32: + return sizeof(uint32_t); + case PhysicalTypeID::UINT16: + return sizeof(uint16_t); + case PhysicalTypeID::UINT8: + return sizeof(uint8_t); + case PhysicalTypeID::INT128: + return sizeof(int128_t); + case PhysicalTypeID::DOUBLE: + return sizeof(double); + case PhysicalTypeID::FLOAT: + return sizeof(float); + case PhysicalTypeID::INTERVAL: + return sizeof(interval_t); + case PhysicalTypeID::INTERNAL_ID: + return sizeof(internalID_t); + case PhysicalTypeID::UINT128: + return sizeof(uint128_t); + case PhysicalTypeID::ALP_EXCEPTION_FLOAT: + return storage::EncodeException::sizeInBytes(); + case PhysicalTypeID::ALP_EXCEPTION_DOUBLE: + return storage::EncodeException::sizeInBytes(); + default: + KU_UNREACHABLE; + } +} + +bool DecimalTypeInfo::operator==(const ExtraTypeInfo& other) const { + auto otherDecimalTypeInfo = ku_dynamic_cast(&other); + if (otherDecimalTypeInfo) { + return precision == otherDecimalTypeInfo->precision && scale == otherDecimalTypeInfo->scale; + } + return false; +} + +std::unique_ptr DecimalTypeInfo::copy() const { + return std::make_unique(precision, scale); +} + +std::unique_ptr DecimalTypeInfo::deserialize(Deserializer& deserializer) { + uint32_t precision = 0, scale = 0; + deserializer.deserializeValue(precision); + deserializer.deserializeValue(scale); + return std::make_unique(precision, scale); +} + +void DecimalTypeInfo::serializeInternal(Serializer& serializer) const { + serializer.serializeValue(precision); + serializer.serializeValue(scale); +} + +bool ListTypeInfo::containsAny() const { + return childType.containsAny(); +} + +bool ListTypeInfo::operator==(const ExtraTypeInfo& other) const { + auto otherListTypeInfo = ku_dynamic_cast(&other); + if (otherListTypeInfo) { + return childType == otherListTypeInfo->childType; + } + return false; +} + +std::unique_ptr ListTypeInfo::copy() const { + return std::make_unique(childType.copy()); +} + +std::unique_ptr ListTypeInfo::deserialize(Deserializer& deserializer) { + return std::make_unique(LogicalType::deserialize(deserializer)); +} + +void ListTypeInfo::serializeInternal(Serializer& serializer) const { + childType.serialize(serializer); +} + +bool ArrayTypeInfo::operator==(const ExtraTypeInfo& other) const { + auto otherArrayTypeInfo = ku_dynamic_cast(&other); + if (otherArrayTypeInfo) { + return childType == otherArrayTypeInfo->childType && + numElements == otherArrayTypeInfo->numElements; + } + return false; +} + +std::unique_ptr ArrayTypeInfo::deserialize(Deserializer& deserializer) { + auto childType = LogicalType::deserialize(deserializer); + uint64_t numElements = 0; + deserializer.deserializeValue(numElements); + return std::make_unique(std::move(childType), numElements); +} + +std::unique_ptr ArrayTypeInfo::copy() const { + return std::make_unique(childType.copy(), numElements); +} + +void ArrayTypeInfo::serializeInternal(Serializer& serializer) const { + ListTypeInfo::serializeInternal(serializer); + serializer.serializeValue(numElements); +} + +bool StructField::containsAny() const { + return type.containsAny(); +} + +bool StructField::operator==(const StructField& other) const { + return type == other.type; +} + +void StructField::serialize(Serializer& serializer) const { + serializer.serializeValue(name); + type.serialize(serializer); +} + +StructField StructField::deserialize(Deserializer& deserializer) { + std::string name; + deserializer.deserializeValue(name); + auto type = LogicalType::deserialize(deserializer); + return StructField(std::move(name), std::move(type)); +} + +StructField StructField::copy() const { + return StructField(name, type.copy()); +} + +StructTypeInfo::StructTypeInfo(std::vector&& fields) : fields{std::move(fields)} { + for (auto i = 0u; i < this->fields.size(); i++) { + auto fieldName = this->fields[i].getName(); + StringUtils::toUpper(fieldName); + fieldNameToIdxMap.emplace(std::move(fieldName), i); + } +} + +StructTypeInfo::StructTypeInfo(const std::vector& fieldNames, + const std::vector& fieldTypes) { + for (auto i = 0u; i < fieldNames.size(); ++i) { + auto fieldName = fieldNames[i]; + auto normalizedFieldName = fieldName; + StringUtils::toUpper(normalizedFieldName); + fieldNameToIdxMap.emplace(normalizedFieldName, i); + fields.emplace_back(fieldName, fieldTypes[i].copy()); + } +} + +bool StructTypeInfo::hasField(const std::string& fieldName) const { + auto copy = fieldName; + StringUtils::toUpper(copy); + return fieldNameToIdxMap.contains(copy); +} + +struct_field_idx_t StructTypeInfo::getStructFieldIdx(std::string fieldName) const { + StringUtils::toUpper(fieldName); + if (fieldNameToIdxMap.contains(fieldName)) { + return fieldNameToIdxMap.at(fieldName); + } + return INVALID_STRUCT_FIELD_IDX; +} + +const StructField& StructTypeInfo::getStructField(struct_field_idx_t idx) const { + return fields[idx]; +} + +const StructField& StructTypeInfo::getStructField(const std::string& fieldName) const { + auto idx = getStructFieldIdx(fieldName); + if (idx == INVALID_STRUCT_FIELD_IDX) { + throw BinderException("Cannot find field " + fieldName + " in STRUCT."); + } + return fields[idx]; +} + +const LogicalType& StructTypeInfo::getChildType(lbug::common::struct_field_idx_t idx) const { + return fields[idx].getType(); +} + +std::vector StructTypeInfo::getChildrenTypes() const { + std::vector childrenTypesToReturn; + for (auto i = 0u; i < fields.size(); i++) { + childrenTypesToReturn.push_back(&fields[i].getType()); + } + return childrenTypesToReturn; +} + +std::vector StructTypeInfo::getChildrenNames() const { + std::vector childrenNames{fields.size()}; + for (auto i = 0u; i < fields.size(); i++) { + childrenNames[i] = fields[i].getName(); + } + return childrenNames; +} + +const std::vector& StructTypeInfo::getStructFields() const { + return fields; +} + +bool StructTypeInfo::containsAny() const { + for (auto& field : fields) { + if (field.containsAny()) { + return true; + } + } + return false; +} + +bool StructTypeInfo::operator==(const ExtraTypeInfo& other) const { + auto otherStructTypeInfo = ku_dynamic_cast(&other); + if (otherStructTypeInfo) { + if (fields.size() != otherStructTypeInfo->fields.size()) { + return false; + } + for (auto i = 0u; i < fields.size(); ++i) { + if (fields[i] != otherStructTypeInfo->fields[i]) { + return false; + } + } + return true; + } + return false; +} + +std::unique_ptr StructTypeInfo::deserialize(Deserializer& deserializer) { + std::vector fields; + deserializer.deserializeVector(fields); + return std::make_unique(std::move(fields)); +} + +std::unique_ptr StructTypeInfo::copy() const { + std::vector structFields{fields.size()}; + for (auto i = 0u; i < fields.size(); i++) { + structFields[i] = fields[i].copy(); + } + return std::make_unique(std::move(structFields)); +} + +void StructTypeInfo::serializeInternal(Serializer& serializer) const { + serializer.serializeVector(fields); +} + +static std::string getIncompleteTypeErrMsg(LogicalTypeID id) { + return "Trying to create nested type " + LogicalTypeUtils::toString(id) + + " without child information."; +} + +LogicalType::LogicalType(LogicalTypeID typeID, TypeCategory info) + : typeID{typeID}, extraTypeInfo{nullptr}, category{info} { + // LCOV_EXCL_START + switch (typeID) { + case LogicalTypeID::DECIMAL: + case LogicalTypeID::LIST: + case LogicalTypeID::ARRAY: + case LogicalTypeID::STRUCT: + case LogicalTypeID::MAP: + case LogicalTypeID::UNION: + throw BinderException(getIncompleteTypeErrMsg(typeID)); + default: + break; + } + physicalType = getPhysicalType(typeID); + // LCOV_EXCL_STOP +} + +LogicalType::LogicalType(LogicalTypeID typeID, std::unique_ptr extraTypeInfo) + : typeID{typeID}, extraTypeInfo{std::move(extraTypeInfo)} { + physicalType = getPhysicalType(typeID, this->extraTypeInfo); +} + +LogicalType::LogicalType(const LogicalType& other) { + typeID = other.typeID; + physicalType = other.physicalType; + if (other.extraTypeInfo != nullptr) { + extraTypeInfo = other.extraTypeInfo->copy(); + } + category = other.category; +} + +bool LogicalType::containsAny() const { + if (extraTypeInfo != nullptr) { + return extraTypeInfo->containsAny(); + } + return typeID == LogicalTypeID::ANY; +} + +bool LogicalType::operator==(const LogicalType& other) const { + if (typeID != other.typeID || category != other.category) { + return false; + } + if (extraTypeInfo) { + return *extraTypeInfo == *other.extraTypeInfo; + } + return true; +} + +bool LogicalType::operator!=(const LogicalType& other) const { + return !((*this) == other); +} + +std::string LogicalType::toString() const { + if (!isInternalType()) { + return extraTypeInfo->constPtrCast()->getTypeName(); + } + switch (typeID) { + case LogicalTypeID::MAP: { + auto structType = ku_dynamic_cast(extraTypeInfo.get())->getChildType(); + auto fieldTypes = StructType::getFieldTypes(structType); + return "MAP(" + fieldTypes[0]->toString() + ", " + fieldTypes[1]->toString() + ")"; + } + case LogicalTypeID::LIST: { + auto listTypeInfo = ku_dynamic_cast(extraTypeInfo.get()); + return listTypeInfo->getChildType().toString() + "[]"; + } + case LogicalTypeID::ARRAY: { + auto arrayTypeInfo = ku_dynamic_cast(extraTypeInfo.get()); + return arrayTypeInfo->getChildType().toString() + "[" + + std::to_string(arrayTypeInfo->getNumElements()) + "]"; + } + case LogicalTypeID::UNION: { + auto unionTypeInfo = ku_dynamic_cast(extraTypeInfo.get()); + std::string dataTypeStr = LogicalTypeUtils::toString(typeID) + "("; + auto numFields = unionTypeInfo->getChildrenTypes().size(); + auto fieldNames = unionTypeInfo->getChildrenNames(); + for (auto i = 1u; i < numFields; i++) { + dataTypeStr += fieldNames[i] + " "; + dataTypeStr += unionTypeInfo->getChildType(i).toString(); + dataTypeStr += (i == numFields - 1 ? ")" : ", "); + } + return dataTypeStr; + } + case LogicalTypeID::STRUCT: { + auto structTypeInfo = ku_dynamic_cast(extraTypeInfo.get()); + std::string dataTypeStr = LogicalTypeUtils::toString(typeID) + "("; + auto numFields = structTypeInfo->getChildrenTypes().size(); + auto fieldNames = structTypeInfo->getChildrenNames(); + for (auto i = 0u; i < numFields; i++) { + dataTypeStr += fieldNames[i] + " "; + dataTypeStr += structTypeInfo->getChildType(i).toString(); + if (i + 1 != numFields) { + dataTypeStr += ", "; + } + } + return dataTypeStr + ")"; + } + case LogicalTypeID::DECIMAL: { + auto decimalTypeInfo = ku_dynamic_cast(extraTypeInfo.get()); + return "DECIMAL(" + std::to_string(decimalTypeInfo->getPrecision()) + ", " + + std::to_string(decimalTypeInfo->getScale()) + ")"; + } + case LogicalTypeID::ANY: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::BOOL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::INT8: + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::INT128: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INTERVAL: + case LogicalTypeID::UINT128: + case LogicalTypeID::BLOB: + case LogicalTypeID::UUID: + case LogicalTypeID::STRING: + case LogicalTypeID::SERIAL: + return LogicalTypeUtils::toString(typeID); + default: + KU_UNREACHABLE; + } +} + +static bool tryGetIDFromString(const std::string& trimmedStr, LogicalTypeID& id); +static std::vector parseStructFields(const std::string& structTypeStr); +static LogicalType parseListType(const std::string& trimmedStr, + main::ClientContext* context = nullptr); +static LogicalType parseArrayType(const std::string& trimmedStr, + main::ClientContext* context = nullptr); +static std::vector parseStructTypeInfo(const std::string& structTypeStr, + main::ClientContext* context, std::string defType); +static LogicalType parseStructType(const std::string& trimmedStr, + main::ClientContext* context = nullptr); +static LogicalType parseMapType(const std::string& trimmedStr, + main::ClientContext* context = nullptr); +static LogicalType parseUnionType(const std::string& trimmedStr, + main::ClientContext* context = nullptr); +static LogicalType parseDecimalType(const std::string& trimmedStr); + +bool LogicalType::isBuiltInType(const std::string& str) { + auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(str)); + auto upperDataTypeString = StringUtils::getUpper(trimmedStr); + auto id = LogicalTypeID::ANY; + try { + if (upperDataTypeString.ends_with("[]")) { + parseListType(trimmedStr); + } else if (upperDataTypeString.ends_with("]")) { + parseArrayType(trimmedStr); + } else if (upperDataTypeString.starts_with("STRUCT")) { + parseStructType(trimmedStr); + } else if (upperDataTypeString.starts_with("MAP")) { + parseMapType(trimmedStr); + } else if (upperDataTypeString.starts_with("UNION")) { + parseUnionType(trimmedStr); + } else if (upperDataTypeString.starts_with("DECIMAL") || + upperDataTypeString.starts_with("NUMERIC")) { + parseDecimalType(trimmedStr); + } else if (!tryGetIDFromString(upperDataTypeString, id)) { + return false; + } + } catch (...) { + return false; + } + return true; +} + +LogicalType LogicalType::convertFromString(const std::string& str, main::ClientContext* context) { + LogicalType type; + auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(str)); + auto upperDataTypeString = StringUtils::getUpper(trimmedStr); + if (upperDataTypeString.ends_with("[]")) { + type = parseListType(trimmedStr, context); + } else if (upperDataTypeString.ends_with("]")) { + type = parseArrayType(trimmedStr, context); + } else if (upperDataTypeString.starts_with("STRUCT")) { + type = parseStructType(trimmedStr, context); + } else if (upperDataTypeString.starts_with("MAP")) { + type = parseMapType(trimmedStr, context); + } else if (upperDataTypeString.starts_with("UNION")) { + type = parseUnionType(trimmedStr, context); + } else if (upperDataTypeString.starts_with("DECIMAL") || + upperDataTypeString.starts_with("NUMERIC")) { + type = parseDecimalType(trimmedStr); + } else if (tryGetIDFromString(upperDataTypeString, type.typeID)) { + type.physicalType = LogicalType::getPhysicalType(type.typeID, type.extraTypeInfo); + } else if (context != nullptr) { + auto transaction = transaction::Transaction::Get(*context); + type = catalog::Catalog::Get(*context)->getType(transaction, upperDataTypeString); + } else { + throw common::RuntimeException{"Invalid datatype string: " + str}; + } + return type; +} + +void LogicalType::serialize(Serializer& serializer) const { + serializer.serializeValue(typeID); + serializer.serializeValue(physicalType); + serializer.serializeValue(category); + if (extraTypeInfo != nullptr) { + extraTypeInfo->serialize(serializer); + } +} + +LogicalType LogicalType::deserialize(Deserializer& deserializer) { + auto typeID = LogicalTypeID::ANY; + deserializer.deserializeValue(typeID); + auto physicalType = PhysicalTypeID::ANY; + deserializer.deserializeValue(physicalType); + TypeCategory typeCategory{}; + deserializer.deserializeValue(typeCategory); + std::unique_ptr extraTypeInfo; + if (typeCategory == TypeCategory::UDT) { + extraTypeInfo = UDTTypeInfo::deserialize(deserializer); + } else { + switch (physicalType) { + case PhysicalTypeID::LIST: { + extraTypeInfo = ListTypeInfo::deserialize(deserializer); + } break; + case PhysicalTypeID::ARRAY: { + extraTypeInfo = ArrayTypeInfo::deserialize(deserializer); + } break; + case PhysicalTypeID::STRUCT: { + extraTypeInfo = StructTypeInfo::deserialize(deserializer); + } break; + default: + if (typeID == LogicalTypeID::DECIMAL) { + extraTypeInfo = DecimalTypeInfo::deserialize(deserializer); + } else { + extraTypeInfo = nullptr; + } + } + } + auto result = LogicalType(); + result.typeID = typeID; + result.physicalType = physicalType; + result.extraTypeInfo = std::move(extraTypeInfo); + result.category = typeCategory; + return result; +} + +std::vector LogicalType::copy(const std::vector& types) { + std::vector typesCopy; + for (auto& type : types) { + typesCopy.push_back(type.copy()); + } + return typesCopy; +} + +std::vector LogicalType::copy(const std::vector& types) { + std::vector typesCopy; + typesCopy.reserve(types.size()); + for (auto& type : types) { + typesCopy.push_back(type->copy()); + } + return typesCopy; +} + +PhysicalTypeID LogicalType::getPhysicalType(LogicalTypeID typeID, + const std::unique_ptr& extraTypeInfo) { + switch (typeID) { + case LogicalTypeID::ANY: { + return PhysicalTypeID::ANY; + } + case LogicalTypeID::BOOL: { + return PhysicalTypeID::BOOL; + } + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + return PhysicalTypeID::INT64; + } + case LogicalTypeID::DATE: + case LogicalTypeID::INT32: { + return PhysicalTypeID::INT32; + } + case LogicalTypeID::INT16: { + return PhysicalTypeID::INT16; + } + case LogicalTypeID::INT8: { + return PhysicalTypeID::INT8; + } + case LogicalTypeID::UINT64: { + return PhysicalTypeID::UINT64; + } + case LogicalTypeID::UINT32: { + return PhysicalTypeID::UINT32; + } + case LogicalTypeID::UINT16: { + return PhysicalTypeID::UINT16; + } + case LogicalTypeID::UINT8: { + return PhysicalTypeID::UINT8; + } + case LogicalTypeID::UUID: + case LogicalTypeID::INT128: { + return PhysicalTypeID::INT128; + } + case LogicalTypeID::DOUBLE: { + return PhysicalTypeID::DOUBLE; + } + case LogicalTypeID::FLOAT: { + return PhysicalTypeID::FLOAT; + } + case LogicalTypeID::DECIMAL: { + if (extraTypeInfo == nullptr) { + throw BinderException(getIncompleteTypeErrMsg(typeID)); + } + auto decimalTypeInfo = extraTypeInfo->constPtrCast(); + auto precision = decimalTypeInfo->getPrecision(); + if (precision <= 4) { + return PhysicalTypeID::INT16; + } else if (precision <= 9) { + return PhysicalTypeID::INT32; + } else if (precision <= 18) { + return PhysicalTypeID::INT64; + } else if (precision <= 38) { + return PhysicalTypeID::INT128; + } else { + throw BinderException("Precision of decimal must be no greater than 38"); + } + } + case LogicalTypeID::INTERVAL: { + return PhysicalTypeID::INTERVAL; + } + case LogicalTypeID::INTERNAL_ID: { + return PhysicalTypeID::INTERNAL_ID; + } + case LogicalTypeID::UINT128: { + return PhysicalTypeID::UINT128; + } + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: { + return PhysicalTypeID::STRING; + } + case LogicalTypeID::MAP: + case LogicalTypeID::LIST: { + return PhysicalTypeID::LIST; + } + case LogicalTypeID::ARRAY: { + return PhysicalTypeID::ARRAY; + } + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::UNION: + case LogicalTypeID::STRUCT: { + return PhysicalTypeID::STRUCT; + } + case LogicalTypeID::POINTER: { + return PhysicalTypeID::POINTER; + } + default: + KU_UNREACHABLE; + } +} + +bool tryGetIDFromString(const std::string& str, LogicalTypeID& id) { + auto upperStr = StringUtils::getUpper(str); + if ("INTERNAL_ID" == upperStr) { + id = LogicalTypeID::INTERNAL_ID; + } else if ("INT64" == upperStr) { + id = LogicalTypeID::INT64; + } else if ("INT32" == upperStr || "INT" == upperStr) { + id = LogicalTypeID::INT32; + } else if ("INT16" == upperStr) { + id = LogicalTypeID::INT16; + } else if ("INT8" == upperStr) { + id = LogicalTypeID::INT8; + } else if ("UINT64" == upperStr) { + id = LogicalTypeID::UINT64; + } else if ("UINT32" == upperStr) { + id = LogicalTypeID::UINT32; + } else if ("UINT16" == upperStr) { + id = LogicalTypeID::UINT16; + } else if ("UINT8" == upperStr) { + id = LogicalTypeID::UINT8; + } else if ("INT128" == upperStr) { + id = LogicalTypeID::INT128; + } else if ("UINT128" == upperStr) { + id = LogicalTypeID::UINT128; + } else if ("DOUBLE" == upperStr || "FLOAT8" == upperStr) { + id = LogicalTypeID::DOUBLE; + } else if ("FLOAT" == upperStr || "FLOAT4" == upperStr || "REAL" == upperStr) { + id = LogicalTypeID::FLOAT; + } else if ("DECIMAL" == upperStr || "NUMERIC" == upperStr) { + id = LogicalTypeID::DECIMAL; + } else if ("BOOLEAN" == upperStr || "BOOL" == upperStr) { + id = LogicalTypeID::BOOL; + } else if ("BYTEA" == upperStr || "BLOB" == upperStr) { + id = LogicalTypeID::BLOB; + } else if ("UUID" == upperStr) { + id = LogicalTypeID::UUID; + } else if ("STRING" == upperStr) { + id = LogicalTypeID::STRING; + } else if ("DATE" == upperStr) { + id = LogicalTypeID::DATE; + } else if ("TIMESTAMP" == upperStr) { + id = LogicalTypeID::TIMESTAMP; + } else if ("TIMESTAMP_NS" == upperStr) { + id = LogicalTypeID::TIMESTAMP_NS; + } else if ("TIMESTAMP_MS" == upperStr) { + id = LogicalTypeID::TIMESTAMP_MS; + } else if ("TIMESTAMP_SEC" == upperStr || "TIMESTAMP_S" == upperStr) { + id = LogicalTypeID::TIMESTAMP_SEC; + } else if ("TIMESTAMP_TZ" == upperStr) { + id = LogicalTypeID::TIMESTAMP_TZ; + } else if ("INTERVAL" == upperStr || "DURATION" == upperStr) { + id = LogicalTypeID::INTERVAL; + } else if ("SERIAL" == upperStr) { + id = LogicalTypeID::SERIAL; + } else { + return false; + } + return true; +} + +std::string LogicalTypeUtils::toString(LogicalTypeID dataTypeID) { + // LCOV_EXCL_START + switch (dataTypeID) { + case LogicalTypeID::ANY: + return "ANY"; + case LogicalTypeID::NODE: + return "NODE"; + case LogicalTypeID::REL: + return "REL"; + case LogicalTypeID::RECURSIVE_REL: + return "RECURSIVE_REL"; + case LogicalTypeID::INTERNAL_ID: + return "INTERNAL_ID"; + case LogicalTypeID::BOOL: + return "BOOL"; + case LogicalTypeID::INT64: + return "INT64"; + case LogicalTypeID::INT32: + return "INT32"; + case LogicalTypeID::INT16: + return "INT16"; + case LogicalTypeID::INT8: + return "INT8"; + case LogicalTypeID::UINT64: + return "UINT64"; + case LogicalTypeID::UINT32: + return "UINT32"; + case LogicalTypeID::UINT16: + return "UINT16"; + case LogicalTypeID::UINT8: + return "UINT8"; + case LogicalTypeID::INT128: + return "INT128"; + case LogicalTypeID::UINT128: + return "UINT128"; + case LogicalTypeID::DOUBLE: + return "DOUBLE"; + case LogicalTypeID::FLOAT: + return "FLOAT"; + case LogicalTypeID::DATE: + return "DATE"; + case LogicalTypeID::TIMESTAMP_NS: + return "TIMESTAMP_NS"; + case LogicalTypeID::TIMESTAMP_MS: + return "TIMESTAMP_MS"; + case LogicalTypeID::TIMESTAMP_SEC: + return "TIMESTAMP_SEC"; + case LogicalTypeID::TIMESTAMP_TZ: + return "TIMESTAMP_TZ"; + case LogicalTypeID::TIMESTAMP: + return "TIMESTAMP"; + case LogicalTypeID::INTERVAL: + return "INTERVAL"; + case LogicalTypeID::DECIMAL: + return "DECIMAL"; + case LogicalTypeID::BLOB: + return "BLOB"; + case LogicalTypeID::UUID: + return "UUID"; + case LogicalTypeID::STRING: + return "STRING"; + case LogicalTypeID::LIST: + return "LIST"; + case LogicalTypeID::ARRAY: + return "ARRAY"; + case LogicalTypeID::STRUCT: + return "STRUCT"; + case LogicalTypeID::SERIAL: + return "SERIAL"; + case LogicalTypeID::MAP: + return "MAP"; + case LogicalTypeID::UNION: + return "UNION"; + case LogicalTypeID::POINTER: + return "POINTER"; + default: + KU_UNREACHABLE; + } + // LCOV_EXCL_STOP +} + +std::string LogicalTypeUtils::toString(const std::vector& dataTypes) { + if (dataTypes.empty()) { + return {""}; + } + std::string result = "(" + dataTypes[0].toString(); + for (auto i = 1u; i < dataTypes.size(); ++i) { + result += "," + dataTypes[i].toString(); + } + result += ")"; + return result; +} + +std::string LogicalTypeUtils::toString(const std::vector& dataTypeIDs) { + if (dataTypeIDs.empty()) { + return {"()"}; + } + std::string result = "(" + LogicalTypeUtils::toString(dataTypeIDs[0]); + for (auto i = 1u; i < dataTypeIDs.size(); ++i) { + result += "," + LogicalTypeUtils::toString(dataTypeIDs[i]); + } + result += ")"; + return result; +} + +uint32_t LogicalTypeUtils::getRowLayoutSize(const LogicalType& type) { + switch (type.getPhysicalType()) { + case PhysicalTypeID::STRING: { + return sizeof(ku_string_t); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return sizeof(ku_list_t); + } + case PhysicalTypeID::STRUCT: { + uint32_t size = 0; + auto fieldsTypes = StructType::getFieldTypes(type); + for (const auto& fieldType : fieldsTypes) { + size += getRowLayoutSize(*fieldType); + } + size += NullBuffer::getNumBytesForNullValues(fieldsTypes.size()); + return size; + } + default: + return PhysicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); + } +} + +bool LogicalTypeUtils::isDate(const LogicalType& dataType) { + return isDate(dataType.typeID); +} + +bool LogicalTypeUtils::isDate(const LogicalTypeID& dataType) { + return dataType == LogicalTypeID::DATE; +} + +bool LogicalTypeUtils::isTimestamp(const LogicalType& dataType) { + return isTimestamp(dataType.typeID); +} + +bool LogicalTypeUtils::isTimestamp(const LogicalTypeID& dataType) { + switch (dataType) { + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_NS: + return true; + default: + return false; + } +} + +bool LogicalTypeUtils::isUnsigned(const LogicalType& dataType) { + return isUnsigned(dataType.typeID); +} + +bool LogicalTypeUtils::isUnsigned(const LogicalTypeID& dataType) { + switch (dataType) { + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::UINT128: + return true; + default: + return false; + } +} + +bool LogicalTypeUtils::isIntegral(const LogicalType& dataType) { + return isIntegral(dataType.typeID); +} + +bool LogicalTypeUtils::isIntegral(const LogicalTypeID& dataType) { + switch (dataType) { + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::INT8: + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT128: + case LogicalTypeID::SERIAL: + return true; + default: + return false; + } +} + +bool LogicalTypeUtils::isNumerical(const LogicalType& dataType) { + return isNumerical(dataType.typeID); +} + +bool LogicalTypeUtils::isNumerical(const LogicalTypeID& dataType) { + switch (dataType) { + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::INT8: + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT128: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + case LogicalTypeID::SERIAL: + case LogicalTypeID::DECIMAL: + return true; + default: + return false; + } +} + +bool LogicalTypeUtils::isFloatingPoint(const LogicalTypeID& dataType) { + switch (dataType) { + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + case LogicalTypeID::SERIAL: + case LogicalTypeID::DECIMAL: + return true; + default: + return false; + } +} + +bool LogicalTypeUtils::isNested(const LogicalType& dataType) { + return isNested(dataType.typeID); +} + +bool LogicalTypeUtils::isNested(lbug::common::LogicalTypeID logicalTypeID) { + switch (logicalTypeID) { + case LogicalTypeID::STRUCT: + case LogicalTypeID::LIST: + case LogicalTypeID::ARRAY: + case LogicalTypeID::UNION: + case LogicalTypeID::MAP: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + return true; + default: + return false; + } +} + +std::vector LogicalTypeUtils::getAllValidComparableLogicalTypes() { + return std::vector{LogicalTypeID::BOOL, LogicalTypeID::INT64, + LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::UINT64, + LogicalTypeID::UINT32, LogicalTypeID::UINT16, LogicalTypeID::UINT8, LogicalTypeID::INT128, + LogicalTypeID::UINT128, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT, LogicalTypeID::DATE, + LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP_NS, LogicalTypeID::TIMESTAMP_MS, + LogicalTypeID::TIMESTAMP_SEC, LogicalTypeID::TIMESTAMP_TZ, LogicalTypeID::INTERVAL, + LogicalTypeID::BLOB, LogicalTypeID::UUID, LogicalTypeID::STRING, LogicalTypeID::SERIAL}; +} + +std::vector LogicalTypeUtils::getIntegerTypeIDs() { + return std::vector{LogicalTypeID::INT128, LogicalTypeID::INT64, + LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL, + LogicalTypeID::UINT128, LogicalTypeID::UINT64, LogicalTypeID::UINT32, LogicalTypeID::UINT16, + LogicalTypeID::UINT8}; +} + +std::vector LogicalTypeUtils::getFloatingPointTypeIDs() { + return std::vector{LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT}; +} + +std::vector LogicalTypeUtils::getNumericalLogicalTypeIDs() { + auto integerTypes = getIntegerTypeIDs(); + auto floatingPointTypes = getFloatingPointTypeIDs(); + integerTypes.insert(integerTypes.end(), floatingPointTypes.begin(), floatingPointTypes.end()); + // integerTypes.push_back(LogicalTypeID::DECIMAL); // fixed point numeric + return integerTypes; +} + +std::vector LogicalTypeUtils::getAllValidLogicTypeIDs() { + return std::vector{LogicalTypeID::INTERNAL_ID, LogicalTypeID::BOOL, + LogicalTypeID::INT64, LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::INT8, + LogicalTypeID::UINT64, LogicalTypeID::UINT32, LogicalTypeID::UINT16, LogicalTypeID::UINT8, + LogicalTypeID::INT128, LogicalTypeID::UINT128, LogicalTypeID::DOUBLE, LogicalTypeID::STRING, + LogicalTypeID::BLOB, LogicalTypeID::UUID, LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP, + LogicalTypeID::TIMESTAMP_NS, LogicalTypeID::TIMESTAMP_MS, LogicalTypeID::TIMESTAMP_SEC, + LogicalTypeID::TIMESTAMP_TZ, LogicalTypeID::INTERVAL, LogicalTypeID::LIST, + LogicalTypeID::ARRAY, LogicalTypeID::MAP, LogicalTypeID::FLOAT, LogicalTypeID::SERIAL, + LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::RECURSIVE_REL, + LogicalTypeID::STRUCT, LogicalTypeID::UNION}; +} + +std::vector LogicalTypeUtils::getAllValidLogicTypes() { + std::vector typeVec; + typeVec.push_back(LogicalType::INTERNAL_ID()); + typeVec.push_back(LogicalType::BOOL()); + typeVec.push_back(LogicalType::INT32()); + typeVec.push_back(LogicalType::INT64()); + typeVec.push_back(LogicalType::INT16()); + typeVec.push_back(LogicalType::INT8()); + typeVec.push_back(LogicalType::UINT64()); + typeVec.push_back(LogicalType::UINT32()); + typeVec.push_back(LogicalType::UINT16()); + typeVec.push_back(LogicalType::UINT8()); + typeVec.push_back(LogicalType::INT128()); + typeVec.push_back(LogicalType::UINT128()); + typeVec.push_back(LogicalType::DOUBLE()); + typeVec.push_back(LogicalType::STRING()); + typeVec.push_back(LogicalType::BLOB()); + typeVec.push_back(LogicalType::UUID()); + typeVec.push_back(LogicalType::DATE()); + typeVec.push_back(LogicalType::TIMESTAMP()); + typeVec.push_back(LogicalType::TIMESTAMP_NS()); + typeVec.push_back(LogicalType::TIMESTAMP_MS()); + typeVec.push_back(LogicalType::TIMESTAMP_SEC()); + typeVec.push_back(LogicalType::TIMESTAMP_TZ()); + typeVec.push_back(LogicalType::INTERVAL()); + typeVec.push_back(LogicalType::LIST(LogicalType::ANY())); + typeVec.push_back(LogicalType::ARRAY(LogicalType::ANY(), 0)); + typeVec.push_back(LogicalType::MAP(LogicalType::ANY(), LogicalType::ANY())); + typeVec.push_back(LogicalType::FLOAT()); + typeVec.push_back(LogicalType::SERIAL()); + typeVec.push_back(LogicalType::NODE({})); + typeVec.push_back(LogicalType::REL({})); + typeVec.push_back(LogicalType::STRUCT({})); + typeVec.push_back(LogicalType::UNION({})); + return typeVec; +} + +std::vector parseStructFields(const std::string& structTypeStr) { + std::vector structFieldsStr; + auto startPos = 0u; + auto curPos = 0u; + auto numOpenBrackets = 0u; + while (curPos < structTypeStr.length()) { + switch (structTypeStr[curPos]) { + case '(': { + numOpenBrackets++; + } break; + case ')': { + numOpenBrackets--; + } break; + case ',': { + if (numOpenBrackets == 0) { + structFieldsStr.push_back( + StringUtils::ltrim(structTypeStr.substr(startPos, curPos - startPos))); + startPos = curPos + 1; + } + } break; + default: { + // Normal character, continue. + } + } + curPos++; + } + structFieldsStr.push_back( + StringUtils::ltrim(structTypeStr.substr(startPos, curPos - startPos))); + return structFieldsStr; +} + +LogicalType parseListType(const std::string& trimmedStr, main::ClientContext* context) { + return LogicalType::LIST( + LogicalType::convertFromString(trimmedStr.substr(0, trimmedStr.size() - 2), context)); +} + +LogicalType parseArrayType(const std::string& trimmedStr, main::ClientContext* context) { + auto leftBracketPos = trimmedStr.find_last_of('['); + auto rightBracketPos = trimmedStr.find_last_of(']'); + auto childType = + LogicalType(LogicalType::convertFromString(trimmedStr.substr(0, leftBracketPos), context)); + auto numElements = std::strtoll( + trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1).c_str(), + nullptr, 0 /* base */); + if (numElements <= 0) { + // Note: the parser already guarantees that the number of elements is a non-negative + // number. However, we still need to check whether the number of elements is 0. + throw BinderException("The number of elements in an array must be greater than 0. Given: " + + std::to_string(numElements) + "."); + } + return LogicalType::ARRAY(std::move(childType), numElements); +} + +std::vector parseStructTypeInfo(const std::string& structTypeStr, + main::ClientContext* context, std::string defType) { + auto leftBracketPos = structTypeStr.find('('); + auto rightBracketPos = structTypeStr.find_last_of(')'); + if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) { + throw Exception("Cannot parse struct type: " + structTypeStr); + } + // Remove the leading and trailing brackets. + auto structFieldsStr = + structTypeStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1); + std::vector structFields; + auto structFieldStrs = parseStructFields(structFieldsStr); + auto numFields = structFieldStrs.size(); + if (numFields > INVALID_STRUCT_FIELD_IDX + 1) { + throw BinderException(stringFormat("Too many fields in {} definition (max {}, got {})", + defType, INVALID_STRUCT_FIELD_IDX + 1, numFields)); + } + std::set fieldNames; + for (auto& structFieldStr : structFieldStrs) { + auto pos = structFieldStr.find(' '); + auto fieldName = structFieldStr.substr(0, pos); + if (!fieldNames.insert(fieldName).second) { + throw BinderException( + stringFormat("Duplicate field '{}' in {} definition", fieldName, defType)); + } + auto fieldTypeString = structFieldStr.substr(pos + 1); + LogicalType fieldType = LogicalType::convertFromString(fieldTypeString, context); + structFields.emplace_back(fieldName, std::move(fieldType)); + } + return structFields; +} + +LogicalType parseStructType(const std::string& trimmedStr, main::ClientContext* context) { + return LogicalType::STRUCT(parseStructTypeInfo(trimmedStr, context, "STRUCT")); +} + +LogicalType parseMapType(const std::string& trimmedStr, main::ClientContext* context) { + auto leftBracketPos = trimmedStr.find('('); + auto rightBracketPos = trimmedStr.find_last_of(')'); + if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) { + throw Exception("Cannot parse map type: " + trimmedStr); + } + auto mapTypeStr = trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1); + auto keyValueTypes = StringUtils::splitComma(mapTypeStr); + return LogicalType::MAP(LogicalType::convertFromString(keyValueTypes[0], context), + LogicalType::convertFromString(keyValueTypes[1], context)); +} + +LogicalType parseUnionType(const std::string& trimmedStr, main::ClientContext* context) { + return LogicalType::UNION(parseStructTypeInfo(trimmedStr, context, "UNION")); +} + +LogicalType parseDecimalType(const std::string& trimmedStr) { + auto leftBracketPos = trimmedStr.find_last_of('('); + auto rightBracketPos = trimmedStr.find_last_of(')'); + if (leftBracketPos == std::string::npos) { + return LogicalType::DECIMAL(18, 3); + } + auto paramSubstr = StringUtils::ltrim(StringUtils::rtrim( + trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1))); + auto commaPos = paramSubstr.find_last_of(','); + if (commaPos == std::string::npos) { + throw BinderException("Only found 1 parameter for NUMERIC/DECIMAL type, expected 2"); + } + auto precisionStr = StringUtils::ltrim(StringUtils::rtrim(paramSubstr.substr(0, commaPos))); + auto scaleStr = StringUtils::ltrim(StringUtils::rtrim(paramSubstr.substr(commaPos + 1))); + auto precision = std::strtoll(precisionStr.c_str(), nullptr, 0); + auto scale = std::strtoll(scaleStr.c_str(), nullptr, 0); + if (precision <= 0 || precision > 38) { + throw BinderException( + "Precision of DECIMAL/NUMERIC must be a positive integer no greater than 38"); + } + if (scale < 0 || scale > precision) { + throw BinderException( + "Scale of DECIMAL/NUMERIC must be a nonnegative integer no greater than the precision"); + } + return LogicalType::DECIMAL((uint32_t)precision, (uint32_t)scale); +} + +LogicalType LogicalType::DECIMAL(uint32_t precision, uint32_t scale) { + return LogicalType(LogicalTypeID::DECIMAL, std::make_unique(precision, scale)); +} + +LogicalType LogicalType::STRUCT(std::vector&& fields) { + return LogicalType(LogicalTypeID::STRUCT, std::make_unique(std::move(fields))); +} + +LogicalType LogicalType::RECURSIVE_REL(std::vector&& fields) { + return LogicalType(LogicalTypeID::RECURSIVE_REL, + std::make_unique(std::move(fields))); +} + +LogicalType LogicalType::NODE(std::vector&& fields) { + return LogicalType(LogicalTypeID::NODE, std::make_unique(std::move(fields))); +} + +LogicalType LogicalType::REL(std::vector&& fields) { + return LogicalType(LogicalTypeID::REL, std::make_unique(std::move(fields))); +} + +LogicalType LogicalType::UNION(std::vector&& fields) { + // TODO(Ziy): Use UINT8 to represent tag value. + fields.insert(fields.begin(), + StructField(UnionType::TAG_FIELD_NAME, LogicalType(UnionType::TAG_FIELD_TYPE))); + return LogicalType(LogicalTypeID::UNION, std::make_unique(std::move(fields))); +} + +LogicalType LogicalType::LIST(LogicalType childType) { + return LogicalType(LogicalTypeID::LIST, std::make_unique(std::move(childType))); +} + +LogicalType LogicalType::MAP(LogicalType keyType, LogicalType valueType) { + std::vector structFields; + structFields.emplace_back(InternalKeyword::MAP_KEY, std::move(keyType)); + structFields.emplace_back(InternalKeyword::MAP_VALUE, std::move(valueType)); + auto mapStructType = LogicalType::STRUCT(std::move(structFields)); + return LogicalType(LogicalTypeID::MAP, + std::make_unique(std::move(mapStructType))); +} + +LogicalType LogicalType::ARRAY(LogicalType childType, uint64_t numElements) { + return LogicalType(LogicalTypeID::ARRAY, + std::make_unique(std::move(childType), numElements)); +} + +// If we can combine the child types, then we can combine the list +static bool tryCombineListTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + LogicalType childType; + if (!LogicalTypeUtils::tryGetMaxLogicalType(ListType::getChildType(left), + ListType::getChildType(right), childType)) { + return false; + } + result = LogicalType::LIST(std::move(childType)); + return true; +} + +static bool tryCombineArrayTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + if (ArrayType::getNumElements(left) != ArrayType::getNumElements(right)) { + return tryCombineListTypes(left, right, result); + } + LogicalType childType; + if (!LogicalTypeUtils::tryGetMaxLogicalType(ArrayType::getChildType(left), + ArrayType::getChildType(right), childType)) { + return false; + } + result = LogicalType::ARRAY(std::move(childType), ArrayType::getNumElements(left)); + return true; +} + +static bool tryCombineListArrayTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + LogicalType childType; + if (!LogicalTypeUtils::tryGetMaxLogicalType(ListType::getChildType(left), + ArrayType::getChildType(right), childType)) { + return false; + } + result = LogicalType::LIST(std::move(childType)); + return true; +} + +// If we can match child labels and combine their types, then we can combine +// the struct +static bool tryCombineStructTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + const auto& leftFields = StructType::getFields(left); + const auto& rightFields = StructType::getFields(right); + if (leftFields.size() != rightFields.size()) { + return false; + } + std::vector newFields; + for (auto i = 0u; i < leftFields.size(); i++) { + if (leftFields[i].getName() != rightFields[i].getName()) { + return false; + } + LogicalType combinedType; + if (LogicalTypeUtils::tryGetMaxLogicalType(leftFields[i].getType(), + rightFields[i].getType(), combinedType)) { + newFields.push_back(StructField(leftFields[i].getName(), std::move(combinedType))); + } else { + return false; + } + } + result = LogicalType::STRUCT(std::move(newFields)); + return true; +} + +// If we can combine the key and value, then we cna combine the map +static bool tryCombineMapTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + const auto& leftKeyType = MapType::getKeyType(left); + const auto& leftValueType = MapType::getValueType(left); + const auto& rightKeyType = MapType::getKeyType(right); + const auto& rightValueType = MapType::getValueType(right); + LogicalType resultKeyType, resultValueType; + if (!LogicalTypeUtils::tryGetMaxLogicalType(leftKeyType, rightKeyType, resultKeyType) || + !LogicalTypeUtils::tryGetMaxLogicalType(leftValueType, rightValueType, resultValueType)) { + return false; + } + result = LogicalType::MAP(std::move(resultKeyType), std::move(resultValueType)); + return true; +} + +/* +// If one of the unions labels is a subset of the other labels, and we can +// combine corresponding labels, then we can combine the union +static bool tryCombineUnionTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + auto leftFields = StructType::getFields(left), rightFields = StructType::getFields(right); + if (leftFields.size() > rightFields.size()) { + std::swap(leftFields, rightFields); + } + std::vector newFields; + for (auto i = 1u, j = 1u; i < leftFields.size(); i++) { + while (j < rightFields.size() && leftFields[i].getName() != rightFields[j].getName()) { + j++; + } + if (j == rightFields.size()) { + return false; + } + LogicalType combinedType; + if (!LogicalTypeUtils::tryGetMaxLogicalType(leftFields[i].getType(), + rightFields[j].getType(), combinedType)) { + newFields.push_back( + StructField(leftFields[i].getName(), LogicalType(combinedType))); + } + } + result = LogicalType::UNION(std::move(newFields)); + return true; +} +*/ + +static LogicalTypeID joinToWiderType(const LogicalTypeID& left, const LogicalTypeID& right) { + KU_ASSERT(LogicalTypeUtils::isIntegral(left)); + KU_ASSERT(LogicalTypeUtils::isIntegral(right)); + if (PhysicalTypeUtils::getFixedTypeSize(LogicalType::getPhysicalType(left)) > + PhysicalTypeUtils::getFixedTypeSize(LogicalType::getPhysicalType(right))) { + return left; + } else { + return right; + } +} + +static bool tryUnsignedToSigned(const LogicalTypeID& input, LogicalTypeID& result) { + switch (input) { + case LogicalTypeID::UINT8: + result = LogicalTypeID::INT16; + break; + case LogicalTypeID::UINT16: + result = LogicalTypeID::INT32; + break; + case LogicalTypeID::UINT32: + result = LogicalTypeID::INT64; + break; + case LogicalTypeID::UINT64: + result = LogicalTypeID::INT128; + break; + default: + return false; + } + return true; +} + +static LogicalTypeID joinDifferentSignIntegrals(const LogicalTypeID& signedType, + const LogicalTypeID& unsignedType) { + auto unsignedToSigned = LogicalTypeID::ANY; + if (!tryUnsignedToSigned(unsignedType, unsignedToSigned)) { + return LogicalTypeID::DOUBLE; + } else { + return joinToWiderType(signedType, unsignedToSigned); + } +} + +static uint32_t internalTimeOrder(const LogicalTypeID& type) { + switch (type) { + case LogicalTypeID::DATE: + return 50; + case LogicalTypeID::TIMESTAMP_SEC: + return 51; + case LogicalTypeID::TIMESTAMP_MS: + return 52; + case LogicalTypeID::TIMESTAMP: + return 53; + case LogicalTypeID::TIMESTAMP_TZ: + return 54; + case LogicalTypeID::TIMESTAMP_NS: + return 55; + default: + return 0; // return 0 if not timestamp + } +} + +static int alwaysCastOrder(const LogicalTypeID& typeID) { + switch (typeID) { + case LogicalTypeID::ANY: + return 0; + case LogicalTypeID::STRING: + return 2; + default: + return -1; + } +} + +static bool canAlwaysCast(const LogicalTypeID& typeID) { + switch (typeID) { + case LogicalTypeID::ANY: + case LogicalTypeID::STRING: + return true; + default: + return false; + } +} + +bool LogicalTypeUtils::tryGetMaxLogicalTypeID(const LogicalTypeID& left, const LogicalTypeID& right, + LogicalTypeID& result) { + if (canAlwaysCast(left) && canAlwaysCast(right)) { + if (alwaysCastOrder(left) > alwaysCastOrder(right)) { + result = left; + } else { + result = right; + } + return true; + } + if (left == right || canAlwaysCast(left)) { + result = right; + return true; + } + if (canAlwaysCast(right)) { + result = left; + return true; + } + auto leftToRight = BuiltInFunctionsUtils::getCastCost(left, right); + auto rightToLeft = BuiltInFunctionsUtils::getCastCost(right, left); + if (leftToRight != UNDEFINED_CAST_COST || rightToLeft != UNDEFINED_CAST_COST) { + if (leftToRight < rightToLeft) { + result = right; + } else { + result = left; + } + return true; + } + if (isIntegral(left) && isIntegral(right)) { + if (isUnsigned(left) && !isUnsigned(right)) { + result = joinDifferentSignIntegrals(right, left); + return true; + } else if (isUnsigned(right) && !isUnsigned(left)) { + result = joinDifferentSignIntegrals(left, right); + return true; + } + } + + // check timestamp combination + // note: this will become obsolete if implicit casting + // between timestamps is allowed + auto leftOrder = internalTimeOrder(left); + auto rightOrder = internalTimeOrder(right); + if (leftOrder && rightOrder) { + if (leftOrder > rightOrder) { + result = left; + } else { + result = right; + } + return true; + } + + return false; +} + +static inline bool isSemanticallyNested(LogicalTypeID ID) { + return LogicalTypeUtils::isNested(ID); +} + +static inline bool tryCombineDecimalTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + auto precisionLeft = DecimalType::getPrecision(left); + auto scaleLeft = DecimalType::getScale(left); + auto precisionRight = DecimalType::getPrecision(right); + auto scaleRight = DecimalType::getScale(right); + auto resultingScale = std::max(scaleLeft, scaleRight); + auto resultingPrecision = + std::max(precisionLeft - scaleLeft, precisionRight - scaleRight) + resultingScale; + if (resultingPrecision > DECIMAL_PRECISION_LIMIT) { + result = LogicalType::DOUBLE(); + return true; + } + result = LogicalType::DECIMAL(resultingPrecision, resultingScale); + return true; +} + +static inline bool tryCombineDecimalWithNumeric(const LogicalType& dec, const LogicalType& nonDec, + LogicalType& result) { + auto precision = DecimalType::getPrecision(dec); + auto scale = DecimalType::getScale(dec); + uint32_t requiredDigits = 0; + // How many digits before the decimal point does result require? + switch (nonDec.getLogicalTypeID()) { + case LogicalTypeID::INT8: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::UINT8: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::INT16: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::UINT16: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::INT32: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::UINT32: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::INT64: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::UINT64: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::INT128: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + case LogicalTypeID::UINT128: + requiredDigits = function::NumericLimits::maxNumDigits(); + break; + default: + requiredDigits = DECIMAL_PRECISION_LIMIT + 1; + } + if (requiredDigits + scale > DECIMAL_PRECISION_LIMIT) { + result = LogicalType::DOUBLE(); + return true; + } + result = LogicalType::DECIMAL(std::max(requiredDigits + scale, precision), scale); + return true; +} + +bool LogicalTypeUtils::tryGetMaxLogicalType(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + if (canAlwaysCast(left.typeID) && canAlwaysCast(right.typeID)) { + if (alwaysCastOrder(left.typeID) > alwaysCastOrder(right.typeID)) { + result = left.copy(); + } else { + result = right.copy(); + } + return true; + } + if (left == right || canAlwaysCast(left.typeID)) { + result = right.copy(); + return true; + } + if (canAlwaysCast(right.typeID)) { + result = left.copy(); + return true; + } + if (left.typeID == LogicalTypeID::DECIMAL && right.typeID == LogicalTypeID::DECIMAL) { + return tryCombineDecimalTypes(left, right, result); + } + if (left.typeID == LogicalTypeID::DECIMAL && LogicalTypeUtils::isNumerical(right.typeID)) { + return tryCombineDecimalWithNumeric(left, right, result); + } + if (right.typeID == LogicalTypeID::DECIMAL && LogicalTypeUtils::isNumerical(left.typeID)) { + return tryCombineDecimalWithNumeric(right, left, result); + } + if (isSemanticallyNested(left.typeID) || isSemanticallyNested(right.typeID)) { + if (left.typeID == LogicalTypeID::LIST && right.typeID == LogicalTypeID::ARRAY) { + return tryCombineListArrayTypes(left, right, result); + } + if (left.typeID == LogicalTypeID::ARRAY && right.typeID == LogicalTypeID::LIST) { + return tryCombineListArrayTypes(right, left, result); + } + if (left.typeID != right.typeID) { + return false; + } + switch (left.typeID) { + case LogicalTypeID::LIST: + return tryCombineListTypes(left, right, result); + case LogicalTypeID::ARRAY: + return tryCombineArrayTypes(left, right, result); + case LogicalTypeID::STRUCT: + return tryCombineStructTypes(left, right, result); + case LogicalTypeID::MAP: + return tryCombineMapTypes(left, right, result); + // LCOV_EXCL_START + case LogicalTypeID::UNION: + throw ConversionException("Union casting is not supported"); + // return tryCombineUnionTypes(left, right, result); + default: + throw RuntimeException(stringFormat("Casting between {} and {} is not implemented.", + left.toString(), right.toString())); + // LCOV_EXCL_END + } + } + auto resultID = LogicalTypeID::ANY; + if (!tryGetMaxLogicalTypeID(left.typeID, right.typeID, resultID)) { + return false; + } + // attempt to make complete types first + if (resultID == left.typeID) { + result = left.copy(); + } else if (resultID == right.typeID) { + result = right.copy(); + } else { + result = LogicalType(resultID); + } + return true; +} + +bool LogicalTypeUtils::tryGetMaxLogicalType(const std::vector& types, + LogicalType& result) { + LogicalType combinedType(LogicalTypeID::ANY); + for (auto& type : types) { + if (!tryGetMaxLogicalType(combinedType, type, combinedType)) { + return false; + } + } + result = combinedType.copy(); + return true; +} + +LogicalType LogicalTypeUtils::combineTypes(const LogicalType& lft, + const LogicalType& rit) { // always succeeds + if (lft.getLogicalTypeID() == LogicalTypeID::STRING || + rit.getLogicalTypeID() == LogicalTypeID::STRING) { + return LogicalType::STRING(); + } + if (isSemanticallyNested(lft.getLogicalTypeID()) && + isSemanticallyNested(rit.getLogicalTypeID())) {} + if (lft.getLogicalTypeID() == rit.getLogicalTypeID() && + lft.getLogicalTypeID() == LogicalTypeID::STRUCT) { + std::vector resultingFields; + for (const auto& i : StructType::getFields(lft)) { + auto name = i.getName(); + if (StructType::hasField(rit, name)) { + resultingFields.emplace_back(name, + combineTypes(i.getType(), StructType::getFieldType(rit, name))); + } else { + resultingFields.push_back(i.copy()); + } + } + for (const auto& i : StructType::getFields(rit)) { + auto name = i.getName(); + if (!StructType::hasField(lft, name)) { + resultingFields.push_back(i.copy()); + } + } + return LogicalType::STRUCT(std::move(resultingFields)); + } + if (lft.getLogicalTypeID() == rit.getLogicalTypeID() && + lft.getLogicalTypeID() == LogicalTypeID::LIST) { + const auto& lftChild = ListType::getChildType(lft); + const auto& ritChild = ListType::getChildType(rit); + return LogicalType::LIST(combineTypes(lftChild, ritChild)); + } + if (lft.getLogicalTypeID() == rit.getLogicalTypeID() && + lft.getLogicalTypeID() == LogicalTypeID::MAP) { + const auto& lftKey = MapType::getKeyType(lft); + const auto& lftValue = MapType::getValueType(lft); + const auto& ritKey = MapType::getKeyType(rit); + const auto& ritValue = MapType::getValueType(rit); + return LogicalType::MAP(combineTypes(lftKey, ritKey), combineTypes(lftValue, ritValue)); + } + common::LogicalType result; + if (!tryGetMaxLogicalType(lft, rit, result)) { + return LogicalType::STRING(); + } + return result; +} + +LogicalType LogicalTypeUtils::combineTypes(const std::vector& types) { + if (types.empty()) { + // LCOV_EXCL_START + throw RuntimeException( + stringFormat("Trying to combine empty types. This should never happen.")); + // LCOV_EXCL_STOP + } + if (types.size() == 1) { + return types[0].copy(); + } + auto result = combineTypes(types[0], types[1]); + for (auto i = 2u; i < types.size(); i++) { + result = combineTypes(result, types[i]); + } + return result; +} + +LogicalType LogicalTypeUtils::purgeAny(const LogicalType& type, const LogicalType& replacement) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::ANY: + return replacement.copy(); + case LogicalTypeID::LIST: + return LogicalType::LIST(purgeAny(ListType::getChildType(type), replacement)); + case LogicalTypeID::ARRAY: + return LogicalType::ARRAY(purgeAny(ArrayType::getChildType(type), replacement), + ArrayType::getNumElements(type)); + case LogicalTypeID::MAP: + return LogicalType::MAP(purgeAny(MapType::getKeyType(type), replacement), + purgeAny(MapType::getValueType(type), replacement)); + case LogicalTypeID::STRUCT: { + std::vector fields; + for (const auto& i : StructType::getFields(type)) { + fields.emplace_back(i.getName(), purgeAny(i.getType(), replacement)); + } + return LogicalType::STRUCT(std::move(fields)); + } + default: + return type.copy(); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/uint128_t.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/uint128_t.cpp new file mode 100644 index 0000000000..f70b081812 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/uint128_t.cpp @@ -0,0 +1,673 @@ +#include "common/types/uint128_t.h" + +#include + +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "common/types/int128_t.h" +#include "function/cast/functions/numeric_limits.h" +#include "function/hash/hash_functions.h" +#include + +namespace lbug::common { + +static uint8_t uint128BitsAmount(uint128_t input) { + if (input.high) { + return 128 - std::countl_zero(input.high); + } else { + return 64 - std::countl_zero(input.low); + } +} + +static bool uint128IsBitSet(uint128_t input, uint8_t bit) { + if (bit < 64) { + return input.low & (1ULL << uint64_t(bit)); + } else { + return input.high & (1ULL << uint64_t(bit - 64)); + } +} + +uint128_t uint128LeftShift(uint128_t lhs, uint32_t amount) { + uint128_t result{}; + result.low = lhs.low << amount; + result.high = (lhs.high << amount) + (lhs.low >> (64 - amount)); + return result; +} + +uint128_t UInt128_t::divModPositive(uint128_t lhs, uint64_t rhs, uint64_t& remainder) { + uint128_t result{0}; + remainder = 0; + + for (uint8_t i = uint128BitsAmount(lhs); i > 0; i--) { + result = uint128LeftShift(result, 1); + remainder <<= 1; + if (uint128IsBitSet(lhs, i - 1)) { + remainder++; + } + if (remainder >= rhs) { + remainder -= rhs; + result.low++; + if (result.low == 0) { + result.high++; + } + } + } + return result; +} + +std::string UInt128_t::toString(uint128_t input) { + std::string result; + uint64_t remainder = 0; + + while (input.high != 0 || input.low != 0) { + input = divModPositive(input, 10, remainder); + result = std::string(1, '0' + remainder) + std::move(result); + } + + if (result.empty()) { + result = "0"; + } + + return result; +} + +bool UInt128_t::addInPlace(uint128_t& lhs, uint128_t rhs) { + int overflow = lhs.low + rhs.low < lhs.low; + if (lhs.high > UINT64_MAX - rhs.high - overflow || + (rhs.high == UINT64_MAX && + lhs.high + overflow != 0)) { // need second condition in case the unsigned (UINT64_MAX - + // rhs.high - overflow) evaluates to -1 + return false; + } + lhs.high = lhs.high + rhs.high + overflow; + lhs.low += rhs.low; + return true; +} + +bool UInt128_t::subInPlace(uint128_t& lhs, uint128_t rhs) { + // check if lhs > rhs; if so return false + if (UInt128_t::lessThan(lhs, rhs)) { + return false; + } + int underflow = lhs.low - rhs.low > lhs.low; + lhs.high = lhs.high - rhs.high - underflow; + lhs.low -= rhs.low; + return true; +} + +uint128_t UInt128_t::Add(uint128_t lhs, const uint128_t rhs) { + if (!addInPlace(lhs, rhs)) { + throw common::OverflowException("UINT128 is out of range: cannot add."); + } + return lhs; +} + +uint128_t UInt128_t::Sub(uint128_t lhs, const uint128_t rhs) { + if (!subInPlace(lhs, rhs)) { + throw common::OverflowException("UINT128 is out of range: cannot subtract."); + } + return lhs; +} + +bool UInt128_t::tryMultiply(uint128_t lhs, uint128_t rhs, uint128_t& result) { +#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) + __uint128_t left = __uint128_t(lhs.low) + (__uint128_t(lhs.high) << 64); + __uint128_t right = __uint128_t(rhs.low) + (__uint128_t(rhs.high) << 64); + __uint128_t result_ui128 = 0; + if (__builtin_mul_overflow(left, right, &result_ui128)) { + return false; + } + result.high = uint64_t(result_ui128 >> 64); + result.low = uint64_t(result_ui128 & 0xffffffffffffffff); +#else + // Multiply code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + // License: https://github.com/calccrypto/uint128_t/blob/c%2B%2B11_14/LICENSE + uint64_t top[4] = {uint64_t(lhs.high) >> 32, uint64_t(lhs.high) & 0xffffffff, lhs.low >> 32, + lhs.low & 0xffffffff}; + uint64_t bottom[4] = {uint64_t(rhs.high) >> 32, uint64_t(rhs.high) & 0xffffffff, rhs.low >> 32, + rhs.low & 0xffffffff}; + uint64_t products[4][4]; + + // multiply each component of the values + for (auto x = 0; x < 4; x++) { + for (auto y = 0; y < 4; y++) { + products[x][y] = top[x] * bottom[y]; + } + } + + // if any of these products are set to a non-zero value, there is always an overflow + if (products[0][0] || products[0][1] || products[0][2] || products[1][0] || products[2][0] || + products[1][1]) { + return false; + } + // if the high bits of any of these are set, there is always an overflow + if ((products[0][3] & 0xffffffff00000000) || (products[1][2] & 0xffffffff00000000) || + (products[2][1] & 0xffffffff00000000) || (products[3][0] & 0xffffffff00000000)) { + return false; + } + + // otherwise we merge the result of the different products together in-order + + // first row + uint64_t fourth32 = (products[3][3] & 0xffffffff); + uint64_t third32 = (products[3][2] & 0xffffffff) + (products[3][3] >> 32); + uint64_t second32 = (products[3][1] & 0xffffffff) + (products[3][2] >> 32); + uint64_t first32 = (products[3][0] & 0xffffffff) + (products[3][1] >> 32); + + // second row + third32 += (products[2][3] & 0xffffffff); + second32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); + first32 += (products[2][1] & 0xffffffff) + (products[2][2] >> 32); + + // third row + second32 += (products[1][3] & 0xffffffff); + first32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); + + // fourth row + first32 += (products[0][3] & 0xffffffff); + + // move carry to next digit + third32 += fourth32 >> 32; + second32 += third32 >> 32; + first32 += second32 >> 32; + + // check if the combination of the different products resulted in an overflow + if (first32 & 0xffffff00000000) { + return false; + } + + // remove carry from current digit + fourth32 &= 0xffffffff; + third32 &= 0xffffffff; + second32 &= 0xffffffff; + first32 &= 0xffffffff; + + // combine components + result.low = (third32 << 32) | fourth32; + result.high = (first32 << 32) | second32; +#endif + return true; +} + +uint128_t UInt128_t::Mul(uint128_t lhs, uint128_t rhs) { + uint128_t result{}; + if (!tryMultiply(lhs, rhs, result)) { + throw common::OverflowException("UINT128 is out of range: cannot multiply."); + } + return result; +} + +uint128_t UInt128_t::divMod(uint128_t lhs, uint128_t rhs, uint128_t& remainder) { + // divMod code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + // License: https://github.com/calccrypto/uint128_t/blob/c%2B%2B11_14/LICENSE + // initialize the result and remainder to 0 + uint128_t div_result{0}; + remainder.low = 0; + remainder.high = 0; + + // now iterate over the amount of bits that are set in the LHS + for (uint8_t x = uint128BitsAmount(lhs); x > 0; x--) { + // left-shift the current result and remainder by 1 + div_result = uint128LeftShift(div_result, 1); + remainder = uint128LeftShift(remainder, 1); + + // we get the value of the bit at position X, where position 0 is the least-significant bit + if (uint128IsBitSet(lhs, x - 1)) { + // increment the remainder + addInPlace(remainder, 1); + } + if (greaterThanOrEquals(remainder, rhs)) { + // the remainder has passed the division multiplier: add one to the divide result + remainder = Sub(remainder, rhs); + addInPlace(div_result, 1); + } + } + return div_result; +} + +uint128_t UInt128_t::Div(uint128_t lhs, uint128_t rhs) { + if (rhs.high == 0 && rhs.low == 0) { + throw common::RuntimeException("Divide by zero."); + } + uint128_t remainder{}; + return divMod(lhs, rhs, remainder); +} + +uint128_t UInt128_t::Mod(uint128_t lhs, uint128_t rhs) { + if (rhs.high == 0 && rhs.low == 0) { + throw common::RuntimeException("Modulo by zero."); + } + uint128_t result{}; + divMod(lhs, rhs, result); + return result; +} + +uint128_t UInt128_t::Xor(uint128_t lhs, uint128_t rhs) { + uint128_t result{lhs.low ^ rhs.low, lhs.high ^ rhs.high}; + return result; +} + +uint128_t UInt128_t::BinaryAnd(uint128_t lhs, uint128_t rhs) { + uint128_t result{lhs.low & rhs.low, lhs.high & rhs.high}; + return result; +} + +uint128_t UInt128_t::BinaryOr(uint128_t lhs, uint128_t rhs) { + uint128_t result{lhs.low | rhs.low, lhs.high | rhs.high}; + return result; +} + +uint128_t UInt128_t::BinaryNot(uint128_t val) { + return uint128_t{~val.low, ~val.high}; +} + +uint128_t UInt128_t::LeftShift(uint128_t lhs, int amount) { + return amount >= 64 ? + uint128_t(0, lhs.low << (amount - 64)) : + amount == 0 ? + lhs : + uint128_t{lhs.low << amount, (lhs.high << amount) | (lhs.low >> (64 - amount))}; +} + +uint128_t UInt128_t::RightShift(uint128_t lhs, int amount) { + return amount >= 64 ? + uint128_t(lhs.high >> (amount - 64), 0) : + amount == 0 ? + lhs : + uint128_t((lhs.low >> amount) | (lhs.high << (64 - amount)), lhs.high >> amount); +} + +//=============================================================================================== +// Cast operation +//=============================================================================================== +template +bool TryCastUint128Template(uint128_t input, DST& result) { + if (input.high == 0 && input.low <= uint64_t(function::NumericLimits::maximum())) { + result = static_cast(input.low); + return true; + } + return false; +} +// we can use the above template if we can get max using something like DST.max + +template<> +bool UInt128_t::tryCast(uint128_t input, int8_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, int16_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, int32_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, int64_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, uint8_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, uint16_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, uint32_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, uint64_t& result) { + return TryCastUint128Template(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, int128_t& result) { // unsigned to signed + if (input.high > (uint64_t)(function::NumericLimits::maximum())) { + return false; + } + result = {input.low, int64_t(input.high)}; + return true; +} + +template<> +bool UInt128_t::tryCast(uint128_t input, float& result) { + double temp_res = NAN; + tryCast(input, temp_res); + result = static_cast(temp_res); + return true; +} + +template +bool CastUint128ToFloating(uint128_t input, REAL_T& result) { + result = REAL_T(input.high) * REAL_T(function::NumericLimits::maximum()) + + REAL_T(input.low); + return true; +} + +template<> +bool UInt128_t::tryCast(uint128_t input, double& result) { + return CastUint128ToFloating(input, result); +} + +template<> +bool UInt128_t::tryCast(uint128_t input, long double& result) { + return CastUint128ToFloating(input, result); +} + +template +uint128_t tryCastToTemplate(SRC value) { + if (value < 0) { + throw common::OverflowException("Cannot cast negative value to UINT128."); + } + uint128_t result{}; + result.low = (uint64_t)value; + result.high = 0; + return result; +} + +template<> +bool UInt128_t::tryCastTo(int8_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(int16_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(int32_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(int64_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(uint8_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(uint16_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(uint32_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(uint64_t value, uint128_t& result) { + result = tryCastToTemplate(value); + return true; +} + +template<> +bool UInt128_t::tryCastTo(uint128_t value, uint128_t& result) { + result = value; + return true; +} + +template<> +bool UInt128_t::tryCastTo(float value, uint128_t& result) { + return tryCastTo(double(value), result); +} + +template +bool castFloatingToUint128(REAL_T value, uint128_t& result) { + if (value < 0.0 || value >= 340282366920938463463374607431768211455.0) { + return false; + } + value = std::nearbyint(value); + result.low = (uint64_t)fmod(value, REAL_T(function::NumericLimits::maximum())); + result.high = (uint64_t)(value / REAL_T(function::NumericLimits::maximum())); + return true; +} + +template<> +bool UInt128_t::tryCastTo(double value, uint128_t& result) { + return castFloatingToUint128(value, result); +} + +template<> +bool UInt128_t::tryCastTo(long double value, uint128_t& result) { + return castFloatingToUint128(value, result); +} +//=============================================================================================== + +template +void constructUInt128Template(T value, uint128_t& result) { + uint128_t casted = UInt128_t::castTo(value); + result.low = casted.low; + result.high = casted.high; +} + +uint128_t::uint128_t(int64_t value) { + auto result = UInt128_t::castTo(value); + this->low = result.low; + this->high = result.high; +} + +uint128_t::uint128_t(int32_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(int16_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(int8_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(uint64_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(uint32_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(uint16_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(uint8_t value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(double value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +uint128_t::uint128_t(float value) { // NOLINT: fields are constructed by the template + constructUInt128Template(value, *this); +} + +//============================================================================================ +bool operator==(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::equals(lhs, rhs); +} + +bool operator!=(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::notEquals(lhs, rhs); +} + +bool operator>(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::greaterThan(lhs, rhs); +} + +bool operator>=(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::greaterThanOrEquals(lhs, rhs); +} + +bool operator<(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::lessThan(lhs, rhs); +} + +bool operator<=(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::lessThanOrEquals(lhs, rhs); +} + +uint128_t uint128_t::operator-() const { + return UInt128_t::negate(*this); +} + +// support for operations like (int32_t)x + (uint128_t)y + +uint128_t operator+(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::Add(lhs, rhs); +} +uint128_t operator-(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::Sub(lhs, rhs); +} +uint128_t operator*(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::Mul(lhs, rhs); +} +uint128_t operator/(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::Div(lhs, rhs); +} +uint128_t operator%(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::Mod(lhs, rhs); +} + +uint128_t operator^(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::Xor(lhs, rhs); +} + +uint128_t operator&(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::BinaryAnd(lhs, rhs); +} + +uint128_t operator|(const uint128_t& lhs, const uint128_t& rhs) { + return UInt128_t::BinaryOr(lhs, rhs); +} + +uint128_t operator~(const uint128_t& val) { + return UInt128_t::BinaryNot(val); +} + +uint128_t operator<<(const uint128_t& lhs, int amount) { + return UInt128_t::LeftShift(lhs, amount); +} + +uint128_t operator>>(const uint128_t& lhs, int amount) { + return UInt128_t::RightShift(lhs, amount); +} + +// inplace arithmetic operators +uint128_t& uint128_t::operator+=(const uint128_t& rhs) { + if (!UInt128_t::addInPlace(*this, rhs)) { + throw common::OverflowException("UINT128 is out of range: cannot add in place."); + } + return *this; +} + +uint128_t& uint128_t::operator*=(const uint128_t& rhs) { + *this = UInt128_t::Mul(*this, rhs); + return *this; +} + +uint128_t& uint128_t::operator|=(const uint128_t& rhs) { + *this = UInt128_t::BinaryOr(*this, rhs); + return *this; +} + +uint128_t& uint128_t::operator&=(const uint128_t& rhs) { + *this = UInt128_t::BinaryAnd(*this, rhs); + return *this; +} + +template +static T NarrowCast(const uint128_t& input) { + return static_cast(input.low); +} + +uint128_t::operator int64_t() const { + return NarrowCast(*this); +} + +uint128_t::operator int32_t() const { + return NarrowCast(*this); +} + +uint128_t::operator int16_t() const { + return NarrowCast(*this); +} + +uint128_t::operator int8_t() const { + return NarrowCast(*this); +} + +uint128_t::operator uint64_t() const { + return NarrowCast(*this); +} + +uint128_t::operator uint32_t() const { + return NarrowCast(*this); +} + +uint128_t::operator uint16_t() const { + return NarrowCast(*this); +} + +uint128_t::operator uint8_t() const { + return NarrowCast(*this); +} + +uint128_t::operator double() const { + double result = NAN; + [[maybe_unused]] bool success = + UInt128_t::tryCast(*this, result); // casting to double should always succeed + KU_ASSERT(success); + return result; +} + +uint128_t::operator float() const { + float result = NAN; + [[maybe_unused]] bool success = UInt128_t::tryCast(*this, + result); // casting overly large values to float currently returns inf + KU_ASSERT(success); + return result; +} + +uint128_t::operator int128_t() const { + int128_t result{}; + if (!UInt128_t::tryCast(*this, result)) { + throw common::OverflowException(common::stringFormat("Value {} is not within INT128 range.", + common::TypeUtils::toString(*this))); + } + return result; +} + +} // namespace lbug::common + +std::size_t std::hash::operator()( + const lbug::common::uint128_t& v) const noexcept { + lbug::common::hash_t hash = 0; + lbug::function::Hash::operation(v, hash); + return hash; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/uuid.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/uuid.cpp new file mode 100644 index 0000000000..2d3cb89833 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/uuid.cpp @@ -0,0 +1,156 @@ +#include "common/types/uuid.h" + +#include "common/exception/conversion.h" +#include "common/random_engine.h" +#include "re2.h" + +namespace lbug { +namespace common { + +void UUID::byteToHex(char byteVal, char* buf, uint64_t& pos) { + buf[pos++] = HEX_DIGITS[(byteVal >> 4) & 0xf]; + buf[pos++] = HEX_DIGITS[byteVal & 0xf]; +} + +unsigned char UUID::hex2Char(char ch) { + if (ch >= '0' && ch <= '9') { + return ch - '0'; + } + if (ch >= 'a' && ch <= 'f') { + return 10 + ch - 'a'; + } + if (ch >= 'A' && ch <= 'F') { + return 10 + ch - 'A'; + } + return 0; +} + +bool UUID::isHex(char ch) { + return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F'); +} + +bool UUID::fromString(std::string str, int128_t& result) { + if (str.empty()) { + return false; + } + uint32_t numBrackets = 0; + if (str.front() == '{') { + numBrackets = 1; + } + // LCOV_EXCL_START + if (numBrackets && str.back() != '}') { + return false; + } + // LCOV_EXCL_STOP + + result.low = 0; + result.high = 0; + uint32_t count = 0; + for (auto i = numBrackets; i < str.size() - numBrackets; ++i) { + if (str[i] == '-') { + continue; + } + if (count >= 32 || !isHex(str[i])) { + return false; + } + if (count >= 16) { + result.low = (result.low << 4) | hex2Char(str[i]); + } else { + result.high = (result.high << 4) | hex2Char(str[i]); + } + count++; + } + // Flip the first bit to make `order by uuid` same as `order by uuid::varchar` + result.high ^= (int64_t(1) << 63); + return count == 32; +} + +int128_t UUID::fromString(std::string str) { + int128_t result = 0; + if (!fromString(str, result)) { + throw ConversionException("Invalid UUID: " + str); + } + return result; +} + +int128_t UUID::fromCString(const char* str, uint64_t len) { + return fromString(std::string(str, len)); +} + +void UUID::toString(int128_t input, char* buf) { + // Flip back before convert to string + int64_t high = input.high ^ (int64_t(1) << 63); + uint64_t pos = 0; + byteToHex(high >> 56 & 0xFF, buf, pos); + byteToHex(high >> 48 & 0xFF, buf, pos); + byteToHex(high >> 40 & 0xFF, buf, pos); + byteToHex(high >> 32 & 0xFF, buf, pos); + buf[pos++] = '-'; + byteToHex(high >> 24 & 0xFF, buf, pos); + byteToHex(high >> 16 & 0xFF, buf, pos); + buf[pos++] = '-'; + byteToHex(high >> 8 & 0xFF, buf, pos); + byteToHex(high & 0xFF, buf, pos); + buf[pos++] = '-'; + byteToHex(input.low >> 56 & 0xFF, buf, pos); + byteToHex(input.low >> 48 & 0xFF, buf, pos); + buf[pos++] = '-'; + byteToHex(input.low >> 40 & 0xFF, buf, pos); + byteToHex(input.low >> 32 & 0xFF, buf, pos); + byteToHex(input.low >> 24 & 0xFF, buf, pos); + byteToHex(input.low >> 16 & 0xFF, buf, pos); + byteToHex(input.low >> 8 & 0xFF, buf, pos); + byteToHex(input.low & 0xFF, buf, pos); +} + +std::string UUID::toString(int128_t input) { + char buff[UUID_STRING_LENGTH]; + toString(input, buff); + return std::string(buff, UUID_STRING_LENGTH); +} + +std::string UUID::toString(ku_uuid_t val) { + return toString(val.value); +} + +ku_uuid_t UUID::generateRandomUUID(RandomEngine* engine) { + uint8_t bytes[16]; + for (int i = 0; i < 16; i += 4) { + *reinterpret_cast(bytes + i) = engine->nextRandomInteger(); + } + // variant must be 10xxxxxx + bytes[8] &= 0xBF; + bytes[8] |= 0x80; + // version must be 0100xxxx + bytes[6] &= 0x4F; + bytes[6] |= 0x40; + + int128_t result = 0; + result.high = 0; + result.high |= ((int64_t)bytes[0] << 56); + result.high |= ((int64_t)bytes[1] << 48); + result.high |= ((int64_t)bytes[2] << 40); + result.high |= ((int64_t)bytes[3] << 32); + result.high |= ((int64_t)bytes[4] << 24); + result.high |= ((int64_t)bytes[5] << 16); + result.high |= ((int64_t)bytes[6] << 8); + result.high |= bytes[7]; + result.low = 0; + result.low |= ((uint64_t)bytes[8] << 56); + result.low |= ((uint64_t)bytes[9] << 48); + result.low |= ((uint64_t)bytes[10] << 40); + result.low |= ((uint64_t)bytes[11] << 32); + result.low |= ((uint64_t)bytes[12] << 24); + result.low |= ((uint64_t)bytes[13] << 16); + result.low |= ((uint64_t)bytes[14] << 8); + result.low |= bytes[15]; + return ku_uuid_t{result}; +} + +const regex::RE2& UUID::regexPattern() { + static regex::RE2 retval("(?i)[0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12}"); + return retval; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/CMakeLists.txt new file mode 100644 index 0000000000..45101be3f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_common_types_value + OBJECT + nested.cpp + node.cpp + recursive_rel.cpp + rel.cpp + value.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/nested.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/nested.cpp new file mode 100644 index 0000000000..0361d43e86 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/nested.cpp @@ -0,0 +1,21 @@ +#include "common/types/value/nested.h" + +#include "common/exception/runtime.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { + +uint32_t NestedVal::getChildrenSize(const Value* val) { + return val->childrenSize; +} + +Value* NestedVal::getChildVal(const Value* val, uint32_t idx) { + if (idx > val->childrenSize) { + throw RuntimeException("NestedVal::getChildVal index out of bound."); + } + return val->children[idx].get(); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/node.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/node.cpp new file mode 100644 index 0000000000..74105165bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/node.cpp @@ -0,0 +1,77 @@ +#include "common/types/value/node.h" + +#include "common/constants.h" +#include "common/string_format.h" +#include "common/types/types.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { + +std::vector>> NodeVal::getProperties( + const Value* val) { + throwIfNotNode(val); + std::vector>> properties; + auto fieldNames = StructType::getFieldNames(val->dataType); + for (auto i = 0u; i < val->childrenSize; ++i) { + auto currKey = fieldNames[i]; + if (currKey == InternalKeyword::ID || currKey == InternalKeyword::LABEL) { + continue; + } + properties.emplace_back(currKey, val->children[i]->copy()); + } + return properties; +} + +uint64_t NodeVal::getNumProperties(const Value* val) { + throwIfNotNode(val); + auto fieldNames = StructType::getFieldNames(val->dataType); + return fieldNames.size() - OFFSET; +} + +std::string NodeVal::getPropertyName(const Value* val, uint64_t index) { + throwIfNotNode(val); + auto fieldNames = StructType::getFieldNames(val->dataType); + if (index >= fieldNames.size() - OFFSET) { + return ""; + } + return fieldNames[index + OFFSET]; +} + +Value* NodeVal::getPropertyVal(const Value* val, uint64_t index) { + throwIfNotNode(val); + auto fieldNames = StructType::getFieldNames(val->dataType); + if (index >= fieldNames.size() - OFFSET) { + return nullptr; + } + return val->children[index + OFFSET].get(); +} + +Value* NodeVal::getNodeIDVal(const Value* val) { + throwIfNotNode(val); + auto fieldIdx = StructType::getFieldIdx(val->dataType, InternalKeyword::ID); + return val->children[fieldIdx].get(); +} + +Value* NodeVal::getLabelVal(const Value* val) { + throwIfNotNode(val); + auto fieldIdx = StructType::getFieldIdx(val->dataType, InternalKeyword::LABEL); + return val->children[fieldIdx].get(); +} + +std::string NodeVal::toString(const Value* val) { + throwIfNotNode(val); + return val->toString(); +} + +void NodeVal::throwIfNotNode(const Value* val) { + // LCOV_EXCL_START + if (val->dataType.getLogicalTypeID() != LogicalTypeID::NODE) { + throw Exception( + stringFormat("Expected NODE type, but got {} type", val->dataType.toString())); + } + // LCOV_EXCL_STOP +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/recursive_rel.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/recursive_rel.cpp new file mode 100644 index 0000000000..aaac11567a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/recursive_rel.cpp @@ -0,0 +1,31 @@ +#include "common/types/value/recursive_rel.h" + +#include "common/exception/exception.h" +#include "common/string_format.h" +#include "common/types/types.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { + +Value* RecursiveRelVal::getNodes(const Value* val) { + throwIfNotRecursiveRel(val); + return val->children[0].get(); +} + +Value* RecursiveRelVal::getRels(const Value* val) { + throwIfNotRecursiveRel(val); + return val->children[1].get(); +} + +void RecursiveRelVal::throwIfNotRecursiveRel(const Value* val) { + // LCOV_EXCL_START + if (val->dataType.getLogicalTypeID() != LogicalTypeID::RECURSIVE_REL) { + throw Exception( + stringFormat("Expected RECURSIVE_REL type, but got {} type", val->dataType.toString())); + } + // LCOV_EXCL_STOP +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/rel.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/rel.cpp new file mode 100644 index 0000000000..d75c471187 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/rel.cpp @@ -0,0 +1,86 @@ +#include "common/types/value/rel.h" + +#include "common/constants.h" +#include "common/string_format.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { + +std::vector>> RelVal::getProperties( + const Value* val) { + throwIfNotRel(val); + std::vector>> properties; + auto fieldNames = StructType::getFieldNames(val->dataType); + for (auto i = 0u; i < val->childrenSize; ++i) { + auto currKey = fieldNames[i]; + if (currKey == InternalKeyword::ID || currKey == InternalKeyword::LABEL || + currKey == InternalKeyword::SRC || currKey == InternalKeyword::DST) { + continue; + } + auto currVal = val->children[i]->copy(); + properties.emplace_back(currKey, std::move(currVal)); + } + return properties; +} + +uint64_t RelVal::getNumProperties(const Value* val) { + throwIfNotRel(val); + auto fieldNames = StructType::getFieldNames(val->dataType); + return fieldNames.size() - OFFSET; +} + +std::string RelVal::getPropertyName(const Value* val, uint64_t index) { + throwIfNotRel(val); + auto fieldNames = StructType::getFieldNames(val->dataType); + if (index >= fieldNames.size() - OFFSET) { + return ""; + } + return fieldNames[index + OFFSET]; +} + +Value* RelVal::getPropertyVal(const Value* val, uint64_t index) { + throwIfNotRel(val); + auto fieldNames = StructType::getFieldNames(val->dataType); + if (index >= fieldNames.size() - OFFSET) { + return nullptr; + } + return val->children[index + OFFSET].get(); +} + +Value* RelVal::getIDVal(const Value* val) { + auto fieldIdx = StructType::getFieldIdx(val->dataType, InternalKeyword::ID); + return val->children[fieldIdx].get(); +} + +Value* RelVal::getSrcNodeIDVal(const Value* val) { + auto fieldIdx = StructType::getFieldIdx(val->dataType, InternalKeyword::SRC); + return val->children[fieldIdx].get(); +} + +Value* RelVal::getDstNodeIDVal(const Value* val) { + auto fieldIdx = StructType::getFieldIdx(val->dataType, InternalKeyword::DST); + return val->children[fieldIdx].get(); +} + +Value* RelVal::getLabelVal(const Value* val) { + auto fieldIdx = StructType::getFieldIdx(val->dataType, InternalKeyword::LABEL); + return val->children[fieldIdx].get(); +} + +std::string RelVal::toString(const Value* val) { + throwIfNotRel(val); + return val->toString(); +} + +void RelVal::throwIfNotRel(const Value* val) { + // LCOV_EXCL_START + if (val->dataType.getLogicalTypeID() != LogicalTypeID::REL) { + throw Exception( + stringFormat("Expected REL type, but got {} type", val->dataType.toString())); + } + // LCOV_EXCL_STOP +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/value.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/value.cpp new file mode 100644 index 0000000000..f9f7f9734d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/types/value/value.cpp @@ -0,0 +1,1162 @@ +#include "common/types/value/value.h" + +#include + +#include "common/exception/binder.h" +#include "common/null_buffer.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/type_utils.h" +#include "common/types/blob.h" +#include "common/types/ku_string.h" +#include "common/types/uuid.h" +#include "common/vector/value_vector.h" +#include "function/hash/hash_functions.h" +#include "storage/storage_utils.h" + +namespace lbug { +namespace common { + +bool Value::operator==(const Value& rhs) const { + if (dataType != rhs.dataType || isNull_ != rhs.isNull_) { + return false; + } + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: + return val.booleanVal == rhs.val.booleanVal; + case PhysicalTypeID::INT128: + return val.int128Val == rhs.val.int128Val; + case PhysicalTypeID::INT64: + return val.int64Val == rhs.val.int64Val; + case PhysicalTypeID::INT32: + return val.int32Val == rhs.val.int32Val; + case PhysicalTypeID::INT16: + return val.int16Val == rhs.val.int16Val; + case PhysicalTypeID::INT8: + return val.int8Val == rhs.val.int8Val; + case PhysicalTypeID::UINT64: + return val.uint64Val == rhs.val.uint64Val; + case PhysicalTypeID::UINT32: + return val.uint32Val == rhs.val.uint32Val; + case PhysicalTypeID::UINT16: + return val.uint16Val == rhs.val.uint16Val; + case PhysicalTypeID::UINT8: + return val.uint8Val == rhs.val.uint8Val; + case PhysicalTypeID::DOUBLE: + return val.doubleVal == rhs.val.doubleVal; + case PhysicalTypeID::FLOAT: + return val.floatVal == rhs.val.floatVal; + case PhysicalTypeID::POINTER: + return val.pointer == rhs.val.pointer; + case PhysicalTypeID::INTERVAL: + return val.intervalVal == rhs.val.intervalVal; + case PhysicalTypeID::INTERNAL_ID: + return val.internalIDVal == rhs.val.internalIDVal; + case PhysicalTypeID::UINT128: + return val.uint128Val == rhs.val.uint128Val; + case PhysicalTypeID::STRING: + return strVal == rhs.strVal; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::STRUCT: { + if (childrenSize != rhs.childrenSize) { + return false; + } + for (auto i = 0u; i < childrenSize; ++i) { + if (*children[i] != *rhs.children[i]) { + return false; + } + } + return true; + } + default: + KU_UNREACHABLE; + } +} + +void Value::setDataType(const LogicalType& dataType_) { + KU_ASSERT(allowTypeChange()); + dataType = dataType_.copy(); +} + +const LogicalType& Value::getDataType() const { + return dataType; +} + +void Value::setNull(bool flag) { + isNull_ = flag; +} + +void Value::setNull() { + isNull_ = true; +} + +bool Value::isNull() const { + return isNull_; +} + +std::unique_ptr Value::copy() const { + return std::make_unique(*this); +} + +Value Value::createNullValue() { + return {}; +} + +Value Value::createNullValue(const LogicalType& dataType) { + return Value(dataType); +} + +Value Value::createDefaultValue(const LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + return Value((int64_t)0); + case LogicalTypeID::INT32: + return Value((int32_t)0); + case LogicalTypeID::INT16: + return Value((int16_t)0); + case LogicalTypeID::INT8: + return Value((int8_t)0); + case LogicalTypeID::UINT64: + return Value((uint64_t)0); + case LogicalTypeID::UINT32: + return Value((uint32_t)0); + case LogicalTypeID::UINT16: + return Value((uint16_t)0); + case LogicalTypeID::UINT8: + return Value((uint8_t)0); + case LogicalTypeID::INT128: + return Value(int128_t(0)); + case LogicalTypeID::BOOL: + return Value(true); + case LogicalTypeID::DOUBLE: + return Value((double)0); + case LogicalTypeID::DATE: + return Value(date_t()); + case LogicalTypeID::TIMESTAMP_NS: + return Value(timestamp_ns_t()); + case LogicalTypeID::TIMESTAMP_MS: + return Value(timestamp_ms_t()); + case LogicalTypeID::TIMESTAMP_SEC: + return Value(timestamp_sec_t()); + case LogicalTypeID::TIMESTAMP_TZ: + return Value(timestamp_tz_t()); + case LogicalTypeID::TIMESTAMP: + return Value(timestamp_t()); + case LogicalTypeID::INTERVAL: + return Value(interval_t()); + case LogicalTypeID::INTERNAL_ID: + return Value(nodeID_t()); + case LogicalTypeID::UINT128: + return Value(uint128_t(0)); + case LogicalTypeID::BLOB: + return Value(LogicalType::BLOB(), std::string("")); + case LogicalTypeID::UUID: + return Value(LogicalType::UUID(), std::string("")); + case LogicalTypeID::STRING: + return Value(LogicalType::STRING(), std::string("")); + case LogicalTypeID::FLOAT: + return Value((float)0); + case LogicalTypeID::DECIMAL: { + Value ret(dataType.copy()); + ret.val.int128Val = 0; + ret.isNull_ = false; + ret.childrenSize = 0; + return ret; + } + case LogicalTypeID::ARRAY: { + std::vector> children; + const auto& childType = ArrayType::getChildType(dataType); + auto arraySize = ArrayType::getNumElements(dataType); + children.reserve(arraySize); + for (auto i = 0u; i < arraySize; ++i) { + children.push_back(std::make_unique(createDefaultValue(childType))); + } + return Value(dataType.copy(), std::move(children)); + } + case LogicalTypeID::MAP: + case LogicalTypeID::LIST: + case LogicalTypeID::UNION: { + // We can't create a default value for the union since the + // selected variant is runtime information. Default value + // is initialized when copying (see Value::copyFromUnion). + return Value(dataType.copy(), std::vector>{}); + } + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + std::vector> children; + for (auto& field : StructType::getFields(dataType)) { + children.push_back(std::make_unique(createDefaultValue(field.getType()))); + } + return Value(dataType.copy(), std::move(children)); + } + case LogicalTypeID::ANY: { + return createNullValue(); + } + default: + KU_UNREACHABLE; + } +} + +Value::Value(bool val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::BOOL(); + val.booleanVal = val_; +} + +Value::Value(int8_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INT8(); + val.int8Val = val_; +} + +Value::Value(int16_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INT16(); + val.int16Val = val_; +} + +Value::Value(int32_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INT32(); + val.int32Val = val_; +} + +Value::Value(int64_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INT64(); + val.int64Val = val_; +} + +Value::Value(uint8_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::UINT8(); + val.uint8Val = val_; +} + +Value::Value(uint16_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::UINT16(); + val.uint16Val = val_; +} + +Value::Value(uint32_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::UINT32(); + val.uint32Val = val_; +} + +Value::Value(uint64_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::UINT64(); + val.uint64Val = val_; +} + +Value::Value(int128_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INT128(); + val.int128Val = val_; +} + +Value::Value(ku_uuid_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::UUID(); + val.int128Val = val_.value; +} + +Value::Value(float val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::FLOAT(); + val.floatVal = val_; +} + +Value::Value(double val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::DOUBLE(); + val.doubleVal = val_; +} + +Value::Value(date_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::DATE(); + val.int32Val = val_.days; +} + +Value::Value(timestamp_ns_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::TIMESTAMP_NS(); + val.int64Val = val_.value; +} + +Value::Value(timestamp_ms_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::TIMESTAMP_MS(); + val.int64Val = val_.value; +} + +Value::Value(timestamp_sec_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::TIMESTAMP_SEC(); + val.int64Val = val_.value; +} + +Value::Value(timestamp_tz_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::TIMESTAMP_TZ(); + val.int64Val = val_.value; +} + +Value::Value(timestamp_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::TIMESTAMP(); + val.int64Val = val_.value; +} + +Value::Value(interval_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INTERVAL(); + val.intervalVal = val_; +} + +Value::Value(internalID_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::INTERNAL_ID(); + val.internalIDVal = val_; +} + +Value::Value(uint128_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::UINT128(); + val.uint128Val = val_; +} + +Value::Value(const char* val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::STRING(); + strVal = std::string(val_); +} + +Value::Value(const std::string& val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::STRING(); + strVal = val_; +} + +Value::Value(uint8_t* val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::POINTER(); + val.pointer = val_; +} + +Value::Value(LogicalType type, std::string val_) + : dataType{std::move(type)}, isNull_{false}, childrenSize{0} { + strVal = std::move(val_); +} + +Value::Value(LogicalType dataType_, std::vector> children) + : dataType{std::move(dataType_)}, isNull_{false} { + this->children = std::move(children); + childrenSize = this->children.size(); +} + +Value::Value(const Value& other) : isNull_{other.isNull_} { + dataType = other.dataType.copy(); + copyValueFrom(other); + childrenSize = other.childrenSize; +} + +void Value::copyFromRowLayout(const uint8_t* value) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INT64: { + val.int64Val = *((int64_t*)value); + } break; + case LogicalTypeID::DATE: + case LogicalTypeID::INT32: { + val.int32Val = *((int32_t*)value); + } break; + case LogicalTypeID::INT16: { + val.int16Val = *((int16_t*)value); + } break; + case LogicalTypeID::INT8: { + val.int8Val = *((int8_t*)value); + } break; + case LogicalTypeID::UINT64: { + val.uint64Val = *((uint64_t*)value); + } break; + case LogicalTypeID::UINT32: { + val.uint32Val = *((uint32_t*)value); + } break; + case LogicalTypeID::UINT16: { + val.uint16Val = *((uint16_t*)value); + } break; + case LogicalTypeID::UINT8: { + val.uint8Val = *((uint8_t*)value); + } break; + case LogicalTypeID::INT128: { + val.int128Val = *((int128_t*)value); + } break; + case LogicalTypeID::BOOL: { + val.booleanVal = *((bool*)value); + } break; + case LogicalTypeID::DOUBLE: { + val.doubleVal = *((double*)value); + } break; + case LogicalTypeID::FLOAT: { + val.floatVal = *((float*)value); + } break; + case LogicalTypeID::DECIMAL: { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + val.int16Val = (*(int16_t*)value); + break; + case PhysicalTypeID::INT32: + val.int32Val = (*(int32_t*)value); + break; + case PhysicalTypeID::INT64: + val.int64Val = (*(int64_t*)value); + break; + case PhysicalTypeID::INT128: + val.int128Val = (*(int128_t*)value); + break; + default: + KU_UNREACHABLE; + } + } break; + case LogicalTypeID::INTERVAL: { + val.intervalVal = *((interval_t*)value); + } break; + case LogicalTypeID::INTERNAL_ID: { + val.internalIDVal = *((nodeID_t*)value); + } break; + case LogicalTypeID::UINT128: { + val.uint128Val = *((uint128_t*)value); + } break; + case LogicalTypeID::BLOB: { + strVal = ((blob_t*)value)->value.getAsString(); + } break; + case LogicalTypeID::UUID: { + val.int128Val = ((ku_uuid_t*)value)->value; + strVal = UUID::toString(*((ku_uuid_t*)value)); + } break; + case LogicalTypeID::STRING: { + strVal = ((ku_string_t*)value)->getAsString(); + } break; + case LogicalTypeID::MAP: + case LogicalTypeID::LIST: { + copyFromRowLayoutList(*(ku_list_t*)value, ListType::getChildType(dataType)); + } break; + case LogicalTypeID::ARRAY: { + copyFromRowLayoutList(*(ku_list_t*)value, ArrayType::getChildType(dataType)); + } break; + case LogicalTypeID::UNION: { + copyFromUnion(value); + } break; + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + copyFromRowLayoutStruct(value); + } break; + case LogicalTypeID::POINTER: { + val.pointer = *((uint8_t**)value); + } break; + default: + KU_UNREACHABLE; + } +} + +void Value::copyFromColLayout(const uint8_t* value, ValueVector* vector) { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT64: { + val.int64Val = *((int64_t*)value); + } break; + case PhysicalTypeID::INT32: { + val.int32Val = *((int32_t*)value); + } break; + case PhysicalTypeID::INT16: { + val.int16Val = *((int16_t*)value); + } break; + case PhysicalTypeID::INT8: { + val.int8Val = *((int8_t*)value); + } break; + case PhysicalTypeID::UINT64: { + val.uint64Val = *((uint64_t*)value); + } break; + case PhysicalTypeID::UINT32: { + val.uint32Val = *((uint32_t*)value); + } break; + case PhysicalTypeID::UINT16: { + val.uint16Val = *((uint16_t*)value); + } break; + case PhysicalTypeID::UINT8: { + val.uint8Val = *((uint8_t*)value); + } break; + case PhysicalTypeID::INT128: { + val.int128Val = *((int128_t*)value); + } break; + case PhysicalTypeID::BOOL: { + val.booleanVal = *((bool*)value); + } break; + case PhysicalTypeID::DOUBLE: { + val.doubleVal = *((double*)value); + } break; + case PhysicalTypeID::FLOAT: { + val.floatVal = *((float*)value); + } break; + case PhysicalTypeID::INTERVAL: { + val.intervalVal = *((interval_t*)value); + } break; + case PhysicalTypeID::STRING: { + strVal = ((ku_string_t*)value)->getAsString(); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + copyFromColLayoutList(*(list_entry_t*)value, vector); + } break; + case PhysicalTypeID::STRUCT: { + copyFromColLayoutStruct(*(struct_entry_t*)value, vector); + } break; + case PhysicalTypeID::INTERNAL_ID: { + val.internalIDVal = *((nodeID_t*)value); + } break; + case PhysicalTypeID::UINT128: { + val.uint128Val = *((uint128_t*)value); + } break; + default: + KU_UNREACHABLE; + } +} + +void Value::copyValueFrom(const Value& other) { + if (other.isNull()) { + isNull_ = true; + return; + } + isNull_ = false; + KU_ASSERT(dataType == other.dataType); + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + val.booleanVal = other.val.booleanVal; + } break; + case PhysicalTypeID::INT64: { + val.int64Val = other.val.int64Val; + } break; + case PhysicalTypeID::INT32: { + val.int32Val = other.val.int32Val; + } break; + case PhysicalTypeID::INT16: { + val.int16Val = other.val.int16Val; + } break; + case PhysicalTypeID::INT8: { + val.int8Val = other.val.int8Val; + } break; + case PhysicalTypeID::UINT64: { + val.uint64Val = other.val.uint64Val; + } break; + case PhysicalTypeID::UINT32: { + val.uint32Val = other.val.uint32Val; + } break; + case PhysicalTypeID::UINT16: { + val.uint16Val = other.val.uint16Val; + } break; + case PhysicalTypeID::UINT8: { + val.uint8Val = other.val.uint8Val; + } break; + case PhysicalTypeID::INT128: { + val.int128Val = other.val.int128Val; + } break; + case PhysicalTypeID::DOUBLE: { + val.doubleVal = other.val.doubleVal; + } break; + case PhysicalTypeID::FLOAT: { + val.floatVal = other.val.floatVal; + } break; + case PhysicalTypeID::INTERVAL: { + val.intervalVal = other.val.intervalVal; + } break; + case PhysicalTypeID::INTERNAL_ID: { + val.internalIDVal = other.val.internalIDVal; + } break; + case PhysicalTypeID::UINT128: { + val.uint128Val = other.val.uint128Val; + } break; + case PhysicalTypeID::STRING: { + strVal = other.strVal; + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::STRUCT: { + for (auto& child : other.children) { + children.push_back(child->copy()); + } + } break; + case PhysicalTypeID::POINTER: { + val.pointer = other.val.pointer; + } break; + default: + KU_UNREACHABLE; + } +} + +std::string Value::toString() const { + if (isNull_) { + return ""; + } + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + return TypeUtils::toString(val.booleanVal); + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + return TypeUtils::toString(val.int64Val); + case LogicalTypeID::INT32: + return TypeUtils::toString(val.int32Val); + case LogicalTypeID::INT16: + return TypeUtils::toString(val.int16Val); + case LogicalTypeID::INT8: + return TypeUtils::toString(val.int8Val); + case LogicalTypeID::UINT64: + return TypeUtils::toString(val.uint64Val); + case LogicalTypeID::UINT32: + return TypeUtils::toString(val.uint32Val); + case LogicalTypeID::UINT16: + return TypeUtils::toString(val.uint16Val); + case LogicalTypeID::UINT8: + return TypeUtils::toString(val.uint8Val); + case LogicalTypeID::INT128: + return TypeUtils::toString(val.int128Val); + case LogicalTypeID::DOUBLE: + return TypeUtils::toString(val.doubleVal); + case LogicalTypeID::FLOAT: + return TypeUtils::toString(val.floatVal); + case LogicalTypeID::DECIMAL: + return decimalToString(); + case LogicalTypeID::POINTER: + return TypeUtils::toString((uint64_t)val.pointer); + case LogicalTypeID::DATE: + return TypeUtils::toString(date_t{val.int32Val}); + case LogicalTypeID::TIMESTAMP_NS: + return TypeUtils::toString(timestamp_ns_t{val.int64Val}); + case LogicalTypeID::TIMESTAMP_MS: + return TypeUtils::toString(timestamp_ms_t{val.int64Val}); + case LogicalTypeID::TIMESTAMP_SEC: + return TypeUtils::toString(timestamp_sec_t{val.int64Val}); + case LogicalTypeID::TIMESTAMP_TZ: + return TypeUtils::toString(timestamp_tz_t{val.int64Val}); + case LogicalTypeID::TIMESTAMP: + return TypeUtils::toString(timestamp_t{val.int64Val}); + case LogicalTypeID::INTERVAL: + return TypeUtils::toString(val.intervalVal); + case LogicalTypeID::INTERNAL_ID: + return TypeUtils::toString(val.internalIDVal); + case LogicalTypeID::UINT128: + return TypeUtils::toString(val.uint128Val); + case LogicalTypeID::BLOB: + return Blob::toString(reinterpret_cast(strVal.c_str()), strVal.length()); + case LogicalTypeID::UUID: + return UUID::toString(val.int128Val); + case LogicalTypeID::STRING: + return strVal; + case LogicalTypeID::MAP: { + return mapToString(); + } + case LogicalTypeID::LIST: + case LogicalTypeID::ARRAY: { + return listToString(); + } + case LogicalTypeID::UNION: { + // Only one member in the union can be active at a time and that member is always stored + // at index 0. + return children[0]->toString(); + } + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + return structToString(); + } + case LogicalTypeID::NODE: { + return nodeToString(); + } + case LogicalTypeID::REL: { + return relToString(); + } + default: + KU_UNREACHABLE; + } +} + +Value::Value() : isNull_{true}, childrenSize{0} { + dataType = LogicalType(LogicalTypeID::ANY); +} + +Value::Value(const LogicalType& dataType_) : isNull_{true}, childrenSize{0} { + dataType = dataType_.copy(); +} + +void Value::resizeChildrenVector(uint64_t size, const LogicalType& childType) { + if (size > children.size()) { + children.reserve(size); + for (auto i = children.size(); i < size; ++i) { + children.push_back(std::make_unique(createDefaultValue(childType))); + } + } + childrenSize = size; +} + +void Value::copyFromRowLayoutList(const ku_list_t& list, const LogicalType& childType) { + resizeChildrenVector(list.size, childType); + auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(childType); + auto listNullBytes = reinterpret_cast(list.overflowPtr); + auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size); + auto listValues = listNullBytes + numBytesForNullValues; + for (auto i = 0u; i < list.size; i++) { + auto childValue = children[i].get(); + if (NullBuffer::isNull(listNullBytes, i)) { + childValue->setNull(true); + } else { + childValue->setNull(false); + childValue->copyFromRowLayout(listValues); + } + listValues += numBytesPerElement; + } +} + +void Value::copyFromColLayoutList(const list_entry_t& listEntry, ValueVector* vec) { + auto dataVec = ListVector::getDataVector(vec); + resizeChildrenVector(listEntry.size, dataVec->dataType); + for (auto i = 0u; i < listEntry.size; i++) { + auto childValue = children[i].get(); + childValue->setNull(dataVec->isNull(listEntry.offset + i)); + if (!childValue->isNull()) { + childValue->copyFromColLayout(ListVector::getListValuesWithOffset(vec, listEntry, i), + dataVec); + } + } +} + +void Value::copyFromRowLayoutStruct(const uint8_t* kuStruct) { + auto numFields = childrenSize; + auto structNullValues = kuStruct; + auto structValues = structNullValues + NullBuffer::getNumBytesForNullValues(numFields); + for (auto i = 0u; i < numFields; i++) { + auto childValue = children[i].get(); + if (NullBuffer::isNull(structNullValues, i)) { + childValue->setNull(true); + } else { + childValue->setNull(false); + childValue->copyFromRowLayout(structValues); + } + structValues += storage::StorageUtils::getDataTypeSize(childValue->dataType); + } +} + +void Value::copyFromColLayoutStruct(const struct_entry_t& structEntry, ValueVector* vec) { + for (auto i = 0u; i < childrenSize; i++) { + children[i]->setNull(StructVector::getFieldVector(vec, i)->isNull(structEntry.pos)); + if (!children[i]->isNull()) { + auto fieldVector = StructVector::getFieldVector(vec, i); + children[i]->copyFromColLayout(fieldVector->getData() + + fieldVector->getNumBytesPerValue() * structEntry.pos, + fieldVector.get()); + } + } +} + +void Value::copyFromUnion(const uint8_t* kuUnion) { + auto childrenTypes = StructType::getFieldTypes(dataType); + auto unionNullValues = kuUnion; + auto unionValues = unionNullValues + NullBuffer::getNumBytesForNullValues(childrenTypes.size()); + // For union dataType, only one member can be active at a time. So we don't need to copy all + // union fields into value. + auto activeFieldIdx = UnionType::getInternalFieldIdx(*(union_field_idx_t*)unionValues); + // Create default value now that we know the active field + auto childValue = Value::createDefaultValue(*childrenTypes[activeFieldIdx]); + auto curMemberIdx = 0u; + // Seek to the current active member value. + while (curMemberIdx < activeFieldIdx) { + unionValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[curMemberIdx]); + curMemberIdx++; + } + if (NullBuffer::isNull(unionNullValues, activeFieldIdx)) { + childValue.setNull(true); + } else { + childValue.setNull(false); + childValue.copyFromRowLayout(unionValues); + } + if (children.empty()) { + children.push_back(std::make_unique(std::move(childValue))); + childrenSize = 1; + } else { + children[0] = std::make_unique(std::move(childValue)); + } +} + +void Value::serialize(Serializer& serializer) const { + dataType.serialize(serializer); + serializer.serializeValue(isNull_); + serializer.serializeValue(childrenSize); + + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + serializer.serializeValue(val.booleanVal); + } break; + case PhysicalTypeID::INT64: { + serializer.serializeValue(val.int64Val); + } break; + case PhysicalTypeID::INT32: { + serializer.serializeValue(val.int32Val); + } break; + case PhysicalTypeID::INT16: { + serializer.serializeValue(val.int16Val); + } break; + case PhysicalTypeID::INT8: { + serializer.serializeValue(val.int8Val); + } break; + case PhysicalTypeID::UINT64: { + serializer.serializeValue(val.uint64Val); + } break; + case PhysicalTypeID::UINT32: { + serializer.serializeValue(val.uint32Val); + } break; + case PhysicalTypeID::UINT16: { + serializer.serializeValue(val.uint16Val); + } break; + case PhysicalTypeID::UINT8: { + serializer.serializeValue(val.uint8Val); + } break; + case PhysicalTypeID::INT128: { + serializer.serializeValue(val.int128Val); + } break; + case PhysicalTypeID::DOUBLE: { + serializer.serializeValue(val.doubleVal); + } break; + case PhysicalTypeID::FLOAT: { + serializer.serializeValue(val.floatVal); + } break; + case PhysicalTypeID::INTERVAL: { + serializer.serializeValue(val.intervalVal); + } break; + case PhysicalTypeID::INTERNAL_ID: { + serializer.serializeValue(val.internalIDVal); + } break; + case PhysicalTypeID::UINT128: { + serializer.serializeValue(val.uint128Val); + } break; + case PhysicalTypeID::STRING: { + serializer.serializeValue(strVal); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::STRUCT: { + for (auto i = 0u; i < childrenSize; ++i) { + children[i]->serialize(serializer); + } + } break; + case PhysicalTypeID::ANY: { + // We want to be able to ser/deser values that are meant to just be null + if (!isNull_) { + KU_UNREACHABLE; + } + } break; + default: { + KU_UNREACHABLE; + } + } +} + +std::unique_ptr Value::deserialize(Deserializer& deserializer) { + LogicalType dataType = LogicalType::deserialize(deserializer); + std::unique_ptr val = std::make_unique(createDefaultValue(dataType)); + deserializer.deserializeValue(val->isNull_); + deserializer.deserializeValue(val->childrenSize); + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + deserializer.deserializeValue(val->val.booleanVal); + } break; + case PhysicalTypeID::INT64: { + deserializer.deserializeValue(val->val.int64Val); + } break; + case PhysicalTypeID::INT32: { + deserializer.deserializeValue(val->val.int32Val); + } break; + case PhysicalTypeID::INT16: { + deserializer.deserializeValue(val->val.int16Val); + } break; + case PhysicalTypeID::INT8: { + deserializer.deserializeValue(val->val.int8Val); + } break; + case PhysicalTypeID::UINT64: { + deserializer.deserializeValue(val->val.uint64Val); + } break; + case PhysicalTypeID::UINT32: { + deserializer.deserializeValue(val->val.uint32Val); + } break; + case PhysicalTypeID::UINT16: { + deserializer.deserializeValue(val->val.uint16Val); + } break; + case PhysicalTypeID::UINT8: { + deserializer.deserializeValue(val->val.uint8Val); + } break; + case PhysicalTypeID::INT128: { + deserializer.deserializeValue(val->val.int128Val); + } break; + case PhysicalTypeID::DOUBLE: { + deserializer.deserializeValue(val->val.doubleVal); + } break; + case PhysicalTypeID::FLOAT: { + deserializer.deserializeValue(val->val.floatVal); + } break; + case PhysicalTypeID::INTERVAL: { + deserializer.deserializeValue(val->val.intervalVal); + } break; + case PhysicalTypeID::INTERNAL_ID: { + deserializer.deserializeValue(val->val.internalIDVal); + } break; + case PhysicalTypeID::UINT128: { + deserializer.deserializeValue(val->val.uint128Val); + } break; + case PhysicalTypeID::STRING: { + deserializer.deserializeValue(val->strVal); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::STRUCT: { + val->children.resize(val->childrenSize); + for (auto i = 0u; i < val->childrenSize; i++) { + val->children[i] = deserialize(deserializer); + } + } break; + case PhysicalTypeID::ANY: { + // We want to be able to ser/deser values that are meant to just be null + if (!val->isNull_) { + KU_UNREACHABLE; + } + } break; + default: { + KU_UNREACHABLE; + } + } + return val; +} + +void Value::validateType(LogicalTypeID targetTypeID) const { + if (dataType.getLogicalTypeID() == targetTypeID) { + return; + } + throw BinderException(stringFormat("{} has data type {} but {} was expected.", toString(), + dataType.toString(), LogicalTypeUtils::toString(targetTypeID))); +} + +bool Value::hasNoneNullChildren() const { + for (auto i = 0u; i < childrenSize; ++i) { + if (!children[i]->isNull()) { + return true; + } + } + return false; +} + +// Handle the case of casting empty list to a different type. +bool Value::allowTypeChange() const { + if (isNull_ || !dataType.isInternalType()) { + return true; + } + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::ANY: + return true; + case LogicalTypeID::LIST: + case LogicalTypeID::ARRAY: { + if (childrenSize == 0) { + return true; + } + for (auto i = 0u; i < childrenSize; ++i) { + if (children[i]->allowTypeChange()) { + return true; + } + } + return false; + } + case LogicalTypeID::STRUCT: { + for (auto i = 0u; i < childrenSize; ++i) { + if (children[i]->allowTypeChange()) { + return true; + } + } + return false; + } + case LogicalTypeID::MAP: { + if (childrenSize == 0) { + return true; + } + for (auto i = 0u; i < childrenSize; ++i) { + auto k = children[i]->children[0].get(); + auto v = children[i]->children[1].get(); + if (k->allowTypeChange() || v->allowTypeChange()) { + return true; + } + } + return false; + } + default: + return false; + } +} + +uint64_t Value::computeHash() const { + if (isNull_) { + return function::NULL_HASH; + } + hash_t hashValue = 0; + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + function::Hash::operation(val.booleanVal, hashValue); + } break; + case PhysicalTypeID::INT128: { + function::Hash::operation(val.int128Val, hashValue); + } break; + case PhysicalTypeID::INT64: { + function::Hash::operation(val.int64Val, hashValue); + } break; + case PhysicalTypeID::INT32: { + function::Hash::operation(val.int32Val, hashValue); + } break; + case PhysicalTypeID::INT16: { + function::Hash::operation(val.int16Val, hashValue); + } break; + case PhysicalTypeID::INT8: { + function::Hash::operation(val.int8Val, hashValue); + } break; + case PhysicalTypeID::UINT64: { + function::Hash::operation(val.uint64Val, hashValue); + } break; + case PhysicalTypeID::UINT32: { + function::Hash::operation(val.uint32Val, hashValue); + } break; + case PhysicalTypeID::UINT16: { + function::Hash::operation(val.uint16Val, hashValue); + } break; + case PhysicalTypeID::UINT8: { + function::Hash::operation(val.uint8Val, hashValue); + } break; + case PhysicalTypeID::DOUBLE: { + function::Hash::operation(val.doubleVal, hashValue); + } break; + case PhysicalTypeID::FLOAT: { + function::Hash::operation(val.floatVal, hashValue); + } break; + case PhysicalTypeID::INTERVAL: { + function::Hash::operation(val.intervalVal, hashValue); + } break; + case PhysicalTypeID::INTERNAL_ID: { + function::Hash::operation(val.internalIDVal, hashValue); + } break; + case PhysicalTypeID::UINT128: { + function::Hash::operation(val.uint128Val, hashValue); + } break; + case PhysicalTypeID::STRING: { + function::Hash::operation(strVal, hashValue); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::STRUCT: { + if (childrenSize == 0) { + return function::NULL_HASH; + } + hashValue = children[0]->computeHash(); + for (auto i = 1u; i < childrenSize; i++) { + hashValue = function::combineHashScalar(hashValue, children[i]->computeHash()); + } + } break; + default: { + KU_UNREACHABLE; + } + } + return hashValue; +} + +std::string Value::mapToString() const { + std::string result = "{"; + for (auto i = 0u; i < childrenSize; ++i) { + auto structVal = children[i].get(); + result += structVal->children[0]->toString(); + result += "="; + result += structVal->children[1]->toString(); + result += (i == childrenSize - 1 ? "" : ", "); + } + result += "}"; + return result; +} + +std::string Value::listToString() const { + std::string result = "["; + for (auto i = 0u; i < childrenSize; ++i) { + result += children[i]->toString(); + if (i != childrenSize - 1) { + result += ","; + } + } + result += "]"; + return result; +} + +std::string Value::structToString() const { + std::string result = "{"; + auto fieldNames = StructType::getFieldNames(dataType); + for (auto i = 0u; i < childrenSize; ++i) { + result += fieldNames[i] + ": "; + result += children[i]->toString(); + if (i != childrenSize - 1) { + result += ", "; + } + } + result += "}"; + return result; +} + +std::string Value::nodeToString() const { + if (children[0]->isNull_) { + // NODE is represented as STRUCT. We don't have a way to represent STRUCT as null. + // Instead, we check the internal ID entry to decide if a NODE is NULL. + return ""; + } + std::string result = "{"; + auto fieldNames = StructType::getFieldNames(dataType); + for (auto i = 0u; i < childrenSize; ++i) { + if (children[i]->isNull_) { + // Avoid printing null key value pair. + continue; + } + if (i != 0) { + result += ", "; + } + result += fieldNames[i] + ": " + children[i]->toString(); + } + result += "}"; + return result; +} + +std::string Value::relToString() const { + if (children[3]->isNull_) { + // REL is represented as STRUCT. We don't have a way to represent STRUCT as null. + // Instead, we check the internal ID entry to decide if a REL is NULL. + return ""; + } + std::string result = "(" + children[0]->toString() + ")-{"; + auto fieldNames = StructType::getFieldNames(dataType); + for (auto i = 2u; i < childrenSize; ++i) { + if (children[i]->isNull_) { + // Avoid printing null key value pair. + continue; + } + if (i != 2) { + result += ", "; + } + result += fieldNames[i] + ": " + children[i]->toString(); + } + result += "}->(" + children[1]->toString() + ")"; + return result; +} + +std::string Value::decimalToString() const { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + return DecimalType::insertDecimalPoint(TypeUtils::toString(val.int16Val), + DecimalType::getScale(dataType)); + case PhysicalTypeID::INT32: + return DecimalType::insertDecimalPoint(TypeUtils::toString(val.int32Val), + DecimalType::getScale(dataType)); + case PhysicalTypeID::INT64: + return DecimalType::insertDecimalPoint(TypeUtils::toString(val.int64Val), + DecimalType::getScale(dataType)); + case PhysicalTypeID::INT128: + return DecimalType::insertDecimalPoint(TypeUtils::toString(val.int128Val), + DecimalType::getScale(dataType)); + default: + KU_UNREACHABLE; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/utils.cpp new file mode 100644 index 0000000000..2475408a65 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/utils.cpp @@ -0,0 +1,29 @@ +#include "common/utils.h" + +namespace lbug { +namespace common { + +uint64_t nextPowerOfTwo(uint64_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} + +uint64_t prevPowerOfTwo(uint64_t v) { + return nextPowerOfTwo((v / 2) + 1); +} + +bool isLittleEndian() { + // Little endian arch stores the least significant value in the lower bytes. + int testNumber = 1; + return *(uint8_t*)&testNumber == 1; +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/CMakeLists.txt new file mode 100644 index 0000000000..b5b23d9bf2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_common_vector + OBJECT + auxiliary_buffer.cpp + value_vector.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/auxiliary_buffer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/auxiliary_buffer.cpp new file mode 100644 index 0000000000..0b596a9794 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/auxiliary_buffer.cpp @@ -0,0 +1,95 @@ +#include "common/vector/auxiliary_buffer.h" + +#include + +#include "common/constants.h" +#include "common/system_config.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +StructAuxiliaryBuffer::StructAuxiliaryBuffer(const LogicalType& type, + storage::MemoryManager* memoryManager) { + auto fieldTypes = StructType::getFieldTypes(type); + childrenVectors.reserve(fieldTypes.size()); + for (const auto& fieldType : fieldTypes) { + childrenVectors.push_back(std::make_shared(fieldType->copy(), memoryManager)); + } +} + +ListAuxiliaryBuffer::ListAuxiliaryBuffer(const LogicalType& dataVectorType, + storage::MemoryManager* memoryManager) + : capacity{DEFAULT_VECTOR_CAPACITY}, size{0}, + dataVector{std::make_shared(dataVectorType.copy(), memoryManager)} {} + +list_entry_t ListAuxiliaryBuffer::addList(list_size_t listSize) { + auto listEntry = list_entry_t{size, listSize}; + bool needResizeDataVector = size + listSize > capacity; + while (size + listSize > capacity) { + capacity *= CHUNK_RESIZE_RATIO; + } + if (needResizeDataVector) { + resizeDataVector(dataVector.get()); + } + size += listSize; + return listEntry; +} + +void ListAuxiliaryBuffer::resize(uint64_t numValues) { + if (numValues <= capacity) { + size = numValues; + return; + } + bool needResizeDataVector = numValues > capacity; + while (numValues > capacity) { + capacity *= 2; + KU_ASSERT(capacity != 0); + } + if (needResizeDataVector) { + resizeDataVector(dataVector.get()); + } + size = numValues; +} + +void ListAuxiliaryBuffer::resizeDataVector(ValueVector* dataVector) { + auto buffer = std::make_unique(capacity * dataVector->getNumBytesPerValue()); + memcpy(buffer.get(), dataVector->valueBuffer.get(), size * dataVector->getNumBytesPerValue()); + dataVector->valueBuffer = std::move(buffer); + dataVector->nullMask.resize(capacity); + // If the dataVector is a struct vector, we need to resize its field vectors. + if (dataVector->dataType.getPhysicalType() == PhysicalTypeID::STRUCT) { + resizeStructDataVector(dataVector); + } +} + +void ListAuxiliaryBuffer::resizeStructDataVector(ValueVector* dataVector) { + std::iota(reinterpret_cast( + dataVector->getData() + dataVector->getNumBytesPerValue() * size), + reinterpret_cast( + dataVector->getData() + dataVector->getNumBytesPerValue() * capacity), + size); + auto fieldVectors = StructVector::getFieldVectors(dataVector); + for (auto& fieldVector : fieldVectors) { + resizeDataVector(fieldVector.get()); + } +} + +std::unique_ptr AuxiliaryBufferFactory::getAuxiliaryBuffer(LogicalType& type, + storage::MemoryManager* memoryManager) { + switch (type.getPhysicalType()) { + case PhysicalTypeID::STRING: + return std::make_unique(memoryManager); + case PhysicalTypeID::STRUCT: + return std::make_unique(type, memoryManager); + case PhysicalTypeID::LIST: + return std::make_unique(ListType::getChildType(type), memoryManager); + case PhysicalTypeID::ARRAY: + return std::make_unique(ArrayType::getChildType(type), memoryManager); + default: + return nullptr; + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/value_vector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/value_vector.cpp new file mode 100644 index 0000000000..4a3b6f9c8b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/vector/value_vector.cpp @@ -0,0 +1,704 @@ +#include "common/vector/value_vector.h" + +#include + +#include "common/exception/runtime.h" +#include "common/null_buffer.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/system_config.h" +#include "common/types/uint128_t.h" +#include "common/types/value/nested.h" +#include "common/types/value/value.h" +#include "common/vector/auxiliary_buffer.h" + +namespace lbug { +namespace common { + +ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager, + std::shared_ptr dataChunkState) + : dataType{std::move(dataType)}, nullMask{DEFAULT_VECTOR_CAPACITY} { + if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { + // LCOV_EXCL_START + // Alternatively we can assign a default type here but I don't think it's a good practice. + throw RuntimeException("Trying to a create a vector with ANY type. This should not happen. " + "Data type is expected to be resolved during binding."); + // LCOV_EXCL_STOP + } + numBytesPerValue = getDataTypeSize(this->dataType); + initializeValueBuffer(); + auxiliaryBuffer = AuxiliaryBufferFactory::getAuxiliaryBuffer(this->dataType, memoryManager); + if (dataChunkState) { + setState(dataChunkState); + } +} + +void ValueVector::setState(const std::shared_ptr& state_) { + this->state = state_; + if (dataType.getPhysicalType() == PhysicalTypeID::STRUCT) { + auto childrenVectors = StructVector::getFieldVectors(this); + for (auto& childVector : childrenVectors) { + childVector->setState(state_); + } + } +} + +uint32_t ValueVector::countNonNull() const { + if (hasNoNullsGuarantee()) { + return state->getSelVector().getSelSize(); + } else if (state->getSelVector().isUnfiltered() && + state->getSelVector().getSelSize() == DEFAULT_VECTOR_CAPACITY) { + return DEFAULT_VECTOR_CAPACITY - nullMask.countNulls(); + } else { + uint32_t count = 0; + forEachNonNull([&](auto) { count++; }); + return count; + } +} + +bool ValueVector::discardNull(ValueVector& vector) { + if (vector.hasNoNullsGuarantee()) { + return true; + } + auto selectedPos = 0u; + if (vector.state->getSelVector().isUnfiltered()) { + auto buffer = vector.state->getSelVectorUnsafe().getMutableBuffer(); + for (auto i = 0u; i < vector.state->getSelVector().getSelSize(); i++) { + buffer[selectedPos] = i; + selectedPos += !vector.isNull(i); + } + vector.state->getSelVectorUnsafe().setToFiltered(); + } else { + for (auto i = 0u; i < vector.state->getSelVector().getSelSize(); i++) { + auto pos = vector.state->getSelVector()[i]; + vector.state->getSelVectorUnsafe()[selectedPos] = pos; + selectedPos += !vector.isNull(pos); + } + } + vector.state->getSelVectorUnsafe().setSelSize(selectedPos); + return selectedPos > 0; +} + +bool ValueVector::setNullFromBits(const uint64_t* srcNullEntries, uint64_t srcOffset, + uint64_t dstOffset, uint64_t numBitsToCopy, bool invert) { + return nullMask.copyFromNullBits(srcNullEntries, srcOffset, dstOffset, numBitsToCopy, invert); +} + +template +void ValueVector::setValue(uint32_t pos, T val) { + ((T*)valueBuffer.get())[pos] = val; +} + +void ValueVector::copyFromRowData(uint32_t pos, const uint8_t* rowData) { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::STRUCT: { + StructVector::copyFromRowData(this, pos, rowData); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + ListVector::copyFromRowData(this, pos, rowData); + } break; + case PhysicalTypeID::STRING: { + StringVector::addString(this, pos, *(ku_string_t*)rowData); + } break; + default: { + auto dataTypeSize = LogicalTypeUtils::getRowLayoutSize(dataType); + memcpy(getData() + pos * dataTypeSize, rowData, dataTypeSize); + } + } +} + +void ValueVector::copyToRowData(uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer) const { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::STRUCT: { + StructVector::copyToRowData(this, pos, rowData, rowOverflowBuffer); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + ListVector::copyToRowData(this, pos, rowData, rowOverflowBuffer); + } break; + case PhysicalTypeID::STRING: { + StringVector::copyToRowData(this, pos, rowData, rowOverflowBuffer); + } break; + default: { + auto dataTypeSize = LogicalTypeUtils::getRowLayoutSize(dataType); + memcpy(rowData, getData() + pos * dataTypeSize, dataTypeSize); + } + } +} + +void ValueVector::copyFromVectorData(uint8_t* dstData, const ValueVector* srcVector, + const uint8_t* srcVectorData) { + KU_ASSERT(srcVector->dataType.getPhysicalType() == dataType.getPhysicalType()); + switch (srcVector->dataType.getPhysicalType()) { + case PhysicalTypeID::STRUCT: { + StructVector::copyFromVectorData(this, dstData, srcVector, srcVectorData); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + ListVector::copyFromVectorData(this, dstData, srcVector, srcVectorData); + } break; + case PhysicalTypeID::STRING: { + StringVector::addString(this, *(ku_string_t*)dstData, *(ku_string_t*)srcVectorData); + } break; + default: { + memcpy(dstData, srcVectorData, srcVector->getNumBytesPerValue()); + } + } +} + +void ValueVector::copyFromVectorData(uint64_t dstPos, const ValueVector* srcVector, + uint64_t srcPos) { + setNull(dstPos, srcVector->isNull(srcPos)); + if (!isNull(dstPos)) { + copyFromVectorData(getData() + dstPos * getNumBytesPerValue(), srcVector, + srcVector->getData() + srcPos * srcVector->getNumBytesPerValue()); + } +} + +void ValueVector::copyFromValue(uint64_t pos, const Value& value) { + if (value.isNull()) { + setNull(pos, true); + return; + } + setNull(pos, false); + auto dstValue = valueBuffer.get() + pos * numBytesPerValue; + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT64: { + memcpy(dstValue, &value.val.int64Val, numBytesPerValue); + } break; + case PhysicalTypeID::INT32: { + memcpy(dstValue, &value.val.int32Val, numBytesPerValue); + } break; + case PhysicalTypeID::INT16: { + memcpy(dstValue, &value.val.int16Val, numBytesPerValue); + } break; + case PhysicalTypeID::INT8: { + memcpy(dstValue, &value.val.int8Val, numBytesPerValue); + } break; + case PhysicalTypeID::UINT64: { + memcpy(dstValue, &value.val.uint64Val, numBytesPerValue); + } break; + case PhysicalTypeID::UINT32: { + memcpy(dstValue, &value.val.uint32Val, numBytesPerValue); + } break; + case PhysicalTypeID::UINT16: { + memcpy(dstValue, &value.val.uint16Val, numBytesPerValue); + } break; + case PhysicalTypeID::UINT8: { + memcpy(dstValue, &value.val.uint8Val, numBytesPerValue); + } break; + case PhysicalTypeID::INT128: { + memcpy(dstValue, &value.val.int128Val, numBytesPerValue); + } break; + case PhysicalTypeID::DOUBLE: { + memcpy(dstValue, &value.val.doubleVal, numBytesPerValue); + } break; + case PhysicalTypeID::FLOAT: { + memcpy(dstValue, &value.val.floatVal, numBytesPerValue); + } break; + case PhysicalTypeID::BOOL: { + memcpy(dstValue, &value.val.booleanVal, numBytesPerValue); + } break; + case PhysicalTypeID::INTERVAL: { + memcpy(dstValue, &value.val.intervalVal, numBytesPerValue); + } break; + case PhysicalTypeID::STRING: { + StringVector::addString(this, *(ku_string_t*)dstValue, value.strVal.data(), + value.strVal.length()); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + auto listEntry = reinterpret_cast(dstValue); + auto numValues = NestedVal::getChildrenSize(&value); + *listEntry = ListVector::addList(this, numValues); + auto dstDataVector = ListVector::getDataVector(this); + for (auto i = 0u; i < numValues; ++i) { + auto childVal = NestedVal::getChildVal(&value, i); + dstDataVector->setNull(listEntry->offset + i, childVal->isNull()); + if (!childVal->isNull()) { + dstDataVector->copyFromValue(listEntry->offset + i, + *NestedVal::getChildVal(&value, i)); + } + } + } break; + case PhysicalTypeID::STRUCT: { + auto structFields = StructVector::getFieldVectors(this); + for (auto i = 0u; i < structFields.size(); ++i) { + structFields[i]->copyFromValue(pos, *NestedVal::getChildVal(&value, i)); + } + } break; + case PhysicalTypeID::INTERNAL_ID: { + memcpy(dstValue, &value.val.internalIDVal, numBytesPerValue); + } break; + case PhysicalTypeID::UINT128: { + memcpy(dstValue, &value.val.uint128Val, numBytesPerValue); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +std::unique_ptr ValueVector::getAsValue(uint64_t pos) const { + auto value = Value::createNullValue(dataType).copy(); + if (isNull(pos)) { + return value; + } + value->setNull(false); + value->dataType = dataType.copy(); + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT64: { + value->val.int64Val = getValue(pos); + } break; + case PhysicalTypeID::INT32: { + value->val.int32Val = getValue(pos); + } break; + case PhysicalTypeID::INT16: { + value->val.int16Val = getValue(pos); + } break; + case PhysicalTypeID::INT8: { + value->val.int8Val = getValue(pos); + } break; + case PhysicalTypeID::UINT64: { + value->val.uint64Val = getValue(pos); + } break; + case PhysicalTypeID::UINT32: { + value->val.uint32Val = getValue(pos); + } break; + case PhysicalTypeID::UINT16: { + value->val.uint16Val = getValue(pos); + } break; + case PhysicalTypeID::UINT8: { + value->val.uint8Val = getValue(pos); + } break; + case PhysicalTypeID::INT128: { + value->val.int128Val = getValue(pos); + } break; + case PhysicalTypeID::DOUBLE: { + value->val.doubleVal = getValue(pos); + } break; + case PhysicalTypeID::FLOAT: { + value->val.floatVal = getValue(pos); + } break; + case PhysicalTypeID::BOOL: { + value->val.booleanVal = getValue(pos); + } break; + case PhysicalTypeID::INTERVAL: { + value->val.intervalVal = getValue(pos); + } break; + case PhysicalTypeID::STRING: { + value->strVal = getValue(pos).getAsString(); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + auto dataVector = ListVector::getDataVector(this); + auto listEntry = getValue(pos); + std::vector> children; + children.reserve(listEntry.size); + for (auto i = 0u; i < listEntry.size; ++i) { + children.push_back(dataVector->getAsValue(listEntry.offset + i)); + } + value->childrenSize = children.size(); + value->children = std::move(children); + } break; + case PhysicalTypeID::STRUCT: { + auto& fieldVectors = StructVector::getFieldVectors(this); + std::vector> children; + children.reserve(fieldVectors.size()); + for (auto& fieldVector : fieldVectors) { + children.push_back(fieldVector->getAsValue(pos)); + } + value->childrenSize = children.size(); + value->children = std::move(children); + } break; + case PhysicalTypeID::INTERNAL_ID: { + value->val.internalIDVal = getValue(pos); + } break; + case PhysicalTypeID::UINT128: { + value->val.uint128Val = getValue(pos); + } break; + default: { + KU_UNREACHABLE; + } + } + return value; +} + +void ValueVector::resetAuxiliaryBuffer() { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::STRING: { + ku_dynamic_cast(auxiliaryBuffer.get())->resetOverflowBuffer(); + return; + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + auto listAuxiliaryBuffer = ku_dynamic_cast(auxiliaryBuffer.get()); + listAuxiliaryBuffer->resetSize(); + listAuxiliaryBuffer->getDataVector()->resetAuxiliaryBuffer(); + return; + } + case PhysicalTypeID::STRUCT: { + auto structAuxiliaryBuffer = ku_dynamic_cast(auxiliaryBuffer.get()); + for (auto& vector : structAuxiliaryBuffer->getFieldVectors()) { + vector->resetAuxiliaryBuffer(); + } + return; + } + default: + return; + } +} + +uint32_t ValueVector::getDataTypeSize(const LogicalType& type) { + switch (type.getPhysicalType()) { + case PhysicalTypeID::STRING: { + return sizeof(ku_string_t); + } + case PhysicalTypeID::STRUCT: { + return sizeof(struct_entry_t); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return sizeof(list_entry_t); + } + default: { + return PhysicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); + } + } +} + +void ValueVector::initializeValueBuffer() { + valueBuffer = std::make_unique(numBytesPerValue * DEFAULT_VECTOR_CAPACITY); + if (dataType.getPhysicalType() == PhysicalTypeID::STRUCT) { + // For struct valueVectors, each struct_entry_t stores its current position in the + // valueVector. + std::iota(reinterpret_cast(getData()), + reinterpret_cast(getData() + getNumBytesPerValue() * DEFAULT_VECTOR_CAPACITY), + 0); + } +} + +void ValueVector::serialize(Serializer& ser) const { + // dataType, num_values, data, nullMask, aux + ser.writeDebuggingInfo("data_type"); + dataType.serialize(ser); + ser.writeDebuggingInfo("num_values"); + const auto selSize = state->getSelVector().getSelSize(); + ser.write(selSize); + for (auto i = 0u; i < selSize; i++) { + const auto pos = state->getSelVector()[i]; + ser.write(nullMask.isNull(pos)); + } + ser.writeDebuggingInfo("values"); + for (auto i = 0u; i < selSize; i++) { + getAsValue(state->getSelVector()[i])->serialize(ser); + } +} + +std::unique_ptr ValueVector::deSerialize(Deserializer& deSer, + storage::MemoryManager* mm, std::shared_ptr dataChunkState) { + std::string key; + deSer.validateDebuggingInfo(key, "data_type"); + auto dataType = LogicalType::deserialize(deSer); + auto result = std::make_unique(std::move(dataType), mm); + result->setState(dataChunkState); + deSer.validateDebuggingInfo(key, "num_values"); + sel_t numValues = 0; + deSer.deserializeValue(numValues); + result->state->getSelVectorUnsafe().setSelSize(numValues); + KU_ASSERT(result->state->getSelVector().isUnfiltered()); + bool isNull = false; + for (auto i = 0u; i < numValues; i++) { + deSer.deserializeValue(isNull); + result->setNull(i, isNull); + } + deSer.validateDebuggingInfo(key, "values"); + for (auto i = 0u; i < numValues; i++) { + auto val = Value::deserialize(deSer); + result->copyFromValue(result->state->getSelVector()[i], *val); + } + return result; +} + +template LBUG_API void ValueVector::setValue(uint32_t pos, nodeID_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, bool val); +template LBUG_API void ValueVector::setValue(uint32_t pos, int64_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, int32_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, int16_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, int8_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, uint64_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, uint32_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, uint16_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, uint8_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, int128_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, uint128_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, double val); +template LBUG_API void ValueVector::setValue(uint32_t pos, float val); +template LBUG_API void ValueVector::setValue(uint32_t pos, date_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, timestamp_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, timestamp_ns_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, timestamp_ms_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, timestamp_sec_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, timestamp_tz_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, interval_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, list_entry_t val); +template LBUG_API void ValueVector::setValue(uint32_t pos, ku_uuid_t val); + +template<> +void ValueVector::setValue(uint32_t pos, ku_string_t val) { + StringVector::addString(this, pos, val); +} +template<> +void ValueVector::setValue(uint32_t pos, std::string val) { + StringVector::addString(this, pos, val.data(), val.length()); +} +template<> +void ValueVector::setValue(uint32_t pos, std::string_view val) { + StringVector::addString(this, pos, val.data(), val.length()); +} + +void ValueVector::setNull(uint32_t pos, bool isNull) { + nullMask.setNull(pos, isNull); +} + +void StringVector::addString(ValueVector* vector, uint32_t vectorPos, ku_string_t& srcStr) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + auto stringBuffer = ku_dynamic_cast(vector->auxiliaryBuffer.get()); + auto& dstStr = vector->getValue(vectorPos); + if (ku_string_t::isShortString(srcStr.len)) { + dstStr.setShortString(srcStr); + } else { + dstStr.overflowPtr = reinterpret_cast(stringBuffer->allocateOverflow(srcStr.len)); + dstStr.setLongString(srcStr); + } +} + +void StringVector::addString(ValueVector* vector, uint32_t vectorPos, const char* srcStr, + uint64_t length) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + auto stringBuffer = ku_dynamic_cast(vector->auxiliaryBuffer.get()); + auto& dstStr = vector->getValue(vectorPos); + if (ku_string_t::isShortString(length)) { + dstStr.setShortString(srcStr, length); + } else { + dstStr.overflowPtr = reinterpret_cast(stringBuffer->allocateOverflow(length)); + dstStr.setLongString(srcStr, length); + } +} + +void StringVector::addString(ValueVector* vector, uint32_t vectorPos, std::string_view srcStr) { + addString(vector, vectorPos, srcStr.data(), srcStr.length()); +} + +ku_string_t& StringVector::reserveString(ValueVector* vector, uint32_t vectorPos, uint64_t length) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + auto stringBuffer = ku_dynamic_cast(vector->auxiliaryBuffer.get()); + auto& dstStr = vector->getValue(vectorPos); + dstStr.len = length; + if (!ku_string_t::isShortString(length)) { + dstStr.overflowPtr = reinterpret_cast(stringBuffer->allocateOverflow(length)); + } + return dstStr; +} + +void StringVector::reserveString(ValueVector* vector, ku_string_t& dstStr, uint64_t length) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + auto stringBuffer = ku_dynamic_cast(vector->auxiliaryBuffer.get()); + dstStr.len = length; + if (!ku_string_t::isShortString(length)) { + dstStr.overflowPtr = reinterpret_cast(stringBuffer->allocateOverflow(length)); + } +} + +void StringVector::addString(ValueVector* vector, ku_string_t& dstStr, ku_string_t& srcStr) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + auto stringBuffer = ku_dynamic_cast(vector->auxiliaryBuffer.get()); + if (ku_string_t::isShortString(srcStr.len)) { + dstStr.setShortString(srcStr); + } else { + dstStr.overflowPtr = reinterpret_cast(stringBuffer->allocateOverflow(srcStr.len)); + dstStr.setLongString(srcStr); + } +} + +void StringVector::addString(ValueVector* vector, ku_string_t& dstStr, const char* srcStr, + uint64_t length) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + auto stringBuffer = ku_dynamic_cast(vector->auxiliaryBuffer.get()); + if (ku_string_t::isShortString(length)) { + dstStr.setShortString(srcStr, length); + } else { + dstStr.overflowPtr = reinterpret_cast(stringBuffer->allocateOverflow(length)); + dstStr.setLongString(srcStr, length); + } +} + +void StringVector::addString(lbug::common::ValueVector* vector, ku_string_t& dstStr, + const std::string& srcStr) { + addString(vector, dstStr, srcStr.data(), srcStr.length()); +} + +void StringVector::copyToRowData(const ValueVector* vector, uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer) { + auto& srcStr = vector->getValue(pos); + auto& dstStr = *(ku_string_t*)rowData; + if (ku_string_t::isShortString(srcStr.len)) { + dstStr.setShortString(srcStr); + } else { + dstStr.overflowPtr = + reinterpret_cast(rowOverflowBuffer->allocateSpace(srcStr.len)); + dstStr.setLongString(srcStr); + } +} + +void ListVector::copyListEntryAndBufferMetaData(ValueVector& vector, + const SelectionVector& selVector, const ValueVector& other, + const SelectionVector& otherSelVector) { + KU_ASSERT(selVector.getSelSize() == otherSelVector.getSelSize()); + // Copy list entries + for (auto i = 0u; i < otherSelVector.getSelSize(); ++i) { + auto pos = selVector[i]; + auto otherPos = otherSelVector[i]; + vector.setNull(pos, other.isNull(pos)); + if (!other.isNull(otherPos)) { + vector.setValue(pos, other.getValue(otherPos)); + } + } + // Copy buffer metadata + auto& buffer = getAuxBufferUnsafe(vector); + auto& otherBuffer = getAuxBuffer(other); + buffer.capacity = otherBuffer.capacity; + buffer.size = otherBuffer.size; +} + +void ListVector::copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData) { + KU_ASSERT(validateType(*vector)); + auto& srcKuList = *(ku_list_t*)rowData; + auto srcNullBytes = reinterpret_cast(srcKuList.overflowPtr); + auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size); + auto dstListEntry = addList(vector, srcKuList.size); + vector->setValue(pos, dstListEntry); + auto resultDataVector = getDataVector(vector); + auto rowLayoutSize = LogicalTypeUtils::getRowLayoutSize(resultDataVector->dataType); + for (auto i = 0u; i < srcKuList.size; i++) { + auto dstListValuePos = dstListEntry.offset + i; + if (NullBuffer::isNull(srcNullBytes, i)) { + resultDataVector->setNull(dstListValuePos, true); + } else { + resultDataVector->setNull(dstListValuePos, false); + resultDataVector->copyFromRowData(dstListValuePos, srcListValues); + } + srcListValues += rowLayoutSize; + } +} + +void ListVector::copyToRowData(const ValueVector* vector, uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer) { + auto& srcListEntry = vector->getValue(pos); + auto srcListDataVector = ListVector::getDataVector(vector); + auto& dstListEntry = *(ku_list_t*)rowData; + dstListEntry.size = srcListEntry.size; + auto nullBytesSize = NullBuffer::getNumBytesForNullValues(dstListEntry.size); + auto dataRowLayoutSize = LogicalTypeUtils::getRowLayoutSize(srcListDataVector->dataType); + auto dstListOverflowSize = dataRowLayoutSize * dstListEntry.size + nullBytesSize; + auto dstListOverflow = rowOverflowBuffer->allocateSpace(dstListOverflowSize); + dstListEntry.overflowPtr = reinterpret_cast(dstListOverflow); + NullBuffer::initNullBytes(dstListOverflow, dstListEntry.size); + auto dstListValues = dstListOverflow + nullBytesSize; + for (auto i = 0u; i < srcListEntry.size; i++) { + if (srcListDataVector->isNull(srcListEntry.offset + i)) { + NullBuffer::setNull(dstListOverflow, i); + } else { + srcListDataVector->copyToRowData(srcListEntry.offset + i, dstListValues, + rowOverflowBuffer); + } + dstListValues += dataRowLayoutSize; + } +} + +void ListVector::copyFromVectorData(ValueVector* dstVector, uint8_t* dstData, + const ValueVector* srcVector, const uint8_t* srcData) { + auto& srcListEntry = *(list_entry_t*)(srcData); + auto& dstListEntry = *(list_entry_t*)(dstData); + dstListEntry = addList(dstVector, srcListEntry.size); + auto srcDataVector = getDataVector(srcVector); + auto srcPos = srcListEntry.offset; + auto dstDataVector = getDataVector(dstVector); + auto dstPos = dstListEntry.offset; + for (auto i = 0u; i < srcListEntry.size; i++) { + dstDataVector->copyFromVectorData(dstPos++, srcDataVector, srcPos++); + } +} + +void ListVector::appendDataVector(ValueVector* dstVector, ValueVector* srcDataVector, + uint64_t numValuesToAppend) { + auto offset = getDataVectorSize(dstVector); + resizeDataVector(dstVector, offset + numValuesToAppend); + auto dstDataVector = getDataVector(dstVector); + for (auto i = 0u; i < numValuesToAppend; i++) { + dstDataVector->copyFromVectorData(offset + i, srcDataVector, i); + } +} + +void ListVector::sliceDataVector(ValueVector* vectorToSlice, uint64_t offset, uint64_t numValues) { + if (offset == 0) { + return; + } + for (auto i = 0u; i < numValues - offset; i++) { + vectorToSlice->copyFromVectorData(i, vectorToSlice, i + offset); + } +} + +void StructVector::copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRUCT); + auto& structFields = getFieldVectors(vector); + auto structNullBytes = rowData; + auto structValues = structNullBytes + NullBuffer::getNumBytesForNullValues(structFields.size()); + for (auto i = 0u; i < structFields.size(); i++) { + auto structField = structFields[i]; + if (NullBuffer::isNull(structNullBytes, i)) { + structField->setNull(pos, true /* isNull */); + } else { + structField->setNull(pos, false /* isNull */); + structField->copyFromRowData(pos, structValues); + } + structValues += LogicalTypeUtils::getRowLayoutSize(structField->dataType); + } +} + +void StructVector::copyToRowData(const ValueVector* vector, uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer) { + // The storage structure of STRUCT type in factorizedTable is: + // [NULLBYTES, FIELD1, FIELD2, ...] + auto& structFields = StructVector::getFieldVectors(vector); + NullBuffer::initNullBytes(rowData, structFields.size()); + auto structNullBytes = rowData; + auto structValues = structNullBytes + NullBuffer::getNumBytesForNullValues(structFields.size()); + for (auto i = 0u; i < structFields.size(); i++) { + auto structField = structFields[i]; + if (structField->isNull(pos)) { + NullBuffer::setNull(structNullBytes, i); + } else { + structField->copyToRowData(pos, structValues, rowOverflowBuffer); + } + structValues += LogicalTypeUtils::getRowLayoutSize(structField->dataType); + } +} + +void StructVector::copyFromVectorData(ValueVector* dstVector, const uint8_t* dstData, + const ValueVector* srcVector, const uint8_t* srcData) { + auto& srcPos = (*(struct_entry_t*)srcData).pos; + auto& dstPos = (*(struct_entry_t*)dstData).pos; + auto& srcFieldVectors = getFieldVectors(srcVector); + auto& dstFieldVectors = getFieldVectors(dstVector); + for (auto i = 0u; i < srcFieldVectors.size(); i++) { + auto srcFieldVector = srcFieldVectors[i]; + auto dstFieldVector = dstFieldVectors[i]; + dstFieldVector->copyFromVectorData(dstPos, srcFieldVector.get(), srcPos); + } +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/common/windows_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/common/windows_utils.cpp new file mode 100644 index 0000000000..ab499a39df --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/common/windows_utils.cpp @@ -0,0 +1,43 @@ +#if defined(_WIN32) +#include "common/windows_utils.h" + +#include + +#include "common/exception/io.h" + +namespace lbug { +namespace common { + +std::wstring WindowsUtils::utf8ToUnicode(const char* input) { + uint32_t result; + + result = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); + if (result == 0) { + throw IOException("Failure in MultiByteToWideChar"); + } + auto buffer = std::make_unique(result); + result = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result); + if (result == 0) { + throw IOException("Failure in MultiByteToWideChar"); + } + return std::wstring(buffer.get(), result); +} + +std::string WindowsUtils::unicodeToUTF8(LPCWSTR input) { + uint64_t resultSize; + + resultSize = WideCharToMultiByte(CP_UTF8, 0, input, -1, 0, 0, 0, 0); + if (resultSize == 0) { + throw IOException("Failure in WideCharToMultiByte"); + } + auto buffer = std::make_unique(resultSize); + resultSize = WideCharToMultiByte(CP_UTF8, 0, input, -1, buffer.get(), resultSize, 0, 0); + if (resultSize == 0) { + throw IOException("Failure in WideCharToMultiByte"); + } + return std::string(buffer.get(), resultSize - 1); +} + +} // namespace common +} // namespace lbug +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/CMakeLists.txt new file mode 100644 index 0000000000..301161287a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(lbug_expression_evaluator + OBJECT + case_evaluator.cpp + expression_evaluator.cpp + expression_evaluator_utils.cpp + expression_evaluator_visitor.cpp + function_evaluator.cpp + lambda_evaluator.cpp + list_slice_info.cpp + literal_evaluator.cpp + pattern_evaluator.cpp + path_evaluator.cpp + reference_evaluator.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/case_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/case_evaluator.cpp new file mode 100644 index 0000000000..d4e072a06f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/case_evaluator.cpp @@ -0,0 +1,108 @@ +#include "expression_evaluator/case_evaluator.h" + +using namespace lbug::main; +using namespace lbug::common; +using namespace lbug::processor; +using namespace lbug::storage; + +namespace lbug { +namespace evaluator { + +void CaseAlternativeEvaluator::init(const ResultSet& resultSet, + main::ClientContext* clientContext) { + whenEvaluator->init(resultSet, clientContext); + thenEvaluator->init(resultSet, clientContext); + whenSelVector = std::make_unique(DEFAULT_VECTOR_CAPACITY); + whenSelVector->setToFiltered(); +} + +void CaseExpressionEvaluator::init(const ResultSet& resultSet, main::ClientContext* clientContext) { + for (auto& alternativeEvaluator : alternativeEvaluators) { + alternativeEvaluator.init(resultSet, clientContext); + } + elseEvaluator->init(resultSet, clientContext); + ExpressionEvaluator::init(resultSet, clientContext); +} + +void CaseExpressionEvaluator::evaluate() { + filledMask.reset(); + for (auto& alternativeEvaluator : alternativeEvaluators) { + auto whenSelVector = alternativeEvaluator.whenSelVector.get(); + // the sel vector is already set to filtered in init() + auto hasAtLeastOneValue = alternativeEvaluator.whenEvaluator->select(*whenSelVector, false); + if (!hasAtLeastOneValue) { + continue; + } + alternativeEvaluator.thenEvaluator->evaluate(); + auto thenVector = alternativeEvaluator.thenEvaluator->resultVector.get(); + if (alternativeEvaluator.whenEvaluator->isResultFlat()) { + fillAll(thenVector); + } else { + fillSelected(*whenSelVector, thenVector); + } + if (filledMask.count() == resultVector->state->getSelVector().getSelSize()) { + return; + } + } + elseEvaluator->evaluate(); + fillAll(elseEvaluator->resultVector.get()); +} + +bool CaseExpressionEvaluator::selectInternal(SelectionVector& selVector) { + evaluate(); + KU_ASSERT(resultVector->state->getSelVector().getSelSize() != 0); + KU_ASSERT(selVector.getSelSize() != 0); + KU_ASSERT(resultVector->state->getSelVector().getSelSize() == selVector.getSelSize()); + auto numSelectedValues = 0u; + auto selectedPosBuffer = selVector.getMutableBuffer(); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto selVectorPos = selVector[i]; + auto resultVectorPos = resultVector->state->getSelVector()[i]; + selectedPosBuffer[numSelectedValues] = selVectorPos; + const bool selectCurrentValue = + !resultVector->isNull(resultVectorPos) && resultVector->getValue(resultVectorPos); + numSelectedValues += selectCurrentValue; + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; +} + +void CaseExpressionEvaluator::resolveResultVector(const ResultSet& /*resultSet*/, + MemoryManager* memoryManager) { + resultVector = std::make_shared(expression->dataType.copy(), memoryManager); + std::vector inputEvaluators; + for (auto& alternative : alternativeEvaluators) { + inputEvaluators.push_back(alternative.whenEvaluator.get()); + inputEvaluators.push_back(alternative.thenEvaluator.get()); + } + inputEvaluators.push_back(elseEvaluator.get()); + resolveResultStateFromChildren(inputEvaluators); +} + +void CaseExpressionEvaluator::fillSelected(const SelectionVector& selVector, + ValueVector* srcVector) { + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto resultPos = selVector[i]; + fillEntry(resultPos, srcVector); + } +} + +void CaseExpressionEvaluator::fillAll(ValueVector* srcVector) { + auto& resultSelVector = resultVector->state->getSelVector(); + for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) { + auto resultPos = resultSelVector[i]; + fillEntry(resultPos, srcVector); + } +} + +void CaseExpressionEvaluator::fillEntry(sel_t resultPos, ValueVector* srcVector) { + if (filledMask[resultPos]) { + return; + } + filledMask[resultPos] = true; + auto srcPos = srcVector->state->isFlat() ? srcVector->state->getSelVector()[0] : resultPos; + resultVector->copyFromVectorData(resultPos, srcVector, srcPos); +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator.cpp new file mode 100644 index 0000000000..1d2b2fdf95 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator.cpp @@ -0,0 +1,59 @@ +#include "expression_evaluator/expression_evaluator.h" + +#include "common/exception/runtime.h" +#include "main/client_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace evaluator { + +void ExpressionEvaluator::init(const processor::ResultSet& resultSet, + main::ClientContext* clientContext) { + localState.clientContext = clientContext; + for (auto& child : children) { + child->init(resultSet, clientContext); + } + resolveResultVector(resultSet, storage::MemoryManager::Get(*clientContext)); +} + +void ExpressionEvaluator::resolveResultStateFromChildren( + const std::vector& inputEvaluators) { + if (resultVector->state != nullptr) { + return; + } + for (auto& input : inputEvaluators) { + if (!input->isResultFlat()) { + isResultFlat_ = false; + resultVector->setState(input->resultVector->state); + return; + } + } + // All children are flat. + isResultFlat_ = true; + // We need to leave capacity for multiple evaluations + resultVector->setState(std::make_shared()); + resultVector->state->initOriginalAndSelectedSize(1); + resultVector->state->setToFlat(); +} + +void ExpressionEvaluator::evaluate(common::sel_t) { + // LCOV_EXCL_START + throw RuntimeException(stringFormat("Cannot evaluate expression {} with count. This should " + "never happen.", + expression->toString())); + // LCOV_EXCL_STOP +} + +bool ExpressionEvaluator::select(common::SelectionVector& selVector, + bool shouldSetSelVectorToFiltered) { + bool ret = selectInternal(selVector); + if (shouldSetSelVectorToFiltered && selVector.isUnfiltered()) { + selVector.setToFiltered(); + } + return ret; +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator_utils.cpp new file mode 100644 index 0000000000..1fee13ab8f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator_utils.cpp @@ -0,0 +1,25 @@ +#include "expression_evaluator/expression_evaluator_utils.h" + +#include "common/types/value/value.h" +#include "processor/expression_mapper.h" + +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace evaluator { + +Value ExpressionEvaluatorUtils::evaluateConstantExpression( + std::shared_ptr expression, main::ClientContext* clientContext) { + auto exprMapper = ExpressionMapper(); + auto evaluator = exprMapper.getConstantEvaluator(expression); + auto emptyResultSet = std::make_unique(0); + evaluator->init(*emptyResultSet, clientContext); + evaluator->evaluate(); + auto& selVector = evaluator->resultVector->state->getSelVector(); + KU_ASSERT(selVector.getSelSize() == 1); + return *evaluator->resultVector->getAsValue(selVector[0]); +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator_visitor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator_visitor.cpp new file mode 100644 index 0000000000..669b1f50d7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/expression_evaluator_visitor.cpp @@ -0,0 +1,67 @@ +#include "expression_evaluator/expression_evaluator_visitor.h" + +#include "expression_evaluator/case_evaluator.h" + +namespace lbug { +namespace evaluator { + +void ExpressionEvaluatorVisitor::visitSwitch(ExpressionEvaluator* evaluator) { + switch (evaluator->getEvaluatorType()) { + case EvaluatorType::CASE_ELSE: { + visitCase(evaluator); + } break; + case EvaluatorType::FUNCTION: { + visitFunction(evaluator); + } break; + case EvaluatorType::LAMBDA_PARAM: { + visitLambdaParam(evaluator); + } break; + case EvaluatorType::LIST_LAMBDA: { + visitListLambda(evaluator); + } break; + case EvaluatorType::LITERAL: { + visitLiteral(evaluator); + } break; + case EvaluatorType::PATH: { + visitPath(evaluator); + } break; + case EvaluatorType::NODE_REL: { + visitPattern(evaluator); + } break; + case EvaluatorType::REFERENCE: { + visitReference(evaluator); + } break; + default: + KU_UNREACHABLE; + } +} + +void LambdaParamEvaluatorCollector::visit(ExpressionEvaluator* evaluator) { + std::vector children; + switch (evaluator->getEvaluatorType()) { + case EvaluatorType::CASE_ELSE: { + auto& caseEvaluator = evaluator->constCast(); + children.push_back(caseEvaluator.getElseEvaluator()); + for (auto& alternativeEvaluator : caseEvaluator.getAlternativeEvaluators()) { + children.push_back(alternativeEvaluator.whenEvaluator.get()); + children.push_back(alternativeEvaluator.thenEvaluator.get()); + } + } break; + case EvaluatorType::LAMBDA_PARAM: { + evaluators.push_back(evaluator); + return; + } + default: { + for (auto& child : evaluator->getChildren()) { + children.push_back(child.get()); + } + } + } + for (auto& child : children) { + visit(child); + visitSwitch(child); + } +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/function_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/function_evaluator.cpp new file mode 100644 index 0000000000..cc1802de0c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/function_evaluator.cpp @@ -0,0 +1,81 @@ +#include "expression_evaluator/function_evaluator.h" + +#include "binder/expression/scalar_function_expression.h" +#include "function/sequence/sequence_functions.h" + +using namespace lbug::common; +using namespace lbug::processor; +using namespace lbug::storage; +using namespace lbug::main; +using namespace lbug::binder; +using namespace lbug::function; + +namespace lbug { +namespace evaluator { + +FunctionExpressionEvaluator::FunctionExpressionEvaluator(std::shared_ptr expression, + std::vector> children) + : ExpressionEvaluator{type_, std::move(expression), std::move(children)} { + auto& functionExpr = this->expression->constCast(); + function = functionExpr.getFunction().copy(); + bindData = functionExpr.getBindData()->copy(); +} + +void FunctionExpressionEvaluator::evaluate() { + auto ctx = localState.clientContext; + for (auto& child : children) { + child->evaluate(); + } + if (function->execFunc != nullptr) { + bindData->clientContext = ctx; + runExecFunc(bindData.get()); + } +} + +void FunctionExpressionEvaluator::evaluate(common::sel_t count) { + KU_ASSERT(expression->constCast().getFunction().name == + NextValFunction::name); + for (auto& child : children) { + child->evaluate(count); + } + bindData->count = count; + bindData->clientContext = localState.clientContext; + runExecFunc(bindData.get()); +} + +bool FunctionExpressionEvaluator::selectInternal(SelectionVector& selVector) { + for (auto& child : children) { + child->evaluate(); + } + // Temporary code path for function whose return type is BOOL but select interface is not + // implemented (e.g. list_contains). We should remove this if statement eventually. + if (function->selectFunc == nullptr) { + KU_ASSERT(resultVector->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); + runExecFunc(); + return updateSelectedPos(selVector); + } + return function->selectFunc(parameters, selVector, bindData.get()); +} + +void FunctionExpressionEvaluator::runExecFunc(void* dataPtr) { + function->execFunc(parameters, common::SelectionVector::fromValueVectors(parameters), + *resultVector, resultVector->getSelVectorPtr(), dataPtr); +} + +void FunctionExpressionEvaluator::resolveResultVector(const ResultSet& /*resultSet*/, + MemoryManager* memoryManager) { + resultVector = std::make_shared(expression->dataType.copy(), memoryManager); + std::vector inputEvaluators; + inputEvaluators.reserve(children.size()); + for (auto& child : children) { + parameters.push_back(child->resultVector); + inputEvaluators.push_back(child.get()); + } + resolveResultStateFromChildren(inputEvaluators); + if (function->compileFunc != nullptr) { + function->compileFunc(bindData.get(), parameters, resultVector); + } +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/lambda_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/lambda_evaluator.cpp new file mode 100644 index 0000000000..7e809b59d9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/lambda_evaluator.cpp @@ -0,0 +1,120 @@ +#include "expression_evaluator/lambda_evaluator.h" + +#include "binder/expression/lambda_expression.h" +#include "common/exception/runtime.h" +#include "expression_evaluator/expression_evaluator_visitor.h" +#include "expression_evaluator/list_slice_info.h" +#include "function/list/vector_list_functions.h" +#include "main/client_context.h" +#include "parser/expression/parsed_lambda_expression.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::processor; +using namespace lbug::main; +using namespace lbug::storage; + +namespace lbug { +namespace evaluator { + +void ListLambdaEvaluator::init(const ResultSet& resultSet, ClientContext* clientContext) { + for (auto& child : children) { + child->init(resultSet, clientContext); + } + KU_ASSERT(children.size() == 1); + auto listInputVector = children[0]->resultVector.get(); + // Find all param in lambda, e.g. find x in x->x+1 + auto collector = LambdaParamEvaluatorCollector(); + collector.visit(lambdaRootEvaluator.get()); + auto evaluators = collector.getEvaluators(); + auto lambdaVarState = std::make_shared(); + memoryManager = MemoryManager::Get(*clientContext); + for (auto& evaluator : evaluators) { + // For list_filter, list_transform: + // The resultVector of lambdaEvaluator should be the list dataVector. + // For list_reduce: + // We should create two vectors for each lambda variable resultVector since we are going to + // update the list elements during execution. + evaluator->resultVector = + listLambdaType != ListLambdaType::LIST_REDUCE ? + ListVector::getSharedDataVector(listInputVector) : + std::make_shared( + ListType::getChildType(listInputVector->dataType).copy(), memoryManager); + evaluator->resultVector->state = lambdaVarState; + lambdaParamEvaluators.push_back(evaluator->ptrCast()); + } + lambdaRootEvaluator->init(resultSet, clientContext); + resolveResultVector(resultSet, memoryManager); + params.push_back(children[0]->resultVector); + params.push_back(lambdaRootEvaluator->resultVector); + auto paramIndices = getParamIndices(); + bindData = ListLambdaBindData{lambdaParamEvaluators, paramIndices, lambdaRootEvaluator.get()}; +} + +void ListLambdaEvaluator::evaluateInternal() { + auto* inputVector = params[0].get(); + if (resultVector->dataType.getPhysicalType() == PhysicalTypeID::LIST) { + ListVector::resizeDataVector(resultVector.get(), + ListVector::getDataVectorSize(inputVector)); + } + ListSliceInfo sliceInfo{inputVector}; + bindData.sliceInfo = &sliceInfo; + auto selVectors = SelectionVector::fromValueVectors(params); + do { + sliceInfo.nextSlice(); + execFunc(params, selVectors, *resultVector, resultVector->getSelVectorPtr(), &bindData); + } while (!sliceInfo.done()); +} + +void ListLambdaEvaluator::evaluate() { + KU_ASSERT(children.size() == 1); + children[0]->evaluate(); + evaluateInternal(); +} + +bool ListLambdaEvaluator::selectInternal(SelectionVector& selVector) { + KU_ASSERT(children.size() == 1); + children[0]->evaluate(); + evaluateInternal(); + return updateSelectedPos(selVector); +} + +void ListLambdaEvaluator::resolveResultVector(const ResultSet&, MemoryManager* memoryManager) { + resultVector = std::make_shared(expression->getDataType().copy(), memoryManager); + resultVector->state = children[0]->resultVector->state; + isResultFlat_ = children[0]->isResultFlat(); +} + +std::vector ListLambdaEvaluator::getParamIndices() { + const auto& paramNames = getExpression() + ->getChild(1) + ->constCast() + .getParsedLambdaExpr() + ->constCast() + .getVarNames(); + std::vector index(lambdaParamEvaluators.size()); + for (idx_t i = 0; i < lambdaParamEvaluators.size(); i++) { + auto paramName = lambdaParamEvaluators[i]->getVarName(); + auto it = std::find(paramNames.begin(), paramNames.end(), paramName); + if (it != paramNames.end()) { + index[i] = it - paramNames.begin(); + } else { + throw RuntimeException(stringFormat("Lambda paramName {} cannot found.", paramName)); + } + } + return index; +} +ListLambdaType ListLambdaEvaluator::checkListLambdaTypeWithFunctionName(std::string functionName) { + if (0 == functionName.compare(function::ListTransformFunction::name)) { + return ListLambdaType::LIST_TRANSFORM; + } else if (0 == functionName.compare(function::ListFilterFunction::name)) { + return ListLambdaType::LIST_FILTER; + } else if (0 == functionName.compare(function::ListReduceFunction::name)) { + return ListLambdaType::LIST_REDUCE; + } else { + return ListLambdaType::DEFAULT; + } +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/list_slice_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/list_slice_info.cpp new file mode 100644 index 0000000000..7d24f4e70f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/list_slice_info.cpp @@ -0,0 +1,106 @@ + +#include "expression_evaluator/list_slice_info.h" + +#include "common/system_config.h" +#include "expression_evaluator/lambda_evaluator.h" + +namespace lbug { +namespace evaluator { + +ListEntryTracker::ListEntryTracker(common::ValueVector* listVector) + : listVector(listVector), listEntryIdx(0), offsetInList(0), + listEntries(listVector->state->getSelSize()) { + // it is not guaranteed that the list entries in a list vector are sorted by offset (the + // case evaluator breaks this) so we need to sort it manually + for (common::sel_t i = 0; i < listVector->state->getSelSize(); ++i) { + listEntries[i] = listVector->state->getSelVector()[i]; + } + std::sort(listEntries.begin(), listEntries.end(), [listVector](const auto a, const auto b) { + return listVector->getValue(a).offset < + listVector->getValue(b).offset; + }); + updateListEntry(); +} + +common::offset_t ListEntryTracker::getNextDataOffset() { + ++offsetInList; + if (offsetInList >= getCurListEntry().size) { + ++listEntryIdx; + updateListEntry(); + offsetInList = 0; + if (done()) { + return common::INVALID_OFFSET; + } + } + return getCurDataOffset(); +} + +void ListEntryTracker::updateListEntry() { + while (true) { + if (listEntryIdx >= listEntries.size()) { + break; + } + const auto newEntry = listVector->getValue(getListEntryPos()); + if (!listVector->isNull(getListEntryPos()) && newEntry.size > 0) { + break; + } + ++listEntryIdx; + } +} + +std::vector> ListSliceInfo::overrideAndSaveParamStates( + std::span lambdaParamEvaluators) { + std::vector> savedStates; + + // The sel states of the result vectors in evaluator trees often point to the same state + // First set the states to the unfiltered slice size + // This makes sure upstream evaluators have the correct input size and don't use the sliced + // offset + for (auto& lambdaParamEvaluator : lambdaParamEvaluators) { + auto param = lambdaParamEvaluator->resultVector.get(); + param->state->getSelVectorUnsafe().setToUnfiltered(getSliceSize()); + savedStates.push_back(param->state); + } + + // Then override the output sel state of the param's result vector + // This will be a list data vector that we need to get data from using the sliced offset + for (auto& lambdaParamEvaluator : lambdaParamEvaluators) { + auto param = lambdaParamEvaluator->resultVector.get(); + param->state = sliceDataState; + } + return savedStates; +} + +bool ListSliceInfo::done() const { + return listEntryTracker.done(); +} + +void ListSliceInfo::restoreParamStates( + std::span lambdaParamEvaluators, + std::vector> savedStates) { + for (size_t i = 0; i < lambdaParamEvaluators.size(); ++i) { + auto param = lambdaParamEvaluators[i]->resultVector.get(); + param->state = savedStates[i]; + } +} + +void ListSliceInfo::nextSlice() { + updateSelVector(); +} + +void ListSliceInfo::updateSelVector() { + auto& dataSel = sliceDataState->getSelVectorUnsafe(); + auto& listEntrySel = sliceListEntryState->getSelVectorUnsafe(); + common::offset_t sliceSize = 0; + while (!listEntryTracker.done() && sliceSize < common::DEFAULT_VECTOR_CAPACITY) { + dataSel[sliceSize] = listEntryTracker.getCurDataOffset(); + listEntrySel[sliceSize] = listEntryTracker.getListEntryPos(); + listEntryTracker.getNextDataOffset(); + ++sliceSize; + } + dataSel.setSelSize(sliceSize); + listEntrySel.setSelSize(sliceSize); +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/literal_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/literal_evaluator.cpp new file mode 100644 index 0000000000..f11db492ab --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/literal_evaluator.cpp @@ -0,0 +1,43 @@ +#include "expression_evaluator/literal_evaluator.h" + +#include "common/types/value/value.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::main; + +namespace lbug { +namespace evaluator { + +void LiteralExpressionEvaluator::evaluate() {} + +void LiteralExpressionEvaluator::evaluate(sel_t count) { + unFlatState->getSelVectorUnsafe().setSelSize(count); + resultVector->setState(unFlatState); + for (auto i = 1ul; i < count; i++) { + resultVector->copyFromVectorData(i, resultVector.get(), 0); + } +} + +bool LiteralExpressionEvaluator::selectInternal(SelectionVector&) { + KU_ASSERT(resultVector->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); + auto pos = resultVector->state->getSelVector()[0]; + KU_ASSERT(pos == 0u); + return resultVector->getValue(pos) && (!resultVector->isNull(pos)); +} + +void LiteralExpressionEvaluator::resolveResultVector(const processor::ResultSet& /*resultSet*/, + MemoryManager* memoryManager) { + resultVector = std::make_shared(value.getDataType().copy(), memoryManager); + flatState = DataChunkState::getSingleValueDataChunkState(); + unFlatState = std::make_shared(); + resultVector->setState(flatState); + if (value.isNull()) { + resultVector->setNull(0 /* pos */, true); + } else { + resultVector->copyFromValue(resultVector->state->getSelVector()[0], value); + } +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/path_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/path_evaluator.cpp new file mode 100644 index 0000000000..b654b19eac --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/path_evaluator.cpp @@ -0,0 +1,227 @@ +#include "expression_evaluator/path_evaluator.h" + +#include "binder/expression/path_expression.h" +#include "binder/expression/rel_expression.h" +#include "common/string_utils.h" + +using namespace lbug::main; +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace evaluator { + +// For each result field vector, find its corresponding input field vector if exist. +static std::vector getFieldVectors(const LogicalType& inputType, + const LogicalType& resultType, ValueVector* inputVector) { + std::vector result; + for (auto& field : StructType::getFields(resultType)) { + auto fieldName = StringUtils::getUpper(field.getName()); + if (StructType::hasField(inputType, fieldName)) { + auto idx = StructType::getFieldIdx(inputType, fieldName); + result.push_back(StructVector::getFieldVector(inputVector, idx).get()); + } else { + result.push_back(nullptr); + } + } + return result; +} + +void PathExpressionEvaluator::init(const processor::ResultSet& resultSet, + main::ClientContext* clientContext) { + ExpressionEvaluator::init(resultSet, clientContext); + auto resultNodesIdx = StructType::getFieldIdx(resultVector->dataType, InternalKeyword::NODES); + resultNodesVector = StructVector::getFieldVector(resultVector.get(), resultNodesIdx).get(); + auto resultNodesDataVector = ListVector::getDataVector(resultNodesVector); + for (auto& fieldVector : StructVector::getFieldVectors(resultNodesDataVector)) { + resultNodesFieldVectors.push_back(fieldVector.get()); + } + auto resultRelsIdx = StructType::getFieldIdx(resultVector->dataType, InternalKeyword::RELS); + resultRelsVector = StructVector::getFieldVector(resultVector.get(), resultRelsIdx).get(); + auto resultRelsDataVector = ListVector::getDataVector(resultRelsVector); + for (auto& fieldVector : StructVector::getFieldVectors(resultRelsDataVector)) { + resultRelsFieldVectors.push_back(fieldVector.get()); + } + auto pathExpression = (PathExpression*)expression.get(); + for (auto i = 0u; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); + auto vectors = std::make_unique(); + vectors->input = children[i]->resultVector.get(); + switch (child->dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + vectors->nodeFieldVectors = + getFieldVectors(child->dataType, pathExpression->getNodeType(), vectors->input); + } break; + case LogicalTypeID::REL: { + vectors->relFieldVectors = + getFieldVectors(child->dataType, pathExpression->getRelType(), vectors->input); + } break; + case LogicalTypeID::RECURSIVE_REL: { + auto rel = (RelExpression*)child; + auto recursiveNode = rel->getRecursiveInfo()->node; + auto recursiveRel = rel->getRecursiveInfo()->rel; + auto nodeFieldIdx = StructType::getFieldIdx(child->dataType, InternalKeyword::NODES); + vectors->nodesInput = StructVector::getFieldVector(vectors->input, nodeFieldIdx).get(); + vectors->nodesDataInput = ListVector::getDataVector(vectors->nodesInput); + vectors->nodeFieldVectors = getFieldVectors(recursiveNode->dataType, + pathExpression->getNodeType(), vectors->nodesDataInput); + auto relFieldIdx = + StructType::getFieldIdx(vectors->input->dataType, InternalKeyword::RELS); + vectors->relsInput = StructVector::getFieldVector(vectors->input, relFieldIdx).get(); + vectors->relsDataInput = ListVector::getDataVector(vectors->relsInput); + vectors->relFieldVectors = getFieldVectors(recursiveRel->dataType, + pathExpression->getRelType(), vectors->relsDataInput); + } break; + default: + KU_UNREACHABLE; + } + inputVectorsPerChild.push_back(std::move(vectors)); + } +} + +void PathExpressionEvaluator::evaluate() { + resultVector->resetAuxiliaryBuffer(); + for (auto& child : children) { + child->evaluate(); + } + auto& selVector = resultVector->state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto pos = selVector[i]; + auto numRels = copyRels(pos); + copyNodes(pos, numRels == 0); + } +} + +static inline uint32_t getCurrentPos(ValueVector* vector, uint32_t pos) { + if (vector->state->isFlat()) { + return vector->state->getSelVector()[0]; + } + return pos; +} + +void PathExpressionEvaluator::copyNodes(sel_t resultPos, bool isEmptyRels) { + auto listSize = 0u; + // Calculate list size. + for (auto i = 0u; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); + switch (child->dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + listSize++; + } break; + case LogicalTypeID::RECURSIVE_REL: { + auto vectors = inputVectorsPerChild[i].get(); + auto inputPos = getCurrentPos(vectors->input, resultPos); + listSize += vectors->nodesInput->getValue(inputPos).size; + } break; + default: + break; + } + } + if (isEmptyRels) { + listSize = 1; + } + // Add list entry. + auto entry = ListVector::addList(resultNodesVector, listSize); + resultNodesVector->setValue(resultPos, entry); + // Copy field vectors + offset_t resultDataPos = entry.offset; + auto numChildrenToCopy = isEmptyRels ? 1 : expression->getNumChildren(); + for (auto i = 0u; i < numChildrenToCopy; ++i) { + auto child = expression->getChild(i).get(); + auto vectors = inputVectorsPerChild[i].get(); + auto inputPos = getCurrentPos(vectors->input, resultPos); + switch (child->dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + copyFieldVectors(inputPos, vectors->nodeFieldVectors, resultDataPos, + resultNodesFieldVectors); + } break; + case LogicalTypeID::RECURSIVE_REL: { + auto& listEntry = vectors->nodesInput->getValue(inputPos); + for (auto j = 0u; j < listEntry.size; ++j) { + copyFieldVectors(listEntry.offset + j, vectors->nodeFieldVectors, resultDataPos, + resultNodesFieldVectors); + } + } break; + default: + break; + } + } +} + +uint64_t PathExpressionEvaluator::copyRels(sel_t resultPos) { + auto listSize = 0u; + // Calculate list size. + for (auto i = 0u; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); + switch (child->dataType.getLogicalTypeID()) { + case LogicalTypeID::REL: { + listSize++; + } break; + case LogicalTypeID::RECURSIVE_REL: { + auto vectors = inputVectorsPerChild[i].get(); + auto inputPos = getCurrentPos(vectors->input, resultPos); + listSize += vectors->relsInput->getValue(inputPos).size; + } break; + default: + break; + } + } + // Add list entry. + auto entry = ListVector::addList(resultRelsVector, listSize); + resultRelsVector->setValue(resultPos, entry); + // Copy field vectors + offset_t resultDataPos = entry.offset; + for (auto i = 0u; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); + auto vectors = inputVectorsPerChild[i].get(); + auto inputPos = getCurrentPos(vectors->input, resultPos); + switch (child->dataType.getLogicalTypeID()) { + case LogicalTypeID::REL: { + copyFieldVectors(inputPos, vectors->relFieldVectors, resultDataPos, + resultRelsFieldVectors); + } break; + case LogicalTypeID::RECURSIVE_REL: { + auto& listEntry = vectors->relsInput->getValue(inputPos); + for (auto j = 0u; j < listEntry.size; ++j) { + copyFieldVectors(listEntry.offset + j, vectors->relFieldVectors, resultDataPos, + resultRelsFieldVectors); + } + } break; + default: + break; + } + } + return listSize; +} + +void PathExpressionEvaluator::copyFieldVectors(offset_t inputVectorPos, + const std::vector& inputFieldVectors, offset_t& resultVectorPos, + const std::vector& resultFieldVectors) { + KU_ASSERT(resultFieldVectors.size() == inputFieldVectors.size()); + for (auto i = 0u; i < inputFieldVectors.size(); ++i) { + auto inputFieldVector = inputFieldVectors[i]; + auto resultFieldVector = resultFieldVectors[i]; + if (inputFieldVector == nullptr || inputFieldVector->isNull(inputVectorPos)) { + resultFieldVector->setNull(resultVectorPos, true); + continue; + } + resultFieldVector->setNull(resultVectorPos, false); + KU_ASSERT(inputFieldVector->dataType == resultFieldVector->dataType); + resultFieldVector->copyFromVectorData(resultVectorPos, inputFieldVector, inputVectorPos); + } + resultVectorPos++; +} + +void PathExpressionEvaluator::resolveResultVector(const processor::ResultSet& /*resultSet*/, + storage::MemoryManager* memoryManager) { + resultVector = std::make_shared(expression->getDataType().copy(), memoryManager); + std::vector inputEvaluators; + inputEvaluators.reserve(children.size()); + for (auto& child : children) { + inputEvaluators.push_back(child.get()); + } + resolveResultStateFromChildren(inputEvaluators); +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/pattern_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/pattern_evaluator.cpp new file mode 100644 index 0000000000..3fb1245244 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/pattern_evaluator.cpp @@ -0,0 +1,88 @@ +#include "expression_evaluator/pattern_evaluator.h" + +#include "common/constants.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::storage; +using namespace lbug::main; +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::processor; + +namespace lbug { +namespace evaluator { + +static void updateNullPattern(ValueVector& patternVector, const ValueVector& idVector) { + // If internal id is NULL, we should mark entire struct/node/rel as NULL. + for (auto i = 0u; i < patternVector.state->getSelVector().getSelSize(); ++i) { + auto pos = patternVector.state->getSelVector()[i]; + patternVector.setNull(pos, idVector.isNull(pos)); + } +} + +void PatternExpressionEvaluator::evaluate() { + for (auto& child : children) { + child->evaluate(); + } + StructPackFunctions::execFunc(parameters, SelectionVector::fromValueVectors(parameters), + *resultVector, resultVector->getSelVectorPtr()); + updateNullPattern(*resultVector, *idVector); +} + +void PatternExpressionEvaluator::resolveResultVector(const ResultSet& resultSet, + MemoryManager* memoryManager) { + const auto& dataType = expression->getDataType(); + resultVector = std::make_shared(dataType.copy(), memoryManager); + std::vector inputEvaluators; + inputEvaluators.reserve(children.size()); + for (auto& child : children) { + parameters.push_back(child->resultVector); + inputEvaluators.push_back(child.get()); + } + resolveResultStateFromChildren(inputEvaluators); + initFurther(resultSet); +} + +void PatternExpressionEvaluator::initFurther(const ResultSet&) { + StructPackFunctions::compileFunc(nullptr, parameters, resultVector); + const auto& dataType = expression->getDataType(); + auto fieldIdx = StructType::getFieldIdx(dataType.copy(), InternalKeyword::ID); + KU_ASSERT(fieldIdx != INVALID_STRUCT_FIELD_IDX); + idVector = StructVector::getFieldVector(resultVector.get(), fieldIdx).get(); +} + +void UndirectedRelExpressionEvaluator::evaluate() { + for (auto& child : children) { + child->evaluate(); + } + StructPackFunctions::undirectedRelPackExecFunc(parameters, *resultVector); + updateNullPattern(*resultVector, *idVector); + directionEvaluator->evaluate(); + auto& selVector = resultVector->state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + if (!directionVector->getValue(directionVector->state->getSelVector()[i])) { + continue; + } + auto pos = selVector[i]; + auto srcID = srcIDVector->getValue(pos); + auto dstID = dstIDVector->getValue(pos); + srcIDVector->setValue(pos, dstID); + dstIDVector->setValue(pos, srcID); + } +} + +void UndirectedRelExpressionEvaluator::initFurther(const ResultSet& resultSet) { + directionEvaluator->init(resultSet, localState.clientContext); + directionVector = directionEvaluator->resultVector.get(); + StructPackFunctions::undirectedRelCompileFunc(nullptr, parameters, resultVector); + const auto& dataType = expression->getDataType(); + auto idFieldIdx = StructType::getFieldIdx(dataType, InternalKeyword::ID); + auto srcFieldIdx = StructType::getFieldIdx(dataType, InternalKeyword::SRC); + auto dstFieldIdx = StructType::getFieldIdx(dataType, InternalKeyword::DST); + idVector = StructVector::getFieldVector(resultVector.get(), idFieldIdx).get(); + srcIDVector = StructVector::getFieldVector(resultVector.get(), srcFieldIdx).get(); + dstIDVector = StructVector::getFieldVector(resultVector.get(), dstFieldIdx).get(); +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/reference_evaluator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/reference_evaluator.cpp new file mode 100644 index 0000000000..1f40db48e8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/expression_evaluator/reference_evaluator.cpp @@ -0,0 +1,34 @@ +#include "expression_evaluator/reference_evaluator.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace evaluator { + +inline static bool isTrue(ValueVector& vector, uint64_t pos) { + KU_ASSERT(vector.dataType.getLogicalTypeID() == LogicalTypeID::BOOL); + return !vector.isNull(pos) && vector.getValue(pos); +} + +bool ReferenceExpressionEvaluator::selectInternal(SelectionVector& selVector) { + uint64_t numSelectedValues = 0; + auto selectedBuffer = resultVector->state->getSelVectorUnsafe().getMutableBuffer(); + if (resultVector->state->getSelVector().isUnfiltered()) { + for (auto i = 0u; i < resultVector->state->getSelVector().getSelSize(); i++) { + selectedBuffer[numSelectedValues] = i; + numSelectedValues += isTrue(*resultVector, i); + } + } else { + for (auto i = 0u; i < resultVector->state->getSelVector().getSelSize(); i++) { + auto pos = resultVector->state->getSelVector()[i]; + selectedBuffer[numSelectedValues] = pos; + numSelectedValues += isTrue(*resultVector, pos); + } + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; +} + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/CMakeLists.txt new file mode 100644 index 0000000000..b3ff980655 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/CMakeLists.txt @@ -0,0 +1,70 @@ +set(INCLUDES "") +set(LOAD_LINKED_EXTENSION "") + +foreach (EXT_NAME IN LISTS STATICALLY_LINKED_EXTENSIONS) + string(TOLOWER "${EXT_NAME}" EXT_NAME_LOWER) + # Convert to camel case + string(REPLACE "_" ";" EXT_NAME_SPLIT ${EXT_NAME}) + set(EXT_NAME_CAMELCASE "") + foreach (EXT_NAME_PART IN LISTS EXT_NAME_SPLIT) + string(SUBSTRING ${EXT_NAME_PART} 0 1 FIRST_LETTER) + string(SUBSTRING ${EXT_NAME_PART} 1 -1 REMAINDER) + string(TOUPPER ${FIRST_LETTER} FIRST_LETTER) + set(EXT_NAME_CAMELCASE "${EXT_NAME_CAMELCASE}${FIRST_LETTER}${REMAINDER}") + endforeach () + + set(LOAD_LINKED_EXTENSION "${LOAD_LINKED_EXTENSION}\ +{ + ${EXT_NAME_LOWER}_extension::${EXT_NAME_CAMELCASE}Extension extension{}; + extension.load(context); + loadedExtensions.push_back(LoadedExtension(${EXT_NAME_LOWER}_extension::${EXT_NAME_CAMELCASE}Extension::EXTENSION_NAME, \" \", + ExtensionSource::STATIC_LINKED)); + }\n") + include_directories(${PROJECT_SOURCE_DIR}/extension/${EXT_NAME}/src/include/main) + set(INCLUDES "${INCLUDES}#include \"${EXT_NAME_LOWER}_extension.h\"\n") +endforeach () + +configure_file( + "generated_extension_loader.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/codegen/include/generated_extension_loader.h" + @ONLY +) + +configure_file( + "generated_extension_loader.cpp.in" + "${CMAKE_CURRENT_BINARY_DIR}/codegen/generated_extension_loader.cpp" + @ONLY +) + +set(GENERATED_CPP_FILE + ${CMAKE_CURRENT_BINARY_DIR}/codegen/generated_extension_loader.cpp) + +include_directories( + "${CMAKE_CURRENT_BINARY_DIR}/codegen/include/" + ${PROJECT_SOURCE_DIR}/third_party/httplib +) + +add_library(lbug_generated_extension_loader OBJECT ${GENERATED_CPP_FILE}) + +# Both the lbug source and httpfs extension includes the httplib which is a header +# only library. Httpfs requires the openssl function to be enabled in httplib, so +# we also have to enable httplib there if httpfs is static linked. +if ("httpfs" IN_LIST STATICALLY_LINKED_EXTENSIONS) + add_compile_definitions(CPPHTTPLIB_OPENSSL_SUPPORT) + include_directories(${OPENSSL_INCLUDE_DIR}) +endif () + +add_library(lbug_extension + OBJECT + catalog_extension.cpp + extension.cpp + extension_entries.cpp + extension_installer.cpp + extension_manager.cpp + loaded_extension.cpp +) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/catalog_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/catalog_extension.cpp new file mode 100644 index 0000000000..2518e71294 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/catalog_extension.cpp @@ -0,0 +1,12 @@ +#include "extension/catalog_extension.h" + +namespace lbug { +namespace extension { + +void CatalogExtension::invalidateCache() { + tables = std::make_unique(); + init(); +} + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension.cpp new file mode 100644 index 0000000000..9895591f37 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension.cpp @@ -0,0 +1,247 @@ +#include "extension/extension.h" + +#include "common/exception/io.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "common/system_message.h" +#include "main/client_context.h" +#include "main/database.h" +#include "storage/storage_manager.h" + +#ifdef _WIN32 + +#include "windows.h" +#define RTLD_NOW 0 +#define RTLD_LOCAL 0 + +#else +#include +#endif + +namespace lbug { +namespace extension { + +std::string getOS() { + std::string os = "linux"; +#if !defined(_GLIBCXX_USE_CXX11_ABI) || _GLIBCXX_USE_CXX11_ABI == 0 + if (os == "linux") { + os = "linux_old"; + } +#endif +#ifdef _WIN32 + os = "win"; +#elif defined(__APPLE__) + os = "osx"; +#endif + return os; +} + +std::string getArch() { + std::string arch = "amd64"; +#if defined(__i386__) || defined(_M_IX86) + arch = "x86"; +#elif defined(__aarch64__) || defined(__ARM_ARCH_ISA_A64) + arch = "arm64"; +#endif + return arch; +} + +std::string getPlatform() { + return getOS() + "_" + getArch(); +} + +static ExtensionRepoInfo getExtensionRepoInfo(std::string& extensionURL) { + common::StringUtils::replaceAll(extensionURL, "http://", ""); + auto hostNamePos = extensionURL.find('/'); + auto hostName = extensionURL.substr(0, hostNamePos); + auto hostURL = "http://" + hostName; + auto hostPath = extensionURL.substr(hostNamePos); + return {hostPath, hostURL, extensionURL}; +} + +std::string ExtensionSourceUtils::toString(ExtensionSource source) { + switch (source) { + case ExtensionSource::OFFICIAL: + return "OFFICIAL"; + case ExtensionSource::USER: + return "USER"; + case ExtensionSource::STATIC_LINKED: + return "STATIC LINK"; + default: + KU_UNREACHABLE; + } +} + +static ExtensionRepoInfo getExtensionFilePath(const std::string& extensionName, + const std::string& extensionRepo, const std::string& fileName) { + auto extensionURL = common::stringFormat(ExtensionUtils::EXTENSION_FILE_REPO_PATH, + extensionRepo, LBUG_EXTENSION_VERSION, getPlatform(), extensionName, fileName); + return getExtensionRepoInfo(extensionURL); +} + +ExtensionRepoInfo ExtensionUtils::getExtensionLibRepoInfo(const std::string& extensionName, + const std::string& extensionRepo) { + return getExtensionFilePath(extensionName, extensionRepo, getExtensionFileName(extensionName)); +} + +ExtensionRepoInfo ExtensionUtils::getExtensionLoaderRepoInfo(const std::string& extensionName, + const std::string& extensionRepo) { + return getExtensionFilePath(extensionName, extensionRepo, + getExtensionFileName(extensionName + EXTENSION_LOADER_SUFFIX)); +} + +ExtensionRepoInfo ExtensionUtils::getExtensionInstallerRepoInfo(const std::string& extensionName, + const std::string& extensionRepo) { + return getExtensionFilePath(extensionName, extensionRepo, + getExtensionFileName(extensionName + EXTENSION_INSTALLER_SUFFIX)); +} + +ExtensionRepoInfo ExtensionUtils::getSharedLibRepoInfo(const std::string& fileName, + const std::string& extensionRepo) { + auto extensionURL = common::stringFormat(SHARED_LIB_REPO, extensionRepo, LBUG_EXTENSION_VERSION, + getPlatform(), fileName); + return getExtensionRepoInfo(extensionURL); +} + +std::string ExtensionUtils::getExtensionFileName(const std::string& name) { + return common::stringFormat(EXTENSION_FILE_NAME, common::StringUtils::getLower(name), + EXTENSION_FILE_SUFFIX); +} + +std::string ExtensionUtils::getLocalPathForExtensionLib(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}/{}", getLocalDirForExtension(context, extensionName), + getExtensionFileName(extensionName)); +} + +std::string ExtensionUtils::getLocalPathForExtensionLoader(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}/{}", getLocalDirForExtension(context, extensionName), + getExtensionFileName(extensionName + EXTENSION_LOADER_SUFFIX)); +} + +std::string ExtensionUtils::getLocalPathForExtensionInstaller(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}/{}", getLocalDirForExtension(context, extensionName), + getExtensionFileName(extensionName + EXTENSION_INSTALLER_SUFFIX)); +} + +std::string ExtensionUtils::getLocalDirForExtension(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}{}", context->getExtensionDir(), extensionName); +} + +std::string ExtensionUtils::appendLibSuffix(const std::string& libName) { + auto os = getOS(); + std::string suffix; + if (os == "linux" || os == "linux_old") { + suffix = "so"; + } else if (os == "osx") { + suffix = "dylib"; + } else { + KU_UNREACHABLE; + } + return common::stringFormat("{}.{}", libName, suffix); +} + +std::string ExtensionUtils::getLocalPathForSharedLib(main::ClientContext* context, + const std::string& libName) { + return common::stringFormat("{}common/{}", context->getExtensionDir(), libName); +} + +std::string ExtensionUtils::getLocalPathForSharedLib(main::ClientContext* context) { + return common::stringFormat("{}common/", context->getExtensionDir()); +} + +bool ExtensionUtils::isOfficialExtension(const std::string& extension) { + auto extensionUpperCase = common::StringUtils::getUpper(extension); + for (auto& officialExtension : OFFICIAL_EXTENSION) { + if (officialExtension == extensionUpperCase) { + return true; + } + } + return false; +} + +void ExtensionUtils::registerIndexType(main::Database& database, storage::IndexType type) { + database.getStorageManager()->registerIndexType(std::move(type)); +} + +ExtensionLibLoader::ExtensionLibLoader(const std::string& extensionName, const std::string& path) + : extensionName{extensionName} { + libHdl = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (libHdl == nullptr) { + throw common::IOException(common::stringFormat( + "Failed to load library: {} which is needed by extension: {}.\nError: {}.", path, + extensionName, common::dlErrMessage())); + } +} + +ext_load_func_t ExtensionLibLoader::getLoadFunc() { + return (ext_load_func_t)getDynamicLibFunc(EXTENSION_LOAD_FUNC_NAME); +} + +ext_init_func_t ExtensionLibLoader::getInitFunc() { + return (ext_init_func_t)getDynamicLibFunc(EXTENSION_INIT_FUNC_NAME); +} + +ext_name_func_t ExtensionLibLoader::getNameFunc() { + return (ext_name_func_t)getDynamicLibFunc(EXTENSION_NAME_FUNC_NAME); +} + +ext_install_func_t ExtensionLibLoader::getInstallFunc() { + return (ext_install_func_t)getDynamicLibFunc(EXTENSION_INSTALL_FUNC_NAME); +} + +void ExtensionLibLoader::unload() { + KU_ASSERT(libHdl != nullptr); + dlclose(libHdl); + libHdl = nullptr; +} + +void* ExtensionLibLoader::getDynamicLibFunc(const std::string& funcName) { + KU_ASSERT(libHdl != nullptr); + auto sym = dlsym(libHdl, funcName.c_str()); + if (sym == nullptr) { + throw common::IOException( + common::stringFormat("Failed to load {} function in extension {}.\nError: {}", funcName, + extensionName, common::dlErrMessage())); + } + return sym; +} + +#ifdef _WIN32 +std::wstring utf8ToUnicode(const char* input) { + uint32_t result; + + result = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); + if (result == 0) { + throw common::IOException("Failure in MultiByteToWideChar"); + } + auto buffer = std::make_unique(result); + result = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result); + if (result == 0) { + throw common::IOException("Failure in MultiByteToWideChar"); + } + return std::wstring(buffer.get(), result); +} + +void* dlopen(const char* file, int /*mode*/) { + KU_ASSERT(file); + auto fpath = utf8ToUnicode(file); + return (void*)LoadLibraryW(fpath.c_str()); +} + +void* dlsym(void* handle, const char* name) { + KU_ASSERT(handle); + return (void*)GetProcAddress((HINSTANCE)handle, name); +} + +void dlclose(void* handle) { + KU_ASSERT(handle); + FreeLibrary((HINSTANCE)handle); +} +#endif + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_entries.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_entries.cpp new file mode 100644 index 0000000000..ae3029cbb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_entries.cpp @@ -0,0 +1,75 @@ +#include "common/string_utils.h" +#include "extension/extension_manager.h" + +namespace lbug { +namespace extension { + +struct EntriesForExtension { + const char* extensionName; + std::span entries; + size_t numEntries; +}; + +static constexpr std::array ftsExtensionFunctions = {"STEM", "QUERY_FTS_INDEX", "CREATE_FTS_INDEX", + "DROP_FTS_INDEX"}; +static constexpr std::array jsonExtensionFunctions = {"TO_JSON", "JSON_QUOTE", "ARRAY_TO_JSON", + "ROW_TO_JSON", "CAST_TO_JSON", "JSON_ARRAY", "JSON_OBJECT", "JSON_MERGE_PATCH", "COPY_JSON", + "JSON_EXTRACT", "JSON_ARRAY_LENGTH", "JSON_CONTAINS", "JSON_KEYS", "JSON_STRUCTURE", + "JSON_TYPE", "JSON_VALID", "JSON"}; +static constexpr std::array duckdbExtensionFunctions = {"CLEAR_ATTACHED_DB_CACHE"}; +static constexpr std::array deltaExtensionFunctions = {"DELTA_SCAN"}; +static constexpr std::array icebergExtensionFunctions = {"ICEBERG_SCAN", "ICEBERG_METADATA", + "ICEBERG_SNAPSHOTS"}; +static constexpr std::array azureExtensionFunctions = {"AZURE_SCAN"}; +static constexpr std::array vectorExtensionFunctions = {"QUERY_VECTOR_INDEX", "CREATE_VECTOR_INDEX", + "DROP_VECTOR_INDEX"}; +static constexpr std::array llmExtensionFunctions = {"CREATE_EMBEDDING"}; +static constexpr std::array neo4jExtensionFunctions = {"NEO4J_MIGRATE"}; +static constexpr std::array algoExtensionFunctions = {"K_CORE_DECOMPOSITION", "PAGE_RANK", + "STRONGLY_CONNECTED_COMPONENTS_KOSARAJU", "STRONGLY_CONNECTED_COMPONENTS", + "WEAKLY_CONNECTED_COMPONENTS"}; + +static constexpr EntriesForExtension functionsForExtensionsRaw[] = { + {"FTS", ftsExtensionFunctions, ftsExtensionFunctions.size()}, + {"DUCKDB", duckdbExtensionFunctions, duckdbExtensionFunctions.size()}, + {"DELTA", deltaExtensionFunctions, deltaExtensionFunctions.size()}, + {"ICEBERG", icebergExtensionFunctions, icebergExtensionFunctions.size()}, + {"AZURE", azureExtensionFunctions, azureExtensionFunctions.size()}, + {"JSON", jsonExtensionFunctions, jsonExtensionFunctions.size()}, + {"VECTOR", vectorExtensionFunctions, vectorExtensionFunctions.size()}, + {"LLM", llmExtensionFunctions, llmExtensionFunctions.size()}, + {"NEO4J", neo4jExtensionFunctions, neo4jExtensionFunctions.size()}, + {"ALGO", algoExtensionFunctions, algoExtensionFunctions.size()}, +}; +static constexpr std::array functionsForExtensions = std::to_array(functionsForExtensionsRaw); + +static constexpr std::array jsonExtensionTypes = {"JSON"}; +static constexpr std::array typesForExtensions = { + EntriesForExtension{"JSON", jsonExtensionTypes, jsonExtensionTypes.size()}}; + +static std::optional lookupExtensionsByEntryName(std::string_view functionName, + std::span entriesForExtensions) { + std::vector ret; + for (const auto extension : entriesForExtensions) { + for (const auto* entry : extension.entries) { + if (entry == functionName) { + return ExtensionEntry{.name = entry, .extensionName = extension.extensionName}; + } + } + } + return {}; +} + +std::optional ExtensionManager::lookupExtensionsByFunctionName( + std::string_view functionName) { + return lookupExtensionsByEntryName(common::StringUtils::getUpper(functionName), + functionsForExtensions); +} + +std::optional ExtensionManager::lookupExtensionsByTypeName( + std::string_view typeName) { + return lookupExtensionsByEntryName(common::StringUtils::getUpper(typeName), typesForExtensions); +} + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_installer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_installer.cpp new file mode 100644 index 0000000000..16fa63b158 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_installer.cpp @@ -0,0 +1,94 @@ +#include "extension/extension_installer.h" + +#include "common/exception/io.h" +#include "common/file_system/virtual_file_system.h" +#include "httplib.h" +#include "main/client_context.h" + +namespace lbug { +namespace extension { + +void ExtensionInstaller::tryDownloadExtensionFile(const ExtensionRepoInfo& repoInfo, + const std::string& localFilePath) { + httplib::Client cli(repoInfo.hostURL.c_str()); + httplib::Headers headers = { + {"User-Agent", common::stringFormat("lbug/v{}", LBUG_EXTENSION_VERSION)}}; + auto res = cli.Get(repoInfo.hostPath.c_str(), headers); + if (!res || res->status != 200) { + if (res.error() == httplib::Error::Success) { + // LCOV_EXCL_START + throw common::IOException(common::stringFormat( + "HTTP Returns: {}, Failed to download extension: \"{}\" from {}.", + res.value().status, info.name, repoInfo.repoURL)); + // LCOC_EXCL_STOP + } else { + throw common::IOException( + common::stringFormat("Failed to download extension: {} at URL {} (ERROR: {})", + info.name, repoInfo.repoURL, to_string(res.error()))); + } + } + + auto vfs = common::VirtualFileSystem::GetUnsafe(context); + auto fileInfo = vfs->openFile(localFilePath, + common::FileOpenFlags(common::FileFlags::WRITE | common::FileFlags::READ_ONLY | + common::FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS)); + fileInfo->writeFile(reinterpret_cast(res->body.c_str()), res->body.size(), + 0 /* offset */); + fileInfo->syncFile(); +} + +bool ExtensionInstaller::install() { + auto install = installExtension(); + if (install) { + installDependencies(); + } + return install; +} + +bool ExtensionInstaller::installExtension() { + auto vfs = common::VirtualFileSystem::GetUnsafe(context); + auto localExtensionDir = context.getExtensionDir(); + if (!vfs->fileOrPathExists(localExtensionDir, &context)) { + vfs->createDir(localExtensionDir); + } + auto localDirForExtension = + extension::ExtensionUtils::getLocalDirForExtension(&context, info.name); + if (!vfs->fileOrPathExists(localDirForExtension)) { + vfs->createDir(localDirForExtension); + } + auto localLibFilePath = + extension::ExtensionUtils::getLocalPathForExtensionLib(&context, info.name); + if (vfs->fileOrPathExists(localLibFilePath) && !info.forceInstall) { + // The extension has been installed, skip downloading from the repo. + return false; + } + auto localDirForSharedLib = extension::ExtensionUtils::getLocalPathForSharedLib(&context); + if (!vfs->fileOrPathExists(localDirForSharedLib)) { + vfs->createDir(localDirForSharedLib); + } + auto libFileRepoInfo = extension::ExtensionUtils::getExtensionLibRepoInfo(info.name, info.repo); + + tryDownloadExtensionFile(libFileRepoInfo, localLibFilePath); + return true; +} + +void ExtensionInstaller::installDependencies() { + auto extensionRepoInfo = ExtensionUtils::getExtensionInstallerRepoInfo(info.name, info.repo); + httplib::Client cli(extensionRepoInfo.hostURL.c_str()); + httplib::Headers headers = { + {"User-Agent", common::stringFormat("lbug/v{}", LBUG_EXTENSION_VERSION)}}; + auto res = cli.Get(extensionRepoInfo.hostPath.c_str(), headers); + if (!res || res->status != 200) { + // The extension doesn't have an installer. + return; + } + auto extensionInstallerPath = + ExtensionUtils::getLocalPathForExtensionInstaller(&context, info.name); + tryDownloadExtensionFile(extensionRepoInfo, extensionInstallerPath); + auto libLoader = ExtensionLibLoader(info.name, extensionInstallerPath.c_str()); + auto install = libLoader.getInstallFunc(); + (*install)(info.repo, context); +} + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_manager.cpp new file mode 100644 index 0000000000..d01f841596 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/extension_manager.cpp @@ -0,0 +1,105 @@ +#include "extension/extension_manager.h" + +#include "common/file_system/virtual_file_system.h" +#include "common/string_utils.h" +#include "extension/extension.h" +#include "generated_extension_loader.h" +#include "storage/wal/local_wal.h" +#include "transaction/transaction_context.h" + +namespace lbug { +namespace extension { + +static void executeExtensionLoader(main::ClientContext* context, const std::string& extensionName) { + auto loaderPath = ExtensionUtils::getLocalPathForExtensionLoader(context, extensionName); + if (common::VirtualFileSystem::GetUnsafe(*context)->fileOrPathExists(loaderPath)) { + auto libLoader = ExtensionLibLoader(extensionName, loaderPath); + auto load = libLoader.getLoadFunc(); + (*load)(context); + } +} + +void ExtensionManager::loadExtension(const std::string& path, main::ClientContext* context) { + auto fullPath = path; + bool isOfficial = ExtensionUtils::isOfficialExtension(path); + if (isOfficial) { + auto localPathForSharedLib = ExtensionUtils::getLocalPathForSharedLib(context); + if (!common::VirtualFileSystem::GetUnsafe(*context)->fileOrPathExists( + localPathForSharedLib)) { + common::VirtualFileSystem::GetUnsafe(*context)->createDir(localPathForSharedLib); + } + executeExtensionLoader(context, path); + fullPath = ExtensionUtils::getLocalPathForExtensionLib(context, path); + } + + auto libLoader = ExtensionLibLoader(path, fullPath); + auto name = libLoader.getNameFunc(); + std::string extensionName = (*name)(); + if (std::any_of(loadedExtensions.begin(), loadedExtensions.end(), + [&](const LoadedExtension& ext) { return ext.getExtensionName() == extensionName; })) { + libLoader.unload(); + return; + } + auto init = libLoader.getInitFunc(); + (*init)(context); + loadedExtensions.push_back(LoadedExtension(extensionName, fullPath, + isOfficial ? ExtensionSource::OFFICIAL : ExtensionSource::USER)); + auto transaction = transaction::Transaction::Get(*context); + if (transaction->shouldLogToWAL()) { + transaction->getLocalWAL().logLoadExtension(path); + } +} + +std::string ExtensionManager::toCypher() { + std::string cypher; + for (auto& extension : loadedExtensions) { + cypher += extension.toCypher(); + } + return cypher; +} + +void ExtensionManager::addExtensionOption(std::string name, common::LogicalTypeID type, + common::Value defaultValue, bool isConfidential) { + if (getExtensionOption(name) != nullptr) { + // One extension option can be shared by multiple extensions. + return; + } + common::StringUtils::toLower(name); + extensionOptions.emplace(name, + main::ExtensionOption{name, type, std::move(defaultValue), isConfidential}); +} + +const main::ExtensionOption* ExtensionManager::getExtensionOption(std::string name) const { + common::StringUtils::toLower(name); + return extensionOptions.contains(name) ? &extensionOptions.at(name) : nullptr; +} + +void ExtensionManager::registerStorageExtension(std::string name, + std::unique_ptr storageExtension) { + if (storageExtensions.contains(name)) { + return; + } + storageExtensions.emplace(std::move(name), std::move(storageExtension)); +} + +std::vector ExtensionManager::getStorageExtensions() { + std::vector storageExtensionsToReturn; + for (auto& [name, storageExtension] : storageExtensions) { + storageExtensionsToReturn.push_back(storageExtension.get()); + } + return storageExtensionsToReturn; +} + +void ExtensionManager::autoLoadLinkedExtensions(main::ClientContext* context) { + auto trxContext = transaction::TransactionContext::Get(*context); + trxContext->beginRecoveryTransaction(); + loadLinkedExtensions(context, loadedExtensions); + trxContext->commit(); +} + +ExtensionManager* ExtensionManager::Get(const main::ClientContext& context) { + return context.getDatabase()->getExtensionManager(); +} + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/generated_extension_loader.cpp.in b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/generated_extension_loader.cpp.in new file mode 100644 index 0000000000..dd8e7c5e3f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/generated_extension_loader.cpp.in @@ -0,0 +1,13 @@ +#include "extension/loaded_extension.h" +#include "generated_extension_loader.h" + +namespace lbug { +namespace extension { + +void loadLinkedExtensions([[maybe_unused]] main::ClientContext* context, + [[maybe_unused]] std::vector& loadedExtensions) { + @LOAD_LINKED_EXTENSION@ +} + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/generated_extension_loader.h.in b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/generated_extension_loader.h.in new file mode 100644 index 0000000000..0e60030ca8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/generated_extension_loader.h.in @@ -0,0 +1,12 @@ +#pragma once + +#include "main/client_context.h" +@INCLUDES@ + +namespace lbug { +namespace extension { + +void loadLinkedExtensions(main::ClientContext* context, std::vector& loadedExtensions); + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/extension/loaded_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/loaded_extension.cpp new file mode 100644 index 0000000000..041771ac4e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/extension/loaded_extension.cpp @@ -0,0 +1,23 @@ +#include "extension/loaded_extension.h" + +#include "common/assert.h" + +namespace lbug { +namespace extension { + +std::string LoadedExtension::toCypher() { + switch (source) { + case ExtensionSource::OFFICIAL: + return common::stringFormat("INSTALL {};\nLOAD EXTENSION {};\n", extensionName, + extensionName); + case ExtensionSource::USER: + return common::stringFormat("LOAD EXTENSION '{}';\n", fullPath); + case ExtensionSource::STATIC_LINKED: + return ""; + default: + KU_UNREACHABLE; + } +} + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/CMakeLists.txt new file mode 100644 index 0000000000..43de0feb64 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/CMakeLists.txt @@ -0,0 +1,48 @@ +add_subdirectory(aggregate) +add_subdirectory(arithmetic) +add_subdirectory(array) +add_subdirectory(cast) +add_subdirectory(date) +add_subdirectory(gds) +add_subdirectory(list) +add_subdirectory(map) +add_subdirectory(path) +add_subdirectory(pattern) +add_subdirectory(sequence) +add_subdirectory(struct) +add_subdirectory(table) +add_subdirectory(union) +add_subdirectory(utility) +add_subdirectory(uuid) +add_subdirectory(string) +add_subdirectory(export) +add_subdirectory(internal_id) +add_subdirectory(timestamp) + +add_library(lbug_function + OBJECT + aggregate_function.cpp + base_lower_upper_operation.cpp + built_in_function_utils.cpp + cast_string_non_nested_functions.cpp + cast_from_string_functions.cpp + comparison_functions.cpp + find_function.cpp + function.cpp + function_collection.cpp + scalar_macro_function.cpp + vector_arithmetic_functions.cpp + vector_boolean_functions.cpp + vector_cast_functions.cpp + vector_date_functions.cpp + vector_hash_functions.cpp + vector_null_functions.cpp + vector_node_rel_functions.cpp + vector_string_functions.cpp + vector_timestamp_functions.cpp + vector_blob_functions.cpp + vector_uuid_functions.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/CMakeLists.txt new file mode 100644 index 0000000000..ba4550c3f7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(lbug_function_aggregate + OBJECT + count.cpp + count_star.cpp + collect.cpp + min_max.cpp + sum.cpp + avg.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/avg.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/avg.cpp new file mode 100644 index 0000000000..d84e1ce434 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/avg.cpp @@ -0,0 +1,17 @@ +#include "function/aggregate/avg.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +function_set AggregateAvgFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + AggregateFunctionUtils::appendSumOrAvgFuncs(name, typeID, result); + } + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/collect.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/collect.cpp new file mode 100644 index 0000000000..4a78e202c9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/collect.cpp @@ -0,0 +1,185 @@ +#include "function/aggregate_function.h" +#include "storage/storage_utils.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +/** + * For collect each grouped key corresponds to a list of values + * We store this value as a linked list where each element is allocated from the shared overflow + * buffer + * The format for each element in the list is the following in order: + * - The value of the current element + * - The pointer to the next element in the list + */ +struct CollectListElement { + CollectListElement() : elementPtr(nullptr) {} + explicit CollectListElement(uint8_t* elementPtr) : elementPtr(elementPtr) {} + + CollectListElement getNextElement() const { return CollectListElement{*getNextElementPtr()}; } + uint8_t** getNextElementPtr() const { return reinterpret_cast(elementPtr); } + void setNextElement(CollectListElement next) const { + KU_ASSERT(*getNextElementPtr() == nullptr); + *getNextElementPtr() = next.elementPtr; + } + void setNextElement(std::nullptr_t next) const { *getNextElementPtr() = next; } + uint8_t* getDataPtr() const { return elementPtr + sizeof(uint8_t*); } + + static uint64_t size(LogicalType& elementType) { + return sizeof(uint8_t*) + StorageUtils::getDataTypeSize(elementType); + } + + bool valid() const { return elementPtr; } + + uint8_t* elementPtr; +}; + +struct CollectState : public AggregateStateWithNull { + CollectState() = default; + uint32_t getStateSize() const override { return sizeof(*this); } + void writeToVector(common::ValueVector* outputVector, uint64_t pos) override; + + void appendElement(ValueVector* input, uint32_t pos, InMemOverflowBuffer* overflowBuffer); + void resetList(); + void appendList(const CollectState& o); + + // We store the head + tail of the linked list + CollectListElement head; + CollectListElement tail; + uint64_t listSize = 0; + + // CollectStates are stored in factorizedTable entries. When the factorizedTable is + // destructed, the destructor of CollectStates won't be called. Therefore, we need to make sure + // that no additional actions are required for destructing the head/tail outside of deallocating + // their memory + + static_assert(std::is_trivially_destructible_v); +}; + +void CollectState::appendList(const CollectState& o) { + if (head.valid()) { + KU_ASSERT(tail.valid()); + tail.setNextElement(o.head); + tail = o.tail; + } else { + head = o.head; + tail = o.tail; + } + listSize += o.listSize; +} + +void CollectState::appendElement(ValueVector* input, uint32_t pos, + InMemOverflowBuffer* overflowBuffer) { + CollectListElement newElement{ + overflowBuffer->allocateSpace(CollectListElement::size(input->dataType))}; + newElement.setNextElement(nullptr); + input->copyToRowData(pos, newElement.getDataPtr(), overflowBuffer); + + if (tail.valid()) { + tail.setNextElement(newElement); + } else { + KU_ASSERT(!head.valid()); + head = newElement; + } + tail = newElement; + + ++listSize; +} + +void CollectState::resetList() { + head = {}; + tail = {}; + listSize = 0; +} + +void CollectState::writeToVector(common::ValueVector* outputVector, uint64_t pos) { + auto listEntry = common::ListVector::addList(outputVector, listSize); + outputVector->setValue(pos, listEntry); + auto outputDataVector = common::ListVector::getDataVector(outputVector); + CollectListElement curElement = head; + for (auto i = 0u; i < listEntry.size; i++) { + KU_ASSERT(curElement.valid()); + outputDataVector->copyFromRowData(listEntry.offset + i, curElement.getDataPtr()); + curElement = curElement.getNextElement(); + } +} + +static std::unique_ptr initialize() { + return std::make_unique(); +} + +static void updateSingleValue(CollectState* state, ValueVector* input, uint32_t pos, + uint64_t multiplicity, InMemOverflowBuffer* overflowBuffer) { + for (auto i = 0u; i < multiplicity; ++i) { + state->isNull = false; + state->appendElement(input, pos, overflowBuffer); + } +} + +static void updateAll(uint8_t* state_, ValueVector* input, uint64_t multiplicity, + InMemOverflowBuffer* overflowBuffer) { + KU_ASSERT(!input->state->isFlat()); + auto state = reinterpret_cast(state_); + auto& inputSelVector = input->state->getSelVector(); + if (input->hasNoNullsGuarantee()) { + for (auto i = 0u; i < inputSelVector.getSelSize(); ++i) { + auto pos = inputSelVector[i]; + updateSingleValue(state, input, pos, multiplicity, overflowBuffer); + } + } else { + for (auto i = 0u; i < inputSelVector.getSelSize(); ++i) { + auto pos = inputSelVector[i]; + if (!input->isNull(pos)) { + updateSingleValue(state, input, pos, multiplicity, overflowBuffer); + } + } + } +} + +static void updatePos(uint8_t* state_, ValueVector* input, uint64_t multiplicity, uint32_t pos, + InMemOverflowBuffer* overflowBuffer) { + auto state = reinterpret_cast(state_); + updateSingleValue(state, input, pos, multiplicity, overflowBuffer); +} + +static void finalize(uint8_t* /*state_*/) {} + +static void combine(uint8_t* state_, uint8_t* otherState_, + InMemOverflowBuffer* /*overflowBuffer*/) { + auto otherState = reinterpret_cast(otherState_); + if (otherState->isNull) { + return; + } + auto state = reinterpret_cast(state_); + state->appendList(*otherState); + state->isNull = false; + otherState->resetList(); + otherState->isNull = true; +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto aggFuncDefinition = reinterpret_cast(input.definition); + aggFuncDefinition->parameterTypeIDs[0] = input.arguments[0]->dataType.getLogicalTypeID(); + auto returnType = LogicalType::LIST(input.arguments[0]->dataType.copy()); + return std::make_unique(std::move(returnType)); +} + +function_set CollectFunction::getFunctionSet() { + function_set result; + for (auto isDistinct : std::vector{true, false}) { + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::LIST, initialize, + updateAll, updatePos, combine, finalize, isDistinct, bindFunc, + nullptr /* paramRewriteFunc */)); + } + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/count.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/count.cpp new file mode 100644 index 0000000000..d6321b7607 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/count.cpp @@ -0,0 +1,43 @@ +#include "function/aggregate/count.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +void CountFunction::updateAll(uint8_t* state_, ValueVector* input, uint64_t multiplicity, + InMemOverflowBuffer* /*overflowBuffer*/) { + auto state = reinterpret_cast(state_); + state->count += multiplicity * input->countNonNull(); +} + +void CountFunction::paramRewriteFunc(expression_vector& arguments) { + KU_ASSERT(arguments.size() == 1); + if (ExpressionUtil::isNodePattern(*arguments[0])) { + arguments[0] = arguments[0]->constCast().getInternalID(); + } else if (ExpressionUtil::isRelPattern(*arguments[0])) { + arguments[0] = arguments[0]->constCast().getInternalID(); + } +} + +function_set CountFunction::getFunctionSet() { + function_set result; + for (auto& type : LogicalTypeUtils::getAllValidLogicTypeIDs()) { + for (auto isDistinct : std::vector{true, false}) { + auto func = AggregateFunctionUtils::getAggFunc(name, type, + LogicalTypeID::INT64, isDistinct, paramRewriteFunc); + func->needToHandleNulls = true; + result.push_back(std::move(func)); + } + } + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/count_star.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/count_star.cpp new file mode 100644 index 0000000000..8e354fb258 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/count_star.cpp @@ -0,0 +1,35 @@ +#include "function/aggregate/count_star.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace function { + +void CountStarFunction::updateAll(uint8_t* state_, ValueVector* input, uint64_t multiplicity, + InMemOverflowBuffer* /*overflowBuffer*/) { + auto state = reinterpret_cast(state_); + KU_ASSERT(input == nullptr); + (void)input; + state->count += multiplicity; +} + +void CountStarFunction::updatePos(uint8_t* state_, ValueVector* input, uint64_t multiplicity, + uint32_t /*pos*/, InMemOverflowBuffer* /*overflowBuffer*/) { + auto state = reinterpret_cast(state_); + KU_ASSERT(input == nullptr); + (void)input; + state->count += multiplicity; +} + +function_set CountStarFunction::getFunctionSet() { + function_set result; + auto aggFunc = std::make_unique(name, std::vector{}, + LogicalTypeID::INT64, initialize, updateAll, updatePos, combine, finalize, false); + aggFunc->needToHandleNulls = true; + result.push_back(std::move(aggFunc)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/min_max.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/min_max.cpp new file mode 100644 index 0000000000..98920a903d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/min_max.cpp @@ -0,0 +1,45 @@ +#include "function/aggregate/min_max.h" + +#include "common/type_utils.h" +#include "function/comparison/comparison_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +template +static void getMinMaxFunction(std::string name, function_set& set) { + std::unique_ptr func; + for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) { + auto inputTypes = std::vector{type}; + for (auto isDistinct : std::vector{true, false}) { + common::TypeUtils::visit( + LogicalType::getPhysicalType(type), + [&](T) { + func = std::make_unique(name, inputTypes, type, + MinMaxFunction::initialize, MinMaxFunction::template updateAll, + MinMaxFunction::template updatePos, + MinMaxFunction::template combine, MinMaxFunction::finalize, + isDistinct); + }, + [](auto) { KU_UNREACHABLE; }); + set.push_back(std::move(func)); + } + } +} + +function_set AggregateMinFunction::getFunctionSet() { + function_set result; + getMinMaxFunction(AggregateMinFunction::name, result); + return result; +} + +function_set AggregateMaxFunction::getFunctionSet() { + function_set result; + getMinMaxFunction(AggregateMaxFunction::name, result); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/sum.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/sum.cpp new file mode 100644 index 0000000000..0efb546502 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate/sum.cpp @@ -0,0 +1,17 @@ +#include "function/aggregate/sum.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +function_set AggregateSumFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + AggregateFunctionUtils::appendSumOrAvgFuncs(name, typeID, result); + } + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate_function.cpp new file mode 100644 index 0000000000..aeb0d7ccb4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/aggregate_function.cpp @@ -0,0 +1,62 @@ +#include "function/aggregate_function.h" + +#include "common/type_utils.h" +#include "function/aggregate/avg.h" +#include "function/aggregate/sum.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace function { + +AggregateFunction::AggregateFunction(const AggregateFunction& other) + : ScalarOrAggregateFunction{other.name, other.parameterTypeIDs, other.returnTypeID, + other.bindFunc} { + needToHandleNulls = other.needToHandleNulls; + isDistinct = other.isDistinct; + initializeFunc = other.initializeFunc; + updateAllFunc = other.updateAllFunc; + updatePosFunc = other.updatePosFunc; + combineFunc = other.combineFunc; + finalizeFunc = other.finalizeFunc; + paramRewriteFunc = other.paramRewriteFunc; + initialNullAggregateState = createInitialNullAggregateState(); +} + +template void AggregateFunctionUtils::appendSumOrAvgFuncs(std::string name, + common::LogicalTypeID inputType, function_set& result); +template void AggregateFunctionUtils::appendSumOrAvgFuncs(std::string name, + common::LogicalTypeID inputType, function_set& result); + +template class FunctionType> +void AggregateFunctionUtils::appendSumOrAvgFuncs(std::string name, common::LogicalTypeID inputType, + function_set& result) { + std::unique_ptr aggFunc; + for (auto isDistinct : std::vector{true, false}) { + TypeUtils::visit( + LogicalType{inputType}, + [&](T) { + using ResultType = + std::conditional, uint128_t, int128_t>::type; + LogicalTypeID resultType = + UnsignedIntegerTypes ? LogicalTypeID::UINT128 : LogicalTypeID::INT128; + // For avg aggregate functions, the result type is always double. + if constexpr (std::is_same_v, + AvgFunction>) { + resultType = LogicalTypeID::DOUBLE; + } + aggFunc = AggregateFunctionUtils::getAggFunc>(name, + inputType, resultType, isDistinct); + }, + [&](T) { + aggFunc = AggregateFunctionUtils::getAggFunc>(name, + inputType, LogicalTypeID::DOUBLE, isDistinct); + }, + [](auto) { KU_UNREACHABLE; }); + result.push_back(std::move(aggFunc)); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/CMakeLists.txt new file mode 100644 index 0000000000..fb68e258f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/CMakeLists.txt @@ -0,0 +1,15 @@ +add_library(lbug_function_arithmetic + OBJECT + multiply.cpp + add.cpp + subtract.cpp + divide.cpp + modulo.cpp + negate.cpp + abs.cpp + rand_function.cpp + set_seed.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/abs.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/abs.cpp new file mode 100644 index 0000000000..e8fb117527 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/abs.cpp @@ -0,0 +1,88 @@ +#include "function/arithmetic/abs.h" + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB arithmetic.cpp +template +static inline bool AbsInPlaceWithOverflowCheck(SRC_TYPE input, SRC_TYPE& result) { + if (input == NumericLimits::minimum()) { + return false; + } + result = std::abs(input); + return true; +} + +struct AbsInPlace { + template + static inline bool operation(T& input, T& result); +}; + +template<> +bool inline AbsInPlace::operation(int8_t& input, int8_t& result) { + return AbsInPlaceWithOverflowCheck(input, result); +} + +template<> +bool inline AbsInPlace::operation(int16_t& input, int16_t& result) { + return AbsInPlaceWithOverflowCheck(input, result); +} + +template<> +bool inline AbsInPlace::operation(int32_t& input, int32_t& result) { + return AbsInPlaceWithOverflowCheck(input, result); +} + +template<> +bool AbsInPlace::operation(int64_t& input, int64_t& result) { + return AbsInPlaceWithOverflowCheck(input, result); +} + +template<> +void Abs::operation(int8_t& input, int8_t& result) { + if (!AbsInPlace::operation(input, result)) { + throw common::OverflowException{ + common::stringFormat("Cannot take the absolute value of {} within INT8 range.", + common::TypeUtils::toString(input))}; + } +} + +template<> +void Abs::operation(int16_t& input, int16_t& result) { + if (!AbsInPlace::operation(input, result)) { + throw common::OverflowException{ + common::stringFormat("Cannot take the absolute value of {} within INT16 range.", + common::TypeUtils::toString(input))}; + } +} + +template<> +void Abs::operation(int32_t& input, int32_t& result) { + if (!AbsInPlace::operation(input, result)) { + throw common::OverflowException{ + common::stringFormat("Cannot take the absolute value of {} within INT32 range.", + common::TypeUtils::toString(input))}; + } +} + +template<> +void Abs::operation(int64_t& input, int64_t& result) { + if (!AbsInPlace::operation(input, result)) { + throw common::OverflowException{ + common::stringFormat("Cannot take the absolute value of {} within INT64 range.", + common::TypeUtils::toString(input))}; + } +} + +template<> +void Abs::operation(common::int128_t& input, common::int128_t& result) { + result = input < 0 ? -input : input; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/add.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/add.cpp new file mode 100644 index 0000000000..0957d9c12d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/add.cpp @@ -0,0 +1,157 @@ +#include "function/arithmetic/add.h" + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB add.cpp +template +static inline bool addInPlaceWithOverflowCheck(SRC_TYPE left, SRC_TYPE right, SRC_TYPE& result) { + DST_TYPE uresult; + uresult = static_cast(left) + static_cast(right); + if (uresult < NumericLimits::minimum() || + uresult > NumericLimits::maximum()) { + return false; + } + result = static_cast(uresult); + return true; +} + +struct AddInPlace { + template + static inline bool operation(A& left, B& right, R& result); +}; + +template<> +bool inline AddInPlace::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline AddInPlace::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline AddInPlace::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool AddInPlace::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (NumericLimits::maximum() - left < right) { + return false; + } + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline AddInPlace::operation(int8_t& left, int8_t& right, int8_t& result) { + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline AddInPlace::operation(int16_t& left, int16_t& right, int16_t& result) { + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline AddInPlace::operation(int32_t& left, int32_t& right, int32_t& result) { + return addInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool AddInPlace::operation(int64_t& left, int64_t& right, int64_t& result) { +#if (__GNUC__ >= 5) || defined(__clang__) + if (__builtin_add_overflow(left, right, &result)) { + return false; + } +#else + // https://blog.regehr.org/archives/1139 + int64_t tmp = int64_t((uint64_t)left + (uint64_t)right); + if ((left < 0 && right < 0 && tmp >= 0) || (left >= 0 && right >= 0 && tmp < 0)) { + return false; + } + result = std::move(tmp); +#endif + return true; +} + +template<> +void Add::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within UINT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within UINT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within UINT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within UINT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(int8_t& left, int8_t& right, int8_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within INT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(int16_t& left, int16_t& right, int16_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within INT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(int32_t& left, int32_t& right, int32_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within INT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Add::operation(int64_t& left, int64_t& right, int64_t& result) { + if (!AddInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} + {} is not within INT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/divide.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/divide.cpp new file mode 100644 index 0000000000..74e7778719 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/divide.cpp @@ -0,0 +1,170 @@ +#include "function/arithmetic/divide.h" + +#include "common/exception/overflow.h" +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB multiply.cpp +template +static inline bool tryDivideWithOverflowCheck(SRC_TYPE left, SRC_TYPE right, SRC_TYPE& result) { + DST_TYPE uresult; + uresult = static_cast(left) / static_cast(right); + if (uresult < NumericLimits::minimum() || + uresult > NumericLimits::maximum()) { + return false; + } + result = static_cast(uresult); + return true; +} + +struct TryDivide { + template + static inline bool operation(A& left, B& right, R& result); +}; + +template<> +bool inline TryDivide::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryDivide::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryDivide::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool TryDivide::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryDivide::operation(int8_t& left, int8_t& right, int8_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryDivide::operation(int16_t& left, int16_t& right, int16_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryDivide::operation(int32_t& left, int32_t& right, int32_t& result) { + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +bool TryDivide::operation(int64_t& left, int64_t& right, int64_t& result) { + if (left == NumericLimits::minimum() && right == -1) { + return false; + } + return tryDivideWithOverflowCheck(left, right, result); +} + +template<> +void Divide::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within UINT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within UINT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within UINT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within UINT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(int8_t& left, int8_t& right, int8_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within INT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(int16_t& left, int16_t& right, int16_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within INT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(int32_t& left, int32_t& right, int32_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within INT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Divide::operation(int64_t& left, int64_t& right, int64_t& result) { + if (right == 0) { + throw common::RuntimeException("Divide by zero."); + } + if (!TryDivide::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} / {} is not within INT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/modulo.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/modulo.cpp new file mode 100644 index 0000000000..55be6cdd69 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/modulo.cpp @@ -0,0 +1,144 @@ +#include "function/arithmetic/modulo.h" + +#include "common/exception/overflow.h" +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB multiply.cpp +template +static inline bool tryModuloWithOverflowCheck(SRC_TYPE left, SRC_TYPE right, SRC_TYPE& result) { + DST_TYPE uresult; + if (left == NumericLimits::minimum() && right == -1) { + return false; + } + uresult = static_cast(left) % static_cast(right); + result = static_cast(uresult); + return true; +} + +struct TryModulo { + template + static inline bool operation(A& left, B& right, R& result); +}; + +template<> +bool inline TryModulo::operation(int8_t& left, int8_t& right, int8_t& result) { + return tryModuloWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryModulo::operation(int16_t& left, int16_t& right, int16_t& result) { + return tryModuloWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryModulo::operation(int32_t& left, int32_t& right, int32_t& result) { + return tryModuloWithOverflowCheck(left, right, result); +} + +template<> +bool TryModulo::operation(int64_t& left, int64_t& right, int64_t& result) { + if (left == NumericLimits::minimum() && right == -1) { + return false; + } + return tryModuloWithOverflowCheck(left, right, result); +} + +template<> +void Modulo::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + result = left % right; +} + +template<> +void Modulo::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + result = left % right; +} + +template<> +void Modulo::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + result = left % right; +} + +template<> +void Modulo::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + result = left % right; +} + +template<> +void Modulo::operation(int8_t& left, int8_t& right, int8_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + if (!TryModulo::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} % {} is not within INT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Modulo::operation(int16_t& left, int16_t& right, int16_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + if (!TryModulo::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} % {} is not within INT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Modulo::operation(int32_t& left, int32_t& right, int32_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + if (!TryModulo::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} % {} is not within INT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Modulo::operation(int64_t& left, int64_t& right, int64_t& result) { + if (right == 0) { + throw common::RuntimeException("Modulo by zero."); + } + if (!TryModulo::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} % {} is not within INT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Modulo::operation(common::int128_t& left, common::int128_t& right, common::int128_t& result) { + result = left % right; +} + +template<> +void Modulo::operation(common::uint128_t& left, common::uint128_t& right, + common::uint128_t& result) { + result = left % right; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/multiply.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/multiply.cpp new file mode 100644 index 0000000000..0b9a3f15c1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/multiply.cpp @@ -0,0 +1,230 @@ +#include "function/arithmetic/multiply.h" + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB multiply.cpp +template +static inline bool tryMultiplyWithOverflowCheck(SRC_TYPE left, SRC_TYPE right, SRC_TYPE& result) { + DST_TYPE uresult; + uresult = static_cast(left) * static_cast(right); + if (uresult < NumericLimits::minimum() || + uresult > NumericLimits::maximum()) { + return false; + } + result = static_cast(uresult); + return true; +} + +struct TryMultiply { + template + static inline bool operation(A& left, B& right, R& result); +}; + +template<> +bool inline TryMultiply::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryMultiply::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryMultiply::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool TryMultiply::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (left > right) { + std::swap(left, right); + } + if (left > NumericLimits::maximum()) { + return false; + } + uint32_t c = right >> 32; + uint32_t d = NumericLimits::maximum() & right; + uint64_t r = left * c; + uint64_t s = left * d; + if (r > NumericLimits::maximum()) { + return false; + } + r <<= 32; + if (NumericLimits::maximum() - s < r) { + return false; + } + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryMultiply::operation(int8_t& left, int8_t& right, int8_t& result) { + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryMultiply::operation(int16_t& left, int16_t& right, int16_t& result) { + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool inline TryMultiply::operation(int32_t& left, int32_t& right, int32_t& result) { + return tryMultiplyWithOverflowCheck(left, right, result); +} + +template<> +bool TryMultiply::operation(int64_t& left, int64_t& right, int64_t& result) { +#if (__GNUC__ >= 5) || defined(__clang__) + if (__builtin_mul_overflow(left, right, &result)) { + return false; + } +#else + if (left == std::numeric_limits::min()) { + if (right == 0) { + result = 0; + return true; + } + if (right == 1) { + result = left; + return true; + } + return false; + } + if (right == std::numeric_limits::min()) { + if (left == 0) { + result = 0; + return true; + } + if (left == 1) { + result = right; + return true; + } + return false; + } + uint64_t left_non_negative = uint64_t(std::abs(left)); + uint64_t right_non_negative = uint64_t(std::abs(right)); + // split values into 2 32-bit parts + uint64_t left_high_bits = left_non_negative >> 32; + uint64_t left_low_bits = left_non_negative & 0xffffffff; + uint64_t right_high_bits = right_non_negative >> 32; + uint64_t right_low_bits = right_non_negative & 0xffffffff; + + // check the high bits of both + // the high bits define the overflow + if (left_high_bits == 0) { + if (right_high_bits != 0) { + // only the right has high bits set + // multiply the high bits of right with the low bits of left + // multiply the low bits, and carry any overflow to the high bits + // then check for any overflow + auto low_low = left_low_bits * right_low_bits; + auto low_high = left_low_bits * right_high_bits; + auto high_bits = low_high + (low_low >> 32); + if (high_bits & 0xffffff80000000) { + // there is! abort + return false; + } + } + } else if (right_high_bits == 0) { + // only the left has high bits set + // multiply the high bits of left with the low bits of right + // multiply the low bits, and carry any overflow to the high bits + // then check for any overflow + auto low_low = left_low_bits * right_low_bits; + auto high_low = left_high_bits * right_low_bits; + auto high_bits = high_low + (low_low >> 32); + if (high_bits & 0xffffff80000000) { + // there is! abort + return false; + } + } else { + // both left and right have high bits set: guaranteed overflow + // abort! + return false; + } + // now we know that there is no overflow, we can just perform the multiplication + result = left * right; +#endif + return true; +} + +template<> +void Multiply::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within UINT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within UINT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within UINT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within UINT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(int8_t& left, int8_t& right, int8_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within INT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(int16_t& left, int16_t& right, int16_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within INT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(int32_t& left, int32_t& right, int32_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within INT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Multiply::operation(int64_t& left, int64_t& right, int64_t& result) { + if (!TryMultiply::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} * {} is not within INT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/negate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/negate.cpp new file mode 100644 index 0000000000..4efa5c291c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/negate.cpp @@ -0,0 +1,79 @@ +#include "function/arithmetic/negate.h" + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB arithmetic.cpp +template +static inline bool NegateInPlaceWithOverflowCheck(SRC_TYPE input, SRC_TYPE& result) { + if (input == NumericLimits::minimum()) { + return false; + } + result = -input; + return true; +} + +struct NegateInPlace { + template + static inline bool operation(T& input, T& result); +}; + +template<> +bool inline NegateInPlace::operation(int8_t& input, int8_t& result) { + return NegateInPlaceWithOverflowCheck(input, result); +} + +template<> +bool inline NegateInPlace::operation(int16_t& input, int16_t& result) { + return NegateInPlaceWithOverflowCheck(input, result); +} + +template<> +bool inline NegateInPlace::operation(int32_t& input, int32_t& result) { + return NegateInPlaceWithOverflowCheck(input, result); +} + +template<> +bool NegateInPlace::operation(int64_t& input, int64_t& result) { + return NegateInPlaceWithOverflowCheck(input, result); +} + +template<> +void Negate::operation(int8_t& input, int8_t& result) { + if (!NegateInPlace::operation(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} cannot be negated within INT8 range.", common::TypeUtils::toString(input))}; + } +} + +template<> +void Negate::operation(int16_t& input, int16_t& result) { + if (!NegateInPlace::operation(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} cannot be negated within INT16 range.", common::TypeUtils::toString(input))}; + } +} + +template<> +void Negate::operation(int32_t& input, int32_t& result) { + if (!NegateInPlace::operation(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} cannot be negated within INT32 range.", common::TypeUtils::toString(input))}; + } +} + +template<> +void Negate::operation(int64_t& input, int64_t& result) { + if (!NegateInPlace::operation(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} cannot be negated within INT64 range.", common::TypeUtils::toString(input))}; + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/rand_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/rand_function.cpp new file mode 100644 index 0000000000..6298ec6df8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/rand_function.cpp @@ -0,0 +1,27 @@ +#include "common/random_engine.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/scalar_function.h" +#include "main/client_context.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct Rand { + static void operation(double& result, void* dataPtr) { + auto context = reinterpret_cast(dataPtr)->clientContext; + result = static_cast(RandomEngine::Get(*context)->nextRandomInteger()) / + static_cast(UINT32_MAX); + } +}; + +function_set RandFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, std::vector{}, + LogicalTypeID::DOUBLE, ScalarFunction::NullaryAuxilaryExecFunction)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/set_seed.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/set_seed.cpp new file mode 100644 index 0000000000..c0cc4e07d8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/set_seed.cpp @@ -0,0 +1,27 @@ +#include "common/random_engine.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct SetSeed { + static void operation(double& seed, void* dataPtr) { + auto context = reinterpret_cast(dataPtr)->clientContext; + RandomEngine::Get(*context)->setSeed( + static_cast(seed * static_cast(UINT64_MAX))); + } +}; + +function_set SetSeedFunction::getFunctionSet() { + function_set result; + result.push_back( + std::make_unique(name, std::vector{LogicalTypeID::DOUBLE}, + LogicalTypeID::INT32, ScalarFunction::UnarySetSeedFunction)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/subtract.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/subtract.cpp new file mode 100644 index 0000000000..07b05b7b94 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/arithmetic/subtract.cpp @@ -0,0 +1,171 @@ +#include "function/arithmetic/subtract.h" + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +// reference from duckDB subtract.cpp +template +static inline bool SubtractInPlaceWithOverflowCheck(SRC_TYPE left, SRC_TYPE right, + SRC_TYPE& result) { + DST_TYPE uresult; + uresult = static_cast(left) - static_cast(right); + if (uresult < NumericLimits::minimum() || + uresult > NumericLimits::maximum()) { + return false; + } + result = static_cast(uresult); + return true; +} + +struct SubtractInPlace { + template + static inline bool operation(A& left, B& right, R& result); +}; + +template<> +bool inline SubtractInPlace::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + if (right > left) { + return false; + } + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline SubtractInPlace::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + if (right > left) { + return false; + } + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline SubtractInPlace::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + if (right > left) { + return false; + } + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool SubtractInPlace::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (right > left) { + return false; + } + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline SubtractInPlace::operation(int8_t& left, int8_t& right, int8_t& result) { + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline SubtractInPlace::operation(int16_t& left, int16_t& right, int16_t& result) { + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool inline SubtractInPlace::operation(int32_t& left, int32_t& right, int32_t& result) { + return SubtractInPlaceWithOverflowCheck(left, right, result); +} + +template<> +bool SubtractInPlace::operation(int64_t& left, int64_t& right, int64_t& result) { +#if (__GNUC__ >= 5) || defined(__clang__) + if (__builtin_sub_overflow(left, right, &result)) { + return false; + } +#else + if (right < 0) { + if (NumericLimits::maximum() + right < left) { + return false; + } + } else { + if (NumericLimits::minimum() + right > left) { + return false; + } + } + result = left - right; +#endif + return true; +} + +template<> +void Subtract::operation(uint8_t& left, uint8_t& right, uint8_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within UINT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(uint16_t& left, uint16_t& right, uint16_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within UINT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(uint32_t& left, uint32_t& right, uint32_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within UINT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(uint64_t& left, uint64_t& right, uint64_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within UINT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(int8_t& left, int8_t& right, int8_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within INT8 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(int16_t& left, int16_t& right, int16_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within INT16 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(int32_t& left, int32_t& right, int32_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within INT32 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +template<> +void Subtract::operation(int64_t& left, int64_t& right, int64_t& result) { + if (!SubtractInPlace::operation(left, right, result)) { + throw common::OverflowException{ + common::stringFormat("Value {} - {} is not within INT64 range.", + common::TypeUtils::toString(left), common::TypeUtils::toString(right))}; + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/CMakeLists.txt new file mode 100644 index 0000000000..a085ca8168 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_function_array + OBJECT + array_functions.cpp + array_value.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/array_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/array_functions.cpp new file mode 100644 index 0000000000..be9f5d030b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/array_functions.cpp @@ -0,0 +1,205 @@ +#include "binder/expression/literal_expression.h" +#include "common/exception/binder.h" +#include "function/array/functions/array_cosine_similarity.h" +#include "function/array/functions/array_cross_product.h" +#include "function/array/functions/array_distance.h" +#include "function/array/functions/array_inner_product.h" +#include "function/array/functions/array_squared_distance.h" +#include "function/array/vector_array_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static LogicalType interpretLogicalType(const binder::Expression* expr) { + if (expr->expressionType == ExpressionType::LITERAL && + expr->dataType.getLogicalTypeID() == LogicalTypeID::LIST) { + auto numChildren = + expr->constPtrCast()->getValue().getChildrenSize(); + return LogicalType::ARRAY(ListType::getChildType(expr->dataType).copy(), numChildren); + } + return expr->dataType.copy(); +} + +std::unique_ptr ArrayCrossProductBindFunc(const ScalarBindFuncInput& input) { + auto leftType = interpretLogicalType(input.arguments[0].get()); + auto rightType = interpretLogicalType(input.arguments[1].get()); + if (leftType != rightType) { + throw BinderException( + stringFormat("{} requires both arrays to have the same element type and size of 3", + ArrayCrossProductFunction::name)); + } + scalar_func_exec_t execFunc; + switch (ArrayType::getChildType(leftType).getLogicalTypeID()) { + case LogicalTypeID::INT128: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT64: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT32: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT16: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT8: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::FLOAT: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::DOUBLE: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + default: + throw BinderException{ + stringFormat("{} can only be applied on array of floating points or signed integers", + ArrayCrossProductFunction::name)}; + } + input.definition->ptrCast()->execFunc = execFunc; + const auto resultType = LogicalType::ARRAY(ArrayType::getChildType(leftType).copy(), + ArrayType::getNumElements(leftType)); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +function_set ArrayCrossProductFunction::getFunctionSet() { + function_set result; + auto func = std::make_unique(name, + std::vector{ + LogicalTypeID::ARRAY, + LogicalTypeID::ARRAY, + }, + LogicalTypeID::ARRAY); + func->bindFunc = ArrayCrossProductBindFunc; + result.push_back(std::move(func)); + return result; +} + +static LogicalType getChildType(const LogicalType& type) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::ARRAY: + return ArrayType::getChildType(type).copy(); + case LogicalTypeID::LIST: + return ListType::getChildType(type).copy(); + // LCOV_EXCL_START + default: + throw BinderException(stringFormat( + "Cannot retrieve child type of type {}. LIST or ARRAY is expected.", type.toString())); + // LCOV_EXCL_STOP + } +} + +static void validateChildType(const LogicalType& type, const std::string& functionName) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + return; + default: + throw BinderException( + stringFormat("{} requires argument type to be FLOAT[] or DOUBLE[].", functionName)); + } +} + +static LogicalType validateArrayFunctionParameters(const LogicalType& leftType, + const LogicalType& rightType, const std::string& functionName) { + auto leftChildType = getChildType(leftType); + auto rightChildType = getChildType(rightType); + validateChildType(leftChildType, functionName); + validateChildType(rightChildType, functionName); + if (leftType.getLogicalTypeID() == common::LogicalTypeID::ARRAY) { + return leftType.copy(); + } else if (rightType.getLogicalTypeID() == common::LogicalTypeID::ARRAY) { + return rightType.copy(); + } + throw BinderException( + stringFormat("{} requires at least one argument to be ARRAY but all parameters are LIST.", + functionName)); +} + +template +static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() { + auto execFunc = + ScalarFunction::BinaryExecListStructFunction; + return execFunc; +} + +template +scalar_func_exec_t getScalarExecFunc(LogicalType type) { + scalar_func_exec_t execFunc; + switch (ArrayType::getChildType(type).getLogicalTypeID()) { + case LogicalTypeID::FLOAT: + execFunc = getBinaryArrayExecFuncSwitchResultType(); + break; + case LogicalTypeID::DOUBLE: + execFunc = getBinaryArrayExecFuncSwitchResultType(); + break; + default: + KU_UNREACHABLE; + } + return execFunc; +} + +template +std::unique_ptr arrayTemplateBindFunc(std::string functionName, + ScalarBindFuncInput input) { + auto leftType = interpretLogicalType(input.arguments[0].get()); + auto rightType = interpretLogicalType(input.arguments[1].get()); + auto paramType = validateArrayFunctionParameters(leftType, rightType, functionName); + input.definition->ptrCast()->execFunc = + std::move(getScalarExecFunc(paramType.copy())); + auto bindData = std::make_unique(ArrayType::getChildType(paramType).copy()); + std::vector paramTypes; + for (auto& _ : input.arguments) { + (void)_; + bindData->paramTypes.push_back(paramType.copy()); + } + return bindData; +} + +template +function_set templateGetFunctionSet(const std::string& functionName) { + function_set result; + auto function = std::make_unique(functionName, + std::vector{ + LogicalTypeID::ARRAY, + LogicalTypeID::ARRAY, + }, + LogicalTypeID::ANY); + function->bindFunc = + std::bind(arrayTemplateBindFunc, functionName, std::placeholders::_1); + result.push_back(std::move(function)); + return result; +} + +function_set ArrayCosineSimilarityFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArrayDistanceFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArraySquaredDistanceFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArrayInnerProductFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArrayDotProductFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/array_value.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/array_value.cpp new file mode 100644 index 0000000000..4582e219f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/array/array_value.cpp @@ -0,0 +1,38 @@ +#include "binder/expression/expression_util.h" +#include "function/array/vector_array_functions.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + LogicalType combinedType(LogicalTypeID::ANY); + binder::ExpressionUtil::tryCombineDataType(input.arguments, combinedType); + if (combinedType.getLogicalTypeID() == LogicalTypeID::ANY) { + combinedType = LogicalType::STRING(); + } + auto resultType = LogicalType::ARRAY(combinedType.copy(), input.arguments.size()); + auto bindData = std::make_unique(std::move(resultType)); + for (auto& _ : input.arguments) { + (void)_; + bindData->paramTypes.push_back(combinedType.copy()); + } + return bindData; +} + +function_set ArrayValueFunction::getFunctionSet() { + function_set result; + auto function = + std::make_unique(name, std::vector{LogicalTypeID::ANY}, + LogicalTypeID::ARRAY, ListCreationFunction::execFunc); + function->bindFunc = bindFunc; + function->isVarLength = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/base_lower_upper_operation.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/base_lower_upper_operation.cpp new file mode 100644 index 0000000000..38bd1a1f38 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/base_lower_upper_operation.cpp @@ -0,0 +1,68 @@ +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "function/string/functions/base_lower_upper_function.h" +#include "utf8proc.h" + +using namespace lbug::common; +using namespace lbug::utf8proc; + +namespace lbug { +namespace function { + +uint32_t BaseLowerUpperFunction::getResultLen(char* inputStr, uint32_t inputLen, bool isUpper) { + uint32_t outputLength = 0; + for (uint32_t i = 0; i < inputLen;) { + // For UTF-8 characters, changing case can increase / decrease total byte length. + // Eg.: 'ß' lower case -> 'SS' upper case [more bytes + more chars] + if (inputStr[i] & 0x80) { + int size = 0; + int codepoint = utf8proc_codepoint(inputStr + i, size); + if (codepoint < 0) { + // LCOV_EXCL_START + // TODO(Xiyang): We shouldn't allow invalid UTF-8 to enter a string column. + std::string funcName = isUpper ? "UPPER" : "LOWER"; + throw RuntimeException( + common::stringFormat("Failed calling {}: Invalid UTF-8.", funcName)); + // LCOV_EXCL_STOP + } + int convertedCodepoint = + isUpper ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); + int newSize = utf8proc_codepoint_length(convertedCodepoint); + KU_ASSERT(newSize >= 0); + outputLength += newSize; + i += size; + } else { + outputLength++; + i++; + } + } + return outputLength; +} + +void BaseLowerUpperFunction::convertCharCase(char* result, const char* input, int32_t charPos, + bool toUpper, int& originalSize, int& newSize) { + originalSize = 1; + newSize = 1; + if (input[charPos] & 0x80) { + auto codepoint = utf8proc_codepoint(input + charPos, originalSize); + KU_ASSERT(codepoint >= 0); // Validity ensured by getResultLen. + int convertedCodepoint = + toUpper ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); + utf8proc_codepoint_to_utf8(convertedCodepoint, newSize, result); + } else { + *result = toUpper ? toupper(input[charPos]) : tolower(input[charPos]); + } +} + +void BaseLowerUpperFunction::convertCase(char* result, uint32_t len, char* input, bool toUpper) { + int originalSize = 0; + int newSize = 0; + for (auto i = 0u; i < len;) { + convertCharCase(result, input, i, toUpper, originalSize, newSize); + i += originalSize; + result += newSize; + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/built_in_function_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/built_in_function_utils.cpp new file mode 100644 index 0000000000..c445af2ab2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/built_in_function_utils.cpp @@ -0,0 +1,558 @@ +#include "function/built_in_function_utils.h" + +#include + +#include "catalog/catalog_entry/function_catalog_entry.h" +#include "common/exception/binder.h" +#include "function/aggregate_function.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +static void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, bool isDistinct, + const function::function_set& set); +static void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + const function::function_set& set); + +Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, + const std::vector& inputTypes, + const catalog::FunctionCatalogEntry* functionEntry) { + auto& functionSet = functionEntry->getFunctionSet(); + std::vector candidateFunctions; + uint32_t minCost = UINT32_MAX; + for (auto& function : functionSet) { + auto func = function.get(); + auto cost = getFunctionCost(inputTypes, func, functionEntry->getType()); + if (cost == UINT32_MAX) { + continue; + } + if (cost < minCost) { + candidateFunctions.clear(); + candidateFunctions.push_back(func); + minCost = cost; + } else if (cost == minCost) { + candidateFunctions.push_back(func); + } + } + validateNonEmptyCandidateFunctions(candidateFunctions, name, inputTypes, functionSet); + if (candidateFunctions.size() > 1) { + return getBestMatch(candidateFunctions); + } + validateSpecialCases(candidateFunctions, name, inputTypes, functionSet); + return candidateFunctions[0]; +} + +AggregateFunction* BuiltInFunctionsUtils::matchAggregateFunction(const std::string& name, + const std::vector& inputTypes, bool isDistinct, + const catalog::FunctionCatalogEntry* functionEntry) { + auto& functionSet = functionEntry->getFunctionSet(); + std::vector candidateFunctions; + for (auto& function : functionSet) { + auto aggregateFunction = function->ptrCast(); + auto cost = getAggregateFunctionCost(inputTypes, isDistinct, aggregateFunction); + if (cost == UINT32_MAX) { + continue; + } + candidateFunctions.push_back(aggregateFunction); + } + validateNonEmptyCandidateFunctions(candidateFunctions, name, inputTypes, isDistinct, + functionSet); + KU_ASSERT(candidateFunctions.size() == 1); + return candidateFunctions[0]; +} + +uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTypeID targetTypeID) { + if (inputTypeID == targetTypeID) { + return 0; + } + // TODO(Jiamin): should check any type + if (inputTypeID == LogicalTypeID::ANY || targetTypeID == LogicalTypeID::ANY) { + // anything can be cast to ANY type for (almost no) cost + return 1; + } + if (targetTypeID == LogicalTypeID::STRING) { + return castFromString(inputTypeID); + } + switch (inputTypeID) { + case LogicalTypeID::INT64: + return castInt64(targetTypeID); + case LogicalTypeID::INT32: + return castInt32(targetTypeID); + case LogicalTypeID::INT16: + return castInt16(targetTypeID); + case LogicalTypeID::INT8: + return castInt8(targetTypeID); + case LogicalTypeID::UINT64: + return castUInt64(targetTypeID); + case LogicalTypeID::UINT32: + return castUInt32(targetTypeID); + case LogicalTypeID::UINT16: + return castUInt16(targetTypeID); + case LogicalTypeID::UINT8: + return castUInt8(targetTypeID); + case LogicalTypeID::INT128: + return castInt128(targetTypeID); + case LogicalTypeID::DOUBLE: + return castDouble(targetTypeID); + case LogicalTypeID::FLOAT: + return castFloat(targetTypeID); + case LogicalTypeID::DECIMAL: + return castDecimal(targetTypeID); + case LogicalTypeID::DATE: + return castDate(targetTypeID); + case LogicalTypeID::UUID: + return castUUID(targetTypeID); + case LogicalTypeID::SERIAL: + return castSerial(targetTypeID); + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_TZ: + // currently don't allow timestamp to other timestamp types + // When we implement this in the future, revise tryGetMaxLogicalTypeID + return castTimestamp(targetTypeID); + case LogicalTypeID::LIST: + return castList(targetTypeID); + case LogicalTypeID::ARRAY: + return castArray(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::getTargetTypeCost(LogicalTypeID typeID) { + switch (typeID) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT16: + return 100; + case LogicalTypeID::INT64: + return 101; + case LogicalTypeID::INT32: + return 102; + case LogicalTypeID::INT128: + return 103; + case LogicalTypeID::DECIMAL: + return 104; + case LogicalTypeID::DOUBLE: + return 105; + case LogicalTypeID::TIMESTAMP: + return 120; + case LogicalTypeID::STRING: + return 149; + case LogicalTypeID::STRUCT: + case LogicalTypeID::MAP: + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: + case LogicalTypeID::UNION: + return 160; + default: + return 110; + } +} + +uint32_t BuiltInFunctionsUtils::castInt64(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + case LogicalTypeID::SERIAL: + return 0; + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt32(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt16(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT32: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt8(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT16: + case LogicalTypeID::INT32: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt64(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt32(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT64: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt16(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT32: + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT64: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt8(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT16: + case LogicalTypeID::INT32: + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT64: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt128(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::DECIMAL: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUUID(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::STRING: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castDouble(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castFloat(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castDecimal(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castDate(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::TIMESTAMP: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castSerial(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT64: + return 0; + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castTimestamp(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::TIMESTAMP: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castFromString(LogicalTypeID inputTypeID) { + switch (inputTypeID) { + case LogicalTypeID::BLOB: + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + return UNDEFINED_CAST_COST; + default: // Any other inputTypeID can be cast to String, but this cast has a high cost + return getTargetTypeCost(LogicalTypeID::STRING); + } +} + +uint32_t BuiltInFunctionsUtils::castList(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::ARRAY: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castArray(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::LIST: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +// When there is multiple candidates functions, e.g. double + int and double + double for input +// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double. +// Additionally, we prefer function with string parameter because string is most permissive and +// can be cast to any type. +Function* BuiltInFunctionsUtils::getBestMatch(std::vector& functionsToMatch) { + KU_ASSERT(functionsToMatch.size() > 1); + Function* result = nullptr; + auto cost = UNDEFINED_CAST_COST; + for (auto& function : functionsToMatch) { + auto currentCost = 0u; + std::unordered_set distinctParameterTypes; + for (auto& parameterTypeID : function->parameterTypeIDs) { + if (parameterTypeID != LogicalTypeID::STRING) { + currentCost++; + } + if (!distinctParameterTypes.contains(parameterTypeID)) { + currentCost++; + distinctParameterTypes.insert(parameterTypeID); + } + } + if (currentCost < cost) { + cost = currentCost; + result = function; + } + } + KU_ASSERT(result != nullptr); + return result; +} + +uint32_t BuiltInFunctionsUtils::getFunctionCost(const std::vector& inputTypes, + Function* function, CatalogEntryType type) { + bool isVarLength = (type == CatalogEntryType::SCALAR_FUNCTION_ENTRY ? + function->constPtrCast()->isVarLength : + false); + if (isVarLength) { + KU_ASSERT(function->parameterTypeIDs.size() == 1); + return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0]); + } + return matchParameters(inputTypes, function->parameterTypeIDs); +} + +uint32_t BuiltInFunctionsUtils::getAggregateFunctionCost(const std::vector& inputTypes, + bool isDistinct, AggregateFunction* function) { + if (inputTypes.size() != function->parameterTypeIDs.size() || + isDistinct != function->isDistinct) { + return UINT32_MAX; + } + for (auto i = 0u; i < inputTypes.size(); ++i) { + if (function->parameterTypeIDs[i] == LogicalTypeID::ANY) { + continue; + } else if (inputTypes[i].getLogicalTypeID() != function->parameterTypeIDs[i]) { + return UINT32_MAX; + } + } + return 0; +} + +uint32_t BuiltInFunctionsUtils::matchParameters(const std::vector& inputTypes, + const std::vector& targetTypeIDs) { + if (inputTypes.size() != targetTypeIDs.size()) { + return UINT32_MAX; + } + auto cost = 0u; + for (auto i = 0u; i < inputTypes.size(); ++i) { + auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]); + if (castCost == UNDEFINED_CAST_COST) { + return UINT32_MAX; + } + cost += castCost; + } + return cost; +} + +uint32_t BuiltInFunctionsUtils::matchVarLengthParameters(const std::vector& inputTypes, + LogicalTypeID targetTypeID) { + auto cost = 0u; + for (const auto& inputType : inputTypes) { + auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID); + if (castCost == UNDEFINED_CAST_COST) { + return UINT32_MAX; + } + cost += castCost; + } + return cost; +} + +void BuiltInFunctionsUtils::validateSpecialCases(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + const function::function_set& set) { + // special case for add func + if (name == AddFunction::name) { + auto targetType0 = candidateFunctions[0]->parameterTypeIDs[0]; + auto targetType1 = candidateFunctions[0]->parameterTypeIDs[1]; + auto inputType0 = inputTypes[0].getLogicalTypeID(); + auto inputType1 = inputTypes[1].getLogicalTypeID(); + if ((inputType0 != LogicalTypeID::STRING || inputType1 != LogicalTypeID::STRING) && + targetType0 == LogicalTypeID::STRING && targetType1 == LogicalTypeID::STRING) { + std::string supportedInputsString; + for (auto& function : set) { + supportedInputsString += function->signatureToString() + "\n"; + } + throw BinderException("Cannot match a built-in function for given function " + name + + LogicalTypeUtils::toString(inputTypes) + + ". Supported inputs are\n" + supportedInputsString); + } + } +} + +static std::string alignedString(const std::string& input) { + std::istringstream stream(input); + std::ostringstream result; + std::string line; + std::string prefix = "Expected: "; + std::string padding(prefix.length(), ' '); + bool firstLine = true; + while (std::getline(stream, line)) { + if (firstLine) { + result << line << '\n'; + firstLine = false; + } else { + result << padding << line << '\n'; + } + } + return result.str(); +} + +std::string BuiltInFunctionsUtils::getFunctionMatchFailureMsg(const std::string name, + const std::vector& inputTypes, const std::string& supportedInputs, + bool isDistinct) { + std::string result = stringFormat("Function {} did not receive correct arguments:\n", name); + result += stringFormat("Actual: {}{}\n", isDistinct ? "DISTINCT " : "", + inputTypes.empty() ? "()" : LogicalTypeUtils::toString(inputTypes)); + result += stringFormat("Expected: {}\n", + supportedInputs.empty() ? "()" : alignedString(supportedInputs)); + return result; +} + +void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, bool isDistinct, + const function::function_set& set) { + if (candidateFunctions.empty()) { + std::string supportedInputsString; + for (auto& function : set) { + auto aggregateFunction = function->constPtrCast(); + if (aggregateFunction->isDistinct) { + supportedInputsString += "DISTINCT "; + } + supportedInputsString += aggregateFunction->signatureToString() + "\n"; + } + throw BinderException(BuiltInFunctionsUtils::getFunctionMatchFailureMsg(name, inputTypes, + supportedInputsString, isDistinct)); + } +} + +void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + const function::function_set& set) { + if (candidateFunctions.empty()) { + std::string supportedInputsString; + for (auto& function : set) { + if (function->parameterTypeIDs.empty()) { + continue; + } + supportedInputsString += function->signatureToString() + "\n"; + } + throw BinderException(BuiltInFunctionsUtils::getFunctionMatchFailureMsg(name, inputTypes, + supportedInputsString)); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast/CMakeLists.txt new file mode 100644 index 0000000000..a32d293977 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_function_cast + OBJECT + cast_array.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast/cast_array.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast/cast_array.cpp new file mode 100644 index 0000000000..4f3c500eb6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast/cast_array.cpp @@ -0,0 +1,155 @@ +#include "function/cast/functions/cast_array.h" + +#include "common/exception/conversion.h" +#include "common/type_utils.h" + +namespace lbug { +namespace function { + +bool CastArrayHelper::checkCompatibleNestedTypes(LogicalTypeID sourceTypeID, + LogicalTypeID targetTypeID) { + switch (sourceTypeID) { + case LogicalTypeID::ANY: { + return true; + } + case LogicalTypeID::LIST: { + return targetTypeID == LogicalTypeID::ARRAY || targetTypeID == LogicalTypeID::LIST; + } + case LogicalTypeID::UNION: + case LogicalTypeID::MAP: + case LogicalTypeID::STRUCT: { + return sourceTypeID == targetTypeID || targetTypeID == LogicalTypeID::UNION; + } + case LogicalTypeID::ARRAY: { + return targetTypeID == LogicalTypeID::LIST || targetTypeID == LogicalTypeID::ARRAY; + } + default: + return false; + } +} + +bool CastArrayHelper::isUnionSpecialCast(const LogicalType& srcType, const LogicalType& dstType) { + if (srcType.getLogicalTypeID() != LogicalTypeID::STRUCT || + dstType.getLogicalTypeID() != LogicalTypeID::UNION || + !StructType::hasField(srcType, "tag")) { + return false; + } + for (auto& field : StructType::getFields(srcType)) { + if (!UnionType::hasField(dstType, field.getName()) || + UnionType::getFieldType(dstType, field.getName()) != field.getType()) { + return false; + } + } + return true; +} + +bool CastArrayHelper::containsListToArray(const LogicalType& srcType, const LogicalType& dstType) { + auto srcTypeID = srcType.getLogicalTypeID(); + auto dstTypeID = dstType.getLogicalTypeID(); + + if ((srcTypeID == LogicalTypeID::LIST || srcTypeID == LogicalTypeID::ARRAY) && + dstTypeID == LogicalTypeID::ARRAY) { + return true; + } + + if (!isUnionSpecialCast(srcType, dstType) && + (srcTypeID == LogicalTypeID::UNION || dstTypeID == LogicalTypeID::UNION)) { + return false; + } + + if (checkCompatibleNestedTypes(srcTypeID, dstTypeID)) { + switch (srcType.getPhysicalType()) { + case PhysicalTypeID::LIST: { + return containsListToArray(ListType::getChildType(srcType), + ListType::getChildType(dstType)); + } + case PhysicalTypeID::ARRAY: { + return containsListToArray(ArrayType::getChildType(srcType), + ListType::getChildType(dstType)); + } + case PhysicalTypeID::STRUCT: { + auto srcFieldTypes = StructType::getFieldTypes(srcType); + auto dstFieldTypes = StructType::getFieldTypes(dstType); + if (srcFieldTypes.size() != dstFieldTypes.size()) { + throw ConversionException{ + stringFormat("Unsupported casting function from {} to {}.", srcType.toString(), + dstType.toString())}; + } + + for (auto i = 0u; i < srcFieldTypes.size(); i++) { + if (containsListToArray(*srcFieldTypes[i], *dstFieldTypes[i])) { + return true; + } + } + } break; + default: + return false; + } + } + return false; +} + +void CastArrayHelper::validateListEntry(ValueVector* inputVector, const LogicalType& resultType, + uint64_t pos) { + if (inputVector->isNull(pos)) { + return; + } + const auto& inputType = inputVector->dataType; + + switch (resultType.getPhysicalType()) { + case PhysicalTypeID::ARRAY: { + if (inputType.getPhysicalType() == PhysicalTypeID::LIST) { + auto listEntry = inputVector->getValue(pos); + if (listEntry.size != ArrayType::getNumElements(resultType)) { + throw ConversionException{ + stringFormat("Unsupported casting LIST with incorrect list entry to ARRAY. " + "Expected: {}, Actual: {}.", + ArrayType::getNumElements(resultType), + inputVector->getValue(pos).size)}; + } + auto inputChildVector = ListVector::getDataVector(inputVector); + for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { + validateListEntry(inputChildVector, ArrayType::getChildType(resultType), i); + } + } else if (inputType.getPhysicalType() == PhysicalTypeID::ARRAY) { + if (ArrayType::getNumElements(inputType) != ArrayType::getNumElements(resultType)) { + throw ConversionException( + stringFormat("Unsupported casting function from {} to {}.", + inputType.toString(), resultType.toString())); + } + auto listEntry = inputVector->getValue(pos); + auto inputChildVector = ListVector::getDataVector(inputVector); + for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { + validateListEntry(inputChildVector, ArrayType::getChildType(resultType), i); + } + } + } break; + case PhysicalTypeID::LIST: { + if (inputType.getPhysicalType() == PhysicalTypeID::LIST || + inputType.getPhysicalType() == PhysicalTypeID::ARRAY) { + auto listEntry = inputVector->getValue(pos); + auto inputChildVector = ListVector::getDataVector(inputVector); + for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { + validateListEntry(inputChildVector, ListType::getChildType(resultType), i); + } + } + } break; + case PhysicalTypeID::STRUCT: { + if (inputType.getPhysicalType() == PhysicalTypeID::STRUCT) { + auto fieldVectors = StructVector::getFieldVectors(inputVector); + auto fieldTypes = StructType::getFieldTypes(resultType); + + auto structEntry = inputVector->getValue(pos); + for (auto i = 0u; i < fieldVectors.size(); i++) { + validateListEntry(fieldVectors[i].get(), *fieldTypes[i], structEntry.pos); + } + } + } break; + default: { + return; + } + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast_from_string_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast_from_string_functions.cpp new file mode 100644 index 0000000000..c8cbc98627 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast_from_string_functions.cpp @@ -0,0 +1,1056 @@ +#include "function/cast/functions/cast_from_string_functions.h" + +#include "common/exception/parser.h" +#include "common/string_format.h" +#include "common/types/blob.h" +#include "function/list/functions/list_unique_function.h" +#include "utf8proc_wrapper.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +// ---------------------- cast String Helper ------------------------------ // +struct CastStringHelper { + template + static void cast(const char* input, uint64_t len, T& result, ValueVector* /*vector*/ = nullptr, + uint64_t /*rowToAdd*/ = 0, const CSVOption* /*option*/ = nullptr) { + simpleIntegerCast(input, len, result, LogicalTypeID::INT64); + } +}; + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, int128_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::INT128); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, uint128_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::UINT128); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, int32_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::INT32); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, int16_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::INT16); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, int8_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::INT8); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, uint64_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::UINT64); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, uint32_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::UINT32); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, uint16_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::UINT16); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, uint8_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(input, len, result, LogicalTypeID::UINT8); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, float& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + doubleCast(input, len, result, LogicalTypeID::FLOAT); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, double& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + doubleCast(input, len, result, LogicalTypeID::DOUBLE); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, bool& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + castStringToBool(input, len, result); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, date_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result = Date::fromCString(input, len); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, timestamp_ms_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast(input, len, result, LogicalTypeID::TIMESTAMP_MS); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, timestamp_ns_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast(input, len, result, LogicalTypeID::TIMESTAMP_NS); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, timestamp_sec_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast(input, len, result, + LogicalTypeID::TIMESTAMP_SEC); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, timestamp_tz_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast(input, len, result, LogicalTypeID::TIMESTAMP_TZ); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, timestamp_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result = Timestamp::fromCString(input, len); +} + +template<> +inline void CastStringHelper::cast(const char* input, uint64_t len, interval_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result = Interval::fromCString(input, len); +} + +// ---------------------- cast String to Blob ------------------------------ // +template<> +void CastString::operation(const ku_string_t& input, blob_t& result, ValueVector* resultVector, + uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result.value.len = Blob::getBlobSize(input); + if (!ku_string_t::isShortString(result.value.len)) { + auto overflowBuffer = StringVector::getInMemOverflowBuffer(resultVector); + auto overflowPtr = overflowBuffer->allocateSpace(result.value.len); + result.value.overflowPtr = reinterpret_cast(overflowPtr); + Blob::fromString(reinterpret_cast(input.getData()), input.len, overflowPtr); + memcpy(result.value.prefix, overflowPtr, ku_string_t::PREFIX_LENGTH); + } else { + Blob::fromString(reinterpret_cast(input.getData()), input.len, + result.value.prefix); + } +} + +template<> +void CastStringHelper::cast(const char* input, uint64_t len, blob_t& /*result*/, + ValueVector* vector, uint64_t rowToAdd, const CSVOption* /*option*/) { + // base case: blob + auto blobBuffer = std::make_unique(len); + auto blobLen = Blob::fromString(input, len, blobBuffer.get()); + BlobVector::addBlob(vector, rowToAdd, blobBuffer.get(), blobLen); +} + +//---------------------- cast String to UUID ------------------------------ // +template<> +void CastString::operation(const ku_string_t& input, ku_uuid_t& result, + ValueVector* /*result_vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result.value = UUID::fromString(input.getAsString()); +} + +// LCOV_EXCL_START +template<> +void CastStringHelper::cast(const char* input, uint64_t len, ku_uuid_t& result, + ValueVector* /*vector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result.value = UUID::fromCString(input, len); +} +// LCOV_EXCL_STOP + +// ---------------------- cast String to nested types ------------------------------ // +static void skipWhitespace(const char*& input, const char* end) { + while (input < end) { + if (*input & 0x80) { + // We only skip ASCII white spaces there. + break; + } else { + KU_ASSERT(*input >= -1); + if (!isspace(*input)) { + break; + } + } + input++; + } +} + +static void trimRightWhitespace(const char* input, const char*& end) { + while (input < end && isspace(*(end - 1))) { + end--; + } +} + +static void trimQuotes(const char*& keyStart, const char*& keyEnd) { + // Skip quotations on struct keys. + if ((keyStart[0] == '\'' && (keyEnd - 1)[0] == '\'') || + (keyStart[0] == '\"' && (keyEnd - 1)[0] == '\"')) { + keyStart++; + keyEnd--; + } +} + +static bool skipToCloseQuotes(const char*& input, const char* end) { + auto ch = *input; + input++; // skip the first " ' + // TODO: escape char + while (input != end) { + if (*input == ch) { + return true; + } + input++; + } + return false; +} + +static bool skipToClose(const char*& input, const char* end, uint64_t& lvl, char target, + const CSVOption* option) { + input++; + while (input != end) { + if (*input == '\'') { + if (!skipToCloseQuotes(input, end)) { + return false; + } + } else if (*input == '{') { // must have closing brackets {, ] if they are not quoted + if (!skipToClose(input, end, lvl, '}', option)) { + return false; + } + } else if (*input == CopyConstants::DEFAULT_CSV_LIST_BEGIN_CHAR) { + if (!skipToClose(input, end, lvl, CopyConstants::DEFAULT_CSV_LIST_END_CHAR, option)) { + return false; + } + lvl++; // nested one more level + } else if (*input == target) { + if (target == CopyConstants::DEFAULT_CSV_LIST_END_CHAR) { + lvl--; + } + return true; + } + input++; + } + return false; // no corresponding closing bracket +} + +static bool isNull(std::string_view& str) { + auto start = str.data(); + auto end = start + str.length(); + skipWhitespace(start, end); + if (start == end) { + return true; + } + if (end - start >= 4 && (*start == 'N' || *start == 'n') && + (*(start + 1) == 'U' || *(start + 1) == 'u') && + (*(start + 2) == 'L' || *(start + 2) == 'l') && + (*(start + 3) == 'L' || *(start + 3) == 'l')) { + start += 4; + skipWhitespace(start, end); + if (start == end) { + return true; + } + } + return false; +} + +// ---------------------- cast String to List Helper ------------------------------ // +struct CountPartOperation { + uint64_t count = 0; + + static inline bool handleKey(const char* /*start*/, const char* /*end*/, + const CSVOption* /*config*/) { + return true; + } + inline void handleValue(const char* /*start*/, const char* /*end*/, + const CSVOption* /*config*/) { + count++; + } +}; + +struct SplitStringListOperation { + SplitStringListOperation(uint64_t& offset, ValueVector* resultVector) + : offset(offset), resultVector(resultVector) {} + + uint64_t& offset; + ValueVector* resultVector; + + void handleValue(const char* start, const char* end, const CSVOption* option) { + skipWhitespace(start, end); + trimRightWhitespace(start, end); + CastString::copyStringToVector(resultVector, offset, + std::string_view{start, (uint32_t)(end - start)}, option); + offset++; + } +}; + +template +static bool splitCStringList(const char* input, uint64_t len, T& state, const CSVOption* option) { + auto end = input + len; + uint64_t lvl = 1; + bool seenValue = false; + + // locate [ + skipWhitespace(input, end); + if (input == end || *input != CopyConstants::DEFAULT_CSV_LIST_BEGIN_CHAR) { + return false; + } + skipWhitespace(++input, end); + + bool justFinishedEntry = true; // true at start + auto startPtr = input; + while (input < end) { + auto ch = *input; + if (ch == CopyConstants::DEFAULT_CSV_LIST_BEGIN_CHAR) { + if (!skipToClose(input, end, ++lvl, CopyConstants::DEFAULT_CSV_LIST_END_CHAR, option)) { + return false; + } + } else if ((ch == '\'' || ch == '"') && justFinishedEntry) { + const char* prevInput = input; + if (!skipToCloseQuotes(input, end)) { + input = prevInput; + } + } else if (ch == '{') { + uint64_t struct_lvl = 0; + skipToClose(input, end, struct_lvl, '}', option); + } else if (ch == ',' || ch == CopyConstants::DEFAULT_CSV_LIST_END_CHAR) { // split + if (ch != CopyConstants::DEFAULT_CSV_LIST_END_CHAR || startPtr < input || seenValue) { + state.handleValue(startPtr, input, option); + seenValue = true; + } + if (ch == CopyConstants::DEFAULT_CSV_LIST_END_CHAR) { // last ] + lvl--; + break; + } + skipWhitespace(++input, end); + startPtr = input; + justFinishedEntry = true; + continue; + } + justFinishedEntry = false; + input++; + } + skipWhitespace(++input, end); + return (input == end && lvl == 0); +} + +template +static bool splitPossibleUnbracedList(std::string_view input, T& state, const CSVOption* option) { + input = StringUtils::ltrim(StringUtils::rtrim(input)); + auto split = StringUtils::smartSplit(input, ';'); + if (split.size() == 1 && input.front() == '[' && input.back() == ']') { + split = StringUtils::smartSplit(input.substr(1, input.size() - 2), ';'); + } + for (auto& i : split) { + state.handleValue(i.data(), i.data() + i.length(), option); + } + return true; +} + +template +static inline void startListCast(const char* input, uint64_t len, T split, const CSVOption* option, + ValueVector* vector) { + auto validList = option->allowUnbracedList ? + splitPossibleUnbracedList(std::string_view(input, len), split, option) : + splitCStringList(input, len, split, option); + if (!validList) { + throw ConversionException("Cast failed. " + std::string{input, (size_t)len} + + " is not in " + vector->dataType.toString() + " range."); + } +} + +// ---------------------- cast String to Array Helper ------------------------------ // +static void validateNumElementsInArray(uint64_t numElementsRead, const LogicalType& type) { + auto numElementsInArray = ArrayType::getNumElements(type); + if (numElementsRead != numElementsInArray) { + throw ConversionException(stringFormat( + "Each array should have fixed number of elements. Expected: {}, Actual: {}.", + numElementsInArray, numElementsRead)); + } +} + +// ---------------------- cast String to List/Array ------------------------------ // +template<> +void CastStringHelper::cast(const char* input, uint64_t len, list_entry_t& /*result*/, + ValueVector* vector, uint64_t rowToAdd, const CSVOption* option) { + auto logicalTypeID = vector->dataType.getLogicalTypeID(); + + // calculate the number of elements in array + CountPartOperation state; + if (option->allowUnbracedList) { + splitPossibleUnbracedList(std::string_view(input, len), state, option); + } else { + splitCStringList(input, len, state, option); + } + if (logicalTypeID == LogicalTypeID::ARRAY) { + validateNumElementsInArray(state.count, vector->dataType); + } + + auto list_entry = ListVector::addList(vector, state.count); + vector->setValue(rowToAdd, list_entry); + auto listDataVector = ListVector::getDataVector(vector); + + SplitStringListOperation split{list_entry.offset, listDataVector}; + startListCast(input, len, split, option, vector); +} + +template<> +void CastString::operation(const ku_string_t& input, list_entry_t& result, + ValueVector* resultVector, uint64_t rowToAdd, const CSVOption* option) { + CastStringHelper::cast(reinterpret_cast(input.getData()), input.len, result, + resultVector, rowToAdd, option); +} + +// ---------------------- cast String to Map ------------------------------ // +struct SplitStringMapOperation { + SplitStringMapOperation(uint64_t& offset, ValueVector* resultVector) + : offset{offset}, resultVector{resultVector} {} + + uint64_t& offset; + ValueVector* resultVector; + ValueSet uniqueKeys; + + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + bool handleKey(const char* start, const char* end, const CSVOption* option); + + void handleValue(const char* start, const char* end, const CSVOption* option); +}; + +bool SplitStringMapOperation::handleKey(const char* start, const char* end, + const CSVOption* option) { + trimRightWhitespace(start, end); + auto fieldVector = StructVector::getFieldVector(resultVector, 0).get(); + CastString::copyStringToVector(fieldVector, offset, + std::string_view{start, (uint32_t)(end - start)}, option); + if (fieldVector->isNull(offset)) { + throw common::ConversionException{"Map does not allow null as key."}; + } + auto val = common::Value::createDefaultValue(fieldVector->dataType); + val.copyFromColLayout(fieldVector->getData() + fieldVector->getNumBytesPerValue() * offset, + fieldVector); + auto uniqueKey = uniqueKeys.insert(val).second; + if (!uniqueKey) { + throw common::ConversionException{"Map does not allow duplicate keys."}; + } + return true; +} + +void SplitStringMapOperation::handleValue(const char* start, const char* end, + const CSVOption* option) { + trimRightWhitespace(start, end); + CastString::copyStringToVector(StructVector::getFieldVector(resultVector, 1).get(), offset++, + std::string_view{start, (uint32_t)(end - start)}, option); +} + +template +static bool parseKeyOrValue(const char*& input, const char* end, T& state, bool isKey, + bool& closeBracket, const CSVOption* option) { + auto start = input; + uint64_t lvl = 0; + + while (input < end) { + if (*input == '"' || *input == '\'') { + if (!skipToCloseQuotes(input, end)) { + return false; + } + } else if (*input == '{') { + if (!skipToClose(input, end, lvl, '}', option)) { + return false; + } + } else if (*input == CopyConstants::DEFAULT_CSV_LIST_BEGIN_CHAR) { + if (!skipToClose(input, end, lvl, CopyConstants::DEFAULT_CSV_LIST_END_CHAR, option)) { + return false; + } + } else if (isKey && *input == '=') { + return state.handleKey(start, input, option); + } else if (!isKey && (*input == ',' || *input == '}')) { + state.handleValue(start, input, option); + if (*input == '}') { + closeBracket = true; + } + return true; + } + input++; + } + return false; +} + +// Split map of format: {a=12,b=13} +template +static bool splitCStringMap(const char* input, uint64_t len, T& state, const CSVOption* option) { + auto end = input + len; + bool closeBracket = false; + + skipWhitespace(input, end); + if (input == end || *input != '{') { // start with { + return false; + } + skipWhitespace(++input, end); + if (input == end) { + return false; + } + if (*input == '}') { + skipWhitespace(++input, end); // empty + return input == end; + } + + while (input < end) { + if (!parseKeyOrValue(input, end, state, true, closeBracket, option)) { + return false; + } + skipWhitespace(++input, end); + if (!parseKeyOrValue(input, end, state, false, closeBracket, option)) { + return false; + } + skipWhitespace(++input, end); + if (closeBracket) { + return (input == end); + } + } + return false; +} + +template<> +void CastStringHelper::cast(const char* input, uint64_t len, map_entry_t& /*result*/, + ValueVector* vector, uint64_t rowToAdd, const CSVOption* option) { + // count the number of maps in map + CountPartOperation state; + splitCStringMap(input, len, state, option); + + auto list_entry = ListVector::addList(vector, state.count); + vector->setValue(rowToAdd, list_entry); + auto structVector = ListVector::getDataVector(vector); + + SplitStringMapOperation split{list_entry.offset, structVector}; + if (!splitCStringMap(input, len, split, option)) { + throw ConversionException("Cast failed. " + std::string{input, (size_t)len} + + " is not in " + vector->dataType.toString() + " range."); + } +} + +template<> +void CastString::operation(const ku_string_t& input, map_entry_t& result, ValueVector* resultVector, + uint64_t rowToAdd, const CSVOption* option) { + CastStringHelper::cast(reinterpret_cast(input.getData()), input.len, result, + resultVector, rowToAdd, option); +} + +// ---------------------- cast String to Struct ------------------------------ // +static bool parseStructFieldName(const char*& input, const char* end) { + while (input < end) { + if (*input == ':') { + return true; + } + input++; + } + return false; +} + +static bool parseStructFieldValue(const char*& input, const char* end, const CSVOption* option, + bool& closeBrack) { + uint64_t lvl = 0; + while (input < end) { + if (*input == '"' || *input == '\'') { + if (!skipToCloseQuotes(input, end)) { + return false; + } + } else if (*input == '{') { + if (!skipToClose(input, end, lvl, '}', option)) { + return false; + } + } else if (*input == CopyConstants::DEFAULT_CSV_LIST_BEGIN_CHAR) { + if (!skipToClose(input, end, ++lvl, CopyConstants::DEFAULT_CSV_LIST_END_CHAR, option)) { + return false; + } + } else if (*input == ',' || *input == '}') { + if (*input == '}') { + closeBrack = true; + } + return (lvl == 0); + } + input++; + } + return false; +} + +static bool tryCastStringToStruct(const char* input, uint64_t len, ValueVector* vector, + uint64_t rowToAdd, const CSVOption* option) { + // default values to NULL + auto fieldVectors = StructVector::getFieldVectors(vector); + for (auto& fieldVector : fieldVectors) { + fieldVector->setNull(rowToAdd, true); + } + + // check if start with { + auto end = input + len; + const auto& type = vector->dataType; + skipWhitespace(input, end); + if (input == end || *input != '{') { + return false; + } + skipWhitespace(++input, end); + + if (input == end) { // no closing bracket + return false; + } + if (*input == '}') { + skipWhitespace(++input, end); + return input == end; + } + + bool closeBracket = false; + while (input < end) { + auto keyStart = input; + if (!parseStructFieldName(input, end)) { // find key + return false; + } + auto keyEnd = input; + trimRightWhitespace(keyStart, keyEnd); + trimQuotes(keyStart, keyEnd); + auto fieldIdx = StructType::getFieldIdx(type, std::string{keyStart, keyEnd}); + if (fieldIdx == INVALID_STRUCT_FIELD_IDX) { + throw ParserException{"Invalid struct field name: " + std::string{keyStart, keyEnd}}; + } + + skipWhitespace(++input, end); + auto valStart = input; + if (!parseStructFieldValue(input, end, option, closeBracket)) { // find value + return false; + } + auto valEnd = input; + trimRightWhitespace(valStart, valEnd); + trimQuotes(valStart, valEnd); + skipWhitespace(++input, end); + + auto fieldVector = StructVector::getFieldVector(vector, fieldIdx).get(); + fieldVector->setNull(rowToAdd, false); + CastString::copyStringToVector(fieldVector, rowToAdd, + std::string_view{valStart, (uint32_t)(valEnd - valStart)}, option); + + if (closeBracket) { + return (input == end); + } + } + return false; +} + +template<> +void CastStringHelper::cast(const char* input, uint64_t len, struct_entry_t& /*result*/, + ValueVector* vector, uint64_t rowToAdd, const CSVOption* option) { + if (!tryCastStringToStruct(input, len, vector, rowToAdd, option)) { + throw ConversionException("Cast failed. " + std::string{input, (size_t)len} + + " is not in " + vector->dataType.toString() + " range."); + } +} + +template<> +void CastString::operation(const ku_string_t& input, struct_entry_t& result, + ValueVector* resultVector, uint64_t rowToAdd, const CSVOption* option) { + CastStringHelper::cast(reinterpret_cast(input.getData()), input.len, result, + resultVector, rowToAdd, option); +} + +// ---------------------- cast String to Union ------------------------------ // +template +static inline void testAndSetValue(ValueVector* vector, uint64_t rowToAdd, T result, bool success) { + if (success) { + vector->setValue(rowToAdd, result); + } +} + +static bool tryCastUnionField(ValueVector* vector, uint64_t rowToAdd, const char* input, + uint64_t len) { + auto& targetType = vector->dataType; + bool success = false; + switch (targetType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + bool result = false; + success = function::tryCastToBool(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT128: { + int128_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT128: { + uint128_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT64: { + int64_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT32: { + int32_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT16: { + int16_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT8: { + int8_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT64: { + uint64_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT32: { + uint32_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT16: { + uint16_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT8: { + uint8_t result = 0; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::DOUBLE: { + double result = 0; + success = function::tryDoubleCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::FLOAT: { + float result = 0; + success = function::tryDoubleCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::DECIMAL: { + switch (targetType.getPhysicalType()) { + case PhysicalTypeID::INT16: { + int16_t result = 0; + tryDecimalCast(input, len, result, DecimalType::getPrecision(targetType), + DecimalType::getScale(targetType)); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case PhysicalTypeID::INT32: { + int32_t result = 0; + tryDecimalCast(input, len, result, DecimalType::getPrecision(targetType), + DecimalType::getScale(targetType)); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case PhysicalTypeID::INT64: { + int64_t result = 0; + tryDecimalCast(input, len, result, DecimalType::getPrecision(targetType), + DecimalType::getScale(targetType)); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case PhysicalTypeID::INT128: { + int128_t result = 0; + tryDecimalCast(input, len, result, DecimalType::getPrecision(targetType), + DecimalType::getScale(targetType)); + testAndSetValue(vector, rowToAdd, result, success); + } break; + default: + KU_UNREACHABLE; + } + } break; + case LogicalTypeID::DATE: { + date_t result; + uint64_t pos = 0; + success = Date::tryConvertDate(input, len, pos, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::TIMESTAMP_NS: { + timestamp_ns_t result; + success = TryCastStringToTimestamp::tryCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::TIMESTAMP_MS: { + timestamp_ms_t result; + success = TryCastStringToTimestamp::tryCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + timestamp_sec_t result; + success = TryCastStringToTimestamp::tryCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + timestamp_tz_t result; + success = TryCastStringToTimestamp::tryCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::TIMESTAMP: { + timestamp_t result; + success = Timestamp::tryConvertTimestamp(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::STRING: { + if (!utf8proc::Utf8Proc::isValid(input, len)) { + throw ConversionException{"Invalid UTF8-encoded string."}; + } + StringVector::addString(vector, rowToAdd, input, len); + return true; + } + default: { + return false; + } + } + return success; +} + +template<> +void CastStringHelper::cast(const char* input, uint64_t len, union_entry_t& /*result*/, + ValueVector* vector, uint64_t rowToAdd, const CSVOption* /*option*/) { + auto& type = vector->dataType; + union_field_idx_t selectedFieldIdx = INVALID_STRUCT_FIELD_IDX; + + auto i = 0u; + for (; i < UnionType::getNumFields(type); i++) { + auto internalFieldIdx = UnionType::getInternalFieldIdx(i); + auto fieldVector = StructVector::getFieldVector(vector, internalFieldIdx).get(); + if (tryCastUnionField(fieldVector, rowToAdd, input, len)) { + fieldVector->setNull(rowToAdd, false /* isNull */); + selectedFieldIdx = i; + i++; + break; + } else { + fieldVector->setNull(rowToAdd, true /* isNull */); + } + } + for (; i < UnionType::getNumFields(type); i++) { + auto fieldVector = UnionVector::getValVector(vector, i); + fieldVector->setNull(rowToAdd, true /* isNull */); + } + + if (selectedFieldIdx == INVALID_STRUCT_FIELD_IDX) { + throw ConversionException{stringFormat("Could not convert to union type {}: {}.", + type.toString(), std::string{input, (size_t)len})}; + } + StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX) + ->setValue(rowToAdd, selectedFieldIdx); + StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX) + ->setNull(rowToAdd, false /* isNull */); +} + +template<> +void CastString::operation(const ku_string_t& input, union_entry_t& result, + ValueVector* resultVector, uint64_t rowToAdd, const CSVOption* CSVOption) { + CastStringHelper::cast(reinterpret_cast(input.getData()), input.len, result, + resultVector, rowToAdd, CSVOption); +} + +static void setVectorNull(ValueVector* vector, uint64_t vectorPos, std::string_view strVal, + const CSVOption* option) { + auto& type = vector->dataType; + switch (type.getLogicalTypeID()) { + case LogicalTypeID::STRING: { + if (std::any_of(option->nullStrings.begin(), option->nullStrings.end(), + [&](const std::string& nullStr) { return nullStr == strVal; })) { + vector->setNull(vectorPos, true /* isNull */); + return; + } + } break; + default: { + if (isNull(strVal)) { + vector->setNull(vectorPos, true /* isNull */); + return; + } + } break; + } + vector->setNull(vectorPos, false /* isNull */); +} + +void CastString::copyStringToVector(ValueVector* vector, uint64_t vectorPos, + std::string_view strVal, const CSVOption* option) { + auto& type = vector->dataType; + setVectorNull(vector, vectorPos, strVal, option); + if (vector->isNull(vectorPos)) { + return; + } + switch (type.getLogicalTypeID()) { + case LogicalTypeID::INT128: { + int128_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + int64_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::INT32: { + int32_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::INT16: { + int16_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::INT8: { + int8_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::UINT64: { + uint64_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::UINT32: { + uint32_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::UINT16: { + uint16_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::UINT8: { + uint8_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::FLOAT: { + float val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::DECIMAL: { + switch (type.getPhysicalType()) { + case PhysicalTypeID::INT16: { + int16_t val = 0; + decimalCast(strVal.data(), strVal.length(), val, type); + vector->setValue(vectorPos, val); + } break; + case PhysicalTypeID::INT32: { + int32_t val = 0; + decimalCast(strVal.data(), strVal.length(), val, type); + vector->setValue(vectorPos, val); + } break; + case PhysicalTypeID::INT64: { + int64_t val = 0; + decimalCast(strVal.data(), strVal.length(), val, type); + vector->setValue(vectorPos, val); + } break; + case PhysicalTypeID::INT128: { + int128_t val = 0; + decimalCast(strVal.data(), strVal.length(), val, type); + vector->setValue(vectorPos, val); + } break; + default: + KU_UNREACHABLE; + } + } break; + case LogicalTypeID::DOUBLE: { + double val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::BOOL: { + bool val = false; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::BLOB: { + blob_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val, vector, vectorPos, option); + } break; + case LogicalTypeID::UUID: { + ku_uuid_t val{}; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val.value); + } break; + case LogicalTypeID::STRING: { + if (!utf8proc::Utf8Proc::isValid(strVal.data(), strVal.length())) { + throw ConversionException{"Invalid UTF8-encoded string."}; + } + StringVector::addString(vector, vectorPos, strVal.data(), strVal.length()); + } break; + case LogicalTypeID::DATE: { + date_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::TIMESTAMP_NS: { + timestamp_ns_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::TIMESTAMP_MS: { + timestamp_ms_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + timestamp_sec_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + timestamp_tz_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::TIMESTAMP: { + timestamp_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::INTERVAL: { + interval_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::UINT128: { + uint128_t val = 0; + CastStringHelper::cast(strVal.data(), strVal.length(), val); + vector->setValue(vectorPos, val); + } break; + case LogicalTypeID::MAP: { + map_entry_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val, vector, vectorPos, option); + } break; + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: { + list_entry_t val; + CastStringHelper::cast(strVal.data(), strVal.length(), val, vector, vectorPos, option); + } break; + case LogicalTypeID::STRUCT: { + struct_entry_t val{}; + CastStringHelper::cast(strVal.data(), strVal.length(), val, vector, vectorPos, option); + } break; + case LogicalTypeID::UNION: { + union_entry_t val{}; + CastStringHelper::cast(strVal.data(), strVal.length(), val, vector, vectorPos, option); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast_string_non_nested_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast_string_non_nested_functions.cpp new file mode 100644 index 0000000000..ad64393447 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/cast_string_non_nested_functions.cpp @@ -0,0 +1,272 @@ +#include "function/cast/functions/cast_string_non_nested_functions.h" + +#include "common/constants.h" +#include "common/types/date_t.h" +#include "common/types/interval_t.h" +#include "common/types/timestamp_t.h" +#include "common/types/uuid.h" +#include "function/cast/functions/numeric_limits.h" +#include "re2.h" + +namespace lbug { +namespace function { + +bool tryCastToBool(const char* input, uint64_t len, bool& result) { + StringUtils::removeCStringWhiteSpaces(input, len); + + switch (len) { + case 1: { + char c = std::tolower(*input); + if (c == 't' || c == '1') { + result = true; + return true; + } else if (c == 'f' || c == '0') { + result = false; + return true; + } + return false; + } + case 4: { + auto t = std::tolower(input[0]); + auto r = std::tolower(input[1]); + auto u = std::tolower(input[2]); + auto e = std::tolower(input[3]); + if (t == 't' && r == 'r' && u == 'u' && e == 'e') { + result = true; + return true; + } + return false; + } + case 5: { + auto f = std::tolower(input[0]); + auto a = std::tolower(input[1]); + auto l = std::tolower(input[2]); + auto s = std::tolower(input[3]); + auto e = std::tolower(input[4]); + if (f == 'f' && a == 'a' && l == 'l' && s == 's' && e == 'e') { + result = false; + return true; + } + return false; + } + default: + return false; + } +} + +void castStringToBool(const char* input, uint64_t len, bool& result) { + if (!tryCastToBool(input, len, result)) { + throw ConversionException{ + stringFormat("Value {} is not a valid boolean", std::string{input, (size_t)len})}; + } +} + +template<> +bool TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result) { + if (!Timestamp::tryConvertTimestamp(input, len, result)) { + return false; + } + result = Timestamp::getEpochNanoSeconds(result); + return true; +} + +template<> +bool TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result) { + if (!Timestamp::tryConvertTimestamp(input, len, result)) { + return false; + } + result = Timestamp::getEpochMilliSeconds(result); + return true; +} + +template<> +bool TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result) { + if (!Timestamp::tryConvertTimestamp(input, len, result)) { + return false; + } + result = Timestamp::getEpochSeconds(result); + return true; +} + +static bool isDate(std::string_view str) { + return RE2::FullMatch(str, Date::regexPattern()); +} + +static bool isUUID(std::string_view str) { + return RE2::FullMatch(str, UUID::regexPattern()); +} + +static bool isInterval(std::string_view str) { + return RE2::FullMatch(str, Interval::regexPattern1()) || + RE2::FullMatch(str, Interval::regexPattern2()); +} + +static LogicalType inferMapOrStruct(std::string_view str) { + auto split = StringUtils::smartSplit(str.substr(1, str.size() - 2), ','); + bool isMap = true, isStruct = true; // Default match to map if both are true + for (auto& ele : split) { + if (StringUtils::smartSplit(ele, '=', 2).size() != 2) { + isMap = false; + } + if (StringUtils::smartSplit(ele, ':', 2).size() != 2) { + isStruct = false; + } + } + if (isMap) { + auto childKeyType = LogicalType::ANY(); + auto childValueType = LogicalType::ANY(); + for (auto& ele : split) { + auto split = StringUtils::smartSplit(ele, '=', 2); + auto& key = split[0]; + auto& value = split[1]; + childKeyType = + LogicalTypeUtils::combineTypes(childKeyType, inferMinimalTypeFromString(key)); + childValueType = + LogicalTypeUtils::combineTypes(childValueType, inferMinimalTypeFromString(value)); + } + return LogicalType::MAP(std::move(childKeyType), std::move(childValueType)); + } else if (isStruct) { + std::vector fields; + for (auto& ele : split) { + auto split = StringUtils::smartSplit(ele, ':', 2); + auto fieldKey = StringUtils::ltrim(StringUtils::rtrim(split[0])); + if (fieldKey.size() > 0 && fieldKey.front() == '\'') { + fieldKey = fieldKey.substr(1); + } + if (fieldKey.size() > 0 && fieldKey.back() == '\'') { + fieldKey = fieldKey.substr(0, fieldKey.size() - 1); + } + auto fieldType = inferMinimalTypeFromString(split[1]); + fields.emplace_back(std::string(fieldKey), std::move(fieldType)); + } + return LogicalType::STRUCT(std::move(fields)); + } else { + return LogicalType::STRING(); + } +} + +LogicalType inferMinimalTypeFromString(const std::string& str) { + return inferMinimalTypeFromString(std::string_view(str)); +} + +static RE2& boolPattern() { + static RE2 retval("(?i)(T|F|TRUE|FALSE)"); + return retval; +} +static RE2& intPattern() { + static RE2 retval("(-?0)|(-?[1-9]\\d*)"); + return retval; +} +static RE2& realPattern() { + static RE2 retval("(\\+|-)?(0|[1-9]\\d*)?\\.(\\d*)"); + return retval; +} + +bool isAnyType(std::string_view cpy) { + return cpy.size() == 0 || StringUtils::caseInsensitiveEquals(cpy, "NULL") || + StringUtils::caseInsensitiveEquals(cpy, "NAN"); +} + +bool isINF(std::string_view cpy) { + return StringUtils::caseInsensitiveEquals(cpy, "INF") || + StringUtils::caseInsensitiveEquals(cpy, "+INF") || + StringUtils::caseInsensitiveEquals(cpy, "-INF") || + StringUtils::caseInsensitiveEquals(cpy, "INFINITY") || + StringUtils::caseInsensitiveEquals(cpy, "+INFINITY") || + StringUtils::caseInsensitiveEquals(cpy, "-INFINITY"); +} + +LogicalType inferMinimalTypeFromString(std::string_view str) { + constexpr char array_begin = common::CopyConstants::DEFAULT_CSV_LIST_BEGIN_CHAR; + constexpr char array_end = common::CopyConstants::DEFAULT_CSV_LIST_END_CHAR; + auto cpy = StringUtils::ltrim(StringUtils::rtrim(str)); + // Check special double literals + if (isINF(cpy)) { + return LogicalType::DOUBLE(); + } + // Any + if (isAnyType(cpy)) { + return LogicalType::ANY(); + } + // Boolean + if (RE2::FullMatch(cpy, boolPattern())) { + return LogicalType::BOOL(); + } + // The reason we're not going to try to match to a minimal width integer + // is because if we're infering the type of integer from a sequence of + // increasing integers, we're bound to underestimate the width + // if we only sniff the first few elements; a rather common occurrence. + + // integer + if (RE2::FullMatch(cpy, intPattern())) { + if (cpy.size() >= 1 + NumericLimits::maxNumDigits()) { + return LogicalType::DOUBLE(); + } + int128_t int128val = 0; + uint128_t uint128val = 0; + if (trySimpleIntegerCast(cpy.data(), cpy.length(), int128val)) { + if (NumericLimits::isInBounds(int128val)) { + return LogicalType::INT64(); + } + KU_ASSERT(NumericLimits::isInBounds(int128val)); + return LogicalType::INT128(); + } else if (trySimpleIntegerCast(cpy.data(), cpy.length(), uint128val)) { + return LogicalType::UINT128(); + } + return LogicalType::STRING(); + } + // Real value checking + if (RE2::FullMatch(cpy, realPattern())) { + if (cpy[0] == '-') { + cpy = cpy.substr(1); + } + if (cpy.size() <= DECIMAL_PRECISION_LIMIT) { + auto decimalPoint = cpy.find('.'); + KU_ASSERT(decimalPoint != std::string::npos); + return LogicalType::DECIMAL(cpy.size() - 1, cpy.size() - decimalPoint - 1); + } else { + return LogicalType::DOUBLE(); + } + } + // date + if (isDate(cpy)) { + return LogicalType::DATE(); + } + // It might just be quicker to try cast to timestamp. + timestamp_t tmp; + if (common::Timestamp::tryConvertTimestamp(cpy.data(), cpy.length(), tmp)) { + return LogicalType::TIMESTAMP(); + } + + // UUID + if (isUUID(cpy)) { + return LogicalType::UUID(); + } + + // interval checking + if (isInterval(cpy)) { + return LogicalType::INTERVAL(); + } + + // array_begin and array_end are constants + if (cpy.front() == array_begin && cpy.back() == array_end) { + auto split = StringUtils::smartSplit(cpy.substr(1, cpy.size() - 2), ','); + auto childType = LogicalType::ANY(); + for (auto& ele : split) { + childType = LogicalTypeUtils::combineTypes(childType, inferMinimalTypeFromString(ele)); + } + return LogicalType::LIST(std::move(childType)); + } + + if (cpy.front() == '{' && cpy.back() == '}') { + return inferMapOrStruct(cpy); + } + + return LogicalType::STRING(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/comparison_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/comparison_functions.cpp new file mode 100644 index 0000000000..621856bad6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/comparison_functions.cpp @@ -0,0 +1,225 @@ +#include "function/comparison/comparison_functions.h" + +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +template +static void executeNestedOperation(uint8_t& result, ValueVector* leftVector, + ValueVector* rightVector, uint64_t leftPos, uint64_t rightPos) { + switch (leftVector->dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INT64: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INT32: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INT16: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INT8: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::UINT64: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::UINT32: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::UINT16: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::UINT8: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INT128: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::DOUBLE: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::FLOAT: { + OP::operation(leftVector->getValue(leftPos), rightVector->getValue(rightPos), + result, nullptr /* left */, nullptr /* right */); + } break; + case PhysicalTypeID::STRING: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INTERVAL: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::INTERNAL_ID: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, nullptr /* left */, + nullptr /* right */); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, leftVector, rightVector); + } break; + case PhysicalTypeID::STRUCT: { + OP::operation(leftVector->getValue(leftPos), + rightVector->getValue(rightPos), result, leftVector, rightVector); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +static void executeNestedEqual(uint8_t& result, ValueVector* leftVector, ValueVector* rightVector, + uint64_t leftPos, uint64_t rightPos) { + if (leftVector->isNull(leftPos) && rightVector->isNull(rightPos)) { + result = true; + } else if (leftVector->isNull(leftPos) != rightVector->isNull(rightPos)) { + result = false; + } else { + executeNestedOperation(result, leftVector, rightVector, leftPos, rightPos); + } +} + +template<> +void Equals::operation(const list_entry_t& left, const list_entry_t& right, uint8_t& result, + ValueVector* leftVector, ValueVector* rightVector) { + if (leftVector->dataType != rightVector->dataType || left.size != right.size) { + result = false; + return; + } + auto leftDataVector = ListVector::getDataVector(leftVector); + auto rightDataVector = ListVector::getDataVector(rightVector); + for (auto i = 0u; i < left.size; i++) { + auto leftPos = left.offset + i; + auto rightPos = right.offset + i; + executeNestedEqual(result, leftDataVector, rightDataVector, leftPos, rightPos); + if (!result) { + return; + } + } + result = true; +} + +template<> +void Equals::operation(const struct_entry_t& left, const struct_entry_t& right, uint8_t& result, + ValueVector* leftVector, ValueVector* rightVector) { + if (leftVector->dataType != rightVector->dataType) { + result = false; + return; + } + auto leftFields = StructVector::getFieldVectors(leftVector); + auto rightFields = StructVector::getFieldVectors(rightVector); + for (auto i = 0u; i < leftFields.size(); i++) { + auto leftField = leftFields[i].get(); + auto rightField = rightFields[i].get(); + executeNestedEqual(result, leftField, rightField, left.pos, right.pos); + if (!result) { + return; + } + } + result = true; + // For STRUCT type, we also need to check their field names + if (result || leftVector->dataType.getLogicalTypeID() == LogicalTypeID::STRUCT || + rightVector->dataType.getLogicalTypeID() == LogicalTypeID::STRUCT) { + auto leftTypeNames = StructType::getFieldNames(leftVector->dataType); + auto rightTypeNames = StructType::getFieldNames(rightVector->dataType); + for (auto i = 0u; i < leftTypeNames.size(); i++) { + if (leftTypeNames[i] != rightTypeNames[i]) { + result = false; + } + } + } +} + +static void executeNestedGreaterThan(uint8_t& isGreaterThan, uint8_t& isEqual, + ValueVector* leftDataVector, ValueVector* rightDataVector, uint64_t leftPos, + uint64_t rightPos) { + auto isLeftNull = leftDataVector->isNull(leftPos); + auto isRightNull = rightDataVector->isNull(rightPos); + if (isLeftNull || isRightNull) { + isGreaterThan = !isRightNull; + isEqual = (isLeftNull == isRightNull); + } else { + executeNestedOperation(isGreaterThan, leftDataVector, rightDataVector, leftPos, + rightPos); + if (!isGreaterThan) { + executeNestedOperation(isEqual, leftDataVector, rightDataVector, leftPos, + rightPos); + } else { + isEqual = false; + } + } +} + +template<> +void GreaterThan::operation(const list_entry_t& left, const list_entry_t& right, uint8_t& result, + ValueVector* leftVector, ValueVector* rightVector) { + KU_ASSERT(leftVector->dataType == rightVector->dataType); + auto leftDataVector = ListVector::getDataVector(leftVector); + auto rightDataVector = ListVector::getDataVector(rightVector); + auto commonLength = std::min(left.size, right.size); + uint8_t isEqual = 0; + for (auto i = 0u; i < commonLength; i++) { + auto leftPos = left.offset + i; + auto rightPos = right.offset + i; + executeNestedGreaterThan(result, isEqual, leftDataVector, rightDataVector, leftPos, + rightPos); + if (result || (!result && !isEqual)) { + return; + } + } + result = left.size > right.size; +} + +template<> +void GreaterThan::operation(const struct_entry_t& left, const struct_entry_t& right, + uint8_t& result, ValueVector* leftVector, ValueVector* rightVector) { + KU_ASSERT(leftVector->dataType == rightVector->dataType); + auto leftFields = StructVector::getFieldVectors(leftVector); + auto rightFields = StructVector::getFieldVectors(rightVector); + uint8_t isEqual = 0; + for (auto i = 0u; i < leftFields.size(); i++) { + auto leftField = leftFields[i].get(); + auto rightField = rightFields[i].get(); + executeNestedGreaterThan(result, isEqual, leftField, rightField, left.pos, right.pos); + if (result || (!result && !isEqual)) { + return; + } + } + result = false; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/date/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/date/CMakeLists.txt new file mode 100644 index 0000000000..b77489769f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/date/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_function_date + OBJECT + date_functions.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/date/date_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/date/date_functions.cpp new file mode 100644 index 0000000000..258c444a95 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/date/date_functions.cpp @@ -0,0 +1,24 @@ +#include "function/date/date_functions.h" + +#include "function/function.h" +#include "transaction/transaction.h" + +namespace lbug { +namespace function { + +void CurrentDate::operation(common::date_t& result, void* dataPtr) { + auto clientContext = reinterpret_cast(dataPtr)->clientContext; + auto transaction = transaction::Transaction::Get(*clientContext); + auto currentTS = transaction->getCurrentTS(); + result = common::Timestamp::getDate(common::timestamp_tz_t(currentTS)); +} + +void CurrentTimestamp::operation(common::timestamp_tz_t& result, void* dataPtr) { + auto clientContext = reinterpret_cast(dataPtr)->clientContext; + auto transaction = transaction::Transaction::Get(*clientContext); + auto currentTS = transaction->getCurrentTS(); + result = common::timestamp_tz_t(currentTS); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/CMakeLists.txt new file mode 100644 index 0000000000..6d2ac8b837 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_function_export + OBJECT + export_csv_function.cpp + export_parquet_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/export_csv_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/export_csv_function.cpp new file mode 100644 index 0000000000..54d684bc72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/export_csv_function.cpp @@ -0,0 +1,282 @@ +#include "common/file_system/virtual_file_system.h" +#include "common/serializer/buffer_writer.h" +#include "function/cast/vector_cast_functions.h" +#include "function/export/export_function.h" +#include "function/scalar_function.h" +#include "main/client_context.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace function { + +using namespace common; + +struct ExportCSVBindData : public ExportFuncBindData { + CSVOption exportOption; + + ExportCSVBindData(std::vector names, std::string fileName, CSVOption exportOption) + : ExportFuncBindData{std::move(names), std::move(fileName)}, + exportOption{std::move(exportOption)} {} + + std::unique_ptr copy() const override { + auto bindData = + std::make_unique(columnNames, fileName, exportOption.copy()); + bindData->types = LogicalType::copy(types); + return bindData; + } +}; + +static std::string addEscapes(char toEscape, char escape, const std::string& val) { + uint64_t i = 0; + std::string escapedStr = ""; + auto found = val.find(toEscape); + + while (found != std::string::npos) { + while (i < found) { + escapedStr += val[i]; + i++; + } + escapedStr += escape; + found = val.find(toEscape, found + sizeof(escape)); + } + while (i < val.length()) { + escapedStr += val[i]; + i++; + } + return escapedStr; +} + +static bool requireQuotes(const ExportCSVBindData& exportCSVBindData, const uint8_t* str, + uint64_t len, std::atomic& parallelFlag) { + // Check if the string is equal to the null string. + if (len == strlen(ExportCSVConstants::DEFAULT_NULL_STR) && + memcmp(str, ExportCSVConstants::DEFAULT_NULL_STR, len) == 0) { + return true; + } + for (auto i = 0u; i < len; i++) { + if (str[i] == ExportCSVConstants::DEFAULT_CSV_NEWLINE[0] || + str[i] == ExportCSVConstants::DEFAULT_CSV_NEWLINE[1]) { + parallelFlag.store(false, std::memory_order_relaxed); + return true; + } + if (str[i] == exportCSVBindData.exportOption.quoteChar || + str[i] == exportCSVBindData.exportOption.delimiter) { + return true; + } + } + return false; +} + +static void writeString(BufferWriter* serializer, const ExportFuncBindData& bindData, + const uint8_t* strData, uint64_t strLen, bool forceQuote, std::atomic& parallelFlag) { + auto& exportCSVBindData = bindData.constCast(); + if (!forceQuote) { + forceQuote = requireQuotes(exportCSVBindData, strData, strLen, parallelFlag); + } + if (forceQuote) { + bool requiresEscape = false; + for (auto i = 0u; i < strLen; i++) { + if (strData[i] == exportCSVBindData.exportOption.quoteChar || + strData[i] == exportCSVBindData.exportOption.escapeChar) { + requiresEscape = true; + break; + } + } + + if (!requiresEscape) { + serializer->writeBufferData(exportCSVBindData.exportOption.quoteChar); + serializer->write(strData, strLen); + serializer->writeBufferData(exportCSVBindData.exportOption.quoteChar); + return; + } + + std::string strValToWrite = std::string(reinterpret_cast(strData), strLen); + strValToWrite = addEscapes(exportCSVBindData.exportOption.escapeChar, + exportCSVBindData.exportOption.escapeChar, strValToWrite); + if (exportCSVBindData.exportOption.escapeChar != exportCSVBindData.exportOption.quoteChar) { + strValToWrite = addEscapes(exportCSVBindData.exportOption.quoteChar, + exportCSVBindData.exportOption.escapeChar, strValToWrite); + } + serializer->writeBufferData(exportCSVBindData.exportOption.quoteChar); + serializer->writeBufferData(strValToWrite); + serializer->writeBufferData(exportCSVBindData.exportOption.quoteChar); + } else { + serializer->write(strData, strLen); + } +} + +struct ExportCSVSharedState : public ExportFuncSharedState { + std::mutex mtx; + std::unique_ptr fileInfo; + offset_t offset = 0; + + ExportCSVSharedState() = default; + + void init(main::ClientContext& context, const ExportFuncBindData& bindData) override { + fileInfo = VirtualFileSystem::GetUnsafe(context)->openFile(bindData.fileName, + FileOpenFlags(FileFlags::WRITE | FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS), &context); + writeHeader(bindData); + } + + void writeHeader(const ExportFuncBindData& bindData) { + auto& exportCSVBindData = bindData.constCast(); + BufferWriter bufferedSerializer; + if (exportCSVBindData.exportOption.hasHeader) { + for (auto i = 0u; i < exportCSVBindData.columnNames.size(); i++) { + if (i != 0) { + bufferedSerializer.writeBufferData(exportCSVBindData.exportOption.delimiter); + } + auto& name = exportCSVBindData.columnNames[i]; + writeString(&bufferedSerializer, exportCSVBindData, + reinterpret_cast(name.c_str()), name.length(), + false /* forceQuote */, parallelFlag); + } + bufferedSerializer.writeBufferData(ExportCSVConstants::DEFAULT_CSV_NEWLINE[0]); + writeRows(bufferedSerializer.getBlobData(), bufferedSerializer.getSize()); + } + } + + void writeRows(const uint8_t* data, uint64_t size) { + std::lock_guard lck(mtx); + fileInfo->writeFile(data, size, offset); + offset += size; + } +}; + +struct ExportCSVLocalState final : public ExportFuncLocalState { + std::unique_ptr serializer; + std::unique_ptr unflatCastDataChunk; + std::unique_ptr flatCastDataChunk; + std::vector castVectors; + std::vector castFuncs; + + ExportCSVLocalState(main::ClientContext& context, const ExportFuncBindData& bindData, + std::vector isFlatVec) { + auto& exportCSVBindData = bindData.constCast(); + serializer = std::make_unique(); + auto numFlatVectors = std::count(isFlatVec.begin(), isFlatVec.end(), true /* isFlat */); + unflatCastDataChunk = std::make_unique(isFlatVec.size() - numFlatVectors); + flatCastDataChunk = std::make_unique(numFlatVectors, + DataChunkState::getSingleValueDataChunkState()); + uint64_t numInsertedFlatVector = 0; + castFuncs.resize(exportCSVBindData.types.size()); + for (auto i = 0u; i < exportCSVBindData.types.size(); i++) { + castFuncs[i] = function::CastFunction::bindCastFunction("cast", + exportCSVBindData.types[i], LogicalType::STRING()) + ->execFunc; + auto castVector = std::make_unique(LogicalTypeID::STRING, + storage::MemoryManager::Get(context)); + castVectors.push_back(castVector.get()); + if (isFlatVec[i]) { + flatCastDataChunk->insert(numInsertedFlatVector, std::move(castVector)); + numInsertedFlatVector++; + } else { + unflatCastDataChunk->insert(i - numInsertedFlatVector, std::move(castVector)); + } + } + } +}; + +static std::unique_ptr bindFunc(ExportFuncBindInput& bindInput) { + return std::make_unique(bindInput.columnNames, bindInput.filePath, + CSVReaderConfig::construct(bindInput.parsingOptions).option.copy()); +} + +static std::unique_ptr initLocalStateFunc(main::ClientContext& context, + const ExportFuncBindData& bindData, std::vector isFlatVec) { + return std::make_unique(context, bindData, isFlatVec); +} + +static std::shared_ptr createSharedStateFunc() { + return std::make_shared(); +} + +static void initSharedStateFunc(ExportFuncSharedState& sharedState, main::ClientContext& context, + const ExportFuncBindData& bindData) { + sharedState.init(context, bindData); +} + +static void writeRows(const ExportCSVBindData& exportCSVBindData, ExportCSVLocalState& localState, + ExportCSVSharedState& sharedState, std::vector> inputVectors) { + auto& exportCSVLocalState = localState.cast(); + auto& castVectors = localState.castVectors; + auto& serializer = localState.serializer; + for (auto i = 0u; i < inputVectors.size(); i++) { + auto vectorToCast = {inputVectors[i]}; + exportCSVLocalState.castFuncs[i](vectorToCast, + common::SelectionVector::fromValueVectors(vectorToCast), *castVectors[i], + castVectors[i]->getSelVectorPtr(), nullptr /* dataPtr */); + } + + uint64_t numRowsToWrite = 1; + for (auto& vectorToCast : inputVectors) { + if (!vectorToCast->state->isFlat()) { + numRowsToWrite = vectorToCast->state->getSelVector().getSelSize(); + break; + } + } + for (auto i = 0u; i < numRowsToWrite; i++) { + for (auto j = 0u; j < castVectors.size(); j++) { + if (j != 0) { + serializer->writeBufferData(exportCSVBindData.exportOption.delimiter); + } + auto vector = castVectors[j]; + auto pos = vector->state->isFlat() ? vector->state->getSelVector()[0] : + vector->state->getSelVector()[i]; + if (vector->isNull(pos)) { + // write null value + serializer->writeBufferData(ExportCSVConstants::DEFAULT_NULL_STR); + continue; + } + auto strValue = vector->getValue(pos); + // Note: we need blindly add quotes to LIST. + writeString(serializer.get(), exportCSVBindData, strValue.getData(), strValue.len, + ExportCSVConstants::DEFAULT_FORCE_QUOTE || + inputVectors[j]->dataType.getLogicalTypeID() == LogicalTypeID::LIST, + sharedState.parallelFlag); + } + serializer->writeBufferData(ExportCSVConstants::DEFAULT_CSV_NEWLINE[0]); + } +} + +static void sinkFunc(ExportFuncSharedState& sharedState, ExportFuncLocalState& localState, + const ExportFuncBindData& bindData, std::vector> inputVectors) { + auto& exportCSVLocalState = localState.cast(); + auto& exportCSVBindData = bindData.constCast(); + auto& exportCSVSharedState = sharedState.cast(); + writeRows(exportCSVBindData, exportCSVLocalState, exportCSVSharedState, + std::move(inputVectors)); + auto& serializer = exportCSVLocalState.serializer; + if (serializer->getSize() > ExportCSVConstants::DEFAULT_CSV_FLUSH_SIZE) { + exportCSVSharedState.writeRows(serializer->getBlobData(), serializer->getSize()); + serializer->clear(); + } +} + +static void combineFunc(ExportFuncSharedState& sharedState, ExportFuncLocalState& localState) { + auto& serializer = localState.cast().serializer; + auto& exportCSVSharedState = sharedState.cast(); + if (serializer->getSize() > 0) { + exportCSVSharedState.writeRows(serializer->getBlobData(), serializer->getSize()); + serializer->clear(); + } +} + +static void finalizeFunc(ExportFuncSharedState&) {} + +function_set ExportCSVFunction::getFunctionSet() { + function_set functionSet; + auto exportFunc = std::make_unique(name); + exportFunc->bind = bindFunc; + exportFunc->initLocalState = initLocalStateFunc; + exportFunc->createSharedState = createSharedStateFunc; + exportFunc->initSharedState = initSharedStateFunc; + exportFunc->sink = sinkFunc; + exportFunc->combine = combineFunc; + exportFunc->finalize = finalizeFunc; + functionSet.push_back(std::move(exportFunc)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/export_parquet_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/export_parquet_function.cpp new file mode 100644 index 0000000000..49d14366f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/export/export_parquet_function.cpp @@ -0,0 +1,183 @@ +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include "common/system_config.h" +#include "function/export/export_function.h" +#include "main/client_context.h" +#include "parquet_types.h" +#include "processor/operator/persistent/writer/parquet/parquet_writer.h" +#include "processor/result/factorized_table.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace function { + +using namespace common; +using namespace processor; + +struct ParquetOptions { + lbug_parquet::format::CompressionCodec::type codec = + lbug_parquet::format::CompressionCodec::SNAPPY; + + explicit ParquetOptions(case_insensitive_map_t parsingOptions) { + for (auto& [name, value] : parsingOptions) { + if (name == "COMPRESSION") { + setCompression(value); + } else { + throw common::RuntimeException{ + common::stringFormat("Unrecognized parquet option: {}.", name)}; + } + } + } + + void setCompression(common::Value& value) { + if (value.getDataType().getLogicalTypeID() != LogicalTypeID::STRING) { + throw common::RuntimeException{ + common::stringFormat("Parquet compression option expects a string value, got: {}.", + value.getDataType().toString())}; + } + auto strVal = common::StringUtils::getUpper(value.getValue()); + if (strVal == "UNCOMPRESSED") { + codec = lbug_parquet::format::CompressionCodec::UNCOMPRESSED; + } else if (strVal == "SNAPPY") { + codec = lbug_parquet::format::CompressionCodec::SNAPPY; + } else if (strVal == "ZSTD") { + codec = lbug_parquet::format::CompressionCodec::ZSTD; + } else if (strVal == "GZIP") { + codec = lbug_parquet::format::CompressionCodec::GZIP; + } else if (strVal == "LZ4_RAW") { + codec = lbug_parquet::format::CompressionCodec::LZ4_RAW; + } else { + throw common::RuntimeException{common::stringFormat( + "Unrecognized parquet compression option: {}.", value.toString())}; + } + } +}; + +struct ExportParquetBindData final : public ExportFuncBindData { + ParquetOptions parquetOptions; + + ExportParquetBindData(std::vector names, std::string fileName, + ParquetOptions parquetOptions) + : ExportFuncBindData{std::move(names), std::move(fileName)}, + parquetOptions{parquetOptions} {} + + std::unique_ptr copy() const override { + auto bindData = + std::make_unique(columnNames, fileName, parquetOptions); + bindData->types = LogicalType::copy(types); + return bindData; + } +}; + +struct ExportParquetLocalState final : public ExportFuncLocalState { + std::unique_ptr ft; + uint64_t numTuplesInFT; + storage::MemoryManager* mm; + + ExportParquetLocalState(const ExportFuncBindData& bindData, main::ClientContext& context, + std::vector isFlatVec) + : mm{storage::MemoryManager::Get(context)} { + auto tableSchema = FactorizedTableSchema(); + for (auto i = 0u; i < isFlatVec.size(); i++) { + auto columnSchema = + isFlatVec[i] ? + ColumnSchema(false, 0 /* dummyGroupPos */, + LogicalTypeUtils::getRowLayoutSize(bindData.types[i])) : + ColumnSchema(true, 1 /* dummyGroupPos */, (uint32_t)sizeof(overflow_value_t)); + tableSchema.appendColumn(std::move(columnSchema)); + } + ft = std::make_unique(mm, tableSchema.copy()); + numTuplesInFT = 0; + } +}; + +struct ExportParquetSharedState : public ExportFuncSharedState { + std::unique_ptr writer; + + ExportParquetSharedState() = default; + + void init(main::ClientContext& context, const ExportFuncBindData& bindData) override { + auto& exportParquetBindData = bindData.constCast(); + writer = std::make_unique(exportParquetBindData.fileName, + common::LogicalType::copy(exportParquetBindData.types), + exportParquetBindData.columnNames, exportParquetBindData.parquetOptions.codec, + &context); + } +}; + +static std::unique_ptr bindFunc(ExportFuncBindInput& bindInput) { + ParquetOptions parquetOptions{bindInput.parsingOptions}; + return std::make_unique(bindInput.columnNames, bindInput.filePath, + parquetOptions); +} + +static std::unique_ptr initLocalStateFunc(main::ClientContext& context, + const ExportFuncBindData& bindData, std::vector isFlatVec) { + return std::make_unique(bindData, context, isFlatVec); +} + +static std::shared_ptr createSharedStateFunc() { + return std::make_shared(); +} + +static void initSharedStateFunc(ExportFuncSharedState& sharedState, main::ClientContext& context, + const ExportFuncBindData& bindData) { + sharedState.init(context, bindData); +} + +static std::vector extractSharedPtr( + std::vector> inputVectors, uint64_t& numTuplesToAppend) { + std::vector vecs; + numTuplesToAppend = + inputVectors.size() > 0 ? inputVectors[0]->state->getSelVector().getSelSize() : 0; + for (auto& inputVector : inputVectors) { + if (!inputVector->state->isFlat()) { + numTuplesToAppend = inputVector->state->getSelVector().getSelSize(); + } + vecs.push_back(inputVector.get()); + } + return vecs; +} + +static void sinkFunc(ExportFuncSharedState& sharedState, ExportFuncLocalState& localState, + const ExportFuncBindData& /*bindData*/, + std::vector> inputVectors) { + auto& exportParquetLocalState = localState.cast(); + uint64_t numTuplesToAppend = 0; + // TODO(Ziyi): We should let factorizedTable::append return the numTuples appended. + exportParquetLocalState.ft->append(extractSharedPtr(inputVectors, numTuplesToAppend)); + exportParquetLocalState.numTuplesInFT += numTuplesToAppend; + if (exportParquetLocalState.numTuplesInFT > StorageConfig::NODE_GROUP_SIZE) { + auto& exportParquetSharedState = sharedState.cast(); + exportParquetSharedState.writer->flush(*exportParquetLocalState.ft); + exportParquetLocalState.numTuplesInFT = 0; + } +} + +static void combineFunc(ExportFuncSharedState& sharedState, ExportFuncLocalState& localState) { + auto& exportParquetSharedState = sharedState.cast(); + auto& exportParquetLocalState = localState.cast(); + exportParquetSharedState.writer->flush(*exportParquetLocalState.ft); +} + +static void finalizeFunc(ExportFuncSharedState& sharedState) { + auto& exportParquetSharedState = sharedState.cast(); + exportParquetSharedState.writer->finalize(); +} + +function_set ExportParquetFunction::getFunctionSet() { + function_set functionSet; + auto exportFunc = std::make_unique(name); + exportFunc->initLocalState = initLocalStateFunc; + exportFunc->createSharedState = createSharedStateFunc; + exportFunc->initSharedState = initSharedStateFunc; + exportFunc->sink = sinkFunc; + exportFunc->combine = combineFunc; + exportFunc->finalize = finalizeFunc; + exportFunc->bind = bindFunc; + functionSet.push_back(std::move(exportFunc)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/find_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/find_function.cpp new file mode 100644 index 0000000000..866eaf106d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/find_function.cpp @@ -0,0 +1,130 @@ +#include "function/string/functions/find_function.h" + +#include + +using namespace lbug::common; + +namespace lbug { +namespace function { + +template +int64_t Find::unalignedNeedleSizeFind(const uint8_t* haystack, uint32_t haystackLen, + const uint8_t* needle, uint32_t needleLen, uint32_t firstMatchCharOffset) { + if (needleLen > haystackLen) { + return -1; + } + // We perform unsigned integer comparisons to check for equality of the entire needle in a + // single comparison. This implementation is inspired by the memmem implementation of + // freebsd. + UNSIGNED needleEntry = 0; + UNSIGNED haystackEntry = 0; + const UNSIGNED start = (sizeof(UNSIGNED) * 8) - 8; + const UNSIGNED shift = (sizeof(UNSIGNED) - needleLen) * 8; + for (auto i = 0u; i < needleLen; i++) { + needleEntry |= UNSIGNED(needle[i]) << UNSIGNED(start - i * 8); + haystackEntry |= UNSIGNED(haystack[i]) << UNSIGNED(start - i * 8); + } + for (auto offset = needleLen; offset < haystackLen; offset++) { + if (haystackEntry == needleEntry) { + return firstMatchCharOffset + offset - needleLen; + } + // We adjust the haystack entry by + // (1) removing the left-most character (shift by 8) + // (2) adding the next character (bitwise or, with potential shift) + // this shift is only necessary if the needle size is not aligned with the unsigned + // integer size (e.g. needle size 3, unsigned integer size 4, we need to shift by 1). + haystackEntry = (haystackEntry << 8) | ((UNSIGNED(haystack[offset])) << shift); + } + if (haystackEntry == needleEntry) { + return firstMatchCharOffset + haystackLen - needleLen; + } + return -1; +} + +template +int64_t Find::alignedNeedleSizeFind(const uint8_t* haystack, uint32_t haystackLen, + const uint8_t* needle, uint32_t firstMatchCharOffset) { + if (sizeof(UNSIGNED) > haystackLen) { + return -1; + } + auto needleVal = *((UNSIGNED*)needle); + for (auto offset = 0u; offset <= haystackLen - sizeof(UNSIGNED); offset++) { + auto haystackVal = *((UNSIGNED*)(haystack + offset)); + if (needleVal == haystackVal) { + return firstMatchCharOffset + offset; + } + } + return -1; +} + +int64_t Find::genericFind(const uint8_t* haystack, uint32_t haystackLen, const uint8_t* needle, + uint32_t needLen, uint32_t firstMatchCharOffset) { + if (needLen > haystackLen) { + return -1; + } + // This implementation is inspired by Raphael Javaux's faststrstr + // (https://github.com/RaphaelJ/fast_strstr) generic contains; note that we can't use strstr + // because we don't have null-terminated strings anymore we keep track of a shifting window + // sum of all characters with window size equal to needle_size this shifting sum is used to + // avoid calling into memcmp; we only need to call into memcmp when the window sum is equal + // to the needle sum when that happens, the characters are potentially the same and we call + // into memcmp to check if they are. + auto sumsDiff = 0u; + for (auto i = 0u; i < needLen; i++) { + sumsDiff += haystack[i]; + sumsDiff -= needle[i]; + } + auto offset = 0u; + while (true) { + if (sumsDiff == 0 && haystack[offset] == needle[0]) { + if (memcmp(haystack + offset, needle, needLen) == 0) { + return firstMatchCharOffset + offset; + } + } + if (offset >= haystackLen - needLen) { + return -1; + } + sumsDiff -= haystack[offset]; + sumsDiff += haystack[offset + needLen]; + offset++; + } +} + +// Returns the position of the first occurrence of needle in the haystack. If haystack doesn't +// contain needle, it returns -1. +int64_t Find::find(const uint8_t* haystack, uint32_t haystackLen, const uint8_t* needle, + uint32_t needleLen) { + auto firstMatchCharPos = (uint8_t*)memchr(haystack, needle[0], haystackLen); + if (firstMatchCharPos == nullptr) { + return -1; + } + auto firstMatchCharOffset = firstMatchCharPos - haystack; + auto numCharsToMatch = haystackLen - firstMatchCharOffset; + switch (needleLen) { + case 1: + return firstMatchCharOffset; + case 2: + return alignedNeedleSizeFind(firstMatchCharPos, numCharsToMatch, needle, + firstMatchCharOffset); + case 3: + return unalignedNeedleSizeFind(firstMatchCharPos, numCharsToMatch, needle, 3, + firstMatchCharOffset); + case 4: + return alignedNeedleSizeFind(firstMatchCharPos, numCharsToMatch, needle, + firstMatchCharOffset); + case 5: + case 6: + case 7: + return unalignedNeedleSizeFind(firstMatchCharPos, numCharsToMatch, needle, + needleLen, firstMatchCharOffset); + case 8: + return alignedNeedleSizeFind(firstMatchCharPos, numCharsToMatch, needle, + firstMatchCharOffset); + default: + return genericFind(firstMatchCharPos, numCharsToMatch, needle, needleLen, + firstMatchCharOffset); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/function.cpp new file mode 100644 index 0000000000..8765e1b90a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/function.cpp @@ -0,0 +1,18 @@ +#include "function/function.h" + +#include "binder/expression/expression_util.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace function { + +std::unique_ptr FunctionBindData::getSimpleBindData( + const expression_vector& params, const LogicalType& resultType) { + auto paramTypes = ExpressionUtil::getDataTypes(params); + return std::make_unique(std::move(paramTypes), resultType.copy()); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/function_collection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/function_collection.cpp new file mode 100644 index 0000000000..f4790ad3e3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/function_collection.cpp @@ -0,0 +1,257 @@ +#include "function/function_collection.h" + +#include "function/aggregate/count.h" +#include "function/aggregate/count_star.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/array/vector_array_functions.h" +#include "function/blob/vector_blob_functions.h" +#include "function/cast/vector_cast_functions.h" +#include "function/comparison/vector_comparison_functions.h" +#include "function/date/vector_date_functions.h" +#include "function/export/export_function.h" +#include "function/hash/vector_hash_functions.h" +#include "function/internal_id/vector_internal_id_functions.h" +#include "function/interval/vector_interval_functions.h" +#include "function/list/vector_list_functions.h" +#include "function/map/vector_map_functions.h" +#include "function/path/vector_path_functions.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/sequence/sequence_functions.h" +#include "function/string/vector_string_functions.h" +#include "function/struct/vector_struct_functions.h" +#include "function/table/simple_table_function.h" +#include "function/table/standalone_call_function.h" +#include "function/timestamp/vector_timestamp_functions.h" +#include "function/union/vector_union_functions.h" +#include "function/utility/vector_utility_functions.h" +#include "function/uuid/vector_uuid_functions.h" +#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" +#include "processor/operator/persistent/reader/npy/npy_reader.h" +#include "processor/operator/persistent/reader/parquet/parquet_reader.h" + +using namespace lbug::processor; + +namespace lbug { +namespace function { + +#define SCALAR_FUNCTION_BASE(_PARAM, _NAME) \ + { _PARAM::getFunctionSet, _NAME, CatalogEntryType::SCALAR_FUNCTION_ENTRY } +#define SCALAR_FUNCTION(_PARAM) SCALAR_FUNCTION_BASE(_PARAM, _PARAM::name) +#define SCALAR_FUNCTION_ALIAS(_PARAM) SCALAR_FUNCTION_BASE(_PARAM::alias, _PARAM::name) +#define REWRITE_FUNCTION_BASE(_PARAM, _NAME) \ + { _PARAM::getFunctionSet, _NAME, CatalogEntryType::REWRITE_FUNCTION_ENTRY } +#define REWRITE_FUNCTION(_PARAM) REWRITE_FUNCTION_BASE(_PARAM, _PARAM::name) +#define REWRITE_FUNCTION_ALIAS(_PARAM) REWRITE_FUNCTION_BASE(_PARAM::alias, _PARAM::name) +#define AGGREGATE_FUNCTION(_PARAM) \ + { _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::AGGREGATE_FUNCTION_ENTRY } +#define EXPORT_FUNCTION(_PARAM) \ + { _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::COPY_FUNCTION_ENTRY } +#define TABLE_FUNCTION(_PARAM) \ + { _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::TABLE_FUNCTION_ENTRY } +#define STANDALONE_TABLE_FUNCTION(_PARAM) \ + { _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY } +#define FINAL_FUNCTION \ + { nullptr, nullptr, CatalogEntryType::SCALAR_FUNCTION_ENTRY } + +FunctionCollection* FunctionCollection::getFunctions() { + static FunctionCollection functions[] = { + + // Arithmetic Functions + SCALAR_FUNCTION(AddFunction), SCALAR_FUNCTION(SubtractFunction), + SCALAR_FUNCTION(MultiplyFunction), SCALAR_FUNCTION(DivideFunction), + SCALAR_FUNCTION(ModuloFunction), SCALAR_FUNCTION(PowerFunction), + SCALAR_FUNCTION(AbsFunction), SCALAR_FUNCTION(AcosFunction), SCALAR_FUNCTION(AsinFunction), + SCALAR_FUNCTION(AtanFunction), SCALAR_FUNCTION(Atan2Function), + SCALAR_FUNCTION(BitwiseXorFunction), SCALAR_FUNCTION(BitwiseAndFunction), + SCALAR_FUNCTION(BitwiseOrFunction), SCALAR_FUNCTION(BitShiftLeftFunction), + SCALAR_FUNCTION(BitShiftRightFunction), SCALAR_FUNCTION(CbrtFunction), + SCALAR_FUNCTION(CeilFunction), SCALAR_FUNCTION_ALIAS(CeilingFunction), + SCALAR_FUNCTION(CosFunction), SCALAR_FUNCTION(CotFunction), + SCALAR_FUNCTION(DegreesFunction), SCALAR_FUNCTION(EvenFunction), + SCALAR_FUNCTION(FactorialFunction), SCALAR_FUNCTION(FloorFunction), + SCALAR_FUNCTION(GammaFunction), SCALAR_FUNCTION(LgammaFunction), + SCALAR_FUNCTION(LnFunction), SCALAR_FUNCTION(LogFunction), + SCALAR_FUNCTION_ALIAS(Log10Function), SCALAR_FUNCTION(Log2Function), + SCALAR_FUNCTION(NegateFunction), SCALAR_FUNCTION(PiFunction), + SCALAR_FUNCTION_ALIAS(PowFunction), SCALAR_FUNCTION(RadiansFunction), + SCALAR_FUNCTION(RoundFunction), SCALAR_FUNCTION(SinFunction), SCALAR_FUNCTION(SignFunction), + SCALAR_FUNCTION(SqrtFunction), SCALAR_FUNCTION(TanFunction), SCALAR_FUNCTION(RandFunction), + SCALAR_FUNCTION(SetSeedFunction), + + // String Functions + SCALAR_FUNCTION(ArrayExtractFunction), SCALAR_FUNCTION(ConcatFunction), + SCALAR_FUNCTION(ContainsFunction), SCALAR_FUNCTION(LowerFunction), + SCALAR_FUNCTION_ALIAS(ToLowerFunction), SCALAR_FUNCTION_ALIAS(LcaseFunction), + SCALAR_FUNCTION(LeftFunction), SCALAR_FUNCTION(LpadFunction), + SCALAR_FUNCTION(LtrimFunction), SCALAR_FUNCTION(StartsWithFunction), + SCALAR_FUNCTION_ALIAS(PrefixFunction), SCALAR_FUNCTION(RepeatFunction), + SCALAR_FUNCTION(ReverseFunction), SCALAR_FUNCTION(RightFunction), + SCALAR_FUNCTION(RpadFunction), SCALAR_FUNCTION(RtrimFunction), + SCALAR_FUNCTION(SubStrFunction), SCALAR_FUNCTION_ALIAS(SubstringFunction), + SCALAR_FUNCTION(EndsWithFunction), SCALAR_FUNCTION_ALIAS(SuffixFunction), + SCALAR_FUNCTION(TrimFunction), SCALAR_FUNCTION(UpperFunction), + SCALAR_FUNCTION_ALIAS(UCaseFunction), SCALAR_FUNCTION_ALIAS(ToUpperFunction), + SCALAR_FUNCTION(RegexpFullMatchFunction), SCALAR_FUNCTION(RegexpMatchesFunction), + SCALAR_FUNCTION(RegexpReplaceFunction), SCALAR_FUNCTION(RegexpExtractFunction), + SCALAR_FUNCTION(RegexpExtractAllFunction), SCALAR_FUNCTION(LevenshteinFunction), + SCALAR_FUNCTION(RegexpSplitToArrayFunction), SCALAR_FUNCTION(InitCapFunction), + SCALAR_FUNCTION(StringSplitFunction), SCALAR_FUNCTION_ALIAS(StrSplitFunction), + SCALAR_FUNCTION_ALIAS(StringToArrayFunction), SCALAR_FUNCTION(SplitPartFunction), + SCALAR_FUNCTION(InternalIDCreationFunction), SCALAR_FUNCTION(ConcatWSFunction), + + // Array Functions + SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction), + SCALAR_FUNCTION(ArrayCosineSimilarityFunction), SCALAR_FUNCTION(ArrayDistanceFunction), + SCALAR_FUNCTION(ArraySquaredDistanceFunction), SCALAR_FUNCTION(ArrayInnerProductFunction), + SCALAR_FUNCTION(ArrayDotProductFunction), + + // List functions + SCALAR_FUNCTION(ListCreationFunction), SCALAR_FUNCTION(ListRangeFunction), + SCALAR_FUNCTION(ListExtractFunction), SCALAR_FUNCTION_ALIAS(ListElementFunction), + SCALAR_FUNCTION(ListConcatFunction), SCALAR_FUNCTION_ALIAS(ListCatFunction), + SCALAR_FUNCTION(ArrayConcatFunction), SCALAR_FUNCTION_ALIAS(ArrayCatFunction), + SCALAR_FUNCTION(ListAppendFunction), SCALAR_FUNCTION(ArrayAppendFunction), + SCALAR_FUNCTION_ALIAS(ArrayPushFrontFunction), SCALAR_FUNCTION(ListPrependFunction), + SCALAR_FUNCTION(ArrayPrependFunction), SCALAR_FUNCTION_ALIAS(ArrayPushBackFunction), + SCALAR_FUNCTION(ListPositionFunction), SCALAR_FUNCTION_ALIAS(ListIndexOfFunction), + SCALAR_FUNCTION(ArrayPositionFunction), SCALAR_FUNCTION_ALIAS(ArrayIndexOfFunction), + SCALAR_FUNCTION(ListContainsFunction), SCALAR_FUNCTION_ALIAS(ListHasFunction), + SCALAR_FUNCTION(ArrayContainsFunction), SCALAR_FUNCTION_ALIAS(ArrayHasFunction), + SCALAR_FUNCTION(ListSliceFunction), SCALAR_FUNCTION(ArraySliceFunction), + SCALAR_FUNCTION(ListSortFunction), SCALAR_FUNCTION(ListReverseSortFunction), + SCALAR_FUNCTION(ListSumFunction), SCALAR_FUNCTION(ListProductFunction), + SCALAR_FUNCTION(ListDistinctFunction), SCALAR_FUNCTION(ListUniqueFunction), + SCALAR_FUNCTION(ListAnyValueFunction), SCALAR_FUNCTION(ListReverseFunction), + SCALAR_FUNCTION(SizeFunction), SCALAR_FUNCTION(ListToStringFunction), + SCALAR_FUNCTION(ListTransformFunction), SCALAR_FUNCTION(ListFilterFunction), + SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction), + SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction), + SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), + + // Cast functions + SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), + SCALAR_FUNCTION(CastToTimestampFunction), SCALAR_FUNCTION(CastToIntervalFunction), + SCALAR_FUNCTION_ALIAS(IntervalFunctionAlias), SCALAR_FUNCTION_ALIAS(DurationFunction), + SCALAR_FUNCTION(CastToStringFunction), SCALAR_FUNCTION_ALIAS(StringFunction), + SCALAR_FUNCTION(CastToBlobFunction), SCALAR_FUNCTION_ALIAS(BlobFunction), + SCALAR_FUNCTION(CastToUUIDFunction), SCALAR_FUNCTION_ALIAS(UUIDFunction), + SCALAR_FUNCTION(CastToDoubleFunction), SCALAR_FUNCTION(CastToFloatFunction), + SCALAR_FUNCTION(CastToSerialFunction), SCALAR_FUNCTION(CastToInt64Function), + SCALAR_FUNCTION(CastToInt32Function), SCALAR_FUNCTION(CastToInt16Function), + SCALAR_FUNCTION(CastToInt8Function), SCALAR_FUNCTION(CastToUInt64Function), + SCALAR_FUNCTION(CastToUInt32Function), SCALAR_FUNCTION(CastToUInt16Function), + SCALAR_FUNCTION(CastToUInt8Function), SCALAR_FUNCTION(CastToInt128Function), + SCALAR_FUNCTION(CastToUInt128Function), SCALAR_FUNCTION(CastToBoolFunction), + SCALAR_FUNCTION(CastAnyFunction), + + // Comparison functions + SCALAR_FUNCTION(EqualsFunction), SCALAR_FUNCTION(NotEqualsFunction), + SCALAR_FUNCTION(GreaterThanFunction), SCALAR_FUNCTION(GreaterThanEqualsFunction), + SCALAR_FUNCTION(LessThanFunction), SCALAR_FUNCTION(LessThanEqualsFunction), + + // Date functions + SCALAR_FUNCTION(DatePartFunction), SCALAR_FUNCTION_ALIAS(DatePartFunctionAlias), + SCALAR_FUNCTION(DateTruncFunction), SCALAR_FUNCTION_ALIAS(DateTruncFunctionAlias), + SCALAR_FUNCTION(DayNameFunction), SCALAR_FUNCTION(GreatestFunction), + SCALAR_FUNCTION(LastDayFunction), SCALAR_FUNCTION(LeastFunction), + SCALAR_FUNCTION(MakeDateFunction), SCALAR_FUNCTION(MonthNameFunction), + SCALAR_FUNCTION(CurrentDateFunction), + + // Timestamp functions + SCALAR_FUNCTION(CenturyFunction), SCALAR_FUNCTION(EpochMsFunction), + SCALAR_FUNCTION(ToTimestampFunction), SCALAR_FUNCTION(CurrentTimestampFunction), + SCALAR_FUNCTION(ToEpochMsFunction), + + // Interval functions + SCALAR_FUNCTION(ToYearsFunction), SCALAR_FUNCTION(ToMonthsFunction), + SCALAR_FUNCTION(ToDaysFunction), SCALAR_FUNCTION(ToHoursFunction), + SCALAR_FUNCTION(ToMinutesFunction), SCALAR_FUNCTION(ToSecondsFunction), + SCALAR_FUNCTION(ToMillisecondsFunction), SCALAR_FUNCTION(ToMicrosecondsFunction), + + // Blob functions + SCALAR_FUNCTION(OctetLengthFunctions), SCALAR_FUNCTION(EncodeFunctions), + SCALAR_FUNCTION(DecodeFunctions), + + // UUID functions + SCALAR_FUNCTION(GenRandomUUIDFunction), + + // Struct functions + SCALAR_FUNCTION(StructPackFunctions), SCALAR_FUNCTION(StructExtractFunctions), + REWRITE_FUNCTION(KeysFunctions), + + // Map functions + SCALAR_FUNCTION(MapCreationFunctions), SCALAR_FUNCTION(MapExtractFunctions), + SCALAR_FUNCTION_ALIAS(ElementAtFunctions), SCALAR_FUNCTION_ALIAS(CardinalityFunction), + SCALAR_FUNCTION(MapKeysFunctions), SCALAR_FUNCTION(MapValuesFunctions), + + // Union functions + SCALAR_FUNCTION(UnionValueFunction), SCALAR_FUNCTION(UnionTagFunction), + SCALAR_FUNCTION(UnionExtractFunction), + + // Node/rel functions + SCALAR_FUNCTION(OffsetFunction), REWRITE_FUNCTION(IDFunction), + REWRITE_FUNCTION(StartNodeFunction), REWRITE_FUNCTION(EndNodeFunction), + REWRITE_FUNCTION(LabelFunction), REWRITE_FUNCTION_ALIAS(LabelsFunction), + REWRITE_FUNCTION(CostFunction), + + // Path functions + SCALAR_FUNCTION(NodesFunction), SCALAR_FUNCTION(RelsFunction), + SCALAR_FUNCTION_ALIAS(RelationshipsFunction), SCALAR_FUNCTION(PropertiesFunction), + SCALAR_FUNCTION(IsTrailFunction), SCALAR_FUNCTION(IsACyclicFunction), + REWRITE_FUNCTION(LengthFunction), + + // Hash functions + SCALAR_FUNCTION(MD5Function), SCALAR_FUNCTION(SHA256Function), + SCALAR_FUNCTION(HashFunction), + + // Scalar utility functions + SCALAR_FUNCTION(CoalesceFunction), SCALAR_FUNCTION(IfNullFunction), + SCALAR_FUNCTION(ConstantOrNullFunction), SCALAR_FUNCTION(CountIfFunction), + SCALAR_FUNCTION(ErrorFunction), REWRITE_FUNCTION(NullIfFunction), + SCALAR_FUNCTION(TypeOfFunction), + + // Sequence functions + SCALAR_FUNCTION(CurrValFunction), SCALAR_FUNCTION(NextValFunction), + + // Aggregate functions + AGGREGATE_FUNCTION(CountStarFunction), AGGREGATE_FUNCTION(CountFunction), + AGGREGATE_FUNCTION(AggregateSumFunction), AGGREGATE_FUNCTION(AggregateAvgFunction), + AGGREGATE_FUNCTION(AggregateMinFunction), AGGREGATE_FUNCTION(AggregateMaxFunction), + AGGREGATE_FUNCTION(CollectFunction), + + // Table functions + TABLE_FUNCTION(CurrentSettingFunction), TABLE_FUNCTION(CatalogVersionFunction), + TABLE_FUNCTION(DBVersionFunction), TABLE_FUNCTION(ShowTablesFunction), + TABLE_FUNCTION(FreeSpaceInfoFunction), TABLE_FUNCTION(ShowWarningsFunction), + TABLE_FUNCTION(TableInfoFunction), TABLE_FUNCTION(ShowConnectionFunction), + TABLE_FUNCTION(StatsInfoFunction), TABLE_FUNCTION(StorageInfoFunction), + TABLE_FUNCTION(ShowAttachedDatabasesFunction), TABLE_FUNCTION(ShowSequencesFunction), + TABLE_FUNCTION(ShowFunctionsFunction), TABLE_FUNCTION(BMInfoFunction), + TABLE_FUNCTION(FileInfoFunction), TABLE_FUNCTION(ShowLoadedExtensionsFunction), + TABLE_FUNCTION(ShowOfficialExtensionsFunction), TABLE_FUNCTION(ShowIndexesFunction), + TABLE_FUNCTION(ShowProjectedGraphsFunction), TABLE_FUNCTION(ProjectedGraphInfoFunction), + TABLE_FUNCTION(ShowMacrosFunction), + + // Standalone Table functions + STANDALONE_TABLE_FUNCTION(LocalCacheArrayColumnFunction), + STANDALONE_TABLE_FUNCTION(ClearWarningsFunction), + STANDALONE_TABLE_FUNCTION(ProjectGraphNativeFunction), + STANDALONE_TABLE_FUNCTION(ProjectGraphCypherFunction), + STANDALONE_TABLE_FUNCTION(DropProjectedGraphFunction), + + // Scan functions + TABLE_FUNCTION(ParquetScanFunction), TABLE_FUNCTION(NpyScanFunction), + TABLE_FUNCTION(SerialCSVScan), TABLE_FUNCTION(ParallelCSVScan), + + // Export functions + EXPORT_FUNCTION(ExportCSVFunction), EXPORT_FUNCTION(ExportParquetFunction), + + // End of array + FINAL_FUNCTION}; + + return functions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/CMakeLists.txt new file mode 100644 index 0000000000..9552f8d4f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/CMakeLists.txt @@ -0,0 +1,24 @@ +add_library(lbug_function_algorithm + OBJECT + asp_destinations.cpp + asp_paths.cpp + awsp_paths.cpp + bfs_graph.cpp + frontier_morsel.cpp + gds.cpp + gds_frontier.cpp + gds_state.cpp + gds_task.cpp + gds_utils.cpp + output_writer.cpp + rec_joins.cpp + ssp_destinations.cpp + ssp_paths.cpp + variable_length_path.cpp + wsp_destinations.cpp + wsp_paths.cpp + ) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/asp_destinations.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/asp_destinations.cpp new file mode 100644 index 0000000000..4cc221b101 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/asp_destinations.cpp @@ -0,0 +1,333 @@ +#include "binder/expression/node_expression.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::processor; +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::storage; +using namespace lbug::graph; + +namespace lbug { +namespace function { + +using multiplicity_t = uint64_t; + +class Multiplicities { +public: + virtual ~Multiplicities() = default; + + virtual void pinTableID(table_id_t tableID) = 0; + + virtual void increaseMultiplicity(offset_t offset, multiplicity_t multiplicity) = 0; + + virtual multiplicity_t getMultiplicity(offset_t offset) = 0; +}; + +class SparseMultiplicitiesReference final : public Multiplicities { +public: + explicit SparseMultiplicitiesReference(GDSSpareObjectManager& spareObjects) + : spareObjects{spareObjects} {} + + void pinTableID(table_id_t tableID) override { curData = spareObjects.getData(tableID); } + + void increaseMultiplicity(offset_t offset, multiplicity_t multiplicity) override { + KU_ASSERT(curData); + if (curData->contains(offset)) { + curData->at(offset) += multiplicity; + } else { + curData->insert({offset, multiplicity}); + } + } + + multiplicity_t getMultiplicity(offset_t offset) override { + KU_ASSERT(curData); + if (curData->contains(offset)) { + return curData->at(offset); + } + return 0; + } + +private: + GDSSpareObjectManager& spareObjects; + std::unordered_map* curData = nullptr; +}; + +class DenseMultiplicitiesReference final : public Multiplicities { +public: + explicit DenseMultiplicitiesReference( + GDSDenseObjectManager>& denseObjects) + : denseObjects(denseObjects) {} + + void pinTableID(table_id_t tableID) override { curData = denseObjects.getData(tableID); } + + void increaseMultiplicity(offset_t offset, multiplicity_t multiplicity) override { + KU_ASSERT(curData); + curData[offset].fetch_add(multiplicity); + } + + multiplicity_t getMultiplicity(offset_t offset) override { + KU_ASSERT(curData); + return curData[offset].load(std::memory_order_relaxed); + } + +private: + GDSDenseObjectManager>& denseObjects; + std::atomic* curData = nullptr; +}; + +class MultiplicitiesPair { +public: + explicit MultiplicitiesPair(const table_id_map_t& maxOffsetMap) + : maxOffsetMap{maxOffsetMap}, densityState{GDSDensityState::SPARSE}, + sparseObjects{maxOffsetMap} { + curSparseMultiplicities = std::make_unique(sparseObjects); + nextSparseMultiplicities = std::make_unique(sparseObjects); + denseObjects = GDSDenseObjectManager>(); + curDenseMultiplicities = std::make_unique(denseObjects); + nextDenseMultiplicities = std::make_unique(denseObjects); + } + + void pinCurTableID(table_id_t tableID) { + switch (densityState) { + case GDSDensityState::SPARSE: { + curSparseMultiplicities->pinTableID(tableID); + curMultiplicities = curSparseMultiplicities.get(); + } break; + case GDSDensityState::DENSE: { + curDenseMultiplicities->pinTableID(tableID); + curMultiplicities = curDenseMultiplicities.get(); + } break; + default: + KU_UNREACHABLE; + } + } + + void pinNextTableID(table_id_t tableID) { + switch (densityState) { + case GDSDensityState::SPARSE: { + nextSparseMultiplicities->pinTableID(tableID); + nextMultiplicities = nextSparseMultiplicities.get(); + } break; + case GDSDensityState::DENSE: { + nextDenseMultiplicities->pinTableID(tableID); + nextMultiplicities = nextDenseMultiplicities.get(); + } break; + default: + KU_UNREACHABLE; + } + } + + void increaseNextMultiplicity(offset_t offset, multiplicity_t multiplicity) { + nextMultiplicities->increaseMultiplicity(offset, multiplicity); + } + + multiplicity_t getCurrentMultiplicity(offset_t offset) const { + return curMultiplicities->getMultiplicity(offset); + } + Multiplicities* getCurrentMultiplicities() { return curMultiplicities; } + + void switchToDense(ExecutionContext* context) { + KU_ASSERT(densityState == GDSDensityState::SPARSE); + densityState = GDSDensityState::DENSE; + for (auto& [tableID, maxOffset] : maxOffsetMap) { + denseObjects.allocate(tableID, maxOffset, MemoryManager::Get(*context->clientContext)); + auto data = denseObjects.getData(tableID); + for (auto i = 0u; i < maxOffset; i++) { + data[i].store(0); + } + } + for (auto& [tableID, map] : sparseObjects.getData()) { + auto data = denseObjects.getData(tableID); + for (auto& [offset, multiplicity] : map) { + data[offset].store(multiplicity); + } + } + } + +private: + table_id_map_t maxOffsetMap; + GDSDensityState densityState; + GDSSpareObjectManager sparseObjects; + std::unique_ptr curSparseMultiplicities; + std::unique_ptr nextSparseMultiplicities; + GDSDenseObjectManager> denseObjects; + std::unique_ptr curDenseMultiplicities; + std::unique_ptr nextDenseMultiplicities; + + Multiplicities* curMultiplicities = nullptr; + Multiplicities* nextMultiplicities = nullptr; +}; + +class ASPDestinationsAuxiliaryState : public GDSAuxiliaryState { +public: + explicit ASPDestinationsAuxiliaryState(std::unique_ptr multiplicitiesPair) + : multiplicitiesPair{std::move(multiplicitiesPair)} {} + + MultiplicitiesPair* getMultiplicitiesPair() const { return multiplicitiesPair.get(); } + + void initSource(nodeID_t source) override { + multiplicitiesPair->pinNextTableID(source.tableID); + multiplicitiesPair->increaseNextMultiplicity(source.offset, 1); + } + + void beginFrontierCompute(table_id_t curTableID, table_id_t nextTableID) override { + multiplicitiesPair->pinCurTableID(curTableID); + multiplicitiesPair->pinNextTableID(nextTableID); + } + + void switchToDense(ExecutionContext* context, Graph*) override { + multiplicitiesPair->switchToDense(context); + } + +private: + std::unique_ptr multiplicitiesPair; +}; + +class ASPDestinationsOutputWriter : public RJOutputWriter { +public: + ASPDestinationsOutputWriter(main::ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID, Frontier* frontier, Multiplicities* multiplicities) + : RJOutputWriter{context, outputNodeMask, sourceNodeID}, frontier{frontier}, + multiplicities{multiplicities} { + lengthVector = createVector(LogicalType::UINT16()); + } + + void beginWritingInternal(table_id_t tableID) override { + frontier->pinTableID(tableID); + multiplicities->pinTableID(tableID); + } + + void write(FactorizedTable& fTable, table_id_t tableID, LimitCounter* counter) override { + auto& sparseFrontier = frontier->cast(); + for (auto [offset, _] : sparseFrontier.getCurrentData()) { + write(fTable, {offset, tableID}, counter); + } + } + + void write(FactorizedTable& fTable, nodeID_t dstNodeID, LimitCounter* counter) override { + if (!inOutputNodeMask(dstNodeID.offset)) { // Skip dst if it not is in scope. + return; + } + if (dstNodeID == sourceNodeID_) { // Skip writing source node. + return; + } + auto iter = frontier->getIteration(dstNodeID.offset); + if (iter == FRONTIER_UNVISITED) { // Skip if dst is not visited. + return; + } + dstNodeIDVector->setValue(0, dstNodeID); + lengthVector->setValue(0, iter); + auto multiplicity = multiplicities->getMultiplicity(dstNodeID.offset); + for (auto i = 0u; i < multiplicity; ++i) { + fTable.append(vectors); + } + if (counter != nullptr) { + counter->increase(multiplicity); + } + } + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, + frontier, multiplicities); + } + +private: + std::unique_ptr lengthVector; + Frontier* frontier; + Multiplicities* multiplicities; +}; + +class ASPDestinationsEdgeCompute : public SPEdgeCompute { +public: + ASPDestinationsEdgeCompute(SPFrontierPair* frontierPair, MultiplicitiesPair* multiplicitiesPair) + : SPEdgeCompute{frontierPair}, multiplicitiesPair{multiplicitiesPair} {}; + + std::vector edgeCompute(nodeID_t boundNodeID, NbrScanState::Chunk& resultChunk, + bool) override { + std::vector activeNodes; + resultChunk.forEach([&](auto neighbors, auto, auto i) { + auto nbrNodeID = neighbors[i]; + auto nbrVal = frontierPair->getNextFrontierValue(nbrNodeID.offset); + // We should update the nbrID's multiplicity in 2 cases: 1) if nbrID is being visited + // for the first time, i.e., when its value in the pathLengths frontier is + // FRONTIER_UNVISITED. Or 2) if nbrID has already been visited but in this + // iteration, so it's value is curIter + 1. + auto shouldUpdate = + nbrVal == FRONTIER_UNVISITED || nbrVal == frontierPair->getCurrentIter(); + if (shouldUpdate) { + // This is safe because boundNodeID is in the current frontier, so its + // shortest paths multiplicity is guaranteed to not change in the current iteration. + auto boundMultiplicity = + multiplicitiesPair->getCurrentMultiplicity(boundNodeID.offset); + multiplicitiesPair->increaseNextMultiplicity(nbrNodeID.offset, boundMultiplicity); + } + if (nbrVal == FRONTIER_UNVISITED) { + activeNodes.push_back(nbrNodeID); + } + }); + return activeNodes; + } + + std::unique_ptr copy() override { + return std::make_unique(frontierPair, multiplicitiesPair); + } + +private: + MultiplicitiesPair* multiplicitiesPair; +}; + +// All shortest path algorithm. Only destinations are tracked (reachability query). +class AllSPDestinationsAlgorithm final : public RJAlgorithm { +public: + std::string getFunctionName() const override { return AllSPDestinationsFunction::name; } + + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, const RJBindData&, + RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto graph = sharedState->graph.get(); + auto multiplicitiesPair = std::make_unique( + graph->getMaxOffsetMap(transaction::Transaction::Get(*clientContext))); + auto frontier = DenseFrontier::getUnvisitedFrontier(context, graph); + auto frontierPair = std::make_unique(std::move(frontier)); + auto edgeCompute = std::make_unique(frontierPair.get(), + multiplicitiesPair.get()); + auto auxiliaryState = + std::make_unique(std::move(multiplicitiesPair)); + return std::make_unique(std::move(frontierPair), std::move(edgeCompute), + std::move(auxiliaryState)); + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, const RJBindData&, + GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto frontier = computeState.frontierPair->ptrCast()->getFrontier(); + auto multiplicities = computeState.auxiliaryState->ptrCast() + ->getMultiplicitiesPair() + ->getCurrentMultiplicities(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, frontier, multiplicities); + } +}; + +std::unique_ptr AllSPDestinationsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/asp_paths.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/asp_paths.cpp new file mode 100644 index 0000000000..7a7c068d2a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/asp_paths.cpp @@ -0,0 +1,115 @@ +#include "binder/expression/node_expression.h" +#include "function/gds/auxiliary_state/path_auxiliary_state.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +class ASPPathsEdgeCompute : public SPEdgeCompute { +public: + ASPPathsEdgeCompute(SPFrontierPair* frontiersPair, BFSGraphManager* bfsGraphManager) + : SPEdgeCompute{frontiersPair}, bfsGraphManager{bfsGraphManager} { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& resultChunk, + bool fwdEdge) override { + std::vector activeNodes; + resultChunk.forEach([&](auto neighbors, auto propertyVectors, auto i) { + auto nbrNodeID = neighbors[i]; + auto iter = frontierPair->getNextFrontierValue(nbrNodeID.offset); + // We should update in 2 cases: 1) if nbrID is being visited + // for the first time, i.e., when its value in the pathLengths frontier is + // PathLengths::UNVISITED. Or 2) if nbrID has already been visited but in this + // iteration, so it's value is curIter + 1. + auto shouldUpdate = + iter == FRONTIER_UNVISITED || iter == frontierPair->getCurrentIter(); + if (shouldUpdate) { + if (!block->hasSpace()) { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + auto edgeID = propertyVectors[0]->template getValue(i); + bfsGraphManager->getCurrentGraph()->addParent(frontierPair->getCurrentIter(), + boundNodeID, edgeID, nbrNodeID, fwdEdge, block); + } + if (iter == FRONTIER_UNVISITED) { + activeNodes.push_back(nbrNodeID); + } + }); + return activeNodes; + } + + std::unique_ptr copy() override { + return std::make_unique(frontierPair, bfsGraphManager); + } + +private: + BFSGraphManager* bfsGraphManager; + ObjectBlock* block = nullptr; +}; + +// All shortest path algorithm. Paths are tracked. +class AllSPPathsAlgorithm final : public RJAlgorithm { +public: + std::string getFunctionName() const override { return AllSPPathsFunction::name; } + + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + if (bindData.extendDirection == ExtendDirection::BOTH) { + columns.push_back(bindData.directionExpr); + } + columns.push_back(bindData.pathNodeIDsExpr); + columns.push_back(bindData.pathEdgeIDsExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, const RJBindData&, + RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto mm = storage::MemoryManager::Get(*clientContext); + auto denseFrontier = + DenseFrontier::getUninitializedFrontier(context, sharedState->graph.get()); + auto frontierPair = std::make_unique(std::move(denseFrontier)); + auto bfsGraph = std::make_unique( + sharedState->graph->getMaxOffsetMap(transaction::Transaction::Get(*clientContext)), mm); + auto edgeCompute = + std::make_unique(frontierPair.get(), bfsGraph.get()); + auto auxiliaryState = std::make_unique(std::move(bfsGraph)); + return std::make_unique(std::move(frontierPair), std::move(edgeCompute), + std::move(auxiliaryState)); + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, + const RJBindData& bindData, GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto bfsGraph = computeState.auxiliaryState->ptrCast() + ->getBFSGraphManager() + ->getCurrentGraph(); + auto writerInfo = bindData.getPathWriterInfo(); + writerInfo.pathNodeMask = sharedState->getPathNodeMaskMap(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, writerInfo, *bfsGraph); + } +}; + +std::unique_ptr AllSPPathsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/awsp_paths.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/awsp_paths.cpp new file mode 100644 index 0000000000..16d7cd62c5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/awsp_paths.cpp @@ -0,0 +1,184 @@ +#include "binder/expression/node_expression.h" +#include "common/exception/interrupt.h" +#include "function/gds/auxiliary_state/path_auxiliary_state.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "function/gds/weight_utils.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +template +class AWSPPathsEdgeCompute : public EdgeCompute { +public: + explicit AWSPPathsEdgeCompute(BFSGraphManager* bfsGraphManager) + : bfsGraphManager{bfsGraphManager} { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk, + bool fwdEdge) override { + std::vector result; + chunk.forEach([&](auto neighbors, auto propertyVectors, auto i) { + auto nbrNodeID = neighbors[i]; + auto edgeID = propertyVectors[0]->template getValue(i); + auto weight = propertyVectors[1]->template getValue(i); + WeightUtils::checkWeight(AllWeightedSPPathsFunction::name, weight); + if (!block->hasSpace()) { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + if (bfsGraphManager->getCurrentGraph()->tryAddParentWithWeight(boundNodeID, edgeID, + nbrNodeID, fwdEdge, static_cast(weight), block)) { + result.push_back(nbrNodeID); + } + }); + return result; + } + + std::unique_ptr copy() override { + return std::make_unique>(bfsGraphManager); + } + +private: + BFSGraphManager* bfsGraphManager; + ObjectBlock* block = nullptr; +}; + +class AWSPPathsOutputWriter : public PathsOutputWriter { +public: + AWSPPathsOutputWriter(main::ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID, PathsOutputWriterInfo info, BaseBFSGraph& bfsGraph) + : PathsOutputWriter{context, outputNodeMask, sourceNodeID, info, bfsGraph} { + costVector = createVector(LogicalType::DOUBLE()); + } + + void writeInternal(FactorizedTable& fTable, nodeID_t dstNodeID, + LimitCounter* counter) override { + if (dstNodeID == sourceNodeID_) { // Skip writing + return; + } + auto firstParent = bfsGraph.getParentListHead(dstNodeID.offset); + if (firstParent == nullptr) { // Skip if dst is not visited. + return; + } + if (firstParent->getCost() == std::numeric_limits::max()) { + // Skip if dst is not visited. + return; + } + costVector->setValue(0, firstParent->getCost()); + std::vector curPath; + curPath.push_back(firstParent); + auto backtracking = false; + while (!curPath.empty()) { + if (context->interrupted()) { + throw InterruptException{}; + } + if (curPath[curPath.size() - 1]->getCost() == 0) { // Find source. Start writing path. + curPath.pop_back(); + writePath(curPath); + fTable.append(vectors); + if (updateCounterAndTerminate(counter)) { + return; + } + backtracking = true; + } + auto topIdx = curPath.size() - 1; + if (backtracking) { + auto next = curPath[topIdx]->getNextPtr(); + if (next != nullptr) { // Find next top node with the same cost. + KU_ASSERT(curPath[topIdx]->getCost() == next->getCost()); + curPath[topIdx] = next; + backtracking = false; + } else { // Move to next top. + curPath.pop_back(); + } + } else { // Forward track fill path. + auto parent = bfsGraph.getParentListHead(curPath[topIdx]->getNodeID()); + KU_ASSERT(parent != nullptr); + curPath.push_back(parent); + backtracking = false; + } + } + } + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, info, + bfsGraph); + } + +private: + std::unique_ptr costVector; +}; + +// All weighted shortest path algorithm. Paths are returned. +class AllWeightedSPPathsAlgorithm : public RJAlgorithm { +public: + std::string getFunctionName() const override { return AllWeightedSPPathsFunction::name; } + + // return srcNodeID, dstNodeID, length, [direction], pathNodeIDs, pathEdgeIDs, weight + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + if (bindData.extendDirection == ExtendDirection::BOTH) { + columns.push_back(bindData.directionExpr); + } + columns.push_back(bindData.pathNodeIDsExpr); + columns.push_back(bindData.pathEdgeIDsExpr); + columns.push_back(bindData.weightOutputExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, + const RJBindData& bindData, RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto mm = storage::MemoryManager::Get(*clientContext); + auto graph = sharedState->graph.get(); + auto curDenseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto nextDenseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto frontierPair = std::make_unique( + std::move(curDenseFrontier), std::move(nextDenseFrontier)); + auto bfsGraph = std::make_unique( + sharedState->graph->getMaxOffsetMap(transaction::Transaction::Get(*clientContext)), mm); + std::unique_ptr gdsState; + WeightUtils::visit(AllWeightedSPPathsFunction::name, + bindData.weightPropertyExpr->getDataType(), [&](T) { + auto edgeCompute = std::make_unique>(bfsGraph.get()); + auto auxiliaryState = std::make_unique(std::move(bfsGraph)); + gdsState = std::make_unique(std::move(frontierPair), + std::move(edgeCompute), std::move(auxiliaryState)); + }); + return gdsState; + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, + const RJBindData& bindData, GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto bfsGraph = computeState.auxiliaryState->ptrCast() + ->getBFSGraphManager() + ->getCurrentGraph(); + auto writerInfo = bindData.getPathWriterInfo(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, writerInfo, *bfsGraph); + } +}; + +std::unique_ptr AllWeightedSPPathsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/bfs_graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/bfs_graph.cpp new file mode 100644 index 0000000000..d588a3d154 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/bfs_graph.cpp @@ -0,0 +1,276 @@ +#include "function/gds/bfs_graph.h" + +#include "function/gds/gds_utils.h" +#include "processor/execution_context.h" + +using namespace lbug::common; +using namespace lbug::graph; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +static constexpr uint64_t BFS_GRAPH_BLOCK_SIZE = (std::uint64_t)1 << 19; + +ObjectBlock* BaseBFSGraph::addNewBlock() { + std::unique_lock lck{mtx}; + auto memBlock = mm->allocateBuffer(false /* init to 0 */, BFS_GRAPH_BLOCK_SIZE); + blocks.push_back( + std::make_unique>(std::move(memBlock), BFS_GRAPH_BLOCK_SIZE)); + return blocks[blocks.size() - 1].get(); +} + +class BFSGraphInitVertexCompute : public VertexCompute { +public: + explicit BFSGraphInitVertexCompute(DenseBFSGraph& bfsGraph) : bfsGraph{bfsGraph} {} + + bool beginOnTable(table_id_t tableID) override { + bfsGraph.pinTableID(tableID); + return true; + } + + void vertexCompute(offset_t startOffset, offset_t endOffset, table_id_t) override { + for (auto i = startOffset; i < endOffset; ++i) { + bfsGraph.curData[i].store(nullptr); + } + } + + std::unique_ptr copy() override { + return std::make_unique(bfsGraph); + } + +private: + DenseBFSGraph& bfsGraph; +}; + +void DenseBFSGraph::init(ExecutionContext* context, Graph* graph) { + auto mm = storage::MemoryManager::Get(*context->clientContext); + for (auto& [tableID, maxOffset] : maxOffsetMap) { + denseObjects.allocate(tableID, maxOffset, mm); + } + auto vc = std::make_unique(*this); + GDSUtils::runVertexCompute(context, GDSDensityState::DENSE, graph, *vc); +} + +void DenseBFSGraph::pinTableID(table_id_t tableID) { + curData = denseObjects.getData(tableID); +} + +static ParentList* reserveParent(nodeID_t boundNodeID, relID_t edgeID, bool fwdEdge, + ObjectBlock* block) { + auto parent = block->reserveNext(); + parent->setNbrInfo(boundNodeID, edgeID, fwdEdge); + return parent; +} + +void DenseBFSGraph::addParent(uint16_t iter, nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) { + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setIter(iter); + // Since by default the parentPtr of each node is nullptr, that's what we start with. + ParentList* expected = nullptr; + while (!curData[nbrNodeID.offset].compare_exchange_strong(expected, parent)) {} + parent->setNextPtr(expected); +} + +void DenseBFSGraph::addSingleParent(uint16_t iter, nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) { + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setIter(iter); + ParentList* expected = nullptr; + if (curData[nbrNodeID.offset].compare_exchange_strong(expected, parent)) { + parent->setNextPtr(expected); + } else { + // Other thread has added the parent. Do NOT add parent and revert reserved slot. + block->revertLast(); + } +} + +static double getCost(ParentList* parentList) { + return parentList == nullptr ? std::numeric_limits::max() : parentList->getCost(); +} + +bool DenseBFSGraph::tryAddParentWithWeight(nodeID_t boundNodeID, relID_t edgeID, nodeID_t nbrNodeID, + bool fwdEdge, double weight, ObjectBlock* block) { + ParentList* expected = getParentListHead(nbrNodeID.offset); + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setCost(getParentListHead(boundNodeID)->getCost() + weight); + while (true) { + if (parent->getCost() < getCost(expected)) { + // New parent has smaller cost, erase all existing parents and add new parent. + if (curData[nbrNodeID.offset].compare_exchange_strong(expected, parent)) { + parent->setNextPtr(nullptr); + return true; + } + } else if (parent->getCost() == getCost(expected) && expected->getEdgeID() != edgeID) { + // New parent has the same cost and comes from different edge, + // append new parent as after existing parents. + if (curData[nbrNodeID.offset].compare_exchange_strong(expected, parent)) { + parent->setNextPtr(expected); + return true; + } + } else { + block->revertLast(); + return false; + } + } +} + +bool DenseBFSGraph::tryAddSingleParentWithWeight(nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, double weight, ObjectBlock* block) { + ParentList* expected = getParentListHead(nbrNodeID.offset); + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setCost(getParentListHead(boundNodeID)->getCost() + weight); + while (parent->getCost() < getCost(expected)) { + if (curData[nbrNodeID.offset].compare_exchange_strong(expected, parent)) { + // Since each node can have one parent, set next ptr to nullptr. + parent->setNextPtr(nullptr); + return true; + } + } + // Other thread has added the parent. Do NOT add parent and revert reserved slot. + block->revertLast(); + return false; +} + +ParentList* DenseBFSGraph::getParentListHead(offset_t offset) { + KU_ASSERT(curData); + return curData[offset].load(std::memory_order_relaxed); +} + +ParentList* DenseBFSGraph::getParentListHead(nodeID_t nodeID) { + return denseObjects.getData(nodeID.tableID)[nodeID.offset].load(std::memory_order_relaxed); +} + +void DenseBFSGraph::setParentList(offset_t offset, ParentList* parentList) { + KU_ASSERT(curData && getParentListHead(offset) == nullptr); + curData[offset].store(parentList, std::memory_order_relaxed); +} + +void SparseBFSGraph::pinTableID(table_id_t tableID) { + curData = sparseObjects.getData(tableID); +} + +void SparseBFSGraph::addParent(uint16_t iter, nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) { + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setIter(iter); + if (curData->contains(nbrNodeID.offset)) { + parent->setNextPtr(curData->at(nbrNodeID.offset)); + curData->at(nbrNodeID.offset) = parent; + } else { + parent->setNextPtr(nullptr); + curData->insert({nbrNodeID.offset, parent}); + } +} + +void SparseBFSGraph::addSingleParent(uint16_t iter, nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) { + if (curData->contains(nbrNodeID.offset)) { + return; + } + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setIter(iter); + parent->setNextPtr(nullptr); + curData->insert({nbrNodeID.offset, parent}); +} + +bool SparseBFSGraph::tryAddParentWithWeight(nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, double weight, ObjectBlock* block) { + auto nbrParent = getParentListHead(nbrNodeID.offset); + auto nbrCost = getCost(nbrParent); + auto newCost = getParentListHead(boundNodeID)->getCost() + weight; + if (newCost < nbrCost) { + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setCost(newCost); + parent->setNextPtr(nullptr); + curData->erase(nbrNodeID.offset); + curData->insert({nbrNodeID.offset, parent}); + return true; + } + // Append parent if newCost is the same as old cost. And the newCost comes from a different edge + // Otherwise, for cases like A->B->C, A->D->C, C->E. If ABD and ADC has the same cost, we will + // visit twice to E with the same cost and same edge. + if (newCost == nbrCost && nbrParent->getEdgeID() != edgeID) { + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setCost(newCost); + if (curData->contains(nbrNodeID.offset)) { + parent->setNextPtr(curData->at(nbrNodeID.offset)); + curData->erase(nbrNodeID.offset); + } else { + parent->setNextPtr(nullptr); + } + curData->insert({nbrNodeID.offset, parent}); + return true; + } + return false; +} + +bool SparseBFSGraph::tryAddSingleParentWithWeight(nodeID_t boundNodeID, relID_t edgeID, + nodeID_t nbrNodeID, bool fwdEdge, double weight, ObjectBlock* block) { + auto nbrCost = getCost(getParentListHead(nbrNodeID.offset)); + auto newCost = getParentListHead(boundNodeID)->getCost() + weight; + if (newCost < nbrCost) { + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setCost(newCost); + parent->setNextPtr(nullptr); + curData->erase(nbrNodeID.offset); + curData->insert({nbrNodeID.offset, parent}); + return true; + } + if (newCost == nbrCost) { + if (curData->contains(nbrNodeID.offset)) { + return false; + } + auto parent = reserveParent(boundNodeID, edgeID, fwdEdge, block); + parent->setCost(newCost); + parent->setNextPtr(nullptr); + curData->insert({nbrNodeID.offset, parent}); + } + return false; +} + +ParentList* SparseBFSGraph::getParentListHead(offset_t offset) { + KU_ASSERT(curData); + if (!curData->contains(offset)) { + return nullptr; + } + return curData->at(offset); +} + +ParentList* SparseBFSGraph::getParentListHead(nodeID_t nodeID) { + auto data = sparseObjects.getData(nodeID.tableID); + if (!data->contains(nodeID.offset)) { + return nullptr; + } + return data->at(nodeID.offset); +} + +void SparseBFSGraph::setParentList(offset_t offset, ParentList* parentList) { + KU_ASSERT(!curData->contains(offset)); + curData->insert({offset, parentList}); +} + +BFSGraphManager::BFSGraphManager(table_id_map_t maxOffsetMap, + storage::MemoryManager* mm) { + denseBFSGraph = std::make_unique(mm, maxOffsetMap); + sparseBFSGraph = std::make_unique(mm, maxOffsetMap); + curGraph = sparseBFSGraph.get(); +} + +void BFSGraphManager::switchToDense(ExecutionContext* context, Graph* graph) { + KU_ASSERT(state == GDSDensityState::SPARSE); + state = GDSDensityState::DENSE; + denseBFSGraph->init(context, graph); + denseBFSGraph->blocks = std::move(sparseBFSGraph->blocks); + for (auto& [tableID, map] : sparseBFSGraph->sparseObjects.getData()) { + denseBFSGraph->pinTableID(tableID); + for (auto& [offset, ptr] : map) { + denseBFSGraph->setParentList(offset, ptr); + } + } + curGraph = denseBFSGraph.get(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/frontier_morsel.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/frontier_morsel.cpp new file mode 100644 index 0000000000..2243446117 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/frontier_morsel.cpp @@ -0,0 +1,35 @@ +#include "function/gds/frontier_morsel.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +FrontierMorselDispatcher::FrontierMorselDispatcher(uint64_t maxThreads) + : maxOffset{INVALID_OFFSET}, maxThreads{maxThreads}, morselSize(UINT64_MAX) { + nextOffset.store(INVALID_OFFSET); +} + +void FrontierMorselDispatcher::init(offset_t _maxOffset) { + maxOffset = _maxOffset; + nextOffset.store(0u); + // Frontier size calculation: The ideal scenario is to have k^2 many morsels where k + // the number of maximum threads that could be working on this frontier. However, if + // that is too small then we default to MIN_FRONTIER_MORSEL_SIZE. + auto idealMorselSize = + maxOffset / std::max(MIN_NUMBER_OF_FRONTIER_MORSELS, maxThreads * maxThreads); + morselSize = std::max(MIN_FRONTIER_MORSEL_SIZE, idealMorselSize); +} + +bool FrontierMorselDispatcher::getNextRangeMorsel(FrontierMorsel& frontierMorsel) { + auto beginOffset = nextOffset.fetch_add(morselSize, std::memory_order_acq_rel); + if (beginOffset >= maxOffset) { + return false; + } + auto endOffset = beginOffset + morselSize > maxOffset ? maxOffset : beginOffset + morselSize; + frontierMorsel.init(beginOffset, endOffset); + return true; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds.cpp new file mode 100644 index 0000000000..0266bae691 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds.cpp @@ -0,0 +1,281 @@ +#include "function/gds/gds.h" + +#include "binder/binder.h" +#include "binder/expression/rel_expression.h" +#include "binder/query/reading_clause/bound_table_function_call.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/binder.h" +#include "function/table/bind_input.h" +#include "graph/graph_entry_set.h" +#include "graph/on_disk_graph.h" +#include "parser/parser.h" +#include "planner/operator/logical_table_function_call.h" +#include "planner/operator/sip/logical_semi_masker.h" +#include "planner/planner.h" +#include "processor/operator/table_function_call.h" +#include "processor/plan_mapper.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::main; +using namespace lbug::graph; +using namespace lbug::processor; +using namespace lbug::planner; + +namespace lbug { +namespace function { + +void GDSFuncSharedState::setGraphNodeMask(std::unique_ptr maskMap) { + auto onDiskGraph = ku_dynamic_cast(graph.get()); + onDiskGraph->setNodeOffsetMask(maskMap.get()); + graphNodeMask = std::move(maskMap); +} + +static expression_vector getResultColumns(const std::string& cypher, ClientContext* context) { + auto parsedStatements = parser::Parser::parseQuery(cypher); + KU_ASSERT(parsedStatements.size() == 1); + auto binder = Binder(context); + auto boundStatement = binder.bind(*parsedStatements[0]); + return boundStatement->getStatementResult()->getColumns(); +} + +static void validateNodeProjected(const table_id_set_t& connectedNodeTableIDSet, + const table_id_set_t& projectedNodeIDSet, const std::string& relName, Catalog* catalog, + transaction::Transaction* transaction) { + for (auto id : connectedNodeTableIDSet) { + if (!projectedNodeIDSet.contains(id)) { + auto entryName = catalog->getTableCatalogEntry(transaction, id)->getName(); + throw BinderException( + stringFormat("{} is connected to {} but not projected.", entryName, relName)); + } + } +} + +static void validateRelSrcDstNodeAreProjected(const TableCatalogEntry& entry, + const table_id_set_t& projectedNodeIDSet, Catalog* catalog, + transaction::Transaction* transaction) { + auto& relEntry = entry.constCast(); + validateNodeProjected(relEntry.getSrcNodeTableIDSet(), projectedNodeIDSet, relEntry.getName(), + catalog, transaction); + validateNodeProjected(relEntry.getDstNodeTableIDSet(), projectedNodeIDSet, relEntry.getName(), + catalog, transaction); +} + +NativeGraphEntry GDSFunction::bindGraphEntry(ClientContext& context, const std::string& name) { + auto set = GraphEntrySet::Get(context); + set->validateGraphExist(name); + auto entry = set->getEntry(name); + if (entry->type != GraphEntryType::NATIVE) { + throw BinderException("AA"); + } + return bindGraphEntry(context, entry->cast()); +} + +static NativeGraphEntryTableInfo bindNodeEntry(ClientContext& context, const std::string& tableName, + const std::string& predicate) { + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + auto nodeEntry = catalog->getTableCatalogEntry(transaction, tableName); + if (nodeEntry->getType() != CatalogEntryType::NODE_TABLE_ENTRY) { + throw BinderException(stringFormat("{} is not a NODE table.", tableName)); + } + if (!predicate.empty()) { + auto cypher = stringFormat("MATCH (n:`{}`) RETURN n, {}", nodeEntry->getName(), predicate); + auto columns = getResultColumns(cypher, &context); + KU_ASSERT(columns.size() == 2); + return {nodeEntry, columns[0], columns[1]}; + } else { + auto cypher = stringFormat("MATCH (n:`{}`) RETURN n", nodeEntry->getName()); + auto columns = getResultColumns(cypher, &context); + KU_ASSERT(columns.size() == 1); + return {nodeEntry, columns[0], nullptr /* empty predicate */}; + } +} + +static NativeGraphEntryTableInfo bindRelEntry(ClientContext& context, const std::string& tableName, + const std::string& predicate) { + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + auto relEntry = catalog->getTableCatalogEntry(transaction, tableName); + if (relEntry->getType() != CatalogEntryType::REL_GROUP_ENTRY) { + throw BinderException( + stringFormat("{} has catalog entry type. REL entry was expected.", tableName)); + } + if (!predicate.empty()) { + auto cypher = + stringFormat("MATCH ()-[r:`{}`]->() RETURN r, {}", relEntry->getName(), predicate); + auto columns = getResultColumns(cypher, &context); + KU_ASSERT(columns.size() == 2); + return {relEntry, columns[0], columns[1]}; + } else { + auto cypher = stringFormat("MATCH ()-[r:`{}`]->() RETURN r", relEntry->getName()); + auto columns = getResultColumns(cypher, &context); + KU_ASSERT(columns.size() == 1); + return {relEntry, columns[0], nullptr /* empty predicate */}; + } +} + +NativeGraphEntry GDSFunction::bindGraphEntry(ClientContext& context, + const ParsedNativeGraphEntry& entry) { + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + auto result = NativeGraphEntry(); + table_id_set_t projectedNodeTableIDSet; + for (auto& nodeInfo : entry.nodeInfos) { + auto boundInfo = bindNodeEntry(context, nodeInfo.tableName, nodeInfo.predicate); + projectedNodeTableIDSet.insert(boundInfo.entry->getTableID()); + result.nodeInfos.push_back(std::move(boundInfo)); + } + for (auto& relInfo : entry.relInfos) { + if (catalog->containsTable(transaction, relInfo.tableName)) { + auto boundInfo = bindRelEntry(context, relInfo.tableName, relInfo.predicate); + validateRelSrcDstNodeAreProjected(*boundInfo.entry, projectedNodeTableIDSet, catalog, + transaction); + result.relInfos.push_back(std::move(boundInfo)); + } else { + throw BinderException(stringFormat("{} is not a REL table.", relInfo.tableName)); + } + } + return result; +} + +std::shared_ptr GDSFunction::bindRelOutput(const TableFuncBindInput& bindInput, + const std::vector& relEntries, + std::shared_ptr srcNode, std::shared_ptr dstNode, + const std::optional& name, const std::optional& yieldVariableIdx) { + std::string relColumnName = name.value_or(REL_COLUMN_NAME); + StringUtils::toLower(relColumnName); + if (!bindInput.yieldVariables.empty()) { + relColumnName = + bindColumnName(bindInput.yieldVariables[yieldVariableIdx.value_or(0)], relColumnName); + } + auto rel = bindInput.binder->createNonRecursiveQueryRel(relColumnName, relEntries, srcNode, + dstNode, RelDirectionType::SINGLE); + bindInput.binder->addToScope(REL_COLUMN_NAME, rel); + return rel; +} + +std::shared_ptr GDSFunction::bindNodeOutput(const TableFuncBindInput& bindInput, + const std::vector& nodeEntries, const std::optional& name, + const std::optional& yieldVariableIdx) { + std::string nodeColumnName = name.value_or(NODE_COLUMN_NAME); + StringUtils::toLower(nodeColumnName); + if (!bindInput.yieldVariables.empty()) { + nodeColumnName = + bindColumnName(bindInput.yieldVariables[yieldVariableIdx.value_or(0)], nodeColumnName); + } + auto node = bindInput.binder->createQueryNode(nodeColumnName, nodeEntries); + bindInput.binder->addToScope(nodeColumnName, node); + return node; +} + +std::string GDSFunction::bindColumnName(const parser::YieldVariable& yieldVariable, + std::string expressionName) { + if (yieldVariable.name != expressionName) { + throw common::BinderException{ + common::stringFormat("Unknown variable name: {}.", yieldVariable.name)}; + } + if (yieldVariable.hasAlias()) { + return yieldVariable.alias; + } + return expressionName; +} + +std::unique_ptr GDSFunction::initSharedState( + const TableFuncInitSharedStateInput& input) { + auto bindData = input.bindData->constPtrCast(); + auto graph = + std::make_unique(input.context->clientContext, bindData->graphEntry.copy()); + return std::make_unique(bindData->getResultTable(), std::move(graph)); +} + +std::vector> getNodeMaskPlanRoots(const GDSBindData& bindData, + Planner* planner) { + std::vector> nodeMaskPlanRoots; + for (auto& nodeInfo : bindData.graphEntry.nodeInfos) { + if (nodeInfo.predicate == nullptr) { + continue; + } + auto& node = nodeInfo.nodeOrRel->constCast(); + planner->getCardinliatyEstimatorUnsafe().init(node); + auto p = planner->getNodeSemiMaskPlan(SemiMaskTargetType::GDS_GRAPH_NODE, node, + nodeInfo.predicate); + nodeMaskPlanRoots.push_back(p.getLastOperator()); + } + return nodeMaskPlanRoots; +}; + +void GDSFunction::getLogicalPlan(Planner* planner, const BoundReadingClause& readingClause, + expression_vector predicates, LogicalPlan& plan) { + auto& call = readingClause.constCast(); + auto bindData = call.getBindData()->constPtrCast(); + auto op = std::make_shared(call.getTableFunc(), bindData->copy()); + for (auto root : getNodeMaskPlanRoots(*bindData, planner)) { + op->addChild(root); + } + op->computeFactorizedSchema(); + planner->planReadOp(std::move(op), predicates, plan); + + auto nodeOutput = bindData->output[0]->ptrCast(); + KU_ASSERT(nodeOutput != nullptr); + planner->getCardinliatyEstimatorUnsafe().init(*nodeOutput); + auto scanPlan = planner->getNodePropertyScanPlan(*nodeOutput); + if (scanPlan.isEmpty()) { + return; + } + expression_vector joinConditions; + joinConditions.push_back(nodeOutput->getInternalID()); + planner->appendHashJoin(joinConditions, JoinType::INNER, plan, scanPlan, plan); +} + +std::unique_ptr GDSFunction::getPhysicalPlan(PlanMapper* planMapper, + const LogicalOperator* logicalOp) { + auto logicalCall = logicalOp->constPtrCast(); + auto bindData = logicalCall->getBindData()->copy(); + auto columns = bindData->columns; + auto tableSchema = PlanMapper::createFlatFTableSchema(columns, *logicalCall->getSchema()); + auto table = std::make_shared( + storage::MemoryManager::Get(*planMapper->clientContext), tableSchema.copy()); + bindData->cast().setResultFTable(table); + auto info = TableFunctionCallInfo(); + info.function = logicalCall->getTableFunc(); + info.bindData = std::move(bindData); + auto initInput = + TableFuncInitSharedStateInput(info.bindData.get(), planMapper->executionContext); + auto sharedState = info.function.initSharedStateFunc(initInput); + auto printInfo = + std::make_unique(info.function.name, info.bindData->columns); + auto call = std::make_unique(std::move(info), sharedState, + planMapper->getOperatorID(), std::move(printInfo)); + if (logicalCall->getNumChildren() > 0u) { + const auto funcSharedState = sharedState->ptrCast(); + funcSharedState->setGraphNodeMask(std::make_unique()); + auto maskMap = funcSharedState->getGraphNodeMaskMap(); + planMapper->addOperatorMapping(logicalOp, call.get()); + for (auto logicalRoot : logicalCall->getChildren()) { + KU_ASSERT(logicalRoot->getNumChildren() == 1); + auto child = logicalRoot->getChild(0); + KU_ASSERT(child->getOperatorType() == LogicalOperatorType::SEMI_MASKER); + auto logicalSemiMasker = child->ptrCast(); + logicalSemiMasker->addTarget(logicalOp); + for (auto tableID : logicalSemiMasker->getNodeTableIDs()) { + maskMap->addMask(tableID, planMapper->createSemiMask(tableID)); + } + auto root = planMapper->mapOperator(logicalRoot.get()); + call->addChild(std::move(root)); + } + planMapper->eraseOperatorMapping(logicalOp); + } + planMapper->addOperatorMapping(logicalOp, call.get()); + physical_op_vector_t children; + auto dummySink = std::make_unique(std::move(call), planMapper->getOperatorID()); + dummySink->setDescriptor(std::make_unique(logicalCall->getSchema())); + children.push_back(std::move(dummySink)); + return planMapper->createFTableScanAligned(columns, logicalCall->getSchema(), table, + DEFAULT_VECTOR_CAPACITY, std::move(children)); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_frontier.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_frontier.cpp new file mode 100644 index 0000000000..4aae9bacd2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_frontier.cpp @@ -0,0 +1,423 @@ +#include "function/gds/gds_frontier.h" + +#include "function/gds/gds_utils.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::graph; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +void SparseFrontier::pinTableID(table_id_t tableID) { + curData = sparseObjects.getData(tableID); +} + +void SparseFrontier::addNode(nodeID_t nodeID, iteration_t iter) { + KU_ASSERT(curData); + addNode(nodeID.offset, iter); +} + +void SparseFrontier::addNode(offset_t offset, iteration_t iter) { + KU_ASSERT(curData); + if (!curData->contains(offset)) { + curData->insert({offset, iter}); + } else { + curData->at(offset) = iter; + } +} + +void SparseFrontier::addNodes(const std::vector& nodeIDs, iteration_t iter) { + KU_ASSERT(curData); + for (auto& nodeID : nodeIDs) { + addNode(nodeID.offset, iter); + } +} + +iteration_t SparseFrontier::getIteration(offset_t offset) const { + KU_ASSERT(curData); + if (!curData->contains(offset)) { + return FRONTIER_UNVISITED; + } + return curData->at(offset); +} + +void SparseFrontierReference::pinTableID(table_id_t tableID) { + curData = sparseObjects.getData(tableID); +} + +void SparseFrontierReference::addNode(offset_t offset, iteration_t iter) { + KU_ASSERT(curData); + if (!curData->contains(offset)) { + curData->insert({offset, iter}); + } else { + curData->at(offset) = iter; + } +} + +void SparseFrontierReference::addNode(nodeID_t nodeID, iteration_t iter) { + KU_ASSERT(curData); + addNode(nodeID.offset, iter); +} + +void SparseFrontierReference::addNodes(const std::vector& nodeIDs, iteration_t iter) { + KU_ASSERT(curData); + for (auto nodeID : nodeIDs) { + addNode(nodeID.offset, iter); + } +} + +iteration_t SparseFrontierReference::getIteration(offset_t offset) const { + KU_ASSERT(curData); + if (!curData->contains(offset)) { + return FRONTIER_UNVISITED; + } + return curData->at(offset); +} + +class DenseFrontierInitVertexCompute : public VertexCompute { +public: + DenseFrontierInitVertexCompute(DenseFrontier& frontier, iteration_t val) + : frontier{frontier}, val{val} {} + + bool beginOnTable(table_id_t tableID) override { + frontier.pinTableID(tableID); + return true; + } + + void vertexCompute(offset_t startOffset, offset_t endOffset, table_id_t) override { + for (auto i = startOffset; i < endOffset; ++i) { + frontier.addNode(i, val); + } + } + + std::unique_ptr copy() override { + return std::make_unique(frontier, val); + } + +private: + DenseFrontier& frontier; + iteration_t val; +}; + +void DenseFrontier::init(ExecutionContext* context, Graph* graph, iteration_t val) { + auto mm = storage::MemoryManager::Get(*context->clientContext); + for (const auto& [tableID, maxOffset] : nodeMaxOffsetMap) { + denseObjects.allocate(tableID, maxOffset, mm); + } + resetValue(context, graph, val); +} + +void DenseFrontier::resetValue(ExecutionContext* context, Graph* graph, iteration_t val) { + auto vc = DenseFrontierInitVertexCompute(*this, val); + GDSUtils::runVertexCompute(context, GDSDensityState::DENSE, graph, vc); +} + +void DenseFrontier::pinTableID(table_id_t tableID) { + curData = denseObjects.getData(tableID); +} + +void DenseFrontier::addNode(nodeID_t nodeID, iteration_t iter) { + KU_ASSERT(curData); + curData[nodeID.offset].store(iter, std::memory_order_relaxed); +} + +void DenseFrontier::addNode(offset_t offset, iteration_t iter) { + KU_ASSERT(curData); + curData[offset].store(iter, std::memory_order_relaxed); +} + +void DenseFrontier::addNodes(const std::vector& nodeIDs, iteration_t iter) { + KU_ASSERT(curData); + for (auto nodeID : nodeIDs) { + curData[nodeID.offset].store(iter, std::memory_order_relaxed); + } +} + +iteration_t DenseFrontier::getIteration(offset_t offset) const { + KU_ASSERT(curData); + return curData[offset].load(std::memory_order_relaxed); +} + +std::unique_ptr DenseFrontier::getUninitializedFrontier(ExecutionContext* context, + Graph* graph) { + auto transaction = transaction::Transaction::Get(*context->clientContext); + return std::make_unique(graph->getMaxOffsetMap(transaction)); +} + +std::unique_ptr DenseFrontier::getUnvisitedFrontier(ExecutionContext* context, + Graph* graph) { + auto transaction = transaction::Transaction::Get(*context->clientContext); + auto frontier = std::make_unique(graph->getMaxOffsetMap(transaction)); + frontier->init(context, graph, FRONTIER_UNVISITED); + return frontier; +} + +std::unique_ptr DenseFrontier::getVisitedFrontier(ExecutionContext* context, + Graph* graph) { + auto transaction = transaction::Transaction::Get(*context->clientContext); + auto frontier = std::make_unique(graph->getMaxOffsetMap(transaction)); + frontier->init(context, graph, FRONTIER_INITIAL_VISITED); + return frontier; +} + +std::unique_ptr DenseFrontier::getVisitedFrontier(ExecutionContext* context, + Graph* graph, NodeOffsetMaskMap* maskMap) { + if (maskMap == nullptr) { + return getVisitedFrontier(context, graph); + } + auto transaction = transaction::Transaction::Get(*context->clientContext); + auto frontier = std::make_unique(graph->getMaxOffsetMap(transaction)); + frontier->init(context, graph, FRONTIER_INITIAL_VISITED); + for (auto [tableID, numNodes] : graph->getMaxOffsetMap(transaction)) { + frontier->pinTableID(tableID); + if (maskMap->containsTableID(tableID)) { + auto mask = maskMap->getOffsetMask(tableID); + for (auto i = 0u; i < numNodes; ++i) { + if (!mask->isMasked(i)) { + frontier->curData[i].store(FRONTIER_UNVISITED); + } + } + } + } + return frontier; +} + +void DenseFrontierReference::pinTableID(table_id_t tableID) { + curData = denseObjects.getData(tableID); +} + +void DenseFrontierReference::addNode(nodeID_t nodeID, iteration_t iter) { + KU_ASSERT(curData); + curData[nodeID.offset].store(iter, std::memory_order_relaxed); +} + +void DenseFrontierReference::addNode(offset_t offset, iteration_t iter) { + KU_ASSERT(curData); + curData[offset].store(iter, std::memory_order_relaxed); +} + +void DenseFrontierReference::addNodes(const std::vector& nodeIDs, iteration_t iter) { + KU_ASSERT(curData); + for (auto nodeID : nodeIDs) { + curData[nodeID.offset].store(iter, std::memory_order_relaxed); + } +} + +iteration_t DenseFrontierReference::getIteration(offset_t offset) const { + KU_ASSERT(curData); + return curData[offset].load(std::memory_order_relaxed); +} + +void FrontierPair::beginNewIteration() { + std::unique_lock lck{mtx}; + curIter++; + hasActiveNodesForNextIter_.store(false); + beginNewIterationInternalNoLock(); +} + +void FrontierPair::beginFrontierComputeBetweenTables(table_id_t curTableID, + table_id_t nextTableID) { + pinCurrentFrontier(curTableID); + pinNextFrontier(nextTableID); +} + +void FrontierPair::pinCurrentFrontier(table_id_t tableID) { + currentFrontier->pinTableID(tableID); +} + +void FrontierPair::pinNextFrontier(table_id_t tableID) { + nextFrontier->pinTableID(tableID); +} + +void FrontierPair::addNodeToNextFrontier(nodeID_t nodeID) { + nextFrontier->addNode(nodeID, curIter); +} + +void FrontierPair::addNodeToNextFrontier(offset_t offset) { + nextFrontier->addNode(offset, curIter); +} + +void FrontierPair::addNodesToNextFrontier(const std::vector& nodeIDs) { + nextFrontier->addNodes(nodeIDs, curIter); +} + +iteration_t FrontierPair::getNextFrontierValue(offset_t offset) { + return nextFrontier->getIteration(offset); +} + +bool FrontierPair::isActiveOnCurrentFrontier(offset_t offset) { + return currentFrontier->getIteration(offset) == curIter - 1; +} + +Frontier* SPFrontierPair::getFrontier() { + switch (state) { + case GDSDensityState::SPARSE: { + return sparseFrontier.get(); + } + case GDSDensityState::DENSE: { + return denseFrontier.get(); + } + default: + KU_UNREACHABLE; + } +} + +SPFrontierPair::SPFrontierPair(std::unique_ptr denseFrontier) + : state{GDSDensityState::SPARSE}, denseFrontier{std::move(denseFrontier)} { + curDenseFrontier = std::make_unique(*this->denseFrontier); + nextDenseFrontier = std::make_unique(*this->denseFrontier); + sparseFrontier = std::make_unique(this->denseFrontier->nodeMaxOffsetMap); + curSparseFrontier = std::make_unique(*this->sparseFrontier); + nextSparseFrontier = std::make_unique(*this->sparseFrontier); + currentFrontier = curSparseFrontier.get(); + nextFrontier = nextSparseFrontier.get(); +} + +void SPFrontierPair::beginNewIterationInternalNoLock() { + switch (state) { + case GDSDensityState::SPARSE: { + std::swap(curSparseFrontier, nextSparseFrontier); + currentFrontier = curSparseFrontier.get(); + nextFrontier = nextSparseFrontier.get(); + } break; + case GDSDensityState::DENSE: { + std::swap(curDenseFrontier, nextDenseFrontier); + currentFrontier = curDenseFrontier.get(); + nextFrontier = nextDenseFrontier.get(); + } break; + default: + KU_UNREACHABLE; + } +} + +offset_t SPFrontierPair::getNumActiveNodesInCurrentFrontier(NodeOffsetMaskMap& mask) { + auto result = 0u; + for (auto& [tableID, maxNumNodes] : denseFrontier->nodeMaxOffsetMap) { + currentFrontier->pinTableID(tableID); + if (!mask.containsTableID(tableID)) { + continue; + } + auto offsetMask = mask.getOffsetMask(tableID); + for (auto offset = 0u; offset < maxNumNodes; ++offset) { + if (isActiveOnCurrentFrontier(offset)) { + result += offsetMask->isMasked(offset); + } + } + } + return result; +} + +std::unordered_set SPFrontierPair::getActiveNodesOnCurrentFrontier() { + KU_ASSERT(state == GDSDensityState::SPARSE); + std::unordered_set result; + for (auto& [offset, iter] : curSparseFrontier->getCurrentData()) { + if (iter != curIter - 1) { + continue; + } + result.insert(offset); + } + return result; +} + +void SPFrontierPair::switchToDense(ExecutionContext* context, graph::Graph* graph) { + KU_ASSERT(state == GDSDensityState::SPARSE); + state = GDSDensityState::DENSE; + denseFrontier->init(context, graph, FRONTIER_UNVISITED); + for (auto& [tableID, map] : sparseFrontier->sparseObjects.getData()) { + nextDenseFrontier->pinTableID(tableID); + for (auto [offset, iter] : map) { + nextDenseFrontier->curData[offset].store(iter); + } + } +} + +DenseSparseDynamicFrontierPair::DenseSparseDynamicFrontierPair( + std::unique_ptr curDenseFrontier, + std::unique_ptr nextDenseFrontier) + : state{GDSDensityState::SPARSE}, curDenseFrontier{std::move(curDenseFrontier)}, + nextDenseFrontier{std::move(nextDenseFrontier)} { + curSparseFrontier = std::make_unique(this->curDenseFrontier->nodeMaxOffsetMap); + nextSparseFrontier = + std::make_unique(this->nextDenseFrontier->nodeMaxOffsetMap); + currentFrontier = curSparseFrontier.get(); + nextFrontier = nextSparseFrontier.get(); +} + +void DenseSparseDynamicFrontierPair::beginNewIterationInternalNoLock() { + switch (state) { + case GDSDensityState::SPARSE: { + std::swap(curSparseFrontier, nextSparseFrontier); + currentFrontier = curSparseFrontier.get(); + nextFrontier = nextSparseFrontier.get(); + } break; + case GDSDensityState::DENSE: { + std::swap(curDenseFrontier, nextDenseFrontier); + currentFrontier = curDenseFrontier.get(); + nextFrontier = nextDenseFrontier.get(); + } break; + default: + KU_UNREACHABLE; + } +} + +std::unordered_set DenseSparseDynamicFrontierPair::getActiveNodesOnCurrentFrontier() { + KU_ASSERT(state == GDSDensityState::SPARSE); + std::unordered_set result; + for (auto& [offset, iter] : *curSparseFrontier->curData) { + if (iter != curIter - 1) { + continue; + } + result.insert(offset); + } + return result; +} + +void DenseSparseDynamicFrontierPair::switchToDense(ExecutionContext* context, Graph* graph) { + KU_ASSERT(state == GDSDensityState::SPARSE); + state = GDSDensityState::DENSE; + curDenseFrontier->init(context, graph, FRONTIER_UNVISITED); + nextDenseFrontier->init(context, graph, FRONTIER_UNVISITED); + for (auto& [tableID, map] : nextSparseFrontier->sparseObjects.getData()) { + nextDenseFrontier->pinTableID(tableID); + for (auto [offset, iter] : map) { + nextDenseFrontier->curData[offset].store(iter); + } + } +} + +DenseFrontierPair::DenseFrontierPair(std::unique_ptr curDenseFrontier, + std::unique_ptr nextDenseFrontier) + : curDenseFrontier{std::move(curDenseFrontier)}, + nextDenseFrontier{std::move(nextDenseFrontier)} { + currentFrontier = this->curDenseFrontier.get(); + nextFrontier = this->nextDenseFrontier.get(); +} + +void DenseFrontierPair::beginNewIterationInternalNoLock() { + std::swap(curDenseFrontier, nextDenseFrontier); + currentFrontier = curDenseFrontier.get(); + nextFrontier = nextDenseFrontier.get(); +} + +void DenseFrontierPair::resetValue(ExecutionContext* context, Graph* graph, iteration_t val) { + curDenseFrontier->resetValue(context, graph, val); + nextDenseFrontier->resetValue(context, graph, val); +} + +static constexpr uint64_t EARLY_TERM_NUM_NODES_THRESHOLD = 100; + +bool SPEdgeCompute::terminate(NodeOffsetMaskMap& maskMap) { + auto targetNumNodes = maskMap.getNumMaskedNode(); + if (targetNumNodes > EARLY_TERM_NUM_NODES_THRESHOLD) { + // Skip checking if it's unlikely to early terminate. + return false; + } + numNodesReached += frontierPair->getNumActiveNodesInCurrentFrontier(maskMap); + return numNodesReached == targetNumNodes; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_state.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_state.cpp new file mode 100644 index 0000000000..aecfa7608e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_state.cpp @@ -0,0 +1,26 @@ +#include "function/gds/gds_state.h" + +namespace lbug { +namespace function { + +void GDSComputeState::initSource(common::nodeID_t sourceNodeID) const { + frontierPair->pinNextFrontier(sourceNodeID.tableID); + frontierPair->addNodeToNextFrontier(sourceNodeID); + frontierPair->setActiveNodesForNextIter(); + auxiliaryState->initSource(sourceNodeID); +} + +void GDSComputeState::beginFrontierCompute(common::table_id_t currTableID, + common::table_id_t nextTableID) const { + frontierPair->beginFrontierComputeBetweenTables(currTableID, nextTableID); + auxiliaryState->beginFrontierCompute(currTableID, nextTableID); +} + +void GDSComputeState::switchToDense(processor::ExecutionContext* context, + graph::Graph* graph) const { + frontierPair->switchToDense(context, graph); + auxiliaryState->switchToDense(context, graph); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_task.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_task.cpp new file mode 100644 index 0000000000..b57527370d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_task.cpp @@ -0,0 +1,151 @@ +#include "function/gds/gds_task.h" + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "function/gds/frontier_morsel.h" +#include "graph/graph.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +table_id_t FrontierTaskInfo::getBoundTableID() const { + switch (direction) { + case ExtendDirection::FWD: + return srcTableID; + case ExtendDirection::BWD: + return dstTableID; + default: + KU_UNREACHABLE; + } +} + +table_id_t FrontierTaskInfo::getNbrTableID() const { + switch (direction) { + case ExtendDirection::FWD: + return dstTableID; + case ExtendDirection::BWD: + return srcTableID; + default: + KU_UNREACHABLE; + } +} + +oid_t FrontierTaskInfo::getRelTableID() const { + return relGroupEntry->constCast() + .getRelEntryInfo(srcTableID, dstTableID) + ->oid; +} + +void FrontierTask::run() { + FrontierMorsel morsel; + auto numActiveNodes = 0u; + auto graph = info.graph; + auto scanState = graph->prepareRelScan(*info.relGroupEntry, info.getRelTableID(), + info.getNbrTableID(), info.propertiesToScan); + auto ec = info.edgeCompute.copy(); + auto boundTableID = info.getBoundTableID(); + switch (info.direction) { + case ExtendDirection::FWD: { + while (sharedState->morselDispatcher.getNextRangeMorsel(morsel)) { + for (auto offset = morsel.getBeginOffset(); offset < morsel.getEndOffset(); ++offset) { + if (!sharedState->frontierPair.isActiveOnCurrentFrontier(offset)) { + continue; + } + nodeID_t nodeID = {offset, boundTableID}; + for (auto chunk : graph->scanFwd(nodeID, *scanState)) { + auto activeNodes = ec->edgeCompute(nodeID, chunk, true); + sharedState->frontierPair.addNodesToNextFrontier(activeNodes); + numActiveNodes += activeNodes.size(); + } + } + } + } break; + case ExtendDirection::BWD: { + while (sharedState->morselDispatcher.getNextRangeMorsel(morsel)) { + for (auto offset = morsel.getBeginOffset(); offset < morsel.getEndOffset(); ++offset) { + if (!sharedState->frontierPair.isActiveOnCurrentFrontier(offset)) { + continue; + } + nodeID_t nodeID = {offset, boundTableID}; + for (auto chunk : graph->scanBwd(nodeID, *scanState)) { + auto activeNodes = ec->edgeCompute(nodeID, chunk, false); + sharedState->frontierPair.addNodesToNextFrontier(activeNodes); + numActiveNodes += activeNodes.size(); + } + } + } + } break; + default: + KU_UNREACHABLE; + } + if (numActiveNodes) { + sharedState->frontierPair.setActiveNodesForNextIter(); + } +} + +void FrontierTask::runSparse() { + auto numActiveNodes = 0u; + auto graph = info.graph; + auto scanState = graph->prepareRelScan(*info.relGroupEntry, info.getRelTableID(), + info.getNbrTableID(), info.propertiesToScan); + auto ec = info.edgeCompute.copy(); + auto boundTableID = info.getBoundTableID(); + switch (info.direction) { + case ExtendDirection::FWD: { + for (const auto offset : sharedState->frontierPair.getActiveNodesOnCurrentFrontier()) { + auto nodeID = nodeID_t{offset, boundTableID}; + for (auto chunk : graph->scanFwd(nodeID, *scanState)) { + auto activeNodes = ec->edgeCompute(nodeID, chunk, true); + sharedState->frontierPair.addNodesToNextFrontier(activeNodes); + numActiveNodes += activeNodes.size(); + } + } + } break; + case ExtendDirection::BWD: { + for (auto& offset : sharedState->frontierPair.getActiveNodesOnCurrentFrontier()) { + auto nodeID = nodeID_t{offset, boundTableID}; + for (auto chunk : graph->scanBwd(nodeID, *scanState)) { + auto activeNodes = ec->edgeCompute(nodeID, chunk, false); + sharedState->frontierPair.addNodesToNextFrontier(activeNodes); + numActiveNodes += activeNodes.size(); + } + } + } break; + default: + KU_UNREACHABLE; + } + if (numActiveNodes) { + sharedState->frontierPair.setActiveNodesForNextIter(); + } +} + +void VertexComputeTask::run() { + FrontierMorsel morsel; + auto graph = info.graph; + auto localVc = info.vc.copy(); + if (info.hasPropertiesToScan()) { + auto scanState = graph->prepareVertexScan(info.tableEntry, info.propertiesToScan); + while (sharedState->morselDispatcher.getNextRangeMorsel(morsel)) { + for (auto chunk : + graph->scanVertices(morsel.getBeginOffset(), morsel.getEndOffset(), *scanState)) { + localVc->vertexCompute(chunk); + } + } + } else { + while (sharedState->morselDispatcher.getNextRangeMorsel(morsel)) { + localVc->vertexCompute(morsel.getBeginOffset(), morsel.getEndOffset(), + info.tableEntry->getTableID()); + } + } +} + +void VertexComputeTask::runSparse() { + KU_ASSERT(!info.hasPropertiesToScan()); + auto localVc = info.vc.copy(); + localVc->vertexCompute(info.tableEntry->getTableID()); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_utils.cpp new file mode 100644 index 0000000000..789fe38557 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/gds_utils.cpp @@ -0,0 +1,173 @@ +#include "function/gds/gds_utils.h" + +#include "binder/expression/property_expression.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/exception/interrupt.h" +#include "common/task_system/task_scheduler.h" +#include "function/gds/gds_task.h" +#include "graph/graph.h" +#include "graph/graph_entry.h" +#include "main/client_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::function; +using namespace lbug::processor; +using namespace lbug::graph; + +namespace lbug { +namespace function { + +static std::shared_ptr getFrontierTask(const main::ClientContext* context, + const GraphRelInfo& relInfo, Graph* graph, ExtendDirection extendDirection, + const GDSComputeState& computeState, std::vector propertiesToScan) { + auto info = FrontierTaskInfo(relInfo.srcTableID, relInfo.dstTableID, relInfo.relGroupEntry, + graph, extendDirection, *computeState.edgeCompute, std::move(propertiesToScan)); + computeState.beginFrontierCompute(info.getBoundTableID(), info.getNbrTableID()); + auto numThreads = context->getMaxNumThreadForExec(); + auto sharedState = + std::make_shared(numThreads, *computeState.frontierPair); + auto maxOffset = + graph->getMaxOffset(transaction::Transaction::Get(*context), info.getBoundTableID()); + sharedState->morselDispatcher.init(maxOffset); + return std::make_shared(numThreads, info, sharedState); +} + +static void scheduleFrontierTask(ExecutionContext* context, const GraphRelInfo& relInfo, + Graph* graph, ExtendDirection extendDirection, const GDSComputeState& computeState, + std::vector propertiesToScan) { + auto clientContext = context->clientContext; + auto task = getFrontierTask(clientContext, relInfo, graph, extendDirection, computeState, + std::move(propertiesToScan)); + if (computeState.frontierPair->getState() == GDSDensityState::SPARSE) { + task->runSparse(); + return; + } + + // GDSUtils::runFrontiersUntilConvergence is called from a GDSCall operator, which is + // already executed by a worker thread Tm of the task scheduler. So this function is + // executed by Tm. Because this function will monitor the task and wait for it to + // complete, running GDS algorithms effectively "loses" Tm. This can even lead to the + // query processor to halt, e.g., if there is a single worker thread in the system, and + // more generally decrease the number of worker threads by 1. Therefore, we instruct + // scheduleTaskAndWaitOrError to start a new thread by passing true as the last + // argument. + TaskScheduler::Get(*context->clientContext) + ->scheduleTaskAndWaitOrError(task, context, true /* launchNewWorkerThread */); +} + +static void runOneIteration(ExecutionContext* context, Graph* graph, + ExtendDirection extendDirection, const GDSComputeState& compState, + const std::vector& propertiesToScan) { + for (auto info : graph->getGraphEntry()->nodeInfos) { + for (const auto& relInfo : graph->getRelInfos(info.entry->getTableID())) { + if (context->clientContext->interrupted()) { + throw InterruptException{}; + } + switch (extendDirection) { + case ExtendDirection::FWD: { + scheduleFrontierTask(context, relInfo, graph, ExtendDirection::FWD, compState, + propertiesToScan); + } break; + case ExtendDirection::BWD: { + scheduleFrontierTask(context, relInfo, graph, ExtendDirection::BWD, compState, + propertiesToScan); + } break; + case ExtendDirection::BOTH: { + scheduleFrontierTask(context, relInfo, graph, ExtendDirection::FWD, compState, + propertiesToScan); + scheduleFrontierTask(context, relInfo, graph, ExtendDirection::BWD, compState, + propertiesToScan); + } break; + default: + KU_UNREACHABLE; + } + } + } +} + +void GDSUtils::runAlgorithmEdgeCompute(ExecutionContext* context, GDSComputeState& compState, + Graph* graph, ExtendDirection extendDirection, uint64_t maxIteration) { + auto frontierPair = compState.frontierPair.get(); + while (frontierPair->continueNextIter(maxIteration)) { + frontierPair->beginNewIteration(); + runOneIteration(context, graph, extendDirection, compState, {}); + } +} + +void GDSUtils::runFTSEdgeCompute(ExecutionContext* context, GDSComputeState& compState, + Graph* graph, ExtendDirection extendDirection, + const std::vector& propertiesToScan) { + compState.frontierPair->beginNewIteration(); + runOneIteration(context, graph, extendDirection, compState, propertiesToScan); +} + +void GDSUtils::runRecursiveJoinEdgeCompute(ExecutionContext* context, GDSComputeState& compState, + Graph* graph, ExtendDirection extendDirection, uint64_t maxIteration, + NodeOffsetMaskMap* outputNodeMask, const std::vector& propertiesToScan) { + auto frontierPair = compState.frontierPair.get(); + compState.edgeCompute->resetSingleThreadState(); + while (frontierPair->continueNextIter(maxIteration)) { + frontierPair->beginNewIteration(); + if (outputNodeMask != nullptr && compState.edgeCompute->terminate(*outputNodeMask)) { + break; + } + runOneIteration(context, graph, extendDirection, compState, propertiesToScan); + if (frontierPair->needSwitchToDense( + context->clientContext->getClientConfig()->sparseFrontierThreshold)) { + compState.switchToDense(context, graph); + } + } +} + +static void runVertexComputeInternal(const TableCatalogEntry* currentEntry, + GDSDensityState densityState, const Graph* graph, std::shared_ptr task, + ExecutionContext* context) { + if (densityState == GDSDensityState::SPARSE) { + task->runSparse(); + return; + } + auto maxOffset = graph->getMaxOffset(transaction::Transaction::Get(*context->clientContext), + currentEntry->getTableID()); + auto sharedState = task->getSharedState(); + sharedState->morselDispatcher.init(maxOffset); + TaskScheduler::Get(*context->clientContext) + ->scheduleTaskAndWaitOrError(task, context, true /* launchNewWorkerThread */); +} + +void GDSUtils::runVertexCompute(ExecutionContext* context, GDSDensityState densityState, + Graph* graph, VertexCompute& vc, const std::vector& propertiesToScan) { + auto maxThreads = context->clientContext->getMaxNumThreadForExec(); + auto sharedState = std::make_shared(maxThreads); + for (const auto& nodeInfo : graph->getGraphEntry()->nodeInfos) { + auto entry = nodeInfo.entry; + if (!vc.beginOnTable(entry->getTableID())) { + continue; + } + auto info = VertexComputeTaskInfo(vc, graph, entry, propertiesToScan); + auto task = std::make_shared(maxThreads, info, sharedState); + runVertexComputeInternal(entry, densityState, graph, task, context); + } +} + +void GDSUtils::runVertexCompute(ExecutionContext* context, GDSDensityState densityState, + Graph* graph, VertexCompute& vc) { + runVertexCompute(context, densityState, graph, vc, std::vector{}); +} + +void GDSUtils::runVertexCompute(ExecutionContext* context, GDSDensityState densityState, + Graph* graph, VertexCompute& vc, TableCatalogEntry* entry, + const std::vector& propertiesToScan) { + auto maxThreads = context->clientContext->getMaxNumThreadForExec(); + auto info = VertexComputeTaskInfo(vc, graph, entry, propertiesToScan); + auto sharedState = std::make_shared(maxThreads); + if (!vc.beginOnTable(entry->getTableID())) { + return; + } + auto task = std::make_shared(maxThreads, info, sharedState); + runVertexComputeInternal(entry, densityState, graph, task, context); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/output_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/output_writer.cpp new file mode 100644 index 0000000000..150d0db5a1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/output_writer.cpp @@ -0,0 +1,396 @@ +#include "common/exception/interrupt.h" +#include "function/gds/rj_output_writer.h" +#include "main/client_context.h" +#include + +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +RJOutputWriter::RJOutputWriter(main::ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID) + : context{context}, outputNodeMask{outputNodeMask}, sourceNodeID_{sourceNodeID} { + srcNodeIDVector = createVector(LogicalType::INTERNAL_ID()); + dstNodeIDVector = createVector(LogicalType::INTERNAL_ID()); + srcNodeIDVector->setValue(0, sourceNodeID); +} + +void RJOutputWriter::pinOutputNodeMask(table_id_t tableID) { + if (outputNodeMask != nullptr) { + outputNodeMask->pin(tableID); + } +} + +bool RJOutputWriter::inOutputNodeMask(common::offset_t offset) { + if (outputNodeMask == nullptr) { // No mask + return true; + } + auto mask = outputNodeMask->getPinnedMask(); + if (!mask->isEnabled()) { // No mask + return true; + } + return mask->isMasked(offset); +} + +std::unique_ptr RJOutputWriter::createVector(const LogicalType& type) { + auto vector = std::make_unique(type.copy(), storage::MemoryManager::Get(*context)); + vector->state = DataChunkState::getSingleValueDataChunkState(); + vectors.push_back(vector.get()); + return vector; +} + +PathsOutputWriter::PathsOutputWriter(main::ClientContext* context, + NodeOffsetMaskMap* outputNodeMask, nodeID_t sourceNodeID, PathsOutputWriterInfo info, + BaseBFSGraph& bfsGraph) + : RJOutputWriter{context, outputNodeMask, sourceNodeID}, info{info}, bfsGraph{bfsGraph} { + lengthVector = createVector(LogicalType::UINT16()); + if (info.writeEdgeDirection) { + directionVector = createVector(LogicalType::LIST(LogicalType::BOOL())); + } + if (info.writePath) { + pathNodeIDsVector = createVector(LogicalType::LIST(LogicalType::INTERNAL_ID())); + pathEdgeIDsVector = createVector(LogicalType::LIST(LogicalType::INTERNAL_ID())); + } +} + +static void addListEntry(ValueVector* vector, uint64_t length) { + vector->resetAuxiliaryBuffer(); + auto entry = ListVector::addList(vector, length); + KU_ASSERT(entry.offset == 0); + vector->setValue(0, entry); +} + +static ParentList* getTop(const std::vector& path) { + return path[path.size() - 1]; +} + +void PathsOutputWriter::write(FactorizedTable& fTable, table_id_t tableID, LimitCounter* counter) { + auto& sparseGraph = bfsGraph.cast(); + for (auto& [offset, _] : sparseGraph.getCurrentData()) { + write(fTable, {offset, tableID}, counter); + } + if (info.lowerBound == 0 && sourceNodeID_.tableID == tableID) { + write(fTable, sourceNodeID_, counter); + } +} + +void PathsOutputWriter::write(FactorizedTable& fTable, nodeID_t dstNodeID, LimitCounter* counter) { + if (!inOutputNodeMask(dstNodeID.offset)) { + return; + } + dstNodeIDVector->setValue(0, dstNodeID); + writeInternal(fTable, dstNodeID, counter); +} + +void PathsOutputWriter::dfsFast(ParentList* firstParent, FactorizedTable& fTable, + LimitCounter* counter) { + std::vector curPath; + curPath.push_back(firstParent); + auto backtracking = false; + while (!curPath.empty()) { + if (context->interrupted()) { + throw InterruptException{}; + } + auto top = curPath[curPath.size() - 1]; + auto topNodeID = top->getNodeID(); + if (top->getIter() == 1) { + writePath(curPath); + fTable.append(vectors); + if (updateCounterAndTerminate(counter)) { + return; + } + backtracking = true; + } + if (backtracking) { + auto next = getTop(curPath)->getNextPtr(); + if (isNextViable(next, curPath)) { + curPath[curPath.size() - 1] = next; + backtracking = false; + } else { + curPath.pop_back(); + } + } else { + auto parent = bfsGraph.getParentListHead(topNodeID); + while (parent->getIter() != top->getIter() - 1) { + parent = parent->getNextPtr(); + } + curPath.push_back(parent); + backtracking = false; + } + } +} + +void PathsOutputWriter::dfsSlow(ParentList* firstParent, FactorizedTable& fTable, + LimitCounter* counter) { + std::vector curPath; + curPath.push_back(firstParent); + auto backtracking = false; + while (!curPath.empty()) { + if (context->interrupted()) { + throw InterruptException{}; + } + if (getTop(curPath)->getIter() == 1) { + writePath(curPath); + fTable.append(vectors); + if (updateCounterAndTerminate(counter)) { + return; + } + backtracking = true; + } + if (backtracking) { + auto next = getTop(curPath)->getNextPtr(); + while (true) { + if (!isNextViable(next, curPath)) { + curPath.pop_back(); + break; + } + // Further check next against path node mask (predicate). + if (!checkPathNodeMask(next) || !checkReplaceTopSemantic(curPath, next)) { + next = next->getNextPtr(); + continue; + } + // Next is a valid path element. Push into stack and switch to forward track. + curPath[curPath.size() - 1] = next; + backtracking = false; + break; + } + } else { + auto top = getTop(curPath); + auto parent = bfsGraph.getParentListHead(top->getNodeID()); + while (true) { + if (parent == nullptr) { + // No more forward tracking candidates. Switch to backward tracking. + backtracking = true; + break; + } + if (parent->getIter() == top->getIter() - 1 && checkPathNodeMask(parent) && + checkAppendSemantic(curPath, parent)) { + // A forward tracking candidate should decrease the iteration by one and also + // pass node predicate checking. + curPath.push_back(parent); + backtracking = false; + break; + } + parent = parent->getNextPtr(); + } + } + } +} + +bool PathsOutputWriter::updateCounterAndTerminate(LimitCounter* counter) { + if (counter != nullptr) { + counter->increase(1); + return counter->exceedLimit(); + } + return false; +} + +ParentList* PathsOutputWriter::findFirstParent(offset_t dstOffset) const { + auto result = bfsGraph.getParentListHead(dstOffset); + if (!info.hasNodeMask() && info.semantic == PathSemantic::WALK) { + // Fast path when there is no node predicate or semantic check + return result; + } + while (result) { + // A valid parent should + // (1) satisfies path node semi mask (i.e. path node predicate) + // (2) since first parent has the largest iteration number which decides path length, we + // also need to check if path length is greater than lower bound. + if (checkPathNodeMask(result) && result->getIter() >= info.lowerBound) { + break; + } + result = result->getNextPtr(); + } + return result; +} + +// This code checks if we should switch from backtracking to forward-tracking, i.e., +// moving forward in the DFS logic to find paths. We switch from backtracking if: +bool PathsOutputWriter::isNextViable(ParentList* next, const std::vector& path) const { + if (next == nullptr) { + return false; + } + auto nextIter = next->getIter(); + // (1) if this is the first element in the stack (curPath.size() == 1), i.e., we + // are enumerating the parents of the destination, then we should switch to + // forward-tracking if the next parent has visited the destination at a length + // that's greater than or equal to the lower bound of the recursive join. Otherwise, we would + // enumerate paths that are smaller than the lower bound from the start element, so we can stop + // here.; OR + if (path.size() == 1) { + return nextIter >= info.lowerBound; + } + // (2) if this is not the first element in the stack, i.e., then we should switch + // to forward tracking only if the next parent of the top node in the stack has the + // same iter value as the current parent. That's because the levels/iter need to + // decrease by 1 each time we add a new node in the stack. + if (nextIter == getTop(path)->getIter()) { + return true; + } + return false; +} + +bool PathsOutputWriter::checkPathNodeMask(ParentList* element) const { + if (!info.hasNodeMask() || element->getIter() == 1) { + return true; + } + return info.pathNodeMask->valid(element->getNodeID()); +} + +bool PathsOutputWriter::checkAppendSemantic(const std::vector& path, + ParentList* candidate) const { + switch (info.semantic) { + case PathSemantic::WALK: + return true; + case PathSemantic::TRAIL: + return isAppendTrail(path, candidate); + case PathSemantic::ACYCLIC: + return isAppendAcyclic(path, candidate); + default: + KU_UNREACHABLE; + } +} + +bool PathsOutputWriter::checkReplaceTopSemantic(const std::vector& path, + ParentList* candidate) const { + switch (info.semantic) { + case PathSemantic::WALK: + return true; + case PathSemantic::TRAIL: + return isReplaceTopTrail(path, candidate); + case PathSemantic::ACYCLIC: + return isReplaceTopAcyclic(path, candidate); + default: + KU_UNREACHABLE; + } +} + +bool PathsOutputWriter::isAppendTrail(const std::vector& path, + ParentList* candidate) const { + for (auto& element : path) { + if (candidate->getEdgeID() == element->getEdgeID()) { + return false; + } + } + return true; +} + +bool PathsOutputWriter::isAppendAcyclic(const std::vector& path, + ParentList* candidate) const { + // Skip dst for semantic checking + for (auto i = 1u; i < path.size() - 1; ++i) { + if (candidate->getNodeID() == path[i]->getNodeID()) { + return false; + } + } + return true; +} + +bool PathsOutputWriter::isReplaceTopTrail(const std::vector& path, + ParentList* candidate) const { + for (auto i = 0u; i < path.size() - 1; ++i) { + if (candidate->getEdgeID() == path[i]->getEdgeID()) { + return false; + } + } + return true; +} + +bool PathsOutputWriter::isReplaceTopAcyclic(const std::vector& path, + ParentList* candidate) const { + // Skip dst for semantic checking + for (auto i = 1u; i < path.size() - 1; ++i) { + if (candidate->getNodeID() == path[i]->getNodeID()) { + return false; + } + } + return true; +} + +static void setLength(ValueVector* vector, uint16_t length) { + KU_ASSERT(vector->dataType.getLogicalTypeID() == LogicalTypeID::UINT16); + vector->setValue(0, length); +} + +void PathsOutputWriter::beginWritePath(idx_t length) const { + KU_ASSERT(info.writePath); + addListEntry(pathNodeIDsVector.get(), length > 1 ? length - 1 : 0); + addListEntry(pathEdgeIDsVector.get(), length); + if (info.writeEdgeDirection) { + addListEntry(directionVector.get(), length); + } +} + +void PathsOutputWriter::writePath(const std::vector& path) const { + setLength(lengthVector.get(), path.size()); + if (!info.writePath) { + return; + } + beginWritePath(path.size()); + if (path.size() == 0) { + return; + } + if (!info.flipPath) { + // By default, write path in reverse direction because we append ParentList from dst to src. + writePathBwd(path); + } else { + // Write path in original direction because computation started from dst node. + // We want to present result in src->dst order. + writePathFwd(path); + } +} + +void PathsOutputWriter::writePathFwd(const std::vector& path) const { + auto length = path.size(); + for (auto i = 0u; i < length - 1; i++) { + auto p = path[i]; + addNode(p->getNodeID(), i); + addEdge(p->getEdgeID(), p->isFwdEdge(), i); + } + auto lastPathElement = path[length - 1]; + addEdge(lastPathElement->getEdgeID(), lastPathElement->isFwdEdge(), length - 1); +} + +void PathsOutputWriter::writePathBwd(const std::vector& path) const { + auto length = path.size(); + for (auto i = 1u; i < length; i++) { + auto p = path[length - 1 - i]; + addNode(p->getNodeID(), i - 1); + addEdge(p->getEdgeID(), p->isFwdEdge(), i); + } + auto lastPathElement = path[length - 1]; + addEdge(lastPathElement->getEdgeID(), lastPathElement->isFwdEdge(), 0); +} + +void PathsOutputWriter::addEdge(relID_t edgeID, bool fwdEdge, sel_t pos) const { + ListVector::getDataVector(pathEdgeIDsVector.get())->setValue(pos, edgeID); + if (info.writeEdgeDirection) { + ListVector::getDataVector(directionVector.get())->setValue(pos, fwdEdge); + } +} + +void PathsOutputWriter::addNode(nodeID_t nodeID, sel_t pos) const { + ListVector::getDataVector(pathNodeIDsVector.get())->setValue(pos, nodeID); +} + +void SPPathsOutputWriter::writeInternal(FactorizedTable& fTable, nodeID_t dstNodeID, + LimitCounter* counter) { + auto firstParent = findFirstParent(dstNodeID.offset); + if (firstParent == nullptr) { + return; + } + if (dstNodeID == sourceNodeID_) { // Avoid writing source + KU_ASSERT(firstParent->getIter() == FRONTIER_INITIAL_VISITED); + return; + } + if (!info.hasNodeMask() && info.semantic == PathSemantic::WALK) { + dfsFast(firstParent, fTable, counter); + return; + } + dfsSlow(firstParent, fTable, counter); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/rec_joins.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/rec_joins.cpp new file mode 100644 index 0000000000..d01ee757a3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/rec_joins.cpp @@ -0,0 +1,35 @@ +#include "function/gds/rec_joins.h" + +namespace lbug { +namespace function { + +RJBindData::RJBindData(const RJBindData& other) { + graphEntry = other.graphEntry.copy(); + nodeInput = other.nodeInput; + nodeOutput = other.nodeOutput; + lowerBound = other.lowerBound; + upperBound = other.upperBound; + semantic = other.semantic; + extendDirection = other.extendDirection; + flipPath = other.flipPath; + writePath = other.writePath; + directionExpr = other.directionExpr; + lengthExpr = other.lengthExpr; + pathNodeIDsExpr = other.pathNodeIDsExpr; + pathEdgeIDsExpr = other.pathEdgeIDsExpr; + weightPropertyExpr = other.weightPropertyExpr; + weightOutputExpr = other.weightOutputExpr; +} + +PathsOutputWriterInfo RJBindData::getPathWriterInfo() const { + auto info = PathsOutputWriterInfo(); + info.semantic = semantic; + info.lowerBound = lowerBound; + info.flipPath = flipPath; + info.writeEdgeDirection = writePath && extendDirection == common::ExtendDirection::BOTH; + info.writePath = writePath; + return info; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/ssp_destinations.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/ssp_destinations.cpp new file mode 100644 index 0000000000..1c6e788095 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/ssp_destinations.cpp @@ -0,0 +1,127 @@ +#include "binder/expression/node_expression.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "processor/execution_context.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::processor; +using namespace lbug::graph; +using namespace lbug::main; + +namespace lbug { +namespace function { + +class SSPDestinationsOutputWriter : public RJOutputWriter { +public: + SSPDestinationsOutputWriter(ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID, Frontier* frontier) + : RJOutputWriter{context, outputNodeMask, sourceNodeID}, frontier{frontier} { + lengthVector = createVector(LogicalType::UINT16()); + } + + void beginWritingInternal(table_id_t tableID) override { frontier->pinTableID(tableID); } + + void write(FactorizedTable& fTable, table_id_t tableID, LimitCounter* counter) override { + auto& sparseFrontier = frontier->cast(); + for (auto [offset, _] : sparseFrontier.getCurrentData()) { + write(fTable, {offset, tableID}, counter); + } + } + + void write(FactorizedTable& fTable, nodeID_t dstNodeID, LimitCounter* counter) override { + if (!inOutputNodeMask(dstNodeID.offset)) { // Skip dst if it not is in scope. + return; + } + if (sourceNodeID_ == dstNodeID) { // Skip writing source node. + return; + } + auto iter = frontier->getIteration(dstNodeID.offset); + if (iter == FRONTIER_UNVISITED) { // Skip if dst is not visited. + return; + } + dstNodeIDVector->setValue(0, dstNodeID); + lengthVector->setValue(0, iter); + fTable.append(vectors); + if (counter != nullptr) { + counter->increase(1); + } + } + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, + frontier); + } + +private: + Frontier* frontier; + std::unique_ptr lengthVector; +}; + +class SSPDestinationsEdgeCompute : public SPEdgeCompute { +public: + explicit SSPDestinationsEdgeCompute(SPFrontierPair* frontierPair) + : SPEdgeCompute{frontierPair} {}; + + std::vector edgeCompute(nodeID_t, NbrScanState::Chunk& resultChunk, bool) override { + std::vector activeNodes; + resultChunk.forEach([&](auto neighbors, auto, auto i) { + auto nbrNode = neighbors[i]; + auto iter = frontierPair->getNextFrontierValue(nbrNode.offset); + if (iter == FRONTIER_UNVISITED) { + activeNodes.push_back(nbrNode); + } + }); + return activeNodes; + } + + std::unique_ptr copy() override { + return std::make_unique(frontierPair); + } +}; + +// Single shortest path algorithm. Only destinations are tracked (reachability query). +// If there are multiple path to a destination. Only one of the path is tracked. +class SingleSPDestinationsAlgorithm : public RJAlgorithm { +public: + std::string getFunctionName() const override { return SingleSPDestinationsFunction::name; } + + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, const RJBindData&, + RecursiveExtendSharedState* sharedState) override { + auto graph = sharedState->graph.get(); + auto denseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto frontierPair = std::make_unique(std::move(denseFrontier)); + auto edgeCompute = std::make_unique(frontierPair.get()); + auto auxiliaryState = std::make_unique(); + return std::make_unique(std::move(frontierPair), std::move(edgeCompute), + std::move(auxiliaryState)); + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, const RJBindData&, + GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto frontier = computeState.frontierPair->ptrCast()->getFrontier(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, frontier); + } +}; + +std::unique_ptr SingleSPDestinationsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/ssp_paths.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/ssp_paths.cpp new file mode 100644 index 0000000000..96ecbe9ff8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/ssp_paths.cpp @@ -0,0 +1,109 @@ +#include "binder/expression/node_expression.h" +#include "function/gds/auxiliary_state/path_auxiliary_state.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +// #include "main/client_context.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +class SSPPathsEdgeCompute : public SPEdgeCompute { +public: + SSPPathsEdgeCompute(SPFrontierPair* frontierPair, BFSGraphManager* bfsGraphManager) + : SPEdgeCompute{frontierPair}, bfsGraphManager{bfsGraphManager} { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& resultChunk, + bool isFwd) override { + std::vector activeNodes; + resultChunk.forEach([&](auto neighbors, auto propertyVectors, auto i) { + auto nbrNodeID = neighbors[i]; + auto iter = frontierPair->getNextFrontierValue(nbrNodeID.offset); + if (iter == FRONTIER_UNVISITED) { + if (!block->hasSpace()) { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + auto edgeID = propertyVectors[0]->template getValue(i); + bfsGraphManager->getCurrentGraph()->addSingleParent(frontierPair->getCurrentIter(), + boundNodeID, edgeID, nbrNodeID, isFwd, block); + activeNodes.push_back(nbrNodeID); + } + }); + return activeNodes; + } + + std::unique_ptr copy() override { + return std::make_unique(frontierPair, bfsGraphManager); + } + +private: + BFSGraphManager* bfsGraphManager; + ObjectBlock* block = nullptr; +}; + +// Single shortest path algorithm. Paths are tracked. +// If there are multiple path to a destination. Only one of the path is tracked. +class SingleSPPathsAlgorithm : public RJAlgorithm { +public: + std::string getFunctionName() const override { return SingleSPPathsFunction::name; } + + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + if (bindData.extendDirection == ExtendDirection::BOTH) { + columns.push_back(bindData.directionExpr); + } + columns.push_back(bindData.pathNodeIDsExpr); + columns.push_back(bindData.pathEdgeIDsExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, const RJBindData&, + RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto frontier = DenseFrontier::getUninitializedFrontier(context, sharedState->graph.get()); + auto frontierPair = std::make_unique(std::move(frontier)); + auto transaction = transaction::Transaction::Get(*context->clientContext); + auto bfsGraph = + std::make_unique(sharedState->graph->getMaxOffsetMap(transaction), + storage::MemoryManager::Get(*clientContext)); + auto edgeCompute = + std::make_unique(frontierPair.get(), bfsGraph.get()); + auto auxiliaryState = std::make_unique(std::move(bfsGraph)); + return std::make_unique(std::move(frontierPair), std::move(edgeCompute), + std::move(auxiliaryState)); + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, + const RJBindData& bindData, GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto bfsGraph = computeState.auxiliaryState->ptrCast() + ->getBFSGraphManager() + ->getCurrentGraph(); + auto writerInfo = bindData.getPathWriterInfo(); + writerInfo.pathNodeMask = sharedState->getPathNodeMaskMap(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, writerInfo, *bfsGraph); + } +}; + +std::unique_ptr SingleSPPathsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/variable_length_path.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/variable_length_path.cpp new file mode 100644 index 0000000000..c3ad260bd8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/variable_length_path.cpp @@ -0,0 +1,155 @@ +#include "binder/expression/node_expression.h" +#include "function/gds/auxiliary_state/path_auxiliary_state.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "graph/graph.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +class VarLenPathsOutputWriter final : public PathsOutputWriter { +public: + VarLenPathsOutputWriter(main::ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID, PathsOutputWriterInfo info, BaseBFSGraph& bfsGraph) + : PathsOutputWriter{context, outputNodeMask, sourceNodeID, info, bfsGraph} {} + + void writeInternal(FactorizedTable& fTable, nodeID_t dstNodeID, + LimitCounter* counter) override { + auto firstParent = findFirstParent(dstNodeID.offset); + if (firstParent == nullptr) { + if (sourceNodeID_ == dstNodeID && info.lowerBound == 0) { + // We still output a path from src to src if required path length is 0. + // e.g. MATCH (a)-[e*0..]-> + // "a" needs to be in the output + writePath({}); + fTable.append(vectors); + updateCounterAndTerminate(counter); + } + return; + } + if (firstParent->getIter() < info.lowerBound) { // Skip if lower bound is not met. + return; + } + if (!info.hasNodeMask() && info.semantic == PathSemantic::WALK) { + dfsFast(firstParent, fTable, counter); + return; + } + dfsSlow(firstParent, fTable, counter); + } + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, + info, bfsGraph); + } +}; + +class VarLenJoinsEdgeCompute : public EdgeCompute { +public: + VarLenJoinsEdgeCompute(DenseSparseDynamicFrontierPair* frontierPair, + BFSGraphManager* bfsGraphManager) + : frontierPair{frontierPair}, bfsGraphManager{bfsGraphManager} { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + }; + + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk, + bool fwdEdge) override { + std::vector activeNodes; + chunk.forEach([&](auto neighbors, auto propertyVectors, auto i) { + // We should always update the nbrID in variable length joins + auto nbrNodeID = neighbors[i]; + auto edgeID = propertyVectors[0]->template getValue(i); + if (!block->hasSpace()) { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + bfsGraphManager->getCurrentGraph()->addParent(frontierPair->getCurrentIter(), + boundNodeID, edgeID, nbrNodeID, fwdEdge, block); + activeNodes.push_back(nbrNodeID); + }); + return activeNodes; + } + + std::unique_ptr copy() override { + return std::make_unique(frontierPair, bfsGraphManager); + } + +private: + DenseSparseDynamicFrontierPair* frontierPair; + BFSGraphManager* bfsGraphManager; + ObjectBlock* block = nullptr; +}; + +/** + * Algorithm for parallel all shortest paths computation, so all shortest paths from a source to + * is returned for each destination. If paths are not returned, multiplicities indicate the + * number of paths to each destination. + */ +class VarLenJoinsAlgorithm final : public RJAlgorithm { +public: + std::string getFunctionName() const override { return VarLenJoinsFunction::name; } + + // return srcNodeID, dstNodeID, length, [direction, pathNodeIDs, pathEdgeIDs] (if track path) + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + if (bindData.writePath) { + if (bindData.extendDirection == ExtendDirection::BOTH) { + columns.push_back(bindData.directionExpr); + } + columns.push_back(bindData.pathNodeIDsExpr); + columns.push_back(bindData.pathEdgeIDsExpr); + } + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, const RJBindData&, + RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto transaction = transaction::Transaction::Get(*clientContext); + auto bfsGraph = + std::make_unique(sharedState->graph->getMaxOffsetMap(transaction), + storage::MemoryManager::Get(*clientContext)); + auto currentDenseFrontier = + DenseFrontier::getUninitializedFrontier(context, sharedState->graph.get()); + auto nextDenseFrontier = + DenseFrontier::getUninitializedFrontier(context, sharedState->graph.get()); + auto frontierPair = std::make_unique( + std::move(currentDenseFrontier), std::move(nextDenseFrontier)); + auto edgeCompute = + std::make_unique(frontierPair.get(), bfsGraph.get()); + auto auxiliaryState = std::make_unique(std::move(bfsGraph)); + return std::make_unique(std::move(frontierPair), std::move(edgeCompute), + std::move(auxiliaryState)); + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, + const RJBindData& bindData, GDSComputeState& computeState, common::nodeID_t sourceNodeID, + processor::RecursiveExtendSharedState* sharedState) override { + auto bfsGraph = computeState.auxiliaryState->ptrCast() + ->getBFSGraphManager() + ->getCurrentGraph(); + auto writerInfo = bindData.getPathWriterInfo(); + writerInfo.pathNodeMask = sharedState->getPathNodeMaskMap(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, writerInfo, *bfsGraph); + } +}; + +std::unique_ptr VarLenJoinsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/wsp_destinations.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/wsp_destinations.cpp new file mode 100644 index 0000000000..da44fac0dc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/wsp_destinations.cpp @@ -0,0 +1,340 @@ +#include "binder/expression/node_expression.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "function/gds/weight_utils.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +class Costs { +public: + virtual ~Costs() = default; + + virtual void pinTableID(table_id_t tableID) = 0; + + virtual void setCost(offset_t offset, double cost) = 0; + virtual bool tryReplaceWithMinCost(offset_t offset, double newCost) = 0; + + virtual double getCost(offset_t offset) = 0; +}; + +class SparseCostsReference : public Costs { +public: + explicit SparseCostsReference(GDSSpareObjectManager& sparseObjects) + : sparseObjects{sparseObjects} {} + + void pinTableID(table_id_t tableID) override { curData = sparseObjects.getData(tableID); } + + void setCost(offset_t offset, double cost) override { + KU_ASSERT(curData != nullptr); + if (curData->contains(offset)) { + curData->at(offset) = cost; + } else { + curData->insert({offset, cost}); + } + } + + bool tryReplaceWithMinCost(offset_t offset, double newCost) override { + auto curCost = getCost(offset); + if (newCost < curCost) { + setCost(offset, newCost); + return true; + } + return false; + } + + double getCost(offset_t offset) override { + KU_ASSERT(curData != nullptr); + if (curData->contains(offset)) { + return curData->at(offset); + } + return std::numeric_limits::max(); + } + +private: + std::unordered_map* curData = nullptr; + GDSSpareObjectManager& sparseObjects; +}; + +class DenseCostsReference : public Costs { +public: + explicit DenseCostsReference(GDSDenseObjectManager>& denseObjects) + : denseObjects{denseObjects} {} + + void pinTableID(table_id_t tableID) override { curData = denseObjects.getData(tableID); } + + void setCost(offset_t offset, double cost) override { + KU_ASSERT(curData != nullptr); + curData[offset].store(cost, std::memory_order_relaxed); + } + + bool tryReplaceWithMinCost(offset_t offset, double newCost) override { + auto curCost = getCost(offset); + while (newCost < curCost) { + if (curData[offset].compare_exchange_strong(curCost, newCost)) { + return true; + } + } + return false; + } + + double getCost(offset_t offset) override { + KU_ASSERT(curData != nullptr); + return curData[offset].load(std::memory_order_relaxed); + } + +private: + table_id_map_t nodeMaxOffsetMap; + std::atomic* curData = nullptr; + GDSDenseObjectManager>& denseObjects; +}; + +class CostsPair { +public: + explicit CostsPair(const table_id_map_t& maxOffsetMap) + : maxOffsetMap{maxOffsetMap}, densityState{GDSDensityState::SPARSE}, + sparseObjects{maxOffsetMap} { + curSparseCosts = std::make_unique(sparseObjects); + nextSparseCosts = std::make_unique(sparseObjects); + denseObjects = GDSDenseObjectManager>(); + curDenseCosts = std::make_unique(denseObjects); + nextDenseCosts = std::make_unique(denseObjects); + } + + Costs* getCurrentCosts() { return curCosts; } + + void pinCurTableID(table_id_t tableID) { + switch (densityState) { + case GDSDensityState::SPARSE: { + curSparseCosts->pinTableID(tableID); + curCosts = curSparseCosts.get(); + } break; + case GDSDensityState::DENSE: { + curDenseCosts->pinTableID(tableID); + curCosts = curDenseCosts.get(); + } break; + default: + KU_UNREACHABLE; + } + } + + void pinNextTableID(table_id_t tableID) { + switch (densityState) { + case GDSDensityState::SPARSE: { + nextSparseCosts->pinTableID(tableID); + nextCosts = nextSparseCosts.get(); + } break; + case GDSDensityState::DENSE: { + nextDenseCosts->pinTableID(tableID); + nextCosts = nextDenseCosts.get(); + } break; + default: + KU_UNREACHABLE; + } + } + + // CAS update nbrOffset if new path from boundOffset has a smaller cost. + bool update(offset_t boundOffset, offset_t nbrOffset, double val) { + KU_ASSERT(curCosts && nextCosts); + auto newCost = curCosts->getCost(boundOffset) + val; + return nextCosts->tryReplaceWithMinCost(nbrOffset, newCost); + } + + void switchToDense(ExecutionContext* context) { + KU_ASSERT(densityState == GDSDensityState::SPARSE); + densityState = GDSDensityState::DENSE; + auto mm = MemoryManager::Get(*context->clientContext); + for (auto& [tableID, maxOffset] : maxOffsetMap) { + denseObjects.allocate(tableID, maxOffset, mm); + auto data = denseObjects.getData(tableID); + for (auto i = 0u; i < maxOffset; i++) { + data[i].store(std::numeric_limits::max()); + } + } + for (auto& [tableID, map] : sparseObjects.getData()) { + auto data = denseObjects.getData(tableID); + for (auto& [offset, cost] : map) { + data[offset].store(cost); + } + } + } + +private: + table_id_map_t maxOffsetMap; + GDSDensityState densityState; + GDSSpareObjectManager sparseObjects; + std::unique_ptr curSparseCosts; + std::unique_ptr nextSparseCosts; + GDSDenseObjectManager> denseObjects; + std::unique_ptr curDenseCosts; + std::unique_ptr nextDenseCosts; + + Costs* curCosts = nullptr; + Costs* nextCosts = nullptr; +}; + +template +class WSPDestinationsEdgeCompute : public EdgeCompute { +public: + explicit WSPDestinationsEdgeCompute(CostsPair* costsPair) : costsPair{costsPair} {} + + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk, + bool) override { + std::vector result; + chunk.forEach([&](auto neighbors, auto propertyVectors, auto i) { + auto nbrNodeID = neighbors[i]; + auto weight = propertyVectors[0]->template getValue(i); + WeightUtils::checkWeight(WeightedSPDestinationsFunction::name, weight); + if (costsPair->update(boundNodeID.offset, nbrNodeID.offset, + static_cast(weight))) { + result.push_back(nbrNodeID); + } + }); + return result; + } + + std::unique_ptr copy() override { + return std::make_unique>(costsPair); + } + +private: + CostsPair* costsPair; +}; + +class WSPDestinationsAuxiliaryState : public GDSAuxiliaryState { +public: + explicit WSPDestinationsAuxiliaryState(std::unique_ptr costsPair) + : costsPair{std::move(costsPair)} {} + + Costs* getCosts() { return costsPair->getCurrentCosts(); } + + void initSource(nodeID_t sourceNodeID) override { + costsPair->pinCurTableID(sourceNodeID.tableID); + costsPair->getCurrentCosts()->setCost(sourceNodeID.offset, 0); + } + + void beginFrontierCompute(table_id_t fromTableID, table_id_t toTableID) override { + costsPair->pinCurTableID(fromTableID); + costsPair->pinNextTableID(toTableID); + } + + void switchToDense(ExecutionContext* context, graph::Graph*) override { + costsPair->switchToDense(context); + } + +private: + std::unique_ptr costsPair; +}; + +class WSPDestinationsOutputWriter : public RJOutputWriter { +public: + WSPDestinationsOutputWriter(main::ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID, Costs* costs, const table_id_map_t& maxOffsetMap) + : RJOutputWriter{context, outputNodeMask, sourceNodeID}, costs{costs}, + maxOffsetMap{maxOffsetMap} { + costVector = createVector(LogicalType::DOUBLE()); + } + + void beginWritingInternal(table_id_t tableID) override { costs->pinTableID(tableID); } + + void write(FactorizedTable& fTable, table_id_t tableID, LimitCounter* counter) override { + for (auto i = 0u; i < maxOffsetMap.at(tableID); ++i) { + write(fTable, {i, tableID}, counter); + } + } + + void write(FactorizedTable& fTable, nodeID_t dstNodeID, LimitCounter* counter) override { + if (!inOutputNodeMask(dstNodeID.offset)) { // Skip dst if it not is in scope. + return; + } + if (dstNodeID == sourceNodeID_) { // Skip writing source node. + return; + } + dstNodeIDVector->setValue(0, dstNodeID); + auto cost = costs->getCost(dstNodeID.offset); + if (cost == std::numeric_limits::max()) { // Skip if dst is not visited. + return; + } + costVector->setValue(0, cost); + fTable.append(vectors); + if (counter != nullptr) { + counter->increase(1); + } + } + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, + costs, maxOffsetMap); + } + +private: + Costs* costs; + std::unique_ptr costVector; + table_id_map_t maxOffsetMap; +}; + +class WeightedSPDestinationsAlgorithm : public RJAlgorithm { +public: + std::string getFunctionName() const override { return WeightedSPDestinationsFunction::name; } + + // return srcNodeID, dstNodeID, weight + expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.weightOutputExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, + const RJBindData& bindData, RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto graph = sharedState->graph.get(); + auto curDenseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto nextDenseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto frontierPair = std::make_unique( + std::move(curDenseFrontier), std::move(nextDenseFrontier)); + auto costsPair = std::make_unique( + graph->getMaxOffsetMap(transaction::Transaction::Get(*clientContext))); + auto costPairPtr = costsPair.get(); + auto auxiliaryState = std::make_unique(std::move(costsPair)); + std::unique_ptr gdsState; + WeightUtils::visit(WeightedSPDestinationsFunction::name, + bindData.weightPropertyExpr->getDataType(), [&](T) { + auto edgeCompute = std::make_unique>(costPairPtr); + gdsState = std::make_unique(std::move(frontierPair), + std::move(edgeCompute), std::move(auxiliaryState)); + }); + return gdsState; + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, const RJBindData&, + GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto costs = + computeState.auxiliaryState->ptrCast()->getCosts(); + auto clientContext = context->clientContext; + return std::make_unique(clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, costs, + sharedState->graph->getMaxOffsetMap(transaction::Transaction::Get(*clientContext))); + } +}; + +std::unique_ptr WeightedSPDestinationsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/wsp_paths.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/wsp_paths.cpp new file mode 100644 index 0000000000..03eb73e015 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/gds/wsp_paths.cpp @@ -0,0 +1,155 @@ +#include "binder/binder.h" +#include "function/gds/auxiliary_state/path_auxiliary_state.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "function/gds/weight_utils.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::processor; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +template +class WSPPathsEdgeCompute : public EdgeCompute { +public: + explicit WSPPathsEdgeCompute(BFSGraphManager* bfsGraphManager) + : bfsGraphManager{bfsGraphManager} { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + + std::vector edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk, + bool fwdEdge) override { + std::vector result; + chunk.forEach([&](auto neighbors, auto propertyVectors, auto i) { + auto nbrNodeID = neighbors[i]; + auto edgeID = propertyVectors[0]->template getValue(i); + auto weight = propertyVectors[1]->template getValue(i); + WeightUtils::checkWeight(WeightedSPPathsFunction::name, weight); + if (!block->hasSpace()) { + block = bfsGraphManager->getCurrentGraph()->addNewBlock(); + } + if (bfsGraphManager->getCurrentGraph()->tryAddSingleParentWithWeight(boundNodeID, + edgeID, nbrNodeID, fwdEdge, static_cast(weight), block)) { + result.push_back(nbrNodeID); + } + }); + return result; + } + + std::unique_ptr copy() override { + return std::make_unique>(bfsGraphManager); + } + +private: + BFSGraphManager* bfsGraphManager; + ObjectBlock* block = nullptr; +}; + +class WSPPathsOutputWriter : public PathsOutputWriter { +public: + WSPPathsOutputWriter(main::ClientContext* context, NodeOffsetMaskMap* outputNodeMask, + nodeID_t sourceNodeID, PathsOutputWriterInfo info, BaseBFSGraph& bfsGraph) + : PathsOutputWriter{context, outputNodeMask, sourceNodeID, info, bfsGraph} { + costVector = createVector(LogicalType::DOUBLE()); + } + + void writeInternal(FactorizedTable& fTable, nodeID_t dstNodeID, + LimitCounter* counter) override { + if (dstNodeID == sourceNodeID_) { // Skip writing source node. + return; + } + auto parent = bfsGraph.getParentListHead(dstNodeID.offset); + if (parent == nullptr) { // Skip if dst is not visited. + return; + } + costVector->setValue(0, parent->getCost()); + std::vector curPath; + curPath.push_back(parent); + while (parent->getCost() != 0) { + parent = bfsGraph.getParentListHead(parent->getNodeID()); + curPath.push_back(parent); + } + curPath.pop_back(); + writePath(curPath); + fTable.append(vectors); + updateCounterAndTerminate(counter); + } + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, info, + bfsGraph); + } + +private: + std::unique_ptr costVector; +}; + +class WeightedSPPathsAlgorithm : public RJAlgorithm { +public: + std::string getFunctionName() const override { return WeightedSPPathsFunction::name; } + + // return srcNodeID, dstNodeID, length, [direction], pathNodeIDs, pathEdgeIDs, weight + binder::expression_vector getResultColumns(const RJBindData& bindData) const override { + expression_vector columns; + columns.push_back(bindData.nodeInput->constCast().getInternalID()); + columns.push_back(bindData.nodeOutput->constCast().getInternalID()); + columns.push_back(bindData.lengthExpr); + if (bindData.extendDirection == ExtendDirection::BOTH) { + columns.push_back(bindData.directionExpr); + } + columns.push_back(bindData.pathNodeIDsExpr); + columns.push_back(bindData.pathEdgeIDsExpr); + columns.push_back(bindData.weightOutputExpr); + return columns; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::unique_ptr getComputeState(ExecutionContext* context, + const RJBindData& bindData, RecursiveExtendSharedState* sharedState) override { + auto clientContext = context->clientContext; + auto graph = sharedState->graph.get(); + auto curDenseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto nextDenseFrontier = DenseFrontier::getUninitializedFrontier(context, graph); + auto frontierPair = std::make_unique( + std::move(curDenseFrontier), std::move(nextDenseFrontier)); + auto bfsGraph = std::make_unique( + sharedState->graph->getMaxOffsetMap(transaction::Transaction::Get(*clientContext)), + MemoryManager::Get(*clientContext)); + std::unique_ptr gdsState; + WeightUtils::visit(WeightedSPPathsFunction::name, + bindData.weightPropertyExpr->getDataType(), [&](T) { + auto edgeCompute = std::make_unique>(bfsGraph.get()); + auto auxiliaryState = std::make_unique(std::move(bfsGraph)); + gdsState = std::make_unique(std::move(frontierPair), + std::move(edgeCompute), std::move(auxiliaryState)); + }); + return gdsState; + } + + std::unique_ptr getOutputWriter(ExecutionContext* context, + const RJBindData& bindData, GDSComputeState& computeState, nodeID_t sourceNodeID, + RecursiveExtendSharedState* sharedState) override { + auto bfsGraph = computeState.auxiliaryState->ptrCast() + ->getBFSGraphManager() + ->getCurrentGraph(); + auto writerInfo = bindData.getPathWriterInfo(); + return std::make_unique(context->clientContext, + sharedState->getOutputNodeMaskMap(), sourceNodeID, writerInfo, *bfsGraph); + } +}; + +std::unique_ptr WeightedSPPathsFunction::getAlgorithm() { + return std::make_unique(); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/internal_id/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/internal_id/CMakeLists.txt new file mode 100644 index 0000000000..40d38818c3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/internal_id/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_function_internal_id + OBJECT + internal_id_creation_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/internal_id/internal_id_creation_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/internal_id/internal_id_creation_function.cpp new file mode 100644 index 0000000000..62f865d0c5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/internal_id/internal_id_creation_function.cpp @@ -0,0 +1,37 @@ +#include "common/type_utils.h" +#include "common/types/types.h" +#include "function/internal_id/vector_internal_id_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +using namespace common; + +struct InternalIDCreation { + template + static void operation(T& tableID, T& offset, internalID_t& result) { + result = internalID_t((offset_t)offset, (offset_t)tableID); + } +}; + +function_set InternalIDCreationFunction::getFunctionSet() { + function_set result; + function::scalar_func_exec_t execFunc; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + common::TypeUtils::visit( + common::LogicalType(typeID), + [&](T) { + execFunc = + ScalarFunction::BinaryExecFunction; + }, + [](auto) { KU_UNREACHABLE; }); + result.push_back(std::make_unique(name, + std::vector{typeID, typeID}, LogicalTypeID::INTERNAL_ID, + execFunc)); + } + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/CMakeLists.txt new file mode 100644 index 0000000000..aa8250b934 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/CMakeLists.txt @@ -0,0 +1,33 @@ +add_library(lbug_list_function + OBJECT + list_agg_function.cpp + list_any_value_function.cpp + list_any.cpp + list_all.cpp + list_append_function.cpp + list_concat_function.cpp + list_contains_function.cpp + list_creation.cpp + list_distinct_function.cpp + list_extract_function.cpp + list_range_function.cpp + list_reverse_function.cpp + list_slice_function.cpp + list_sort_function.cpp + list_to_string_function.cpp + list_unique_function.cpp + list_prepend_function.cpp + list_position_function.cpp + list_transform.cpp + list_filter.cpp + list_function_utils.cpp + list_reduce.cpp + list_none.cpp + list_single.cpp + size_function.cpp + quantifier_functions.cpp + list_has_all.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_agg_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_agg_function.cpp new file mode 100644 index 0000000000..2b33b27071 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_agg_function.cpp @@ -0,0 +1,77 @@ +#include "common/exception/binder.h" +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +template +static std::unique_ptr bindFuncListAggr(const ScalarBindFuncInput& input) { + auto scalarFunction = input.definition->ptrCast(); + const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); + TypeUtils::visit( + resultType, + [&scalarFunction](T) { + scalarFunction->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + }, + [&input, &resultType](auto) { + throw BinderException(stringFormat("Unsupported inner data type for {}: {}", + input.definition->name, LogicalTypeUtils::toString(resultType.getLogicalTypeID()))); + }); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +struct ListSum { + template + static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector, + common::ValueVector& /*resultVector*/) { + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + result = 0; + for (auto i = 0u; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + continue; + } + result += inputDataVector->getValue(input.offset + i); + } + } +}; + +function_set ListSumFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::INT64); + function->bindFunc = bindFuncListAggr; + result.push_back(std::move(function)); + return result; +} + +struct ListProduct { + template + static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector, + common::ValueVector& /*resultVector*/) { + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + result = 1; + for (auto i = 0u; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + continue; + } + result *= inputDataVector->getValue(input.offset + i); + } + } +}; + +function_set ListProductFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::INT64); + function->bindFunc = bindFuncListAggr; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_all.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_all.cpp new file mode 100644 index 0000000000..9141c7b3bc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_all.cpp @@ -0,0 +1,26 @@ +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +bool allHandler(uint64_t numSelectedValues, uint64_t originalSize) { + return numSelectedValues == originalSize; +} + +function_set ListAllFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL, + std::bind(execQuantifierFunc, allHandler, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + function->bindFunc = bindQuantifierFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_any.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_any.cpp new file mode 100644 index 0000000000..aceb140526 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_any.cpp @@ -0,0 +1,26 @@ +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +bool anyHandler(uint64_t numSelectedValues, uint64_t /*originalSize*/) { + return numSelectedValues > 0; +} + +function_set ListAnyFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL, + std::bind(execQuantifierFunc, anyHandler, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + function->bindFunc = bindQuantifierFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_any_value_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_any_value_function.cpp new file mode 100644 index 0000000000..03bac730bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_any_value_function.cpp @@ -0,0 +1,49 @@ +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListAnyValue { + template + static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector, + common::ValueVector& resultVector) { + auto inputValues = common::ListVector::getListValues(&inputVector, input); + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + auto numBytesPerValue = inputDataVector->getNumBytesPerValue(); + + for (auto i = 0u; i < input.size; i++) { + if (!(inputDataVector->isNull(input.offset + i))) { + resultVector.copyFromVectorData(reinterpret_cast(&result), + inputDataVector, inputValues); + break; + } + inputValues += numBytesPerValue; + } + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto scalarFunction = ku_dynamic_cast(input.definition); + const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); + TypeUtils::visit(resultType.getPhysicalType(), [&scalarFunction](T) { + scalarFunction->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + }); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +function_set ListAnyValueFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::ANY); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_append_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_append_function.cpp new file mode 100644 index 0000000000..281bc048ff --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_append_function.cpp @@ -0,0 +1,69 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "function/function.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListAppend { + template + static void operation(common::list_entry_t& listEntry, T& value, common::list_entry_t& result, + common::ValueVector& listVector, common::ValueVector& valueVector, + common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, listEntry.size + 1); + auto listDataVector = common::ListVector::getDataVector(&listVector); + auto listPos = listEntry.offset; + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + for (auto i = 0u; i < listEntry.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++); + } + resultDataVector->copyFromVectorData( + resultDataVector->getData() + resultPos * resultDataVector->getNumBytesPerValue(), + &valueVector, reinterpret_cast(&value)); + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + + std::vector types; + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + + using resolver = ListTypeResolver; + ListFunctionUtils::resolveTypes(input, types, resolver::anyEmpty, resolver::anyEmpty, + resolver::anyEmpty, resolver::finalResolver, resolver::bothNull, resolver::leftNull, + resolver::rightNull, resolver::finalResolver); + + if (types[0].getLogicalTypeID() != LogicalTypeID::ANY && + types[1] != ListType::getChildType(types[0])) { + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListAppendFunction::name, types[0].toString(), types[1].toString())); + } + + auto scalarFunction = input.definition->ptrCast(); + TypeUtils::visit(types[1].getPhysicalType(), [&scalarFunction](T) { + scalarFunction->execFunc = + ScalarFunction::BinaryExecListStructFunction; + }); + return std::make_unique(std::move(types), types[0].copy()); +} + +function_set ListAppendFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_concat_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_concat_function.cpp new file mode 100644 index 0000000000..9605234a2c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_concat_function.cpp @@ -0,0 +1,62 @@ +#include "function/list/functions/list_concat_function.h" + +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void ListConcat::operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, left.size + right.size); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + auto leftDataVector = common::ListVector::getDataVector(&leftVector); + auto leftPos = left.offset; + for (auto i = 0u; i < left.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos++); + } + auto rightDataVector = common::ListVector::getDataVector(&rightVector); + auto rightPos = right.offset; + for (auto i = 0u; i < right.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, rightDataVector, rightPos++); + } +} + +std::unique_ptr ListConcatFunction::bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + + using resolver = ListTypeResolver; + ListFunctionUtils::resolveTypes(input, types, resolver::leftEmpty, resolver::leftEmpty, + resolver::rightEmpty, resolver::finalResolver, resolver::bothNull, resolver::leftEmpty, + resolver::rightEmpty, resolver::finalResolver); + + if (types[0] != types[1]) { + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(name, + types[0].toString(), types[1].toString())); + } + return std::make_unique(std::move(types), types[0].copy()); +} + +function_set ListConcatFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_contains_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_contains_function.cpp new file mode 100644 index 0000000000..24064ce53e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_contains_function.cpp @@ -0,0 +1,68 @@ +#include "binder/expression/expression_util.h" +#include "common/exception/binder.h" +#include "common/type_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +struct ListContains { + template + static void operation(common::list_entry_t& list, T& element, uint8_t& result, + common::ValueVector& listVector, common::ValueVector& elementVector, + common::ValueVector& resultVector) { + int64_t pos = 0; + ListPosition::operation(list, element, pos, listVector, elementVector, resultVector); + result = (pos != 0); + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto scalarFunction = input.definition->ptrCast(); + // for list_contains(list, input), we expect input and list child have the same type, if list + // is empty, we use in the input type. Otherwise, we use list child type because casting list + // is more expensive. + std::vector paramTypes; + LogicalType childType; + auto listExpr = input.arguments[0]; + auto elementExpr = input.arguments[1]; + if (ExpressionUtil::isEmptyList(*listExpr)) { + childType = elementExpr->getDataType().copy(); + } else { + auto& listChildType = ListType::getChildType(listExpr->getDataType()); + auto& elementType = elementExpr->getDataType(); + if (!LogicalTypeUtils::tryGetMaxLogicalType(listChildType, elementType, childType)) { + throw BinderException( + stringFormat("Cannot compare {} and {} in list_contains function.", + listChildType.toString(), elementType.toString())); + } + } + if (childType.getLogicalTypeID() == LogicalTypeID::ANY) { + childType = LogicalType::STRING(); + } + auto listType = LogicalType::LIST(childType.copy()); + paramTypes.push_back(listType.copy()); + paramTypes.push_back(childType.copy()); + TypeUtils::visit(childType.getPhysicalType(), [&scalarFunction](T) { + scalarFunction->execFunc = + ScalarFunction::BinaryExecListStructFunction; + }); + return std::make_unique(std::move(paramTypes), LogicalType::BOOL()); +} + +function_set ListContainsFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_creation.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_creation.cpp new file mode 100644 index 0000000000..a163a7d7be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_creation.cpp @@ -0,0 +1,56 @@ +#include "binder/expression/expression_util.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void ListCreationFunction::execFunc( + const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + result.resetAuxiliaryBuffer(); + for (auto selectedPos = 0u; selectedPos < resultSelVector->getSelSize(); ++selectedPos) { + auto pos = (*resultSelVector)[selectedPos]; + auto resultEntry = ListVector::addList(&result, parameters.size()); + result.setValue(pos, resultEntry); + auto resultDataVector = ListVector::getDataVector(&result); + auto resultPos = resultEntry.offset; + for (auto i = 0u; i < parameters.size(); i++) { + const auto& parameter = parameters[i]; + const auto& parameterSelVector = *parameterSelVectors[i]; + auto paramPos = parameter->state->isFlat() ? parameterSelVector[0] : pos; + resultDataVector->copyFromVectorData(resultPos++, parameter.get(), paramPos); + } + } +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + LogicalType combinedType(LogicalTypeID::ANY); + binder::ExpressionUtil::tryCombineDataType(input.arguments, combinedType); + if (combinedType.getLogicalTypeID() == LogicalTypeID::ANY) { + combinedType = LogicalType::INT64(); + } + auto resultType = LogicalType::LIST(combinedType.copy()); + auto bindData = std::make_unique(std::move(resultType)); + for (auto& _ : input.arguments) { + (void)_; + bindData->paramTypes.push_back(combinedType.copy()); + } + return bindData; +} + +function_set ListCreationFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::LIST, execFunc); + function->bindFunc = bindFunc; + function->isVarLength = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_distinct_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_distinct_function.cpp new file mode 100644 index 0000000000..84e1b70cfa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_distinct_function.cpp @@ -0,0 +1,43 @@ +#include "function/list/functions/list_unique_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListDistinct { + static void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto numUniqueValues = ListUnique::appendListElementsToValueSet(input, inputVector); + result = common::ListVector::addList(&resultVector, numUniqueValues); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultDataVectorBuffer = + common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */); + ListUnique::appendListElementsToValueSet(input, inputVector, nullptr, + [&resultDataVector, &resultDataVectorBuffer](common::ValueVector& dataVector, + uint64_t pos) -> void { + resultDataVector->copyFromVectorData(resultDataVectorBuffer, &dataVector, + dataVector.getData() + pos * dataVector.getNumBytesPerValue()); + resultDataVectorBuffer += dataVector.getNumBytesPerValue(); + }); + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + return FunctionBindData::getSimpleBindData(input.arguments, input.arguments[0]->getDataType()); +} + +function_set ListDistinctFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::LIST, + ScalarFunction::UnaryExecNestedTypeFunction); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_extract_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_extract_function.cpp new file mode 100644 index 0000000000..1efe243719 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_extract_function.cpp @@ -0,0 +1,55 @@ +#include "function/list/functions/list_extract_function.h" + +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +template +static void BinaryExecListExtractFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector, dataPtr); +} + +static std::unique_ptr ListExtractBindFunc(const ScalarBindFuncInput& input) { + const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); + auto scalarFunction = input.definition->ptrCast(); + TypeUtils::visit(resultType.getPhysicalType(), [&scalarFunction](T) { + scalarFunction->execFunc = + BinaryExecListExtractFunction; + }); + std::vector paramTypes; + paramTypes.push_back(input.arguments[0]->getDataType().copy()); + paramTypes.push_back(LogicalType(input.definition->parameterTypeIDs[1])); + return std::make_unique(std::move(paramTypes), resultType.copy()); +} + +function_set ListExtractFunction::getFunctionSet() { + function_set result; + std::unique_ptr func; + func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::INT64}, LogicalTypeID::ANY); + func->bindFunc = ListExtractBindFunc; + result.push_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::BinaryExecFunction); + result.push_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::ARRAY, LogicalTypeID::INT64}, LogicalTypeID::ANY); + func->bindFunc = ListExtractBindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_filter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_filter.cpp new file mode 100644 index 0000000000..c0cf43dbbf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_filter.cpp @@ -0,0 +1,134 @@ +#include "common/exception/binder.h" +#include "expression_evaluator/lambda_evaluator.h" +#include "expression_evaluator/list_slice_info.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + if (input.arguments[1]->expressionType != ExpressionType::LAMBDA) { + throw BinderException(stringFormat( + "The second argument of LIST_FILTER should be a lambda expression but got {}.", + ExpressionTypeUtil::toString(input.arguments[1]->expressionType))); + } + std::vector paramTypes; + paramTypes.push_back(input.arguments[0]->getDataType().copy()); + paramTypes.push_back(input.arguments[1]->getDataType().copy()); + if (input.arguments[1]->getDataType() != LogicalType::BOOL()) { + throw BinderException(stringFormat( + "{} requires the result type of lambda expression be BOOL.", ListFilterFunction::name)); + } + return std::make_unique(std::move(paramTypes), + LogicalType::LIST(ListType::getChildType(input.arguments[0]->getDataType()).copy())); +} + +static void constEvaluateFilterResult(const common::ValueVector& inputVector, + const common::SelectionVector& listInputSelVector, common::ValueVector& result, + common::SelectionVector& resultSelVector, const common::ValueVector& filterVector, + const common::SelectionVector& filterSelVector, evaluator::ListSliceInfo* sliceInfo) { + auto srcDataVector = ListVector::getDataVector(&inputVector); + auto dstDataVector = ListVector::getDataVector(&result); + KU_ASSERT(!filterVector.isNull(filterSelVector[0])); + auto filterResult = filterVector.getValue(filterSelVector[0]); + + // resolve data vector + if (filterResult) { + for (sel_t i = 0; i < sliceInfo->getSliceSize(); ++i) { + const auto [_, dataOffset] = sliceInfo->getPos(i); + dstDataVector->copyFromVectorData(dataOffset, srcDataVector, dataOffset); + dstDataVector->setNull(dataOffset, srcDataVector->isNull(dataOffset)); + } + } + + // resolve list entries + if (sliceInfo->done()) { + for (uint64_t i = 0; i < listInputSelVector.getSelSize(); ++i) { + list_entry_t dstListEntry; + auto srcListEntry = inputVector.getValue(listInputSelVector[i]); + if (filterResult) { + dstListEntry = srcListEntry; + } else { + dstListEntry = {srcListEntry.offset, 0}; + } + result.setValue(resultSelVector[i], dstListEntry); + } + } +} + +static void evaluateFilterResult(const common::ValueVector& inputVector, + common::ValueVector& result, const common::ValueVector& filterVector, + [[maybe_unused]] const common::SelectionVector& filterSelVector, + evaluator::ListSliceInfo* sliceInfo) { + KU_ASSERT(filterSelVector.isUnfiltered()); + auto srcDataVector = ListVector::getDataVector(&inputVector); + auto dstDataVector = ListVector::getDataVector(&result); + + auto& resultDataOffset = sliceInfo->getResultSliceOffset(); + for (sel_t i = 0; i < sliceInfo->getSliceSize(); ++i) { + const auto [listEntryPos, dataOffset] = sliceInfo->getPos(i); + const auto listEntry = inputVector.getValue(listEntryPos); + if (dataOffset == listEntry.offset) { + result.setValue(listEntryPos, list_entry_t{resultDataOffset, 0}); + } + if (filterVector.getValue(i) && !filterVector.isNull(i)) { + // TODO(Royi) make the output pos respect resultSelVector + auto& resultListEntry = result.getValue(listEntryPos); + dstDataVector->copyFromVectorData(resultDataOffset, srcDataVector, dataOffset); + dstDataVector->setNull(resultDataOffset, srcDataVector->isNull(dataOffset)); + ++resultListEntry.size; + ++resultDataOffset; + } + } +} + +static void execFunc(const std::vector>& input, + const std::vector& inputSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* bindData) { + auto listLambdaBindData = reinterpret_cast(bindData); + auto* sliceInfo = listLambdaBindData->sliceInfo; + const auto& inputVector = *input[0]; + + auto savedParamStates = + sliceInfo->overrideAndSaveParamStates(listLambdaBindData->lambdaParamEvaluators); + + listLambdaBindData->rootEvaluator->evaluate(); + KU_ASSERT(input.size() == 2); + auto& listInputSelVector = *inputSelVectors[0]; + auto& filterVector = *input[1]; + auto& filterSelVector = *inputSelVectors[1]; + + if (listLambdaBindData->lambdaParamEvaluators.empty()) { + constEvaluateFilterResult(inputVector, listInputSelVector, result, *resultSelVector, + filterVector, filterSelVector, sliceInfo); + } else { + evaluateFilterResult(inputVector, result, filterVector, filterSelVector, sliceInfo); + } + + if (listLambdaBindData->sliceInfo->done()) { + for (idx_t i = 0; i < inputSelVectors[0]->getSelSize(); ++i) { + const auto pos = (*inputSelVectors[0])[i]; + result.setNull(pos, inputVector.isNull(i)); + } + } + + sliceInfo->restoreParamStates(listLambdaBindData->lambdaParamEvaluators, + std::move(savedParamStates)); +} + +function_set ListFilterFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_function_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_function_utils.cpp new file mode 100644 index 0000000000..99365ea316 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_function_utils.cpp @@ -0,0 +1,106 @@ +#include "function/list/functions/list_function_utils.h" + +#include "binder/expression/expression_util.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void ListTypeResolver::anyEmpty(std::vector& types, + common::LogicalType& targetType) { + targetType = types[1].copy(); + if (targetType.getLogicalTypeID() == LogicalTypeID::ANY) { + targetType = LogicalType(LogicalTypeID::INT64); + } +} + +void ListTypeResolver::bothNull(std::vector& types, + common::LogicalType& targetType) { + (void)types; + targetType = LogicalType::INT64(); +} + +void ListTypeResolver::leftNull(std::vector& types, + common::LogicalType& targetType) { + targetType = types[1].copy(); +} + +void ListTypeResolver::rightNull(std::vector& types, + common::LogicalType& targetType) { + targetType = ListType::getChildType(types[0]).copy(); +} + +void ListTypeResolver::finalResolver(std::vector& types, + common::LogicalType& targetType) { + types[0] = LogicalType::LIST(targetType.copy()); + types[1] = targetType.copy(); +} + +void ListTypeResolver::leftEmpty(std::vector& types, + common::LogicalType& targetType) { + targetType = types[1].copy(); +} +void ListTypeResolver::rightEmpty(std::vector& types, + common::LogicalType& targetType) { + targetType = types[0].copy(); +} +void ListTypeResolver::bothNull(std::vector& types, + common::LogicalType& targetType) { + (void)types; + targetType = LogicalType::LIST(LogicalType::INT64()); +} +void ListTypeResolver::finalResolver(std::vector& types, + common::LogicalType& targetType) { + types[0] = targetType.copy(); + types[1] = targetType.copy(); +} + +void ListFunctionUtils::resolveEmptyList(const ScalarBindFuncInput& input, + std::vector& types, type_resolver bothEmpty, type_resolver leftEmpty, + type_resolver rightEmpty, type_resolver finalEmptyListResolver) { + + auto isArg0Empty = binder::ExpressionUtil::isEmptyList(*input.arguments[0]); + auto isArg1Empty = binder::ExpressionUtil::isEmptyList(*input.arguments[1]); + LogicalType targetType; + if (isArg0Empty && isArg1Empty) { + bothEmpty(types, targetType); + } else if (isArg0Empty) { + leftEmpty(types, targetType); + } else if (isArg1Empty) { + rightEmpty(types, targetType); + } else { + return; + } + finalEmptyListResolver(types, targetType); +} + +void ListFunctionUtils::resolveNulls(std::vector& types, + type_resolver bothNull, type_resolver leftNull, type_resolver rightNull, + type_resolver finalNullParamResolver) { + auto isArg0AnyType = types[0].getLogicalTypeID() == common::LogicalTypeID::ANY; + auto isArg1AnyType = types[1].getLogicalTypeID() == common::LogicalTypeID::ANY; + + common::LogicalType targetType; + if (isArg0AnyType && isArg1AnyType) { + bothNull(types, targetType); + } else if (isArg0AnyType) { + leftNull(types, targetType); + } else if (isArg1AnyType) { + rightNull(types, targetType); + } else { + return; + } + finalNullParamResolver(types, targetType); +} + +void ListFunctionUtils::resolveTypes(const ScalarBindFuncInput& input, + std::vector& types, type_resolver bothEmpty, type_resolver leftEmpty, + type_resolver rightEmpty, type_resolver finalEmptyListResolver, type_resolver bothNull, + type_resolver leftNull, type_resolver rightNull, type_resolver finalNullParamResolver) { + resolveEmptyList(input, types, bothEmpty, leftEmpty, rightEmpty, finalEmptyListResolver); + resolveNulls(types, bothNull, leftNull, rightNull, finalNullParamResolver); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_has_all.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_has_all.cpp new file mode 100644 index 0000000000..4944264341 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_has_all.cpp @@ -0,0 +1,68 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListHasAll { + static void operation(common::list_entry_t& left, common::list_entry_t& right, uint8_t& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + int64_t pos = 0; + auto rightDataVector = ListVector::getDataVector(&rightVector); + result = true; + for (auto i = 0u; i < right.size; i++) { + common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(), + [&](T) { + if (rightDataVector->isNull(right.offset + i)) { + return; + } + ListPosition::operation(left, + *(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos, + leftVector, *ListVector::getDataVector(&rightVector), resultVector); + result = (pos != 0); + }); + if (!result) { + return; + } + } + } +}; + +std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + for (auto& arg : input.arguments) { + if (arg->dataType == LogicalType::ANY()) { + types.push_back(LogicalType::LIST(LogicalType::INT64())); + } else { + types.push_back(arg->dataType.copy()); + } + } + if (types[0] != types[1]) { + throw common::BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListHasAllFunction::name, input.arguments[0]->getDataType().toString(), + input.arguments[1]->getDataType().toString())); + } + return std::make_unique(std::move(types), LogicalType::BOOL()); +} + +function_set ListHasAllFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::BOOL, + execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_none.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_none.cpp new file mode 100644 index 0000000000..3f138b4b61 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_none.cpp @@ -0,0 +1,26 @@ +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +bool noneHandler(uint64_t numSelectedValues, uint64_t /*originalSize*/) { + return numSelectedValues == 0; +} + +function_set ListNoneFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL, + std::bind(execQuantifierFunc, noneHandler, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + function->bindFunc = bindQuantifierFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_position_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_position_function.cpp new file mode 100644 index 0000000000..11b0d4f699 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_position_function.cpp @@ -0,0 +1,32 @@ +#include "function/list/functions/list_position_function.h" + +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto scalarFunction = input.definition->ptrCast(); + TypeUtils::visit(input.arguments[1]->getDataType().getPhysicalType(), + [&scalarFunction](T) { + scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + }); + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::INT64()); +} + +function_set ListPositionFunction::getFunctionSet() { + function_set result; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::INT64); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_prepend_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_prepend_function.cpp new file mode 100644 index 0000000000..7b3668d561 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_prepend_function.cpp @@ -0,0 +1,69 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListPrepend { + template + static void operation(common::list_entry_t& listEntry, T& value, common::list_entry_t& result, + common::ValueVector& listVector, common::ValueVector& valueVector, + common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, listEntry.size + 1); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + resultDataVector->copyFromVectorData( + common::ListVector::getListValues(&resultVector, result), &valueVector, + reinterpret_cast(&value)); + auto resultPos = result.offset + 1; + auto listDataVector = common::ListVector::getDataVector(&listVector); + auto listPos = listEntry.offset; + for (auto i = 0u; i < listEntry.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++); + } + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + + std::vector types; + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + + using resolver = ListTypeResolver; + ListFunctionUtils::resolveTypes(input, types, resolver::anyEmpty, resolver::anyEmpty, + resolver::anyEmpty, resolver::finalResolver, resolver::bothNull, resolver::leftNull, + resolver::rightNull, resolver::finalResolver); + + if (types[0].getLogicalTypeID() != LogicalTypeID::ANY && + types[1] != ListType::getChildType(types[0])) { + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListAppendFunction::name, types[0].toString(), types[1].toString())); + } + + auto scalarFunction = input.definition->ptrCast(); + TypeUtils::visit(types[1].getPhysicalType(), [&scalarFunction](T) { + scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + }); + + return std::make_unique(std::move(types), types[0].copy()); +} + +function_set ListPrependFunction::getFunctionSet() { + function_set result; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_range_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_range_function.cpp new file mode 100644 index 0000000000..78c874b589 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_range_function.cpp @@ -0,0 +1,106 @@ +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct Range { + // range function: + // - include end + // - when start = end: there is only one element in result list + // - when end - start are of opposite sign of step, the result will be empty + // - default step = 1 + template + static void operation(T& end, list_entry_t& result, ValueVector& endVector, + ValueVector& resultVector) { + T step = 1; + T start = 0; + operation(start, end, step, result, endVector, resultVector); + } + + template + static void operation(T& start, T& end, list_entry_t& result, ValueVector& leftVector, + ValueVector& /*rightVector*/, ValueVector& resultVector) { + T step = 1; + operation(start, end, step, result, leftVector, resultVector); + } + + template + static void operation(T& start, T& end, T& step, list_entry_t& result, + ValueVector& /*inputVector*/, ValueVector& resultVector) { + if (step == 0) { + throw RuntimeException("Step of range cannot be 0."); + } + + // start, start + step, start + 2step, ..., end + T number = start; + auto size = ((end - start) * 1.0 / step); + size < 0 ? size = 0 : size = (int64_t)(size + 1); + + result = ListVector::addList(&resultVector, (int64_t)size); + auto resultDataVector = ListVector::getDataVector(&resultVector); + for (auto i = 0u; i < (int64_t)size; i++) { + resultDataVector->setValue(result.offset + i, number); + number += step; + } + } +}; + +static scalar_func_exec_t getBinaryExecFunc(const LogicalType& type) { + scalar_func_exec_t execFunc; + TypeUtils::visit( + type, + [&execFunc](T) { + execFunc = ScalarFunction::BinaryExecListStructFunction; + }, + [](auto) { KU_UNREACHABLE; }); + return execFunc; +} + +static scalar_func_exec_t getTernaryExecFunc(const LogicalType& type) { + scalar_func_exec_t execFunc; + TypeUtils::visit( + type, + [&execFunc](T) { + execFunc = ScalarFunction::TernaryExecListStructFunction; + }, + [](auto) { KU_UNREACHABLE; }); + return execFunc; +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto type = LogicalType(input.definition->parameterTypeIDs[0]); + auto resultType = LogicalType::LIST(type.copy()); + auto bindData = std::make_unique(std::move(resultType)); + for (auto& _ : input.arguments) { + (void)_; + bindData->paramTypes.push_back(type.copy()); + } + return bindData; +} + +function_set ListRangeFunction::getFunctionSet() { + function_set result; + std::unique_ptr func; + for (auto typeID : LogicalTypeUtils::getIntegerTypeIDs()) { + // start, end + func = std::make_unique(name, std::vector{typeID, typeID}, + LogicalTypeID::LIST, getBinaryExecFunc(LogicalType{typeID})); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + // start, end, step + func = std::make_unique(name, + std::vector{typeID, typeID, typeID}, LogicalTypeID::LIST, + getTernaryExecFunc(LogicalType{typeID})); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + } + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_reduce.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_reduce.cpp new file mode 100644 index 0000000000..a4793d7923 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_reduce.cpp @@ -0,0 +1,111 @@ +#include "common/exception/binder.h" +#include "common/exception/runtime.h" +#include "expression_evaluator/lambda_evaluator.h" +#include "expression_evaluator/list_slice_info.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + if (input.arguments[1]->expressionType != ExpressionType::LAMBDA) { + throw BinderException(stringFormat( + "The second argument of LIST_REDUCE should be a lambda expression but got {}.", + ExpressionTypeUtil::toString(input.arguments[1]->expressionType))); + } + std::vector paramTypes; + paramTypes.push_back(input.arguments[0]->getDataType().copy()); + paramTypes.push_back(input.arguments[1]->getDataType().copy()); + return std::make_unique(std::move(paramTypes), + ListType::getChildType(input.arguments[0]->getDataType()).copy()); +} + +static void processDataEntry(offset_t curOffset, sel_t listEntryPos, common::ValueVector& result, + const common::ValueVector& inputVector, common::ValueVector& tmpResultVector, + const common::SelectionVector& tmpResultVectorSelVector, + const std::vector& params, const std::vector& paramIndices, + evaluator::ListLambdaBindData& bindData) { + common::ValueVector& inputDataVector = *ListVector::getDataVector(&inputVector); + const auto listEntry = inputVector.getValue(listEntryPos); + KU_ASSERT(listEntry.size > 0); + offset_t offsetInList = curOffset - listEntry.offset; + if (offsetInList == 0 && listEntry.size == 1) { + // if list size is 1 the reduce result is equal to the single value + result.copyFromVectorData(listEntryPos, &inputDataVector, listEntry.offset); + } else { + auto paramPos = params[0]->state->getSelVector()[0]; + auto tmpResultPos = tmpResultVectorSelVector[0]; + if (offsetInList < listEntry.size - 1) { + // continue reducing + for (auto i = 0u; i < params.size(); i++) { + if (0u == paramIndices[i] && 0u != offsetInList) { + params[i]->copyFromVectorData(paramPos, &tmpResultVector, tmpResultPos); + } else { + params[i]->copyFromVectorData(paramPos, &inputDataVector, + listEntry.offset + offsetInList + paramIndices[i]); + } + params[i]->state->getSelVectorUnsafe().setSelSize(1); + } + bindData.rootEvaluator->evaluate(); + } else { + // we are done reducing, copy the result from the intermediate result vector + result.copyFromVectorData(listEntryPos, &tmpResultVector, tmpResultPos); + } + } +} + +static void reduceSlice(evaluator::ListSliceInfo& sliceInfo, common::ValueVector& result, + const common::ValueVector& inputVector, common::ValueVector& tmpResultVector, + common::SelectionVector& tmpResultVectorSelVector, evaluator::ListLambdaBindData& bindData) { + const auto& paramIndices = bindData.paramIndices; + std::vector params(bindData.lambdaParamEvaluators.size()); + for (auto i = 0u; i < bindData.lambdaParamEvaluators.size(); i++) { + auto param = bindData.lambdaParamEvaluators[i]->resultVector.get(); + params[i] = param; + } + + for (sel_t i = 0; i < sliceInfo.getSliceSize(); ++i) { + const auto [listEntryPos, dataOffset] = sliceInfo.getPos(i); + processDataEntry(dataOffset, listEntryPos, result, inputVector, tmpResultVector, + tmpResultVectorSelVector, params, paramIndices, bindData); + } +} + +static void execFunc(const std::vector>& input, + const std::vector& inputSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* bindData) { + KU_ASSERT(input.size() == 2); + auto listLambdaBindData = reinterpret_cast(bindData); + const auto* inputVector = input[0].get(); + reduceSlice(*listLambdaBindData->sliceInfo, result, *inputVector, *input[1].get(), + *inputSelVectors[1], *listLambdaBindData); + + if (listLambdaBindData->sliceInfo->done()) { + for (idx_t i = 0; i < inputSelVectors[0]->getSelSize(); ++i) { + const auto pos = (*inputSelVectors[0])[i]; + const auto resPos = (*resultSelVector)[i]; + if (inputVector->isNull(pos)) { + result.setNull(resPos, true); + } else if (inputVector->getValue(pos).size == 0) { + throw common::RuntimeException{"Cannot execute list_reduce on an empty list."}; + } + } + } +} + +function_set ListReduceFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_reverse_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_reverse_function.cpp new file mode 100644 index 0000000000..0a591f0707 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_reverse_function.cpp @@ -0,0 +1,42 @@ +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListReverse { + static inline void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + ListVector::resizeDataVector(&resultVector, ListVector::getDataVectorSize(&inputVector)); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + result = input; // reverse does not change + for (auto i = 0u; i < input.size; i++) { + auto pos = input.offset + i; + auto reversePos = input.offset + input.size - 1 - i; + resultDataVector->copyFromVectorData(reversePos, inputDataVector, pos); + } + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto scalarFunction = ku_dynamic_cast(input.definition); + const auto& resultType = input.arguments[0]->dataType; + scalarFunction->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + return FunctionBindData::getSimpleBindData(input.arguments, resultType.copy()); +} + +function_set ListReverseFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::ANY); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_single.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_single.cpp new file mode 100644 index 0000000000..e9d6c21cd1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_single.cpp @@ -0,0 +1,26 @@ +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +bool singleHandler(uint64_t numSelectedValues, uint64_t /*originalSize*/) { + return numSelectedValues == 1; +} + +function_set ListSingleFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL, + std::bind(execQuantifierFunc, singleHandler, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + function->bindFunc = bindQuantifierFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_slice_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_slice_function.cpp new file mode 100644 index 0000000000..a39159cf79 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_slice_function.cpp @@ -0,0 +1,113 @@ +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" +#include "function/string/functions/substr_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static void normalizeIndices(int64_t& startIdx, int64_t& endIdx, uint64_t size) { + if (startIdx < 0) { + startIdx = size + startIdx + 1; + } + if (startIdx <= 0) { + startIdx = 1; + } + if (endIdx < 0) { + endIdx = size + endIdx + 1; + } + if (endIdx > (int64_t)size) { + endIdx = size; + } + if (endIdx < startIdx) { + startIdx = 1; + endIdx = 0; + } +} + +struct ListSlice { + // Note: this function takes in a 1-based begin/end index (The index of the first value in the + // listEntry is 1). + static void operation(list_entry_t& listEntry, int64_t& begin, int64_t& end, + list_entry_t& result, ValueVector& listVector, ValueVector& resultVector) { + auto startIdx = begin; + auto endIdx = end; + normalizeIndices(startIdx, endIdx, listEntry.size); + result = ListVector::addList(&resultVector, endIdx - startIdx + 1); + auto srcDataVector = ListVector::getDataVector(&listVector); + auto srcPos = listEntry.offset + startIdx - 1; + auto dstDataVector = ListVector::getDataVector(&resultVector); + auto dstPos = result.offset; + for (; startIdx <= endIdx; startIdx++) { + dstDataVector->copyFromVectorData(dstPos++, srcDataVector, srcPos++); + } + } + + static void operation(ku_string_t& str, int64_t& begin, int64_t& end, ku_string_t& result, + ValueVector&, ValueVector& resultValueVector) { + auto startIdx = begin; + auto endIdx = end; + normalizeIndices(startIdx, endIdx, str.len); + SubStr::operation(str, startIdx, endIdx - startIdx + 1, result, resultValueVector); + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + KU_ASSERT(input.arguments.size() == 3); + std::vector paramTypes; + auto& arg0Type = input.arguments[0]->getDataType(); + LogicalType resultType; + switch (arg0Type.getLogicalTypeID()) { + case LogicalTypeID::ANY: { + paramTypes.push_back(LogicalType::STRING()); + resultType = LogicalType::STRING(); + } break; + case LogicalTypeID::ARRAY: { + paramTypes.push_back(arg0Type.copy()); + resultType = LogicalType::LIST(ArrayType::getChildType(arg0Type).copy()); + } break; + default: { + paramTypes.push_back(arg0Type.copy()); + resultType = arg0Type.copy(); + } + } + paramTypes.push_back(LogicalType(input.definition->parameterTypeIDs[1])); + paramTypes.push_back(LogicalType(input.definition->parameterTypeIDs[2])); + return std::make_unique(std::move(paramTypes), std::move(resultType)); +} + +function_set ListSliceFunction::getFunctionSet() { + function_set result; + std::unique_ptr func; + // List slice + func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::INT64, LogicalTypeID::INT64}, + LogicalTypeID::LIST, + ScalarFunction::TernaryExecListStructFunction); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + // Array slice + func = std::make_unique(name, + std::vector{LogicalTypeID::ARRAY, LogicalTypeID::INT64, + LogicalTypeID::INT64}, + LogicalTypeID::LIST, + ScalarFunction::TernaryExecListStructFunction); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + // Substr + func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64, + LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::TernaryExecListStructFunction); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_sort_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_sort_function.cpp new file mode 100644 index 0000000000..c0c2b6459f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_sort_function.cpp @@ -0,0 +1,95 @@ +#include "function/list/functions/list_sort_function.h" + +#include "common/exception/binder.h" +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "function/list/functions/list_reverse_sort_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +template +static scalar_func_exec_t getListSortExecFunction(const binder::expression_vector& arguments) { + scalar_func_exec_t func; + if (arguments.size() == 1) { + func = ScalarFunction::UnaryExecNestedTypeFunction; + } else if (arguments.size() == 2) { + func = ScalarFunction::BinaryExecListStructFunction; + } else if (arguments.size() == 3) { + func = ScalarFunction::TernaryExecListStructFunction; + } else { + throw RuntimeException("Invalid number of arguments"); + } + return func; +} + +static std::unique_ptr ListSortBindFunc(ScalarBindFuncInput input) { + auto scalarFunction = input.definition->ptrCast(); + if (input.arguments[0]->dataType.getLogicalTypeID() == common::LogicalTypeID::ANY) { + throw BinderException(stringFormat("Cannot resolve recursive data type for expression {}", + input.arguments[0]->toString())); + } + common::TypeUtils::visit( + ListType::getChildType(input.arguments[0]->dataType).getPhysicalType(), + [&input, &scalarFunction](T) { + scalarFunction->execFunc = getListSortExecFunction>(input.arguments); + }, + [](auto) { KU_UNREACHABLE; }); + return FunctionBindData::getSimpleBindData(input.arguments, input.arguments[0]->getDataType()); +} + +static std::unique_ptr ListReverseSortBindFunc(const ScalarBindFuncInput& input) { + auto scalarFunction = input.definition->ptrCast(); + common::TypeUtils::visit( + ListType::getChildType(input.arguments[0]->dataType).getPhysicalType(), + [&input, &scalarFunction](T) { + scalarFunction->execFunc = getListSortExecFunction>(input.arguments); + }, + [](auto) { KU_UNREACHABLE; }); + return FunctionBindData::getSimpleBindData(input.arguments, input.arguments[0]->getDataType()); +} + +function_set ListSortFunction::getFunctionSet() { + function_set result; + std::unique_ptr func; + func = std::make_unique(name, std::vector{LogicalTypeID::LIST}, + LogicalTypeID::LIST); + func->bindFunc = ListSortBindFunc; + result.push_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::STRING}, + LogicalTypeID::LIST); + func->bindFunc = ListSortBindFunc; + result.push_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::STRING, + LogicalTypeID::STRING}, + LogicalTypeID::LIST); + func->bindFunc = ListSortBindFunc; + result.push_back(std::move(func)); + return result; +} + +function_set ListReverseSortFunction::getFunctionSet() { + function_set result; + std::unique_ptr func; + func = std::make_unique(name, std::vector{LogicalTypeID::LIST}, + LogicalTypeID::LIST); + func->bindFunc = ListReverseSortBindFunc; + result.push_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::STRING}, + LogicalTypeID::LIST); + func->bindFunc = ListReverseSortBindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_to_string_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_to_string_function.cpp new file mode 100644 index 0000000000..777d7ae76a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_to_string_function.cpp @@ -0,0 +1,67 @@ +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ListToString { + static void operation(ku_string_t& delim, list_entry_t& input, common::ku_string_t& result, + common::ValueVector& inputVector, common::ValueVector& /*delimVector*/, + common::ValueVector& resultVector); +}; + +void ListToString::operation(ku_string_t& delim, list_entry_t& input, ku_string_t& result, + ValueVector& /*delimVector*/, ValueVector& inputVector, ValueVector& resultVector) { + std::string resultStr = ""; + bool outputDelim = false; + if (input.size != 0) { + auto dataVector = ListVector::getDataVector(&inputVector); + if (!dataVector->isNull(input.offset)) { + resultStr += TypeUtils::entryToString(dataVector->dataType, + ListVector::getListValuesWithOffset(&inputVector, input, 0 /* offset */), + dataVector); + outputDelim = true; + } + for (auto i = 1u; i < input.size; i++) { + if (dataVector->isNull(input.offset + i)) { + continue; + } + if (outputDelim) { + resultStr += delim.getAsString(); + } + outputDelim = true; + resultStr += TypeUtils::entryToString(dataVector->dataType, + ListVector::getListValuesWithOffset(&inputVector, input, i), dataVector); + } + } + StringVector::addString(&resultVector, result, resultStr); +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector paramTypes; + paramTypes.push_back(LogicalType(input.definition->parameterTypeIDs[0])); + if (input.arguments[1]->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + paramTypes.push_back(LogicalType::STRING()); + } else { + paramTypes.push_back(input.arguments[1]->getDataType().copy()); + } + return std::make_unique(std::move(paramTypes), LogicalType::STRING()); +} + +function_set ListToStringFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::LIST}, + LogicalTypeID::STRING, + ScalarFunction::BinaryExecListStructFunction); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_transform.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_transform.cpp new file mode 100644 index 0000000000..9fb64ff856 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_transform.cpp @@ -0,0 +1,90 @@ +#include "common/exception/binder.h" +#include "expression_evaluator/lambda_evaluator.h" +#include "expression_evaluator/list_slice_info.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +using namespace common; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + if (input.arguments[1]->expressionType != ExpressionType::LAMBDA) { + throw BinderException(stringFormat( + "The second argument of LIST_TRANSFORM should be a lambda expression but got {}.", + ExpressionTypeUtil::toString(input.arguments[1]->expressionType))); + } + std::vector paramTypes; + paramTypes.push_back(input.arguments[0]->getDataType().copy()); + paramTypes.push_back(input.arguments[1]->getDataType().copy()); + return std::make_unique(std::move(paramTypes), + LogicalType::LIST(input.arguments[1]->getDataType().copy())); +} + +static void copyEvaluatedDataToResult(ValueVector& resultVector, + evaluator::ListLambdaBindData* bindData) { + auto& sliceInfo = *bindData->sliceInfo; + auto dstDataVector = ListVector::getDataVector(&resultVector); + auto rootResultVector = bindData->rootEvaluator->resultVector.get(); + for (sel_t i = 0; i < sliceInfo.getSliceSize(); ++i) { + const auto [listEntryPos, dataOffset] = sliceInfo.getPos(i); + const auto srcIdx = bindData->lambdaParamEvaluators.empty() ? 0 : i; + sel_t srcPos = rootResultVector->state->getSelVector()[srcIdx]; + dstDataVector->copyFromVectorData(dataOffset, rootResultVector, srcPos); + dstDataVector->setNull(dataOffset, rootResultVector->isNull(srcPos)); + } +} + +static void copyListEntriesToResult(const ValueVector& inputVector, + const SelectionVector& inputSelVector, ValueVector& result) { + for (uint64_t i = 0; i < inputSelVector.getSelSize(); ++i) { + auto pos = inputSelVector[i]; + result.setNull(pos, inputVector.isNull(pos)); + + auto inputList = inputVector.getValue(pos); + ListVector::addList(&result, inputList.size); + result.setValue(pos, inputList); + } +} + +static void execFunc(const std::vector>& input, + const std::vector& inputSelVectors, ValueVector& result, + SelectionVector* resultSelVector, void* bindData_) { + auto bindData = reinterpret_cast(bindData_); + auto* sliceInfo = bindData->sliceInfo; + auto savedParamStates = sliceInfo->overrideAndSaveParamStates(bindData->lambdaParamEvaluators); + + bindData->rootEvaluator->evaluate(); + copyEvaluatedDataToResult(result, bindData); + + auto& inputVector = *input[0]; + const auto& inputSelVector = *inputSelVectors[0]; + KU_ASSERT(input.size() == 2); + if (!bindData->lambdaParamEvaluators.empty()) { + if (sliceInfo->done()) { + ListVector::copyListEntryAndBufferMetaData(result, *resultSelVector, inputVector, + inputSelVector); + } + } else { + if (sliceInfo->done()) { + copyListEntriesToResult(inputVector, inputSelVector, result); + } + } + + sliceInfo->restoreParamStates(bindData->lambdaParamEvaluators, std::move(savedParamStates)); +} + +function_set ListTransformFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + function->isListLambda = true; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_unique_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_unique_function.cpp new file mode 100644 index 0000000000..a823f9bb13 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/list_unique_function.cpp @@ -0,0 +1,58 @@ +#include "function/list/functions/list_unique_function.h" + +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +uint64_t ListUnique::appendListElementsToValueSet(common::list_entry_t& input, + common::ValueVector& inputVector, duplicate_value_handler duplicateValHandler, + unique_value_handler uniqueValueHandler, null_value_handler nullValueHandler) { + ValueSet uniqueKeys; + auto dataVector = common::ListVector::getDataVector(&inputVector); + auto val = common::Value::createDefaultValue(dataVector->dataType); + for (auto i = 0u; i < input.size; i++) { + if (dataVector->isNull(input.offset + i)) { + if (nullValueHandler != nullptr) { + nullValueHandler(); + } + continue; + } + auto entryVal = common::ListVector::getListValuesWithOffset(&inputVector, input, i); + val.copyFromColLayout(entryVal, dataVector); + auto uniqueKey = uniqueKeys.insert(val).second; + if (duplicateValHandler != nullptr && !uniqueKey) { + duplicateValHandler( + common::TypeUtils::entryToString(dataVector->dataType, entryVal, dataVector)); + } + if (uniqueValueHandler != nullptr && uniqueKey) { + uniqueValueHandler(*dataVector, input.offset + i); + } + } + return uniqueKeys.size(); +} + +void ListUnique::operation(common::list_entry_t& input, int64_t& result, + common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) { + result = appendListElementsToValueSet(input, inputVector); +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::INT64()); +} + +function_set ListUniqueFunction::getFunctionSet() { + function_set result; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::INT64, + ScalarFunction::UnaryExecNestedTypeFunction); + func->bindFunc = bindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/quantifier_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/quantifier_functions.cpp new file mode 100644 index 0000000000..549399293c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/quantifier_functions.cpp @@ -0,0 +1,51 @@ +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "expression_evaluator/lambda_evaluator.h" +#include "function/function.h" +#include "function/list/vector_list_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void execQuantifierFunc(quantifier_handler handler, + const std::vector>& input, + const std::vector& inputSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* bindData) { + auto listLambdaBindData = reinterpret_cast(bindData); + auto& inputVector = *input[0]; + if (!listLambdaBindData->lambdaParamEvaluators.empty()) { + auto listSize = ListVector::getDataVectorSize(&inputVector); + auto lambdaParamVector = listLambdaBindData->lambdaParamEvaluators[0]->resultVector.get(); + lambdaParamVector->state->getSelVectorUnsafe().setSelSize(listSize); + } + auto& filterVector = *input[1]; + bool isConstantTrueExpr = listLambdaBindData->lambdaParamEvaluators.empty() && + filterVector.getValue(filterVector.state->getSelVector()[0]); + listLambdaBindData->rootEvaluator->evaluate(); + KU_ASSERT(input.size() == 2); + auto& listInputSelVector = *inputSelVectors[0]; + uint64_t numSelectedValues = 0; + for (auto i = 0u; i < listInputSelVector.getSelSize(); ++i) { + numSelectedValues = 0; + auto srcListEntry = inputVector.getValue(listInputSelVector[i]); + for (auto j = 0u; j < srcListEntry.size; j++) { + auto pos = srcListEntry.offset + j; + if (isConstantTrueExpr || filterVector.getValue(pos)) { + numSelectedValues++; + } + } + result.setValue((*resultSelVector)[i], handler(numSelectedValues, srcListEntry.size)); + } +} + +std::unique_ptr bindQuantifierFunc(const ScalarBindFuncInput& input) { + std::vector paramTypes; + paramTypes.push_back(input.arguments[0]->getDataType().copy()); + paramTypes.push_back(input.arguments[1]->getDataType().copy()); + return std::make_unique(std::move(paramTypes), common::LogicalType::BOOL()); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/size_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/size_function.cpp new file mode 100644 index 0000000000..01a77cdd55 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/list/size_function.cpp @@ -0,0 +1,54 @@ +#include "function/list/functions/list_len_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr sizeBindFunc(const ScalarBindFuncInput& input) { + auto scalarFunc = input.definition->constPtrCast(); + auto resultType = LogicalType(scalarFunc->returnTypeID); + if (input.definition->parameterTypeIDs[0] == common::LogicalTypeID::STRING) { + std::vector paramTypes; + paramTypes.push_back(LogicalType::STRING()); + return std::make_unique(std::move(paramTypes), resultType.copy()); + } else { + return FunctionBindData::getSimpleBindData(input.arguments, resultType); + } +} + +function_set SizeFunction::getFunctionSet() { + function_set result; + // size(list) + auto listFunc = std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::INT64, + ScalarFunction::UnaryExecFunction); + listFunc->bindFunc = sizeBindFunc; + result.push_back(std::move(listFunc)); + // size(array) + auto arrayFunc = std::make_unique(name, + std::vector{ + LogicalTypeID::ARRAY, + }, + LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction); + arrayFunc->bindFunc = sizeBindFunc; + result.push_back(std::move(arrayFunc)); + // size(map) + auto mapFunc = std::make_unique(name, + std::vector{LogicalTypeID::MAP}, LogicalTypeID::INT64, + ScalarFunction::UnaryExecFunction); + mapFunc->bindFunc = sizeBindFunc; + result.push_back(std::move(mapFunc)); + // size(string) + auto strFunc = + std::make_unique(name, std::vector{LogicalTypeID::STRING}, + LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction); + strFunc->bindFunc = sizeBindFunc; + result.push_back(std::move(strFunc)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/CMakeLists.txt new file mode 100644 index 0000000000..3a23f9eed4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(lbug_function_map + OBJECT + map_creation_function.cpp + map_extract_function.cpp + map_keys_function.cpp + map_values_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_creation_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_creation_function.cpp new file mode 100644 index 0000000000..264a778868 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_creation_function.cpp @@ -0,0 +1,31 @@ +#include "function/map/functions/map_creation_function.h" + +#include "function/map/vector_map_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + const auto& keyType = ListType::getChildType(input.arguments[0]->dataType); + const auto& valueType = ListType::getChildType(input.arguments[1]->dataType); + auto resultType = LogicalType::MAP(keyType.copy(), valueType.copy()); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +function_set MapCreationFunctions::getFunctionSet() { + auto execFunc = ScalarFunction::BinaryExecWithBindData; + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::MAP, + execFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_extract_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_extract_function.cpp new file mode 100644 index 0000000000..2a406fee74 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_extract_function.cpp @@ -0,0 +1,42 @@ +#include "function/map/functions/map_extract_function.h" + +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "function/map/vector_map_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static void validateKeyType(const std::shared_ptr& mapExpression, + const std::shared_ptr& extractKeyExpression) { + const auto& mapKeyType = MapType::getKeyType(mapExpression->dataType); + if (mapKeyType != extractKeyExpression->dataType) { + throw RuntimeException("Unmatched map key type and extract key type"); + } +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + validateKeyType(input.arguments[0], input.arguments[1]); + auto scalarFunction = ku_dynamic_cast(input.definition); + TypeUtils::visit(input.arguments[1]->getDataType().getPhysicalType(), [&](T) { + scalarFunction->execFunc = + ScalarFunction::BinaryExecListStructFunction; + }); + auto resultType = LogicalType::LIST(MapType::getValueType(input.arguments[0]->dataType).copy()); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +function_set MapExtractFunctions::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::MAP, LogicalTypeID::ANY}, LogicalTypeID::LIST); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_keys_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_keys_function.cpp new file mode 100644 index 0000000000..7b9bc89857 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_keys_function.cpp @@ -0,0 +1,28 @@ +#include "function/map/functions/map_keys_function.h" + +#include "function/map/vector_map_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto resultType = LogicalType::LIST(MapType::getKeyType(input.arguments[0]->dataType).copy()); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +function_set MapKeysFunctions::getFunctionSet() { + auto execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::MAP}, LogicalTypeID::LIST, execFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_values_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_values_function.cpp new file mode 100644 index 0000000000..19a9960cab --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/map/map_values_function.cpp @@ -0,0 +1,28 @@ +#include "function/map/functions/map_values_function.h" + +#include "function/map/vector_map_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + auto resultType = LogicalType::LIST(MapType::getValueType(input.arguments[0]->dataType).copy()); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +function_set MapValuesFunctions::getFunctionSet() { + auto execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::MAP}, + LogicalTypeID::LIST, execFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/CMakeLists.txt new file mode 100644 index 0000000000..ec3c1af50c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_function_path + OBJECT + length_function.cpp + nodes_function.cpp + properties_function.cpp + rels_function.cpp + semantic_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/length_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/length_function.cpp new file mode 100644 index 0000000000..ad31cbb1da --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/length_function.cpp @@ -0,0 +1,58 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression_binder.h" +#include "common/types/value/value.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/path/vector_path_functions.h" +#include "function/rewrite_function.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + auto binder = input.expressionBinder; + if (param->expressionType == ExpressionType::PATH) { + int64_t numRels = 0u; + std::vector recursiveRels; + for (auto& child : param->getChildren()) { + if (ExpressionUtil::isRelPattern(*child)) { + numRels++; + } else if (ExpressionUtil::isRecursiveRelPattern(*child)) { + recursiveRels.push_back(child->constPtrCast()); + } + } + auto numRelsExpression = binder->createLiteralExpression(Value(numRels)); + if (recursiveRels.empty()) { + return numRelsExpression; + } + expression_vector children; + children.push_back(std::move(numRelsExpression)); + children.push_back(recursiveRels[0]->getLengthExpression()); + auto result = binder->bindScalarFunctionExpression(children, AddFunction::name); + for (auto i = 1u; i < recursiveRels.size(); ++i) { + children[0] = std::move(result); + children[1] = recursiveRels[i]->getLengthExpression(); + result = binder->bindScalarFunctionExpression(children, AddFunction::name); + } + return result; + } else if (ExpressionUtil::isRecursiveRelPattern(*param)) { + return param->constPtrCast()->getLengthExpression(); + } + KU_UNREACHABLE; +} + +function_set LengthFunction::getFunctionSet() { + function_set result; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::RECURSIVE_REL}, rewriteFunc); + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/nodes_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/nodes_function.cpp new file mode 100644 index 0000000000..28391985fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/nodes_function.cpp @@ -0,0 +1,32 @@ +#include "binder/expression/expression_util.h" +#include "function/path/vector_path_functions.h" +#include "function/scalar_function.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + const auto& structType = input.arguments[0]->getDataType(); + auto fieldIdx = StructType::getFieldIdx(structType, InternalKeyword::NODES); + auto resultType = StructType::getField(structType, fieldIdx).getType().copy(); + auto bindData = std::make_unique(std::move(resultType), fieldIdx); + bindData->paramTypes = ExpressionUtil::getDataTypes(input.arguments); + return bindData; +} + +function_set NodesFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::ANY); + function->bindFunc = bindFunc; + function->compileFunc = StructExtractFunctions::compileFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/properties_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/properties_function.cpp new file mode 100644 index 0000000000..6f5933f511 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/properties_function.cpp @@ -0,0 +1,71 @@ +#include "binder/expression/literal_expression.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "common/vector/value_vector.h" +#include "function/path/vector_path_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + if (input.arguments[1]->expressionType != ExpressionType::LITERAL) { + throw BinderException(stringFormat( + "Expected literal input as the second argument for {}().", PropertiesFunction::name)); + } + auto literalExpr = input.arguments[1]->constPtrCast(); + auto key = literalExpr->getValue().getValue(); + const auto& listType = input.arguments[0]->getDataType(); + const auto& childType = ListType::getChildType(listType); + struct_field_idx_t fieldIdx = 0; + if (childType.getLogicalTypeID() == LogicalTypeID::NODE || + childType.getLogicalTypeID() == LogicalTypeID::REL) { + fieldIdx = StructType::getFieldIdx(childType, key); + if (fieldIdx == INVALID_STRUCT_FIELD_IDX) { + throw BinderException(stringFormat("Invalid property name: {}.", key)); + } + } else { + throw BinderException( + stringFormat("Cannot extract properties from {}.", listType.toString())); + } + const auto& field = StructType::getField(childType, fieldIdx); + auto returnType = LogicalType::LIST(field.getType().copy()); + auto bindData = std::make_unique(std::move(returnType), fieldIdx); + bindData->paramTypes.push_back(input.arguments[0]->getDataType().copy()); + bindData->paramTypes.push_back(LogicalType(input.definition->parameterTypeIDs[1])); + return bindData; +} + +static void compileFunc(FunctionBindData* bindData, + const std::vector>& parameters, + std::shared_ptr& result) { + KU_ASSERT(parameters[0]->dataType.getPhysicalType() == PhysicalTypeID::LIST); + auto& propertiesBindData = bindData->cast(); + auto fieldVector = StructVector::getFieldVector(ListVector::getDataVector(parameters[0].get()), + propertiesBindData.childIdx); + ListVector::setDataVector(result.get(), fieldVector); +} + +static void execFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + ListVector::copyListEntryAndBufferMetaData(result, *resultSelVector, *parameters[0], + *parameterSelVectors[0]); +} + +function_set PropertiesFunction::getFunctionSet() { + function_set functions; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::STRING}, LogicalTypeID::ANY, + execFunc); + function->bindFunc = bindFunc; + function->compileFunc = compileFunc; + functions.push_back(std::move(function)); + return functions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/rels_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/rels_function.cpp new file mode 100644 index 0000000000..d11e2fd1f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/rels_function.cpp @@ -0,0 +1,32 @@ +#include "binder/expression/expression_util.h" +#include "function/path/vector_path_functions.h" +#include "function/scalar_function.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + const auto& structType = input.arguments[0]->getDataType(); + auto fieldIdx = StructType::getFieldIdx(structType, InternalKeyword::RELS); + auto resultType = StructType::getField(structType, fieldIdx).getType().copy(); + auto bindData = std::make_unique(std::move(resultType), fieldIdx); + bindData->paramTypes = binder::ExpressionUtil::getDataTypes(input.arguments); + return bindData; +} + +function_set RelsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::ANY); + function->bindFunc = bindFunc; + function->compileFunc = StructExtractFunctions::compileFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/semantic_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/semantic_function.cpp new file mode 100644 index 0000000000..2e541cdb66 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/path/semantic_function.cpp @@ -0,0 +1,59 @@ +#include "common/vector/value_vector.h" +#include "function/path/path_function_executor.h" +#include "function/path/vector_path_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::BOOL()); +} + +static void IsTrailExecFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector*, void* /*dataPtr*/) { + UnaryPathExecutor::executeRelIDs(*parameters[0], *parameterSelVectors[0], result); +} + +static bool IsTrailSelectFunc(const std::vector>& parameters, + SelectionVector& selectionVector, void* /* dataPtr */) { + return UnaryPathExecutor::selectRelIDs(*parameters[0], selectionVector); +} + +function_set IsTrailFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::BOOL, + IsTrailExecFunc, IsTrailSelectFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +static void IsACyclicExecFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector*, void* /*dataPtr*/) { + UnaryPathExecutor::executeNodeIDs(*parameters[0], *parameterSelVectors[0], result); +} + +static bool IsACyclicSelectFunc(const std::vector>& parameters, + SelectionVector& selectionVector, void* /* dataPtr */) { + return UnaryPathExecutor::selectNodeIDs(*parameters[0], selectionVector); +} + +function_set IsACyclicFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::BOOL, + IsACyclicExecFunc, IsACyclicSelectFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/CMakeLists.txt new file mode 100644 index 0000000000..6710235cb2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(lbug_function_pattern + OBJECT + cost_function.cpp + id_function.cpp + label_function.cpp + start_end_node_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/cost_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/cost_function.cpp new file mode 100644 index 0000000000..7bcd53a2f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/cost_function.cpp @@ -0,0 +1,33 @@ +#include "binder/expression/rel_expression.h" +#include "common/exception/binder.h" +#include "function/rewrite_function.h" +#include "function/schema/vector_node_rel_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + KU_ASSERT(param->getDataType().getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL); + auto recursiveInfo = param->ptrCast()->getRecursiveInfo(); + if (recursiveInfo->bindData->weightOutputExpr == nullptr) { + throw BinderException( + stringFormat("Cost function is not defined for {}", param->toString())); + } + return recursiveInfo->bindData->weightOutputExpr; +} + +function_set CostFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::RECURSIVE_REL}, rewriteFunc); + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/id_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/id_function.cpp new file mode 100644 index 0000000000..0e3cccd4e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/id_function.cpp @@ -0,0 +1,45 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression_binder.h" +#include "function/rewrite_function.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + if (ExpressionUtil::isNodePattern(*param)) { + auto node = param->constPtrCast(); + return node->getInternalID(); + } + if (ExpressionUtil::isRelPattern(*param)) { + auto rel = param->constPtrCast(); + return rel->getPropertyExpression(InternalKeyword::ID); + } + // Bind as struct_extract(param, "_id") + auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::ID); + return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey}, + StructExtractFunctions::name); +} + +function_set IDFunction::getFunctionSet() { + function_set functionSet; + auto inputTypes = + std::vector{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT}; + for (auto& inputType : inputTypes) { + auto function = std::make_unique(name, + std::vector{inputType}, rewriteFunc); + functionSet.push_back(std::move(function)); + } + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/label_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/label_function.cpp new file mode 100644 index 0000000000..4bf566fc33 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/label_function.cpp @@ -0,0 +1,150 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "function/binary_function_executor.h" +#include "function/list/functions/list_extract_function.h" +#include "function/rewrite_function.h" +#include "function/scalar_function.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +struct Label { + static void operation(internalID_t& left, list_entry_t& right, ku_string_t& result, + ValueVector& leftVector, ValueVector& rightVector, ValueVector& resultVector, + uint64_t resPos) { + KU_ASSERT(left.tableID < right.size); + ListExtract::operation(right, left.tableID + 1 /* listExtract requires 1-based index */, + result, rightVector, leftVector, resultVector, resPos); + } +}; + +static void execFunction(const std::vector>& params, + const std::vector& paramSelVectors, ValueVector& result, + SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector, dataPtr); +} + +static std::shared_ptr getLabelsAsLiteral( + std::unordered_map map, ExpressionBinder* expressionBinder) { + table_id_t maxTableID = 0; + for (auto [id, name] : map) { + if (id > maxTableID) { + maxTableID = id; + } + } + std::vector> labels; + labels.resize(maxTableID + 1); + for (auto i = 0u; i < labels.size(); ++i) { + if (map.contains(i)) { + labels[i] = std::make_unique(LogicalType::STRING(), map.at(i)); + } else { + labels[i] = std::make_unique(LogicalType::STRING(), std::string("")); + } + } + auto labelsValue = Value(LogicalType::LIST(LogicalType::STRING()), std::move(labels)); + return expressionBinder->createLiteralExpression(labelsValue); +} + +static std::unordered_map getNodeTableIDToLabel( + std::vector entries) { + std::unordered_map map; + for (auto& entry : entries) { + map.insert({entry->getTableID(), entry->getName()}); + } + return map; +} + +static std::unordered_map getRelTableIDToLabel( + std::vector entries) { + std::unordered_map map; + for (auto& entry : entries) { + auto& relGroupEntry = entry->constCast(); + for (auto& relEntryInfo : relGroupEntry.getRelEntryInfos()) { + map.insert({relEntryInfo.oid, entry->getName()}); + } + } + return map; +} + +std::shared_ptr LabelFunction::rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto argument = input.arguments[0].get(); + auto expressionBinder = input.expressionBinder; + if (ExpressionUtil::isNullLiteral(*argument)) { + return expressionBinder->createNullLiteralExpression(); + } + expression_vector children; + if (argument->expressionType == ExpressionType::VARIABLE) { + children.push_back(input.arguments[0]); + children.push_back(expressionBinder->createLiteralExpression(InternalKeyword::LABEL)); + return expressionBinder->bindScalarFunctionExpression(children, + StructExtractFunctions::name); + } + auto disableLiteralRewrite = expressionBinder->getConfig().disableLabelFunctionLiteralRewrite; + if (ExpressionUtil::isNodePattern(*argument)) { + auto& node = argument->constCast(); + if (!disableLiteralRewrite) { + if (node.isEmpty()) { + return expressionBinder->createLiteralExpression(""); + } + if (!node.isMultiLabeled()) { + auto label = node.getEntry(0)->getName(); + return expressionBinder->createLiteralExpression(label); + } + } + children.push_back(node.getInternalID()); + auto map = getNodeTableIDToLabel(node.getEntries()); + children.push_back(getLabelsAsLiteral(map, expressionBinder)); + } else if (ExpressionUtil::isRelPattern(*argument)) { + auto& rel = argument->constCast(); + if (!disableLiteralRewrite) { + if (rel.isEmpty()) { + return expressionBinder->createLiteralExpression(""); + } + if (!rel.isMultiLabeled()) { + auto label = rel.getEntry(0)->getName(); + return expressionBinder->createLiteralExpression(label); + } + } + children.push_back(rel.getInternalID()); + auto map = getRelTableIDToLabel(rel.getEntries()); + children.push_back(getLabelsAsLiteral(map, expressionBinder)); + } + KU_ASSERT(children.size() == 2); + auto function = std::make_unique(LabelFunction::name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, execFunction); + auto bindData = std::make_unique(LogicalType::STRING()); + auto uniqueName = ScalarFunctionExpression::getUniqueName(LabelFunction::name, children); + return std::make_shared(ExpressionType::FUNCTION, std::move(function), + std::move(bindData), std::move(children), uniqueName); +} + +function_set LabelFunction::getFunctionSet() { + function_set set; + auto inputTypes = + std::vector{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT}; + for (auto& inputType : inputTypes) { + auto function = std::make_unique(name, + std::vector{inputType}, rewriteFunc); + set.push_back(std::move(function)); + } + return set; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/start_end_node_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/start_end_node_function.cpp new file mode 100644 index 0000000000..d559668d93 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/pattern/start_end_node_function.cpp @@ -0,0 +1,53 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression_binder.h" +#include "function/rewrite_function.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +static std::shared_ptr startRewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + if (ExpressionUtil::isRelPattern(*param)) { + return param->constCast().getSrcNode(); + } + auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::SRC); + return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey}, + StructExtractFunctions::name); +} + +function_set StartNodeFunction::getFunctionSet() { + function_set set; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::REL}, startRewriteFunc); + set.push_back(std::move(function)); + return set; +} + +static std::shared_ptr endRewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + if (ExpressionUtil::isRelPattern(*param)) { + return param->constCast().getDstNode(); + } + auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::DST); + return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey}, + StructExtractFunctions::name); +} + +function_set EndNodeFunction::getFunctionSet() { + function_set set; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::REL}, endRewriteFunc); + set.push_back(std::move(function)); + return set; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/scalar_macro_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/scalar_macro_function.cpp new file mode 100644 index 0000000000..69de6a7306 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/scalar_macro_function.cpp @@ -0,0 +1,72 @@ +#include "function/scalar_macro_function.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/string_format.h" +#include "common/string_utils.h" + +using namespace lbug::common; +using namespace lbug::parser; + +namespace lbug { +namespace function { + +macro_parameter_value_map ScalarMacroFunction::getDefaultParameterVals() const { + macro_parameter_value_map defaultArgsToReturn; + for (auto& defaultArg : defaultArgs) { + defaultArgsToReturn.emplace(defaultArg.first, defaultArg.second.get()); + } + return defaultArgsToReturn; +} + +std::unique_ptr ScalarMacroFunction::copy() const { + default_macro_args defaultArgsCopy; + for (auto& defaultArg : defaultArgs) { + defaultArgsCopy.emplace_back(defaultArg.first, defaultArg.second->copy()); + } + return std::make_unique(expression->copy(), positionalArgs, + std::move(defaultArgsCopy)); +} + +void ScalarMacroFunction::serialize(Serializer& serializer) const { + expression->serialize(serializer); + serializer.serializeVector(positionalArgs); + uint64_t vectorSize = defaultArgs.size(); + serializer.serializeValue(vectorSize); + for (auto& defaultArg : defaultArgs) { + serializer.serializeValue(defaultArg.first); + defaultArg.second->serialize(serializer); + } +} + +std::unique_ptr ScalarMacroFunction::deserialize(Deserializer& deserializer) { + auto expression = ParsedExpression::deserialize(deserializer); + std::vector positionalArgs; + deserializer.deserializeVector(positionalArgs); + default_macro_args defaultArgs; + uint64_t vectorSize = 0; + deserializer.deserializeValue(vectorSize); + defaultArgs.reserve(vectorSize); + for (auto i = 0u; i < vectorSize; i++) { + std::string key; + deserializer.deserializeValue(key); + auto val = ParsedExpression::deserialize(deserializer); + defaultArgs.emplace_back(std::move(key), std::move(val)); + } + return std::make_unique(std::move(expression), std::move(positionalArgs), + std::move(defaultArgs)); +} + +std::string ScalarMacroFunction::toCypher(const std::string& name) const { + std::vector paramStrings; + for (auto& param : positionalArgs) { + paramStrings.push_back(param); + } + for (auto& defaultParam : defaultArgs) { + paramStrings.push_back(defaultParam.first + ":=" + defaultParam.second->toString()); + } + return stringFormat("CREATE MACRO `{}` ({}) AS {};", name, StringUtils::join(paramStrings, ","), + expression->toString()); +} +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/sequence/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/sequence/CMakeLists.txt new file mode 100644 index 0000000000..893cf979eb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/sequence/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_function_sequence + OBJECT + sequence_functions.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/sequence/sequence_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/sequence/sequence_functions.cpp new file mode 100644 index 0000000000..c63de9d9af --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/sequence/sequence_functions.cpp @@ -0,0 +1,59 @@ +#include "function/sequence/sequence_functions.h" + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "function/scalar_function.h" +#include "main/client_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct CurrVal { + static void operation(ku_string_t& input, ValueVector& result, void* dataPtr) { + auto ctx = reinterpret_cast(dataPtr)->clientContext; + auto catalog = catalog::Catalog::Get(*ctx); + auto transaction = transaction::Transaction::Get(*ctx); + auto sequenceName = input.getAsString(); + auto sequenceEntry = + catalog->getSequenceEntry(transaction, sequenceName, ctx->useInternalCatalogEntry()); + result.setValue(0, sequenceEntry->currVal()); + } +}; + +struct NextVal { + static void operation(ku_string_t& input, ValueVector& result, void* dataPtr) { + auto ctx = reinterpret_cast(dataPtr)->clientContext; + auto cnt = reinterpret_cast(dataPtr)->count; + auto catalog = catalog::Catalog::Get(*ctx); + auto transaction = transaction::Transaction::Get(*ctx); + auto sequenceName = input.getAsString(); + auto sequenceEntry = + catalog->getSequenceEntry(transaction, sequenceName, ctx->useInternalCatalogEntry()); + sequenceEntry->nextKVal(transaction, cnt, result); + result.state->getSelVectorUnsafe().setSelSize(cnt); + } +}; + +function_set CurrValFunction::getFunctionSet() { + function_set functionSet; + functionSet.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::INT64, + ScalarFunction::UnarySequenceExecFunction)); + return functionSet; +} + +function_set NextValFunction::getFunctionSet() { + function_set functionSet; + auto func = make_unique(name, std::vector{LogicalTypeID::STRING}, + LogicalTypeID::INT64, + ScalarFunction::UnarySequenceExecFunction); + func->isReadOnly = false; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/CMakeLists.txt new file mode 100644 index 0000000000..83357e41d4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(lbug_string_function + OBJECT + concat_ws.cpp + string_split_function.cpp + init_cap_function.cpp + levenshtein_function.cpp + split_part.cpp + regex_full_match_function.cpp + regex_replace_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/concat_ws.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/concat_ws.cpp new file mode 100644 index 0000000000..a1f50bfb38 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/concat_ws.cpp @@ -0,0 +1,111 @@ +#include "common/exception/binder.h" +#include "function/string/vector_string_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + if (input.arguments.size() < 2) { + throw BinderException{stringFormat("concat_ws expects at least two parameters. Got: {}.", + input.arguments.size())}; + } + for (auto i = 0u; i < input.arguments.size(); i++) { + auto& argument = input.arguments[i]; + if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + argument->cast(LogicalType::STRING()); + } + if (argument->getDataType() != LogicalType::STRING()) { + throw BinderException{stringFormat("concat_ws expects all string parameters. Got: {}.", + argument->getDataType().toString())}; + } + } + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::STRING()); +} + +using handle_separator_func_t = std::function; +using handle_element_func_t = std::function; + +static void iterateParams(const std::vector>& parameters, + const std::vector& parameterSelVectors, sel_t pos, + handle_separator_func_t handleSeparatorFunc, handle_element_func_t handleElementFunc) { + bool isPrevNull = false; + for (auto i = 1u; i < parameters.size(); i++) { + const auto& parameter = parameters[i]; + const auto& parameterSelVector = *parameterSelVectors[i]; + auto paramPos = parameter->state->isFlat() ? parameterSelVector[0] : pos; + if (parameter->isNull(paramPos)) { + isPrevNull = true; + continue; + } + if (i != 1u && !isPrevNull) { + handleSeparatorFunc(); + } + handleElementFunc(parameter->getValue(paramPos)); + isPrevNull = false; + } +} + +void execFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + result.resetAuxiliaryBuffer(); + for (auto selectedPos = 0u; selectedPos < resultSelVector->getSelSize(); ++selectedPos) { + auto pos = (*resultSelVector)[selectedPos]; + auto separatorPos = parameters[0]->state->isFlat() ? (*parameterSelVectors[0])[0] : pos; + if (parameters[0]->isNull(separatorPos)) { + result.setNull(pos, true /* isNull */); + continue; + } + auto separator = parameters[0]->getValue(separatorPos); + auto len = 0u; + bool isPrevNull = false; + iterateParams( + parameters, parameterSelVectors, pos, [&]() { len += separator.len; }, + [&](const ku_string_t& str) { len += str.len; }); + for (auto i = 1u; i < parameters.size(); i++) { + const auto& parameter = parameters[i]; + const auto& parameterSelVector = *parameterSelVectors[i]; + auto paramPos = parameter->state->isFlat() ? parameterSelVector[0] : pos; + if (parameter->isNull(paramPos)) { + isPrevNull = true; + continue; + } + if (i != 1u && !isPrevNull) {} + + isPrevNull = false; + } + common::ku_string_t resultStr; + StringVector::reserveString(&result, resultStr, len); + auto resultBuffer = resultStr.getData(); + iterateParams( + parameters, parameterSelVectors, pos, + [&]() { + memcpy((void*)resultBuffer, (void*)separator.getData(), separator.len); + resultBuffer += separator.len; + }, + [&](const ku_string_t& str) { + memcpy((void*)resultBuffer, (void*)str.getData(), str.len); + resultBuffer += str.len; + }); + memcpy(resultStr.prefix, resultStr.getData(), + std::min(resultStr.len, ku_string_t::PREFIX_LENGTH)); + KU_ASSERT(resultBuffer - resultStr.getData() == len); + result.setNull(pos, false /* isNull */); + result.setValue(pos, resultStr); + } +} + +function_set ConcatWSFunction::getFunctionSet() { + function_set functionSet; + auto func = make_unique(name, std::vector{LogicalTypeID::STRING}, + LogicalTypeID::STRING, execFunc); + func->bindFunc = bindFunc; + func->isVarLength = true; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/init_cap_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/init_cap_function.cpp new file mode 100644 index 0000000000..c34e1098b5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/init_cap_function.cpp @@ -0,0 +1,27 @@ +#include "function/string/vector_string_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct InitCap { + static void operation(ku_string_t& operand, ku_string_t& result, ValueVector& resultVector) { + Lower::operation(operand, result, resultVector); + int originalSize = 0, newSize = 0; + BaseLowerUpperFunction::convertCharCase(reinterpret_cast(result.getDataUnsafe()), + reinterpret_cast(result.getData()), 0 /* charPos */, true /* toUpper */, + originalSize, newSize); + } +}; + +function_set InitCapFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::STRING, + ScalarFunction::UnaryStringExecFunction)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/levenshtein_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/levenshtein_function.cpp new file mode 100644 index 0000000000..956d370171 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/levenshtein_function.cpp @@ -0,0 +1,63 @@ +#include "function/string/vector_string_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct Levenshtein { +public: + static void operation(common::ku_string_t& left, common::ku_string_t& right, int64_t& result) { + // If one string is empty, the distance equals the length of the other string. + if (left.len == 0 || right.len == 0) { + result = left.len + right.len; + return; + } + + auto leftStr = left.getData(); + auto rightStr = right.getData(); + + std::vector distances0(right.len + 1, 0); + std::vector distances1(right.len + 1, 0); + + uint64_t substitutionCost = 0; + uint64_t insertionCost = 0; + uint64_t deletionCost = 0; + + for (auto i = 0u; i <= right.len; i++) { + distances0[i] = i; + } + + for (auto i = 0u; i < left.len; i++) { + distances1[0] = i + 1; + for (auto j = 0u; j < right.len; j++) { + deletionCost = distances0[j + 1] + 1; + insertionCost = distances1[j] + 1; + substitutionCost = distances0[j]; + + if (leftStr[i] != rightStr[j]) { + substitutionCost += 1; + } + distances1[j + 1] = + std::min(deletionCost, std::min(substitutionCost, insertionCost)); + } + // Copy distances1 (current row) to distances0 (previous row) for next iteration + // since data in distances1 is always invalidated, a swap without copy is more + // efficient. + distances0 = distances1; + } + result = distances0[right.len]; + } +}; + +function_set LevenshteinFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::INT64, + ScalarFunction::BinaryExecFunction)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/regex_full_match_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/regex_full_match_function.cpp new file mode 100644 index 0000000000..3c6b96aeef --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/regex_full_match_function.cpp @@ -0,0 +1,73 @@ +#include "binder/expression/expression_util.h" +#include "expression_evaluator/expression_evaluator_utils.h" +#include "function/string/functions/base_regexp_function.h" +#include "function/string/vector_string_functions.h" +#include "re2.h" + +namespace lbug { +namespace function { + +using namespace common; + +struct RegexFullMatchBindData : public FunctionBindData { + regex::RE2 pattern; + + explicit RegexFullMatchBindData(common::logical_type_vec_t paramTypes, std::string patternInStr) + : FunctionBindData{std::move(paramTypes), common::LogicalType::BOOL()}, + pattern{patternInStr} {} + + std::unique_ptr copy() const override { + return std::make_unique(copyVector(paramTypes), pattern.pattern()); + } +}; + +struct RegexpFullMatch { + static void operation(common::ku_string_t& left, common::ku_string_t& right, uint8_t& result) { + result = RE2::FullMatch(left.getAsString(), + BaseRegexpOperation::parseCypherPattern(right.getAsString())); + } +}; + +struct RegexpFullMatchStaticPattern : BaseRegexpOperation { + static void operation(common::ku_string_t& left, common::ku_string_t& /*right*/, + uint8_t& result, common::ValueVector& /*leftValueVector*/, + common::ValueVector& /*rightValueVector*/, common::ValueVector& /*resultValueVector*/, + void* dataPtr) { + auto regexFullMatchBindData = reinterpret_cast(dataPtr); + result = RE2::FullMatch(left.getAsString(), regexFullMatchBindData->pattern); + } +}; + +static std::unique_ptr regexFullMatchBindFunc(const ScalarBindFuncInput& input) { + if (input.arguments[1]->expressionType == ExpressionType::LITERAL) { + auto value = evaluator::ExpressionEvaluatorUtils::evaluateConstantExpression( + input.arguments[1], input.context); + input.definition->ptrCast()->execFunc = + ScalarFunction::BinaryExecWithBindData; + input.definition->ptrCast()->selectFunc = + ScalarFunction::BinarySelectWithBindData; + auto patternInStr = value.getValue(); + return std::make_unique( + binder::ExpressionUtil::getDataTypes(input.arguments), + BaseRegexpOperation::parseCypherPattern(patternInStr)); + } else { + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::BOOL()); + } +} + +function_set RegexpFullMatchFunction::getFunctionSet() { + function_set functionSet; + auto scalarFunc = make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, + ScalarFunction::BinaryExecFunction, + ScalarFunction::BinarySelectFunction); + scalarFunc->bindFunc = regexFullMatchBindFunc; + functionSet.emplace_back(std::move(scalarFunc)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/regex_replace_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/regex_replace_function.cpp new file mode 100644 index 0000000000..70e6103c2d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/regex_replace_function.cpp @@ -0,0 +1,142 @@ +#include "binder/expression/expression_util.h" +#include "common/exception/binder.h" +#include "expression_evaluator/expression_evaluator_utils.h" +#include "function/string/functions/base_regexp_function.h" +#include "function/string/vector_string_functions.h" +#include "re2.h" + +namespace lbug { +namespace function { + +using namespace common; + +using re2_replace_func_t = + std::function; + +struct RegexReplaceBindData : public FunctionBindData { + re2_replace_func_t replaceFunc; + + RegexReplaceBindData(common::logical_type_vec_t paramTypes, re2_replace_func_t replaceFunc) + : FunctionBindData{std::move(paramTypes), common::LogicalType::STRING()}, + replaceFunc{std::move(replaceFunc)} {} + + std::unique_ptr copy() const override { + return std::make_unique(copyVector(paramTypes), replaceFunc); + } +}; + +struct RegexpReplace { + static void operation(common::ku_string_t& value, common::ku_string_t& pattern, + common::ku_string_t& replacement, common::ku_string_t& result, + common::ValueVector& resultValueVector, void* dataPtr) { + auto bindData = reinterpret_cast(dataPtr); + std::string resultStr = value.getAsString(); + RE2 re2Pattern{pattern.getAsString()}; + bindData->replaceFunc(&resultStr, re2Pattern, replacement.getAsString()); + BaseRegexpOperation::copyToLbugString(resultStr, result, resultValueVector); + } +}; + +struct RegexReplaceBindDataStaticPattern : public RegexReplaceBindData { + regex::RE2 pattern; + + RegexReplaceBindDataStaticPattern(common::logical_type_vec_t paramTypes, + re2_replace_func_t replaceFunc, std::string patternInStr) + : RegexReplaceBindData{std::move(paramTypes), std::move(replaceFunc)}, + pattern{patternInStr} {} + + std::unique_ptr copy() const override { + return std::make_unique(copyVector(paramTypes), + replaceFunc, pattern.pattern()); + } +}; + +struct RegexpReplaceStaticPattern { + static void operation(common::ku_string_t& value, common::ku_string_t& /*pattern*/, + common::ku_string_t& replacement, common::ku_string_t& result, + common::ValueVector& resultValueVector, void* dataPtr) { + auto bindData = reinterpret_cast(dataPtr); + auto resultStr = value.getAsString(); + bindData->replaceFunc(&resultStr, bindData->pattern, replacement.getAsString()); + BaseRegexpOperation::copyToLbugString(resultStr, result, resultValueVector); + } +}; + +static re2_replace_func_t bindReplaceFunc(const binder::expression_vector& expr) { + re2_replace_func_t result; + switch (expr.size()) { + case 3: { + result = RE2::Replace; + } break; + case 4: { + result = RE2::GlobalReplace; + } break; + default: + KU_UNREACHABLE; + } + return result; +} + +template +scalar_func_exec_t getExecFunc(const binder::expression_vector& expr) { + scalar_func_exec_t execFunc; + switch (expr.size()) { + case 3: { + execFunc = ScalarFunction::TernaryRegexExecFunction; + } break; + case 4: { + auto option = expr[3]; + binder::ExpressionUtil::validateExpressionType(*option, ExpressionType::LITERAL); + binder::ExpressionUtil::validateDataType(*option, LogicalType::STRING()); + auto optionVal = binder::ExpressionUtil::getLiteralValue(*option); + if (optionVal != RegexpReplaceFunction::GLOBAL_REPLACE_OPTION) { + throw common::BinderException{ + "regex_replace can only support global replace option: g."}; + } + execFunc = ScalarFunction::TernaryRegexExecFunction; + } break; + default: + KU_UNREACHABLE; + } + return execFunc; +} + +std::unique_ptr bindFunc(ScalarBindFuncInput input) { + auto definition = input.definition->ptrCast(); + re2_replace_func_t replaceFunc = bindReplaceFunc(input.arguments); + if (input.arguments[1]->expressionType == ExpressionType::LITERAL) { + definition->execFunc = getExecFunc(input.arguments); + auto value = evaluator::ExpressionEvaluatorUtils::evaluateConstantExpression( + input.arguments[1], input.context); + return std::make_unique( + binder::ExpressionUtil::getDataTypes(input.arguments), std::move(replaceFunc), + BaseRegexpOperation::parseCypherPattern(value.getValue())); + } else { + definition->execFunc = getExecFunc(input.arguments); + return std::make_unique( + binder::ExpressionUtil::getDataTypes(input.arguments), std::move(replaceFunc)); + } +} + +function_set RegexpReplaceFunction::getFunctionSet() { + function_set functionSet; + std::unique_ptr func; + func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::STRING); + func->bindFunc = bindFunc; + functionSet.emplace_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::STRING}, + LogicalTypeID::STRING); + func->bindFunc = bindFunc; + functionSet.emplace_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/split_part.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/split_part.cpp new file mode 100644 index 0000000000..1ce825a682 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/split_part.cpp @@ -0,0 +1,37 @@ +#include "common/string_utils.h" +#include "function/string/vector_string_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct SplitPart { + static void operation(ku_string_t& strToSplit, ku_string_t& separator, int64_t idx, + ku_string_t& result, ValueVector& resultVector) { + auto splitStrVec = StringUtils::split(strToSplit.getAsString(), separator.getAsString()); + bool idxOutOfRange = idx <= 0 || (uint64_t)idx > splitStrVec.size(); + std::string resultStr = idxOutOfRange ? "" : splitStrVec[idx - 1]; + StringVector::addString(&resultVector, result, resultStr); + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::STRING()); +} + +function_set SplitPartFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::TernaryStringExecFunction); + function->bindFunc = bindFunc; + functionSet.emplace_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/string_split_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/string_split_function.cpp new file mode 100644 index 0000000000..0015e28385 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/string/string_split_function.cpp @@ -0,0 +1,38 @@ +#include "common/string_utils.h" +#include "function/string/vector_string_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct StringSplit { + static void operation(ku_string_t& strToSplit, ku_string_t& separator, list_entry_t& result, + ValueVector& resultVector) { + auto splitStrVec = StringUtils::split(strToSplit.getAsString(), separator.getAsString()); + result = ListVector::addList(&resultVector, splitStrVec.size()); + for (auto i = 0u; i < result.size; i++) { + ListVector::getDataVector(&resultVector)->setValue(result.offset + i, splitStrVec[i]); + } + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + return FunctionBindData::getSimpleBindData(input.arguments, + LogicalType::LIST(LogicalType::STRING())); +} + +function_set StringSplitFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::LIST, + ScalarFunction::BinaryStringExecFunction); + function->bindFunc = bindFunc; + functionSet.emplace_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/CMakeLists.txt new file mode 100644 index 0000000000..9d7fc3fc54 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_function_struct + OBJECT + struct_extract_function.cpp + struct_pack_function.cpp + keys_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/keys_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/keys_function.cpp new file mode 100644 index 0000000000..8a10992012 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/keys_function.cpp @@ -0,0 +1,53 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/literal_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "function/rewrite_function.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto argument = input.arguments[0].get(); + auto expressionBinder = input.expressionBinder; + if (ExpressionUtil::isNullLiteral(*argument)) { + return expressionBinder->createNullLiteralExpression(); + } + auto uniqueExpressionName = + ScalarFunctionExpression::getUniqueName(KeysFunctions::name, input.arguments); + const auto& resultType = LogicalType::LIST(LogicalType::STRING()); + auto fields = common::StructType::getFieldNames(input.arguments[0]->dataType); + std::vector> children; + for (auto field : fields) { + if (field == InternalKeyword::ID || field == InternalKeyword::LABEL || + field == InternalKeyword::SRC || field == InternalKeyword::DST) { + continue; + } + children.push_back(std::make_unique(field)); + } + auto resultExpr = std::make_shared( + Value{resultType.copy(), std::move(children)}, std::move(uniqueExpressionName)); + return resultExpr; +} + +static std::unique_ptr getKeysFunction(LogicalTypeID logicalTypeID) { + return std::make_unique(KeysFunctions::name, + std::vector{logicalTypeID}, rewriteFunc); +} + +function_set KeysFunctions::getFunctionSet() { + function_set functions; + auto inputTypeIDs = std::vector{LogicalTypeID::NODE, LogicalTypeID::REL}; + for (auto inputTypeID : inputTypeIDs) { + functions.push_back(getKeysFunction(inputTypeID)); + } + return functions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/struct_extract_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/struct_extract_function.cpp new file mode 100644 index 0000000000..d834f67765 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/struct_extract_function.cpp @@ -0,0 +1,61 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/literal_expression.h" +#include "common/exception/binder.h" +#include "function/scalar_function.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +std::unique_ptr StructExtractFunctions::bindFunc( + const ScalarBindFuncInput& input) { + const auto& structType = input.arguments[0]->getDataType(); + if (input.arguments[1]->expressionType != ExpressionType::LITERAL) { + throw BinderException("Key name for struct/union extract must be STRING literal."); + } + auto key = + input.arguments[1]->constPtrCast()->getValue().getValue(); + auto fieldIdx = StructType::getFieldIdx(structType, key); + if (fieldIdx == INVALID_STRUCT_FIELD_IDX) { + throw BinderException(stringFormat("Invalid struct field name: {}.", key)); + } + auto paramTypes = ExpressionUtil::getDataTypes(input.arguments); + auto resultType = StructType::getField(structType, fieldIdx).getType().copy(); + auto bindData = std::make_unique(std::move(resultType), fieldIdx); + bindData->paramTypes.push_back(input.arguments[0]->getDataType().copy()); + bindData->paramTypes.push_back(LogicalType(input.definition->parameterTypeIDs[1])); + return bindData; +} + +void StructExtractFunctions::compileFunc(FunctionBindData* bindData, + const std::vector>& parameters, + std::shared_ptr& result) { + KU_ASSERT(parameters[0]->dataType.getPhysicalType() == PhysicalTypeID::STRUCT); + auto& structBindData = bindData->cast(); + result = StructVector::getFieldVector(parameters[0].get(), structBindData.childIdx); + result->state = parameters[0]->state; +} + +static std::unique_ptr getStructExtractFunction(LogicalTypeID logicalTypeID) { + auto function = std::make_unique(StructExtractFunctions::name, + std::vector{logicalTypeID, LogicalTypeID::STRING}, LogicalTypeID::ANY); + function->bindFunc = StructExtractFunctions::bindFunc; + function->compileFunc = StructExtractFunctions::compileFunc; + return function; +} + +function_set StructExtractFunctions::getFunctionSet() { + function_set functions; + auto inputTypeIDs = + std::vector{LogicalTypeID::STRUCT, LogicalTypeID::NODE, LogicalTypeID::REL}; + for (auto inputTypeID : inputTypeIDs) { + functions.push_back(getStructExtractFunction(inputTypeID)); + } + return functions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/struct_pack_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/struct_pack_function.cpp new file mode 100644 index 0000000000..bea2928ab1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/struct/struct_pack_function.cpp @@ -0,0 +1,140 @@ +#include "common/exception/binder.h" +#include "function/scalar_function.h" +#include "function/struct/vector_struct_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector fields; + if (input.arguments.size() > INVALID_STRUCT_FIELD_IDX - 1) { + throw BinderException(stringFormat("Too many fields in STRUCT literal (max {}, got {})", + INVALID_STRUCT_FIELD_IDX - 1, input.arguments.size())); + } + std::unordered_set fieldNameSet; + for (auto i = 0u; i < input.arguments.size(); i++) { + auto& argument = input.arguments[i]; + if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + argument->cast(LogicalType::STRING()); + } + if (i >= input.optionalArguments.size()) { + throw BinderException( + stringFormat("Cannot infer field name for {}.", argument->toString())); + } + auto fieldName = input.optionalArguments[i]; + if (fieldNameSet.contains(fieldName)) { + throw BinderException(stringFormat("Found duplicate field {} in STRUCT.", fieldName)); + } else { + fieldNameSet.insert(fieldName); + } + fields.emplace_back(fieldName, argument->getDataType().copy()); + } + const auto resultType = LogicalType::STRUCT(std::move(fields)); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +void StructPackFunctions::compileFunc(FunctionBindData* /*bindData*/, + const std::vector>& parameters, + std::shared_ptr& result) { + // Our goal is to make the state of the resultVector consistent with its children vectors. + // If the resultVector and inputVector are in different dataChunks, we should create a new + // child valueVector, which shares the state with the resultVector, instead of reusing the + // inputVector. + for (auto i = 0u; i < parameters.size(); i++) { + if (parameters[i]->state == result->state) { + StructVector::referenceVector(result.get(), i, parameters[i]); + } + } +} + +void StructPackFunctions::undirectedRelCompileFunc(FunctionBindData*, + const std::vector>& parameters, + std::shared_ptr& result) { + // Skip src and dst reference because we may change their state + for (auto i = 2u; i < parameters.size(); i++) { + if (parameters[i]->state == result->state) { + StructVector::referenceVector(result.get(), i, parameters[i]); + } + } +} + +static void copyParameterValueToStructFieldVector(const ValueVector* parameter, + ValueVector* structField, DataChunkState* structVectorState) { + // If the parameter is unFlat, then its state must be consistent with the result's state. + // Thus, we don't need to copy values to structFieldVector. + KU_ASSERT(parameter->state->isFlat()); + auto paramPos = parameter->state->getSelVector()[0]; + if (structVectorState->isFlat()) { + auto pos = structVectorState->getSelVector()[0]; + structField->copyFromVectorData(pos, parameter, paramPos); + } else { + for (auto i = 0u; i < structVectorState->getSelVector().getSelSize(); i++) { + auto pos = structVectorState->getSelVector()[i]; + structField->copyFromVectorData(pos, parameter, paramPos); + } + } +} + +void StructPackFunctions::execFunc( + const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + for (auto i = 0u; i < parameters.size(); i++) { + auto* parameter = parameters[i].get(); + auto* parameterSelVector = parameterSelVectors[i]; + if (parameterSelVector == resultSelVector) { + continue; + } + // If the parameter's state is inconsistent with the result's state, we need to copy the + // parameter's value to the corresponding child vector. + StructVector::getFieldVector(&result, i)->resetAuxiliaryBuffer(); + copyParameterValueToStructFieldVector(parameter, + StructVector::getFieldVector(&result, i).get(), result.state.get()); + } +} + +void StructPackFunctions::undirectedRelPackExecFunc( + const std::vector>& parameters, ValueVector& result, void*) { + KU_ASSERT(parameters.size() > 1); + // Force copy of the src and internal id child vectors because we might modify them later. + for (auto i = 0u; i < 2; i++) { + auto& parameter = parameters[i]; + auto fieldVector = StructVector::getFieldVector(&result, i).get(); + fieldVector->resetAuxiliaryBuffer(); + if (parameter->state->isFlat()) { + copyParameterValueToStructFieldVector(parameter.get(), fieldVector, result.state.get()); + } else { + for (auto j = 0u; j < result.state->getSelVector().getSelSize(); j++) { + auto pos = result.state->getSelVector()[j]; + fieldVector->copyFromVectorData(pos, parameter.get(), pos); + } + } + } + for (auto i = 2u; i < parameters.size(); i++) { + auto& parameter = parameters[i]; + if (parameter->state == result.state) { + continue; + } + // If the parameter's state is inconsistent with the result's state, we need to copy the + // parameter's value to the corresponding child vector. + StructVector::getFieldVector(&result, i)->resetAuxiliaryBuffer(); + copyParameterValueToStructFieldVector(parameter.get(), + StructVector::getFieldVector(&result, i).get(), result.state.get()); + } +} + +function_set StructPackFunctions::getFunctionSet() { + function_set functions; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::STRUCT, execFunc); + function->bindFunc = bindFunc; + function->compileFunc = compileFunc; + function->isVarLength = true; + functions.push_back(std::move(function)); + return functions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/CMakeLists.txt new file mode 100644 index 0000000000..33eb35c4f3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/CMakeLists.txt @@ -0,0 +1,37 @@ +add_library(lbug_table_function + OBJECT + bind_data.cpp + bind_input.cpp + bm_info.cpp + cache_column.cpp + catalog_version.cpp + clear_warnings.cpp + current_setting.cpp + db_version.cpp + drop_project_graph.cpp + file_info.cpp + free_space_info.cpp + project_cypher_graph.cpp + project_native_graph.cpp + show_attached_databases.cpp + show_connection.cpp + show_functions.cpp + show_indexes.cpp + show_loaded_extensions.cpp + show_macros.cpp + show_official_extensions.cpp + show_projected_graphs.cpp + show_sequences.cpp + show_tables.cpp + show_warnings.cpp + stats_info.cpp + storage_info.cpp + simple_table_function.cpp + table_function.cpp + table_info.cpp + projected_graph_info.cpp + ) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bind_data.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bind_data.cpp new file mode 100644 index 0000000000..7db0a328bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bind_data.cpp @@ -0,0 +1,28 @@ +#include "function/table/bind_data.h" + +#include "common/constants.h" + +namespace lbug { +namespace function { + +std::vector TableFuncBindData::getColumnSkips() const { + if (columnSkips.empty()) { // If not specified, all columns should be scanned. + std::vector skips; + for (auto i = 0u; i < getNumColumns(); ++i) { + skips.push_back(false); + } + return skips; + } + return columnSkips; +} + +bool TableFuncBindData::getIgnoreErrorsOption() const { + return common::CopyConstants::DEFAULT_IGNORE_ERRORS; +} + +std::unique_ptr TableFuncBindData::copy() const { + return std::make_unique(*this); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bind_input.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bind_input.cpp new file mode 100644 index 0000000000..8e9df821f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bind_input.cpp @@ -0,0 +1,31 @@ +#include "function/table/bind_input.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/literal_expression.h" + +namespace lbug { +namespace function { + +void TableFuncBindInput::addLiteralParam(common::Value value) { + params.push_back(std::make_shared(std::move(value), "")); +} + +common::Value TableFuncBindInput::getValue(common::idx_t idx) const { + binder::ExpressionUtil::validateExpressionType(*params[idx], common::ExpressionType::LITERAL); + return params[idx]->constCast().getValue(); +} + +template +T TableFuncBindInput::getLiteralVal(common::idx_t idx) const { + return getValue(idx).getValue(); +} + +template LBUG_API std::string TableFuncBindInput::getLiteralVal( + common::idx_t idx) const; +template LBUG_API int64_t TableFuncBindInput::getLiteralVal(common::idx_t idx) const; +template LBUG_API uint64_t TableFuncBindInput::getLiteralVal(common::idx_t idx) const; +template LBUG_API uint32_t TableFuncBindInput::getLiteralVal(common::idx_t idx) const; +template LBUG_API uint8_t* TableFuncBindInput::getLiteralVal(common::idx_t idx) const; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bm_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bm_info.cpp new file mode 100644 index 0000000000..262d6a038b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/bm_info.cpp @@ -0,0 +1,58 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace function { + +struct BMInfoBindData final : TableFuncBindData { + uint64_t memLimit; + uint64_t memUsage; + + BMInfoBindData(uint64_t memLimit, uint64_t memUsage, binder::expression_vector columns) + : TableFuncBindData{std::move(columns), 1}, memLimit{memLimit}, memUsage{memUsage} {} + + std::unique_ptr copy() const override { + return std::make_unique(memLimit, memUsage, columns); + } +}; + +static common::offset_t internalTableFunc(const TableFuncMorsel& /*morsel*/, + const TableFuncInput& input, common::DataChunk& output) { + KU_ASSERT(output.getNumValueVectors() == 2); + auto bmInfoBindData = input.bindData->constPtrCast(); + output.getValueVectorMutable(0).setValue(0, bmInfoBindData->memLimit); + output.getValueVectorMutable(1).setValue(0, bmInfoBindData->memUsage); + return 1; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + auto memLimit = storage::MemoryManager::Get(*context)->getBufferManager()->getMemoryLimit(); + auto memUsage = storage::MemoryManager::Get(*context)->getBufferManager()->getUsedMemory(); + std::vector returnTypes; + returnTypes.emplace_back(common::LogicalType::UINT64()); + returnTypes.emplace_back(common::LogicalType::UINT64()); + auto returnColumnNames = std::vector{"mem_limit", "mem_usage"}; + returnColumnNames = + TableFunction::extractYieldVariables(returnColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(returnColumnNames, returnTypes); + return std::make_unique(memLimit, memUsage, columns); +} + +function_set BMInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/cache_column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/cache_column.cpp new file mode 100644 index 0000000000..1eb886968f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/cache_column.cpp @@ -0,0 +1,209 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/exception/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "processor/execution_context.h" +#include "storage/local_cached_column.h" +#include "storage/storage_manager.h" +#include "storage/table/list_chunk_data.h" +#include "storage/table/node_table.h" +#include "storage/table/table.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct CacheArrayColumnBindData final : TableFuncBindData { + catalog::TableCatalogEntry* tableEntry; + property_id_t propertyID; + + CacheArrayColumnBindData(catalog::TableCatalogEntry* tableEntry, property_id_t propertyID) + : tableEntry{tableEntry}, propertyID{propertyID} {} + + std::unique_ptr copy() const override { + return std::make_unique(tableEntry, propertyID); + } +}; + +static void validateArrayColumnType(const catalog::TableCatalogEntry* entry, + property_id_t propertyID) { + auto& type = entry->getProperty(propertyID).getType(); + if (type.getLogicalTypeID() != LogicalTypeID::ARRAY) { + throw BinderException{stringFormat("Column {} is not of the expected type {}.", + entry->getProperty(propertyID).getName(), + LogicalTypeUtils::toString(LogicalTypeID::ARRAY))}; + } +} + +static std::unique_ptr bindFunc(main::ClientContext* context, + const TableFuncBindInput* input) { + const auto tableName = input->getLiteralVal(0); + const auto columnName = input->getLiteralVal(1); + binder::Binder::validateTableExistence(*context, tableName); + const auto tableEntry = catalog::Catalog::Get(*context)->getTableCatalogEntry( + transaction::Transaction::Get(*context), tableName); + binder::Binder::validateNodeTableType(tableEntry); + binder::Binder::validateColumnExistence(tableEntry, columnName); + auto propertyID = tableEntry->getPropertyID(columnName); + validateArrayColumnType(tableEntry, propertyID); + return std::make_unique(tableEntry, propertyID); +} + +struct CacheArrayColumnSharedState final : public SimpleTableFuncSharedState { + explicit CacheArrayColumnSharedState(storage::NodeTable& table, + node_group_idx_t maxNodeGroupIdx, const CacheArrayColumnBindData& bindData) + : SimpleTableFuncSharedState{maxNodeGroupIdx, 1 /*maxMorselSize*/}, table{table} { + cachedColumn = std::make_unique(bindData.tableEntry->getTableID(), + bindData.propertyID); + cachedColumn->columnChunks.resize(maxNodeGroupIdx + 1); + } + + void merge(node_group_idx_t nodeGroupIdx, + std::unique_ptr columnChunkData) { + std::unique_lock lck{mtx}; + KU_ASSERT(cachedColumn->columnChunks.size() > nodeGroupIdx); + cachedColumn->columnChunks[nodeGroupIdx] = std::move(columnChunkData); + ++numNodeGroupsCached; + } + + std::mutex mtx; + storage::NodeTable& table; + std::unique_ptr cachedColumn; + std::atomic numNodeGroupsCached; +}; + +static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input) { + const auto bindData = input.bindData->constPtrCast(); + auto& table = storage::StorageManager::Get(*input.context->clientContext) + ->getTable(bindData->tableEntry->getTableID()) + ->cast(); + return std::make_unique(table, table.getNumCommittedNodeGroups(), + *bindData); +} + +struct CacheArrayColumnLocalState final : TableFuncLocalState { + CacheArrayColumnLocalState(const main::ClientContext& context, table_id_t tableID, + column_id_t columnID) + : dataChunk{2, std::make_shared()} { + auto& table = + storage::StorageManager::Get(context)->getTable(tableID)->cast(); + dataChunk.insert(0, std::make_shared(LogicalType::INTERNAL_ID())); + dataChunk.insert(1, + std::make_shared(table.getColumn(columnID).getDataType().copy())); + std::vector columnIDs; + columnIDs.push_back(columnID); + scanState = + std::make_unique(&dataChunk.getValueVectorMutable(0), + std::vector{&dataChunk.getValueVectorMutable(1)}, dataChunk.state); + scanState->source = storage::TableScanSource::COMMITTED; + scanState->setToTable(transaction::Transaction::Get(context), &table, columnIDs, {}); + } + + DataChunk dataChunk; + std::unique_ptr scanState; +}; + +static std::unique_ptr initLocalState( + const TableFuncInitLocalStateInput& input) { + const auto bindData = input.bindData.constPtrCast(); + auto tableID = bindData->tableEntry->getTableID(); + auto columnID = bindData->tableEntry->getColumnID(bindData->propertyID); + return std::make_unique(*input.clientContext, tableID, columnID); +} + +static void scanTableDataToChunk(const node_group_idx_t nodeGroupIdx, + storage::NodeTableScanState& scanState, storage::ColumnChunkData* data, + transaction::Transaction* transaction, storage::NodeTable& table) { + scanState.nodeGroupIdx = nodeGroupIdx; + table.initScanState(transaction, scanState); + + // We want to ensure that the offsets in the cached column match the offsets in the + // table + // To do this we write to the same offsets and set any non-selected (e.g. deleted) + // rows to null + data->getNullData()->resetToAllNull(); + while (table.scan(transaction, scanState)) { + const auto& selVector = scanState.outState->getSelVector(); + selVector.forEach([&](auto vectorIdx) { + const auto dataOffsetInGroup = + scanState.nodeIDVector->getValue(vectorIdx).offset - + storage::StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + data->write(scanState.outputVectors[0], vectorIdx, dataOffsetInGroup); + }); + } +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { + auto& bindData = input.bindData->cast(); + const auto sharedState = input.sharedState->ptrCast(); + auto localState = input.localState->ptrCast(); + const auto morsel = sharedState->getMorsel(); + if (morsel.isInvalid()) { + return 0; + } + auto context = input.context->clientContext; + auto columnType = bindData.tableEntry->getProperty(bindData.propertyID).getType().copy(); + auto& table = sharedState->table; + auto& scanState = *localState->scanState; + for (auto i = morsel.startOffset; i < morsel.endOffset; i++) { + auto numRows = table.getNumTuplesInNodeGroup(i); + auto data = storage::ColumnChunkFactory::createColumnChunkData( + *storage::MemoryManager::Get(*context), columnType.copy(), false /*enableCompression*/, + numRows, storage::ResidencyState::IN_MEMORY, true /*hasNullData*/, + false /*initializeToZero*/); + if (columnType.getLogicalTypeID() == LogicalTypeID::ARRAY) { + auto arrayTypeInfo = columnType.getExtraTypeInfo()->constPtrCast(); + data->cast().getDataColumnChunk()->resize( + numRows * arrayTypeInfo->getNumElements()); + } + scanTableDataToChunk(i, scanState, data.get(), transaction::Transaction::Get(*context), + table); + sharedState->merge(i, std::move(data)); + } + return morsel.endOffset - morsel.startOffset; +} + +static double progressFunc(TableFuncSharedState* sharedState) { + const auto cacheColumnSharedState = sharedState->ptrCast(); + const auto numNodeGroupsCached = cacheColumnSharedState->numNodeGroupsCached.load(); + if (cacheColumnSharedState->numRows == 0) { + return 1.0; + } + if (numNodeGroupsCached == 0) { + return 0.0; + } + return static_cast(numNodeGroupsCached) / cacheColumnSharedState->numRows; +} + +static void finalizeFunc(const processor::ExecutionContext* context, + TableFuncSharedState* sharedState) { + auto transaction = transaction::Transaction::Get(*context->clientContext); + auto cacheColumnSharedState = sharedState->ptrCast(); + auto& localCacheManager = transaction->getLocalCacheManager(); + localCacheManager.put(std::move(cacheColumnSharedState->cachedColumn)); +} + +function_set LocalCacheArrayColumnFunction::getFunctionSet() { + function_set functionSet; + std::vector inputTypes = {LogicalTypeID::STRING, LogicalTypeID::STRING}; + auto func = std::make_unique(name, inputTypes); + func->bindFunc = bindFunc; + func->initSharedStateFunc = initSharedState; + func->initLocalStateFunc = initLocalState; + func->tableFunc = tableFunc; + func->finalizeFunc = finalizeFunc; + func->canParallelFunc = [] { return true; }; + func->progressFunc = progressFunc; + func->isReadOnly = false; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/catalog_version.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/catalog_version.cpp new file mode 100644 index 0000000000..60fa477a80 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/catalog_version.cpp @@ -0,0 +1,43 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "processor/execution_context.h" + +namespace lbug { +namespace function { + +static common::offset_t internalTableFunc(const TableFuncMorsel& /*morsel*/, + const TableFuncInput& input, common::DataChunk& output) { + auto& outputVector = output.getValueVectorMutable(0); + auto pos = outputVector.state->getSelVector()[0]; + outputVector.setValue(pos, catalog::Catalog::Get(*input.context->clientContext)->getVersion()); + return 1; +} + +static std::unique_ptr bindFunc(const main::ClientContext*, + const TableFuncBindInput* input) { + std::vector returnColumnNames; + std::vector returnTypes; + returnColumnNames.emplace_back("version"); + returnTypes.emplace_back(common::LogicalType::INT64()); + returnColumnNames = + TableFunction::extractYieldVariables(returnColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(returnColumnNames, returnTypes); + return std::make_unique(std::move(columns), 1 /* one row result */); +} + +function_set CatalogVersionFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/clear_warnings.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/clear_warnings.cpp new file mode 100644 index 0000000000..7615e0126f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/clear_warnings.cpp @@ -0,0 +1,37 @@ +#include "function/table/bind_data.h" +#include "function/table/standalone_call_function.h" +#include "function/table/table_function.h" +#include "processor/execution_context.h" +#include "processor/warning_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { + auto warningContext = processor::WarningContext::Get(*input.context->clientContext); + warningContext->clearPopulatedWarnings(); + return 0; +} + +static std::unique_ptr bindFunc(const main::ClientContext*, + const TableFuncBindInput*) { + return std::make_unique(0); +} + +function_set ClearWarningsFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, std::vector{}); + func->tableFunc = tableFunc; + func->bindFunc = bindFunc; + func->initSharedStateFunc = TableFunction::initEmptySharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->canParallelFunc = []() { return false; }; + func->isReadOnly = false; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/current_setting.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/current_setting.cpp new file mode 100644 index 0000000000..50a137fd6e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/current_setting.cpp @@ -0,0 +1,58 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct CurrentSettingBindData final : TableFuncBindData { + std::string result; + + CurrentSettingBindData(std::string result, binder::expression_vector columns, + offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, result{std::move(result)} {} + + std::unique_ptr copy() const override { + return std::make_unique(result, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& /*morsel*/, const TableFuncInput& input, + common::DataChunk& output) { + auto currentSettingBindData = input.bindData->constPtrCast(); + const auto pos = output.state->getSelVector()[0]; + output.getValueVectorMutable(0).setValue(pos, currentSettingBindData->result); + return 1; +} + +static std::unique_ptr bindFunc(const ClientContext* context, + const TableFuncBindInput* input) { + auto optionName = input->getLiteralVal(0); + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back(optionName); + columnTypes.push_back(LogicalType::STRING()); + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique( + context->getCurrentSetting(optionName).toString(), columns, 1 /* one row result */); +} + +function_set CurrentSettingFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/db_version.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/db_version.cpp new file mode 100644 index 0000000000..1b29c786e3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/db_version.cpp @@ -0,0 +1,44 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +static offset_t internalTableFunc(const TableFuncMorsel& /*morsel*/, + const TableFuncInput& /*input*/, DataChunk& output) { + auto& outputVector = output.getValueVectorMutable(0); + auto pos = output.state->getSelVector()[0]; + outputVector.setValue(pos, std::string(LBUG_VERSION)); + return 1; +} + +static std::unique_ptr bindFunc(const ClientContext*, + const TableFuncBindInput* input) { + std::vector returnColumnNames; + std::vector returnTypes; + returnColumnNames.emplace_back("version"); + returnTypes.emplace_back(LogicalType::STRING()); + returnColumnNames = + TableFunction::extractYieldVariables(returnColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(returnColumnNames, returnTypes); + return std::make_unique(std::move(columns), 1 /* one row result */); +} + +function_set DBVersionFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/drop_project_graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/drop_project_graph.cpp new file mode 100644 index 0000000000..f2305afa34 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/drop_project_graph.cpp @@ -0,0 +1,51 @@ +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/standalone_call_function.h" +#include "function/table/table_function.h" +#include "graph/graph_entry_set.h" +#include "processor/execution_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct DropProjectedGraphBindData final : TableFuncBindData { + std::string graphName; + + explicit DropProjectedGraphBindData(std::string graphName) + : TableFuncBindData{0}, graphName{std::move(graphName)} {} + + std::unique_ptr copy() const override { + return std::make_unique(graphName); + } +}; + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { + const auto bindData = ku_dynamic_cast(input.bindData); + auto graphEntrySet = graph::GraphEntrySet::Get(*input.context->clientContext); + graphEntrySet->validateGraphExist(bindData->graphName); + graphEntrySet->dropGraph(bindData->graphName); + return 0; +} + +static std::unique_ptr bindFunc(const main::ClientContext*, + const TableFuncBindInput* input) { + auto graphName = input->getLiteralVal(0 /* maxOffset */); + return std::make_unique(graphName); +} + +function_set DropProjectedGraphFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + func->bindFunc = bindFunc; + func->tableFunc = tableFunc; + func->initSharedStateFunc = TableFunction::initEmptySharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->canParallelFunc = []() { return false; }; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/file_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/file_info.cpp new file mode 100644 index 0000000000..13af0d0012 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/file_info.cpp @@ -0,0 +1,53 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "storage/storage_manager.h" + +namespace lbug { +namespace function { + +struct FileInfoBindData final : TableFuncBindData { + uint64_t numPages; + + FileInfoBindData(uint64_t numPages, binder::expression_vector columns) + : TableFuncBindData{std::move(columns), 1}, numPages(numPages) {} + + std::unique_ptr copy() const override { + return std::make_unique(numPages, columns); + } +}; + +static common::offset_t internalTableFunc(const TableFuncMorsel& /*morsel*/, + const TableFuncInput& input, common::DataChunk& output) { + KU_ASSERT(output.getNumValueVectors() == 1); + auto fileInfoBindData = input.bindData->constPtrCast(); + output.getValueVectorMutable(0).setValue(0, fileInfoBindData->numPages); + return 1; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + auto numPages = storage::StorageManager::Get(*context)->getDataFH()->getNumPages(); + std::vector returnTypes; + returnTypes.emplace_back(common::LogicalType::UINT64()); + auto returnColumnNames = std::vector{"num_pages"}; + returnColumnNames = + TableFunction::extractYieldVariables(returnColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(returnColumnNames, returnTypes); + return std::make_unique(numPages, columns); +} + +function_set FileInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/free_space_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/free_space_info.cpp new file mode 100644 index 0000000000..6c29595594 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/free_space_info.cpp @@ -0,0 +1,57 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "storage/page_manager.h" + +namespace lbug { +namespace function { + +struct FreeSpaceInfoBindData final : TableFuncBindData { + const main::ClientContext* ctx; + FreeSpaceInfoBindData(binder::expression_vector columns, common::row_idx_t numRows, + const main::ClientContext* ctx) + : TableFuncBindData{std::move(columns), numRows}, ctx{ctx} {} + + std::unique_ptr copy() const override { + return std::make_unique(columns, numRows, ctx); + } +}; + +static common::offset_t internalTableFunc(const TableFuncMorsel& morsel, + const TableFuncInput& input, common::DataChunk& output) { + const auto bindData = input.bindData->constPtrCast(); + const auto entries = storage::PageManager::Get(*bindData->ctx) + ->getFreeEntries(morsel.startOffset, morsel.endOffset); + for (common::row_idx_t i = 0; i < entries.size(); ++i) { + const auto& freeEntry = entries[i]; + output.getValueVectorMutable(0).setValue(i, freeEntry.startPageIdx); + output.getValueVectorMutable(1).setValue(i, freeEntry.numPages); + } + return entries.size(); +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames = {"start_page_idx", "num_pages"}; + std::vector columnTypes; + columnTypes.push_back(common::LogicalType::UINT64()); + columnTypes.push_back(common::LogicalType::UINT64()); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(columns, + storage::PageManager::Get(*context)->getNumFreeEntries(), context); +} + +function_set FreeSpaceInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/project_cypher_graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/project_cypher_graph.cpp new file mode 100644 index 0000000000..4e134487d2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/project_cypher_graph.cpp @@ -0,0 +1,63 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/standalone_call_function.h" +#include "function/table/table_function.h" +#include "graph/graph_entry_set.h" +#include "parser/parser.h" +#include "processor/execution_context.h" + +using namespace lbug::common; +using namespace lbug::graph; + +namespace lbug { +namespace function { + +struct ProjectGraphCypherBindData final : TableFuncBindData { + std::string graphName; + std::string cypherQuery; + + ProjectGraphCypherBindData(std::string graphName, std::string cypherQuery) + : graphName{std::move(graphName)}, cypherQuery{std::move(cypherQuery)} {} + + std::unique_ptr copy() const override { + return std::make_unique(graphName, cypherQuery); + } +}; + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { + const auto bindData = ku_dynamic_cast(input.bindData); + auto graphEntrySet = GraphEntrySet::Get(*input.context->clientContext); + graphEntrySet->validateGraphNotExist(bindData->graphName); + // bind graph entry to check if input is valid or not. Ignore bind result. + auto parsedStatements = parser::Parser::parseQuery(bindData->cypherQuery); + KU_ASSERT(parsedStatements.size() == 1); + auto binder = binder::Binder(input.context->clientContext); + binder.bind(*parsedStatements[0]); + auto entry = std::make_unique(bindData->cypherQuery); + graphEntrySet->addGraph(bindData->graphName, std::move(entry)); + return 0; +} + +static std::unique_ptr bindFunc(const main::ClientContext*, + const TableFuncBindInput* input) { + auto graphName = input->getLiteralVal(0); + auto cypherQuery = input->getLiteralVal(1); + return std::make_unique(graphName, cypherQuery); +} + +function_set ProjectGraphCypherFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}); + func->bindFunc = bindFunc; + func->tableFunc = tableFunc; + func->initSharedStateFunc = TableFunction::initEmptySharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->canParallelFunc = []() { return false; }; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/project_native_graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/project_native_graph.cpp new file mode 100644 index 0000000000..0aa9705df5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/project_native_graph.cpp @@ -0,0 +1,98 @@ +#include "common/exception/binder.h" +#include "common/types/value/nested.h" +#include "function/gds/gds.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/standalone_call_function.h" +#include "graph/graph_entry_set.h" +#include "parser/parser.h" +#include "processor/execution_context.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::graph; + +namespace lbug { +namespace function { + +struct ProjectGraphNativeBindData final : TableFuncBindData { + std::string graphName; + std::vector nodeInfos; + std::vector relInfos; + + ProjectGraphNativeBindData(std::string graphName, + std::vector nodeInfos, + std::vector relInfos) + : TableFuncBindData{0}, graphName{std::move(graphName)}, nodeInfos{std::move(nodeInfos)}, + relInfos{std::move(relInfos)} {} + + std::unique_ptr copy() const override { + return std::make_unique(graphName, nodeInfos, relInfos); + } +}; + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { + const auto bindData = ku_dynamic_cast(input.bindData); + auto graphEntrySet = GraphEntrySet::Get(*input.context->clientContext); + graphEntrySet->validateGraphNotExist(bindData->graphName); + auto entry = std::make_unique(bindData->nodeInfos, bindData->relInfos); + // bind graph entry to check if input is valid or not. Ignore bind result. + GDSFunction::bindGraphEntry(*input.context->clientContext, *entry); + graphEntrySet->addGraph(bindData->graphName, std::move(entry)); + return 0; +} + +static std::string getStringVal(const Value& value) { + value.validateType(LogicalTypeID::STRING); + return value.getValue(); +} + +static std::vector extractGraphEntryTableInfos(const Value& value) { + std::vector infos; + switch (value.getDataType().getLogicalTypeID()) { + case LogicalTypeID::LIST: { + for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) { + auto tableName = getStringVal(*NestedVal::getChildVal(&value, i)); + infos.emplace_back(tableName, "" /* empty predicate */); + } + } break; + case LogicalTypeID::STRUCT: { + for (auto i = 0u; i < StructType::getNumFields(value.getDataType()); ++i) { + auto& field = StructType::getField(value.getDataType(), i); + auto tableName = field.getName(); + auto predicate = getStringVal(*NestedVal::getChildVal(&value, i)); + infos.emplace_back(tableName, predicate); + } + } break; + default: + throw BinderException( + stringFormat("Argument {} has data type {}. LIST or STRUCT was expected.", + value.toString(), value.getDataType().toString())); + } + return infos; +} + +static std::unique_ptr bindFunc(const main::ClientContext*, + const TableFuncBindInput* input) { + auto graphName = input->getLiteralVal(0); + auto nodeInfos = extractGraphEntryTableInfos(input->getValue(1)); + auto relInfos = extractGraphEntryTableInfos(input->getValue(2)); + return std::make_unique(graphName, nodeInfos, relInfos); +} + +function_set ProjectGraphNativeFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::ANY, LogicalTypeID::ANY}); + func->bindFunc = bindFunc; + func->tableFunc = tableFunc; + func->initSharedStateFunc = TableFunction::initEmptySharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->canParallelFunc = []() { return false; }; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/projected_graph_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/projected_graph_info.cpp new file mode 100644 index 0000000000..d6c356c189 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/projected_graph_info.cpp @@ -0,0 +1,170 @@ +#include "binder/binder.h" +#include "common/exception/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "graph/graph_entry_set.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct ProjectedGraphInfo { + virtual ~ProjectedGraphInfo() = default; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + + virtual std::unique_ptr copy() const = 0; +}; + +struct ProjectedTableInfo { + std::string tableType; + std::string tableName; + std::string predicate; + + ProjectedTableInfo(std::string tableType, std::string tableName, std::string predicate) + : tableType{std::move(tableType)}, tableName{std::move(tableName)}, + predicate{std::move(predicate)} {} +}; + +struct NativeProjectedGraphInfo final : ProjectedGraphInfo { + std::vector tableInfo; + + explicit NativeProjectedGraphInfo(std::vector tableInfo) + : tableInfo{std::move(tableInfo)} {} + + std::unique_ptr copy() const override { + return std::make_unique(tableInfo); + } +}; + +struct CypherProjectedGraphInfo final : ProjectedGraphInfo { + std::string cypherQuery; + + explicit CypherProjectedGraphInfo(std::string cypherQuery) + : cypherQuery{std::move(cypherQuery)} {} + + std::unique_ptr copy() const override { + return std::make_unique(cypherQuery); + } +}; + +struct ProjectedGraphInfoBindData final : TableFuncBindData { + graph::GraphEntryType type; + std::unique_ptr info; + + ProjectedGraphInfoBindData(binder::expression_vector columns, graph::GraphEntryType type, + std::unique_ptr info) + : TableFuncBindData{std::move(columns), + type == graph::GraphEntryType::NATIVE ? + info->constCast().tableInfo.size() : + 1}, + type{type}, info{std::move(info)} {} + + std::unique_ptr copy() const override { + return std::make_unique(columns, type, info->copy()); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto projectedGraphData = input.bindData->constPtrCast(); + switch (projectedGraphData->type) { + case graph::GraphEntryType::NATIVE: { + auto morselSize = morsel.getMorselSize(); + auto nativeProjectedGraphInfo = + projectedGraphData->info->constCast(); + for (auto i = 0u; i < morselSize; i++) { + auto& tableInfo = nativeProjectedGraphInfo.tableInfo[i + morsel.startOffset]; + output.getValueVectorMutable(0).setValue(i, tableInfo.tableType); + output.getValueVectorMutable(1).setValue(i, tableInfo.tableName); + output.getValueVectorMutable(2).setValue(i, tableInfo.predicate); + } + return morselSize; + } + case graph::GraphEntryType::CYPHER: { + output.getValueVectorMutable(0).setValue(0, + projectedGraphData->info->constCast().cypherQuery); + return 1; + } + default: + KU_UNREACHABLE; + } +} + +static std::unique_ptr bindFunc(const ClientContext* context, + const TableFuncBindInput* input) { + std::vector returnColumnNames; + std::vector returnTypes; + auto graphName = input->getValue(0).toString(); + auto graphEntrySet = graph::GraphEntrySet::Get(*context); + if (!graphEntrySet->hasGraph(graphName)) { + throw BinderException(stringFormat("Graph {} does not exist.", graphName)); + } + auto graphEntry = graphEntrySet->getEntry(graphName); + switch (graphEntry->type) { + case graph::GraphEntryType::CYPHER: { + returnColumnNames.emplace_back("cypher statement"); + returnTypes.emplace_back(LogicalType::STRING()); + } break; + case graph::GraphEntryType::NATIVE: { + returnColumnNames.emplace_back("table type"); + returnTypes.emplace_back(LogicalType::STRING()); + returnColumnNames.emplace_back("table name"); + returnTypes.emplace_back(LogicalType::STRING()); + returnColumnNames.emplace_back("predicate"); + returnTypes.emplace_back(LogicalType::STRING()); + } break; + default: { + KU_UNREACHABLE; + } + } + returnColumnNames = + TableFunction::extractYieldVariables(returnColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(returnColumnNames, returnTypes); + std::unique_ptr projectedGraphInfo; + switch (graphEntry->type) { + case graph::GraphEntryType::CYPHER: { + auto& cypherGraphEntry = graphEntry->cast(); + projectedGraphInfo = + std::make_unique(cypherGraphEntry.cypherQuery); + } break; + case graph::GraphEntryType::NATIVE: { + auto& nativeGraphEntry = graphEntry->cast(); + std::vector tableInfo; + for (auto& nodeInfo : nativeGraphEntry.nodeInfos) { + tableInfo.emplace_back(TableTypeUtils::toString(TableType::NODE), nodeInfo.tableName, + nodeInfo.predicate); + } + for (auto& relInfo : nativeGraphEntry.relInfos) { + tableInfo.emplace_back(TableTypeUtils::toString(TableType::REL), relInfo.tableName, + relInfo.predicate); + } + projectedGraphInfo = std::make_unique(std::move(tableInfo)); + } break; + default: + KU_UNREACHABLE; + } + return std::make_unique(std::move(columns), graphEntry->type, + std::move(projectedGraphInfo)); +} + +function_set ProjectedGraphInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = + std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_attached_databases.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_attached_databases.cpp new file mode 100644 index 0000000000..22174cb7cb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_attached_databases.cpp @@ -0,0 +1,66 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "main/database_manager.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +struct ShowAttachedDatabasesBindData final : TableFuncBindData { + std::vector attachedDatabases; + + ShowAttachedDatabasesBindData(std::vector attachedDatabases, + binder::expression_vector columns, offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, + attachedDatabases{std::move(attachedDatabases)} {} + + std::unique_ptr copy() const override { + return std::make_unique(attachedDatabases, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto& attachedDatabases = + input.bindData->constPtrCast()->attachedDatabases; + auto numDatabasesToOutput = morsel.getMorselSize(); + for (auto i = 0u; i < numDatabasesToOutput; i++) { + const auto attachedDatabase = attachedDatabases[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, attachedDatabase->getDBName()); + output.getValueVectorMutable(1).setValue(i, attachedDatabase->getDBType()); + } + return numDatabasesToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("database type"); + columnTypes.emplace_back(LogicalType::STRING()); + auto attachedDatabases = main::DatabaseManager::Get(*context)->getAttachedDatabases(); + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(attachedDatabases, columns, + attachedDatabases.size()); +} + +function_set ShowAttachedDatabasesFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_connection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_connection.cpp new file mode 100644 index 0000000000..b98e9d3a4c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_connection.cpp @@ -0,0 +1,101 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct ShowConnectionBindData final : TableFuncBindData { + std::vector> srcDstEntries; + + ShowConnectionBindData( + std::vector> srcDstEntries, + binder::expression_vector columns, offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, + srcDstEntries{std::move(srcDstEntries)} {} + + std::unique_ptr copy() const override { + return std::make_unique(srcDstEntries, columns, numRows); + } +}; + +static void outputRelTableConnection(DataChunk& outputDataChunk, uint64_t outputPos, + const NodeTableCatalogEntry& srcEntry, const NodeTableCatalogEntry& dstEntry) { + // Write result to dataChunk + outputDataChunk.getValueVectorMutable(0).setValue(outputPos, srcEntry.getName()); + outputDataChunk.getValueVectorMutable(1).setValue(outputPos, dstEntry.getName()); + outputDataChunk.getValueVectorMutable(2).setValue(outputPos, srcEntry.getPrimaryKeyName()); + outputDataChunk.getValueVectorMutable(3).setValue(outputPos, dstEntry.getPrimaryKeyName()); +} + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + const auto bindData = input.bindData->constPtrCast(); + auto i = 0u; + auto size = morsel.getMorselSize(); + for (; i < size; i++) { + auto [srcEntry, dstEntry] = bindData->srcDstEntries[i + morsel.startOffset]; + outputRelTableConnection(output, i, *srcEntry, *dstEntry); + } + return i; +} + +static std::unique_ptr bindFunc(const ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("source table name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("destination table name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("source table primary key"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("destination table primary key"); + columnTypes.emplace_back(LogicalType::STRING()); + const auto name = input->getLiteralVal(0); + const auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + std::vector> srcDstEntries; + if (catalog->containsTable(transaction, name)) { + auto entry = catalog->getTableCatalogEntry(transaction, name); + if (entry->getType() != catalog::CatalogEntryType::REL_GROUP_ENTRY) { + throw BinderException{"Show connection can only be called on a rel table!"}; + } + for (auto& info : entry->ptrCast()->getRelEntryInfos()) { + auto srcEntry = catalog->getTableCatalogEntry(transaction, info.nodePair.srcTableID) + ->ptrCast(); + auto dstEntry = catalog->getTableCatalogEntry(transaction, info.nodePair.dstTableID) + ->ptrCast(); + srcDstEntries.emplace_back(srcEntry, dstEntry); + } + } else { + throw BinderException{"Show connection can only be called on a rel table!"}; + } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(srcDstEntries, columns, srcDstEntries.size()); +} + +function_set ShowConnectionFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_functions.cpp new file mode 100644 index 0000000000..57b5b75007 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_functions.cpp @@ -0,0 +1,86 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +struct FunctionInfo { + std::string name; + std::string type; + std::string signature; + + FunctionInfo(std::string name, std::string type, std::string signature) + : name{std::move(name)}, type{std::move(type)}, signature{std::move(signature)} {} +}; + +struct ShowFunctionsBindData final : TableFuncBindData { + std::vector sequences; + + ShowFunctionsBindData(std::vector sequences, binder::expression_vector columns, + offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, sequences{std::move(sequences)} {} + + std::unique_ptr copy() const override { + return std::make_unique(sequences, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto sequences = input.bindData->constPtrCast()->sequences; + auto numSequencesToOutput = morsel.getMorselSize(); + for (auto i = 0u; i < numSequencesToOutput; i++) { + const auto functionInfo = sequences[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, functionInfo.name); + output.getValueVectorMutable(1).setValue(i, functionInfo.type); + output.getValueVectorMutable(2).setValue(i, functionInfo.signature); + } + return numSequencesToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("type"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("signature"); + columnTypes.emplace_back(LogicalType::STRING()); + std::vector FunctionInfos; + for (const auto& entry : + Catalog::Get(*context)->getFunctionEntries(transaction::Transaction::Get(*context))) { + const auto& functionSet = entry->getFunctionSet(); + const auto type = FunctionEntryTypeUtils::toString(entry->getType()); + for (auto& function : functionSet) { + auto signature = function->signatureToString(); + FunctionInfos.emplace_back(entry->getName(), type, signature); + } + } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(std::move(FunctionInfos), columns, + FunctionInfos.size()); +} + +function_set ShowFunctionsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_indexes.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_indexes.cpp new file mode 100644 index 0000000000..9bbcf4246e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_indexes.cpp @@ -0,0 +1,127 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct IndexInfo { + std::string tableName; + std::string indexName; + std::string indexType; + std::vector properties; + bool dependencyLoaded; + std::string indexDefinition; + + IndexInfo(std::string tableName, std::string indexName, std::string indexType, + std::vector properties, bool dependencyLoaded, std::string indexDefinition) + : tableName{std::move(tableName)}, indexName{std::move(indexName)}, + indexType{std::move(indexType)}, properties{std::move(properties)}, + dependencyLoaded{dependencyLoaded}, indexDefinition{std::move(indexDefinition)} {} +}; + +struct ShowIndexesBindData final : TableFuncBindData { + std::vector indexesInfo; + + ShowIndexesBindData(std::vector indexesInfo, binder::expression_vector columns, + offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, indexesInfo{std::move(indexesInfo)} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto& indexesInfo = input.bindData->constPtrCast()->indexesInfo; + auto numTuplesToOutput = morsel.getMorselSize(); + auto& propertyVector = output.getValueVectorMutable(3); + auto propertyDataVec = ListVector::getDataVector(&propertyVector); + for (auto i = 0u; i < numTuplesToOutput; i++) { + auto indexInfo = indexesInfo[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, indexInfo.tableName); + output.getValueVectorMutable(1).setValue(i, indexInfo.indexName); + output.getValueVectorMutable(2).setValue(i, indexInfo.indexType); + auto listEntry = ListVector::addList(&propertyVector, indexInfo.properties.size()); + for (auto j = 0u; j < indexInfo.properties.size(); j++) { + propertyDataVec->setValue(listEntry.offset + j, indexInfo.properties[j]); + } + propertyVector.setValue(i, listEntry); + output.getValueVectorMutable(4).setValue(i, indexInfo.dependencyLoaded); + output.getValueVectorMutable(5).setValue(i, indexInfo.indexDefinition); + } + return numTuplesToOutput; +} + +static binder::expression_vector bindColumns(const TableFuncBindInput& input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("table_name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("index_name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("index_type"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("property_names"); + columnTypes.emplace_back(LogicalType::LIST(LogicalType::STRING())); + columnNames.emplace_back("extension_loaded"); + columnTypes.emplace_back(LogicalType::BOOL()); + columnNames.emplace_back("index_definition"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames = TableFunction::extractYieldVariables(columnNames, input.yieldVariables); + return input.binder->createVariables(columnNames, columnTypes); +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector indexesInfo; + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto indexEntries = catalog->getIndexEntries(transaction); + for (auto indexEntry : indexEntries) { + auto tableEntry = catalog->getTableCatalogEntry(transaction, indexEntry->getTableID()); + auto tableName = tableEntry->getName(); + auto indexName = indexEntry->getIndexName(); + auto indexType = indexEntry->getIndexType(); + auto properties = indexEntry->getPropertyIDs(); + std::vector propertyNames; + for (auto& property : properties) { + propertyNames.push_back(tableEntry->getProperty(property).getName()); + } + auto dependencyLoaded = indexEntry->isLoaded(); + std::string indexDefinition; + if (dependencyLoaded) { + auto& auxInfo = indexEntry->getAuxInfo(); + common::FileScanInfo exportFileInfo{}; + IndexToCypherInfo info{context, exportFileInfo}; + indexDefinition = auxInfo.toCypher(*indexEntry, info); + } + indexesInfo.emplace_back(std::move(tableName), std::move(indexName), std::move(indexType), + std::move(propertyNames), dependencyLoaded, std::move(indexDefinition)); + } + return std::make_unique(indexesInfo, bindColumns(*input), + indexesInfo.size()); +} + +function_set ShowIndexesFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_loaded_extensions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_loaded_extensions.cpp new file mode 100644 index 0000000000..1d61f26834 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_loaded_extensions.cpp @@ -0,0 +1,91 @@ +#include "binder/binder.h" +#include "extension/extension.h" +#include "extension/extension_manager.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct LoadedExtensionInfo { + std::string name; + extension::ExtensionSource extensionSource; + std::string extensionPath; + + LoadedExtensionInfo(std::string name, extension::ExtensionSource extensionSource, + std::string extensionPath) + : name{std::move(name)}, extensionSource{extensionSource}, + extensionPath{std::move(extensionPath)} {} +}; + +struct ShowLoadedExtensionsBindData final : TableFuncBindData { + std::vector loadedExtensionInfo; + + ShowLoadedExtensionsBindData(std::vector loadedExtensionInfo, + binder::expression_vector columns, offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, + loadedExtensionInfo{std::move(loadedExtensionInfo)} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto& loadedExtensions = + input.bindData->constPtrCast()->loadedExtensionInfo; + auto numTuplesToOutput = morsel.getMorselSize(); + for (auto i = 0u; i < numTuplesToOutput; i++) { + auto loadedExtension = loadedExtensions[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, loadedExtension.name); + output.getValueVectorMutable(1).setValue(i, + extension::ExtensionSourceUtils::toString(loadedExtension.extensionSource)); + output.getValueVectorMutable(2).setValue(i, loadedExtension.extensionPath); + } + return numTuplesToOutput; +} + +static binder::expression_vector bindColumns(const TableFuncBindInput& input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("extension name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("extension source"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("extension path"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames = TableFunction::extractYieldVariables(columnNames, input.yieldVariables); + return input.binder->createVariables(columnNames, columnTypes); +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + auto loadedExtensions = extension::ExtensionManager::Get(*context)->getLoadedExtensions(); + std::vector loadedExtensionInfo; + for (auto& loadedExtension : loadedExtensions) { + loadedExtensionInfo.emplace_back(loadedExtension.getExtensionName(), + loadedExtension.getSource(), loadedExtension.getFullPath()); + } + return std::make_unique(loadedExtensionInfo, bindColumns(*input), + loadedExtensionInfo.size()); +} + +function_set ShowLoadedExtensionsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_macros.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_macros.cpp new file mode 100644 index 0000000000..2204257c96 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_macros.cpp @@ -0,0 +1,82 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +struct MacroInfo { + std::string name; + std::string definition; + + MacroInfo(std::string name, std::string definition) + : name{std::move(name)}, definition(std::move(definition)) {} +}; + +struct ShowMacrosBindData final : TableFuncBindData { + std::vector macros; + + ShowMacrosBindData(std::vector macros, binder::expression_vector columns, + row_idx_t numRows) + : TableFuncBindData{std::move(columns), numRows}, macros{std::move(macros)} {} + + std::unique_ptr copy() const override { + return std::make_unique(macros, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + const auto macros = input.bindData->constPtrCast()->macros; + const auto numMacrosToOutput = morsel.endOffset - morsel.startOffset; + for (auto i = 0u; i < numMacrosToOutput; i++) { + const auto tableInfo = macros[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, tableInfo.name); + output.getValueVectorMutable(1).setValue(i, tableInfo.definition); + } + return numMacrosToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("definition"); + columnTypes.emplace_back(LogicalType::STRING()); + std::vector macroInfos; + auto transaction = transaction::Transaction::Get(*context); + auto catalog = Catalog::Get(*context); + for (auto& entry : catalog->getMacroEntries(transaction)) { + std::string name = entry->getName(); + auto macroFunction = catalog->getScalarMacroFunction(transaction, name); + macroInfos.emplace_back(name, macroFunction->toCypher(name)); + } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(std::move(macroInfos), std::move(columns), + macroInfos.size()); +} + +function_set ShowMacrosFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_official_extensions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_official_extensions.cpp new file mode 100644 index 0000000000..c0bdc165f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_official_extensions.cpp @@ -0,0 +1,65 @@ +#include "binder/binder.h" +#include "extension/extension.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +static constexpr std::pair extensions[] = { + {"ALGO", "Adds support for graph algorithms"}, + {"AZURE", "Adds support for reading from azure blob storage"}, + {"DELTA", "Adds support for reading from delta tables"}, + {"DUCKDB", "Adds support for reading from duckdb tables"}, + {"FTS", "Adds support for full-text search indexes"}, + {"HTTPFS", "Adds support for reading and writing files over a HTTP(S)/S3 filesystem"}, + {"ICEBERG", "Adds support for reading from iceberg tables"}, + {"JSON", "Adds support for JSON operations"}, {"LLM", "Adds support for LLM operations"}, + {"NEO4J", "Adds support for migrating nodes and rels from neo4j to lbug"}, + {"POSTGRES", "Adds support for reading from POSTGRES tables"}, + {"SQLITE", "Adds support for reading from SQLITE tables"}, + {"UNITY_CATALOG", "Adds support for scanning delta tables registered in unity catalog"}}; +static constexpr auto officialExtensions = std::to_array(extensions); + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& /*input*/, + DataChunk& output) { + auto numTuplesToOutput = morsel.getMorselSize(); + for (auto i = 0u; i < numTuplesToOutput; ++i) { + auto& [name, description] = officialExtensions[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, name); + output.getValueVectorMutable(1).setValue(i, description); + } + return numTuplesToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* /*context*/, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("description"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(std::move(columns), officialExtensions.size()); +} + +function_set ShowOfficialExtensionsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_projected_graphs.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_projected_graphs.cpp new file mode 100644 index 0000000000..b7f1dedc88 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_projected_graphs.cpp @@ -0,0 +1,78 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "graph/graph_entry_set.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct ProjectedGraphData { + std::string name; + std::string type; + + ProjectedGraphData(std::string name, std::string type) + : name{std::move(name)}, type{std::move(type)} {} +}; + +struct ShowProjectedGraphBindData : public TableFuncBindData { + std::vector projectedGraphData; + + ShowProjectedGraphBindData(std::vector projectedGraphData, + binder::expression_vector columns) + : TableFuncBindData{std::move(columns), projectedGraphData.size()}, + projectedGraphData{std::move(projectedGraphData)} {} + + std::unique_ptr copy() const override { + return std::make_unique(projectedGraphData, columns); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto& projectedGraphData = + input.bindData->constPtrCast()->projectedGraphData; + auto numTablesToOutput = morsel.endOffset - morsel.startOffset; + for (auto i = 0u; i < numTablesToOutput; i++) { + auto graphData = projectedGraphData[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, graphData.name); + output.getValueVectorMutable(1).setValue(i, graphData.type); + } + return numTablesToOutput; +} + +static std::unique_ptr bindFunc(const ClientContext* context, + const TableFuncBindInput* input) { + std::vector returnColumnNames; + std::vector returnTypes; + returnColumnNames.emplace_back("name"); + returnTypes.emplace_back(LogicalType::STRING()); + returnColumnNames.emplace_back("type"); + returnTypes.emplace_back(LogicalType::STRING()); + returnColumnNames = + TableFunction::extractYieldVariables(returnColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(returnColumnNames, returnTypes); + std::vector projectedGraphData; + for (auto& [name, entry] : graph::GraphEntrySet::Get(*context)->getNameToEntryMap()) { + projectedGraphData.emplace_back(name, graph::GraphEntryTypeUtils::toString(entry->type)); + } + return std::make_unique(std::move(projectedGraphData), + std::move(columns)); +} + +function_set ShowProjectedGraphsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_sequences.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_sequences.cpp new file mode 100644 index 0000000000..fa94887185 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_sequences.cpp @@ -0,0 +1,120 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +struct SequenceInfo { + std::string name; + std::string databaseName; + int64_t startValue; + int64_t increment; + int64_t minValue; + int64_t maxValue; + bool cycle; + + SequenceInfo(std::string name, std::string databaseName, int64_t startValue, int64_t increment, + int64_t minValue, int64_t maxValue, bool cycle) + : name{std::move(name)}, databaseName{std::move(databaseName)}, startValue{startValue}, + increment{increment}, minValue{minValue}, maxValue{maxValue}, cycle{cycle} {} +}; + +struct ShowSequencesBindData final : TableFuncBindData { + std::vector sequences; + + ShowSequencesBindData(std::vector sequences, binder::expression_vector columns, + offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, sequences{std::move(sequences)} {} + + std::unique_ptr copy() const override { + return std::make_unique(sequences, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + const auto sequences = input.bindData->constPtrCast()->sequences; + const auto numSequencesToOutput = morsel.endOffset - morsel.startOffset; + for (auto i = 0u; i < numSequencesToOutput; i++) { + const auto sequenceInfo = sequences[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, sequenceInfo.name); + output.getValueVectorMutable(1).setValue(i, sequenceInfo.databaseName); + output.getValueVectorMutable(2).setValue(i, sequenceInfo.startValue); + output.getValueVectorMutable(3).setValue(i, sequenceInfo.increment); + output.getValueVectorMutable(4).setValue(i, sequenceInfo.minValue); + output.getValueVectorMutable(5).setValue(i, sequenceInfo.maxValue); + output.getValueVectorMutable(6).setValue(i, sequenceInfo.cycle); + } + return numSequencesToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("database name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("start value"); + columnTypes.emplace_back(LogicalType::INT64()); + columnNames.emplace_back("increment"); + columnTypes.emplace_back(LogicalType::INT64()); + columnNames.emplace_back("min value"); + columnTypes.emplace_back(LogicalType::INT64()); + columnNames.emplace_back("max value"); + columnTypes.emplace_back(LogicalType::INT64()); + columnNames.emplace_back("cycle"); + columnTypes.emplace_back(LogicalType::BOOL()); + std::vector sequenceInfos; + for (const auto& entry : + Catalog::Get(*context)->getSequenceEntries(transaction::Transaction::Get(*context))) { + const auto sequenceData = entry->getSequenceData(); + auto sequenceInfo = SequenceInfo{entry->getName(), LOCAL_DB_NAME, sequenceData.startValue, + sequenceData.increment, sequenceData.minValue, sequenceData.maxValue, + sequenceData.cycle}; + sequenceInfos.push_back(std::move(sequenceInfo)); + } + + // TODO: uncomment this when we can test it + // for (auto attachedDatabase : databaseManager->getAttachedDatabases()) { + // auto databaseName = attachedDatabase->getDBName(); + // auto databaseType = attachedDatabase->getDBType(); + // for (auto& entry : attachedDatabase->getCatalog()->getSequenceEntries(context->getTx())) + // { + // auto sequenceData = entry->getSequenceData(); + // auto sequenceInfo = + // SequenceInfo{entry->getName(), stringFormat("{}({})", databaseName, + // databaseType), + // sequenceData.startValue, sequenceData.increment, sequenceData.minValue, + // sequenceData.maxValue, sequenceData.cycle}; + // sequenceInfos.push_back(std::move(sequenceInfo)); + // } + // } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(std::move(sequenceInfos), columns, + sequenceInfos.size()); +} + +function_set ShowSequencesFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_tables.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_tables.cpp new file mode 100644 index 0000000000..913a514a33 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_tables.cpp @@ -0,0 +1,110 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "function/table/bind_data.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "main/database_manager.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +struct TableInfo { + std::string name; + table_id_t id; + std::string type; + std::string databaseName; + std::string comment; + + TableInfo(std::string name, table_id_t id, std::string type, std::string databaseName, + std::string comment) + : name{std::move(name)}, id{id}, type{std::move(type)}, + databaseName{std::move(databaseName)}, comment{std::move(comment)} {} +}; + +struct ShowTablesBindData final : TableFuncBindData { + std::vector tables; + + ShowTablesBindData(std::vector tables, binder::expression_vector columns, + row_idx_t numRows) + : TableFuncBindData{std::move(columns), numRows}, tables{std::move(tables)} {} + + std::unique_ptr copy() const override { + return std::make_unique(tables, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + const auto tables = input.bindData->constPtrCast()->tables; + const auto numTablesToOutput = morsel.endOffset - morsel.startOffset; + for (auto i = 0u; i < numTablesToOutput; i++) { + const auto tableInfo = tables[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, tableInfo.id); + output.getValueVectorMutable(1).setValue(i, tableInfo.name); + output.getValueVectorMutable(2).setValue(i, tableInfo.type); + output.getValueVectorMutable(3).setValue(i, tableInfo.databaseName); + output.getValueVectorMutable(4).setValue(i, tableInfo.comment); + } + return numTablesToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("id"); + columnTypes.emplace_back(LogicalType::UINT64()); + columnNames.emplace_back("name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("type"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("database name"); + columnTypes.emplace_back(LogicalType::STRING()); + columnNames.emplace_back("comment"); + columnTypes.emplace_back(LogicalType::STRING()); + std::vector tableInfos; + auto transaction = transaction::Transaction::Get(*context); + if (!context->hasDefaultDatabase()) { + auto catalog = Catalog::Get(*context); + for (auto& entry : + catalog->getTableEntries(transaction, context->useInternalCatalogEntry())) { + tableInfos.emplace_back(entry->getName(), entry->getTableID(), + TableTypeUtils::toString(entry->getTableType()), LOCAL_DB_NAME, + entry->getComment()); + } + } + + for (auto attachedDatabase : main::DatabaseManager::Get(*context)->getAttachedDatabases()) { + auto databaseName = attachedDatabase->getDBName(); + auto databaseType = attachedDatabase->getDBType(); + for (auto& entry : attachedDatabase->getCatalog()->getTableEntries(transaction, + context->useInternalCatalogEntry())) { + auto tableInfo = TableInfo{entry->getName(), entry->getTableID(), + TableTypeUtils::toString(entry->getTableType()), + stringFormat("{}({})", databaseName, databaseType), entry->getComment()}; + tableInfos.push_back(std::move(tableInfo)); + } + } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(std::move(tableInfos), std::move(columns), + tableInfos.size()); +} + +function_set ShowTablesFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_warnings.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_warnings.cpp new file mode 100644 index 0000000000..978842dddd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/show_warnings.cpp @@ -0,0 +1,67 @@ +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "processor/warning_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ShowWarningsBindData final : TableFuncBindData { + std::vector warnings; + + ShowWarningsBindData(std::vector warnings, + binder::expression_vector columns, offset_t maxOffset) + : TableFuncBindData{std::move(columns), maxOffset}, warnings{std::move(warnings)} {} + + std::unique_ptr copy() const override { + return std::make_unique(warnings, columns, numRows); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + const auto warnings = input.bindData->constPtrCast()->warnings; + const auto numWarningsToOutput = morsel.endOffset - morsel.startOffset; + for (auto i = 0u; i < numWarningsToOutput; i++) { + const auto tableInfo = warnings[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, tableInfo.queryID); + output.getValueVectorMutable(1).setValue(i, tableInfo.warning.message); + output.getValueVectorMutable(2).setValue(i, tableInfo.warning.filePath); + output.getValueVectorMutable(3).setValue(i, tableInfo.warning.lineNumber); + output.getValueVectorMutable(4).setValue(i, tableInfo.warning.skippedLineOrRecord); + } + return numWarningsToOutput; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames{WarningConstants::WARNING_TABLE_COLUMN_NAMES.begin(), + WarningConstants::WARNING_TABLE_COLUMN_NAMES.end()}; + std::vector columnTypes{WarningConstants::WARNING_TABLE_COLUMN_DATA_TYPES.begin(), + WarningConstants::WARNING_TABLE_COLUMN_DATA_TYPES.end()}; + std::vector warningInfos; + for (const auto& warning : processor::WarningContext::Get(*context)->getPopulatedWarnings()) { + warningInfos.emplace_back(warning); + } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(std::move(warningInfos), columns, + warningInfos.size()); +} + +function_set ShowWarningsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/simple_table_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/simple_table_function.cpp new file mode 100644 index 0000000000..18002d280c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/simple_table_function.cpp @@ -0,0 +1,39 @@ +#include "function/table/simple_table_function.h" + +#include "function/table/bind_data.h" + +namespace lbug { +namespace function { + +TableFuncMorsel SimpleTableFuncSharedState::getMorsel() { + std::lock_guard lck{mtx}; + KU_ASSERT(curRowIdx <= numRows); + if (curRowIdx == numRows) { + return TableFuncMorsel::createInvalidMorsel(); + } + const auto numValuesToOutput = std::min(maxMorselSize, numRows - curRowIdx); + curRowIdx += numValuesToOutput; + return {curRowIdx - numValuesToOutput, curRowIdx}; +} + +std::unique_ptr SimpleTableFunc::initSharedState( + const TableFuncInitSharedStateInput& input) { + return std::make_unique(input.bindData->numRows); +} + +common::offset_t tableFunc(simple_internal_table_func internalTableFunc, + const TableFuncInput& input, TableFuncOutput& output) { + const auto sharedState = input.sharedState->ptrCast(); + auto morsel = sharedState->getMorsel(); + if (!morsel.hasMoreToOutput()) { + return 0; + } + return internalTableFunc(morsel, input, output.dataChunk); +} + +table_func_t SimpleTableFunc::getTableFunc(simple_internal_table_func internalTableFunc) { + return std::bind(tableFunc, internalTableFunc, std::placeholders::_1, std::placeholders::_2); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/stats_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/stats_info.cpp new file mode 100644 index 0000000000..ea9ce9464f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/stats_info.cpp @@ -0,0 +1,93 @@ +#include "binder/binder.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/exception/binder.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct StatsInfoBindData final : TableFuncBindData { + TableCatalogEntry* tableEntry; + storage::Table* table; + const ClientContext* context; + + StatsInfoBindData(binder::expression_vector columns, TableCatalogEntry* tableEntry, + storage::Table* table, const ClientContext* context) + : TableFuncBindData{std::move(columns), 1 /*numRows*/}, tableEntry{tableEntry}, + table{table}, context{context} {} + + std::unique_ptr copy() const override { + return std::make_unique(columns, tableEntry, table, context); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& /*morsel*/, const TableFuncInput& input, + DataChunk& output) { + const auto bindData = input.bindData->constPtrCast(); + const auto table = bindData->table; + switch (table->getTableType()) { + case TableType::NODE: { + const auto& nodeTable = table->cast(); + const auto stats = nodeTable.getStats(transaction::Transaction::Get(*bindData->context)); + output.getValueVectorMutable(0).setValue(0, stats.getTableCard()); + for (auto i = 0u; i < nodeTable.getNumColumns(); ++i) { + output.getValueVectorMutable(i + 1).setValue(0, stats.getNumDistinctValues(i)); + } + } break; + default: { + KU_UNREACHABLE; + } + } + return 1; +} + +static std::unique_ptr bindFunc(const ClientContext* context, + const TableFuncBindInput* input) { + const auto tableName = input->getLiteralVal(0); + const auto catalog = Catalog::Get(*context); + if (!catalog->containsTable(transaction::Transaction::Get(*context), tableName)) { + throw BinderException{"Table " + tableName + " does not exist!"}; + } + auto tableEntry = + catalog->getTableCatalogEntry(transaction::Transaction::Get(*context), tableName); + if (tableEntry->getTableType() != TableType::NODE) { + throw BinderException{ + "Stats from a non-node table " + tableName + " is not supported yet!"}; + } + + std::vector columnNames = {"cardinality"}; + std::vector columnTypes; + columnTypes.push_back(LogicalType::INT64()); + for (auto& propDef : tableEntry->getProperties()) { + columnNames.push_back(propDef.getName() + "_distinct_count"); + columnTypes.push_back(LogicalType::INT64()); + } + const auto storageManager = storage::StorageManager::Get(*context); + auto table = storageManager->getTable(tableEntry->getTableID()); + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(columns, tableEntry, table, context); +} + +function_set StatsInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/storage_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/storage_info.cpp new file mode 100644 index 0000000000..deaddb96fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/storage_info.cpp @@ -0,0 +1,348 @@ +#include "binder/binder.h" +#include "common/data_chunk/data_chunk_collection.h" +#include "common/exception/binder.h" +#include "common/type_utils.h" +#include "common/types/interval_t.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "storage/storage_manager.h" +#include "storage/table/list_chunk_data.h" +#include "storage/table/list_column.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" +#include "storage/table/string_chunk_data.h" +#include "storage/table/string_column.h" +#include "storage/table/struct_chunk_data.h" +#include "storage/table/struct_column.h" +#include + +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::storage; +using namespace lbug::main; + +namespace lbug { +namespace function { + +struct StorageInfoLocalState final : TableFuncLocalState { + std::unique_ptr dataChunkCollection; + idx_t currChunkIdx; + + explicit StorageInfoLocalState(MemoryManager* mm) : currChunkIdx{0} { + dataChunkCollection = std::make_unique(mm); + } +}; + +struct StorageInfoBindData final : TableFuncBindData { + TableCatalogEntry* tableEntry; + const ClientContext* context; + + StorageInfoBindData(binder::expression_vector columns, TableCatalogEntry* tableEntry, + const ClientContext* context) + : TableFuncBindData{std::move(columns), 1 /*maxOffset*/}, tableEntry{tableEntry}, + context{context} {} + + std::unique_ptr copy() const override { + return std::make_unique(columns, tableEntry, context); + } +}; + +static std::unique_ptr initLocalState( + const TableFuncInitLocalStateInput& input) { + return std::make_unique(MemoryManager::Get(*input.clientContext)); +} + +struct StorageInfoOutputData { + node_group_idx_t nodeGroupIdx = INVALID_NODE_GROUP_IDX; + node_group_idx_t chunkIdx = INVALID_NODE_GROUP_IDX; + std::string tableType; + uint32_t columnIdx = INVALID_COLUMN_ID; + std::vector columns; +}; + +static void resetOutputIfNecessary(const StorageInfoLocalState* localState, + DataChunk& outputChunk) { + if (outputChunk.state->getSelVector().getSelSize() == DEFAULT_VECTOR_CAPACITY) { + localState->dataChunkCollection->append(outputChunk); + outputChunk.resetAuxiliaryBuffer(); + outputChunk.state->getSelVectorUnsafe().setSelSize(0); + } +} + +static void appendStorageInfoForChunkData(StorageInfoLocalState* localState, DataChunk& outputChunk, + StorageInfoOutputData& outputData, const Column& column, const ColumnChunkData& chunkData, + bool ignoreNull = false) { + resetOutputIfNecessary(localState, outputChunk); + auto vectorPos = outputChunk.state->getSelVector().getSelSize(); + auto residency = chunkData.getResidencyState(); + ColumnChunkMetadata metadata; + switch (residency) { + case ResidencyState::IN_MEMORY: { + metadata = chunkData.getMetadataToFlush(); + } break; + case ResidencyState::ON_DISK: { + metadata = chunkData.getMetadata(); + } break; + default: { + KU_UNREACHABLE; + } + } + auto& columnType = chunkData.getDataType(); + outputChunk.getValueVectorMutable(0).setValue(vectorPos, outputData.tableType); + outputChunk.getValueVectorMutable(1).setValue(vectorPos, outputData.nodeGroupIdx); + outputChunk.getValueVectorMutable(2).setValue(vectorPos, outputData.chunkIdx); + outputChunk.getValueVectorMutable(3).setValue(vectorPos, + ResidencyStateUtils::toString(residency)); + outputChunk.getValueVectorMutable(4).setValue(vectorPos, column.getName()); + outputChunk.getValueVectorMutable(5).setValue(vectorPos, columnType.toString()); + outputChunk.getValueVectorMutable(6).setValue(vectorPos, metadata.getStartPageIdx()); + outputChunk.getValueVectorMutable(7).setValue(vectorPos, metadata.getNumPages()); + outputChunk.getValueVectorMutable(8).setValue(vectorPos, metadata.numValues); + + auto customToString = [&](T) { + outputChunk.getValueVectorMutable(9).setValue(vectorPos, + std::to_string(metadata.compMeta.min.get())); + outputChunk.getValueVectorMutable(10).setValue(vectorPos, + std::to_string(metadata.compMeta.max.get())); + }; + auto physicalType = columnType.getPhysicalType(); + TypeUtils::visit( + physicalType, [&](ku_string_t) { customToString(uint32_t()); }, + [&](list_entry_t) { customToString(uint64_t()); }, + [&](internalID_t) { customToString(uint64_t()); }, + [&](T) + requires(std::integral || std::floating_point) + { + auto min = metadata.compMeta.min.get(); + auto max = metadata.compMeta.max.get(); + outputChunk.getValueVectorMutable(9).setValue(vectorPos, + TypeUtils::entryToString(columnType, (uint8_t*)&min, + &outputChunk.getValueVectorMutable(9))); + outputChunk.getValueVectorMutable(10).setValue(vectorPos, + TypeUtils::entryToString(columnType, (uint8_t*)&max, + &outputChunk.getValueVectorMutable(10))); + }, + // Types which don't support statistics. + // types not supported by TypeUtils::visit can + // also be ignored since we don't track statistics for them + [](int128_t) {}, [](struct_entry_t) {}, [](interval_t) {}, [](uint128_t) {}); + outputChunk.getValueVectorMutable(11).setValue(vectorPos, + metadata.compMeta.toString(physicalType)); + outputChunk.state->getSelVectorUnsafe().incrementSelSize(); + if (columnType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID) { + ignoreNull = true; + } + if (!ignoreNull && chunkData.hasNullData()) { + appendStorageInfoForChunkData(localState, outputChunk, outputData, *column.getNullColumn(), + *chunkData.getNullData()); + } + switch (columnType.getPhysicalType()) { + case PhysicalTypeID::STRUCT: { + auto& structChunk = chunkData.cast(); + const auto& structColumn = ku_dynamic_cast(column); + auto numChildren = structChunk.getNumChildren(); + for (auto i = 0u; i < numChildren; i++) { + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *structColumn.getChild(i), structChunk.getChild(i)); + } + } break; + case PhysicalTypeID::STRING: { + auto& stringChunk = chunkData.cast(); + auto& dictionaryChunk = stringChunk.getDictionaryChunk(); + const auto& stringColumn = ku_dynamic_cast(column); + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *stringColumn.getIndexColumn(), *stringChunk.getIndexColumnChunk()); + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *stringColumn.getDictionary().getDataColumn(), *dictionaryChunk.getStringDataChunk()); + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *stringColumn.getDictionary().getOffsetColumn(), *dictionaryChunk.getOffsetChunk()); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + auto& listChunk = chunkData.cast(); + const auto& listColumn = ku_dynamic_cast(column); + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *listColumn.getOffsetColumn(), *listChunk.getOffsetColumnChunk()); + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *listColumn.getSizeColumn(), *listChunk.getSizeColumnChunk()); + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *listColumn.getDataColumn(), *listChunk.getDataColumnChunk()); + } break; + default: { + // DO NOTHING. + } + } +} + +static void appendStorageInfoForChunkedGroup(StorageInfoLocalState* localState, + DataChunk& outputChunk, StorageInfoOutputData& outputData, ChunkedNodeGroup* chunkedGroup) { + auto numColumns = chunkedGroup->getNumColumns(); + outputData.columnIdx = 0; + for (auto i = 0u; i < numColumns; i++) { + for (auto* segment : chunkedGroup->getColumnChunk(i).getSegments()) { + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *outputData.columns[i], *segment); + } + } + if (chunkedGroup->getFormat() == NodeGroupDataFormat::CSR) { + auto& chunkedCSRGroup = chunkedGroup->cast(); + for (auto* segment : chunkedCSRGroup.getCSRHeader().offset->getSegments()) { + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *outputData.columns[numColumns], *segment, true); + } + for (auto* segment : chunkedCSRGroup.getCSRHeader().length->getSegments()) { + appendStorageInfoForChunkData(localState, outputChunk, outputData, + *outputData.columns[numColumns + 1], *segment, true); + } + } +} + +static void appendStorageInfoForNodeGroup(StorageInfoLocalState* localState, DataChunk& outputChunk, + StorageInfoOutputData& outputData, NodeGroup* nodeGroup) { + auto numChunks = nodeGroup->getNumChunkedGroups(); + for (auto chunkIdx = 0ul; chunkIdx < numChunks; chunkIdx++) { + outputData.chunkIdx = chunkIdx; + appendStorageInfoForChunkedGroup(localState, outputChunk, outputData, + nodeGroup->getChunkedNodeGroup(chunkIdx)); + } + if (nodeGroup->getFormat() == NodeGroupDataFormat::CSR) { + auto& csrNodeGroup = nodeGroup->cast(); + auto persistentChunk = csrNodeGroup.getPersistentChunkedGroup(); + if (persistentChunk) { + outputData.chunkIdx = INVALID_NODE_GROUP_IDX; + appendStorageInfoForChunkedGroup(localState, outputChunk, outputData, + csrNodeGroup.getPersistentChunkedGroup()); + } + } +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto& dataChunk = output.dataChunk; + auto localState = ku_dynamic_cast(input.localState); + KU_ASSERT(dataChunk.state->getSelVector().isUnfiltered()); + auto storageManager = StorageManager::Get(*input.context->clientContext); + while (true) { + if (localState->currChunkIdx < localState->dataChunkCollection->getNumChunks()) { + // Copy from local state chunk. + const auto& chunk = + localState->dataChunkCollection->getChunkUnsafe(localState->currChunkIdx); + const auto numValuesToOutput = chunk.state->getSelVector().getSelSize(); + for (auto columnIdx = 0u; columnIdx < dataChunk.getNumValueVectors(); columnIdx++) { + const auto& localVector = chunk.getValueVector(columnIdx); + auto& outputVector = dataChunk.getValueVectorMutable(columnIdx); + for (auto i = 0u; i < numValuesToOutput; i++) { + outputVector.copyFromVectorData(i, &localVector, i); + } + } + dataChunk.state->getSelVectorUnsafe().setToUnfiltered(numValuesToOutput); + localState->currChunkIdx++; + return numValuesToOutput; + } + auto morsel = input.sharedState->ptrCast()->getMorsel(); + if (!morsel.hasMoreToOutput()) { + return 0; + } + const auto bindData = input.bindData->constPtrCast(); + StorageInfoOutputData outputData; + node_group_idx_t numNodeGroups = 0; + switch (bindData->tableEntry->getTableType()) { + case TableType::NODE: { + outputData.tableType = "NODE"; + auto table = storageManager->getTable(bindData->tableEntry->getTableID()); + auto& nodeTable = table->cast(); + std::vector columns; + for (auto columnID = 0u; columnID < nodeTable.getNumColumns(); columnID++) { + columns.push_back(&nodeTable.getColumn(columnID)); + } + outputData.columns = std::move(columns); + numNodeGroups = nodeTable.getNumNodeGroups(); + for (auto i = 0ul; i < numNodeGroups; i++) { + outputData.nodeGroupIdx = i; + appendStorageInfoForNodeGroup(localState, dataChunk, outputData, + nodeTable.getNodeGroup(i)); + } + } break; + case TableType::REL: { + outputData.tableType = "REL"; + for (auto innerEntryInfo : + bindData->tableEntry->cast().getRelEntryInfos()) { + auto& relTable = storageManager->getTable(innerEntryInfo.oid)->cast(); + auto appendDirectedStorageInfo = [&](RelDataDirection direction) { + auto directedRelTableData = relTable.getDirectedTableData(direction); + std::vector columns; + for (auto columnID = 0u; columnID < relTable.getNumColumns(); columnID++) { + columns.push_back(directedRelTableData->getColumn(columnID)); + } + columns.push_back(directedRelTableData->getCSROffsetColumn()); + columns.push_back(directedRelTableData->getCSRLengthColumn()); + outputData.columns = std::move(columns); + numNodeGroups = directedRelTableData->getNumNodeGroups(); + for (auto i = 0ul; i < numNodeGroups; i++) { + outputData.nodeGroupIdx = i; + appendStorageInfoForNodeGroup(localState, dataChunk, outputData, + directedRelTableData->getNodeGroup(i)); + } + }; + for (auto direction : relTable.getStorageDirections()) { + appendDirectedStorageInfo(direction); + } + } + } break; + default: { + KU_UNREACHABLE; + } + } + localState->dataChunkCollection->append(dataChunk); + dataChunk.resetAuxiliaryBuffer(); + dataChunk.state->getSelVectorUnsafe().setSelSize(0); + } +} + +static std::unique_ptr bindFunc(const ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames = {"table_type", "node_group_id", "node_chunk_id", + "residency", "column_name", "data_type", "start_page_idx", "num_pages", "num_values", "min", + "max", "compression"}; + std::vector columnTypes; + columnTypes.emplace_back(LogicalType::STRING()); + columnTypes.emplace_back(LogicalType::INT64()); + columnTypes.emplace_back(LogicalType::INT64()); + columnTypes.emplace_back(LogicalType::STRING()); + columnTypes.emplace_back(LogicalType::STRING()); + columnTypes.emplace_back(LogicalType::STRING()); + columnTypes.emplace_back(LogicalType::INT64()); + columnTypes.emplace_back(LogicalType::INT64()); + columnTypes.emplace_back(LogicalType::INT64()); + columnTypes.emplace_back(LogicalType::STRING()); + columnTypes.emplace_back(LogicalType::STRING()); + columnTypes.emplace_back(LogicalType::STRING()); + auto tableName = input->getLiteralVal(0); + auto catalog = Catalog::Get(*context); + if (!catalog->containsTable(transaction::Transaction::Get(*context), tableName)) { + throw BinderException{"Table " + tableName + " does not exist!"}; + } + auto tableEntry = + catalog->getTableCatalogEntry(transaction::Transaction::Get(*context), tableName); + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(columns, tableEntry, context); +} + +function_set StorageInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = tableFunc; + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = initLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/table_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/table_function.cpp new file mode 100644 index 0000000000..eefe8b5e0a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/table_function.cpp @@ -0,0 +1,115 @@ +#include "function/table/table_function.h" + +#include "common/exception/binder.h" +#include "parser/query/reading_clause/yield_variable.h" +#include "planner/operator/logical_table_function_call.h" +#include "planner/planner.h" +#include "processor/data_pos.h" +#include "processor/operator/table_function_call.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::processor; + +namespace lbug { +namespace function { + +void TableFuncOutput::resetState() { + dataChunk.state->getSelVectorUnsafe().setSelSize(0); + dataChunk.resetAuxiliaryBuffer(); + for (auto i = 0u; i < dataChunk.getNumValueVectors(); i++) { + dataChunk.getValueVectorMutable(i).setAllNonNull(); + } +} + +void TableFuncOutput::setOutputSize(offset_t size) const { + dataChunk.state->getSelVectorUnsafe().setToUnfiltered(size); +} + +TableFunction::~TableFunction() = default; + +std::unique_ptr TableFunction::initEmptyLocalState( + const TableFuncInitLocalStateInput&) { + return std::make_unique(); +} + +std::unique_ptr TableFunction::initEmptySharedState( + const TableFuncInitSharedStateInput& /*input*/) { + return std::make_unique(); +} + +std::unique_ptr TableFunction::initSingleDataChunkScanOutput( + const TableFuncInitOutputInput& input) { + if (input.outColumnPositions.empty()) { + return std::make_unique(DataChunk{}); + } + auto state = input.resultSet.getDataChunk(input.outColumnPositions[0].dataChunkPos)->state; + auto dataChunk = DataChunk(input.outColumnPositions.size(), state); + for (auto i = 0u; i < input.outColumnPositions.size(); ++i) { + dataChunk.insert(i, input.resultSet.getValueVector(input.outColumnPositions[i])); + } + return std::make_unique(std::move(dataChunk)); +} + +std::vector TableFunction::extractYieldVariables(const std::vector& names, + const std::vector& yieldVariables) { + std::vector variableNames; + if (!yieldVariables.empty()) { + if (yieldVariables.size() < names.size()) { + throw BinderException{"Output variables must all appear in the yield clause."}; + } + if (yieldVariables.size() > names.size()) { + throw BinderException{"The number of variables in the yield clause exceeds the " + "number of output variables of the table function."}; + } + for (auto i = 0u; i < names.size(); i++) { + if (names[i] != yieldVariables[i].name) { + throw BinderException{stringFormat( + "Unknown table function output variable name: {}.", yieldVariables[i].name)}; + } + auto variableName = + yieldVariables[i].hasAlias() ? yieldVariables[i].alias : yieldVariables[i].name; + variableNames.push_back(variableName); + } + } else { + variableNames = names; + } + return variableNames; +} + +void TableFunction::getLogicalPlan(Planner* planner, + const binder::BoundReadingClause& boundReadingClause, binder::expression_vector predicates, + LogicalPlan& plan) { + auto op = planner->getTableFunctionCall(boundReadingClause); + planner->planReadOp(op, predicates, plan); +} + +std::unique_ptr TableFunction::getPhysicalPlan(PlanMapper* planMapper, + const LogicalOperator* logicalOp) { + std::vector outPosV; + auto& call = logicalOp->constCast(); + auto outSchema = call.getSchema(); + for (auto& expr : call.getBindData()->columns) { + outPosV.emplace_back(planMapper->getDataPos(*expr, *outSchema)); + } + auto info = TableFunctionCallInfo(); + info.function = call.getTableFunc(); + info.bindData = call.getBindData()->copy(); + info.outPosV = outPosV; + auto initInput = + TableFuncInitSharedStateInput(info.bindData.get(), planMapper->executionContext); + auto sharedState = info.function.initSharedStateFunc(initInput); + auto printInfo = std::make_unique(call.getTableFunc().name, + call.getBindData()->columns); + return std::make_unique(std::move(info), sharedState, + planMapper->getOperatorID(), std::move(printInfo)); +} + +offset_t TableFunction::emptyTableFunc(const TableFuncInput&, TableFuncOutput&) { + // DO NOTHING. + return 0; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/table_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/table_info.cpp new file mode 100644 index 0000000000..a9425b59c5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/table/table_info.cpp @@ -0,0 +1,224 @@ +#include "binder/binder.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/enums/extend_direction_util.h" +#include "common/exception/catalog.h" +#include "common/string_utils.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/simple_table_function.h" +#include "main/client_context.h" +#include "main/database_manager.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct ExtraPropertyInfo { + virtual ~ExtraPropertyInfo() = default; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + + virtual std::unique_ptr copy() const = 0; +}; + +struct ExtraNodePropertyInfo : ExtraPropertyInfo { + bool isPrimaryKey; + + explicit ExtraNodePropertyInfo(bool isPrimaryKey) : isPrimaryKey{isPrimaryKey} {} + + std::unique_ptr copy() const override { + return std::make_unique(isPrimaryKey); + } +}; + +struct ExtraRelPropertyInfo : ExtraPropertyInfo { + std::string storageDirection; + + explicit ExtraRelPropertyInfo(std::string storageDirection) + : storageDirection{std::move(storageDirection)} {} + + std::unique_ptr copy() const override { + return std::make_unique(storageDirection); + } +}; + +struct PropertyInfo { + column_id_t propertyID = INVALID_COLUMN_ID; + std::string name; + std::string type; + std::string defaultVal; + std::unique_ptr extraInfo = nullptr; + + PropertyInfo() = default; + EXPLICIT_COPY_DEFAULT_MOVE(PropertyInfo); + +private: + PropertyInfo(const PropertyInfo& other) + : propertyID{other.propertyID}, name{other.name}, type{other.type}, + defaultVal{other.defaultVal} { + if (other.extraInfo) { + extraInfo = other.extraInfo->copy(); + } + } +}; + +struct TableInfoBindData final : TableFuncBindData { + CatalogEntryType type; + std::vector infos; + + TableInfoBindData(CatalogEntryType type, std::vector infos, + binder::expression_vector columns) + : TableFuncBindData{std::move(columns), infos.size()}, type{type}, infos{std::move(infos)} { + } + + std::unique_ptr copy() const override { + return std::make_unique(type, copyVector(infos), columns); + } +}; + +static offset_t internalTableFunc(const TableFuncMorsel& morsel, const TableFuncInput& input, + DataChunk& output) { + auto bindData = input.bindData->constPtrCast(); + auto i = 0u; + auto size = morsel.getMorselSize(); + for (; i < size; i++) { + auto& info = bindData->infos[morsel.startOffset + i]; + output.getValueVectorMutable(0).setValue(i, info.propertyID); + output.getValueVectorMutable(1).setValue(i, info.name); + output.getValueVectorMutable(2).setValue(i, info.type); + output.getValueVectorMutable(3).setValue(i, info.defaultVal); + switch (bindData->type) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + auto extraInfo = info.extraInfo->ptrCast(); + output.getValueVectorMutable(4).setValue(i, extraInfo->isPrimaryKey); + } break; + case CatalogEntryType::REL_GROUP_ENTRY: { + auto extraInfo = info.extraInfo->ptrCast(); + output.getValueVectorMutable(4).setValue(i, extraInfo->storageDirection); + } break; + default: + break; + } + } + return i; +} + +static PropertyInfo getInfo(const binder::PropertyDefinition& def) { + auto info = PropertyInfo(); + info.name = def.getName(); + info.type = def.getType().toString(); + info.defaultVal = def.getDefaultExpressionName(); + return info; +} + +static std::vector getForeignPropertyInfos(TableCatalogEntry* entry) { + std::vector infos; + for (auto& def : entry->getProperties()) { + auto info = getInfo(def); + info.propertyID = entry->getPropertyID(def.getName()); + infos.push_back(std::move(info)); + } + return infos; +} + +static std::vector getNodePropertyInfos(NodeTableCatalogEntry* entry) { + std::vector infos; + auto primaryKeyName = entry->getPrimaryKeyName(); + for (auto& def : entry->getProperties()) { + auto info = getInfo(def); + info.propertyID = entry->getPropertyID(def.getName()); + info.extraInfo = std::make_unique(primaryKeyName == def.getName()); + infos.push_back(std::move(info)); + } + return infos; +} + +static std::vector getRelPropertyInfos(RelGroupCatalogEntry* entry) { + std::vector infos; + for (auto& def : entry->getProperties()) { + if (def.getName() == InternalKeyword::ID) { + continue; + } + auto info = getInfo(def); + info.propertyID = entry->getPropertyID(def.getName()); + info.extraInfo = std::make_unique( + ExtendDirectionUtil::toString(entry->getStorageDirection())); + infos.push_back(std::move(info)); + } + return infos; +} + +static std::unique_ptr bindFunc(const main::ClientContext* context, + const TableFuncBindInput* input) { + std::vector columnNames; + std::vector columnTypes; + columnNames.emplace_back("property id"); + columnTypes.push_back(LogicalType::INT32()); + columnNames.emplace_back("name"); + columnTypes.push_back(LogicalType::STRING()); + columnNames.emplace_back("type"); + columnTypes.push_back(LogicalType::STRING()); + columnNames.emplace_back("default expression"); + columnTypes.push_back(LogicalType::STRING()); + auto name = common::StringUtils::split(input->getLiteralVal(0), "."); + + std::vector infos; + CatalogEntryType type = CatalogEntryType::DUMMY_ENTRY; + auto transaction = transaction::Transaction::Get(*context); + if (name.size() == 1) { + auto tableName = name[0]; + auto catalog = Catalog::Get(*context); + if (catalog->containsTable(transaction, tableName)) { + auto entry = catalog->getTableCatalogEntry(transaction, tableName); + switch (entry->getType()) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + columnNames.emplace_back("primary key"); + columnTypes.push_back(LogicalType::BOOL()); + infos = getNodePropertyInfos(entry->ptrCast()); + type = CatalogEntryType::NODE_TABLE_ENTRY; + } break; + case CatalogEntryType::REL_GROUP_ENTRY: { + columnNames.emplace_back("storage_direction"); + columnTypes.push_back(LogicalType::STRING()); + infos = getRelPropertyInfos(entry->ptrCast()); + type = CatalogEntryType::REL_GROUP_ENTRY; + } break; + default: + KU_UNREACHABLE; + } + } else { + throw CatalogException(stringFormat("{} does not exist in catalog.", tableName)); + } + } else { + auto dbName = name[0]; + auto tableName = name[1]; + auto db = main::DatabaseManager::Get(*context)->getAttachedDatabase(dbName); + auto entry = db->getCatalog()->getTableCatalogEntry(transaction, tableName); + infos = getForeignPropertyInfos(entry); + type = CatalogEntryType::FOREIGN_TABLE_ENTRY; + } + columnNames = TableFunction::extractYieldVariables(columnNames, input->yieldVariables); + auto columns = input->binder->createVariables(columnNames, columnTypes); + return std::make_unique(type, std::move(infos), columns); +} + +function_set TableInfoFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = SimpleTableFunc::getTableFunc(internalTableFunc); + function->bindFunc = bindFunc; + function->initSharedStateFunc = SimpleTableFunc::initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/timestamp/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/timestamp/CMakeLists.txt new file mode 100644 index 0000000000..a91398b11b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/timestamp/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_timestamp_function + OBJECT + to_epoch_ms.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/timestamp/to_epoch_ms.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/timestamp/to_epoch_ms.cpp new file mode 100644 index 0000000000..b75c0abb4c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/timestamp/to_epoch_ms.cpp @@ -0,0 +1,26 @@ +#include "function/arithmetic/divide.h" +#include "function/scalar_function.h" +#include "function/timestamp/vector_timestamp_functions.h" + +namespace lbug { +namespace function { + +using namespace lbug::common; + +struct ToEpochMs { + static void operation(common::timestamp_t& input, int64_t& result) { + function::Divide::operation(input.value, Interval::MICROS_PER_MSEC, result); + } +}; + +function_set ToEpochMsFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::INT64, + ScalarFunction::UnaryExecFunction); + functionSet.emplace_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/CMakeLists.txt new file mode 100644 index 0000000000..2d54ec741a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_function_union + OBJECT + union_extract_function.cpp + union_tag_function.cpp + union_value_function.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_extract_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_extract_function.cpp new file mode 100644 index 0000000000..e52af666fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_extract_function.cpp @@ -0,0 +1,22 @@ +#include "function/scalar_function.h" +#include "function/struct/vector_struct_functions.h" +#include "function/union/vector_union_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +function_set UnionExtractFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::UNION, LogicalTypeID::STRING}, + LogicalTypeID::ANY); + function->bindFunc = StructExtractFunctions::bindFunc; + function->compileFunc = StructExtractFunctions::compileFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_tag_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_tag_function.cpp new file mode 100644 index 0000000000..a624629367 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_tag_function.cpp @@ -0,0 +1,25 @@ +#include "function/scalar_function.h" +#include "function/union/functions/union_tag.h" +#include "function/union/vector_union_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::STRING()); +} + +function_set UnionTagFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::UNION}, LogicalTypeID::STRING, + ScalarFunction::UnaryExecNestedTypeFunction); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_value_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_value_function.cpp new file mode 100644 index 0000000000..157420c6b5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/union/union_value_function.cpp @@ -0,0 +1,46 @@ +#include "function/scalar_function.h" +#include "function/union/vector_union_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + KU_ASSERT(input.arguments.size() == 1); + std::vector fields; + if (input.arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) { + input.arguments[0]->cast(LogicalType::STRING()); + } + fields.emplace_back(input.arguments[0]->getAlias(), input.arguments[0]->getDataType().copy()); + auto resultType = LogicalType::UNION(std::move(fields)); + return FunctionBindData::getSimpleBindData(input.arguments, resultType); +} + +static void execFunc(const std::vector>&, + const std::vector&, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + UnionVector::setTagField(result, *resultSelVector, UnionType::TAG_FIELD_IDX); +} + +static void valueCompileFunc(FunctionBindData* /*bindData*/, + const std::vector>& parameters, + std::shared_ptr& result) { + KU_ASSERT(parameters.size() == 1); + result->setState(parameters[0]->state); + UnionVector::getTagVector(result.get())->setState(parameters[0]->state); + UnionVector::referenceVector(result.get(), UnionType::TAG_FIELD_IDX, parameters[0]); +} + +function_set UnionValueFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::UNION, execFunc); + function->bindFunc = bindFunc; + function->compileFunc = valueCompileFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/CMakeLists.txt new file mode 100644 index 0000000000..712f692e34 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/CMakeLists.txt @@ -0,0 +1,14 @@ +add_library(lbug_utility_function + OBJECT + coalesce.cpp + md5.cpp + sha256.cpp + constant_or_null.cpp + count_if.cpp + error.cpp + nullif.cpp + typeof.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/coalesce.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/coalesce.cpp new file mode 100644 index 0000000000..146ca347e7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/coalesce.cpp @@ -0,0 +1,99 @@ +#include "binder/expression/expression_util.h" +#include "common/exception/binder.h" +#include "function/scalar_function.h" +#include "function/utility/vector_utility_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + if (input.arguments.empty()) { + throw BinderException("COALESCE requires at least one argument"); + } + LogicalType resultType(LogicalTypeID::ANY); + binder::ExpressionUtil::tryCombineDataType(input.arguments, resultType); + if (resultType.getLogicalTypeID() == LogicalTypeID::ANY) { + resultType = LogicalType::STRING(); + } + auto bindData = std::make_unique(resultType.copy()); + for (auto& _ : input.arguments) { + (void)_; + bindData->paramTypes.push_back(resultType.copy()); + } + return bindData; +} + +static void execFunc(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + result.resetAuxiliaryBuffer(); + for (auto i = 0u; i < resultSelVector->getSelSize(); ++i) { + auto resultPos = (*resultSelVector)[i]; + auto isNull = true; + for (size_t i = 0; i < params.size(); ++i) { + const auto& param = *params[i]; + const auto& paramSelVector = *paramSelVectors[i]; + auto paramPos = param.state->isFlat() ? paramSelVector[0] : resultPos; + if (!param.isNull(paramPos)) { + result.copyFromVectorData(resultPos, ¶m, paramPos); + isNull = false; + break; + } + } + result.setNull(resultPos, isNull); + } +} + +static bool selectFunc(const std::vector>& params, + SelectionVector& selVector, void* /* dataPtr */) { + KU_ASSERT(!params.empty()); + auto unFlatVectorIdx = 0u; + for (auto i = 0u; i < params.size(); ++i) { + if (!params[i]->state->isFlat()) { + unFlatVectorIdx = i; + break; + } + } + auto numSelectedValues = 0u; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + for (auto i = 0u; i < params[unFlatVectorIdx]->state->getSelVector().getSelSize(); ++i) { + auto resultPos = params[unFlatVectorIdx]->state->getSelVector()[i]; + auto resultValue = false; + for (auto& param : params) { + auto paramPos = param->state->isFlat() ? param->state->getSelVector()[0] : resultPos; + if (!param->isNull(paramPos)) { + resultValue = param->getValue(paramPos); + break; + } + } + selectedPositionsBuffer[numSelectedValues] = resultPos; + numSelectedValues += resultValue; + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; +} + +function_set CoalesceFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::ANY, execFunc, selectFunc); + function->bindFunc = bindFunc; + function->isVarLength = true; + functionSet.push_back(std::move(function)); + return functionSet; +} + +function_set IfNullFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY, LogicalTypeID::ANY}, LogicalTypeID::ANY, + execFunc, selectFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/constant_or_null.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/constant_or_null.cpp new file mode 100644 index 0000000000..28a4dc8bac --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/constant_or_null.cpp @@ -0,0 +1,81 @@ +#include "function/scalar_function.h" +#include "function/utility/vector_utility_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + logical_type_vec_t paramTypes; + for (auto& argument : input.arguments) { + if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + paramTypes.push_back(LogicalType::STRING()); + } else { + paramTypes.push_back(argument->getDataType().copy()); + } + } + auto bindData = std::make_unique(paramTypes[0].copy()); + bindData->paramTypes = std::move(paramTypes); + return bindData; +} + +static void execFunc(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + KU_ASSERT(params.size() == 2); + result.resetAuxiliaryBuffer(); + for (auto i = 0u; i < resultSelVector->getSelSize(); ++i) { + auto resultPos = (*resultSelVector)[i]; + auto firstParamPos = params[0]->state->isFlat() ? (*paramSelVectors[0])[0] : resultPos; + auto secondParamPos = params[1]->state->isFlat() ? (*paramSelVectors[1])[0] : resultPos; + if (params[1]->isNull(secondParamPos) || params[0]->isNull(firstParamPos)) { + result.setNull(resultPos, true); + } else { + result.setNull(resultPos, false); + result.copyFromVectorData(resultPos, params[0].get(), firstParamPos); + } + } +} + +static bool selectFunc(const std::vector>& params, + SelectionVector& selVector, void* /* dataPtr */) { + KU_ASSERT(params.size() == 2); + auto unFlatVectorIdx = 0u; + for (auto i = 0u; i < params.size(); ++i) { + if (!params[i]->state->isFlat()) { + unFlatVectorIdx = i; + break; + } + } + auto numSelectedValues = 0u; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + for (auto i = 0u; i < params[unFlatVectorIdx]->state->getSelVector().getSelSize(); ++i) { + auto resultPos = params[unFlatVectorIdx]->state->getSelVector()[i]; + auto resultValue = false; + auto firstParamPos = + params[0]->state->isFlat() ? params[0]->state->getSelVector()[0] : resultPos; + auto secondParamPos = + params[1]->state->isFlat() ? params[1]->state->getSelVector()[0] : resultPos; + if (!params[1]->isNull(secondParamPos) && !params[0]->isNull(firstParamPos)) { + resultValue = params[0]->getValue(firstParamPos); + } + selectedPositionsBuffer[numSelectedValues] = resultPos; + numSelectedValues += resultValue; + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; +} + +function_set ConstantOrNullFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY, LogicalTypeID::ANY}, LogicalTypeID::ANY, + execFunc, selectFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/count_if.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/count_if.cpp new file mode 100644 index 0000000000..22463f95b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/count_if.cpp @@ -0,0 +1,41 @@ +#include "common/type_utils.h" +#include "function/scalar_function.h" +#include "function/utility/vector_utility_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct CountIf { + template + static inline void operation(T& input, uint8_t& result) { + if (input != 0) { + result = 1; + } else { + result = 0; + } + } +}; + +function_set CountIfFunction::getFunctionSet() { + function_set functionSet; + auto operandTypeIDs = LogicalTypeUtils::getNumericalLogicalTypeIDs(); + operandTypeIDs.push_back(LogicalTypeID::BOOL); + scalar_func_exec_t execFunc; + for (auto operandTypeID : operandTypeIDs) { + TypeUtils::visit( + LogicalType(operandTypeID), + [&execFunc]( + T) { execFunc = ScalarFunction::UnaryExecFunction; }, + [&execFunc]( + bool) { execFunc = ScalarFunction::UnaryExecFunction; }, + [](auto) { KU_UNREACHABLE; }); + functionSet.push_back(std::make_unique(name, + std::vector{operandTypeID}, LogicalTypeID::UINT8, execFunc)); + } + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/error.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/error.cpp new file mode 100644 index 0000000000..5ac7b6f039 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/error.cpp @@ -0,0 +1,28 @@ +#include "common/exception/runtime.h" +#include "function/scalar_function.h" +#include "function/utility/vector_utility_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct Error { + static void operation(ku_string_t& input, int32_t& result) { + result = 0; + throw RuntimeException(input.getAsString()); + } +}; + +function_set ErrorFunction::getFunctionSet() { + function_set functionSet; + functionSet.push_back( + std::make_unique(name, std::vector{LogicalTypeID::STRING}, + LogicalTypeID::INT32, ScalarFunction::UnaryExecFunction)); + // int32_t is just a dummy resultType for error(), since this function throws an exception + // instead of returns any result + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/md5.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/md5.cpp new file mode 100644 index 0000000000..7d1c334162 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/md5.cpp @@ -0,0 +1,28 @@ +#include "common/md5.h" + +#include "function/hash/vector_hash_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct MD5Operator { + static void operation(ku_string_t& operand, ku_string_t& result, ValueVector& resultVector) { + MD5 hasher; + hasher.addToMD5(reinterpret_cast(operand.getData()), operand.len); + StringVector::addString(&resultVector, result, std::string(hasher.finishMD5())); + } +}; + +function_set MD5Function::getFunctionSet() { + function_set functionSet; + functionSet.push_back(std::make_unique(name, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::STRING, + ScalarFunction::UnaryStringExecFunction)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/nullif.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/nullif.cpp new file mode 100644 index 0000000000..8021d7bdda --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/nullif.cpp @@ -0,0 +1,38 @@ +#include "binder/expression/case_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "function/rewrite_function.h" +#include "function/utility/vector_utility_functions.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 2); + auto uniqueExpressionName = + ScalarFunctionExpression::getUniqueName(NullIfFunction::name, input.arguments); + const auto& resultType = input.arguments[0]->getDataType(); + auto caseExpression = std::make_shared(resultType.copy(), input.arguments[0], + uniqueExpressionName); + auto binder = input.expressionBinder; + auto whenExpression = binder->bindComparisonExpression(ExpressionType::EQUALS, input.arguments); + auto thenExpression = binder->createNullLiteralExpression(); + thenExpression = binder->implicitCastIfNecessary(thenExpression, resultType.copy()); + caseExpression->addCaseAlternative(whenExpression, thenExpression); + return caseExpression; +} + +function_set NullIfFunction::getFunctionSet() { + function_set functionSet; + for (auto typeID : LogicalTypeUtils::getAllValidLogicTypeIDs()) { + functionSet.push_back(std::make_unique(name, + std::vector{typeID, typeID}, rewriteFunc)); + } + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/sha256.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/sha256.cpp new file mode 100644 index 0000000000..a5b293a08e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/sha256.cpp @@ -0,0 +1,32 @@ +#include "common/sha256.h" + +#include "function/hash/vector_hash_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct SHA256Operator { + static void operation(ku_string_t& operand, ku_string_t& result, ValueVector& resultVector) { + StringVector::reserveString(&resultVector, result, SHA256::SHA256_HASH_LENGTH_TEXT); + SHA256 hasher; + hasher.addString(operand.getAsString()); + hasher.finishSHA256(reinterpret_cast(result.getDataUnsafe())); + if (!ku_string_t::isShortString(result.len)) { + memcpy(result.prefix, result.getData(), ku_string_t::PREFIX_LENGTH); + } + } +}; + +function_set SHA256Function::getFunctionSet() { + function_set functionSet; + functionSet.push_back(std::make_unique(name, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::STRING, + ScalarFunction::UnaryStringExecFunction)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/typeof.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/typeof.cpp new file mode 100644 index 0000000000..a3d7747d84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/utility/typeof.cpp @@ -0,0 +1,43 @@ +#include "function/scalar_function.h" +#include "function/utility/function_string_bind_data.h" +#include "function/utility/vector_utility_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::unique_ptr bindData; + if (input.arguments[0]->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) { + bindData = std::make_unique("NULL"); + bindData->paramTypes.push_back(LogicalType::STRING()); + } else { + bindData = + std::make_unique(input.arguments[0]->getDataType().toString()); + } + return bindData; +} + +static void execFunc(const std::vector>&, + const std::vector&, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + result.resetAuxiliaryBuffer(); + auto typeData = reinterpret_cast(dataPtr); + for (auto i = 0u; i < resultSelVector->getSelSize(); ++i) { + auto resultPos = (*resultSelVector)[i]; + StringVector::addString(&result, resultPos, typeData->str); + } +} + +function_set TypeOfFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::STRING, execFunc); + function->bindFunc = bindFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/uuid/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/function/uuid/CMakeLists.txt new file mode 100644 index 0000000000..be61343925 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/uuid/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_function_uuid + OBJECT + gen_random_uuid.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/uuid/gen_random_uuid.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/uuid/gen_random_uuid.cpp new file mode 100644 index 0000000000..2dd0cbb481 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/uuid/gen_random_uuid.cpp @@ -0,0 +1,15 @@ +#include "function/uuid/functions/gen_random_uuid.h" + +#include "common/random_engine.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +void GenRandomUUID::operation(common::ku_uuid_t& input, void* dataPtr) { + auto clientContext = static_cast(dataPtr)->clientContext; + input = common::UUID::generateRandomUUID(common::RandomEngine::Get(*clientContext)); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_arithmetic_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_arithmetic_functions.cpp new file mode 100644 index 0000000000..b5c3b9993e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_arithmetic_functions.cpp @@ -0,0 +1,875 @@ +#include "function/arithmetic/vector_arithmetic_functions.h" + +#include "common/exception/overflow.h" +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "common/types/date_t.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/timestamp_t.h" +#include "function/arithmetic/abs.h" +#include "function/arithmetic/add.h" +#include "function/arithmetic/arithmetic_functions.h" +#include "function/arithmetic/divide.h" +#include "function/arithmetic/modulo.h" +#include "function/arithmetic/multiply.h" +#include "function/arithmetic/negate.h" +#include "function/arithmetic/subtract.h" +#include "function/cast/functions/numeric_limits.h" +#include "function/list/functions/list_concat_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" +#include "function/string/vector_string_functions.h" + +using namespace lbug::common; +using std::max; +using std::min; + +namespace lbug { +namespace function { + +struct DecimalFunction { + + static std::unique_ptr bindAddFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindSubtractFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindMultiplyFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindDivideFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindModuloFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindNegateFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindAbsFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindFloorFunc(ScalarBindFuncInput input); + + static std::unique_ptr bindCeilFunc(ScalarBindFuncInput input); +}; + +template +static std::unique_ptr getUnaryFunction(std::string name, + LogicalTypeID operandTypeID) { + function::scalar_func_exec_t execFunc; + common::TypeUtils::visit( + LogicalType(operandTypeID), + [&](T) { execFunc = ScalarFunction::UnaryExecFunction; }, + [](auto) { KU_UNREACHABLE; }); + return std::make_unique(std::move(name), + std::vector{operandTypeID}, operandTypeID, execFunc); +} + +template +static std::unique_ptr getUnaryFunction(std::string name, + LogicalTypeID operandTypeID, LogicalTypeID resultTypeID) { + return std::make_unique(std::move(name), + std::vector{operandTypeID}, resultTypeID, + ScalarFunction::UnaryExecFunction); +} + +template +static std::unique_ptr getBinaryFunction(std::string name, + common::LogicalTypeID operandTypeID) { + function::scalar_func_exec_t execFunc; + common::TypeUtils::visit( + common::LogicalType(operandTypeID), + [&]( + T) { execFunc = ScalarFunction::BinaryExecFunction; }, + [](auto) { KU_UNREACHABLE; }); + return std::make_unique(std::move(name), + std::vector{operandTypeID, operandTypeID}, operandTypeID, execFunc); +} + +template +static std::unique_ptr getBinaryFunction(std::string name, + LogicalTypeID operandTypeID, LogicalTypeID resultTypeID) { + return std::make_unique(std::move(name), + std::vector{operandTypeID, operandTypeID}, resultTypeID, + ScalarFunction::BinaryExecFunction); +} + +function_set AddFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getBinaryFunction(name, typeID)); + } + + // decimal + decimal -> decimal + std::unique_ptr func; + func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL, LogicalTypeID::DECIMAL}, + LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindAddFunc; + result.push_back(std::move(func)); + // list + list -> list + func = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST, + ScalarFunction::BinaryExecListStructFunction); + func->bindFunc = ListConcatFunction::bindFunc; + result.push_back(std::move(func)); + // string + string -> string + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::STRING, ConcatFunction::execFunc)); + // interval + interval → interval + result.push_back(getBinaryFunction(name, LogicalTypeID::INTERVAL, + LogicalTypeID::INTERVAL)); + // date + int → date + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE, LogicalTypeID::INT64}, LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + // int + date → date + result.push_back(make_unique(name, + std::vector{LogicalTypeID::INT64, LogicalTypeID::DATE}, LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + // date + interval → date + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE, LogicalTypeID::INTERVAL}, + LogicalTypeID::DATE, ScalarFunction::BinaryExecFunction)); + // interval + date → date + result.push_back(make_unique(name, + std::vector{LogicalTypeID::INTERVAL, LogicalTypeID::DATE}, + LogicalTypeID::DATE, ScalarFunction::BinaryExecFunction)); + // timestamp + interval → timestamp + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL}, + LogicalTypeID::TIMESTAMP, + ScalarFunction::BinaryExecFunction)); + // interval + timestamp → timestamp + result.push_back(make_unique(name, + std::vector{LogicalTypeID::INTERVAL, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, + ScalarFunction::BinaryExecFunction)); + return result; +} + +function_set SubtractFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getBinaryFunction(name, typeID)); + } + // decimal - decimal -> decimal + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL, LogicalTypeID::DECIMAL}, + LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindSubtractFunc; + result.push_back(std::move(func)); + // date - date → int64 + result.push_back(getBinaryFunction(name, LogicalTypeID::DATE, + LogicalTypeID::INT64)); + // date - integer → date + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE, LogicalTypeID::INT64}, LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + // date - interval → date + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE, LogicalTypeID::INTERVAL}, + LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + // timestamp - timestamp → interval + result.push_back(getBinaryFunction(name, + LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL)); + // timestamp - interval → timestamp + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL}, + LogicalTypeID::TIMESTAMP, + ScalarFunction::BinaryExecFunction)); + // interval - interval → interval + result.push_back(getBinaryFunction(name, + LogicalTypeID::INTERVAL, LogicalTypeID::INTERVAL)); + return result; +} + +function_set MultiplyFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getBinaryFunction(name, typeID)); + } + // decimal * decimal -> decimal + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL, LogicalTypeID::DECIMAL}, + LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindMultiplyFunc; + result.push_back(std::move(func)); + return result; +} + +function_set DivideFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getBinaryFunction(name, typeID)); + } + // interval / int → interval + result.push_back(make_unique(name, + std::vector{LogicalTypeID::INTERVAL, LogicalTypeID::INT64}, + LogicalTypeID::INTERVAL, + ScalarFunction::BinaryExecFunction)); + // decimal / decimal -> decimal + // drop to double division for now + // result.push_back(make_unique(name, + // std::vector{LogicalTypeID::DECIMAL, LogicalTypeID::DECIMAL}, + // LogicalTypeID::DECIMAL, nullptr, nullptr, DecimalFunction::bindDivideFunc)); + return result; +} + +function_set ModuloFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getBinaryFunction(name, typeID)); + } + // decimal % decimal -> decimal + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL, LogicalTypeID::DECIMAL}, + LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindModuloFunc; + result.push_back(std::move(func)); + return result; +} + +function_set PowerFunction::getFunctionSet() { + function_set result; + // double ^ double -> double + result.push_back( + getBinaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set NegateFunction::getFunctionSet() { + function_set result; + for (auto& typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getUnaryFunction(name, typeID)); + } + // floor(decimal) -> decimal + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL}, LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindNegateFunc; + result.push_back(std::move(func)); + return result; +} + +function_set AbsFunction::getFunctionSet() { + function_set result; + for (auto& typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getUnaryFunction(name, typeID)); + } + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL}, LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindAbsFunc; + result.push_back(std::move(func)); + return result; +} + +function_set FloorFunction::getFunctionSet() { + function_set result; + for (auto& typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getUnaryFunction(name, typeID)); + } + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL}, LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindFloorFunc; + result.push_back(std::move(func)); + return result; +} + +function_set CeilFunction::getFunctionSet() { + function_set result; + for (auto& typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(getUnaryFunction(name, typeID)); + } + auto func = std::make_unique(name, + std::vector{LogicalTypeID::DECIMAL}, LogicalTypeID::DECIMAL); + func->bindFunc = DecimalFunction::bindCeilFunc; + result.push_back(std::move(func)); + return result; +} + +function_set SinFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set CosFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set TanFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set CotFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set AsinFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set AcosFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set AtanFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set FactorialFunction::getFunctionSet() { + function_set result; + result.push_back( + make_unique(name, std::vector{LogicalTypeID::INT64}, + LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction)); + return result; +} + +function_set SqrtFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set CbrtFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set GammaFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set LgammaFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set LnFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set LogFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set Log2Function::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set DegreesFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set RadiansFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set EvenFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set SignFunction::getFunctionSet() { + function_set result; + result.push_back( + getUnaryFunction(name, LogicalTypeID::INT64, LogicalTypeID::INT64)); + result.push_back( + getUnaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::INT64)); + result.push_back( + getUnaryFunction(name, LogicalTypeID::FLOAT, LogicalTypeID::INT64)); + return result; +} + +function_set Atan2Function::getFunctionSet() { + function_set result; + result.push_back( + getBinaryFunction(name, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); + return result; +} + +function_set RoundFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DOUBLE, LogicalTypeID::INT64}, + LogicalTypeID::DOUBLE, ScalarFunction::BinaryExecFunction)); + return result; +} + +function_set BitwiseXorFunction::getFunctionSet() { + function_set result; + result.push_back( + getBinaryFunction(name, LogicalTypeID::INT64, LogicalTypeID::INT64)); + return result; +} + +function_set BitwiseAndFunction::getFunctionSet() { + function_set result; + result.push_back( + getBinaryFunction(name, LogicalTypeID::INT64, LogicalTypeID::INT64)); + return result; +} + +function_set BitwiseOrFunction::getFunctionSet() { + function_set result; + result.push_back( + getBinaryFunction(name, LogicalTypeID::INT64, LogicalTypeID::INT64)); + return result; +} + +function_set BitShiftLeftFunction::getFunctionSet() { + function_set result; + result.push_back( + getBinaryFunction(name, LogicalTypeID::INT64, LogicalTypeID::INT64)); + return result; +} + +function_set BitShiftRightFunction::getFunctionSet() { + function_set result; + result.push_back(getBinaryFunction(name, LogicalTypeID::INT64, + LogicalTypeID::INT64)); + return result; +} + +function_set PiFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, std::vector{}, + LogicalTypeID::DOUBLE, ScalarFunction::NullaryExecFunction)); + return result; +} + +using param_get_func_t = std::function(int, int, int, int)>; + +// Following param func rules are from +// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql +// todo: Figure out which param rules we should use + +struct DecimalAdd { + static constexpr bool matchToOutputLogicalType = true; + // whether or not the input and output logical types + // are expected to be equivalent. If so, the bind function + // should specify that the input be casted to the output type before execution + template + static inline void operation(A& left, B& right, R& result, + common::ValueVector& resultValueVector) { + constexpr auto pow10s = pow10Sequence(); + auto precision = DecimalType::getPrecision(resultValueVector.dataType); + if ((right > 0 && pow10s[precision] - right <= left) || + (right < 0 && -pow10s[precision] - right >= left)) { + throw OverflowException("Decimal Addition result is out of range"); + } + result = left + right; + } + + static std::pair resultingParams(int p1, int p2, int s1, int s2) { + auto p = min(DECIMAL_PRECISION_LIMIT, max(s1, s2) + max(p1 - s1, p2 - s2) + 1); + auto s = min(p, max(s1, s2)); + if (max(p1 - s1, p2 - s2) < min(DECIMAL_PRECISION_LIMIT, p) - s) { + s = min(p, DECIMAL_PRECISION_LIMIT) - max(p1 - s1, p2 - s2); + } + return {p, s}; + } +}; + +struct DecimalSubtract { + static constexpr bool matchToOutputLogicalType = true; + template + static inline void operation(A& left, B& right, R& result, + common::ValueVector& resultValueVector) { + constexpr auto pow10s = pow10Sequence(); + auto precision = DecimalType::getPrecision(resultValueVector.dataType); + if ((right > 0 && -pow10s[precision] + right >= left) || + (right < 0 && pow10s[precision] + right <= left)) { + throw OverflowException("Decimal Subtraction result is out of range"); + } + result = left - right; + } + + static std::pair resultingParams(int p1, int p2, int s1, int s2) { + auto p = min(DECIMAL_PRECISION_LIMIT, max(s1, s2) + max(p1 - s1, p2 - s2) + 1); + auto s = min(p, max(s1, s2)); + if (max(p1 - s1, p2 - s2) < min(DECIMAL_PRECISION_LIMIT, p) - s) { + s = min(p, DECIMAL_PRECISION_LIMIT) - max(p1 - s1, p2 - s2); + } + return {p, s}; + } +}; + +struct DecimalMultiply { + static constexpr bool matchToOutputLogicalType = false; + template + static inline void operation(A& left, B& right, R& result, + common::ValueVector& resultValueVector) { + constexpr auto pow10s = pow10Sequence(); + auto precision = DecimalType::getPrecision(resultValueVector.dataType); + result = (R)left * (R)right; + // no need to divide by any scale given resultingParams and matchToOutput + if (result <= -pow10s[precision] || result >= pow10s[precision]) { + [[unlikely]] throw OverflowException("Decimal Multiplication Result is out of range"); + } + } + + static std::pair resultingParams(int p1, int p2, int s1, int s2) { + if (p1 + p2 + 1 > DECIMAL_PRECISION_LIMIT) { + throw OverflowException( + "Resulting precision of decimal multiplication greater than 38"); + } + auto p = p1 + p2 + 1; + auto s = s1 + s2; + return {p, s}; + } +}; + +struct DecimalDivide { + static constexpr bool matchToOutputLogicalType = true; + template + static inline void operation(A& left, B& right, R& result, + common::ValueVector& resultValueVector) { + constexpr auto pow10s = pow10Sequence(); + auto precision = DecimalType::getPrecision(resultValueVector.dataType); + auto scale = DecimalType::getScale(resultValueVector.dataType); + if (right == 0) { + throw RuntimeException("Divide by zero."); + } + if (-pow10s[precision - scale] >= left || pow10s[precision - scale] <= left) { + throw OverflowException("Overflow encountered when attempting to divide decimals"); + // happens too often; let's just drop to double division for now, which is in line with + // what DuckDB does right now + } + result = (left * pow10s[scale]) / right; + } + + static std::pair resultingParams(int p1, int p2, int s1, int s2) { + auto p = min(DECIMAL_PRECISION_LIMIT, p1 - s1 + s2 + max(6, s1 + p2 + 1)); + auto s = min(p, max(6, s1 + p2 + 1)); // todo: complete rules + return {p, s}; + } +}; + +struct DecimalModulo { + static constexpr bool matchToOutputLogicalType = true; + template + static inline void operation(A& left, B& right, R& result, common::ValueVector&) { + if (right == 0) { + throw RuntimeException("Modulo by zero."); + } + result = left % right; + } + + static std::pair resultingParams(int p1, int p2, int s1, int s2) { + auto p = min(DECIMAL_PRECISION_LIMIT, min(p1 - s1, p2 - s2) + max(s1, s2)); + auto s = min(p, max(s1, s2)); + return {p, s}; + } +}; + +struct DecimalNegate { + static constexpr bool matchToOutputLogicalType = true; + template + static inline void operation(A& input, R& result, common::ValueVector&, common::ValueVector&) { + result = -input; + } + + static std::pair resultingParams(int p, int s) { return {p, s}; } +}; + +struct DecimalAbs { + static constexpr bool matchToOutputLogicalType = true; + template + static inline void operation(A& input, R& result, common::ValueVector&, common::ValueVector&) { + result = input; + if (result < 0) { + result = -result; + } + } + + static std::pair resultingParams(int p, int s) { return {p, s}; } +}; + +struct DecimalFloor { + static constexpr bool matchToOutputLogicalType = false; + template + static inline void operation(A& input, R& result, common::ValueVector& inputVector, + common::ValueVector&) { + constexpr auto pow10s = pow10Sequence(); + auto scale = DecimalType::getScale(inputVector.dataType); + if (input < 0) { + // round to larger absolute value + result = (R)input - + (input % pow10s[scale] == 0 ? 0 : pow10s[scale] + (R)(input % pow10s[scale])); + } else { + // round to smaller absolute value + result = (R)input - (R)(input % pow10s[scale]); + } + result = result / pow10s[scale]; + } + + static std::pair resultingParams(int p, int) { return {p, 0}; } +}; + +struct DecimalCeil { + static constexpr bool matchToOutputLogicalType = false; + template + static inline void operation(A& input, R& result, common::ValueVector& inputVector, + common::ValueVector&) { + constexpr auto pow10s = pow10Sequence(); + auto scale = DecimalType::getScale(inputVector.dataType); + if (input < 0) { + // round to larger absolute value + result = (R)input - (R)(input % pow10s[scale]); + } else { + // round to smaller absolute value + result = (R)input + + (input % pow10s[scale] == 0 ? 0 : pow10s[scale] - (R)(input % pow10s[scale])); + } + result = result / pow10s[scale]; + } + + static std::pair resultingParams(int p, int) { return {p, 0}; } +}; + +template +static void getBinaryExecutionHelperB(const LogicalType& typeR, scalar_func_exec_t& result) { + // here to assist in getting scalar_func_exec_t for genericBinaryArithmeticFunc + switch (typeR.getPhysicalType()) { + case PhysicalTypeID::INT16: + result = ScalarFunction::BinaryStringExecFunction; + break; + case PhysicalTypeID::INT32: + result = ScalarFunction::BinaryStringExecFunction; + break; + case PhysicalTypeID::INT64: + result = ScalarFunction::BinaryStringExecFunction; + break; + case PhysicalTypeID::INT128: + result = ScalarFunction::BinaryStringExecFunction; + break; + default: + KU_UNREACHABLE; + } +} + +template +static void getBinaryExecutionHelperA(const LogicalType& typeB, const LogicalType& typeR, + scalar_func_exec_t& result) { + // here to assist in getting scalar_func_exec_t for genericBinaryArithmeticFunc + switch (typeB.getPhysicalType()) { + case PhysicalTypeID::INT16: + getBinaryExecutionHelperB(typeR, result); + break; + case PhysicalTypeID::INT32: + getBinaryExecutionHelperB(typeR, result); + break; + case PhysicalTypeID::INT64: + getBinaryExecutionHelperB(typeR, result); + break; + case PhysicalTypeID::INT128: + getBinaryExecutionHelperB(typeR, result); + break; + default: + KU_UNREACHABLE; + } +} + +template +static std::unique_ptr genericBinaryArithmeticFunc( + const binder::expression_vector& arguments, Function* func) { + auto asScalar = ku_dynamic_cast(func); + KU_ASSERT(asScalar != nullptr); + auto argADataType = arguments[0]->getDataType().copy(); + auto argBDataType = arguments[1]->getDataType().copy(); + if (argADataType.getLogicalTypeID() != LogicalTypeID::DECIMAL) { + argADataType = argBDataType.copy(); + } + if (argBDataType.getLogicalTypeID() != LogicalTypeID::DECIMAL) { + argBDataType = argADataType.copy(); + } + auto precision1 = DecimalType::getPrecision(argADataType); + auto precision2 = DecimalType::getPrecision(argBDataType); + auto scale1 = DecimalType::getScale(argADataType); + auto scale2 = DecimalType::getScale(argBDataType); + auto params = FUNC::resultingParams(precision1, precision2, scale1, scale2); + auto resultingType = LogicalType::DECIMAL(params.first, params.second); + auto argumentAType = + FUNC::matchToOutputLogicalType ? resultingType.copy() : argADataType.copy(); + auto argumentBType = + FUNC::matchToOutputLogicalType ? resultingType.copy() : argBDataType.copy(); + if constexpr (FUNC::matchToOutputLogicalType) { + common::TypeUtils::visit( + resultingType.getPhysicalType(), + [&](T) { + asScalar->execFunc = ScalarFunction::BinaryStringExecFunction; + }, + [](auto) { KU_UNREACHABLE; }); + } else { + common::TypeUtils::visit( + argumentAType.getPhysicalType(), + [&](T) { + getBinaryExecutionHelperA(argumentBType, resultingType, + asScalar->execFunc); + }, + [](auto) { KU_UNREACHABLE; }); + } + std::vector resVec; + resVec.push_back(std::move(argumentAType)); + resVec.push_back(std::move(argumentBType)); + resVec.push_back(resultingType.copy()); + return std::make_unique(std::move(resVec), std::move(resultingType)); +} + +template +static void getUnaryExecutionHelper(const LogicalType& resultType, scalar_func_exec_t& result) { + switch (resultType.getPhysicalType()) { + case PhysicalTypeID::INT16: + result = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT32: + result = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT64: + result = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT128: + result = ScalarFunction::UnaryExecNestedTypeFunction; + break; + default: + KU_UNREACHABLE; + } +} + +template +static std::unique_ptr genericUnaryArithmeticFunc( + const binder::expression_vector& arguments, Function* func) { + auto asScalar = ku_dynamic_cast(func); + KU_ASSERT(asScalar != nullptr); + auto argPrecision = DecimalType::getPrecision(arguments[0]->getDataType()); + auto argScale = DecimalType::getScale(arguments[0]->getDataType()); + auto params = FUNC::resultingParams(argPrecision, argScale); + auto resultingType = LogicalType::DECIMAL(params.first, params.second); + auto argumentType = + FUNC::matchToOutputLogicalType ? resultingType.copy() : arguments[0]->getDataType().copy(); + if constexpr (FUNC::matchToOutputLogicalType) { + switch (resultingType.getPhysicalType()) { + case PhysicalTypeID::INT16: + asScalar->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT32: + asScalar->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT64: + asScalar->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT128: + asScalar->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + default: + KU_UNREACHABLE; + } + } else { + switch (argumentType.getPhysicalType()) { + case PhysicalTypeID::INT16: + getUnaryExecutionHelper(resultingType, asScalar->execFunc); + break; + case PhysicalTypeID::INT32: + getUnaryExecutionHelper(resultingType, asScalar->execFunc); + break; + case PhysicalTypeID::INT64: + getUnaryExecutionHelper(resultingType, asScalar->execFunc); + break; + case PhysicalTypeID::INT128: + getUnaryExecutionHelper(resultingType, asScalar->execFunc); + break; + default: + KU_UNREACHABLE; + } + } + std::vector argTypes; + argTypes.push_back(std::move(argumentType)); + return std::make_unique(std::move(argTypes), std::move(resultingType)); +} + +std::unique_ptr DecimalFunction::bindAddFunc(ScalarBindFuncInput input) { + return genericBinaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindSubtractFunc(ScalarBindFuncInput input) { + return genericBinaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindMultiplyFunc(ScalarBindFuncInput input) { + return genericBinaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindDivideFunc(ScalarBindFuncInput input) { + return genericBinaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindModuloFunc(ScalarBindFuncInput input) { + return genericBinaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindNegateFunc(ScalarBindFuncInput input) { + return genericUnaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindAbsFunc(ScalarBindFuncInput input) { + return genericUnaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindFloorFunc(ScalarBindFuncInput input) { + return genericUnaryArithmeticFunc(input.arguments, input.definition); +} + +std::unique_ptr DecimalFunction::bindCeilFunc(ScalarBindFuncInput input) { + return genericUnaryArithmeticFunc(input.arguments, input.definition); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_blob_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_blob_functions.cpp new file mode 100644 index 0000000000..95aa60258d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_blob_functions.cpp @@ -0,0 +1,38 @@ +#include "function/blob/vector_blob_functions.h" + +#include "function/blob/functions/decode_function.h" +#include "function/blob/functions/encode_function.h" +#include "function/blob/functions/octet_length_function.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +function_set OctetLengthFunctions::getFunctionSet() { + function_set definitions; + definitions.push_back( + make_unique(name, std::vector{LogicalTypeID::BLOB}, + LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction)); + return definitions; +} + +function_set EncodeFunctions::getFunctionSet() { + function_set definitions; + definitions.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::BLOB, + ScalarFunction::UnaryStringExecFunction)); + return definitions; +} + +function_set DecodeFunctions::getFunctionSet() { + function_set definitions; + definitions.push_back(make_unique(name, + std::vector{LogicalTypeID::BLOB}, LogicalTypeID::STRING, + ScalarFunction::UnaryStringExecFunction)); + return definitions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_boolean_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_boolean_functions.cpp new file mode 100644 index 0000000000..74be067259 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_boolean_functions.cpp @@ -0,0 +1,124 @@ +#include "function/boolean/vector_boolean_functions.h" + +#include "common/exception/runtime.h" +#include "function/boolean/boolean_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void VectorBooleanFunction::bindExecFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func) { + if (ExpressionTypeUtil::isBinary(expressionType)) { + bindBinaryExecFunction(expressionType, children, func); + } else { + KU_ASSERT(ExpressionTypeUtil::isUnary(expressionType)); + bindUnaryExecFunction(expressionType, children, func); + } +} + +void VectorBooleanFunction::bindSelectFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func) { + if (ExpressionTypeUtil::isBinary(expressionType)) { + bindBinarySelectFunction(expressionType, children, func); + } else { + KU_ASSERT(ExpressionTypeUtil::isUnary(expressionType)); + bindUnarySelectFunction(expressionType, children, func); + } +} + +void VectorBooleanFunction::bindBinaryExecFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func) { + KU_ASSERT(children.size() == 2); + const auto& leftType = children[0]->dataType; + const auto& rightType = children[1]->dataType; + (void)leftType; + (void)rightType; + KU_ASSERT(leftType.getLogicalTypeID() == LogicalTypeID::BOOL && + rightType.getLogicalTypeID() == LogicalTypeID::BOOL); + switch (expressionType) { + case ExpressionType::AND: { + func = &BinaryBooleanExecFunction; + return; + } + case ExpressionType::OR: { + func = &BinaryBooleanExecFunction; + return; + } + case ExpressionType::XOR: { + func = &BinaryBooleanExecFunction; + return; + } + default: + throw RuntimeException("Invalid expression type " + + ExpressionTypeUtil::toString(expressionType) + + " for VectorBooleanFunctions::bindBinaryExecFunction."); + } +} + +void VectorBooleanFunction::bindBinarySelectFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func) { + KU_ASSERT(children.size() == 2); + const auto& leftType = children[0]->dataType; + const auto& rightType = children[1]->dataType; + (void)leftType; + (void)rightType; + KU_ASSERT(leftType.getLogicalTypeID() == LogicalTypeID::BOOL && + rightType.getLogicalTypeID() == LogicalTypeID::BOOL); + switch (expressionType) { + case ExpressionType::AND: { + func = &BinaryBooleanSelectFunction; + return; + } + case ExpressionType::OR: { + func = &BinaryBooleanSelectFunction; + return; + } + case ExpressionType::XOR: { + func = &BinaryBooleanSelectFunction; + return; + } + default: + throw RuntimeException("Invalid expression type " + + ExpressionTypeUtil::toString(expressionType) + + " for VectorBooleanFunctions::bindBinarySelectFunction."); + } +} + +void VectorBooleanFunction::bindUnaryExecFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func) { + KU_ASSERT( + children.size() == 1 && children[0]->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); + (void)children; + switch (expressionType) { + case ExpressionType::NOT: { + func = &UnaryBooleanExecFunction; + return; + } + default: + throw RuntimeException("Invalid expression type " + + ExpressionTypeUtil::toString(expressionType) + + " for VectorBooleanFunctions::bindUnaryExecFunction."); + } +} + +void VectorBooleanFunction::bindUnarySelectFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func) { + KU_ASSERT( + children.size() == 1 && children[0]->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); + (void)children; + switch (expressionType) { + case ExpressionType::NOT: { + func = &UnaryBooleanSelectFunction; + return; + } + default: + throw RuntimeException("Invalid expression type " + + ExpressionTypeUtil::toString(expressionType) + + " for VectorBooleanFunctions::bindUnaryExecFunction."); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_cast_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_cast_functions.cpp new file mode 100644 index 0000000000..0c388b7325 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_cast_functions.cpp @@ -0,0 +1,1206 @@ +#include "function/cast/vector_cast_functions.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/literal_expression.h" +#include "catalog/catalog.h" +#include "common/exception/binder.h" +#include "common/exception/conversion.h" +#include "function/built_in_function_utils.h" +#include "function/cast/cast_union_bind_data.h" +#include "function/cast/functions/cast_array.h" +#include "function/cast/functions/cast_decimal.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "function/cast/functions/cast_functions.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace function { + +struct CastChildFunctionExecutor { + template + static void executeSwitch(common::ValueVector& operand, common::SelectionVector*, + common::ValueVector& result, common::SelectionVector*, void* dataPtr) { + auto& bindData = *reinterpret_cast(dataPtr); + for (auto i = 0u; i < bindData.numOfEntries; i++) { + result.setNull(i, operand.isNull(i)); + if (!result.isNull(i)) { + OP_WRAPPER::template operation((void*)(&operand), + i, (void*)(&result), i, dataPtr); + } + } + } +}; + +static union_field_idx_t findUnionMinCostTag(const LogicalType&, const LogicalType&); + +static void resolveNestedVector(std::shared_ptr inputVector, ValueVector* resultVector, + uint64_t numOfEntries, CastFunctionBindData* dataPtr) { + const auto* inputType = &inputVector->dataType; + const auto* resultType = &resultVector->dataType; + while (true) { + if ((inputType->getPhysicalType() == PhysicalTypeID::LIST || + inputType->getPhysicalType() == PhysicalTypeID::ARRAY) && + (resultType->getPhysicalType() == PhysicalTypeID::LIST || + resultType->getPhysicalType() == PhysicalTypeID::ARRAY)) { + // copy data and nullmask from input + memcpy(resultVector->getData(), inputVector->getData(), + numOfEntries * resultVector->getNumBytesPerValue()); + resultVector->setNullFromBits(inputVector->getNullMask().getData(), 0, 0, numOfEntries); + + numOfEntries = ListVector::getDataVectorSize(inputVector.get()); + ListVector::resizeDataVector(resultVector, numOfEntries); + + inputVector = ListVector::getSharedDataVector(inputVector.get()); + resultVector = ListVector::getDataVector(resultVector); + inputType = &inputVector->dataType; + resultType = &resultVector->dataType; + } else if ((inputType->getLogicalTypeID() == LogicalTypeID::STRUCT && + resultType->getLogicalTypeID() == LogicalTypeID::STRUCT) || + CastArrayHelper::isUnionSpecialCast(*inputType, *resultType)) { + // Check if struct type can be cast + auto errorMsg = stringFormat("Unsupported casting function from {} to {}.", + inputType->toString(), resultType->toString()); + // Check if two structs have the same number of fields + if (StructType::getNumFields(*inputType) != StructType::getNumFields(*resultType)) { + throw ConversionException{errorMsg}; + } + + // Check if two structs have the same field names + auto inputTypeNames = StructType::getFieldNames(*inputType); + auto resultTypeNames = StructType::getFieldNames(*resultType); + + for (auto i = 0u; i < inputTypeNames.size(); i++) { + if (StringUtils::caseInsensitiveEquals(inputTypeNames[i], resultTypeNames[i])) { + continue; + } + throw ConversionException{errorMsg}; + } + + // copy data and nullmask from input + memcpy(resultVector->getData(), inputVector->getData(), + numOfEntries * resultVector->getNumBytesPerValue()); + resultVector->setNullFromBits(inputVector->getNullMask().getData(), 0, 0, numOfEntries); + + auto inputFieldVectors = StructVector::getFieldVectors(inputVector.get()); + auto resultFieldVectors = StructVector::getFieldVectors(resultVector); + for (auto i = 0u; i < inputFieldVectors.size(); i++) { + resolveNestedVector(inputFieldVectors[i], resultFieldVectors[i].get(), numOfEntries, + dataPtr); + } + return; + } else if (resultType->getLogicalTypeID() == LogicalTypeID::UNION) { + if (inputType->getLogicalTypeID() == LogicalTypeID::UNION) { + auto numFieldsSrc = UnionType::getNumFields(*inputType); + std::vector tagMap(numFieldsSrc); + for (auto i = 0u; i < numFieldsSrc; ++i) { + const auto& fieldName = UnionType::getFieldName(*inputType, i); + if (!UnionType::hasField(*resultType, fieldName)) { + throw ConversionException{stringFormat( + "Cannot cast from {} to {}, target type is missing field '{}'.", + inputType->toString(), resultType->toString(), fieldName)}; + } + const auto& fieldTypeSrc = UnionType::getFieldType(*inputType, i); + const auto& fieldTypeDst = UnionType::getFieldType(*resultType, fieldName); + if (!CastFunction::hasImplicitCast(fieldTypeSrc, fieldTypeDst)) { + throw ConversionException{ + stringFormat("Unsupported casting function from {} to {}.", + fieldTypeSrc.toString(), fieldTypeDst.toString())}; + } + auto dstTag = UnionType::getFieldIdx(*resultType, fieldName); + tagMap[i] = dstTag; + auto srcValVector = UnionVector::getSharedValVector(inputVector.get(), i); + auto resValVector = UnionVector::getValVector(resultVector, dstTag); + resolveNestedVector(srcValVector, resValVector, numOfEntries, dataPtr); + } + auto srcTagVector = UnionVector::getTagVector(inputVector.get()); + auto resTagVector = UnionVector::getTagVector(resultVector); + for (auto i = 0u; i < numOfEntries; ++i) { + auto srcTag = srcTagVector->getValue(i); + resTagVector->setValue(i, tagMap[srcTag]); + } + return; + } else { + auto minCostTag = findUnionMinCostTag(*inputType, *resultType); + auto tagVector = UnionVector::getTagVector(resultVector); + for (auto i = 0u; i < numOfEntries; ++i) { + tagVector->setValue(i, minCostTag); + } + resultVector = UnionVector::getValVector(resultVector, minCostTag); + resultType = &UnionType::getFieldType(*resultType, minCostTag); + } + } else { + break; + } + } + // non-nested types + if (inputType->getLogicalTypeID() != resultType->getLogicalTypeID()) { + auto func = CastFunction::bindCastFunction("CAST", *inputType, + *resultType) + ->execFunc; + std::vector> childParams{inputVector}; + dataPtr->numOfEntries = numOfEntries; + func(childParams, SelectionVector::fromValueVectors(childParams), *resultVector, + resultVector->getSelVectorPtr(), (void*)dataPtr); + } else { + for (auto i = 0u; i < numOfEntries; i++) { + resultVector->copyFromVectorData(i, inputVector.get(), i); + } + } +} + +static void nestedTypesCastExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void*) { + KU_ASSERT(params.size() == 1); + result.resetAuxiliaryBuffer(); + const auto& inputVector = params[0]; + const auto* inputVectorSelVector = paramSelVectors[0]; + + // check if all selected list entries have the required fixed list size + if (CastArrayHelper::containsListToArray(inputVector->dataType, result.dataType)) { + for (auto i = 0u; i < inputVectorSelVector->getSelSize(); i++) { + auto pos = (*inputVectorSelVector)[i]; + CastArrayHelper::validateListEntry(inputVector.get(), result.dataType, pos); + } + }; + + auto& selVector = *inputVectorSelVector; + auto bindData = CastFunctionBindData(result.dataType.copy()); + bindData.numOfEntries = selVector[selVector.getSelSize() - 1] + 1; + resolveNestedVector(inputVector, &result, bindData.numOfEntries, &bindData); + if (inputVector->state->isFlat()) { + resultSelVector->setToFiltered(); + (*resultSelVector)[0] = (*inputVectorSelVector)[0]; + } +} + +static bool hasImplicitCastList(const LogicalType& srcType, const LogicalType& dstType) { + return CastFunction::hasImplicitCast(ListType::getChildType(srcType), + ListType::getChildType(dstType)); +} + +static bool hasImplicitCastArray(const LogicalType& srcType, const LogicalType& dstType) { + if (ArrayType::getNumElements(srcType) != ArrayType::getNumElements(dstType)) { + return false; + } + return CastFunction::hasImplicitCast(ArrayType::getChildType(srcType), + ArrayType::getChildType(dstType)); +} + +static bool hasImplicitCastArrayToList(const LogicalType& srcType, const LogicalType& dstType) { + return CastFunction::hasImplicitCast(ArrayType::getChildType(srcType), + ListType::getChildType(dstType)); +} + +static bool hasImplicitCastListToArray(const LogicalType& srcType, const LogicalType& dstType) { + return CastFunction::hasImplicitCast(ListType::getChildType(srcType), + ArrayType::getChildType(dstType)); +} + +static bool hasImplicitCastStruct(const LogicalType& srcType, const LogicalType& dstType) { + const auto& srcFields = StructType::getFields(srcType); + const auto& dstFields = StructType::getFields(dstType); + if (srcFields.size() != dstFields.size()) { + return false; + } + for (auto i = 0u; i < srcFields.size(); i++) { + if (srcFields[i].getName() != dstFields[i].getName()) { + return false; + } + if (!CastFunction::hasImplicitCast(srcFields[i].getType(), dstFields[i].getType())) { + return false; + } + } + return true; +} + +static bool hasImplicitCastUnion(const LogicalType& srcType, const LogicalType& dstType) { + if (srcType.getLogicalTypeID() == LogicalTypeID::UNION) { + auto numFieldsSrc = UnionType::getNumFields(srcType); + for (auto i = 0u; i < numFieldsSrc; ++i) { + const auto& fieldName = UnionType::getFieldName(srcType, i); + const auto& fieldType = UnionType::getFieldType(srcType, i); + if (!UnionType::hasField(dstType, fieldName) || + !CastFunction::hasImplicitCast(fieldType, + UnionType::getFieldType(dstType, fieldName))) { + return false; + } + } + return true; + } else { + auto numFields = UnionType::getNumFields(dstType); + for (auto i = 0u; i < numFields; ++i) { + const auto& fieldType = UnionType::getFieldType(dstType, i); + if (CastFunction::hasImplicitCast(srcType, fieldType)) { + return true; + } + } + return false; + } +} + +static bool hasImplicitCastMap(const LogicalType& srcType, const LogicalType& dstType) { + const auto& srcKeyType = MapType::getKeyType(srcType); + const auto& srcValueType = MapType::getValueType(srcType); + const auto& dstKeyType = MapType::getKeyType(dstType); + const auto& dstValueType = MapType::getValueType(dstType); + return CastFunction::hasImplicitCast(srcKeyType, dstKeyType) && + CastFunction::hasImplicitCast(srcValueType, dstValueType); +} + +bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType& dstType) { + if (LogicalTypeUtils::isNested(srcType) && LogicalTypeUtils::isNested(dstType)) { + if (srcType.getLogicalTypeID() == LogicalTypeID::ARRAY && + dstType.getLogicalTypeID() == LogicalTypeID::LIST) { + return hasImplicitCastArrayToList(srcType, dstType); + } + if (srcType.getLogicalTypeID() == LogicalTypeID::LIST && + dstType.getLogicalTypeID() == LogicalTypeID::ARRAY) { + return hasImplicitCastListToArray(srcType, dstType); + } + if (srcType.getLogicalTypeID() != dstType.getLogicalTypeID()) { + return false; + } + switch (srcType.getLogicalTypeID()) { + case LogicalTypeID::LIST: + return hasImplicitCastList(srcType, dstType); + case LogicalTypeID::ARRAY: + return hasImplicitCastArray(srcType, dstType); + case LogicalTypeID::STRUCT: + return hasImplicitCastStruct(srcType, dstType); + case LogicalTypeID::UNION: + return hasImplicitCastUnion(srcType, dstType); + case LogicalTypeID::MAP: + return hasImplicitCastMap(srcType, dstType); + default: + // LCOV_EXCL_START + KU_UNREACHABLE; + // LCOV_EXCL_END + } + } else if (dstType.getLogicalTypeID() == LogicalTypeID::UNION) { + return hasImplicitCastUnion(srcType, dstType); + } + if (BuiltInFunctionsUtils::getCastCost(srcType.getLogicalTypeID(), + dstType.getLogicalTypeID()) != UNDEFINED_CAST_COST) { + return true; + } + // TODO(Jiamin): there are still other special cases + // We allow cast between any numerical types + if (LogicalTypeUtils::isNumerical(srcType) && LogicalTypeUtils::isNumerical(dstType)) { + return true; + } + return false; +} + +template +static std::unique_ptr bindCastFromStringFunction(const std::string& functionName, + const LogicalType& targetType) { + scalar_func_exec_t execFunc; + switch (targetType.getLogicalTypeID()) { + case LogicalTypeID::DATE: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_MS: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_NS: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::TIMESTAMP: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::INTERVAL: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::BLOB: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UUID: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::STRING: { + execFunc = + ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::BOOL: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::DOUBLE: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::FLOAT: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::DECIMAL: { + switch (targetType.getPhysicalType()) { + case PhysicalTypeID::INT16: + execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT32: + execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT64: + execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT128: + execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + default: + KU_UNREACHABLE; + } + } break; + case LogicalTypeID::INT128: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UINT128: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::INT32: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::INT16: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::INT8: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UINT64: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UINT32: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UINT16: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UINT8: { + execFunc = + ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::MAP: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::STRUCT: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + case LogicalTypeID::UNION: { + execFunc = ScalarFunction::UnaryCastStringExecFunction; + } break; + default: + throw ConversionException{ + stringFormat("Unsupported casting function from STRING to {}.", targetType.toString())}; + } + return std::make_unique(functionName, + std::vector{LogicalTypeID::STRING}, targetType.getLogicalTypeID(), execFunc); +} + +template +static std::unique_ptr bindCastToStringFunction(const std::string& functionName, + const LogicalType& sourceType) { + scalar_func_exec_t func; + switch (sourceType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::INT32: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::INT16: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::INT8: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UINT64: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UINT32: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UINT16: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::INT128: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UINT128: { + func = + ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UINT8: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::DOUBLE: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::FLOAT: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::DECIMAL: { + switch (sourceType.getPhysicalType()) { + case PhysicalTypeID::INT16: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT32: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT64: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT128: + func = + ScalarFunction::UnaryExecNestedTypeFunction; + break; + default: + KU_UNREACHABLE; + } + } break; + case LogicalTypeID::DATE: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_NS: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_MS: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_TZ: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::TIMESTAMP: { + func = + ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::INTERVAL: { + func = + ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::INTERNAL_ID: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::BLOB: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UUID: { + func = + ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::MAP: { + func = + ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::NODE: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::REL: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + case LogicalTypeID::UNION: { + func = ScalarFunction::UnaryCastExecFunction; + } break; + default: + KU_UNREACHABLE; + } + return std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, LogicalTypeID::STRING, func); +} + +template +static std::unique_ptr bindCastToDecimalFunction(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& targetType) { + scalar_func_exec_t func; + if (sourceType.getLogicalTypeID() == LogicalTypeID::DECIMAL) { + TypeUtils::visit( + sourceType, + [&](T) { + func = ScalarFunction::UnaryCastExecFunction; + }, + [&](auto) { KU_UNREACHABLE; }); + } else { + TypeUtils::visit( + sourceType, + [&](T) { + func = ScalarFunction::UnaryCastExecFunction; + }, + [&](auto) { KU_UNREACHABLE; }); + } + return std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, targetType.getLogicalTypeID(), + func); +} + +template +static std::unique_ptr bindCastToNumericFunction(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& targetType) { + scalar_func_exec_t func; + switch (sourceType.getLogicalTypeID()) { + case LogicalTypeID::INT8: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::INT16: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::INT32: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::UINT8: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::UINT16: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::UINT32: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::UINT64: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::INT128: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::UINT128: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::FLOAT: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::DOUBLE: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::DECIMAL: { + switch (sourceType.getPhysicalType()) { + // note: this cannot handle decimal -> decimal casting. + case PhysicalTypeID::INT16: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT32: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT64: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT128: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + default: + KU_UNREACHABLE; + } + } break; + default: + throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", + sourceType.toString(), targetType.toString())}; + } + return std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, targetType.getLogicalTypeID(), + func); +} + +static union_field_idx_t findUnionMinCostTag(const LogicalType& sourceType, + const LogicalType& unionType) { + uint32_t minCastCost = UNDEFINED_CAST_COST; + union_field_idx_t minCostTag = 0; + auto numFields = UnionType::getNumFields(unionType); + for (auto i = 0u; i < numFields; ++i) { + const auto& fieldType = UnionType::getFieldType(unionType, i); + if (CastFunction::hasImplicitCast(sourceType, fieldType)) { + uint32_t castCost = BuiltInFunctionsUtils::getCastCost(sourceType.getLogicalTypeID(), + fieldType.getLogicalTypeID()); + if (castCost < minCastCost) { + minCastCost = castCost; + minCostTag = i; + } + } + } + if (minCastCost == UNDEFINED_CAST_COST) { + throw ConversionException{ + stringFormat("Cannot cast from {} to {}, target type has no compatible field.", + sourceType.toString(), unionType.toString())}; + } + return minCostTag; +} + +static std::unique_ptr bindCastToUnionFunction(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& targetType) { + auto minCostTag = findUnionMinCostTag(sourceType, targetType); + const auto& innerType = common::UnionType::getFieldType(targetType, minCostTag); + CastToUnionBindData::inner_func_t innerFunc; + if (sourceType == innerType) { + innerFunc = [](ValueVector* inputVector, ValueVector& valVector, SelectionVector*, + uint64_t inputPos, uint64_t resultPos) { + valVector.copyFromVectorData(inputPos, inputVector, resultPos); + }; + } else { + std::shared_ptr innerCast = + CastFunction::bindCastFunction("CAST", sourceType, innerType); + innerFunc = [innerCast](ValueVector* inputVector, ValueVector& valVector, + SelectionVector* selVector, uint64_t, uint64_t) { + // Can we just use inputPos / resultPos and not the entire sel vector? + auto input = std::shared_ptr(inputVector, [](ValueVector*) {}); + innerCast->execFunc({input}, {selVector}, valVector, selVector, nullptr /* dataPtr */); + }; + } + auto castFunc = std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, targetType.getLogicalTypeID(), + ScalarFunction::UnaryCastExecFunction); + castFunc->bindFunc = [minCostTag, innerFunc, &targetType](const ScalarBindFuncInput&) { + return std::make_unique(minCostTag, innerFunc, targetType.copy()); + }; + return castFunc; +} + +static std::unique_ptr bindCastBetweenNested(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& targetType) { + // todo: compile time checking of nested types + if (CastArrayHelper::checkCompatibleNestedTypes(sourceType.getLogicalTypeID(), + targetType.getLogicalTypeID())) { + return std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, + targetType.getLogicalTypeID(), nestedTypesCastExecFunction); + } + throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", + LogicalTypeUtils::toString(sourceType.getLogicalTypeID()), + LogicalTypeUtils::toString(targetType.getLogicalTypeID()))}; +} + +template +static std::unique_ptr bindCastToDateFunction(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& dstType) { + scalar_func_exec_t func; + switch (sourceType.getLogicalTypeID()) { + case LogicalTypeID::TIMESTAMP_MS: + func = ScalarFunction::UnaryExecFunction; + break; + case LogicalTypeID::TIMESTAMP_NS: + func = ScalarFunction::UnaryExecFunction; + break; + case LogicalTypeID::TIMESTAMP_SEC: + func = ScalarFunction::UnaryExecFunction; + break; + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP: + func = ScalarFunction::UnaryExecFunction; + break; + // LCOV_EXCL_START + default: + throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", + sourceType.toString(), dstType.toString())}; + // LCOV_EXCL_END + } + return std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, LogicalTypeID::DATE, func); +} + +template +static std::unique_ptr bindCastToTimestampFunction(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& dstType) { + scalar_func_exec_t func; + switch (sourceType.getLogicalTypeID()) { + case LogicalTypeID::DATE: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_MS: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_NS: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_SEC: { + func = ScalarFunction::UnaryExecFunction; + } break; + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP: { + func = ScalarFunction::UnaryExecFunction; + } break; + default: + throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", + sourceType.toString(), dstType.toString())}; + } + return std::make_unique(functionName, + std::vector{sourceType.getLogicalTypeID()}, LogicalTypeID::TIMESTAMP, func); +} + +template +static std::unique_ptr bindCastBetweenDecimalFunction( + const std::string& functionName, const LogicalType& sourceType) { + scalar_func_exec_t func; + switch (sourceType.getPhysicalType()) { + case PhysicalTypeID::INT16: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT32: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT64: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + case PhysicalTypeID::INT128: + func = ScalarFunction::UnaryExecNestedTypeFunction; + break; + default: + KU_UNREACHABLE; + } + return std::make_unique(functionName, + std::vector{LogicalTypeID::DECIMAL}, LogicalTypeID::DECIMAL, func); +} + +template +std::unique_ptr CastFunction::bindCastFunction(const std::string& functionName, + const LogicalType& sourceType, const LogicalType& targetType) { + auto sourceTypeID = sourceType.getLogicalTypeID(); + auto targetTypeID = targetType.getLogicalTypeID(); + if (sourceTypeID == LogicalTypeID::STRING) { + return bindCastFromStringFunction(functionName, targetType); + } + switch (targetTypeID) { + case LogicalTypeID::STRING: { + return bindCastToStringFunction(functionName, sourceType); + } + case LogicalTypeID::DOUBLE: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::FLOAT: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::DECIMAL: { + std::unique_ptr scalarFunc; + TypeUtils::visit( + targetType.getPhysicalType(), + [&](T) { + scalarFunc = + bindCastToDecimalFunction(functionName, sourceType, targetType); + }, + [](auto) { KU_UNREACHABLE; }); + return scalarFunc; + } + case LogicalTypeID::INT128: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::UINT128: { + return bindCastToNumericFunction(functionName, + sourceType, targetType); + } + case LogicalTypeID::SERIAL: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::INT64: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::INT32: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::INT16: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::INT8: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::UINT64: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::UINT32: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::UINT16: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::UINT8: { + return bindCastToNumericFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::DATE: { + return bindCastToDateFunction(functionName, sourceType, targetType); + } + case LogicalTypeID::TIMESTAMP_NS: { + return bindCastToTimestampFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::TIMESTAMP_MS: { + return bindCastToTimestampFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::TIMESTAMP_SEC: { + return bindCastToTimestampFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP: { + return bindCastToTimestampFunction(functionName, sourceType, + targetType); + } + case LogicalTypeID::UNION: { + if (sourceType.getLogicalTypeID() != LogicalTypeID::UNION && + !CastArrayHelper::isUnionSpecialCast(sourceType, targetType)) { + return bindCastToUnionFunction(functionName, sourceType, targetType); + } + [[fallthrough]]; + } + case LogicalTypeID::LIST: + case LogicalTypeID::ARRAY: + case LogicalTypeID::MAP: + case LogicalTypeID::STRUCT: { + return bindCastBetweenNested(functionName, sourceType, targetType); + } + default: { + throw ConversionException(stringFormat("Unsupported casting function from {} to {}.", + sourceType.toString(), targetType.toString())); + } + } +} + +function_set CastToDateFunction::getFunctionSet() { + function_set result; + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::DATE())); + return result; +} + +function_set CastToTimestampFunction::getFunctionSet() { + function_set result; + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::TIMESTAMP())); + return result; +} + +function_set CastToIntervalFunction::getFunctionSet() { + function_set result; + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::INTERVAL())); + return result; +} + +static std::unique_ptr toStringBindFunc(ScalarBindFuncInput input) { + return FunctionBindData::getSimpleBindData(input.arguments, LogicalType::STRING()); +} + +function_set CastToStringFunction::getFunctionSet() { + function_set result; + result.reserve(LogicalTypeUtils::getAllValidLogicTypes().size()); + for (auto& type : LogicalTypeUtils::getAllValidLogicTypes()) { + auto function = CastFunction::bindCastFunction(name, type, LogicalType::STRING()); + function->bindFunc = toStringBindFunc; + result.push_back(std::move(function)); + } + return result; +} + +function_set CastToBlobFunction::getFunctionSet() { + function_set result; + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::BLOB())); + return result; +} + +function_set CastToUUIDFunction::getFunctionSet() { + function_set result; + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::UUID())); + return result; +} + +function_set CastToBoolFunction::getFunctionSet() { + function_set result; + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::BOOL())); + return result; +} + +function_set CastToDoubleFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::DOUBLE())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::DOUBLE())); + return result; +} + +function_set CastToFloatFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::FLOAT())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::FLOAT())); + return result; +} + +function_set CastToInt128Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::INT128())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::INT128())); + return result; +} + +function_set CastToSerialFunction::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::SERIAL())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::SERIAL())); + return result; +} + +function_set CastToInt64Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::INT64())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::INT64())); + return result; +} + +function_set CastToInt32Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::INT32())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::INT32())); + return result; +} + +function_set CastToInt16Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::INT16())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::INT16())); + return result; +} + +function_set CastToInt8Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::INT8())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::INT8())); + return result; +} + +function_set CastToUInt128Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::UINT128())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::UINT128())); + return result; +} + +function_set CastToUInt64Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::UINT64())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::UINT64())); + return result; +} + +function_set CastToUInt32Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::UINT32())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::UINT32())); + return result; +} + +function_set CastToUInt16Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::UINT16())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::UINT16())); + return result; +} + +function_set CastToUInt8Function::getFunctionSet() { + function_set result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back( + CastFunction::bindCastFunction(name, LogicalType(typeID), LogicalType::UINT8())); + } + result.push_back( + CastFunction::bindCastFunction(name, LogicalType::STRING(), LogicalType::UINT8())); + return result; +} + +static std::unique_ptr castBindFunc(ScalarBindFuncInput input) { + KU_ASSERT(input.arguments.size() == 2); + // Bind target type. + if (input.arguments[1]->expressionType != ExpressionType::LITERAL) { + throw BinderException(stringFormat("Second parameter of CAST function must be a literal.")); + } + auto literalExpr = input.arguments[1]->constPtrCast(); + auto targetTypeStr = literalExpr->getValue().getValue(); + auto func = input.definition->ptrCast(); + func->name = "CAST_TO_" + targetTypeStr; + auto targetType = LogicalType::convertFromString(targetTypeStr, input.context); + if (!LogicalType::isBuiltInType(targetTypeStr)) { + std::vector typeVec; + typeVec.push_back(input.arguments[0]->getDataType().copy()); + try { + auto entry = + catalog::Catalog::Get(*input.context) + ->getFunctionEntry(transaction::Transaction::Get(*input.context), func->name); + auto match = BuiltInFunctionsUtils::matchFunction(func->name, typeVec, + entry->ptrCast()); + func->execFunc = match->constPtrCast()->execFunc; + return std::make_unique(targetType.copy()); + } catch (...) { // NOLINT + // If there's no user defined casting function for the corresponding user defined type, + // we use the default casting function. + } + } + // For STRUCT type, we will need to check its field name in later stage + // Otherwise, there will be bug for: RETURN cast({'a': 12, 'b': 12} AS struct(c int64, d + // int64)); being allowed. + if (targetType == input.arguments[0]->getDataType() && + targetType.getLogicalTypeID() != LogicalTypeID::STRUCT) { // No need to cast. + return nullptr; + } + if (ExpressionUtil::canCastStatically(*input.arguments[0], targetType) && + targetType.getLogicalTypeID() != LogicalTypeID::STRUCT) { + input.arguments[0]->cast(targetType); + return nullptr; + } + // TODO(Xiyang): Can we unify the binding of casting function with other scalar functions? + auto res = + CastFunction::bindCastFunction(func->name, input.arguments[0]->getDataType(), targetType); + func->execFunc = res->execFunc; + if (res->bindFunc) { + return res->bindFunc(input); + } + return std::make_unique(targetType.copy()); +} + +function_set CastAnyFunction::getFunctionSet() { + function_set result; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::ANY, LogicalTypeID::STRING}, LogicalTypeID::ANY); + func->bindFunc = castBindFunc; + result.push_back(std::move(func)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_date_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_date_functions.cpp new file mode 100644 index 0000000000..998b2ea905 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_date_functions.cpp @@ -0,0 +1,123 @@ +#include "function/date/vector_date_functions.h" + +#include "function/date/date_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +function_set DatePartFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::DATE}, + LogicalTypeID::INT64, + ScalarFunction::BinaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::INT64, + ScalarFunction::BinaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INTERVAL}, + LogicalTypeID::INT64, + ScalarFunction::BinaryExecFunction)); + return result; +} + +function_set DateTruncFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::DATE}, LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, + ScalarFunction::BinaryExecFunction)); + return result; +} + +function_set DayNameFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, + ScalarFunction::UnaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, + ScalarFunction::UnaryExecFunction)); + return result; +} + +function_set GreatestFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE, LogicalTypeID::DATE}, LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, + ScalarFunction::BinaryExecFunction)); + return result; +} + +function_set LastDayFunction::getFunctionSet() { + function_set result; + result.push_back( + make_unique(name, std::vector{LogicalTypeID::DATE}, + LogicalTypeID::DATE, ScalarFunction::UnaryExecFunction)); + result.push_back( + make_unique(name, std::vector{LogicalTypeID::TIMESTAMP}, + LogicalTypeID::DATE, ScalarFunction::UnaryExecFunction)); + return result; +} + +function_set LeastFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE, LogicalTypeID::DATE}, LogicalTypeID::DATE, + ScalarFunction::BinaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, + ScalarFunction::BinaryExecFunction)); + return result; +} + +function_set MakeDateFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::INT64, LogicalTypeID::INT64, + LogicalTypeID::INT64}, + LogicalTypeID::DATE, + ScalarFunction::TernaryExecFunction)); + return result; +} + +function_set MonthNameFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, + ScalarFunction::UnaryExecFunction)); + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, + ScalarFunction::UnaryExecFunction)); + return result; +} + +function_set CurrentDateFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, std::vector{}, + LogicalTypeID::DATE, ScalarFunction::NullaryAuxilaryExecFunction)); + return result; +} + +function_set CurrentTimestampFunction::getFunctionSet() { + function_set result; + result.push_back( + make_unique(name, std::vector{}, LogicalTypeID::TIMESTAMP, + ScalarFunction::NullaryAuxilaryExecFunction)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_hash_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_hash_functions.cpp new file mode 100644 index 0000000000..2a9eb268b9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_hash_functions.cpp @@ -0,0 +1,256 @@ +#include "function/hash/vector_hash_functions.h" + +#include "common/data_chunk/sel_vector.h" +#include "common/system_config.h" +#include "common/type_utils.h" +#include "function/hash/hash_functions.h" +#include "function/scalar_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +template +static void executeOnValue(const ValueVector& operand, sel_t operandPos, ValueVector& result, + sel_t resultPos) { + Hash::operation(operand.getValue(operandPos), + result.getValue(resultPos)); +} + +template +void UnaryHashFunctionExecutor::execute(const ValueVector& operand, + const SelectionView& operandSelectVec, ValueVector& result, + const SelectionView& resultSelectVec) { + auto resultValues = (RESULT_TYPE*)result.getData(); + if (operand.hasNoNullsGuarantee()) { + if (operandSelectVec.isUnfiltered()) { + for (auto i = 0u; i < operandSelectVec.getSelSize(); i++) { + auto resultPos = resultSelectVec[i]; + executeOnValue(operand, i, result, resultPos); + } + } else { + for (auto i = 0u; i < operandSelectVec.getSelSize(); i++) { + auto operandPos = operandSelectVec[i]; + auto resultPos = resultSelectVec[i]; + Hash::operation(operand.getValue(operandPos), + resultValues[resultPos]); + } + } + } else { + if (operandSelectVec.isUnfiltered()) { + for (auto i = 0u; i < operandSelectVec.getSelSize(); i++) { + auto resultPos = resultSelectVec[i]; + if (!operand.isNull(i)) { + Hash::operation(operand.getValue(i), resultValues[resultPos]); + } else { + result.setValue(resultPos, NULL_HASH); + } + } + } else { + for (auto i = 0u; i < operandSelectVec.getSelSize(); i++) { + auto operandPos = operandSelectVec[i]; + auto resultPos = resultSelectVec[i]; + if (!operand.isNull(operandPos)) { + Hash::operation(operand.getValue(operandPos), + resultValues[resultPos]); + } else { + result.setValue(resultPos, NULL_HASH); + } + } + } + } +} + +template +static void executeOnValue(const common::ValueVector& left, common::sel_t leftPos, + const common::ValueVector& right, common::sel_t rightPos, common::ValueVector& result, + common::sel_t resultPos) { + FUNC::operation(left.getValue(leftPos), right.getValue(rightPos), + result.getValue(resultPos)); +} + +static void validateSelState(const SelectionView& leftSelVec, const SelectionView& rightSelVec, + const SelectionView& resultSelVec) { + auto leftSelSize = leftSelVec.getSelSize(); + auto rightSelSize = rightSelVec.getSelSize(); + auto resultSelSize = resultSelVec.getSelSize(); + (void)resultSelSize; + if (leftSelSize > 1 && rightSelSize > 1) { + KU_ASSERT(leftSelSize == rightSelSize); + KU_ASSERT(leftSelSize == resultSelSize); + } else if (leftSelSize > 1) { + KU_ASSERT(leftSelSize == resultSelSize); + } else if (rightSelSize > 1) { + KU_ASSERT(rightSelSize == resultSelSize); + } +} + +template +void BinaryHashFunctionExecutor::execute(const common::ValueVector& left, + const SelectionView& leftSelVec, const common::ValueVector& right, + const SelectionView& rightSelVec, common::ValueVector& result, + const SelectionView& resultSelVec) { + validateSelState(leftSelVec, rightSelVec, resultSelVec); + result.resetAuxiliaryBuffer(); + if (leftSelVec.getSelSize() != 1 && rightSelVec.getSelSize() != 1) { + for (auto i = 0u; i < leftSelVec.getSelSize(); i++) { + auto leftPos = leftSelVec[i]; + auto rightPos = rightSelVec[i]; + auto resultPos = resultSelVec[i]; + executeOnValue(left, leftPos, right, rightPos, + result, resultPos); + } + } else if (leftSelVec.getSelSize() == 1) { + auto leftPos = leftSelVec[0]; + for (auto i = 0u; i < rightSelVec.getSelSize(); i++) { + auto rightPos = rightSelVec[i]; + auto resultPos = resultSelVec[i]; + executeOnValue(left, leftPos, right, rightPos, + result, resultPos); + } + } else { + auto rightPos = rightSelVec[0]; + for (auto i = 0u; i < leftSelVec.getSelSize(); i++) { + auto leftPos = leftSelVec[i]; + auto resultPos = resultSelVec[i]; + executeOnValue(left, leftPos, right, rightPos, + result, resultPos); + } + } +} + +static std::unique_ptr computeDataVecHash(const ValueVector& operand) { + auto hashVector = std::make_unique(LogicalType::LIST(LogicalType::HASH())); + auto numValuesInDataVec = ListVector::getDataVectorSize(&operand); + ListVector::resizeDataVector(hashVector.get(), numValuesInDataVec); + // TODO(Ziyi): Allow selection size to be greater than default vector capacity, so we don't have + // to chunk the selectionVector. + SelectionVector selectionVector{DEFAULT_VECTOR_CAPACITY}; + selectionVector.setToFiltered(); + auto numValuesComputed = 0u; + uint64_t numValuesToComputeHash = 0; + while (numValuesComputed < numValuesInDataVec) { + numValuesToComputeHash = + std::min(DEFAULT_VECTOR_CAPACITY, numValuesInDataVec - numValuesComputed); + for (auto i = 0u; i < numValuesToComputeHash; i++) { + selectionVector[i] = numValuesComputed; + numValuesComputed++; + } + selectionVector.setSelSize(numValuesToComputeHash); + VectorHashFunction::computeHash(*ListVector::getDataVector(&operand), selectionVector, + *ListVector::getDataVector(hashVector.get()), selectionVector); + } + return hashVector; +} + +static void finalizeDataVecHash(const ValueVector& operand, const SelectionView& operandSelVec, + ValueVector& result, const SelectionView& resultSelVec, ValueVector& tmpHashVec) { + for (auto i = 0u; i < operandSelVec.getSelSize(); i++) { + auto pos = operandSelVec[i]; + auto resultPos = resultSelVec[i]; + auto entry = operand.getValue(pos); + if (operand.isNull(pos)) { + result.setValue(resultPos, NULL_HASH); + } else { + auto hashValue = NULL_HASH; + for (auto j = 0u; j < entry.size; j++) { + hashValue = combineHashScalar(hashValue, + ListVector::getDataVector(&tmpHashVec)->getValue(entry.offset + j)); + } + result.setValue(resultPos, hashValue); + } + } +} + +static void computeListVectorHash(const ValueVector& operand, const SelectionView& operandSelectVec, + ValueVector& result, const SelectionView& resultSelectVec) { + auto dataVecHash = computeDataVecHash(operand); + finalizeDataVecHash(operand, operandSelectVec, result, resultSelectVec, *dataVecHash); +} + +static void computeStructVecHash(const ValueVector& operand, const SelectionView& operandSelVec, + ValueVector& result, const SelectionView& resultSelVec) { + switch (operand.dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + KU_ASSERT(0 == common::StructType::getFieldIdx(operand.dataType, InternalKeyword::ID)); + UnaryHashFunctionExecutor::execute( + *StructVector::getFieldVector(&operand, 0), operandSelVec, result, resultSelVec); + } break; + case LogicalTypeID::REL: { + KU_ASSERT(3 == StructType::getFieldIdx(operand.dataType, InternalKeyword::ID)); + UnaryHashFunctionExecutor::execute( + *StructVector::getFieldVector(&operand, 3), operandSelVec, result, resultSelVec); + } break; + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::UNION: + case LogicalTypeID::STRUCT: { + VectorHashFunction::computeHash(*StructVector::getFieldVector(&operand, 0 /* idx */), + operandSelVec, result, resultSelVec); + auto tmpHashVector = std::make_unique(LogicalType::HASH()); + SelectionView tmpSel(resultSelVec.getSelSize()); + for (auto i = 1u; i < StructType::getNumFields(operand.dataType); i++) { + auto fieldVector = StructVector::getFieldVector(&operand, i); + VectorHashFunction::computeHash(*fieldVector, operandSelVec, *tmpHashVector, tmpSel); + VectorHashFunction::combineHash(*tmpHashVector, tmpSel, result, resultSelVec, result, + resultSelVec); + } + } break; + default: + KU_UNREACHABLE; + } +} + +void VectorHashFunction::computeHash(const ValueVector& operand, + const SelectionView& operandSelectVec, ValueVector& result, + const SelectionView& resultSelectVec) { + result.state = operand.state; + KU_ASSERT(result.dataType.getLogicalTypeID() == LogicalType::HASH().getLogicalTypeID()); + TypeUtils::visit( + operand.dataType.getPhysicalType(), + [&](T) { + UnaryHashFunctionExecutor::execute(operand, operandSelectVec, result, + resultSelectVec); + }, + [&](struct_entry_t) { + computeStructVecHash(operand, operandSelectVec, result, resultSelectVec); + }, + [&](list_entry_t) { + computeListVectorHash(operand, operandSelectVec, result, resultSelectVec); + }, + [&operand](auto) { + // LCOV_EXCL_START + throw RuntimeException("Cannot hash data type " + operand.dataType.toString()); + // LCOV_EXCL_STOP + }); +} + +void VectorHashFunction::combineHash(const ValueVector& left, const SelectionView& leftSelVec, + const ValueVector& right, const SelectionView& rightSelVec, ValueVector& result, + const SelectionView& resultSelVec) { + KU_ASSERT(left.dataType.getLogicalTypeID() == LogicalType::HASH().getLogicalTypeID()); + KU_ASSERT(left.dataType.getLogicalTypeID() == right.dataType.getLogicalTypeID()); + KU_ASSERT(left.dataType.getLogicalTypeID() == result.dataType.getLogicalTypeID()); + BinaryHashFunctionExecutor::execute(left, leftSelVec, + right, rightSelVec, result, resultSelVec); +} + +static void HashExecFunc(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector*, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 1); + // TODO(Ziyi): evaluators should resolve the state for result vector. + result.state = params[0]->state; + VectorHashFunction::computeHash(*params[0], *paramSelVectors[0], result, + result.state->getSelVectorUnsafe()); +} + +function_set HashFunction::getFunctionSet() { + function_set functionSet; + functionSet.push_back(std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::UINT64, HashExecFunc)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_node_rel_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_node_rel_functions.cpp new file mode 100644 index 0000000000..5423f87101 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_node_rel_functions.cpp @@ -0,0 +1,29 @@ +#include "function/schema/vector_node_rel_functions.h" + +#include "common/vector/value_vector.h" +#include "function/scalar_function.h" +#include "function/schema/offset_functions.h" +#include "function/unary_function_executor.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +static void execFunc(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 1); + UnaryFunctionExecutor::execute(*params[0], paramSelVectors[0], + result, resultSelVector); +} + +function_set OffsetFunction::getFunctionSet() { + function_set functionSet; + functionSet.push_back(make_unique(name, + std::vector{LogicalTypeID::INTERNAL_ID}, LogicalTypeID::INT64, execFunc)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_null_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_null_functions.cpp new file mode 100644 index 0000000000..5ef0d3e71c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_null_functions.cpp @@ -0,0 +1,50 @@ +#include "function/null/vector_null_functions.h" + +#include "common/exception/runtime.h" +#include "function/null/null_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void VectorNullFunction::bindExecFunction(ExpressionType expressionType, + const binder::expression_vector& /*children*/, scalar_func_exec_t& func) { + switch (expressionType) { + case ExpressionType::IS_NULL: { + func = UnaryNullExecFunction; + return; + } + case ExpressionType::IS_NOT_NULL: { + func = UnaryNullExecFunction; + return; + } + default: + throw RuntimeException("Invalid expression type " + + ExpressionTypeUtil::toString(expressionType) + + "for VectorNullOperations::bindUnaryExecFunction."); + } +} + +void VectorNullFunction::bindSelectFunction(ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func) { + KU_ASSERT(children.size() == 1); + (void)children; + switch (expressionType) { + case ExpressionType::IS_NULL: { + func = UnaryNullSelectFunction; + return; + } + case ExpressionType::IS_NOT_NULL: { + func = UnaryNullSelectFunction; + return; + } + default: + throw RuntimeException("Invalid expression type " + + ExpressionTypeUtil::toString(expressionType) + + "for VectorNullOperations::bindUnarySelectFunction."); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_string_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_string_functions.cpp new file mode 100644 index 0000000000..2643298e07 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_string_functions.cpp @@ -0,0 +1,304 @@ +#include "function/string/vector_string_functions.h" + +#include "function/string/functions/array_extract_function.h" +#include "function/string/functions/contains_function.h" +#include "function/string/functions/ends_with_function.h" +#include "function/string/functions/left_operation.h" +#include "function/string/functions/lpad_function.h" +#include "function/string/functions/regexp_extract_all_function.h" +#include "function/string/functions/regexp_extract_function.h" +#include "function/string/functions/regexp_matches_function.h" +#include "function/string/functions/regexp_split_to_array_function.h" +#include "function/string/functions/repeat_function.h" +#include "function/string/functions/right_function.h" +#include "function/string/functions/rpad_function.h" +#include "function/string/functions/starts_with_function.h" +#include "function/string/functions/substr_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +void BaseLowerUpperFunction::operation(ku_string_t& input, ku_string_t& result, + ValueVector& resultValueVector, bool isUpper) { + uint32_t resultLen = getResultLen((char*)input.getData(), input.len, isUpper); + result.len = resultLen; + if (resultLen <= ku_string_t::SHORT_STR_LENGTH) { + convertCase((char*)result.prefix, input.len, (char*)input.getData(), isUpper); + } else { + StringVector::reserveString(&resultValueVector, result, resultLen); + auto buffer = reinterpret_cast(result.overflowPtr); + convertCase(buffer, input.len, (char*)input.getData(), isUpper); + memcpy(result.prefix, buffer, ku_string_t::PREFIX_LENGTH); + } +} + +void BaseStrOperation::operation(ku_string_t& input, ku_string_t& result, + ValueVector& resultValueVector, uint32_t (*strOperation)(char* data, uint32_t len)) { + if (input.len <= ku_string_t::SHORT_STR_LENGTH) { + memcpy(result.prefix, input.prefix, input.len); + result.len = strOperation((char*)result.prefix, input.len); + } else { + StringVector::reserveString(&resultValueVector, result, input.len); + auto buffer = reinterpret_cast(result.overflowPtr); + memcpy(buffer, input.getData(), input.len); + result.len = strOperation(buffer, input.len); + memcpy(result.prefix, buffer, + result.len < ku_string_t::PREFIX_LENGTH ? result.len : ku_string_t::PREFIX_LENGTH); + } +} + +void Repeat::operation(ku_string_t& left, int64_t& right, ku_string_t& result, + ValueVector& resultValueVector) { + result.len = left.len * right; + if (result.len <= ku_string_t::SHORT_STR_LENGTH) { + repeatStr((char*)result.prefix, left.getAsString(), right); + } else { + StringVector::reserveString(&resultValueVector, result, result.len); + auto buffer = reinterpret_cast(result.overflowPtr); + repeatStr(buffer, left.getAsString(), right); + memcpy(result.prefix, buffer, ku_string_t::PREFIX_LENGTH); + } +} + +void Reverse::operation(ku_string_t& input, ku_string_t& result, ValueVector& resultValueVector) { + bool isAscii = true; + std::string inputStr = input.getAsString(); + for (uint32_t i = 0; i < input.len; i++) { + if (inputStr[i] & 0x80) { + isAscii = false; + break; + } + } + if (isAscii) { + BaseStrOperation::operation(input, result, resultValueVector, reverseStr); + } else { + result.len = input.len; + if (result.len > ku_string_t::SHORT_STR_LENGTH) { + StringVector::reserveString(&resultValueVector, result, input.len); + } + auto resultBuffer = result.len <= ku_string_t::SHORT_STR_LENGTH ? + reinterpret_cast(result.prefix) : + reinterpret_cast(result.overflowPtr); + utf8proc::utf8proc_grapheme_callback(inputStr.c_str(), input.len, + [&](size_t start, size_t end) { + memcpy(resultBuffer + input.len - end, input.getData() + start, end - start); + return true; + }); + if (result.len > ku_string_t::SHORT_STR_LENGTH) { + memcpy(result.prefix, resultBuffer, ku_string_t::PREFIX_LENGTH); + } + } +} + +function_set ArrayExtractFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::BinaryExecFunction)); + return functionSet; +} + +void ConcatFunction::execFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + result.resetAuxiliaryBuffer(); + for (auto selectedPos = 0u; selectedPos < resultSelVector->getSelSize(); ++selectedPos) { + auto pos = (*resultSelVector)[selectedPos]; + auto strLen = 0u; + for (auto i = 0u; i < parameters.size(); i++) { + const auto& parameter = *parameters[i]; + const auto& parameterSelVector = parameterSelVectors[i]; + auto paramPos = (*parameterSelVector)[parameter.state->isFlat() ? 0 : selectedPos]; + if (!parameter.isNull(paramPos)) { + strLen += parameter.getValue(paramPos).len; + } + } + auto& resultStr = result.getValue(pos); + StringVector::reserveString(&result, resultStr, strLen); + auto dstData = strLen <= ku_string_t::SHORT_STR_LENGTH ? + resultStr.prefix : + reinterpret_cast(resultStr.overflowPtr); + for (auto i = 0u; i < parameters.size(); i++) { + const auto& parameter = *parameters[i]; + const auto& parameterSelVector = parameterSelVectors[i]; + auto paramPos = (*parameterSelVector)[parameter.state->isFlat() ? 0 : selectedPos]; + if (!parameter.isNull(paramPos)) { + auto srcStr = parameter.getValue(paramPos); + memcpy(dstData, srcStr.getData(), srcStr.len); + dstData += srcStr.len; + } + } + if (strLen > ku_string_t::SHORT_STR_LENGTH) { + memcpy(resultStr.prefix, resultStr.getData(), ku_string_t::PREFIX_LENGTH); + } + } +} + +function_set ConcatFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::STRING, execFunc); + function->isVarLength = true; + functionSet.emplace_back(std::move(function)); + return functionSet; +} + +function_set ContainsFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, + ScalarFunction::BinaryExecFunction, + ScalarFunction::BinarySelectFunction)); + return functionSet; +} + +function_set EndsWithFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, + ScalarFunction::BinaryExecFunction, + ScalarFunction::BinarySelectFunction)); + return functionSet; +} + +function_set LeftFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::BinaryStringExecFunction)); + return functionSet; +} + +function_set LpadFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64, + LogicalTypeID::STRING}, + LogicalTypeID::STRING, + ScalarFunction::TernaryStringExecFunction)); + return functionSet; +} + +function_set RepeatFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::BinaryStringExecFunction)); + return functionSet; +} + +function_set RightFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::BinaryStringExecFunction)); + return functionSet; +} + +function_set RpadFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64, + LogicalTypeID::STRING}, + LogicalTypeID::STRING, + ScalarFunction::TernaryStringExecFunction)); + return functionSet; +} + +function_set StartsWithFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, + ScalarFunction::BinaryExecFunction, + ScalarFunction::BinarySelectFunction)); + return functionSet; +} + +function_set SubStrFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64, + LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::TernaryStringExecFunction)); + return functionSet; +} + +function_set RegexpMatchesFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, + ScalarFunction::BinaryExecFunction, + ScalarFunction::BinarySelectFunction)); + return functionSet; +} + +function_set RegexpExtractFunction::getFunctionSet() { + function_set functionSet; + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::STRING, + ScalarFunction::BinaryStringExecFunction)); + functionSet.emplace_back(make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::INT64}, + LogicalTypeID::STRING, + ScalarFunction::TernaryStringExecFunction)); + return functionSet; +} + +static std::unique_ptr bindFunc(const ScalarBindFuncInput /* input */ + &) { + return std::make_unique(LogicalType::LIST(LogicalType::STRING())); +} + +function_set RegexpExtractAllFunction::getFunctionSet() { + function_set functionSet; + std::unique_ptr func; + func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::LIST, + ScalarFunction::BinaryStringExecFunction); + func->bindFunc = bindFunc; + functionSet.emplace_back(std::move(func)); + func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::INT64}, + LogicalTypeID::LIST, + ScalarFunction::TernaryStringExecFunction); + func->bindFunc = bindFunc; + functionSet.emplace_back(std::move(func)); + return functionSet; +} + +function_set RegexpSplitToArrayFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::LIST, + ScalarFunction::BinaryStringExecFunction); + func->bindFunc = bindFunc; + functionSet.emplace_back(std::move(func)); + return functionSet; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_timestamp_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_timestamp_functions.cpp new file mode 100644 index 0000000000..9935487acf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_timestamp_functions.cpp @@ -0,0 +1,36 @@ +#include "function/timestamp/vector_timestamp_functions.h" + +#include "function/scalar_function.h" +#include "function/timestamp/timestamp_function.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +function_set CenturyFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::INT64, + ScalarFunction::UnaryExecFunction)); + return result; +} + +function_set EpochMsFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::INT64}, LogicalTypeID::TIMESTAMP, + ScalarFunction::UnaryExecFunction)); + return result; +} + +function_set ToTimestampFunction::getFunctionSet() { + function_set result; + result.push_back(make_unique(name, + std::vector{LogicalTypeID::DOUBLE}, LogicalTypeID::TIMESTAMP, + ScalarFunction::UnaryExecFunction)); + return result; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_uuid_functions.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_uuid_functions.cpp new file mode 100644 index 0000000000..dd4b6895f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/function/vector_uuid_functions.cpp @@ -0,0 +1,20 @@ +#include "function/uuid/vector_uuid_functions.h" + +#include "function/scalar_function.h" +#include "function/uuid/functions/gen_random_uuid.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +function_set GenRandomUUIDFunction::getFunctionSet() { + function_set definitions; + definitions.push_back( + make_unique(name, std::vector{}, LogicalTypeID::UUID, + ScalarFunction::NullaryAuxilaryExecFunction)); + return definitions; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/graph/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/CMakeLists.txt new file mode 100644 index 0000000000..4f4341af51 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_graph + OBJECT + graph.cpp + graph_entry.cpp + graph_entry_set.cpp + on_disk_graph.cpp + parsed_graph_entry.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph.cpp new file mode 100644 index 0000000000..081e31f0d4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph.cpp @@ -0,0 +1,18 @@ +#include "graph/graph.h" + +#include "common/system_config.h" + +namespace lbug::graph { +NbrScanState::Chunk::Chunk(std::span nbrNodes, + common::SelectionVector& selVector, + std::span> propertyVectors) + : nbrNodes{nbrNodes}, selVector{selVector}, propertyVectors{propertyVectors} { + KU_ASSERT(nbrNodes.size() == common::DEFAULT_VECTOR_CAPACITY); +} + +VertexScanState::Chunk::Chunk(std::span nodeIDs, + std::span> propertyVectors) + : nodeIDs{nodeIDs}, propertyVectors{propertyVectors} { + KU_ASSERT(nodeIDs.size() <= common::DEFAULT_VECTOR_CAPACITY); +} +} // namespace lbug::graph diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph_entry.cpp new file mode 100644 index 0000000000..fc08c58119 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph_entry.cpp @@ -0,0 +1,59 @@ +#include "graph/graph_entry.h" + +#include "common/exception/runtime.h" + +using namespace lbug::planner; +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace graph { + +NativeGraphEntry::NativeGraphEntry(std::vector nodeEntries, + std::vector relEntries) { + for (auto& entry : nodeEntries) { + nodeInfos.emplace_back(entry); + } + for (auto& entry : relEntries) { + relInfos.emplace_back(entry); + } +} + +std::vector NativeGraphEntry::getNodeTableIDs() const { + std::vector result; + for (auto& info : nodeInfos) { + result.push_back(info.entry->getTableID()); + } + return result; +} + +std::vector NativeGraphEntry::getRelEntries() const { + std::vector result; + for (auto& info : relInfos) { + result.push_back(info.entry); + } + return result; +} + +std::vector NativeGraphEntry::getNodeEntries() const { + std::vector result; + for (auto& info : nodeInfos) { + result.push_back(info.entry); + } + return result; +} + +const NativeGraphEntryTableInfo& NativeGraphEntry::getRelInfo(table_id_t tableID) const { + for (auto& info : relInfos) { + if (info.entry->getTableID() == tableID) { + return info; + } + } + // LCOV_EXCL_START + throw RuntimeException(stringFormat("Cannot find rel table with id {}", tableID)); + // LCOV_EXCL_STOP +} + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph_entry_set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph_entry_set.cpp new file mode 100644 index 0000000000..049216647d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/graph_entry_set.cpp @@ -0,0 +1,29 @@ +#include "graph/graph_entry_set.h" + +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "main/client_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace graph { + +void GraphEntrySet::validateGraphNotExist(const std::string& name) const { + if (hasGraph(name)) { + throw RuntimeException(stringFormat("Projected graph {} already exists.", name)); + } +} + +void GraphEntrySet::validateGraphExist(const std::string& name) const { + if (!hasGraph(name)) { + throw RuntimeException(stringFormat("Projected graph {} does not exists.", name)); + } +} + +GraphEntrySet* GraphEntrySet::Get(const main::ClientContext& context) { + return context.graphEntrySet.get(); +} + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/graph/on_disk_graph.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/on_disk_graph.cpp new file mode 100644 index 0000000000..5ca7f4762c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/on_disk_graph.cpp @@ -0,0 +1,350 @@ +#include "graph/on_disk_graph.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/property_expression.h" +#include "binder/expression_visitor.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/assert.h" +#include "common/cast.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/enums/rel_direction.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "expression_evaluator/expression_evaluator.h" +#include "graph/graph.h" +#include "planner/operator/schema.h" +#include "processor/expression_mapper.h" +#include "storage/local_storage/local_rel_table.h" +#include "storage/local_storage/local_storage.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +using namespace lbug::catalog; +using namespace lbug::storage; +using namespace lbug::main; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::processor; +using namespace lbug::binder; + +namespace lbug { +namespace graph { + +static std::vector getColumnIDs(const expression_vector& propertyExprs, + const TableCatalogEntry& relEntry, const std::vector& propertyColumnIDs) { + auto columnIDs = std::vector{NBR_ID_COLUMN_ID}; + for (auto columnID : propertyColumnIDs) { + columnIDs.push_back(columnID); + } + for (const auto& expr : propertyExprs) { + auto& property = expr->constCast(); + if (property.hasProperty(relEntry.getTableID())) { + columnIDs.push_back(relEntry.getColumnID(property.getPropertyName())); + } else { + columnIDs.push_back(INVALID_COLUMN_ID); + } + } + return columnIDs; +} + +static expression_vector getProperties(std::shared_ptr expr) { + if (expr == nullptr) { + return expression_vector{}; + } + auto collector = PropertyExprCollector(); + collector.visit(std::move(expr)); + return ExpressionUtil::removeDuplication(collector.getPropertyExprs()); +} + +// We generate an empty schema with one group even if exprs is empty because we always need to +// scan edgeID and nbrNodeID which will need the state of empty data chunk. +static Schema getSchema(const expression_vector& exprs) { + auto schema = Schema(); + schema.createGroup(); + for (auto expr : exprs) { + schema.insertToGroupAndScope(expr, 0); + } + return schema; +} + +static ResultSet getResultSet(Schema* schema, MemoryManager* mm) { + auto descriptor = ResultSetDescriptor(schema); + return ResultSet(&descriptor, mm); +} + +static std::unique_ptr getValueVector(const LogicalType& type, MemoryManager* mm, + std::shared_ptr state) { + auto vector = std::make_unique(type.copy(), mm); + vector->state = std::move(state); + return vector; +} + +OnDiskGraphNbrScanState::OnDiskGraphNbrScanState(ClientContext* context, + const TableCatalogEntry& entry, oid_t relTableID, std::shared_ptr predicate) + : OnDiskGraphNbrScanState{context, entry, relTableID, std::move(predicate), {}} {} + +OnDiskGraphNbrScanState::OnDiskGraphNbrScanState(ClientContext* context, + const TableCatalogEntry& entry, oid_t relTableID, std::shared_ptr predicate, + std::vector relProperties, bool randomLookup) { + auto predicateProps = getProperties(predicate); + auto schema = getSchema(predicateProps); + auto mm = MemoryManager::Get(*context); + auto resultSet = getResultSet(&schema, mm); + KU_ASSERT(resultSet.dataChunks.size() == 1); + auto state = resultSet.getDataChunk(0)->state; + srcNodeIDVector = getValueVector(LogicalType::INTERNAL_ID(), mm, state); + srcNodeIDVector->state = DataChunkState::getSingleValueDataChunkState(); + dstNodeIDVector = getValueVector(LogicalType::INTERNAL_ID(), mm, state); + propertyVectors.valueVectors.resize(relProperties.size()); + // TODO(bmwinger): If there are both a predicate and a custom edgePropertyIndex, they will + // currently be scanned twice. The propertyVector could simply be one of the vectors used + // for the predicate. + std::vector relPropertyColumnIDs; + relPropertyColumnIDs.resize(relProperties.size()); + for (auto i = 0u; i < relProperties.size(); ++i) { + auto propertyName = relProperties[i]; + auto& property = entry.getProperty(propertyName); + relPropertyColumnIDs[i] = entry.getColumnID(propertyName); + KU_ASSERT(relPropertyColumnIDs[i] != INVALID_COLUMN_ID); + propertyVectors.valueVectors[i] = getValueVector(property.getType(), mm, state); + } + if (predicate != nullptr) { + auto mapper = ExpressionMapper(&schema); + relPredicateEvaluator = mapper.getEvaluator(predicate); + relPredicateEvaluator->init(resultSet, context); + } + auto table = StorageManager::Get(*context)->getTable(relTableID)->ptrCast(); + for (auto dataDirection : entry.constCast().getRelDataDirections()) { + auto columnIDs = getColumnIDs(predicateProps, entry, relPropertyColumnIDs); + std::vector outVectors{dstNodeIDVector.get()}; + for (auto i = 0u; i < propertyVectors.getNumValueVectors(); i++) { + outVectors.push_back(&propertyVectors.getValueVectorMutable(i)); + } + for (auto& property : predicateProps) { + auto pos = DataPos(schema.getExpressionPos(*property)); + outVectors.push_back(resultSet.getValueVector(pos).get()); + } + auto scanState = std::make_unique(*MemoryManager::Get(*context), + srcNodeIDVector.get(), outVectors, dstNodeIDVector->state, randomLookup); + scanState->setToTable(transaction::Transaction::Get(*context), table, columnIDs, {}, + dataDirection); + directedIterators.emplace_back(context, table, std::move(scanState)); + } +} + +OnDiskGraph::OnDiskGraph(ClientContext* context, NativeGraphEntry entry) + : context{context}, graphEntry{std::move(entry)} { + auto storage = StorageManager::Get(*context); + for (const auto& nodeInfo : graphEntry.nodeInfos) { + auto id = nodeInfo.entry->getTableID(); + nodeIDToNodeTable.insert({id, storage->getTable(id)->ptrCast()}); + } + for (auto& relInfo : graphEntry.relInfos) { + auto relGroupEntry = relInfo.entry->ptrCast(); + for (auto& relEntryInfo : relGroupEntry->getRelEntryInfos()) { + auto srcTableID = relEntryInfo.nodePair.srcTableID; + auto dstTableID = relEntryInfo.nodePair.dstTableID; + if (!nodeIDToNodeTable.contains(srcTableID)) { + continue; + } + if (!nodeIDToNodeTable.contains(dstTableID)) { + continue; + } + relInfos.emplace_back(srcTableID, dstTableID, relGroupEntry, relEntryInfo.oid); + } + } +} + +table_id_map_t OnDiskGraph::getMaxOffsetMap(transaction::Transaction* transaction) const { + table_id_map_t result; + for (auto tableID : getNodeTableIDs()) { + result[tableID] = getMaxOffset(transaction, tableID); + } + return result; +} + +offset_t OnDiskGraph::getMaxOffset(transaction::Transaction* transaction, table_id_t id) const { + KU_ASSERT(nodeIDToNodeTable.contains(id)); + return nodeIDToNodeTable.at(id)->getNumTotalRows(transaction); +} + +offset_t OnDiskGraph::getNumNodes(transaction::Transaction* transaction) const { + offset_t numNodes = 0u; + for (auto id : getNodeTableIDs()) { + if (nodeOffsetMaskMap != nullptr && nodeOffsetMaskMap->containsTableID(id)) { + numNodes += nodeOffsetMaskMap->getOffsetMask(id)->getNumMaskedNodes(); + } else { + numNodes += getMaxOffset(transaction, id); + } + } + return numNodes; +} + +std::vector OnDiskGraph::getRelInfos(table_id_t srcTableID) { + std::vector result; + for (auto& info : relInfos) { + if (info.srcTableID == srcTableID) { + result.push_back(info); + } + } + return result; +} + +// TODO(Xiyang): since now we need to provide nbr info at prepare stage. It no longer make sense to +// have scanFwd&scanBwd. The direction has already been decided in this function. +std::unique_ptr OnDiskGraph::prepareRelScan(const TableCatalogEntry& entry, + oid_t relTableID, table_id_t nbrTableID, std::vector relProperties, + bool randomLookup) { + auto& info = graphEntry.getRelInfo(entry.getTableID()); + auto state = std::make_unique(context, entry, relTableID, + info.predicate, relProperties, randomLookup); + if (nodeOffsetMaskMap != nullptr && nodeOffsetMaskMap->containsTableID(nbrTableID)) { + state->nbrNodeMask = nodeOffsetMaskMap->getOffsetMask(nbrTableID); + } + return state; +} + +Graph::EdgeIterator OnDiskGraph::scanFwd(nodeID_t nodeID, NbrScanState& state) { + auto& onDiskScanState = ku_dynamic_cast(state); + onDiskScanState.srcNodeIDVector->setValue(0, nodeID); + onDiskScanState.dstNodeIDVector->state->getSelVectorUnsafe().setSelSize(0); + onDiskScanState.startScan(RelDataDirection::FWD); + return EdgeIterator(&onDiskScanState); +} + +Graph::EdgeIterator OnDiskGraph::scanBwd(nodeID_t nodeID, NbrScanState& state) { + auto& onDiskScanState = ku_dynamic_cast(state); + onDiskScanState.srcNodeIDVector->setValue(0, nodeID); + onDiskScanState.dstNodeIDVector->state->getSelVectorUnsafe().setSelSize(0); + onDiskScanState.startScan(RelDataDirection::BWD); + return EdgeIterator(&onDiskScanState); +} + +Graph::VertexIterator OnDiskGraph::scanVertices(offset_t beginOffset, offset_t endOffsetExclusive, + VertexScanState& state) { + auto& onDiskVertexScanState = ku_dynamic_cast(state); + onDiskVertexScanState.startScan(beginOffset, endOffsetExclusive); + return VertexIterator(&state); +} + +std::unique_ptr OnDiskGraph::prepareVertexScan(TableCatalogEntry* tableEntry, + const std::vector& propertiesToScan) { + return std::make_unique(*context, tableEntry, propertiesToScan); +} + +bool OnDiskGraphNbrScanState::InnerIterator::next(evaluator::ExpressionEvaluator* predicate, + SemiMask* nbrNodeMask_) { + bool hasAtLeastOneSelectedValue = false; + do { + restoreSelVector(*tableScanState->outState); + if (!relTable->scan(transaction::Transaction::Get(*context), *tableScanState)) { + return false; + } + saveSelVector(*tableScanState->outState); + hasAtLeastOneSelectedValue = tableScanState->outState->getSelVector().getSelSize() > 0; + if (predicate != nullptr) { + hasAtLeastOneSelectedValue = + predicate->select(tableScanState->outState->getSelVectorUnsafe(), + !tableScanState->outState->isFlat()); + } + if (nbrNodeMask_ != nullptr) { + auto selectedSize = 0u; + auto buffer = tableScanState->outState->getSelVectorUnsafe().getMutableBuffer(); + for (auto i = 0u; i < tableScanState->outState->getSelSize(); ++i) { + auto pos = tableScanState->outState->getSelVector()[i]; + buffer[selectedSize] = pos; + auto nbrNodeID = tableScanState->outputVectors[0]->getValue(pos); + selectedSize += nbrNodeMask_->isMasked(nbrNodeID.offset); + } + tableScanState->outState->getSelVectorUnsafe().setToFiltered(selectedSize); + hasAtLeastOneSelectedValue = selectedSize > 0; + } + } while (!hasAtLeastOneSelectedValue); + return true; +} + +OnDiskGraphNbrScanState::InnerIterator::InnerIterator(const ClientContext* context, + RelTable* relTable, std::unique_ptr tableScanState) + : context{context}, relTable{relTable}, tableScanState{std::move(tableScanState)} {} + +void OnDiskGraphNbrScanState::InnerIterator::initScan() const { + relTable->initScanState(transaction::Transaction::Get(*context), *tableScanState); +} + +void OnDiskGraphNbrScanState::startScan(RelDataDirection direction) { + auto idx = RelDirectionUtils::relDirectionToKeyIdx(direction); + KU_ASSERT(idx < directedIterators.size() && directedIterators[idx].getDirection() == direction); + currentIter = &directedIterators[idx]; + currentIter->initScan(); +} + +bool OnDiskGraphNbrScanState::next() { + KU_ASSERT(currentIter != nullptr); + if (currentIter->next(relPredicateEvaluator.get(), nbrNodeMask)) { + return true; + } + return false; +} + +OnDiskGraphVertexScanState::OnDiskGraphVertexScanState(ClientContext& context, + const TableCatalogEntry* tableEntry, const std::vector& propertyNames) + : context{context}, nodeTable{ku_dynamic_cast( + *StorageManager::Get(context)->getTable(tableEntry->getTableID()))}, + numNodesToScan{0}, currentOffset{0}, endOffsetExclusive{0} { + std::vector propertyColumnIDs; + propertyColumnIDs.reserve(propertyNames.size()); + std::vector types; + for (const auto& property : propertyNames) { + auto columnID = tableEntry->getColumnID(property); + propertyColumnIDs.push_back(columnID); + types.push_back(tableEntry->getProperty(property).getType().copy()); + } + propertyVectors = Table::constructDataChunk(MemoryManager::Get(context), std::move(types)); + nodeIDVector = std::make_unique(LogicalType::INTERNAL_ID(), + MemoryManager::Get(context), propertyVectors.state); + std::vector outVectors; + for (auto i = 0u; i < propertyVectors.getNumValueVectors(); i++) { + outVectors.push_back(&propertyVectors.getValueVectorMutable(i)); + } + tableScanState = + std::make_unique(nodeIDVector.get(), outVectors, propertyVectors.state); + auto table = StorageManager::Get(context)->getTable(tableEntry->getTableID()); + tableScanState->setToTable(transaction::Transaction::Get(context), table, propertyColumnIDs); +} + +void OnDiskGraphVertexScanState::startScan(offset_t beginOffset, offset_t endOffsetExclusive) { + numNodesToScan = 0; + this->currentOffset = beginOffset; + this->endOffsetExclusive = endOffsetExclusive; + tableScanState->nodeIDVector->getSelVectorPtr()->setToUnfiltered(0); + for (auto& vector : tableScanState->outputVectors) { + vector->resetAuxiliaryBuffer(); + } + nodeTable.initScanState(transaction::Transaction::Get(context), *tableScanState, + nodeTable.getTableID(), beginOffset); +} + +bool OnDiskGraphVertexScanState::next() { + if (currentOffset >= endOffsetExclusive) { + return false; + } + startScan(currentOffset, endOffsetExclusive); + + auto startOffsetOfNextGroup = + StorageUtils::getStartOffsetOfNodeGroup(tableScanState->nodeGroupIdx + 1); + auto transaction = transaction::Transaction::Get(context); + auto endOffset = std::min(endOffsetExclusive, + tableScanState->source == TableScanSource::COMMITTED ? + startOffsetOfNextGroup : + startOffsetOfNextGroup + transaction->getUncommittedOffset( + tableScanState->table->getTableID(), currentOffset)); + numNodesToScan = std::min(endOffset - currentOffset, DEFAULT_VECTOR_CAPACITY); + auto result = tableScanState->scanNext(transaction, currentOffset, numNodesToScan); + currentOffset += result.numRows; + return result != NODE_GROUP_SCAN_EMPTY_RESULT; +} + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/graph/parsed_graph_entry.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/parsed_graph_entry.cpp new file mode 100644 index 0000000000..a6713393d6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/graph/parsed_graph_entry.cpp @@ -0,0 +1,20 @@ +#include "graph/parsed_graph_entry.h" + +using namespace lbug::common; + +namespace lbug { +namespace graph { + +std::string GraphEntryTypeUtils::toString(GraphEntryType type) { + switch (type) { + case GraphEntryType::NATIVE: + return "NATIVE"; + case GraphEntryType::CYPHER: + return "CYPHER"; + default: + KU_UNREACHABLE; + } +} + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/binder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/binder.h new file mode 100644 index 0000000000..0b305c20d3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/binder.h @@ -0,0 +1,332 @@ +#pragma once + +#include "binder/binder_scope.h" +#include "binder/expression_binder.h" +#include "binder/query/bound_regular_query.h" +#include "binder/query/query_graph.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/copier_config/file_scan_info.h" +#include "parser/ddl/parsed_property_definition.h" +#include "parser/query/graph_pattern/pattern_element.h" + +namespace lbug { +namespace extension { +class BinderExtension; +} + +namespace parser { +class ProjectionBody; +class ReturnClause; +class WithClause; +class UpdatingClause; +class ReadingClause; +class QueryPart; +class SingleQuery; +struct CreateTableInfo; +struct BaseScanSource; +struct JoinHintNode; +class Statement; +struct YieldVariable; +} // namespace parser + +namespace catalog { +class NodeTableCatalogEntry; +class RelGroupCatalogEntry; +class Catalog; +} // namespace catalog + +namespace main { +class ClientContext; +class Database; +} // namespace main + +namespace function { +struct TableFunction; +} // namespace function + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace binder { +struct BoundBaseScanSource; +struct BoundCreateTableInfo; +struct BoundInsertInfo; +struct BoundSetPropertyInfo; +struct BoundDeleteInfo; +class BoundWithClause; +class BoundReturnClause; +struct ExportedTableData; +struct BoundJoinHintNode; +struct BoundCopyFromInfo; +struct BoundTableScanInfo; + +// BinderScope keeps track of expressions in scope and their aliases. We maintain the order of +// expressions in + +class Binder { + friend class ExpressionBinder; + +public: + explicit Binder(main::ClientContext* clientContext, + std::vector binderExtensions = {}) + : lastExpressionId{0}, scope{}, expressionBinder{this, clientContext}, + clientContext{clientContext}, binderExtensions{std::move(binderExtensions)} {} + + LBUG_API std::unique_ptr bind(const parser::Statement& statement); + + LBUG_API std::shared_ptr createVariable(const std::string& name, + const common::LogicalType& dataType); + LBUG_API std::shared_ptr createInvisibleVariable(const std::string& name, + const common::LogicalType& dataType) const; + LBUG_API expression_vector createVariables(const std::vector& names, + const std::vector& types); + LBUG_API expression_vector createInvisibleVariables(const std::vector& names, + const std::vector& types) const; + + std::shared_ptr bindWhereExpression( + const parser::ParsedExpression& parsedExpression); + + std::shared_ptr createVariable(std::string_view name, common::LogicalTypeID typeID); + std::shared_ptr createVariable(const std::string& name, + common::LogicalTypeID logicalTypeID); + + /*** bind DDL ***/ + BoundCreateTableInfo bindCreateTableInfo(const parser::CreateTableInfo* info); + BoundCreateTableInfo bindCreateNodeTableInfo(const parser::CreateTableInfo* info); + BoundCreateTableInfo bindCreateRelTableGroupInfo(const parser::CreateTableInfo* info); + std::unique_ptr bindCreateTable(const parser::Statement& statement); + std::unique_ptr bindCreateTableAs(const parser::Statement& createTable); + std::unique_ptr bindCreateType(const parser::Statement& statement) const; + std::unique_ptr bindCreateSequence(const parser::Statement& statement) const; + + static std::unique_ptr bindDrop(const parser::Statement& statement); + std::unique_ptr bindAlter(const parser::Statement& statement); + std::unique_ptr bindRenameTable(const parser::Statement& statement) const; + std::unique_ptr bindAddProperty(const parser::Statement& statement); + std::unique_ptr bindDropProperty(const parser::Statement& statement) const; + std::unique_ptr bindRenameProperty(const parser::Statement& statement) const; + std::unique_ptr bindCommentOn(const parser::Statement& statement) const; + std::unique_ptr bindAlterFromToConnection( + const parser::Statement& statement) const; + + std::vector bindPropertyDefinitions( + const std::vector& parsedDefinitions, + const std::string& tableName); + + std::unique_ptr resolvePropertyDefault( + parser::ParsedExpression* parsedDefault, const common::LogicalType& type, + const std::string& tableName, const std::string& propertyName); + + /*** bind copy ***/ + BoundCopyFromInfo bindCopyNodeFromInfo(std::string tableName, + const std::vector& properties, const parser::BaseScanSource* source, + const parser::options_t& parsingOptions, + const std::vector& expectedColumnNames, + const std::vector& expectedColumnTypes, bool byColumn); + BoundCopyFromInfo bindCopyRelFromInfo(std::string tableName, + const std::vector& properties, const parser::BaseScanSource* source, + const parser::options_t& parsingOptions, + const std::vector& expectedColumnNames, + const std::vector& expectedColumnTypes, + const catalog::NodeTableCatalogEntry* fromTable, + const catalog::NodeTableCatalogEntry* toTable); + std::unique_ptr bindCopyFromClause(const parser::Statement& statement); + std::unique_ptr bindCopyNodeFrom(const parser::Statement& statement, + catalog::NodeTableCatalogEntry& nodeEntry); + std::unique_ptr bindCopyRelFrom(const parser::Statement& statement, + catalog::RelGroupCatalogEntry& relGroupEntry, const std::string& fromTableName, + const std::string& toTableName); + std::unique_ptr bindLegacyCopyRelGroupFrom(const parser::Statement& copyFrom); + + std::unique_ptr bindCopyToClause(const parser::Statement& statement); + + std::unique_ptr bindExportDatabaseClause(const parser::Statement& statement); + std::unique_ptr bindImportDatabaseClause(const parser::Statement& statement); + + static std::unique_ptr bindAttachDatabase(const parser::Statement& statement); + static std::unique_ptr bindDetachDatabase(const parser::Statement& statement); + static std::unique_ptr bindUseDatabase(const parser::Statement& statement); + std::unique_ptr bindExtensionClause(const parser::Statement& statement); + + /*** bind scan source ***/ + std::unique_ptr bindScanSource(const parser::BaseScanSource* source, + const parser::options_t& options, const std::vector& columnNames, + const std::vector& columnTypes); + std::unique_ptr bindFileScanSource( + const parser::BaseScanSource& scanSource, const parser::options_t& options, + const std::vector& columnNames, + const std::vector& columnTypes); + std::unique_ptr bindQueryScanSource( + const parser::BaseScanSource& scanSource, const parser::options_t& options, + const std::vector& columnNames, + const std::vector& columnTypes); + std::unique_ptr bindObjectScanSource( + const parser::BaseScanSource& scanSource, const parser::options_t& options, + const std::vector& columnNames, + const std::vector& columnTypes); + std::unique_ptr bindParameterScanSource( + const parser::BaseScanSource& scanSource, const parser::options_t& options, + const std::vector& columnNames, + const std::vector& columnTypes); + std::unique_ptr bindTableFuncScanSource( + const parser::BaseScanSource& scanSource, const parser::options_t& options, + const std::vector& columnNames, + const std::vector& columnTypes); + + common::case_insensitive_map_t bindParsingOptions( + const parser::options_t& parsingOptions); + common::FileTypeInfo bindFileTypeInfo(const std::vector& filePaths) const; + std::vector bindFilePaths(const std::vector& filePaths) const; + + /*** bind query ***/ + std::unique_ptr bindQuery(const parser::Statement& statement); + NormalizedSingleQuery bindSingleQuery(const parser::SingleQuery& singleQuery); + NormalizedQueryPart bindQueryPart(const parser::QueryPart& queryPart); + + /*** bind standalone call ***/ + std::unique_ptr bindStandaloneCall(const parser::Statement& statement); + + /*** bind standalone call function ***/ + std::unique_ptr bindStandaloneCallFunction(const parser::Statement& statement); + + /*** bind table function ***/ + BoundTableScanInfo bindTableFunc(const std::string& tableFuncName, + const parser::ParsedExpression& expr, std::vector yieldVariables); + + /*** bind create macro ***/ + std::unique_ptr bindCreateMacro(const parser::Statement& statement) const; + + /*** bind transaction ***/ + static std::unique_ptr bindTransaction(const parser::Statement& statement); + + /*** bind extension ***/ + std::unique_ptr bindExtension(const parser::Statement& statement); + + /*** bind explain ***/ + std::unique_ptr bindExplain(const parser::Statement& statement); + + /*** bind reading clause ***/ + std::unique_ptr bindReadingClause( + const parser::ReadingClause& readingClause); + std::unique_ptr bindMatchClause(const parser::ReadingClause& readingClause); + std::shared_ptr bindJoinHint( + const QueryGraphCollection& queryGraphCollection, const parser::JoinHintNode& joinHintNode); + std::shared_ptr bindJoinNode(const parser::JoinHintNode& joinHintNode); + void rewriteMatchPattern(BoundGraphPattern& boundGraphPattern); + std::unique_ptr bindUnwindClause( + const parser::ReadingClause& readingClause); + std::unique_ptr bindInQueryCall(const parser::ReadingClause& readingClause); + std::unique_ptr bindLoadFrom(const parser::ReadingClause& readingClause); + + /*** bind updating clause ***/ + std::unique_ptr bindUpdatingClause( + const parser::UpdatingClause& updatingClause); + std::unique_ptr bindInsertClause( + const parser::UpdatingClause& updatingClause); + std::unique_ptr bindMergeClause( + const parser::UpdatingClause& updatingClause); + std::unique_ptr bindSetClause( + const parser::UpdatingClause& updatingClause); + std::unique_ptr bindDeleteClause( + const parser::UpdatingClause& updatingClause); + + std::vector bindInsertInfos(QueryGraphCollection& queryGraphCollection, + const std::unordered_set& patternsInScope_); + void bindInsertNode(std::shared_ptr node, std::vector& infos); + void bindInsertRel(std::shared_ptr rel, std::vector& infos); + expression_vector bindInsertColumnDataExprs( + const common::case_insensitive_map_t>& propertyDataExprs, + const std::vector& propertyDefinitions); + + BoundSetPropertyInfo bindSetPropertyInfo(const parser::ParsedExpression* column, + const parser::ParsedExpression* columnData); + expression_pair bindSetItem(const parser::ParsedExpression* column, + const parser::ParsedExpression* columnData); + + /*** bind projection clause ***/ + BoundWithClause bindWithClause(const parser::WithClause& withClause); + BoundReturnClause bindReturnClause(const parser::ReturnClause& returnClause); + + std::pair> bindProjectionList( + const parser::ProjectionBody& projectionBody); + BoundProjectionBody bindProjectionBody(const parser::ProjectionBody& projectionBody, + const expression_vector& projectionExprs, const std::vector& aliases); + + expression_vector bindOrderByExpressions( + const std::vector>& parsedExprs); + std::shared_ptr bindSkipLimitExpression(const parser::ParsedExpression& expression); + + /*** bind graph pattern ***/ + BoundGraphPattern bindGraphPattern(const std::vector& graphPattern); + + QueryGraph bindPatternElement(const parser::PatternElement& patternElement); + std::shared_ptr createPath(const std::string& pathName, + const expression_vector& children); + + std::shared_ptr bindQueryRel(const parser::RelPattern& relPattern, + const std::shared_ptr& leftNode, + const std::shared_ptr& rightNode, QueryGraph& queryGraph); + std::shared_ptr createNonRecursiveQueryRel(const std::string& parsedName, + const std::vector& entries, + std::shared_ptr srcNode, std::shared_ptr dstNode, + RelDirectionType directionType); + std::shared_ptr createRecursiveQueryRel(const parser::RelPattern& relPattern, + const std::vector& entries, + std::shared_ptr srcNode, std::shared_ptr dstNode, + RelDirectionType directionType); + expression_vector bindRecursivePatternNodeProjectionList( + const parser::RecursiveRelPatternInfo& info, const NodeOrRelExpression& expr); + expression_vector bindRecursivePatternRelProjectionList( + const parser::RecursiveRelPatternInfo& info, const NodeOrRelExpression& expr); + std::pair bindVariableLengthRelBound(const parser::RelPattern& relPattern); + + std::shared_ptr bindQueryNode(const parser::NodePattern& nodePattern, + QueryGraph& queryGraph); + std::shared_ptr createQueryNode(const parser::NodePattern& nodePattern); + LBUG_API std::shared_ptr createQueryNode(const std::string& parsedName, + const std::vector& entries); + + /*** bind table entries ***/ + std::vector bindNodeTableEntries( + const std::vector& tableNames) const; + std::vector bindRelGroupEntries( + const std::vector& tableNames) const; + catalog::TableCatalogEntry* bindNodeTableEntry(const std::string& name) const; + std::vector bindRelPropertyDefinitions(const parser::CreateTableInfo& info); + + /*** validations ***/ + LBUG_API static void validateTableExistence(const main::ClientContext& context, + const std::string& tableName); + LBUG_API static void validateNodeTableType(const catalog::TableCatalogEntry* entry); + LBUG_API static void validateColumnExistence(const catalog::TableCatalogEntry* entry, + const std::string& columnName); + + /*** helpers ***/ + std::string getUniqueExpressionName(const std::string& name); + + static bool reservedInColumnName(const std::string& name); + static bool reservedInPropertyLookup(const std::string& name); + + void addToScope(const std::vector& names, const expression_vector& exprs); + LBUG_API void addToScope(const std::string& name, std::shared_ptr expr); + BinderScope saveScope() const; + void restoreScope(BinderScope prevScope); + void replaceExpressionInScope(const std::string& oldName, const std::string& newName, + std::shared_ptr expression); + + function::TableFunction getScanFunction(const common::FileTypeInfo& typeInfo, + const common::FileScanInfo& fileScanInfo) const; + + ExpressionBinder* getExpressionBinder() { return &expressionBinder; } + +private: + common::idx_t lastExpressionId; + BinderScope scope; + ExpressionBinder expressionBinder; + main::ClientContext* clientContext; + std::vector binderExtensions; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/binder_scope.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/binder_scope.h new file mode 100644 index 0000000000..1ad85b5e57 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/binder_scope.h @@ -0,0 +1,71 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "binder/expression/node_expression.h" +#include "common/case_insensitive_map.h" + +namespace lbug { +namespace binder { + +class BinderScope { +public: + BinderScope() = default; + EXPLICIT_COPY_DEFAULT_MOVE(BinderScope); + + bool empty() const { return expressions.empty(); } + bool contains(const std::string& varName) const { return nameToExprIdx.contains(varName); } + std::shared_ptr getExpression(const std::string& varName) const { + KU_ASSERT(nameToExprIdx.contains(varName)); + return expressions[nameToExprIdx.at(varName)]; + } + expression_vector getExpressions() const { return expressions; } + void addExpression(const std::string& varName, std::shared_ptr expression); + void replaceExpression(const std::string& oldName, const std::string& newName, + std::shared_ptr expression); + + void memorizeTableEntries(const std::string& name, + std::vector entries) { + memorizedNodeNameToEntries.insert({name, entries}); + } + bool hasMemorizedTableIDs(const std::string& name) const { + return memorizedNodeNameToEntries.contains(name); + } + std::vector getMemorizedTableEntries(const std::string& name) { + KU_ASSERT(memorizedNodeNameToEntries.contains(name)); + return memorizedNodeNameToEntries.at(name); + } + + void addNodeReplacement(std::shared_ptr node) { + nodeReplacement.insert({node->getVariableName(), node}); + } + bool hasNodeReplacement(const std::string& name) const { + return nodeReplacement.contains(name); + } + std::shared_ptr getNodeReplacement(const std::string& name) const { + KU_ASSERT(hasNodeReplacement(name)); + return nodeReplacement.at(name); + } + + void clear(); + +private: + BinderScope(const BinderScope& other) + : expressions{other.expressions}, nameToExprIdx{other.nameToExprIdx}, + memorizedNodeNameToEntries{other.memorizedNodeNameToEntries} {} + +private: + // Expressions in scope. Order should be preserved. + expression_vector expressions; + common::case_insensitive_map_t nameToExprIdx; + // A node might be popped out of scope. But we may need to retain its table ID information. + // E.g. MATCH (a:person) WITH collect(a) AS list_a UNWIND list_a AS new_a MATCH (new_a)-[]->() + // It will be more performant if we can retain the information that new_a has label person. + common::case_insensitive_map_t> + memorizedNodeNameToEntries; + // A node pattern may not always be bound as a node expression, e.g. in the above query, + // (new_a) is bound as a variable rather than node expression. + common::case_insensitive_map_t> nodeReplacement; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_attach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_attach_database.h new file mode 100644 index 0000000000..f09a41c1d7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_attach_database.h @@ -0,0 +1,23 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "bound_attach_info.h" + +namespace lbug { +namespace binder { + +class BoundAttachDatabase final : public BoundStatement { +public: + explicit BoundAttachDatabase(binder::AttachInfo attachInfo) + : BoundStatement{common::StatementType::ATTACH_DATABASE, + BoundStatementResult::createSingleStringColumnResult()}, + attachInfo{std::move(attachInfo)} {} + + AttachInfo getAttachInfo() const { return attachInfo; } + +private: + AttachInfo attachInfo; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_attach_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_attach_info.h new file mode 100644 index 0000000000..fb565aa7b2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_attach_info.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/case_insensitive_map.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace binder { + +struct LBUG_API AttachOption { + common::case_insensitive_map_t options; +}; + +struct LBUG_API AttachInfo { + AttachInfo(std::string dbPath, std::string dbAlias, std::string dbType, AttachOption options) + : dbPath{std::move(dbPath)}, dbAlias{std::move(dbAlias)}, dbType{std::move(dbType)}, + options{std::move(options)} {} + + std::string dbPath, dbAlias, dbType; + AttachOption options; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_create_macro.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_create_macro.h new file mode 100644 index 0000000000..8b38f325e6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_create_macro.h @@ -0,0 +1,29 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "function/scalar_macro_function.h" + +namespace lbug { +namespace binder { + +class BoundCreateMacro final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_MACRO; + +public: + explicit BoundCreateMacro(std::string macroName, + std::unique_ptr macro) + : BoundStatement{type_, + BoundStatementResult::createSingleStringColumnResult("result" /* columnName */)}, + macroName{std::move(macroName)}, macro{std::move(macro)} {} + + std::string getMacroName() const { return macroName; } + + std::unique_ptr getMacro() const { return macro->copy(); } + +private: + std::string macroName; + std::unique_ptr macro; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_database_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_database_statement.h new file mode 100644 index 0000000000..8801867ab4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_database_statement.h @@ -0,0 +1,21 @@ +#pragma once + +#include "binder/bound_statement.h" + +namespace lbug { +namespace binder { + +class BoundDatabaseStatement : public BoundStatement { +public: + explicit BoundDatabaseStatement(common::StatementType statementType, std::string dbName) + : BoundStatement{statementType, BoundStatementResult::createSingleStringColumnResult()}, + dbName{std::move(dbName)} {} + + std::string getDBName() const { return dbName; } + +private: + std::string dbName; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_detach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_detach_database.h new file mode 100644 index 0000000000..bbe3b90af1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_detach_database.h @@ -0,0 +1,15 @@ +#pragma once + +#include "binder/bound_database_statement.h" + +namespace lbug { +namespace binder { + +class BoundDetachDatabase final : public BoundDatabaseStatement { +public: + explicit BoundDetachDatabase(std::string dbName) + : BoundDatabaseStatement{common::StatementType::DETACH_DATABASE, std::move(dbName)} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_explain.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_explain.h new file mode 100644 index 0000000000..eadb4e97e3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_explain.h @@ -0,0 +1,29 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "common/enums/explain_type.h" + +namespace lbug { +namespace binder { + +class BoundExplain final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::EXPLAIN; + +public: + explicit BoundExplain(std::unique_ptr statementToExplain, + common::ExplainType explainType) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult( + "explain result" /* columnName */)}, + statementToExplain{std::move(statementToExplain)}, explainType{explainType} {} + + BoundStatement* getStatementToExplain() const { return statementToExplain.get(); } + + common::ExplainType getExplainType() const { return explainType; } + +private: + std::unique_ptr statementToExplain; + common::ExplainType explainType; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_export_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_export_database.h new file mode 100644 index 0000000000..0711b2fbc4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_export_database.h @@ -0,0 +1,51 @@ +#pragma once +#include "binder/binder.h" +#include "binder/bound_statement.h" +#include "binder/query/bound_regular_query.h" +#include "common/copier_config/file_scan_info.h" + +namespace lbug { +namespace binder { + +struct ExportedTableData { + std::string tableName; + std::string fileName; + std::unique_ptr regularQuery; + std::vector columnNames; + std::vector columnTypes; + + const std::vector& getColumnTypesRef() const { return columnTypes; } + const BoundRegularQuery* getRegularQuery() const { return regularQuery.get(); } +}; + +class BoundExportDatabase final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::EXPORT_DATABASE; + +public: + BoundExportDatabase(std::string filePath, common::FileTypeInfo fileTypeInfo, + std::vector exportData, + common::case_insensitive_map_t csvOption, bool schemaOnly) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult()}, + exportData(std::move(exportData)), + boundFileInfo(std::move(fileTypeInfo), std::vector{std::move(filePath)}), + schemaOnly{schemaOnly} { + boundFileInfo.options = std::move(csvOption); + } + + std::string getFilePath() const { return boundFileInfo.filePaths[0]; } + common::FileType getFileType() const { return boundFileInfo.fileTypeInfo.fileType; } + common::case_insensitive_map_t getExportOptions() const { + return boundFileInfo.options; + } + const common::FileScanInfo* getBoundFileInfo() const { return &boundFileInfo; } + const std::vector* getExportData() const { return &exportData; } + bool exportSchemaOnly() const { return schemaOnly; } + +private: + std::vector exportData; + common::FileScanInfo boundFileInfo; + bool schemaOnly; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_extension_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_extension_statement.h new file mode 100644 index 0000000000..ad2e58acfd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_extension_statement.h @@ -0,0 +1,26 @@ +#pragma once + +#include "bound_statement.h" +#include "extension/extension_action.h" + +namespace lbug { +namespace binder { + +using namespace lbug::extension; + +class BoundExtensionStatement final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::EXTENSION; + +public: + explicit BoundExtensionStatement(std::unique_ptr info) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult()}, + info{std::move(info)} {} + + std::unique_ptr getAuxInfo() const { return info->copy(); } + +private: + std::unique_ptr info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_import_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_import_database.h new file mode 100644 index 0000000000..2b23edccea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_import_database.h @@ -0,0 +1,29 @@ +#pragma once +#include "binder/bound_statement.h" + +namespace lbug { +namespace binder { + +class BoundImportDatabase final : public BoundStatement { +public: + BoundImportDatabase(std::string filePath, std::string query, std::string indexQuery) + : BoundStatement{common::StatementType::IMPORT_DATABASE, + BoundStatementResult::createSingleStringColumnResult()}, + filePath{std::move(filePath)}, query{std::move(query)}, + indexQuery{std::move(indexQuery)} {} + + std::string getFilePath() const { return filePath; } + std::string getQuery() const { return query; } + std::string getIndexQuery() const { return indexQuery; } + +private: + std::string filePath; + // We concatenate queries based on the schema.cypher, copy.cypher, and macro.cypher files + // generated by exporting the database, resulting in queries such as "create node table xxx; + // create rel table xxx; copy from xxxx;". + std::string query; + std::string indexQuery; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_scan_source.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_scan_source.h new file mode 100644 index 0000000000..77e00720bd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_scan_source.h @@ -0,0 +1,85 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "binder/copy/bound_query_scan_info.h" +#include "bound_table_scan_info.h" +#include "common/enums/scan_source_type.h" +#include "function/table/scan_file_function.h" + +namespace lbug { +namespace binder { + +struct BoundBaseScanSource { + common::ScanSourceType type; + + explicit BoundBaseScanSource(common::ScanSourceType type) : type{type} {} + virtual ~BoundBaseScanSource() = default; + + virtual expression_vector getColumns() = 0; + virtual expression_vector getWarningColumns() const { return expression_vector{}; }; + virtual bool getIgnoreErrorsOption() const { + return common::CopyConstants::DEFAULT_IGNORE_ERRORS; + }; + virtual common::column_id_t getNumWarningDataColumns() const { return 0; } + + virtual std::unique_ptr copy() const = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +protected: + BoundBaseScanSource(const BoundBaseScanSource& other) : type{other.type} {} +}; + +struct BoundTableScanSource final : BoundBaseScanSource { + BoundTableScanInfo info; + + explicit BoundTableScanSource(common::ScanSourceType type, BoundTableScanInfo info) + : BoundBaseScanSource{type}, info{std::move(info)} {} + BoundTableScanSource(const BoundTableScanSource& other) + : BoundBaseScanSource{other}, info{other.info.copy()} {} + + expression_vector getColumns() override { return info.bindData->columns; } + expression_vector getWarningColumns() const override; + bool getIgnoreErrorsOption() const override; + common::column_id_t getNumWarningDataColumns() const override { + switch (type) { + case common::ScanSourceType::FILE: + return info.bindData->constPtrCast()->numWarningDataColumns; + default: + return 0; + } + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundQueryScanSource final : BoundBaseScanSource { + // Use shared ptr to avoid copy BoundStatement. + // We should consider implement a copy constructor though. + std::shared_ptr statement; + BoundQueryScanSourceInfo info; + + explicit BoundQueryScanSource(std::shared_ptr statement, + BoundQueryScanSourceInfo info) + : BoundBaseScanSource{common::ScanSourceType::QUERY}, statement{std::move(statement)}, + info(std::move(info)) {} + BoundQueryScanSource(const BoundQueryScanSource& other) = default; + + bool getIgnoreErrorsOption() const override; + + expression_vector getColumns() override { + return statement->getStatementResult()->getColumns(); + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_standalone_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_standalone_call.h new file mode 100644 index 0000000000..17eccb14fd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_standalone_call.h @@ -0,0 +1,30 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "binder/expression/expression.h" + +namespace lbug { +namespace main { +struct Option; +} +namespace binder { + +class BoundStandaloneCall final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::STANDALONE_CALL; + +public: + BoundStandaloneCall(const main::Option* option, std::shared_ptr optionValue) + : BoundStatement{type_, BoundStatementResult::createEmptyResult()}, option{option}, + optionValue{std::move(optionValue)} {} + + const main::Option* getOption() const { return option; } + + std::shared_ptr getOptionValue() const { return optionValue; } + +private: + const main::Option* option; + std::shared_ptr optionValue; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_standalone_call_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_standalone_call_function.h new file mode 100644 index 0000000000..e3344cfc76 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_standalone_call_function.h @@ -0,0 +1,27 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "bound_table_scan_info.h" + +namespace lbug { +namespace binder { + +class BoundStandaloneCallFunction final : public BoundStatement { + static constexpr common::StatementType statementType = + common::StatementType::STANDALONE_CALL_FUNCTION; + +public: + explicit BoundStandaloneCallFunction(BoundTableScanInfo info) + : BoundStatement{statementType, BoundStatementResult::createEmptyResult()}, + info{std::move(info)} {} + + const function::TableFunction& getTableFunction() const { return info.func; } + + const function::TableFuncBindData* getBindData() const { return info.bindData.get(); } + +private: + BoundTableScanInfo info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement.h new file mode 100644 index 0000000000..d242e60508 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement.h @@ -0,0 +1,42 @@ +#pragma once + +#include "bound_statement_result.h" +#include "common/copy_constructors.h" +#include "common/enums/statement_type.h" + +namespace lbug { +namespace binder { + +class BoundStatement { +public: + BoundStatement(common::StatementType statementType, BoundStatementResult statementResult) + : statementType{statementType}, statementResult{std::move(statementResult)} {} + DELETE_COPY_DEFAULT_MOVE(BoundStatement); + + virtual ~BoundStatement() = default; + + common::StatementType getStatementType() const { return statementType; } + + const BoundStatementResult* getStatementResult() const { return &statementResult; } + std::shared_ptr getSingleColumnExpr() const { + return statementResult.getSingleColumnExpr(); + } + + BoundStatementResult* getStatementResultUnsafe() { return &statementResult; } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + +private: + common::StatementType statementType; + BoundStatementResult statementResult; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_result.h new file mode 100644 index 0000000000..4af1dcafd1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_result.h @@ -0,0 +1,53 @@ +#pragma once + +#include "binder/expression/expression.h" + +namespace lbug { +namespace binder { + +class BoundStatementResult { +public: + BoundStatementResult() = default; + explicit BoundStatementResult(expression_vector columns, std::vector columnNames) + : columns{std::move(columns)}, columnNames{std::move(columnNames)} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundStatementResult); + + static BoundStatementResult createEmptyResult() { return BoundStatementResult(); } + + static BoundStatementResult createSingleStringColumnResult( + const std::string& columnName = "result"); + + void addColumn(const std::string& columnName, std::shared_ptr column) { + columns.push_back(std::move(column)); + columnNames.push_back(columnName); + } + expression_vector getColumns() const { return columns; } + std::vector getColumnNames() const { return columnNames; } + std::vector getColumnTypes() const { + std::vector columnTypes; + for (auto& column : columns) { + columnTypes.push_back(column->getDataType().copy()); + } + return columnTypes; + } + + std::shared_ptr getSingleColumnExpr() const { + KU_ASSERT(columns.size() == 1); + return columns[0]; + } + +private: + BoundStatementResult(const BoundStatementResult& other) + : columns{other.columns}, columnNames{other.columnNames} {} + +private: + expression_vector columns; + // ColumnNames might be different from column.toString() because the same column might have + // different aliases, e.g. RETURN id AS a, id AS b + // For both columns we currently refer to the same id expr object so we cannot resolve column + // name properly from expression object. + std::vector columnNames; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_rewriter.h new file mode 100644 index 0000000000..475c66938b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_rewriter.h @@ -0,0 +1,15 @@ +#pragma once + +#include "bound_statement.h" + +namespace lbug { +namespace binder { + +// Perform semantic rewrite over bound statement. +class BoundStatementRewriter { +public: + static void rewrite(BoundStatement& boundStatement, main::ClientContext& clientContext); +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_visitor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_visitor.h new file mode 100644 index 0000000000..4fffccec37 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_statement_visitor.h @@ -0,0 +1,65 @@ +#pragma once + +#include "binder/query/normalized_single_query.h" +#include "bound_statement.h" + +namespace lbug { +namespace binder { + +class LBUG_API BoundStatementVisitor { +public: + BoundStatementVisitor() = default; + virtual ~BoundStatementVisitor() = default; + + void visit(const BoundStatement& statement); + // Unsafe visitors are implemented on-demand. We may reuse safe visitor inside unsafe visitor + // if no other class need to overwrite an unsafe visitor. + void visitUnsafe(BoundStatement& statement); + + virtual void visitSingleQuery(const NormalizedSingleQuery& singleQuery); + +protected: + virtual void visitCreateSequence(const BoundStatement&) {} + virtual void visitCreateTable(const BoundStatement&) {} + virtual void visitDrop(const BoundStatement&) {} + virtual void visitCreateType(const BoundStatement&) {} + virtual void visitAlter(const BoundStatement&) {} + virtual void visitCopyFrom(const BoundStatement&); + virtual void visitCopyTo(const BoundStatement&); + virtual void visitExportDatabase(const BoundStatement&) {} + virtual void visitImportDatabase(const BoundStatement&) {} + virtual void visitStandaloneCall(const BoundStatement&) {} + virtual void visitExplain(const BoundStatement&); + virtual void visitCreateMacro(const BoundStatement&) {} + virtual void visitTransaction(const BoundStatement&) {} + virtual void visitExtension(const BoundStatement&) {} + + virtual void visitRegularQuery(const BoundStatement& statement); + virtual void visitRegularQueryUnsafe(BoundStatement& statement); + virtual void visitSingleQueryUnsafe(NormalizedSingleQuery& singleQuery); + virtual void visitQueryPart(const NormalizedQueryPart& queryPart); + virtual void visitQueryPartUnsafe(NormalizedQueryPart& queryPart); + void visitReadingClause(const BoundReadingClause& readingClause); + void visitReadingClauseUnsafe(BoundReadingClause& readingClause); + virtual void visitMatch(const BoundReadingClause&) {} + virtual void visitMatchUnsafe(BoundReadingClause&) {} + virtual void visitUnwind(const BoundReadingClause& /*readingClause*/) {} + virtual void visitTableFunctionCall(const BoundReadingClause&) {} + virtual void visitLoadFrom(const BoundReadingClause& /*statement*/) {} + void visitUpdatingClause(const BoundUpdatingClause& updatingClause); + virtual void visitSet(const BoundUpdatingClause& /*updatingClause*/) {} + virtual void visitDelete(const BoundUpdatingClause& /* updatingClause*/) {} + virtual void visitInsert(const BoundUpdatingClause& /* updatingClause*/) {} + virtual void visitMerge(const BoundUpdatingClause& /* updatingClause*/) {} + + virtual void visitProjectionBody(const BoundProjectionBody& /* projectionBody*/) {} + virtual void visitProjectionBodyPredicate(const std::shared_ptr& /* predicate*/) {} + virtual void visitAttachDatabase(const BoundStatement&) {} + virtual void visitDetachDatabase(const BoundStatement&) {} + virtual void visitUseDatabase(const BoundStatement&) {} + virtual void visitStandaloneCallFunction(const BoundStatement&) {} + virtual void visitExtensionClause(const BoundStatement&) {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_table_scan_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_table_scan_info.h new file mode 100644 index 0000000000..b865f94dea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_table_scan_info.h @@ -0,0 +1,24 @@ +#pragma once + +#include "function/table/bind_data.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace binder { + +struct BoundTableScanInfo { + function::TableFunction func; + std::unique_ptr bindData; + + BoundTableScanInfo(function::TableFunction func, + std::unique_ptr bindData) + : func{std::move(func)}, bindData{std::move(bindData)} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundTableScanInfo); + +private: + BoundTableScanInfo(const BoundTableScanInfo& other) + : func{other.func}, bindData{other.bindData->copy()} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_transaction_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_transaction_statement.h new file mode 100644 index 0000000000..376dc53c95 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_transaction_statement.h @@ -0,0 +1,24 @@ +#pragma once + +#include "bound_statement.h" +#include "transaction/transaction_action.h" + +namespace lbug { +namespace binder { + +class BoundTransactionStatement final : public BoundStatement { + static constexpr common::StatementType statementType_ = common::StatementType::TRANSACTION; + +public: + explicit BoundTransactionStatement(transaction::TransactionAction transactionAction) + : BoundStatement{statementType_, BoundStatementResult::createEmptyResult()}, + transactionAction{transactionAction} {} + + transaction::TransactionAction getTransactionAction() const { return transactionAction; } + +private: + transaction::TransactionAction transactionAction; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_use_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_use_database.h new file mode 100644 index 0000000000..ac260172b0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/bound_use_database.h @@ -0,0 +1,17 @@ +#pragma once + +#include "binder/bound_database_statement.h" + +namespace lbug { +namespace binder { + +class BoundUseDatabase final : public BoundDatabaseStatement { + static constexpr common::StatementType type_ = common::StatementType::USE_DATABASE; + +public: + explicit BoundUseDatabase(std::string dbName) + : BoundDatabaseStatement{type_, std::move(dbName)} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_copy_from.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_copy_from.h new file mode 100644 index 0000000000..2e4af22d0e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_copy_from.h @@ -0,0 +1,101 @@ +#pragma once + +#include "binder/bound_scan_source.h" +#include "binder/expression/expression.h" +#include "common/enums/column_evaluate_type.h" +#include "common/enums/table_type.h" +#include "index_look_up_info.h" + +namespace lbug { +namespace binder { + +struct ExtraBoundCopyFromInfo { + virtual ~ExtraBoundCopyFromInfo() = default; + virtual std::unique_ptr copy() const = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +struct LBUG_API BoundCopyFromInfo { + // Name of table to copy into. + std::string tableName; + // Type of table. + common::TableType tableType; + // Data source. + std::unique_ptr source; + // Row offset. + std::shared_ptr offset; + expression_vector columnExprs; + std::vector columnEvaluateTypes; + std::unique_ptr extraInfo; + + BoundCopyFromInfo(std::string tableName, common::TableType tableType, + std::unique_ptr source, std::shared_ptr offset, + expression_vector columnExprs, std::vector columnEvaluateTypes, + std::unique_ptr extraInfo) + : tableName{std::move(tableName)}, tableType{tableType}, source{std::move(source)}, + offset{std::move(offset)}, columnExprs{std::move(columnExprs)}, + columnEvaluateTypes{std::move(columnEvaluateTypes)}, extraInfo{std::move(extraInfo)} {} + + EXPLICIT_COPY_DEFAULT_MOVE(BoundCopyFromInfo); + + expression_vector getSourceColumns() const { + return source ? source->getColumns() : expression_vector{}; + } + expression_vector getWarningColumns() const { + return offset ? source->getWarningColumns() : expression_vector{}; + } + + bool getIgnoreErrorsOption() const { return source ? source->getIgnoreErrorsOption() : false; } + +private: + BoundCopyFromInfo(const BoundCopyFromInfo& other) + : tableName{other.tableName}, tableType{other.tableType}, offset{other.offset}, + columnExprs{other.columnExprs}, columnEvaluateTypes{other.columnEvaluateTypes} { + source = other.source ? other.source->copy() : nullptr; + if (other.extraInfo) { + extraInfo = other.extraInfo->copy(); + } + } +}; + +struct ExtraBoundCopyRelInfo final : ExtraBoundCopyFromInfo { + std::string fromTableName; + std::string toTableName; + // We process internal ID column as offset (INT64) column until partitioner. In partitioner, + // we need to manually change offset(INT64) type to internal ID type. + std::vector internalIDColumnIndices; + std::vector infos; + + ExtraBoundCopyRelInfo(std::string fromTableName, std::string toTableName, + std::vector internalIDColumnIndices, std::vector infos) + : fromTableName{std::move(fromTableName)}, toTableName{std::move(toTableName)}, + internalIDColumnIndices{std::move(internalIDColumnIndices)}, infos{std::move(infos)} {} + ExtraBoundCopyRelInfo(const ExtraBoundCopyRelInfo& other) + : fromTableName{other.fromTableName}, toTableName{other.toTableName}, + internalIDColumnIndices{other.internalIDColumnIndices}, infos{other.infos} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +class BoundCopyFrom final : public BoundStatement { + static constexpr common::StatementType statementType_ = common::StatementType::COPY_FROM; + +public: + explicit BoundCopyFrom(BoundCopyFromInfo info) + : BoundStatement{statementType_, BoundStatementResult::createSingleStringColumnResult()}, + info{std::move(info)} {} + + const BoundCopyFromInfo* getInfo() const { return &info; } + +private: + BoundCopyFromInfo info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_copy_to.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_copy_to.h new file mode 100644 index 0000000000..2e1a0058dd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_copy_to.h @@ -0,0 +1,32 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "function/export/export_function.h" + +namespace lbug { +namespace binder { + +class BoundCopyTo final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::COPY_TO; + +public: + BoundCopyTo(std::unique_ptr bindData, + function::ExportFunction exportFunc, std::unique_ptr query) + : BoundStatement{type_, BoundStatementResult::createEmptyResult()}, + bindData{std::move(bindData)}, exportFunc{std::move(exportFunc)}, + query{std::move(query)} {} + + std::unique_ptr getBindData() const { return bindData->copy(); } + + function::ExportFunction getExportFunc() const { return exportFunc; } + + const BoundStatement* getRegularQuery() const { return query.get(); } + +private: + std::unique_ptr bindData; + function::ExportFunction exportFunc; + std::unique_ptr query; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_query_scan_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_query_scan_info.h new file mode 100644 index 0000000000..23dc9c1cf6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/bound_query_scan_info.h @@ -0,0 +1,16 @@ +#pragma once + +#include "common/case_insensitive_map.h" +#include "common/types/value/value.h" +namespace lbug { +namespace binder { + +struct BoundQueryScanSourceInfo { + common::case_insensitive_map_t options; + + explicit BoundQueryScanSourceInfo(common::case_insensitive_map_t options) + : options{std::move(options)} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/index_look_up_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/index_look_up_info.h new file mode 100644 index 0000000000..bcb0e3701c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/copy/index_look_up_info.h @@ -0,0 +1,22 @@ +#pragma once + +#include "binder/expression/expression.h" + +namespace lbug { +namespace binder { + +struct IndexLookupInfo { + common::table_id_t nodeTableID; + std::shared_ptr offset; // output + std::shared_ptr key; // input + expression_vector warningExprs; + + IndexLookupInfo(common::table_id_t nodeTableID, std::shared_ptr offset, + std::shared_ptr key, expression_vector warningExprs = {}) + : nodeTableID{nodeTableID}, offset{std::move(offset)}, key{std::move(key)}, + warningExprs(std::move(warningExprs)) {} + IndexLookupInfo(const IndexLookupInfo& other) = default; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_alter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_alter.h new file mode 100644 index 0000000000..ecc1753702 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_alter.h @@ -0,0 +1,24 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "bound_alter_info.h" + +namespace lbug { +namespace binder { + +class BoundAlter final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::ALTER; + +public: + explicit BoundAlter(BoundAlterInfo info) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult()}, + info{std::move(info)} {} + + const BoundAlterInfo& getInfo() const { return info; } + +private: + BoundAlterInfo info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_alter_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_alter_info.h new file mode 100644 index 0000000000..88a41a54a9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_alter_info.h @@ -0,0 +1,125 @@ +#pragma once + +#include "binder/ddl/property_definition.h" +#include "binder/expression/expression.h" +#include "common/enums/alter_type.h" +#include "common/enums/conflict_action.h" + +namespace lbug { +namespace binder { + +struct BoundExtraAlterInfo { + virtual ~BoundExtraAlterInfo() = default; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + + virtual std::unique_ptr copy() const = 0; +}; + +struct BoundAlterInfo { + common::AlterType alterType; + std::string tableName; + std::unique_ptr extraInfo; + common::ConflictAction onConflict; + + BoundAlterInfo(common::AlterType alterType, std::string tableName, + std::unique_ptr extraInfo, + common::ConflictAction onConflict = common::ConflictAction::ON_CONFLICT_THROW) + : alterType{alterType}, tableName{std::move(tableName)}, extraInfo{std::move(extraInfo)}, + onConflict{onConflict} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundAlterInfo); + + std::string toString() const; + +private: + BoundAlterInfo(const BoundAlterInfo& other) + : alterType{other.alterType}, tableName{other.tableName}, + extraInfo{other.extraInfo->copy()}, onConflict{other.onConflict} {} +}; + +struct BoundExtraRenameTableInfo final : BoundExtraAlterInfo { + std::string newName; + + explicit BoundExtraRenameTableInfo(std::string newName) : newName{std::move(newName)} {} + BoundExtraRenameTableInfo(const BoundExtraRenameTableInfo& other) : newName{other.newName} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraAddPropertyInfo final : BoundExtraAlterInfo { + PropertyDefinition propertyDefinition; + std::shared_ptr boundDefault; + + BoundExtraAddPropertyInfo(const PropertyDefinition& definition, + std::shared_ptr boundDefault) + : propertyDefinition{definition.copy()}, boundDefault{std::move(boundDefault)} {} + BoundExtraAddPropertyInfo(const BoundExtraAddPropertyInfo& other) + : propertyDefinition{other.propertyDefinition.copy()}, boundDefault{other.boundDefault} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraDropPropertyInfo final : BoundExtraAlterInfo { + std::string propertyName; + + explicit BoundExtraDropPropertyInfo(std::string propertyName) + : propertyName{std::move(propertyName)} {} + BoundExtraDropPropertyInfo(const BoundExtraDropPropertyInfo& other) + : propertyName{other.propertyName} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraRenamePropertyInfo final : BoundExtraAlterInfo { + std::string newName; + std::string oldName; + + BoundExtraRenamePropertyInfo(std::string newName, std::string oldName) + : newName{std::move(newName)}, oldName{std::move(oldName)} {} + BoundExtraRenamePropertyInfo(const BoundExtraRenamePropertyInfo& other) + : newName{other.newName}, oldName{other.oldName} {} + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraCommentInfo final : BoundExtraAlterInfo { + std::string comment; + + explicit BoundExtraCommentInfo(std::string comment) : comment{std::move(comment)} {} + BoundExtraCommentInfo(const BoundExtraCommentInfo& other) : comment{other.comment} {} + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraAlterFromToConnection final : BoundExtraAlterInfo { + common::table_id_t fromTableID; + common::table_id_t toTableID; + + BoundExtraAlterFromToConnection(common::table_id_t fromTableID, common::table_id_t toTableID) + : fromTableID{fromTableID}, toTableID{toTableID} {} + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_sequence.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_sequence.h new file mode 100644 index 0000000000..3cad5467a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_sequence.h @@ -0,0 +1,23 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "bound_create_sequence_info.h" +namespace lbug { +namespace binder { + +class BoundCreateSequence final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_SEQUENCE; + +public: + explicit BoundCreateSequence(BoundCreateSequenceInfo info) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult()}, + info{std::move(info)} {} + + const BoundCreateSequenceInfo& getInfo() const { return info; } + +private: + BoundCreateSequenceInfo info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_sequence_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_sequence_info.h new file mode 100644 index 0000000000..a1e9ca2664 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_sequence_info.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include "common/copy_constructors.h" +#include "common/enums/conflict_action.h" + +namespace lbug { +namespace binder { + +struct BoundCreateSequenceInfo { + std::string sequenceName; + int64_t startWith; + int64_t increment; + int64_t minValue; + int64_t maxValue; + bool cycle; + common::ConflictAction onConflict; + bool hasParent = false; + bool isInternal; + + BoundCreateSequenceInfo(std::string sequenceName, int64_t startWith, int64_t increment, + int64_t minValue, int64_t maxValue, bool cycle, common::ConflictAction onConflict, + bool isInternal) + : sequenceName{std::move(sequenceName)}, startWith{startWith}, increment{increment}, + minValue{minValue}, maxValue{maxValue}, cycle{cycle}, onConflict{onConflict}, + isInternal{isInternal} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundCreateSequenceInfo); + +private: + BoundCreateSequenceInfo(const BoundCreateSequenceInfo& other) + : sequenceName{other.sequenceName}, startWith{other.startWith}, increment{other.increment}, + minValue{other.minValue}, maxValue{other.maxValue}, cycle{other.cycle}, + onConflict{other.onConflict}, hasParent{other.hasParent}, isInternal{other.isInternal} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_table.h new file mode 100644 index 0000000000..e279b1d1d6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_table.h @@ -0,0 +1,32 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "binder/copy/bound_copy_from.h" +#include "bound_create_table_info.h" + +namespace lbug { +namespace binder { + +class BoundCreateTable final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_TABLE; + +public: + explicit BoundCreateTable(BoundCreateTableInfo info, BoundStatementResult result) + : BoundStatement{type_, std::move(result)}, info{std::move(info)} {} + + const BoundCreateTableInfo& getInfo() const { return info; } + + void setCopyInfo(BoundCopyFromInfo copyInfo_) { copyInfo = std::move(copyInfo_); } + bool hasCopyInfo() const { return copyInfo.has_value(); } + const BoundCopyFromInfo& getCopyInfo() const { + KU_ASSERT(copyInfo.has_value()); + return copyInfo.value(); + } + +private: + BoundCreateTableInfo info; + std::optional copyInfo; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_table_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_table_info.h new file mode 100644 index 0000000000..03a86d911d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_table_info.h @@ -0,0 +1,112 @@ +#pragma once + +#include "catalog/catalog_entry/catalog_entry_type.h" +#include "catalog/catalog_entry/node_table_id_pair.h" +#include "common/enums/conflict_action.h" +#include "common/enums/extend_direction.h" +#include "common/enums/rel_multiplicity.h" +#include "property_definition.h" + +namespace lbug { +namespace common { +enum class RelMultiplicity : uint8_t; +} +namespace binder { +struct BoundExtraCreateCatalogEntryInfo { + virtual ~BoundExtraCreateCatalogEntryInfo() = default; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + + virtual inline std::unique_ptr copy() const = 0; +}; + +struct BoundCreateTableInfo { + catalog::CatalogEntryType type = catalog::CatalogEntryType::DUMMY_ENTRY; + std::string tableName; + common::ConflictAction onConflict = common::ConflictAction::INVALID; + std::unique_ptr extraInfo; + bool isInternal = false; + bool hasParent = false; + + BoundCreateTableInfo() = default; + BoundCreateTableInfo(catalog::CatalogEntryType type, std::string tableName, + common::ConflictAction onConflict, + std::unique_ptr extraInfo, bool isInternal, + bool hasParent = false) + : type{type}, tableName{std::move(tableName)}, onConflict{onConflict}, + extraInfo{std::move(extraInfo)}, isInternal{isInternal}, hasParent{hasParent} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundCreateTableInfo); + + std::string toString() const; + +private: + BoundCreateTableInfo(const BoundCreateTableInfo& other) + : type{other.type}, tableName{other.tableName}, onConflict{other.onConflict}, + extraInfo{other.extraInfo->copy()}, isInternal{other.isInternal}, + hasParent{other.hasParent} {} +}; + +struct LBUG_API BoundExtraCreateTableInfo : BoundExtraCreateCatalogEntryInfo { + std::vector propertyDefinitions; + + explicit BoundExtraCreateTableInfo(std::vector propertyDefinitions) + : propertyDefinitions{std::move(propertyDefinitions)} {} + + BoundExtraCreateTableInfo(const BoundExtraCreateTableInfo& other) + : BoundExtraCreateTableInfo{copyVector(other.propertyDefinitions)} {} + BoundExtraCreateTableInfo& operator=(const BoundExtraCreateTableInfo&) = delete; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraCreateNodeTableInfo final : BoundExtraCreateTableInfo { + std::string primaryKeyName; + + BoundExtraCreateNodeTableInfo(std::string primaryKeyName, + std::vector definitions) + : BoundExtraCreateTableInfo{std::move(definitions)}, + primaryKeyName{std::move(primaryKeyName)} {} + BoundExtraCreateNodeTableInfo(const BoundExtraCreateNodeTableInfo& other) + : BoundExtraCreateTableInfo{copyVector(other.propertyDefinitions)}, + primaryKeyName{other.primaryKeyName} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct BoundExtraCreateRelTableGroupInfo final : BoundExtraCreateTableInfo { + common::RelMultiplicity srcMultiplicity; + common::RelMultiplicity dstMultiplicity; + common::ExtendDirection storageDirection; + std::vector nodePairs; + + explicit BoundExtraCreateRelTableGroupInfo(std::vector definitions, + common::RelMultiplicity srcMultiplicity, common::RelMultiplicity dstMultiplicity, + common::ExtendDirection storageDirection, std::vector nodePairs) + : BoundExtraCreateTableInfo{std::move(definitions)}, srcMultiplicity{srcMultiplicity}, + dstMultiplicity{dstMultiplicity}, storageDirection{storageDirection}, + nodePairs{std::move(nodePairs)} {} + + BoundExtraCreateRelTableGroupInfo(const BoundExtraCreateRelTableGroupInfo& other) + : BoundExtraCreateTableInfo{copyVector(other.propertyDefinitions)}, + srcMultiplicity{other.srcMultiplicity}, dstMultiplicity{other.dstMultiplicity}, + storageDirection{other.storageDirection}, nodePairs{other.nodePairs} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_type.h new file mode 100644 index 0000000000..7bf1cc5561 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_create_type.h @@ -0,0 +1,26 @@ +#pragma once + +#include "binder/bound_statement.h" + +namespace lbug { +namespace binder { + +class BoundCreateType final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_TYPE; + +public: + explicit BoundCreateType(std::string name, common::LogicalType type) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult()}, + name{std::move(name)}, type{std::move(type)} {} + + std::string getName() const { return name; }; + + const common::LogicalType& getType() const { return type; } + +private: + std::string name; + common::LogicalType type; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_drop.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_drop.h new file mode 100644 index 0000000000..b6bde72f35 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/bound_drop.h @@ -0,0 +1,24 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "parser/ddl/drop_info.h" + +namespace lbug { +namespace binder { + +class BoundDrop final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::DROP; + +public: + explicit BoundDrop(parser::DropInfo dropInfo) + : BoundStatement{type_, BoundStatementResult::createSingleStringColumnResult()}, + dropInfo{std::move(dropInfo)} {} + + const parser::DropInfo& getDropInfo() const { return dropInfo; }; + +private: + parser::DropInfo dropInfo; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/property_definition.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/property_definition.h new file mode 100644 index 0000000000..154fac041f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/ddl/property_definition.h @@ -0,0 +1,47 @@ +#pragma once + +#include "common/types/types.h" +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace binder { + +struct LBUG_API ColumnDefinition { + std::string name; + common::LogicalType type; + + ColumnDefinition() = default; + ColumnDefinition(std::string name, common::LogicalType type) + : name{std::move(name)}, type{std::move(type)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ColumnDefinition); + +private: + ColumnDefinition(const ColumnDefinition& other) : name{other.name}, type{other.type.copy()} {} +}; + +struct LBUG_API PropertyDefinition { + ColumnDefinition columnDefinition; + std::unique_ptr defaultExpr; + + PropertyDefinition() = default; + explicit PropertyDefinition(ColumnDefinition columnDefinition); + PropertyDefinition(ColumnDefinition columnDefinition, + std::unique_ptr defaultExpr) + : columnDefinition{std::move(columnDefinition)}, defaultExpr{std::move(defaultExpr)} {} + EXPLICIT_COPY_DEFAULT_MOVE(PropertyDefinition); + + std::string getName() const { return columnDefinition.name; } + const common::LogicalType& getType() const { return columnDefinition.type; } + std::string getDefaultExpressionName() const { return defaultExpr->getRawName(); } + void rename(const std::string& newName) { columnDefinition.name = newName; } + + void serialize(common::Serializer& serializer) const; + static PropertyDefinition deserialize(common::Deserializer& deserializer); + +private: + PropertyDefinition(const PropertyDefinition& other) + : columnDefinition{other.columnDefinition.copy()}, defaultExpr{other.defaultExpr->copy()} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/aggregate_function_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/aggregate_function_expression.h new file mode 100644 index 0000000000..a8331988b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/aggregate_function_expression.h @@ -0,0 +1,36 @@ +#pragma once + +#include "expression.h" +#include "function/aggregate_function.h" + +namespace lbug { +namespace binder { + +class AggregateFunctionExpression final : public Expression { + static constexpr common::ExpressionType expressionType_ = + common::ExpressionType::AGGREGATE_FUNCTION; + +public: + AggregateFunctionExpression(function::AggregateFunction function, + std::unique_ptr bindData, expression_vector children, + std::string uniqueName) + : Expression{expressionType_, bindData->resultType.copy(), std::move(children), + std::move(uniqueName)}, + function{std::move(function)}, bindData{std::move(bindData)} {} + + const function::AggregateFunction& getFunction() const { return function; } + function::FunctionBindData* getBindData() const { return bindData.get(); } + bool isDistinct() const { return function.isDistinct; } + + std::string toStringInternal() const override; + + static std::string getUniqueName(const std::string& functionName, + const expression_vector& children, bool isDistinct); + +private: + function::AggregateFunction function; + std::unique_ptr bindData; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/case_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/case_expression.h new file mode 100644 index 0000000000..aa00cba841 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/case_expression.h @@ -0,0 +1,44 @@ +#pragma once + +#include "expression.h" + +namespace lbug { +namespace binder { + +struct CaseAlternative { + std::shared_ptr whenExpression; + std::shared_ptr thenExpression; + + CaseAlternative(std::shared_ptr whenExpression, + std::shared_ptr thenExpression) + : whenExpression{std::move(whenExpression)}, thenExpression{std::move(thenExpression)} {} +}; + +class CaseExpression final : public Expression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::CASE_ELSE; + +public: + CaseExpression(common::LogicalType dataType, std::shared_ptr elseExpression, + const std::string& name) + : Expression{expressionType_, std::move(dataType), name}, + elseExpression{std::move(elseExpression)} {} + + void addCaseAlternative(std::shared_ptr when, std::shared_ptr then) { + caseAlternatives.push_back(make_unique(std::move(when), std::move(then))); + } + common::idx_t getNumCaseAlternatives() const { return caseAlternatives.size(); } + CaseAlternative* getCaseAlternative(common::idx_t idx) const { + return caseAlternatives[idx].get(); + } + + std::shared_ptr getElseExpression() const { return elseExpression; } + + std::string toStringInternal() const override; + +private: + std::vector> caseAlternatives; + std::shared_ptr elseExpression; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/expression.h new file mode 100644 index 0000000000..968cc1a627 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/expression.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/cast.h" +#include "common/copy_constructors.h" +#include "common/enums/expression_type.h" +#include "common/types/types.h" + +namespace lbug { +namespace binder { + +class Expression; +using expression_vector = std::vector>; +using expression_pair = std::pair, std::shared_ptr>; + +struct ExpressionHasher; +struct ExpressionEquality; +using expression_set = + std::unordered_set, ExpressionHasher, ExpressionEquality>; +template +using expression_map = + std::unordered_map, T, ExpressionHasher, ExpressionEquality>; + +class LBUG_API Expression : public std::enable_shared_from_this { + friend class ExpressionChildrenCollector; + +public: + Expression(common::ExpressionType expressionType, common::LogicalType dataType, + expression_vector children, std::string uniqueName) + : expressionType{expressionType}, dataType{std::move(dataType)}, + uniqueName{std::move(uniqueName)}, children{std::move(children)} {} + // Create binary expression. + Expression(common::ExpressionType expressionType, common::LogicalType dataType, + const std::shared_ptr& left, const std::shared_ptr& right, + std::string uniqueName) + : Expression{expressionType, std::move(dataType), expression_vector{left, right}, + std::move(uniqueName)} {} + // Create unary expression. + Expression(common::ExpressionType expressionType, common::LogicalType dataType, + const std::shared_ptr& child, std::string uniqueName) + : Expression{expressionType, std::move(dataType), expression_vector{child}, + std::move(uniqueName)} {} + // Create leaf expression + Expression(common::ExpressionType expressionType, common::LogicalType dataType, + std::string uniqueName) + : Expression{expressionType, std::move(dataType), expression_vector{}, + std::move(uniqueName)} {} + DELETE_COPY_DEFAULT_MOVE(Expression); + virtual ~Expression(); + + void setUniqueName(const std::string& name) { uniqueName = name; } + std::string getUniqueName() const { + KU_ASSERT(!uniqueName.empty()); + return uniqueName; + } + + virtual void cast(const common::LogicalType& type); + const common::LogicalType& getDataType() const { return dataType; } + + void setAlias(const std::string& newAlias) { alias = newAlias; } + bool hasAlias() const { return !alias.empty(); } + std::string getAlias() const { return alias; } + + common::idx_t getNumChildren() const { return children.size(); } + std::shared_ptr getChild(common::idx_t idx) const { + KU_ASSERT(idx < children.size()); + return children[idx]; + } + expression_vector getChildren() const { return children; } + void setChild(common::idx_t idx, std::shared_ptr child) { + KU_ASSERT(idx < children.size()); + children[idx] = std::move(child); + } + + expression_vector splitOnAND(); + + bool operator==(const Expression& rhs) const { return uniqueName == rhs.uniqueName; } + + std::string toString() const { return hasAlias() ? alias : toStringInternal(); } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + +protected: + virtual std::string toStringInternal() const = 0; + +public: + common::ExpressionType expressionType; + common::LogicalType dataType; + +protected: + // Name that serves as the unique identifier. + std::string uniqueName; + std::string alias; + expression_vector children; +}; + +struct ExpressionHasher { + std::size_t operator()(const std::shared_ptr& expression) const { + return std::hash{}(expression->getUniqueName()); + } +}; + +struct ExpressionEquality { + bool operator()(const std::shared_ptr& left, + const std::shared_ptr& right) const { + return left->getUniqueName() == right->getUniqueName(); + } +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/expression_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/expression_util.h new file mode 100644 index 0000000000..4766215eee --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/expression_util.h @@ -0,0 +1,84 @@ +#pragma once + +#include "common/types/types.h" +#include "common/types/value/value.h" +#include "expression.h" + +namespace lbug { +namespace binder { + +struct LBUG_API ExpressionUtil { + static expression_vector getExpressionsWithDataType(const expression_vector& expressions, + common::LogicalTypeID dataTypeID); + + static uint32_t find(const Expression* target, const expression_vector& expressions); + + // Print as a1,a2,a3,... + static std::string toString(const expression_vector& expressions); + static std::string toStringOrdered(const expression_vector& expressions); + // Print as a1=a2, a3=a4,... + static std::string toString(const std::vector& expressionPairs); + // Print as a1=a2 + static std::string toString(const expression_pair& expressionPair); + static std::string getUniqueName(const expression_vector& expressions); + + static expression_vector excludeExpression(const expression_vector& exprs, + const Expression& exprToExclude); + static expression_vector excludeExpressions(const expression_vector& expressions, + const expression_vector& expressionsToExclude); + + static common::logical_type_vec_t getDataTypes(const expression_vector& expressions); + + static expression_vector removeDuplication(const expression_vector& expressions); + + static bool isEmptyPattern(const Expression& expression); + static bool isNodePattern(const Expression& expression); + static bool isRelPattern(const Expression& expression); + static bool isRecursiveRelPattern(const Expression& expression); + static bool isNullLiteral(const Expression& expression); + static bool isBoolLiteral(const Expression& expression); + static bool isFalseLiteral(const Expression& expression); + static bool isEmptyList(const Expression& expression); + + static void validateExpressionType(const Expression& expr, common::ExpressionType expectedType); + static void validateExpressionType(const Expression& expr, + std::vector expectedType); + + // Validate data type. + static void validateDataType(const Expression& expr, const common::LogicalType& expectedType); + // Validate recursive data type top level (used when child type is unknown). + static void validateDataType(const Expression& expr, common::LogicalTypeID expectedTypeID); + static void validateDataType(const Expression& expr, + const std::vector& expectedTypeIDs); + template + static T getLiteralValue(const Expression& expr); + + static bool tryCombineDataType(const expression_vector& expressions, + common::LogicalType& result); + + // Check If we can directly assign a new data type to an expression. + // This mostly happen when a literal is an empty list. By default, we assign its data type to + // INT64[] but it can be cast to any other list type at compile time. + static bool canCastStatically(const Expression& expr, const common::LogicalType& targetType); + + static bool canEvaluateAsLiteral(const Expression& expr); + static common::Value evaluateAsLiteralValue(const Expression& expr); + static uint64_t evaluateAsSkipLimit(const Expression& expr); + + template + using validate_param_func = void (*)(T); + + template + static T getExpressionVal(const Expression& expr, const common::Value& value, + const common::LogicalType& targetType, validate_param_func validateParamFunc = nullptr); + + template + static T evaluateLiteral(main::ClientContext* context, std::shared_ptr expression, + const common::LogicalType& type, validate_param_func validateParamFunc = nullptr); + + static std::shared_ptr applyImplicitCastingIfNecessary(main::ClientContext* context, + std::shared_ptr expr, common::LogicalType targetType); +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/lambda_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/lambda_expression.h new file mode 100644 index 0000000000..7fb344d2c6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/lambda_expression.h @@ -0,0 +1,36 @@ +#pragma once + +#include "expression.h" +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace binder { + +class LambdaExpression final : public Expression { + static constexpr common::ExpressionType type_ = common::ExpressionType::LAMBDA; + +public: + LambdaExpression(std::unique_ptr parsedLambdaExpr, + std::string uniqueName) + : Expression{type_, common::LogicalType::ANY(), std::move(uniqueName)}, + parsedLambdaExpr{std::move(parsedLambdaExpr)} {} + + void cast(const common::LogicalType& type_) override { + KU_ASSERT(dataType.getLogicalTypeID() == common::LogicalTypeID::ANY); + dataType = type_.copy(); + } + + parser::ParsedExpression* getParsedLambdaExpr() const { return parsedLambdaExpr.get(); } + + void setFunctionExpr(std::shared_ptr expr) { functionExpr = std::move(expr); } + std::shared_ptr getFunctionExpr() const { return functionExpr; } + + std::string toStringInternal() const override { return parsedLambdaExpr->toString(); } + +private: + std::unique_ptr parsedLambdaExpr; + std::shared_ptr functionExpr; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/literal_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/literal_expression.h new file mode 100644 index 0000000000..abad7511da --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/literal_expression.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/types/value/value.h" +#include "expression.h" + +namespace lbug { +namespace binder { + +class LBUG_API LiteralExpression final : public Expression { + static constexpr common::ExpressionType type_ = common::ExpressionType::LITERAL; + +public: + LiteralExpression(common::Value value, const std::string& uniqueName) + : Expression{type_, value.getDataType().copy(), uniqueName}, value{std::move(value)} {} + + bool isNull() const { return value.isNull(); } + + void cast(const common::LogicalType& type) override; + + common::Value getValue() const { return value; } + + std::string toStringInternal() const override { return value.toString(); } + +public: + common::Value value; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/node_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/node_expression.h new file mode 100644 index 0000000000..8269468fc3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/node_expression.h @@ -0,0 +1,33 @@ +#pragma once + +#include "node_rel_expression.h" + +namespace lbug { +namespace binder { + +class LBUG_API NodeExpression final : public NodeOrRelExpression { +public: + NodeExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName, + std::vector entries) + : NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName), + std::move(entries)} {} + + ~NodeExpression() override; + + bool isMultiLabeled() const override { return entries.size() > 1; } + + void setInternalID(std::shared_ptr expr) { internalID = std::move(expr); } + std::shared_ptr getInternalID() const override { + KU_ASSERT(internalID != nullptr); + return internalID; + } + + // Get the primary key property expression for a given table ID. + std::shared_ptr getPrimaryKey(common::table_id_t tableID) const; + +private: + std::shared_ptr internalID; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/node_rel_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/node_rel_expression.h new file mode 100644 index 0000000000..74e9c54c4a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/node_rel_expression.h @@ -0,0 +1,95 @@ +#pragma once + +#include "common/case_insensitive_map.h" +#include "expression.h" +#include "property_expression.h" + +namespace lbug { +namespace catalog { +class TableCatalogEntry; +} +namespace binder { + +class LBUG_API NodeOrRelExpression : public Expression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::PATTERN; + +public: + NodeOrRelExpression(common::LogicalType dataType, std::string uniqueName, + std::string variableName, std::vector entries) + : Expression{expressionType_, std::move(dataType), std::move(uniqueName)}, + variableName(std::move(variableName)), entries{std::move(entries)} {} + + void setDataType(common::LogicalType dataType) { this->dataType = std::move(dataType); } + + std::string getVariableName() const { return variableName; } + + bool isEmpty() const { return entries.empty(); } + virtual bool isMultiLabeled() const = 0; + + common::table_id_vector_t getTableIDs() const; + common::table_id_set_t getTableIDsSet() const; + + // Table entries + common::idx_t getNumEntries() const { return entries.size(); } + const std::vector& getEntries() const { return entries; } + catalog::TableCatalogEntry* getEntry(common::idx_t idx) const { return entries[idx]; } + void setEntries(std::vector entries_) { + entries = std::move(entries_); + } + void addEntries(const std::vector& entries_); + + // Property expressions + void addPropertyExpression(std::shared_ptr property); + bool hasPropertyExpression(const std::string& propertyName) const { + return propertyNameToIdx.contains(propertyName); + } + std::vector> getPropertyExpressions() const { + return propertyExprs; + } + std::shared_ptr getPropertyExpression( + const std::string& propertyName) const { + KU_ASSERT(propertyNameToIdx.contains(propertyName)); + return propertyExprs[propertyNameToIdx.at(propertyName)]; + } + virtual std::shared_ptr getInternalID() const = 0; + + // Label expression + void setLabelExpression(std::shared_ptr expression) { + labelExpression = std::move(expression); + } + std::shared_ptr getLabelExpression() const { return labelExpression; } + + // Property data expressions + void addPropertyDataExpr(std::string propertyName, std::shared_ptr expr) { + propertyDataExprs.insert({propertyName, expr}); + } + const common::case_insensitive_map_t>& + getPropertyDataExprRef() const { + return propertyDataExprs; + } + bool hasPropertyDataExpr(const std::string& propertyName) const { + return propertyDataExprs.contains(propertyName); + } + std::shared_ptr getPropertyDataExpr(const std::string& propertyName) const { + KU_ASSERT(propertyDataExprs.contains(propertyName)); + return propertyDataExprs.at(propertyName); + } + + std::string toStringInternal() const final { return variableName; } + +protected: + std::string variableName; + // A pattern may bind to multiple tables. + std::vector entries; + // Index over propertyExprs on property name. + common::case_insensitive_map_t propertyNameToIdx; + // Property expressions with order (aligned with catalog). + std::vector> propertyExprs; + // Label expression + std::shared_ptr labelExpression; + // Property data expressions specified by user in the form of "{propertyName : data}" + common::case_insensitive_map_t> propertyDataExprs; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/parameter_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/parameter_expression.h new file mode 100644 index 0000000000..87eaa9660a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/parameter_expression.h @@ -0,0 +1,31 @@ +#pragma once + +#include "common/types/value/value.h" +#include "expression.h" + +namespace lbug { +namespace binder { + +class LBUG_API ParameterExpression final : public Expression { + static constexpr common::ExpressionType expressionType = common::ExpressionType::PARAMETER; + +public: + explicit ParameterExpression(const std::string& parameterName, common::Value value) + : Expression{expressionType, value.getDataType().copy(), createUniqueName(parameterName)}, + parameterName(parameterName), value{std::move(value)} {} + + void cast(const common::LogicalType& type) override; + + common::Value getValue() const { return value; } + +private: + std::string toStringInternal() const override { return "$" + parameterName; } + static std::string createUniqueName(const std::string& input) { return "$" + input; } + +private: + std::string parameterName; + common::Value value; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/path_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/path_expression.h new file mode 100644 index 0000000000..55c940e0fc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/path_expression.h @@ -0,0 +1,30 @@ +#pragma once + +#include "binder/expression/expression.h" + +namespace lbug { +namespace binder { + +class PathExpression final : public Expression { +public: + PathExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName, + common::LogicalType nodeType, common::LogicalType relType, expression_vector children) + : Expression{common::ExpressionType::PATH, std::move(dataType), std::move(children), + std::move(uniqueName)}, + variableName{std::move(variableName)}, nodeType{std::move(nodeType)}, + relType{std::move(relType)} {} + + std::string getVariableName() const { return variableName; } + const common::LogicalType& getNodeType() const { return nodeType; } + const common::LogicalType& getRelType() const { return relType; } + + std::string toStringInternal() const override { return variableName; } + +private: + std::string variableName; + common::LogicalType nodeType; + common::LogicalType relType; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/property_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/property_expression.h new file mode 100644 index 0000000000..9671dd57b2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/property_expression.h @@ -0,0 +1,76 @@ +#pragma once + +#include "common/constants.h" +#include "expression.h" + +namespace lbug { +namespace catalog { +class TableCatalogEntry; +} +namespace binder { + +struct SingleLabelPropertyInfo { + bool exists; + bool isPrimaryKey; + + explicit SingleLabelPropertyInfo(bool exists, bool isPrimaryKey) + : exists{exists}, isPrimaryKey{isPrimaryKey} {} + EXPLICIT_COPY_DEFAULT_MOVE(SingleLabelPropertyInfo); + +private: + SingleLabelPropertyInfo(const SingleLabelPropertyInfo& other) + : exists{other.exists}, isPrimaryKey{other.isPrimaryKey} {} +}; + +class LBUG_API PropertyExpression final : public Expression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::PROPERTY; + +public: + PropertyExpression(common::LogicalType dataType, std::string propertyName, + std::string uniqueVarName, std::string rawVariableName, + common::table_id_map_t infos) + : Expression{expressionType_, std::move(dataType), uniqueVarName + "." + propertyName}, + propertyName{std::move(propertyName)}, uniqueVarName{std::move(uniqueVarName)}, + rawVariableName{std::move(rawVariableName)}, infos{std::move(infos)} {} + + PropertyExpression(const PropertyExpression& other) + : Expression{expressionType_, other.dataType.copy(), other.uniqueName}, + propertyName{other.propertyName}, uniqueVarName{other.uniqueVarName}, + rawVariableName{other.rawVariableName}, infos{copyUnorderedMap(other.infos)} {} + + // If this property is primary key on all tables. + bool isPrimaryKey() const; + // If this property is primary key for given table. + bool isPrimaryKey(common::table_id_t tableID) const; + + std::string getPropertyName() const { return propertyName; } + std::string getVariableName() const { return uniqueVarName; } + std::string getRawVariableName() const { return rawVariableName; } + + // If this property exists for given table. + bool hasProperty(common::table_id_t tableID) const; + + // common::column_id_t getColumnID(const catalog::TableCatalogEntry& entry) const; + bool isSingleLabel() const { return infos.size() == 1; } + common::table_id_t getSingleTableID() const { return infos.begin()->first; } + + bool isInternalID() const { return getPropertyName() == common::InternalKeyword::ID; } + + std::string toStringInternal() const override { return rawVariableName + "." + propertyName; } + + std::unique_ptr copy() const { + return std::make_unique(*this); + } + +private: + std::string propertyName; + // unique identifier references to a node/rel table. + std::string uniqueVarName; + // printable identifier references to a node/rel table. + std::string rawVariableName; + // The same property name may have different info on each table. + common::table_id_map_t infos; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/rel_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/rel_expression.h new file mode 100644 index 0000000000..788aa0dbf7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/rel_expression.h @@ -0,0 +1,123 @@ +#pragma once + +#include "common/constants.h" +#include "common/enums/extend_direction.h" +#include "common/enums/query_rel_type.h" +#include "function/gds/rec_joins.h" +#include "node_expression.h" + +namespace lbug { +namespace binder { + +enum class RelDirectionType : uint8_t { + SINGLE = 0, + BOTH = 1, + UNKNOWN = 2, +}; + +class RelExpression; + +struct RecursiveInfo { + /* + * E.g. [e*1..2 (r, n) | WHERE n.age > 10 AND r.year = 2012 ] + * node = n + * nodeCopy = n (see comment below) + * rel = r + * predicates = [n.age > 10, r.year = 2012] + * */ + std::shared_ptr node = nullptr; + // NodeCopy has the same fields as node but a different unique name. + // We use nodeCopy to plan recursive plan because boundNode&nbrNode cannot be the same. + std::shared_ptr nodeCopy = nullptr; + std::shared_ptr rel = nullptr; + // Predicates + std::shared_ptr nodePredicate = nullptr; + std::shared_ptr relPredicate = nullptr; + // Projection list + expression_vector nodeProjectionList; + expression_vector relProjectionList; + // Function information + std::unique_ptr function; + std::unique_ptr bindData; +}; + +class LBUG_API RelExpression final : public NodeOrRelExpression { +public: + RelExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName, + std::vector entries, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType, + common::QueryRelType relType) + : NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName), + std::move(entries)}, + srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directionType{directionType}, + relType{relType} {} + + bool isRecursive() const { + return dataType.getLogicalTypeID() == common::LogicalTypeID::RECURSIVE_REL; + } + + bool isMultiLabeled() const override; + bool isBoundByMultiLabeledNode() const { + return srcNode->isMultiLabeled() || dstNode->isMultiLabeled(); + } + + std::shared_ptr getSrcNode() const { return srcNode; } + std::string getSrcNodeName() const { return srcNode->getUniqueName(); } + void setDstNode(std::shared_ptr node) { dstNode = std::move(node); } + std::shared_ptr getDstNode() const { return dstNode; } + std::string getDstNodeName() const { return dstNode->getUniqueName(); } + + void setLeftNode(std::shared_ptr node) { leftNode = std::move(node); } + std::shared_ptr getLeftNode() const { return leftNode; } + void setRightNode(std::shared_ptr node) { rightNode = std::move(node); } + std::shared_ptr getRightNode() const { return rightNode; } + + common::QueryRelType getRelType() const { return relType; } + + void setDirectionExpr(std::shared_ptr expr) { directionExpr = std::move(expr); } + bool hasDirectionExpr() const { return directionExpr != nullptr; } + std::shared_ptr getDirectionExpr() const { return directionExpr; } + RelDirectionType getDirectionType() const { return directionType; } + + std::shared_ptr getInternalID() const override { + return getPropertyExpression(common::InternalKeyword::ID); + } + + void setRecursiveInfo(std::unique_ptr recursiveInfo_) { + recursiveInfo = std::move(recursiveInfo_); + } + const RecursiveInfo* getRecursiveInfo() const { return recursiveInfo.get(); } + std::shared_ptr getLengthExpression() const { + KU_ASSERT(recursiveInfo != nullptr); + return recursiveInfo->bindData->lengthExpr; + } + + bool isSelfLoop() const { return *srcNode == *dstNode; } + + std::string detailsToString() const; + + // if multiple tables match the pattern + // returns the intersection of available extend directions for all matched tables + std::vector getExtendDirections() const; + + std::vector getInnerRelTableIDs() const; + +private: + // Start node if a directed arrow is given. Left node otherwise. + std::shared_ptr srcNode; + // End node if a directed arrow is given. Right node otherwise. + std::shared_ptr dstNode; + std::shared_ptr leftNode; + std::shared_ptr rightNode; + // Whether relationship is directed. + RelDirectionType directionType; + // Direction expr is nullptr when direction type is SINGLE + std::shared_ptr directionExpr; + // Whether relationship type is recursive. + common::QueryRelType relType; + // Null if relationship type is non-recursive. + std::unique_ptr recursiveInfo; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/scalar_function_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/scalar_function_expression.h new file mode 100644 index 0000000000..c714831355 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/scalar_function_expression.h @@ -0,0 +1,33 @@ +#pragma once + +#include "expression.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace binder { + +class ScalarFunctionExpression final : public Expression { +public: + ScalarFunctionExpression(common::ExpressionType expressionType, + std::unique_ptr function, + std::unique_ptr bindData, expression_vector children, + std::string uniqueName) + : Expression{expressionType, bindData->resultType.copy(), std::move(children), + std::move(uniqueName)}, + function{std::move(function)}, bindData{std::move(bindData)} {} + + const function::ScalarFunction& getFunction() const { return *function; } + function::FunctionBindData* getBindData() const { return bindData.get(); } + + std::string toStringInternal() const override; + + static std::string getUniqueName(const std::string& functionName, + const expression_vector& children); + +private: + std::unique_ptr function; + std::unique_ptr bindData; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/subquery_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/subquery_expression.h new file mode 100644 index 0000000000..28d1c2a157 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/subquery_expression.h @@ -0,0 +1,55 @@ +#pragma once + +#include "binder/query/query_graph.h" +#include "binder/query/reading_clause/bound_join_hint.h" +#include "common/enums/subquery_type.h" +#include "expression.h" + +namespace lbug { +namespace binder { + +class SubqueryExpression final : public Expression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::SUBQUERY; + +public: + SubqueryExpression(common::SubqueryType subqueryType, common::LogicalType dataType, + QueryGraphCollection queryGraphCollection, std::string uniqueName, std::string rawName) + : Expression{expressionType_, std::move(dataType), std::move(uniqueName)}, + subqueryType{subqueryType}, queryGraphCollection{std::move(queryGraphCollection)}, + rawName{std::move(rawName)} {} + + common::SubqueryType getSubqueryType() const { return subqueryType; } + + const QueryGraphCollection* getQueryGraphCollection() const { return &queryGraphCollection; } + + void setWhereExpression(std::shared_ptr expression) { + whereExpression = std::move(expression); + } + bool hasWhereExpression() const { return whereExpression != nullptr; } + std::shared_ptr getWhereExpression() const { return whereExpression; } + expression_vector getPredicatesSplitOnAnd() const { + return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{}; + } + + void setCountStarExpr(std::shared_ptr expr) { countStarExpr = std::move(expr); } + std::shared_ptr getCountStarExpr() const { return countStarExpr; } + void setProjectionExpr(std::shared_ptr expr) { projectionExpr = std::move(expr); } + std::shared_ptr getProjectionExpr() const { return projectionExpr; } + + void setHint(std::shared_ptr root) { hintRoot = std::move(root); } + std::shared_ptr getHint() const { return hintRoot; } + + std::string toStringInternal() const override { return rawName; } + +private: + common::SubqueryType subqueryType; + QueryGraphCollection queryGraphCollection; + std::shared_ptr whereExpression; + std::shared_ptr countStarExpr; + std::shared_ptr projectionExpr; + std::shared_ptr hintRoot; + std::string rawName; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/variable_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/variable_expression.h new file mode 100644 index 0000000000..7c56b26273 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression/variable_expression.h @@ -0,0 +1,28 @@ +#pragma once + +#include "expression.h" + +namespace lbug { +namespace binder { + +class VariableExpression final : public Expression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::VARIABLE; + +public: + VariableExpression(common::LogicalType dataType, std::string uniqueName, + std::string variableName) + : Expression{expressionType_, std::move(dataType), std::move(uniqueName)}, + variableName{std::move(variableName)} {} + + std::string getVariableName() const { return variableName; } + + void cast(const common::LogicalType& type) override; + + std::string toStringInternal() const override { return variableName; } + +private: + std::string variableName; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression_binder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression_binder.h new file mode 100644 index 0000000000..1f41266696 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression_binder.h @@ -0,0 +1,155 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/types/value/value.h" +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace function { +struct Function; +} + +namespace binder { + +class Binder; +struct CaseAlternative; + +struct ExpressionBinderConfig { + // If a property is not in projection list but required in order by after aggregation, + // we need to bind it as struct extraction because node/rel must have been evaluated as + // struct during aggregate + // e.g. RETURN a, COUNT(*) ORDER BY a.ID + bool bindOrderByAfterAggregate = false; + // If a node is single labeled, we rewrite its label function as string literal. This however, + // should be applied to recursive pattern predicate because if path is of length <= 1, there + // is no intermediate node and thus the predicate should be a noop. If we try to evaluate, it + // may lead to empty result. + // e.g. [* (r, n | WHERE label(n)='dummy') ] + bool disableLabelFunctionLiteralRewrite = false; +}; + +class ExpressionBinder { + friend class Binder; + +public: + ExpressionBinder(Binder* queryBinder, main::ClientContext* context) + : binder{queryBinder}, context{context} {} + + std::shared_ptr bindExpression(const parser::ParsedExpression& parsedExpression); + + // TODO(Xiyang): move to an expression rewriter + LBUG_API std::shared_ptr foldExpression( + const std::shared_ptr& expression) const; + + // Boolean expressions. + std::shared_ptr bindBooleanExpression( + const parser::ParsedExpression& parsedExpression); + std::shared_ptr bindBooleanExpression(common::ExpressionType expressionType, + const expression_vector& children); + + std::shared_ptr combineBooleanExpressions(common::ExpressionType expressionType, + std::shared_ptr left, std::shared_ptr right); + // Comparison expressions. + std::shared_ptr bindComparisonExpression( + const parser::ParsedExpression& parsedExpression); + std::shared_ptr bindComparisonExpression(common::ExpressionType expressionType, + const expression_vector& children); + std::shared_ptr createEqualityComparisonExpression(std::shared_ptr left, + std::shared_ptr right); + // Null operator expressions. + std::shared_ptr bindNullOperatorExpression( + const parser::ParsedExpression& parsedExpression); + std::shared_ptr bindNullOperatorExpression(common::ExpressionType expressionType, + const expression_vector& children); + + // Property expressions. + expression_vector bindPropertyStarExpression(const parser::ParsedExpression& parsedExpression); + static expression_vector bindNodeOrRelPropertyStarExpression(const Expression& child); + expression_vector bindStructPropertyStarExpression(const std::shared_ptr& child); + std::shared_ptr bindPropertyExpression( + const parser::ParsedExpression& parsedExpression); + static std::shared_ptr bindNodeOrRelPropertyExpression(const Expression& child, + const std::string& propertyName); + std::shared_ptr bindStructPropertyExpression(std::shared_ptr child, + const std::string& propertyName); + // Function expressions. + std::shared_ptr bindFunctionExpression(const parser::ParsedExpression& expr); + void bindLambdaExpression(const Expression& lambdaInput, Expression& lambdaExpr) const; + std::shared_ptr bindLambdaExpression( + const parser::ParsedExpression& parsedExpr) const; + + std::shared_ptr bindScalarFunctionExpression( + const parser::ParsedExpression& parsedExpression, const std::string& functionName); + std::shared_ptr bindScalarFunctionExpression(const expression_vector& children, + const std::string& functionName, + std::vector optionalArguments = std::vector{}); + std::shared_ptr bindRewriteFunctionExpression(const parser::ParsedExpression& expr); + std::shared_ptr bindAggregateFunctionExpression( + const parser::ParsedExpression& parsedExpression, const std::string& functionName, + bool isDistinct); + std::shared_ptr bindMacroExpression( + const parser::ParsedExpression& parsedExpression, const std::string& macroName); + + // Parameter expressions. + std::shared_ptr bindParameterExpression( + const parser::ParsedExpression& parsedExpression); + // Literal expressions. + std::shared_ptr bindLiteralExpression( + const parser::ParsedExpression& parsedExpression) const; + std::shared_ptr createLiteralExpression(const common::Value& value) const; + std::shared_ptr createLiteralExpression(const std::string& strVal) const; + std::shared_ptr createNullLiteralExpression() const; + std::shared_ptr createNullLiteralExpression(const common::Value& value) const; + // Variable expressions. + std::shared_ptr bindVariableExpression( + const parser::ParsedExpression& parsedExpression) const; + std::shared_ptr bindVariableExpression(const std::string& varName) const; + std::shared_ptr createVariableExpression(common::LogicalType logicalType, + std::string_view name) const; + std::shared_ptr createVariableExpression(common::LogicalType logicalType, + std::string name) const; + // Subquery expressions. + std::shared_ptr bindSubqueryExpression(const parser::ParsedExpression& parsedExpr); + // Case expressions. + std::shared_ptr bindCaseExpression( + const parser::ParsedExpression& parsedExpression); + + /****** cast *****/ + LBUG_API std::shared_ptr implicitCastIfNecessary( + const std::shared_ptr& expression, const common::LogicalType& targetType); + // Use implicitCast to cast to types you have obtained through known implicit casting rules. + // Use forceCast to cast to types you have obtained through other means, for example, + // through a maxLogicalType function + std::shared_ptr implicitCast(const std::shared_ptr& expression, + const common::LogicalType& targetType); + std::shared_ptr forceCast(const std::shared_ptr& expression, + const common::LogicalType& targetType); + + // Parameter + void addParameter(const std::string& name, std::shared_ptr value); + const std::unordered_set& getUnknownParameters() const { + return unknownParameters; + } + const std::unordered_map>& + getKnownParameters() const { + return knownParameters; + } + + std::string getUniqueName(const std::string& name) const; + + const ExpressionBinderConfig& getConfig() { return config; } + +private: + Binder* binder; + main::ClientContext* context; + std::unordered_set unknownParameters; + std::unordered_map> knownParameters; + ExpressionBinderConfig config; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression_visitor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression_visitor.h new file mode 100644 index 0000000000..6be80adc1a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/expression_visitor.h @@ -0,0 +1,101 @@ +#pragma once + +#include "binder/expression/expression.h" + +namespace lbug { +namespace binder { + +class ExpressionChildrenCollector { +public: + static expression_vector collectChildren(const Expression& expression); + +private: + static expression_vector collectCaseChildren(const Expression& expression); + + static expression_vector collectSubqueryChildren(const Expression& expression); + + static expression_vector collectNodeChildren(const Expression& expression); + + static expression_vector collectRelChildren(const Expression& expression); +}; + +class ExpressionVisitor { +public: + virtual ~ExpressionVisitor() = default; + + void visit(std::shared_ptr expr); + + static bool isRandom(const Expression& expression); + +protected: + void visitSwitch(std::shared_ptr expr); + virtual void visitFunctionExpr(std::shared_ptr) {} + virtual void visitAggFunctionExpr(std::shared_ptr) {} + virtual void visitPropertyExpr(std::shared_ptr) {} + virtual void visitLiteralExpr(std::shared_ptr) {} + virtual void visitVariableExpr(std::shared_ptr) {} + virtual void visitPathExpr(std::shared_ptr) {} + virtual void visitNodeRelExpr(std::shared_ptr) {} + virtual void visitParamExpr(std::shared_ptr) {} + virtual void visitSubqueryExpr(std::shared_ptr) {} + virtual void visitCaseExpr(std::shared_ptr) {} + virtual void visitGraphExpr(std::shared_ptr) {} + virtual void visitLambdaExpr(std::shared_ptr) {} + + virtual void visitChildren(const Expression& expr); + void visitCaseExprChildren(const Expression& expr); +}; + +// Do not collect subquery expression recursively. Caller should handle recursive subquery instead. +class SubqueryExprCollector final : public ExpressionVisitor { +public: + bool hasSubquery() const { return !exprs.empty(); } + expression_vector getSubqueryExprs() const { return exprs; } + +protected: + void visitSubqueryExpr(std::shared_ptr expr) override { exprs.push_back(expr); } + +private: + expression_vector exprs; +}; + +class DependentVarNameCollector final : public ExpressionVisitor { +public: + std::unordered_set getVarNames() const { return varNames; } + +protected: + void visitSubqueryExpr(std::shared_ptr expr) override; + void visitPropertyExpr(std::shared_ptr expr) override; + void visitNodeRelExpr(std::shared_ptr expr) override; + void visitVariableExpr(std::shared_ptr expr) override; + +private: + std::unordered_set varNames; +}; + +class PropertyExprCollector final : public ExpressionVisitor { +public: + expression_vector getPropertyExprs() const { return expressions; } + +protected: + void visitSubqueryExpr(std::shared_ptr expr) override; + void visitPropertyExpr(std::shared_ptr expr) override; + void visitNodeRelExpr(std::shared_ptr expr) override; + +private: + expression_vector expressions; +}; + +class ConstantExpressionVisitor { +public: + static bool needFold(const Expression& expr); + static bool isConstant(const Expression& expr); + +private: + static bool visitFunction(const Expression& expr); + static bool visitCase(const Expression& expr); + static bool visitChildren(const Expression& expr); +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/bound_regular_query.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/bound_regular_query.h new file mode 100644 index 0000000000..d5293ff505 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/bound_regular_query.h @@ -0,0 +1,33 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "normalized_single_query.h" + +namespace lbug { +namespace binder { + +class BoundRegularQuery final : public BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::QUERY; + +public: + explicit BoundRegularQuery(std::vector isUnionAll, BoundStatementResult statementResult) + : BoundStatement{type_, std::move(statementResult)}, isUnionAll{std::move(isUnionAll)} {} + + void addSingleQuery(NormalizedSingleQuery singleQuery) { + singleQueries.push_back(std::move(singleQuery)); + } + common::idx_t getNumSingleQueries() const { return singleQueries.size(); } + NormalizedSingleQuery* getSingleQueryUnsafe(common::idx_t idx) { return &singleQueries[idx]; } + const NormalizedSingleQuery* getSingleQuery(common::idx_t idx) const { + return &singleQueries[idx]; + } + + bool getIsUnionAll(common::idx_t idx) const { return isUnionAll[idx]; } + +private: + std::vector singleQueries; + std::vector isUnionAll; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/normalized_query_part.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/normalized_query_part.h new file mode 100644 index 0000000000..62c912242d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/normalized_query_part.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +#include "binder/query/reading_clause/bound_reading_clause.h" +#include "binder/query/return_with_clause/bound_projection_body.h" +#include "binder/query/updating_clause/bound_updating_clause.h" + +namespace lbug { +namespace binder { + +class NormalizedQueryPart { + friend class NormalizedQueryPartMatchRewriter; + +public: + NormalizedQueryPart() = default; + DELETE_COPY_DEFAULT_MOVE(NormalizedQueryPart); + + void addReadingClause(std::unique_ptr boundReadingClause) { + readingClauses.push_back(std::move(boundReadingClause)); + } + bool hasReadingClause() const { return !readingClauses.empty(); } + uint32_t getNumReadingClause() const { return readingClauses.size(); } + BoundReadingClause* getReadingClause(uint32_t idx) const { return readingClauses[idx].get(); } + + void addUpdatingClause(std::unique_ptr boundUpdatingClause) { + updatingClauses.push_back(std::move(boundUpdatingClause)); + } + bool hasUpdatingClause() const { return !updatingClauses.empty(); } + uint32_t getNumUpdatingClause() const { return updatingClauses.size(); } + BoundUpdatingClause* getUpdatingClause(uint32_t idx) const { + return updatingClauses[idx].get(); + } + + void setProjectionBody(BoundProjectionBody boundProjectionBody) { + projectionBody = std::move(boundProjectionBody); + } + bool hasProjectionBody() const { return projectionBody.has_value(); } + BoundProjectionBody* getProjectionBodyUnsafe() { + KU_ASSERT(projectionBody.has_value()); + return &projectionBody.value(); + } + const BoundProjectionBody* getProjectionBody() const { + KU_ASSERT(projectionBody.has_value()); + return &projectionBody.value(); + } + + bool hasProjectionBodyPredicate() const { return projectionBodyPredicate != nullptr; } + std::shared_ptr getProjectionBodyPredicate() const { + return projectionBodyPredicate; + } + void setProjectionBodyPredicate(const std::shared_ptr& predicate) { + projectionBodyPredicate = predicate; + } + +private: + std::vector> readingClauses; + std::vector> updatingClauses; + std::optional projectionBody; + std::shared_ptr projectionBodyPredicate; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/normalized_single_query.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/normalized_single_query.h new file mode 100644 index 0000000000..e0d09cfa54 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/normalized_single_query.h @@ -0,0 +1,30 @@ +#pragma once + +#include "binder/bound_statement_result.h" +#include "normalized_query_part.h" + +namespace lbug { +namespace binder { + +class NormalizedSingleQuery { +public: + NormalizedSingleQuery() = default; + DELETE_COPY_DEFAULT_MOVE(NormalizedSingleQuery); + + void appendQueryPart(NormalizedQueryPart queryPart) { + queryParts.push_back(std::move(queryPart)); + } + common::idx_t getNumQueryParts() const { return queryParts.size(); } + NormalizedQueryPart* getQueryPartUnsafe(common::idx_t idx) { return &queryParts[idx]; } + const NormalizedQueryPart* getQueryPart(common::idx_t idx) const { return &queryParts[idx]; } + + void setStatementResult(BoundStatementResult result) { statementResult = std::move(result); } + const BoundStatementResult* getStatementResult() const { return &statementResult; } + +private: + std::vector queryParts; + BoundStatementResult statementResult; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/query_graph.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/query_graph.h new file mode 100644 index 0000000000..38937c7849 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/query_graph.h @@ -0,0 +1,178 @@ +#pragma once + +#include +#include + +#include "binder/expression/rel_expression.h" + +namespace lbug { +namespace binder { + +constexpr static uint8_t MAX_NUM_QUERY_VARIABLES = 64; + +class QueryGraph; +struct SubqueryGraph; +struct SubqueryGraphHasher; +using subquery_graph_set_t = std::unordered_set; +template +using subquery_graph_V_map_t = std::unordered_map; + +// hash on node bitset if subgraph has no rel +struct SubqueryGraphHasher { + std::size_t operator()(const SubqueryGraph& key) const; +}; + +struct SubqueryGraph { + const QueryGraph& queryGraph; + std::bitset queryNodesSelector; + std::bitset queryRelsSelector; + + explicit SubqueryGraph(const QueryGraph& queryGraph) : queryGraph{queryGraph} {} + + void addQueryNode(common::idx_t nodePos) { queryNodesSelector[nodePos] = true; } + void addQueryRel(common::idx_t relPos) { queryRelsSelector[relPos] = true; } + void addSubqueryGraph(const SubqueryGraph& other) { + queryRelsSelector |= other.queryRelsSelector; + queryNodesSelector |= other.queryNodesSelector; + } + + common::idx_t getNumQueryRels() const { return queryRelsSelector.count(); } + common::idx_t getTotalNumVariables() const { + return queryNodesSelector.count() + queryRelsSelector.count(); + } + bool isSingleRel() const { + return queryRelsSelector.count() == 1 && queryNodesSelector.count() == 0; + } + + bool containAllVariables(const std::unordered_set& variables) const; + + std::unordered_set getNodeNbrPositions() const; + std::unordered_set getRelNbrPositions() const; + subquery_graph_set_t getNbrSubgraphs(uint32_t size) const; + std::vector getConnectedNodePos(const SubqueryGraph& nbr) const; + + // E.g. query graph (a)-[e1]->(b) and subgraph (a)-[e1], although (b) is not in subgraph, we + // return both (a) and (b) regardless of node selector. See needPruneJoin() in + // join_order_enumerator.cpp for its use case. + std::unordered_set getNodePositionsIgnoringNodeSelector() const; + + std::vector getNbrNodeIndices() const; + + bool operator==(const SubqueryGraph& other) const { + return queryRelsSelector == other.queryRelsSelector && + queryNodesSelector == other.queryNodesSelector; + } + +private: + subquery_graph_set_t getBaseNbrSubgraph() const; + subquery_graph_set_t getNextNbrSubgraphs(const SubqueryGraph& prevNbr) const; +}; + +// QueryGraph represents a connected pattern specified in MATCH clause. +class QueryGraph { +public: + QueryGraph() = default; + QueryGraph(const QueryGraph& other) + : queryNodeNameToPosMap{other.queryNodeNameToPosMap}, + queryRelNameToPosMap{other.queryRelNameToPosMap}, queryNodes{other.queryNodes}, + queryRels{other.queryRels} {} + + EXPLICIT_COPY_DEFAULT_MOVE(QueryGraph); + + bool isEmpty() const; + + std::vector> getAllPatterns() const; + + common::idx_t getNumQueryNodes() const { return queryNodes.size(); } + bool containsQueryNode(const std::string& queryNodeName) const { + return queryNodeNameToPosMap.contains(queryNodeName); + } + std::vector> getQueryNodes() const { return queryNodes; } + std::shared_ptr getQueryNode(const std::string& queryNodeName) const { + return queryNodes[getQueryNodeIdx(queryNodeName)]; + } + std::vector> getQueryNodes( + const std::vector& nodePoses) const { + std::vector> result; + result.reserve(nodePoses.size()); + for (auto nodePos : nodePoses) { + result.push_back(queryNodes[nodePos]); + } + return result; + } + std::shared_ptr getQueryNode(common::idx_t nodePos) const { + return queryNodes[nodePos]; + } + common::idx_t getQueryNodeIdx(const NodeExpression& node) const { + return getQueryNodeIdx(node.getUniqueName()); + } + common::idx_t getQueryNodeIdx(const std::string& queryNodeName) const { + return queryNodeNameToPosMap.at(queryNodeName); + } + void addQueryNode(std::shared_ptr queryNode); + + common::idx_t getNumQueryRels() const { return queryRels.size(); } + bool containsQueryRel(const std::string& queryRelName) const { + return queryRelNameToPosMap.contains(queryRelName); + } + std::vector> getQueryRels() const { return queryRels; } + std::shared_ptr getQueryRel(const std::string& queryRelName) const { + return queryRels.at(queryRelNameToPosMap.at(queryRelName)); + } + std::shared_ptr getQueryRel(common::idx_t relPos) const { + return queryRels[relPos]; + } + common::idx_t getQueryRelIdx(const std::string& queryRelName) const { + return queryRelNameToPosMap.at(queryRelName); + } + void addQueryRel(std::shared_ptr queryRel); + + bool canProjectExpression(const std::shared_ptr& expression) const; + + bool isConnected(const QueryGraph& other) const; + + void merge(const QueryGraph& other); + +private: + std::unordered_map queryNodeNameToPosMap; + std::unordered_map queryRelNameToPosMap; + std::vector> queryNodes; + std::vector> queryRels; +}; + +// QueryGraphCollection represents a pattern (a set of connected components) specified in MATCH +// clause. +class QueryGraphCollection { +public: + QueryGraphCollection() = default; + DELETE_COPY_DEFAULT_MOVE(QueryGraphCollection); + + void merge(const QueryGraphCollection& other); + void addAndMergeQueryGraphIfConnected(QueryGraph queryGraphToAdd); + void finalize(); + + common::idx_t getNumQueryGraphs() const { return queryGraphs.size(); } + QueryGraph* getQueryGraphUnsafe(common::idx_t idx) { return &queryGraphs[idx]; } + const QueryGraph* getQueryGraph(common::idx_t idx) const { return &queryGraphs[idx]; } + + bool contains(const std::string& name) const; + LBUG_API std::vector> getQueryNodes() const; + LBUG_API std::vector> getQueryRels() const; + +private: + std::vector mergeGraphs(common::idx_t baseGraphIdx); + +private: + std::vector queryGraphs; +}; + +struct BoundGraphPattern { + QueryGraphCollection queryGraphCollection; + std::shared_ptr where; + + BoundGraphPattern() = default; + DELETE_COPY_DEFAULT_MOVE(BoundGraphPattern); +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/query_graph_label_analyzer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/query_graph_label_analyzer.h new file mode 100644 index 0000000000..ed665d02bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/query_graph_label_analyzer.h @@ -0,0 +1,25 @@ +#pragma once + +#include "query_graph.h" + +namespace lbug { +namespace binder { + +class QueryGraphLabelAnalyzer { +public: + explicit QueryGraphLabelAnalyzer(const main::ClientContext& clientContext, bool throwOnViolate) + : throwOnViolate{throwOnViolate}, clientContext{clientContext} {} + + void pruneLabel(QueryGraph& graph) const; + +private: + void pruneNode(const QueryGraph& graph, NodeExpression& node) const; + void pruneRel(RelExpression& rel) const; + +private: + bool throwOnViolate; + const main::ClientContext& clientContext; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_join_hint.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_join_hint.h new file mode 100644 index 0000000000..f1be0a416e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_join_hint.h @@ -0,0 +1,26 @@ +#pragma once + +#include "binder/expression/expression.h" + +namespace lbug { +namespace binder { + +struct BoundJoinHintNode { + std::shared_ptr nodeOrRel; + std::vector> children; + + BoundJoinHintNode() = default; + explicit BoundJoinHintNode(std::shared_ptr nodeOrRel) + : nodeOrRel{std::move(nodeOrRel)} {} + + void addChild(std::shared_ptr child) { + children.push_back(std::move(child)); + } + + bool isLeaf() const { return children.empty(); } + bool isBinary() const { return children.size() == 2; } + bool isMultiWay() const { return children.size() > 2; } +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_load_from.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_load_from.h new file mode 100644 index 0000000000..2a8c3522a4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_load_from.h @@ -0,0 +1,23 @@ +#pragma once + +#include "binder/bound_table_scan_info.h" +#include "bound_reading_clause.h" + +namespace lbug { +namespace binder { + +class BoundLoadFrom final : public BoundReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::LOAD_FROM; + +public: + explicit BoundLoadFrom(BoundTableScanInfo info) + : BoundReadingClause{clauseType_}, info{std::move(info)} {} + + const BoundTableScanInfo* getInfo() const { return &info; } + +private: + BoundTableScanInfo info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_match_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_match_clause.h new file mode 100644 index 0000000000..a785148dbe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_match_clause.h @@ -0,0 +1,34 @@ +#pragma once + +#include "binder/query/query_graph.h" +#include "bound_join_hint.h" +#include "bound_reading_clause.h" + +namespace lbug { +namespace binder { + +class LBUG_API BoundMatchClause final : public BoundReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::MATCH; + +public: + BoundMatchClause(QueryGraphCollection collection, common::MatchClauseType matchClauseType) + : BoundReadingClause{clauseType_}, collection{std::move(collection)}, + matchClauseType{matchClauseType} {} + + QueryGraphCollection* getQueryGraphCollectionUnsafe() { return &collection; } + const QueryGraphCollection* getQueryGraphCollection() const { return &collection; } + + common::MatchClauseType getMatchClauseType() const { return matchClauseType; } + + void setHint(std::shared_ptr root) { hintRoot = std::move(root); } + bool hasHint() const { return hintRoot != nullptr; } + std::shared_ptr getHint() const { return hintRoot; } + +private: + QueryGraphCollection collection; + common::MatchClauseType matchClauseType; + std::shared_ptr hintRoot; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_reading_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_reading_clause.h new file mode 100644 index 0000000000..407c000431 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_reading_clause.h @@ -0,0 +1,48 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/enums/clause_type.h" + +namespace lbug { +namespace binder { + +class LBUG_API BoundReadingClause { +public: + explicit BoundReadingClause(common::ClauseType clauseType) : clauseType{clauseType} {} + DELETE_COPY_DEFAULT_MOVE(BoundReadingClause); + virtual ~BoundReadingClause() = default; + + common::ClauseType getClauseType() const { return clauseType; } + + void setPredicate(std::shared_ptr predicate_) { predicate = std::move(predicate_); } + bool hasPredicate() const { return predicate != nullptr; } + std::shared_ptr getPredicate() const { return predicate; } + expression_vector getConjunctivePredicates() const { + return hasPredicate() ? predicate->splitOnAND() : expression_vector{}; + } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET* ptrCast() const { + return common::ku_dynamic_cast(this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + +private: + common::ClauseType clauseType; + // Predicate in WHERE clause + std::shared_ptr predicate; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_table_function_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_table_function_call.h new file mode 100644 index 0000000000..7e6d6bfb82 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_table_function_call.h @@ -0,0 +1,25 @@ +#pragma once + +#include "binder/bound_table_scan_info.h" +#include "binder/query/reading_clause/bound_reading_clause.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace binder { + +class LBUG_API BoundTableFunctionCall : public BoundReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::TABLE_FUNCTION_CALL; + +public: + explicit BoundTableFunctionCall(BoundTableScanInfo info) + : BoundReadingClause{clauseType_}, info{std::move(info)} {} + + const function::TableFunction& getTableFunc() const { return info.func; } + const function::TableFuncBindData* getBindData() const { return info.bindData.get(); } + +private: + BoundTableScanInfo info; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_unwind_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_unwind_clause.h new file mode 100644 index 0000000000..ee94daa258 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/reading_clause/bound_unwind_clause.h @@ -0,0 +1,27 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "bound_reading_clause.h" + +namespace lbug { +namespace binder { + +class BoundUnwindClause final : public BoundReadingClause { +public: + BoundUnwindClause(std::shared_ptr inExpr, std::shared_ptr outExpr, + std::shared_ptr idExpr) + : BoundReadingClause{common::ClauseType::UNWIND}, inExpr{std::move(inExpr)}, + outExpr{std::move(outExpr)}, idExpr{std::move(idExpr)} {} + + std::shared_ptr getInExpr() const { return inExpr; } + std::shared_ptr getOutExpr() const { return outExpr; } + std::shared_ptr getIDExpr() const { return idExpr; } + +private: + std::shared_ptr inExpr; + std::shared_ptr outExpr; + std::shared_ptr idExpr; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_projection_body.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_projection_body.h new file mode 100644 index 0000000000..49cdece68f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_projection_body.h @@ -0,0 +1,72 @@ +#pragma once + +#include "binder/expression/expression.h" + +namespace lbug { +namespace binder { + +class BoundProjectionBody { + static constexpr uint64_t INVALID_NUMBER = UINT64_MAX; + +public: + explicit BoundProjectionBody(bool distinct) + : distinct{distinct}, skipNumber{nullptr}, limitNumber{nullptr} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundProjectionBody); + + bool isDistinct() const { return distinct; } + + void setProjectionExpressions(expression_vector expressions) { + projectionExpressions = std::move(expressions); + } + expression_vector getProjectionExpressions() const { return projectionExpressions; } + + void setGroupByExpressions(expression_vector expressions) { + groupByExpressions = std::move(expressions); + } + expression_vector getGroupByExpressions() const { return groupByExpressions; } + + void setAggregateExpressions(expression_vector expressions) { + aggregateExpressions = std::move(expressions); + } + bool hasAggregateExpressions() const { return !aggregateExpressions.empty(); } + expression_vector getAggregateExpressions() const { return aggregateExpressions; } + + void setOrderByExpressions(expression_vector expressions, std::vector sortOrders) { + orderByExpressions = std::move(expressions); + isAscOrders = std::move(sortOrders); + } + bool hasOrderByExpressions() const { return !orderByExpressions.empty(); } + const expression_vector& getOrderByExpressions() const { return orderByExpressions; } + const std::vector& getSortingOrders() const { return isAscOrders; } + + void setSkipNumber(std::shared_ptr number) { skipNumber = std::move(number); } + bool hasSkip() const { return skipNumber != nullptr; } + std::shared_ptr getSkipNumber() const { return skipNumber; } + + void setLimitNumber(std::shared_ptr number) { limitNumber = std::move(number); } + bool hasLimit() const { return limitNumber != nullptr; } + std::shared_ptr getLimitNumber() const { return limitNumber; } + + bool hasSkipOrLimit() const { return hasSkip() || hasLimit(); } + +private: + BoundProjectionBody(const BoundProjectionBody& other) + : distinct{other.distinct}, projectionExpressions{other.projectionExpressions}, + groupByExpressions{other.groupByExpressions}, + aggregateExpressions{other.aggregateExpressions}, + orderByExpressions{other.orderByExpressions}, isAscOrders{other.isAscOrders}, + skipNumber{other.skipNumber}, limitNumber{other.limitNumber} {} + +private: + bool distinct; + expression_vector projectionExpressions; + expression_vector groupByExpressions; + expression_vector aggregateExpressions; + expression_vector orderByExpressions; + std::vector isAscOrders; + std::shared_ptr skipNumber; + std::shared_ptr limitNumber; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_return_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_return_clause.h new file mode 100644 index 0000000000..8e6c1a7300 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_return_clause.h @@ -0,0 +1,28 @@ +#pragma once + +#include "binder/bound_statement_result.h" +#include "bound_projection_body.h" + +namespace lbug { +namespace binder { + +class BoundReturnClause { +public: + explicit BoundReturnClause(BoundProjectionBody projectionBody) + : projectionBody{std::move(projectionBody)} {} + BoundReturnClause(BoundProjectionBody projectionBody, BoundStatementResult statementResult) + : projectionBody{std::move(projectionBody)}, statementResult{std::move(statementResult)} {} + DELETE_COPY_DEFAULT_MOVE(BoundReturnClause); + virtual ~BoundReturnClause() = default; + + inline const BoundProjectionBody* getProjectionBody() const { return &projectionBody; } + + inline const BoundStatementResult* getStatementResult() const { return &statementResult; } + +protected: + BoundProjectionBody projectionBody; + BoundStatementResult statementResult; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_with_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_with_clause.h new file mode 100644 index 0000000000..f361c4952b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/return_with_clause/bound_with_clause.h @@ -0,0 +1,24 @@ +#pragma once + +#include "bound_return_clause.h" + +namespace lbug { +namespace binder { + +class BoundWithClause final : public BoundReturnClause { +public: + explicit BoundWithClause(BoundProjectionBody projectionBody) + : BoundReturnClause{std::move(projectionBody)} {} + + void setWhereExpression(std::shared_ptr expression) { + whereExpression = std::move(expression); + } + bool hasWhereExpression() const { return whereExpression != nullptr; } + std::shared_ptr getWhereExpression() const { return whereExpression; } + +private: + std::shared_ptr whereExpression; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_delete_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_delete_clause.h new file mode 100644 index 0000000000..36ade5a9ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_delete_clause.h @@ -0,0 +1,42 @@ +#pragma once + +#include "bound_delete_info.h" +#include "bound_updating_clause.h" + +namespace lbug { +namespace binder { + +class BoundDeleteClause final : public BoundUpdatingClause { +public: + BoundDeleteClause() : BoundUpdatingClause{common::ClauseType::DELETE_} {}; + + void addInfo(BoundDeleteInfo info) { infos.push_back(std::move(info)); } + + bool hasNodeInfo() const { + return hasInfo( + [](const BoundDeleteInfo& info) { return info.tableType == common::TableType::NODE; }); + } + std::vector getNodeInfos() const { + return getInfos( + [](const BoundDeleteInfo& info) { return info.tableType == common::TableType::NODE; }); + } + bool hasRelInfo() const { + return hasInfo( + [](const BoundDeleteInfo& info) { return info.tableType == common::TableType::REL; }); + } + std::vector getRelInfos() const { + return getInfos( + [](const BoundDeleteInfo& info) { return info.tableType == common::TableType::REL; }); + } + +private: + bool hasInfo(const std::function& check) const; + std::vector getInfos( + const std::function& check) const; + +private: + std::vector infos; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_delete_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_delete_info.h new file mode 100644 index 0000000000..e554448f53 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_delete_info.h @@ -0,0 +1,28 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/enums/delete_type.h" +#include "common/enums/table_type.h" + +namespace lbug { +namespace binder { + +struct BoundDeleteInfo { + common::DeleteNodeType deleteType; + common::TableType tableType; + std::shared_ptr pattern; + + BoundDeleteInfo(common::DeleteNodeType deleteType, common::TableType tableType, + std::shared_ptr pattern) + : deleteType{deleteType}, tableType{tableType}, pattern{std::move(pattern)} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundDeleteInfo); + + std::string toString() const { return "Delete " + pattern->toString(); } + +private: + BoundDeleteInfo(const BoundDeleteInfo& other) + : deleteType{other.deleteType}, tableType{other.tableType}, pattern{other.pattern} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_insert_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_insert_clause.h new file mode 100644 index 0000000000..186a5c5675 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_insert_clause.h @@ -0,0 +1,43 @@ +#pragma once + +#include "bound_insert_info.h" +#include "bound_updating_clause.h" + +namespace lbug { +namespace binder { + +class BoundInsertClause final : public BoundUpdatingClause { +public: + explicit BoundInsertClause(std::vector infos) + : BoundUpdatingClause{common::ClauseType::INSERT}, infos{std::move(infos)} {} + + const std::vector& getInfos() const { return infos; } + + bool hasNodeInfo() const { + return hasInfo( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; }); + } + std::vector getNodeInfos() const { + return getInfos( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; }); + } + bool hasRelInfo() const { + return hasInfo( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; }); + } + std::vector getRelInfos() const { + return getInfos( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; }); + } + +private: + bool hasInfo(const std::function& check) const; + std::vector getInfos( + const std::function& check) const; + +private: + std::vector infos; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_insert_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_insert_info.h new file mode 100644 index 0000000000..97e2f23d68 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_insert_info.h @@ -0,0 +1,29 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/enums/conflict_action.h" +#include "common/enums/table_type.h" + +namespace lbug { +namespace binder { + +struct BoundInsertInfo { + common::TableType tableType; + std::shared_ptr pattern; + expression_vector columnExprs; + expression_vector columnDataExprs; + common::ConflictAction conflictAction; + + BoundInsertInfo(common::TableType tableType, std::shared_ptr pattern) + : tableType{tableType}, pattern{std::move(pattern)}, + conflictAction{common::ConflictAction::ON_CONFLICT_THROW} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundInsertInfo); + +private: + BoundInsertInfo(const BoundInsertInfo& other) + : tableType{other.tableType}, pattern{other.pattern}, columnExprs{other.columnExprs}, + columnDataExprs{other.columnDataExprs}, conflictAction{other.conflictAction} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_merge_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_merge_clause.h new file mode 100644 index 0000000000..c0d930bb64 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_merge_clause.h @@ -0,0 +1,141 @@ +#pragma once + +#include "binder/query/query_graph.h" +#include "bound_insert_info.h" +#include "bound_set_info.h" +#include "bound_updating_clause.h" + +namespace lbug { +namespace binder { + +class BoundMergeClause final : public BoundUpdatingClause { + static constexpr common::ClauseType type_ = common::ClauseType::MERGE; + +public: + BoundMergeClause(expression_vector columnDataExprs, std::shared_ptr existenceMark, + std::shared_ptr distinctMark, QueryGraphCollection queryGraphCollection, + std::shared_ptr predicate, std::vector insertInfos) + : BoundUpdatingClause{type_}, columnDataExprs{std::move(columnDataExprs)}, + existenceMark{std::move(existenceMark)}, distinctMark{std::move(distinctMark)}, + queryGraphCollection{std::move(queryGraphCollection)}, predicate{std::move(predicate)}, + insertInfos{std::move(insertInfos)} {} + + expression_vector getColumnDataExprs() const { return columnDataExprs; } + + std::shared_ptr getExistenceMark() const { return existenceMark; } + std::shared_ptr getDistinctMark() const { return distinctMark; } + + const QueryGraphCollection* getQueryGraphCollection() const { return &queryGraphCollection; } + bool hasPredicate() const { return predicate != nullptr; } + std::shared_ptr getPredicate() const { return predicate; } + + const std::vector& getInsertInfosRef() const { return insertInfos; } + const std::vector& getOnMatchSetInfosRef() const { + return onMatchSetPropertyInfos; + } + const std::vector& getOnCreateSetInfosRef() const { + return onCreateSetPropertyInfos; + } + + bool hasInsertNodeInfo() const { + return hasInsertInfo( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; }); + } + std::vector getInsertNodeInfos() const { + return getInsertInfos( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; }); + } + bool hasInsertRelInfo() const { + return hasInsertInfo( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; }); + } + std::vector getInsertRelInfos() const { + return getInsertInfos( + [](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; }); + } + + bool hasOnMatchSetNodeInfo() const { + return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::NODE; + }); + } + std::vector getOnMatchSetNodeInfos() const { + return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::NODE; + }); + } + bool hasOnMatchSetRelInfo() const { + return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::REL; + }); + } + std::vector getOnMatchSetRelInfos() const { + return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::REL; + }); + } + + bool hasOnCreateSetNodeInfo() const { + return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::NODE; + }); + } + std::vector getOnCreateSetNodeInfos() const { + return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::NODE; + }); + } + bool hasOnCreateSetRelInfo() const { + return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::REL; + }); + } + std::vector getOnCreateSetRelInfos() const { + return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::REL; + }); + } + + void addOnMatchSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) { + onMatchSetPropertyInfos.push_back(std::move(setPropertyInfo)); + } + void addOnCreateSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) { + onCreateSetPropertyInfos.push_back(std::move(setPropertyInfo)); + } + +private: + bool hasInsertInfo(const std::function& check) const; + std::vector getInsertInfos( + const std::function& check) const; + + bool hasOnMatchSetInfo( + const std::function& check) const; + std::vector getOnMatchSetInfos( + const std::function& check) const; + + bool hasOnCreateSetInfo( + const std::function& check) const; + std::vector getOnCreateSetInfos( + const std::function& check) const; + +private: + // Capture user input column (right-hand-side) values in MERGE clause + // E.g. UNWIND [1,2,3] AS x MERGE (a {id:2, rank:x}) + // this field should be {2, x} + expression_vector columnDataExprs; + // Internal marks + std::shared_ptr existenceMark; + std::shared_ptr distinctMark; + // Pattern to match. + QueryGraphCollection queryGraphCollection; + std::shared_ptr predicate; + // Pattern to create on match failure. + std::vector insertInfos; + // Update on match + std::vector onMatchSetPropertyInfos; + // Update on create + std::vector onCreateSetPropertyInfos; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_set_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_set_clause.h new file mode 100644 index 0000000000..a24b503079 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_set_clause.h @@ -0,0 +1,47 @@ +#pragma once + +#include "bound_set_info.h" +#include "bound_updating_clause.h" + +namespace lbug { +namespace binder { + +class BoundSetClause final : public BoundUpdatingClause { +public: + BoundSetClause() : BoundUpdatingClause{common::ClauseType::SET} {} + + void addInfo(BoundSetPropertyInfo info) { infos.push_back(std::move(info)); } + const std::vector& getInfos() const { return infos; } + + bool hasNodeInfo() const { + return hasInfo([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::NODE; + }); + } + std::vector getNodeInfos() const { + return getInfos([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::NODE; + }); + } + bool hasRelInfo() const { + return hasInfo([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::REL; + }); + } + std::vector getRelInfos() const { + return getInfos([](const BoundSetPropertyInfo& info) { + return info.tableType == common::TableType::REL; + }); + } + +private: + bool hasInfo(const std::function& check) const; + std::vector getInfos( + const std::function& check) const; + +private: + std::vector infos; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_set_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_set_info.h new file mode 100644 index 0000000000..32f948b9d0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_set_info.h @@ -0,0 +1,28 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/enums/table_type.h" + +namespace lbug { +namespace binder { + +struct BoundSetPropertyInfo { + common::TableType tableType; + std::shared_ptr pattern; + std::shared_ptr column; + std::shared_ptr columnData; + + BoundSetPropertyInfo(common::TableType tableType, std::shared_ptr pattern, + std::shared_ptr column, std::shared_ptr columnData) + : tableType{tableType}, pattern{std::move(pattern)}, column{std::move(column)}, + columnData{std::move(columnData)} {} + EXPLICIT_COPY_DEFAULT_MOVE(BoundSetPropertyInfo); + +private: + BoundSetPropertyInfo(const BoundSetPropertyInfo& other) + : tableType{other.tableType}, pattern{other.pattern}, column{other.column}, + columnData{other.columnData} {} +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_updating_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_updating_clause.h new file mode 100644 index 0000000000..cedf6c0ba9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/query/updating_clause/bound_updating_clause.h @@ -0,0 +1,30 @@ +#pragma once + +#include "common/cast.h" +#include "common/enums/clause_type.h" + +namespace lbug { +namespace binder { + +class BoundUpdatingClause { +public: + explicit BoundUpdatingClause(common::ClauseType clauseType) : clauseType{clauseType} {} + virtual ~BoundUpdatingClause() = default; + + common::ClauseType getClauseType() const { return clauseType; } + + template + TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +private: + common::ClauseType clauseType; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/match_clause_pattern_label_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/match_clause_pattern_label_rewriter.h new file mode 100644 index 0000000000..4f2ba424c1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/match_clause_pattern_label_rewriter.h @@ -0,0 +1,21 @@ +#pragma once + +#include "binder/bound_statement_visitor.h" +#include "binder/query/query_graph_label_analyzer.h" + +namespace lbug { +namespace binder { + +class MatchClausePatternLabelRewriter final : public BoundStatementVisitor { +public: + explicit MatchClausePatternLabelRewriter(const main::ClientContext& clientContext) + : analyzer{clientContext, false /* throwOnViolate */} {} + + void visitMatchUnsafe(BoundReadingClause& readingClause) override; + +private: + QueryGraphLabelAnalyzer analyzer; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/normalized_query_part_match_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/normalized_query_part_match_rewriter.h new file mode 100644 index 0000000000..4d42d95bd6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/normalized_query_part_match_rewriter.h @@ -0,0 +1,31 @@ +#pragma once + +#include "binder/bound_statement_visitor.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace binder { + +// Merge consecutive match pattern in a query part. E.g. +// MATCH (a) WHERE a.ID = 0 +// MATCH (b) WHERE b.ID = 1 +// MATCH (a)-[]->(b) +// will be rewritten as +// MATCH (a)-[]->(b) WHERE a.ID = 0 AND b.ID = 1 +// This rewrite does not apply to MATCH with HINT or OPTIONAL MATCH +class NormalizedQueryPartMatchRewriter final : public BoundStatementVisitor { +public: + explicit NormalizedQueryPartMatchRewriter(main::ClientContext* clientContext) + : clientContext{clientContext} {} + +private: + void visitQueryPartUnsafe(NormalizedQueryPart& queryPart) override; + +private: + main::ClientContext* clientContext; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/with_clause_projection_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/with_clause_projection_rewriter.h new file mode 100644 index 0000000000..3ef6ad1999 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/rewriter/with_clause_projection_rewriter.h @@ -0,0 +1,20 @@ +#pragma once + +#include "binder/bound_statement_visitor.h" + +namespace lbug { +namespace binder { + +// WithClauseProjectionRewriter first analyze the properties need to be scanned for each query. And +// then rewrite node/rel expression in WITH clause as their properties. So We avoid eagerly evaluate +// node/rel in WITH clause projection. E.g. +// MATCH (a) WITH a MATCH (a)->(b); +// will be rewritten as +// MATCH (a) WITH a._id MATCH (a)->(b); +class WithClauseProjectionRewriter final : public BoundStatementVisitor { +public: + void visitSingleQueryUnsafe(NormalizedSingleQuery& singleQuery) override; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/confidential_statement_analyzer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/confidential_statement_analyzer.h new file mode 100644 index 0000000000..d0a19f6ebc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/confidential_statement_analyzer.h @@ -0,0 +1,20 @@ +#pragma once + +#include "binder/bound_statement_visitor.h" + +namespace lbug { +namespace binder { + +class ConfidentialStatementAnalyzer final : public BoundStatementVisitor { +public: + bool isConfidential() const { return confidentialStatement; } + +private: + void visitStandaloneCall(const BoundStatement& boundStatement) override; + +private: + bool confidentialStatement = false; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/default_type_solver.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/default_type_solver.h new file mode 100644 index 0000000000..381dd181bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/default_type_solver.h @@ -0,0 +1,21 @@ +#pragma once + +#include "binder/bound_statement_visitor.h" + +namespace lbug { +namespace binder { + +// Assign a default data type (STRING) for expressions with ANY data type for a given statement. +// E.g. RETURN NULL; Expression NULL can be resolved as any type based on semantic. +// We don't iterate all expressions because +// - predicates must have been resolved to BOOL type +// - lhs expressions for update must have been resolved to column type +// So we only need to resolve for expressions appear in the projection clause. This assumption might +// change as we add more features. +class DefaultTypeSolver final : public BoundStatementVisitor { +private: + void visitProjectionBody(const BoundProjectionBody& projectionBody) override; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/property_collector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/property_collector.h new file mode 100644 index 0000000000..d9a52f3621 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/binder/visitor/property_collector.h @@ -0,0 +1,42 @@ +#pragma once + +#include "binder/bound_statement_visitor.h" + +namespace lbug { +namespace binder { + +// Collect all property expressions for a given statement. +class LBUG_API PropertyCollector final : public BoundStatementVisitor { +public: + expression_vector getProperties() const; + + // Skip collecting node/rel properties if they are in WITH projection list. + // See with_clause_projection_rewriter for more details. + void visitSingleQuerySkipNodeRel(const NormalizedSingleQuery& singleQuery); + +private: + void visitQueryPartSkipNodeRel(const NormalizedQueryPart& queryPart); + + void visitMatch(const BoundReadingClause& readingClause) override; + void visitUnwind(const BoundReadingClause& readingClause) override; + void visitLoadFrom(const BoundReadingClause& readingClause) override; + void visitTableFunctionCall(const BoundReadingClause&) override; + + void visitSet(const BoundUpdatingClause& updatingClause) override; + void visitDelete(const BoundUpdatingClause& updatingClause) override; + void visitInsert(const BoundUpdatingClause& updatingClause) override; + void visitMerge(const BoundUpdatingClause& updatingClause) override; + + void visitProjectionBodySkipNodeRel(const BoundProjectionBody& projectionBody); + void visitProjectionBody(const BoundProjectionBody& projectionBody) override; + void visitProjectionBodyPredicate(const std::shared_ptr& predicate) override; + + void collectProperties(const std::shared_ptr& expression); + void collectPropertiesSkipNodeRel(const std::shared_ptr& expression); + +private: + expression_set properties; +}; + +} // namespace binder +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/c_api/helpers.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/c_api/helpers.h new file mode 100644 index 0000000000..21b386ab45 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/c_api/helpers.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include +#ifdef _WIN32 +#include + +#include + +time_t convertTmToTime(struct tm tm); + +int32_t convertTimeToTm(time_t time, struct tm* out_tm); +#endif + +char* convertToOwnedCString(const std::string& str); diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/c_api/lbug.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/c_api/lbug.h new file mode 100644 index 0000000000..c5d6cbac10 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/c_api/lbug.h @@ -0,0 +1,1588 @@ +#pragma once +#include +#include +#include +#ifdef _WIN32 +#include +#endif + +/* Export header from common/api.h */ +// Helpers +#if defined _WIN32 || defined __CYGWIN__ +#define LBUG_HELPER_DLL_IMPORT __declspec(dllimport) +#define LBUG_HELPER_DLL_EXPORT __declspec(dllexport) +#define LBUG_HELPER_DLL_LOCAL +#define LBUG_HELPER_DEPRECATED __declspec(deprecated) +#else +#define LBUG_HELPER_DLL_IMPORT __attribute__((visibility("default"))) +#define LBUG_HELPER_DLL_EXPORT __attribute__((visibility("default"))) +#define LBUG_HELPER_DLL_LOCAL __attribute__((visibility("hidden"))) +#define LBUG_HELPER_DEPRECATED __attribute__((__deprecated__)) +#endif + +#ifdef LBUG_STATIC_DEFINE +#define LBUG_API +#define LBUG_NO_EXPORT +#else +#ifndef LBUG_API +#ifdef LBUG_EXPORTS +/* We are building this library */ +#define LBUG_API LBUG_HELPER_DLL_EXPORT +#else +/* We are using this library */ +#define LBUG_API LBUG_HELPER_DLL_IMPORT +#endif +#endif + +#endif + +#ifndef LBUG_DEPRECATED +#define LBUG_DEPRECATED LBUG_HELPER_DEPRECATED +#endif + +#ifndef LBUG_DEPRECATED_EXPORT +#define LBUG_DEPRECATED_EXPORT LBUG_API LBUG_DEPRECATED +#endif +/* end export header */ + +// The Arrow C data interface. +// https://arrow.apache.org/docs/format/CDataInterface.html + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + +#ifdef __cplusplus +} +#endif + +#ifdef __cplusplus +#define LBUG_C_API extern "C" LBUG_API +#else +#define LBUG_C_API LBUG_API +#endif + +/** + * @brief Stores runtime configuration for creating or opening a Database + */ +typedef struct { + // bufferPoolSize Max size of the buffer pool in bytes. + // The larger the buffer pool, the more data from the database files is kept in memory, + // reducing the amount of File I/O + uint64_t buffer_pool_size; + // The maximum number of threads to use during query execution + uint64_t max_num_threads; + // Whether or not to compress data on-disk for supported types + bool enable_compression; + // If true, open the database in read-only mode. No write transaction is allowed on the Database + // object. If false, open the database read-write. + bool read_only; + // The maximum size of the database in bytes. Note that this is introduced temporarily for now + // to get around with the default 8TB mmap address space limit under some environment. This + // will be removed once we implemente a better solution later. The value is default to 1 << 43 + // (8TB) under 64-bit environment and 1GB under 32-bit one (see `DEFAULT_VM_REGION_MAX_SIZE`). + uint64_t max_db_size; + // If true, the database will automatically checkpoint when the size of + // the WAL file exceeds the checkpoint threshold. + bool auto_checkpoint; + // The threshold of the WAL file size in bytes. When the size of the + // WAL file exceeds this threshold, the database will checkpoint if auto_checkpoint is true. + uint64_t checkpoint_threshold; + +#if defined(__APPLE__) + // The thread quality of service (QoS) for the worker threads. + // This works for Swift bindings on Apple platforms only. + uint32_t thread_qos; +#endif +} lbug_system_config; + +/** + * @brief lbug_database manages all database components. + */ +typedef struct { + void* _database; +} lbug_database; + +/** + * @brief lbug_connection is used to interact with a Database instance. Each connection is + * thread-safe. Multiple connections can connect to the same Database instance in a multi-threaded + * environment. + */ +typedef struct { + void* _connection; +} lbug_connection; + +/** + * @brief lbug_prepared_statement is a parameterized query which can avoid planning the same query + * for repeated execution. + */ +typedef struct { + void* _prepared_statement; + void* _bound_values; +} lbug_prepared_statement; + +/** + * @brief lbug_query_result stores the result of a query. + */ +typedef struct { + void* _query_result; + bool _is_owned_by_cpp; +} lbug_query_result; + +/** + * @brief lbug_flat_tuple stores a vector of values. + */ +typedef struct { + void* _flat_tuple; + bool _is_owned_by_cpp; +} lbug_flat_tuple; + +/** + * @brief lbug_logical_type is the lbug internal representation of data types. + */ +typedef struct { + void* _data_type; +} lbug_logical_type; + +/** + * @brief lbug_value is used to represent a value with any lbug internal dataType. + */ +typedef struct { + void* _value; + bool _is_owned_by_cpp; +} lbug_value; + +/** + * @brief lbug internal internal_id type which stores the table_id and offset of a node/rel. + */ +typedef struct { + uint64_t table_id; + uint64_t offset; +} lbug_internal_id_t; + +/** + * @brief lbug internal date type which stores the number of days since 1970-01-01 00:00:00 UTC. + */ +typedef struct { + // Days since 1970-01-01 00:00:00 UTC. + int32_t days; +} lbug_date_t; + +/** + * @brief lbug internal timestamp_ns type which stores the number of nanoseconds since 1970-01-01 + * 00:00:00 UTC. + */ +typedef struct { + // Nanoseconds since 1970-01-01 00:00:00 UTC. + int64_t value; +} lbug_timestamp_ns_t; + +/** + * @brief lbug internal timestamp_ms type which stores the number of milliseconds since 1970-01-01 + * 00:00:00 UTC. + */ +typedef struct { + // Milliseconds since 1970-01-01 00:00:00 UTC. + int64_t value; +} lbug_timestamp_ms_t; + +/** + * @brief lbug internal timestamp_sec_t type which stores the number of seconds since 1970-01-01 + * 00:00:00 UTC. + */ +typedef struct { + // Seconds since 1970-01-01 00:00:00 UTC. + int64_t value; +} lbug_timestamp_sec_t; + +/** + * @brief lbug internal timestamp_tz type which stores the number of microseconds since 1970-01-01 + * with timezone 00:00:00 UTC. + */ +typedef struct { + // Microseconds since 1970-01-01 00:00:00 UTC. + int64_t value; +} lbug_timestamp_tz_t; + +/** + * @brief lbug internal timestamp type which stores the number of microseconds since 1970-01-01 + * 00:00:00 UTC. + */ +typedef struct { + // Microseconds since 1970-01-01 00:00:00 UTC. + int64_t value; +} lbug_timestamp_t; + +/** + * @brief lbug internal interval type which stores the months, days and microseconds. + */ +typedef struct { + int32_t months; + int32_t days; + int64_t micros; +} lbug_interval_t; + +/** + * @brief lbug_query_summary stores the execution time, plan, compiling time and query options of a + * query. + */ +typedef struct { + void* _query_summary; +} lbug_query_summary; + +typedef struct { + uint64_t low; + int64_t high; +} lbug_int128_t; + +/** + * @brief enum class for lbug internal dataTypes. + */ +typedef enum { + LBUG_ANY = 0, + LBUG_NODE = 10, + LBUG_REL = 11, + LBUG_RECURSIVE_REL = 12, + // SERIAL is a special data type that is used to represent a sequence of INT64 values that are + // incremented by 1 starting from 0. + LBUG_SERIAL = 13, + // fixed size types + LBUG_BOOL = 22, + LBUG_INT64 = 23, + LBUG_INT32 = 24, + LBUG_INT16 = 25, + LBUG_INT8 = 26, + LBUG_UINT64 = 27, + LBUG_UINT32 = 28, + LBUG_UINT16 = 29, + LBUG_UINT8 = 30, + LBUG_INT128 = 31, + LBUG_DOUBLE = 32, + LBUG_FLOAT = 33, + LBUG_DATE = 34, + LBUG_TIMESTAMP = 35, + LBUG_TIMESTAMP_SEC = 36, + LBUG_TIMESTAMP_MS = 37, + LBUG_TIMESTAMP_NS = 38, + LBUG_TIMESTAMP_TZ = 39, + LBUG_INTERVAL = 40, + LBUG_DECIMAL = 41, + LBUG_INTERNAL_ID = 42, + // variable size types + LBUG_STRING = 50, + LBUG_BLOB = 51, + LBUG_LIST = 52, + LBUG_ARRAY = 53, + LBUG_STRUCT = 54, + LBUG_MAP = 55, + LBUG_UNION = 56, + LBUG_POINTER = 58, + LBUG_UUID = 59 +} lbug_data_type_id; + +/** + * @brief enum class for lbug function return state. + */ +typedef enum { LbugSuccess = 0, LbugError = 1 } lbug_state; + +// Database +/** + * @brief Allocates memory and creates a lbug database instance at database_path with + * bufferPoolSize=buffer_pool_size. Caller is responsible for calling lbug_database_destroy() to + * release the allocated memory. + * @param database_path The path to the database. + * @param system_config The runtime configuration for creating or opening the database. + * @param[out] out_database The output parameter that will hold the database instance. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_database_init(const char* database_path, + lbug_system_config system_config, lbug_database* out_database); +/** + * @brief Destroys the lbug database instance and frees the allocated memory. + * @param database The database instance to destroy. + */ +LBUG_C_API void lbug_database_destroy(lbug_database* database); + +LBUG_C_API lbug_system_config lbug_default_system_config(); + +// Connection +/** + * @brief Allocates memory and creates a connection to the database. Caller is responsible for + * calling lbug_connection_destroy() to release the allocated memory. + * @param database The database instance to connect to. + * @param[out] out_connection The output parameter that will hold the connection instance. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_init(lbug_database* database, + lbug_connection* out_connection); +/** + * @brief Destroys the connection instance and frees the allocated memory. + * @param connection The connection instance to destroy. + */ +LBUG_C_API void lbug_connection_destroy(lbug_connection* connection); +/** + * @brief Sets the maximum number of threads to use for executing queries. + * @param connection The connection instance to set max number of threads for execution. + * @param num_threads The maximum number of threads to use for executing queries. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_set_max_num_thread_for_exec(lbug_connection* connection, + uint64_t num_threads); + +/** + * @brief Returns the maximum number of threads of the connection to use for executing queries. + * @param connection The connection instance to return max number of threads for execution. + * @param[out] out_result The output parameter that will hold the maximum number of threads to use + * for executing queries. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_get_max_num_thread_for_exec(lbug_connection* connection, + uint64_t* out_result); +/** + * @brief Executes the given query and returns the result. + * @param connection The connection instance to execute the query. + * @param query The query to execute. + * @param[out] out_query_result The output parameter that will hold the result of the query. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_query(lbug_connection* connection, const char* query, + lbug_query_result* out_query_result); +/** + * @brief Prepares the given query and returns the prepared statement. + * @param connection The connection instance to prepare the query. + * @param query The query to prepare. + * @param[out] out_prepared_statement The output parameter that will hold the prepared statement. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_prepare(lbug_connection* connection, const char* query, + lbug_prepared_statement* out_prepared_statement); +/** + * @brief Executes the prepared_statement using connection. + * @param connection The connection instance to execute the prepared_statement. + * @param prepared_statement The prepared statement to execute. + * @param[out] out_query_result The output parameter that will hold the result of the query. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_execute(lbug_connection* connection, + lbug_prepared_statement* prepared_statement, lbug_query_result* out_query_result); +/** + * @brief Interrupts the current query execution in the connection. + * @param connection The connection instance to interrupt. + */ +LBUG_C_API void lbug_connection_interrupt(lbug_connection* connection); +/** + * @brief Sets query timeout value in milliseconds for the connection. + * @param connection The connection instance to set query timeout value. + * @param timeout_in_ms The timeout value in milliseconds. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_connection_set_query_timeout(lbug_connection* connection, + uint64_t timeout_in_ms); + +// PreparedStatement +/** + * @brief Destroys the prepared statement instance and frees the allocated memory. + * @param prepared_statement The prepared statement instance to destroy. + */ +LBUG_C_API void lbug_prepared_statement_destroy(lbug_prepared_statement* prepared_statement); +/** + * @return the query is prepared successfully or not. + */ +LBUG_C_API bool lbug_prepared_statement_is_success(lbug_prepared_statement* prepared_statement); +/** + * @brief Returns the error message if the prepared statement is not prepared successfully. + * The caller is responsible for freeing the returned string with `lbug_destroy_string`. + * @param prepared_statement The prepared statement instance. + * @return the error message if the statement is not prepared successfully or null + * if the statement is prepared successfully. + */ +LBUG_C_API char* lbug_prepared_statement_get_error_message( + lbug_prepared_statement* prepared_statement); +/** + * @brief Binds the given boolean value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The boolean value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_bool(lbug_prepared_statement* prepared_statement, + const char* param_name, bool value); +/** + * @brief Binds the given int64_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The int64_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_int64( + lbug_prepared_statement* prepared_statement, const char* param_name, int64_t value); +/** + * @brief Binds the given int32_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The int32_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_int32( + lbug_prepared_statement* prepared_statement, const char* param_name, int32_t value); +/** + * @brief Binds the given int16_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The int16_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_int16( + lbug_prepared_statement* prepared_statement, const char* param_name, int16_t value); +/** + * @brief Binds the given int8_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The int8_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_int8(lbug_prepared_statement* prepared_statement, + const char* param_name, int8_t value); +/** + * @brief Binds the given uint64_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The uint64_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_uint64( + lbug_prepared_statement* prepared_statement, const char* param_name, uint64_t value); +/** + * @brief Binds the given uint32_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The uint32_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_uint32( + lbug_prepared_statement* prepared_statement, const char* param_name, uint32_t value); +/** + * @brief Binds the given uint16_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The uint16_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_uint16( + lbug_prepared_statement* prepared_statement, const char* param_name, uint16_t value); +/** + * @brief Binds the given int8_t value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The int8_t value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_uint8( + lbug_prepared_statement* prepared_statement, const char* param_name, uint8_t value); + +/** + * @brief Binds the given double value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The double value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_double( + lbug_prepared_statement* prepared_statement, const char* param_name, double value); +/** + * @brief Binds the given float value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The float value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_float( + lbug_prepared_statement* prepared_statement, const char* param_name, float value); +/** + * @brief Binds the given date value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The date value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_date(lbug_prepared_statement* prepared_statement, + const char* param_name, lbug_date_t value); +/** + * @brief Binds the given timestamp_ns value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The timestamp_ns value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_timestamp_ns( + lbug_prepared_statement* prepared_statement, const char* param_name, lbug_timestamp_ns_t value); +/** + * @brief Binds the given timestamp_sec value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The timestamp_sec value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_timestamp_sec( + lbug_prepared_statement* prepared_statement, const char* param_name, + lbug_timestamp_sec_t value); +/** + * @brief Binds the given timestamp_tz value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The timestamp_tz value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_timestamp_tz( + lbug_prepared_statement* prepared_statement, const char* param_name, lbug_timestamp_tz_t value); +/** + * @brief Binds the given timestamp_ms value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The timestamp_ms value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_timestamp_ms( + lbug_prepared_statement* prepared_statement, const char* param_name, lbug_timestamp_ms_t value); +/** + * @brief Binds the given timestamp value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The timestamp value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_timestamp( + lbug_prepared_statement* prepared_statement, const char* param_name, lbug_timestamp_t value); +/** + * @brief Binds the given interval value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The interval value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_interval( + lbug_prepared_statement* prepared_statement, const char* param_name, lbug_interval_t value); +/** + * @brief Binds the given string value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The string value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_string( + lbug_prepared_statement* prepared_statement, const char* param_name, const char* value); +/** + * @brief Binds the given lbug value to the given parameter name in the prepared statement. + * @param prepared_statement The prepared statement instance to bind the value. + * @param param_name The parameter name to bind the value. + * @param value The lbug value to bind. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_prepared_statement_bind_value( + lbug_prepared_statement* prepared_statement, const char* param_name, lbug_value* value); + +// QueryResult +/** + * @brief Destroys the given query result instance. + * @param query_result The query result instance to destroy. + */ +LBUG_C_API void lbug_query_result_destroy(lbug_query_result* query_result); +/** + * @brief Returns true if the query is executed successful, false otherwise. + * @param query_result The query result instance to check. + */ +LBUG_C_API bool lbug_query_result_is_success(lbug_query_result* query_result); +/** + * @brief Returns the error message if the query is failed. + * The caller is responsible for freeing the returned string with `lbug_destroy_string`. + * @param query_result The query result instance to check and return error message. + * @return The error message if the query has failed, or null if the query is successful. + */ +LBUG_C_API char* lbug_query_result_get_error_message(lbug_query_result* query_result); +/** + * @brief Returns the number of columns in the query result. + * @param query_result The query result instance to return. + */ +LBUG_C_API uint64_t lbug_query_result_get_num_columns(lbug_query_result* query_result); +/** + * @brief Returns the column name at the given index. + * @param query_result The query result instance to return. + * @param index The index of the column to return name. + * @param[out] out_column_name The output parameter that will hold the column name. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_query_result_get_column_name(lbug_query_result* query_result, + uint64_t index, char** out_column_name); +/** + * @brief Returns the data type of the column at the given index. + * @param query_result The query result instance to return. + * @param index The index of the column to return data type. + * @param[out] out_column_data_type The output parameter that will hold the column data type. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_query_result_get_column_data_type(lbug_query_result* query_result, + uint64_t index, lbug_logical_type* out_column_data_type); +/** + * @brief Returns the number of tuples in the query result. + * @param query_result The query result instance to return. + */ +LBUG_C_API uint64_t lbug_query_result_get_num_tuples(lbug_query_result* query_result); +/** + * @brief Returns the query summary of the query result. + * @param query_result The query result instance to return. + * @param[out] out_query_summary The output parameter that will hold the query summary. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_query_result_get_query_summary(lbug_query_result* query_result, + lbug_query_summary* out_query_summary); +/** + * @brief Returns true if we have not consumed all tuples in the query result, false otherwise. + * @param query_result The query result instance to check. + */ +LBUG_C_API bool lbug_query_result_has_next(lbug_query_result* query_result); +/** + * @brief Returns the next tuple in the query result. Throws an exception if there is no more tuple. + * Note that to reduce resource allocation, all calls to lbug_query_result_get_next() reuse the same + * FlatTuple object. Since its contents will be overwritten, please complete processing a FlatTuple + * or make a copy of its data before calling lbug_query_result_get_next() again. + * @param query_result The query result instance to return. + * @param[out] out_flat_tuple The output parameter that will hold the next tuple. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_query_result_get_next(lbug_query_result* query_result, + lbug_flat_tuple* out_flat_tuple); +/** + * @brief Returns true if we have not consumed all query results, false otherwise. Use this function + * for loop results of multiple query statements + * @param query_result The query result instance to check. + */ +LBUG_C_API bool lbug_query_result_has_next_query_result(lbug_query_result* query_result); +/** + * @brief Returns the next query result. Use this function to loop multiple query statements' + * results. + * @param query_result The query result instance to return. + * @param[out] out_next_query_result The output parameter that will hold the next query result. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_query_result_get_next_query_result(lbug_query_result* query_result, + lbug_query_result* out_next_query_result); + +/** + * @brief Returns the query result as a string. + * @param query_result The query result instance to return. + * @return The query result as a string. + */ +LBUG_C_API char* lbug_query_result_to_string(lbug_query_result* query_result); +/** + * @brief Resets the iterator of the query result to the beginning of the query result. + * @param query_result The query result instance to reset iterator. + */ +LBUG_C_API void lbug_query_result_reset_iterator(lbug_query_result* query_result); + +/** + * @brief Returns the query result's schema as ArrowSchema. + * @param query_result The query result instance to return. + * @param[out] out_schema The output parameter that will hold the datatypes of the columns as an + * arrow schema. + * @return The state indicating the success or failure of the operation. + * + * It is the caller's responsibility to call the release function to release the underlying data + */ +LBUG_C_API lbug_state lbug_query_result_get_arrow_schema(lbug_query_result* query_result, + struct ArrowSchema* out_schema); + +/** + * @brief Returns the next chunk of the query result as ArrowArray. + * @param query_result The query result instance to return. + * @param chunk_size The number of tuples to return in the chunk. + * @param[out] out_arrow_array The output parameter that will hold the arrow array representation of + * the query result. The arrow array internally stores an arrow struct with fields for each of the + * columns. + * @return The state indicating the success or failure of the operation. + * + * It is the caller's responsibility to call the release function to release the underlying data + */ +LBUG_C_API lbug_state lbug_query_result_get_next_arrow_chunk(lbug_query_result* query_result, + int64_t chunk_size, struct ArrowArray* out_arrow_array); + +// FlatTuple +/** + * @brief Destroys the given flat tuple instance. + * @param flat_tuple The flat tuple instance to destroy. + */ +LBUG_C_API void lbug_flat_tuple_destroy(lbug_flat_tuple* flat_tuple); +/** + * @brief Returns the value at index of the flat tuple. + * @param flat_tuple The flat tuple instance to return. + * @param index The index of the value to return. + * @param[out] out_value The output parameter that will hold the value at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_flat_tuple_get_value(lbug_flat_tuple* flat_tuple, uint64_t index, + lbug_value* out_value); +/** + * @brief Converts the flat tuple to a string. + * @param flat_tuple The flat tuple instance to convert. + * @return The flat tuple as a string. + */ +LBUG_C_API char* lbug_flat_tuple_to_string(lbug_flat_tuple* flat_tuple); + +// DataType +// TODO(Chang): Refactor the datatype constructor to follow the cpp way of creating dataTypes. +/** + * @brief Creates a data type instance with the given id, childType and num_elements_in_array. + * Caller is responsible for destroying the returned data type instance. + * @param id The enum type id of the datatype to create. + * @param child_type The child type of the datatype to create(only used for nested dataTypes). + * @param num_elements_in_array The number of elements in the array(only used for ARRAY). + * @param[out] out_type The output parameter that will hold the data type instance. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API void lbug_data_type_create(lbug_data_type_id id, lbug_logical_type* child_type, + uint64_t num_elements_in_array, lbug_logical_type* out_type); +/** + * @brief Creates a new data type instance by cloning the given data type instance. + * @param data_type The data type instance to clone. + * @param[out] out_type The output parameter that will hold the cloned data type instance. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API void lbug_data_type_clone(lbug_logical_type* data_type, lbug_logical_type* out_type); +/** + * @brief Destroys the given data type instance. + * @param data_type The data type instance to destroy. + */ +LBUG_C_API void lbug_data_type_destroy(lbug_logical_type* data_type); +/** + * @brief Returns true if the given data type is equal to the other data type, false otherwise. + * @param data_type1 The first data type instance to compare. + * @param data_type2 The second data type instance to compare. + */ +LBUG_C_API bool lbug_data_type_equals(lbug_logical_type* data_type1, lbug_logical_type* data_type2); +/** + * @brief Returns the enum type id of the given data type. + * @param data_type The data type instance to return. + */ +LBUG_C_API lbug_data_type_id lbug_data_type_get_id(lbug_logical_type* data_type); +/** + * @brief Returns the number of elements for array. + * @param data_type The data type instance to return. + * @param[out] out_result The output parameter that will hold the number of elements in the array. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_data_type_get_num_elements_in_array(lbug_logical_type* data_type, + uint64_t* out_result); + +// Value +/** + * @brief Creates a NULL value of ANY type. Caller is responsible for destroying the returned value. + */ +LBUG_C_API lbug_value* lbug_value_create_null(); +/** + * @brief Creates a value of the given data type. Caller is responsible for destroying the + * returned value. + * @param data_type The data type of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_null_with_data_type(lbug_logical_type* data_type); +/** + * @brief Returns true if the given value is NULL, false otherwise. + * @param value The value instance to check. + */ +LBUG_C_API bool lbug_value_is_null(lbug_value* value); +/** + * @brief Sets the given value to NULL or not. + * @param value The value instance to set. + * @param is_null True if sets the value to NULL, false otherwise. + */ +LBUG_C_API void lbug_value_set_null(lbug_value* value, bool is_null); +/** + * @brief Creates a value of the given data type with default non-NULL value. Caller is responsible + * for destroying the returned value. + * @param data_type The data type of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_default(lbug_logical_type* data_type); +/** + * @brief Creates a value with boolean type and the given bool value. Caller is responsible for + * destroying the returned value. + * @param val_ The bool value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_bool(bool val_); +/** + * @brief Creates a value with int8 type and the given int8 value. Caller is responsible for + * destroying the returned value. + * @param val_ The int8 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_int8(int8_t val_); +/** + * @brief Creates a value with int16 type and the given int16 value. Caller is responsible for + * destroying the returned value. + * @param val_ The int16 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_int16(int16_t val_); +/** + * @brief Creates a value with int32 type and the given int32 value. Caller is responsible for + * destroying the returned value. + * @param val_ The int32 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_int32(int32_t val_); +/** + * @brief Creates a value with int64 type and the given int64 value. Caller is responsible for + * destroying the returned value. + * @param val_ The int64 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_int64(int64_t val_); +/** + * @brief Creates a value with uint8 type and the given uint8 value. Caller is responsible for + * destroying the returned value. + * @param val_ The uint8 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_uint8(uint8_t val_); +/** + * @brief Creates a value with uint16 type and the given uint16 value. Caller is responsible for + * destroying the returned value. + * @param val_ The uint16 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_uint16(uint16_t val_); +/** + * @brief Creates a value with uint32 type and the given uint32 value. Caller is responsible for + * destroying the returned value. + * @param val_ The uint32 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_uint32(uint32_t val_); +/** + * @brief Creates a value with uint64 type and the given uint64 value. Caller is responsible for + * destroying the returned value. + * @param val_ The uint64 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_uint64(uint64_t val_); +/** + * @brief Creates a value with int128 type and the given int128 value. Caller is responsible for + * destroying the returned value. + * @param val_ The int128 value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_int128(lbug_int128_t val_); +/** + * @brief Creates a value with float type and the given float value. Caller is responsible for + * destroying the returned value. + * @param val_ The float value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_float(float val_); +/** + * @brief Creates a value with double type and the given double value. Caller is responsible for + * destroying the returned value. + * @param val_ The double value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_double(double val_); +/** + * @brief Creates a value with internal_id type and the given internal_id value. Caller is + * responsible for destroying the returned value. + * @param val_ The internal_id value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_internal_id(lbug_internal_id_t val_); +/** + * @brief Creates a value with date type and the given date value. Caller is responsible for + * destroying the returned value. + * @param val_ The date value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_date(lbug_date_t val_); +/** + * @brief Creates a value with timestamp_ns type and the given timestamp value. Caller is + * responsible for destroying the returned value. + * @param val_ The timestamp_ns value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_timestamp_ns(lbug_timestamp_ns_t val_); +/** + * @brief Creates a value with timestamp_ms type and the given timestamp value. Caller is + * responsible for destroying the returned value. + * @param val_ The timestamp_ms value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_timestamp_ms(lbug_timestamp_ms_t val_); +/** + * @brief Creates a value with timestamp_sec type and the given timestamp value. Caller is + * responsible for destroying the returned value. + * @param val_ The timestamp_sec value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_timestamp_sec(lbug_timestamp_sec_t val_); +/** + * @brief Creates a value with timestamp_tz type and the given timestamp value. Caller is + * responsible for destroying the returned value. + * @param val_ The timestamp_tz value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_timestamp_tz(lbug_timestamp_tz_t val_); +/** + * @brief Creates a value with timestamp type and the given timestamp value. Caller is responsible + * for destroying the returned value. + * @param val_ The timestamp value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_timestamp(lbug_timestamp_t val_); +/** + * @brief Creates a value with interval type and the given interval value. Caller is responsible + * for destroying the returned value. + * @param val_ The interval value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_interval(lbug_interval_t val_); +/** + * @brief Creates a value with string type and the given string value. Caller is responsible for + * destroying the returned value. + * @param val_ The string value of the value to create. + */ +LBUG_C_API lbug_value* lbug_value_create_string(const char* val_); +/** + * @brief Creates a list value with the given number of elements and the given elements. + * The caller needs to make sure that all elements have the same type. + * The elements are copied into the list value, so destroying the elements after creating the list + * value is safe. + * Caller is responsible for destroying the returned value. + * @param num_elements The number of elements in the list. + * @param elements The elements of the list. + * @param[out] out_value The output parameter that will hold a pointer to the created list value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_create_list(uint64_t num_elements, lbug_value** elements, + lbug_value** out_value); +/** + * @brief Creates a struct value with the given number of fields and the given field names and + * values. The caller needs to make sure that all field names are unique. + * The field names and values are copied into the struct value, so destroying the field names and + * values after creating the struct value is safe. + * Caller is responsible for destroying the returned value. + * @param num_fields The number of fields in the struct. + * @param field_names The field names of the struct. + * @param field_values The field values of the struct. + * @param[out] out_value The output parameter that will hold a pointer to the created struct value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_create_struct(uint64_t num_fields, const char** field_names, + lbug_value** field_values, lbug_value** out_value); +/** + * @brief Creates a map value with the given number of fields and the given keys and values. The + * caller needs to make sure that all keys are unique, and all keys and values have the same type. + * The keys and values are copied into the map value, so destroying the keys and values after + * creating the map value is safe. + * Caller is responsible for destroying the returned value. + * @param num_fields The number of fields in the map. + * @param keys The keys of the map. + * @param values The values of the map. + * @param[out] out_value The output parameter that will hold a pointer to the created map value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_create_map(uint64_t num_fields, lbug_value** keys, + lbug_value** values, lbug_value** out_value); +/** + * @brief Creates a new value based on the given value. Caller is responsible for destroying the + * returned value. + * @param value The value to create from. + */ +LBUG_C_API lbug_value* lbug_value_clone(lbug_value* value); +/** + * @brief Copies the other value to the value. + * @param value The value to copy to. + * @param other The value to copy from. + */ +LBUG_C_API void lbug_value_copy(lbug_value* value, lbug_value* other); +/** + * @brief Destroys the value. + * @param value The value to destroy. + */ +LBUG_C_API void lbug_value_destroy(lbug_value* value); +/** + * @brief Returns the number of elements per list of the given value. The value must be of type + * ARRAY. + * @param value The ARRAY value to get list size. + * @param[out] out_result The output parameter that will hold the number of elements per list. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_list_size(lbug_value* value, uint64_t* out_result); +/** + * @brief Returns the element at index of the given value. The value must be of type LIST. + * @param value The LIST value to return. + * @param index The index of the element to return. + * @param[out] out_value The output parameter that will hold the element at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_list_element(lbug_value* value, uint64_t index, + lbug_value* out_value); +/** + * @brief Returns the number of fields of the given struct value. The value must be of type STRUCT. + * @param value The STRUCT value to get number of fields. + * @param[out] out_result The output parameter that will hold the number of fields. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_struct_num_fields(lbug_value* value, uint64_t* out_result); +/** + * @brief Returns the field name at index of the given struct value. The value must be of physical + * type STRUCT (STRUCT, NODE, REL, RECURSIVE_REL, UNION). + * @param value The STRUCT value to get field name. + * @param index The index of the field name to return. + * @param[out] out_result The output parameter that will hold the field name at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_struct_field_name(lbug_value* value, uint64_t index, + char** out_result); +/** + * @brief Returns the field value at index of the given struct value. The value must be of physical + * type STRUCT (STRUCT, NODE, REL, RECURSIVE_REL, UNION). + * @param value The STRUCT value to get field value. + * @param index The index of the field value to return. + * @param[out] out_value The output parameter that will hold the field value at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_struct_field_value(lbug_value* value, uint64_t index, + lbug_value* out_value); + +/** + * @brief Returns the size of the given map value. The value must be of type MAP. + * @param value The MAP value to get size. + * @param[out] out_result The output parameter that will hold the size of the map. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_map_size(lbug_value* value, uint64_t* out_result); +/** + * @brief Returns the key at index of the given map value. The value must be of physical + * type MAP. + * @param value The MAP value to get key. + * @param index The index of the field name to return. + * @param[out] out_key The output parameter that will hold the key at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_map_key(lbug_value* value, uint64_t index, + lbug_value* out_key); +/** + * @brief Returns the field value at index of the given map value. The value must be of physical + * type MAP. + * @param value The MAP value to get field value. + * @param index The index of the field value to return. + * @param[out] out_value The output parameter that will hold the field value at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_map_value(lbug_value* value, uint64_t index, + lbug_value* out_value); +/** + * @brief Returns the list of nodes for recursive rel value. The value must be of type + * RECURSIVE_REL. + * @param value The RECURSIVE_REL value to return. + * @param[out] out_value The output parameter that will hold the list of nodes. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_recursive_rel_node_list(lbug_value* value, + lbug_value* out_value); + +/** + * @brief Returns the list of rels for recursive rel value. The value must be of type RECURSIVE_REL. + * @param value The RECURSIVE_REL value to return. + * @param[out] out_value The output parameter that will hold the list of rels. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_recursive_rel_rel_list(lbug_value* value, + lbug_value* out_value); +/** + * @brief Returns internal type of the given value. + * @param value The value to return. + * @param[out] out_type The output parameter that will hold the internal type of the value. + */ +LBUG_C_API void lbug_value_get_data_type(lbug_value* value, lbug_logical_type* out_type); +/** + * @brief Returns the boolean value of the given value. The value must be of type BOOL. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the boolean value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_bool(lbug_value* value, bool* out_result); +/** + * @brief Returns the int8 value of the given value. The value must be of type INT8. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the int8 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_int8(lbug_value* value, int8_t* out_result); +/** + * @brief Returns the int16 value of the given value. The value must be of type INT16. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the int16 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_int16(lbug_value* value, int16_t* out_result); +/** + * @brief Returns the int32 value of the given value. The value must be of type INT32. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the int32 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_int32(lbug_value* value, int32_t* out_result); +/** + * @brief Returns the int64 value of the given value. The value must be of type INT64 or SERIAL. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the int64 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_int64(lbug_value* value, int64_t* out_result); +/** + * @brief Returns the uint8 value of the given value. The value must be of type UINT8. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the uint8 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_uint8(lbug_value* value, uint8_t* out_result); +/** + * @brief Returns the uint16 value of the given value. The value must be of type UINT16. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the uint16 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_uint16(lbug_value* value, uint16_t* out_result); +/** + * @brief Returns the uint32 value of the given value. The value must be of type UINT32. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the uint32 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_uint32(lbug_value* value, uint32_t* out_result); +/** + * @brief Returns the uint64 value of the given value. The value must be of type UINT64. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the uint64 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_uint64(lbug_value* value, uint64_t* out_result); +/** + * @brief Returns the int128 value of the given value. The value must be of type INT128. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the int128 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_int128(lbug_value* value, lbug_int128_t* out_result); +/** + * @brief convert a string to int128 value. + * @param str The string to convert. + * @param[out] out_result The output parameter that will hold the int128 value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_int128_t_from_string(const char* str, lbug_int128_t* out_result); +/** + * @brief convert int128 to corresponding string. + * @param val The int128 value to convert. + * @param[out] out_result The output parameter that will hold the string value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_int128_t_to_string(lbug_int128_t val, char** out_result); +/** + * @brief Returns the float value of the given value. The value must be of type FLOAT. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the float value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_float(lbug_value* value, float* out_result); +/** + * @brief Returns the double value of the given value. The value must be of type DOUBLE. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the double value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_double(lbug_value* value, double* out_result); +/** + * @brief Returns the internal id value of the given value. The value must be of type INTERNAL_ID. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the internal id value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_internal_id(lbug_value* value, lbug_internal_id_t* out_result); +/** + * @brief Returns the date value of the given value. The value must be of type DATE. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the date value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_date(lbug_value* value, lbug_date_t* out_result); +/** + * @brief Returns the timestamp value of the given value. The value must be of type TIMESTAMP. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the timestamp value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_timestamp(lbug_value* value, lbug_timestamp_t* out_result); +/** + * @brief Returns the timestamp_ns value of the given value. The value must be of type TIMESTAMP_NS. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the timestamp_ns value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_timestamp_ns(lbug_value* value, + lbug_timestamp_ns_t* out_result); +/** + * @brief Returns the timestamp_ms value of the given value. The value must be of type TIMESTAMP_MS. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the timestamp_ms value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_timestamp_ms(lbug_value* value, + lbug_timestamp_ms_t* out_result); +/** + * @brief Returns the timestamp_sec value of the given value. The value must be of type + * TIMESTAMP_SEC. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the timestamp_sec value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_timestamp_sec(lbug_value* value, + lbug_timestamp_sec_t* out_result); +/** + * @brief Returns the timestamp_tz value of the given value. The value must be of type TIMESTAMP_TZ. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the timestamp_tz value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_timestamp_tz(lbug_value* value, + lbug_timestamp_tz_t* out_result); +/** + * @brief Returns the interval value of the given value. The value must be of type INTERVAL. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the interval value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_interval(lbug_value* value, lbug_interval_t* out_result); +/** + * @brief Returns the decimal value of the given value as a string. The value must be of type + * DECIMAL. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the decimal value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_decimal_as_string(lbug_value* value, char** out_result); +/** + * @brief Returns the string value of the given value. The value must be of type STRING. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the string value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_string(lbug_value* value, char** out_result); +/** + * @brief Returns the blob value of the given value. The value must be of type BLOB. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the blob value. + * @param[out] out_length The output parameter that will hold the length of the blob. + * @return The state indicating the success or failure of the operation. + * @note The caller is responsible for freeing the returned memory using `lbug_destroy_blob`. + */ +LBUG_C_API lbug_state lbug_value_get_blob(lbug_value* value, uint8_t** out_result, + uint64_t* out_length); +/** + * @brief Returns the uuid value of the given value. + * to a string. The value must be of type UUID. + * @param value The value to return. + * @param[out] out_result The output parameter that will hold the uuid value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_value_get_uuid(lbug_value* value, char** out_result); +/** + * @brief Converts the given value to string. + * @param value The value to convert. + * @return The value as a string. + */ +LBUG_C_API char* lbug_value_to_string(lbug_value* value); +/** + * @brief Returns the internal id value of the given node value as a lbug value. + * @param node_val The node value to return. + * @param[out] out_value The output parameter that will hold the internal id value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_node_val_get_id_val(lbug_value* node_val, lbug_value* out_value); +/** + * @brief Returns the label value of the given node value as a label value. + * @param node_val The node value to return. + * @param[out] out_value The output parameter that will hold the label value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_node_val_get_label_val(lbug_value* node_val, lbug_value* out_value); +/** + * @brief Returns the number of properties of the given node value. + * @param node_val The node value to return. + * @param[out] out_value The output parameter that will hold the number of properties. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_node_val_get_property_size(lbug_value* node_val, uint64_t* out_value); +/** + * @brief Returns the property name of the given node value at the given index. + * @param node_val The node value to return. + * @param index The index of the property. + * @param[out] out_result The output parameter that will hold the property name at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_node_val_get_property_name_at(lbug_value* node_val, uint64_t index, + char** out_result); +/** + * @brief Returns the property value of the given node value at the given index. + * @param node_val The node value to return. + * @param index The index of the property. + * @param[out] out_value The output parameter that will hold the property value at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_node_val_get_property_value_at(lbug_value* node_val, uint64_t index, + lbug_value* out_value); +/** + * @brief Converts the given node value to string. + * @param node_val The node value to convert. + * @param[out] out_result The output parameter that will hold the node value as a string. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_node_val_to_string(lbug_value* node_val, char** out_result); +/** + * @brief Returns the internal id value of the rel value as a lbug value. + * @param rel_val The rel value to return. + * @param[out] out_value The output parameter that will hold the internal id value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_id_val(lbug_value* rel_val, lbug_value* out_value); +/** + * @brief Returns the internal id value of the source node of the given rel value as a lbug value. + * @param rel_val The rel value to return. + * @param[out] out_value The output parameter that will hold the internal id value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_src_id_val(lbug_value* rel_val, lbug_value* out_value); +/** + * @brief Returns the internal id value of the destination node of the given rel value as a lbug + * value. + * @param rel_val The rel value to return. + * @param[out] out_value The output parameter that will hold the internal id value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_dst_id_val(lbug_value* rel_val, lbug_value* out_value); +/** + * @brief Returns the label value of the given rel value. + * @param rel_val The rel value to return. + * @param[out] out_value The output parameter that will hold the label value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_label_val(lbug_value* rel_val, lbug_value* out_value); +/** + * @brief Returns the number of properties of the given rel value. + * @param rel_val The rel value to return. + * @param[out] out_value The output parameter that will hold the number of properties. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_property_size(lbug_value* rel_val, uint64_t* out_value); +/** + * @brief Returns the property name of the given rel value at the given index. + * @param rel_val The rel value to return. + * @param index The index of the property. + * @param[out] out_result The output parameter that will hold the property name at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_property_name_at(lbug_value* rel_val, uint64_t index, + char** out_result); +/** + * @brief Returns the property of the given rel value at the given index as lbug value. + * @param rel_val The rel value to return. + * @param index The index of the property. + * @param[out] out_value The output parameter that will hold the property value at index. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_get_property_value_at(lbug_value* rel_val, uint64_t index, + lbug_value* out_value); +/** + * @brief Converts the given rel value to string. + * @param rel_val The rel value to convert. + * @param[out] out_result The output parameter that will hold the rel value as a string. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_rel_val_to_string(lbug_value* rel_val, char** out_result); +/** + * @brief Destroys any string created by the Lbug C API, including both the error message and the + * values returned by the API functions. This function is provided to avoid the inconsistency + * between the memory allocation and deallocation across different libraries and is preferred over + * using the standard C free function. + * @param str The string to destroy. + */ +LBUG_C_API void lbug_destroy_string(char* str); +/** + * @brief Destroys any blob created by the Lbug C API. This function is provided to avoid the + * inconsistency between the memory allocation and deallocation across different libraries and + * is preferred over using the standard C free function. + * @param blob The blob to destroy. + */ +LBUG_C_API void lbug_destroy_blob(uint8_t* blob); + +// QuerySummary +/** + * @brief Destroys the given query summary. + * @param query_summary The query summary to destroy. + */ +LBUG_C_API void lbug_query_summary_destroy(lbug_query_summary* query_summary); +/** + * @brief Returns the compilation time of the given query summary in milliseconds. + * @param query_summary The query summary to get compilation time. + */ +LBUG_C_API double lbug_query_summary_get_compiling_time(lbug_query_summary* query_summary); +/** + * @brief Returns the execution time of the given query summary in milliseconds. + * @param query_summary The query summary to get execution time. + */ +LBUG_C_API double lbug_query_summary_get_execution_time(lbug_query_summary* query_summary); + +// Utility functions +/** + * @brief Convert timestamp_ns to corresponding tm struct. + * @param timestamp The timestamp_ns value to convert. + * @param[out] out_result The output parameter that will hold the tm struct. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_ns_to_tm(lbug_timestamp_ns_t timestamp, struct tm* out_result); +/** + * @brief Convert timestamp_ms to corresponding tm struct. + * @param timestamp The timestamp_ms value to convert. + * @param[out] out_result The output parameter that will hold the tm struct. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_ms_to_tm(lbug_timestamp_ms_t timestamp, struct tm* out_result); +/** + * @brief Convert timestamp_sec to corresponding tm struct. + * @param timestamp The timestamp_sec value to convert. + * @param[out] out_result The output parameter that will hold the tm struct. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_sec_to_tm(lbug_timestamp_sec_t timestamp, + struct tm* out_result); +/** + * @brief Convert timestamp_tz to corresponding tm struct. + * @param timestamp The timestamp_tz value to convert. + * @param[out] out_result The output parameter that will hold the tm struct. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_tz_to_tm(lbug_timestamp_tz_t timestamp, struct tm* out_result); +/** + * @brief Convert timestamp to corresponding tm struct. + * @param timestamp The timestamp value to convert. + * @param[out] out_result The output parameter that will hold the tm struct. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_to_tm(lbug_timestamp_t timestamp, struct tm* out_result); +/** + * @brief Convert tm struct to timestamp_ns value. + * @param tm The tm struct to convert. + * @param[out] out_result The output parameter that will hold the timestamp_ns value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_ns_from_tm(struct tm tm, lbug_timestamp_ns_t* out_result); +/** + * @brief Convert tm struct to timestamp_ms value. + * @param tm The tm struct to convert. + * @param[out] out_result The output parameter that will hold the timestamp_ms value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_ms_from_tm(struct tm tm, lbug_timestamp_ms_t* out_result); +/** + * @brief Convert tm struct to timestamp_sec value. + * @param tm The tm struct to convert. + * @param[out] out_result The output parameter that will hold the timestamp_sec value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_sec_from_tm(struct tm tm, lbug_timestamp_sec_t* out_result); +/** + * @brief Convert tm struct to timestamp_tz value. + * @param tm The tm struct to convert. + * @param[out] out_result The output parameter that will hold the timestamp_tz value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_tz_from_tm(struct tm tm, lbug_timestamp_tz_t* out_result); +/** + * @brief Convert timestamp_ns to corresponding string. + * @param timestamp The timestamp_ns value to convert. + * @param[out] out_result The output parameter that will hold the string value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_timestamp_from_tm(struct tm tm, lbug_timestamp_t* out_result); +/** + * @brief Convert date to corresponding string. + * @param date The date value to convert. + * @param[out] out_result The output parameter that will hold the string value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_date_to_string(lbug_date_t date, char** out_result); +/** + * @brief Convert a string to date value. + * @param str The string to convert. + * @param[out] out_result The output parameter that will hold the date value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_date_from_string(const char* str, lbug_date_t* out_result); +/** + * @brief Convert date to corresponding tm struct. + * @param date The date value to convert. + * @param[out] out_result The output parameter that will hold the tm struct. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_date_to_tm(lbug_date_t date, struct tm* out_result); +/** + * @brief Convert tm struct to date value. + * @param tm The tm struct to convert. + * @param[out] out_result The output parameter that will hold the date value. + * @return The state indicating the success or failure of the operation. + */ +LBUG_C_API lbug_state lbug_date_from_tm(struct tm tm, lbug_date_t* out_result); +/** + * @brief Convert interval to corresponding difftime value in seconds. + * @param interval The interval value to convert. + * @param[out] out_result The output parameter that will hold the difftime value. + */ +LBUG_C_API void lbug_interval_to_difftime(lbug_interval_t interval, double* out_result); +/** + * @brief Convert difftime value in seconds to interval. + * @param difftime The difftime value to convert. + * @param[out] out_result The output parameter that will hold the interval value. + */ +LBUG_C_API void lbug_interval_from_difftime(double difftime, lbug_interval_t* out_result); + +// Version +/** + * @brief Returns the version of the Lbug library. + */ +LBUG_C_API char* lbug_get_version(); + +/** + * @brief Returns the storage version of the Lbug library. + */ +LBUG_C_API uint64_t lbug_get_storage_version(); +#undef LBUG_C_API diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog.h new file mode 100644 index 0000000000..dd29b83f6c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog.h @@ -0,0 +1,250 @@ +#pragma once + +#include "catalog/catalog_entry/function_catalog_entry.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" +#include "catalog/catalog_set.h" +#include "common/cast.h" +#include "function/function.h" + +namespace lbug::main { +struct DBConfig; +} // namespace lbug::main + +namespace lbug { +namespace main { +class AttachedLbugDatabase; +} // namespace main + +namespace binder { +struct BoundAlterInfo; +struct BoundCreateTableInfo; +struct BoundCreateSequenceInfo; +} // namespace binder + +namespace common { +class VirtualFileSystem; +} // namespace common + +namespace function { +struct ScalarMacroFunction; +} // namespace function + +namespace storage { +class WAL; +} // namespace storage + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace catalog { +class TableCatalogEntry; +class NodeTableCatalogEntry; +class RelGroupCatalogEntry; +class FunctionCatalogEntry; +class SequenceCatalogEntry; +class IndexCatalogEntry; + +template +concept TableCatalogEntryType = + std::is_same_v || std::is_same_v; + +class LBUG_API Catalog { + friend class main::AttachedLbugDatabase; + +public: + Catalog(); + virtual ~Catalog() = default; + + static Catalog* Get(const main::ClientContext& context); + + // ----------------------------- Tables ---------------------------- + + // Check if table entry exists. + bool containsTable(const transaction::Transaction* transaction, const std::string& tableName, + bool useInternal = true) const; + bool containsTable(const transaction::Transaction* transaction, common::table_id_t tableID, + bool useInternal = true) const; + // Get table entry with name. + TableCatalogEntry* getTableCatalogEntry(const transaction::Transaction* transaction, + const std::string& tableName, bool useInternal = true) const; + // Get table entry with id. + TableCatalogEntry* getTableCatalogEntry(const transaction::Transaction* transaction, + common::table_id_t tableID) const; + // Get all node table entries. + std::vector getNodeTableEntries( + const transaction::Transaction* transaction, bool useInternal = true) const; + // Get all rel table entries. + std::vector getRelGroupEntries( + const transaction::Transaction* transaction, bool useInternal = true) const; + // Get all table entries. + std::vector getTableEntries(const transaction::Transaction* transaction, + bool useInternal = true) const; + + // Create table catalog entry. + CatalogEntry* createTableEntry(transaction::Transaction* transaction, + const binder::BoundCreateTableInfo& info); + // Drop table entry and all indices within the table. + void dropTableEntryAndIndex(transaction::Transaction* transaction, const std::string& name); + // Drop table entry with id. + void dropTableEntry(transaction::Transaction* transaction, common::table_id_t tableID); + // Drop table entry. + void dropTableEntry(transaction::Transaction* transaction, const TableCatalogEntry* entry); + // Alter table entry. + void alterTableEntry(transaction::Transaction* transaction, const binder::BoundAlterInfo& info); + + // ----------------------------- Sequences ---------------------------- + + // Check if sequence entry exists. + bool containsSequence(const transaction::Transaction* transaction, + const std::string& name) const; + // Get sequence entry with name. + SequenceCatalogEntry* getSequenceEntry(const transaction::Transaction* transaction, + const std::string& sequenceName, bool useInternalSeq = true) const; + // Get sequence entry with id. + SequenceCatalogEntry* getSequenceEntry(const transaction::Transaction* transaction, + common::sequence_id_t sequenceID) const; + // Get all sequence entries. + std::vector getSequenceEntries( + const transaction::Transaction* transaction) const; + + // Create sequence entry. + common::sequence_id_t createSequence(transaction::Transaction* transaction, + const binder::BoundCreateSequenceInfo& info); + // Drop sequence entry with name. + void dropSequence(transaction::Transaction* transaction, const std::string& name); + // Drop sequence entry with id. + void dropSequence(transaction::Transaction* transaction, common::sequence_id_t sequenceID); + + // ----------------------------- Types ---------------------------- + + // Check if type entry exists. + bool containsType(const transaction::Transaction* transaction, const std::string& name) const; + // Get type entry with name. + common::LogicalType getType(const transaction::Transaction*, const std::string& name) const; + + // Create type entry. + void createType(transaction::Transaction* transaction, std::string name, + common::LogicalType type); + + // ----------------------------- Indexes ---------------------------- + + // Check if index exists for given table and name + bool containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID, + const std::string& indexName) const; + // Check if index exists for given table and property + bool containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID, + common::property_id_t propertyID) const; + // Check if there is any unloaded index for given table and property + bool containsUnloadedIndex(const transaction::Transaction* transaction, + common::table_id_t tableID, common::property_id_t propertyID) const; + // Get index entry with name. + IndexCatalogEntry* getIndex(const transaction::Transaction* transaction, + common::table_id_t tableID, const std::string& indexName) const; + // Get all index entries. + std::vector getIndexEntries( + const transaction::Transaction* transaction) const; + // Get all index entries for given table + std::vector getIndexEntries(const transaction::Transaction* transaction, + common::table_id_t tableID) const; + + // Create index entry. + void createIndex(transaction::Transaction* transaction, + std::unique_ptr indexCatalogEntry); + // Drop all index entries within a table. + void dropAllIndexes(transaction::Transaction* transaction, common::table_id_t tableID); + // Drop index entry with name. + void dropIndex(transaction::Transaction* transaction, common::table_id_t tableID, + const std::string& indexName) const; + void dropIndex(transaction::Transaction* transaction, common::oid_t indexOID); + + // ----------------------------- Functions ---------------------------- + + // Check if function exists. + bool containsFunction(const transaction::Transaction* transaction, const std::string& name, + bool useInternal = false) const; + // Get function entry by name. + // Note we cannot cast to FunctionEntry here because result could also be a MacroEntry. + CatalogEntry* getFunctionEntry(const transaction::Transaction* transaction, + const std::string& name, bool useInternal = false) const; + // Get all function entries. + std::vector getFunctionEntries( + const transaction::Transaction* transaction) const; + + // Get all macro entries. + std::vector getMacroEntries( + const transaction::Transaction* transaction) const; + + // Add function with name. + void addFunction(transaction::Transaction* transaction, CatalogEntryType entryType, + std::string name, function::function_set functionSet, bool isInternal = false); + // Drop function with name. + void dropFunction(transaction::Transaction* transaction, const std::string& name); + + // ----------------------------- Macro ---------------------------- + + // Check if macro entry exists. + bool containsMacro(const transaction::Transaction* transaction, + const std::string& macroName) const; + void addScalarMacroFunction(transaction::Transaction* transaction, std::string name, + std::unique_ptr macro); + ScalarMacroCatalogEntry* getScalarMacroCatalogEntry(const transaction::Transaction* transaction, + lbug::common::oid_t MacroID) const; + void dropMacroEntry(transaction::Transaction* transaction, const lbug::common::oid_t macroID); + void dropMacroEntry(transaction::Transaction* transaction, + const ScalarMacroCatalogEntry* entry); + function::ScalarMacroFunction* getScalarMacroFunction( + const transaction::Transaction* transaction, const std::string& name) const; + std::vector getMacroNames(const transaction::Transaction* transaction) const; + void dropMacro(transaction::Transaction* transaction, std::string& name); + + void incrementVersion() { version++; } + uint64_t getVersion() const { return version; } + bool changedSinceLastCheckpoint() const { return version != 0; } + void resetVersion() { version = 0; } + + void serialize(common::Serializer& ser) const; + void deserialize(common::Deserializer& deSer); + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + +private: + void initCatalogSets(); + void registerBuiltInFunctions(); + + CatalogEntry* createNodeTableEntry(transaction::Transaction* transaction, + const binder::BoundCreateTableInfo& info); + CatalogEntry* createRelGroupEntry(transaction::Transaction* transaction, + const binder::BoundCreateTableInfo& info); + + void createSerialSequence(transaction::Transaction* transaction, const TableCatalogEntry* entry, + bool isInternal); + void dropSerialSequence(transaction::Transaction* transaction, const TableCatalogEntry* entry); + + template + std::vector getTableEntries(const transaction::Transaction* transaction, bool useInternal, + CatalogEntryType entryType) const; + +protected: + std::unique_ptr tables; + +private: + std::unique_ptr sequences; + std::unique_ptr functions; + std::unique_ptr types; + std::unique_ptr indexes; + std::unique_ptr macros; + std::unique_ptr internalTables; + std::unique_ptr internalSequences; + std::unique_ptr internalFunctions; + + // incremented whenever a change is made to the catalog + // reset to 0 at the end of each checkpoint + uint64_t version; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/catalog_entry.h new file mode 100644 index 0000000000..248c90373a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/catalog_entry.h @@ -0,0 +1,114 @@ +#pragma once + +#include + +#include "catalog_entry_type.h" +#include "common/assert.h" +#include "common/copy_constructors.h" +#include "common/serializer/serializer.h" +#include "common/types/types.h" + +namespace lbug { +namespace main { +class ClientContext; +} // namespace main + +namespace catalog { + +struct LBUG_API ToCypherInfo { + virtual ~ToCypherInfo() = default; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +class LBUG_API CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructor & destructor + //===--------------------------------------------------------------------===// + CatalogEntry() : CatalogEntry{CatalogEntryType::DUMMY_ENTRY, ""} {} + CatalogEntry(CatalogEntryType type, std::string name) + : type{type}, name{std::move(name)}, oid{common::INVALID_OID}, + timestamp{common::INVALID_TRANSACTION} {} + DELETE_COPY_DEFAULT_MOVE(CatalogEntry); + virtual ~CatalogEntry() = default; + + //===--------------------------------------------------------------------===// + // getter & setter + //===--------------------------------------------------------------------===// + CatalogEntryType getType() const { return type; } + void rename(std::string name_) { this->name = std::move(name_); } + std::string getName() const { return name; } + common::transaction_t getTimestamp() const { return timestamp; } + void setTimestamp(common::transaction_t timestamp_) { this->timestamp = timestamp_; } + bool isDeleted() const { return deleted; } + void setDeleted(bool deleted_) { this->deleted = deleted_; } + bool hasParent() const { return hasParent_; } + void setHasParent(bool hasParent) { hasParent_ = hasParent; } + void setOID(common::oid_t oid) { this->oid = oid; } + common::oid_t getOID() const { return oid; } + CatalogEntry* getPrev() const { + KU_ASSERT(prev); + return prev.get(); + } + std::unique_ptr movePrev() { + if (this->prev) { + this->prev->setNext(nullptr); + } + return std::move(prev); + } + void setPrev(std::unique_ptr prev_) { + this->prev = std::move(prev_); + if (this->prev) { + this->prev->setNext(this); + } + } + CatalogEntry* getNext() const { return next; } + void setNext(CatalogEntry* next_) { this->next = next_; } + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + virtual void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + virtual std::string toCypher(const ToCypherInfo& /*info*/) const { KU_UNREACHABLE; } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + +protected: + virtual void copyFrom(const CatalogEntry& other); + +protected: + CatalogEntryType type; + std::string name; + common::oid_t oid; + common::transaction_t timestamp; + bool deleted = false; + bool hasParent_ = false; + // Older versions. + std::unique_ptr prev; + // Newer versions. + CatalogEntry* next = nullptr; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/catalog_entry_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/catalog_entry_type.h new file mode 100644 index 0000000000..8ad76675cb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/catalog_entry_type.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace catalog { + +enum class CatalogEntryType : uint8_t { + // Table entries + NODE_TABLE_ENTRY = 0, + REL_GROUP_ENTRY = 2, + FOREIGN_TABLE_ENTRY = 4, + // Macro entries + SCALAR_MACRO_ENTRY = 10, + // Function entries + AGGREGATE_FUNCTION_ENTRY = 20, + SCALAR_FUNCTION_ENTRY = 21, + REWRITE_FUNCTION_ENTRY = 22, + TABLE_FUNCTION_ENTRY = 23, + COPY_FUNCTION_ENTRY = 25, + STANDALONE_TABLE_FUNCTION_ENTRY = 26, + // Sequence entries + SEQUENCE_ENTRY = 40, + // UDT entries + TYPE_ENTRY = 41, + // Index entries + INDEX_ENTRY = 42, + // Dummy entry + DUMMY_ENTRY = 100, +}; + +struct CatalogEntryTypeUtils { + static std::string toString(CatalogEntryType type); +}; + +struct FunctionEntryTypeUtils { + static std::string toString(CatalogEntryType type); +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/dummy_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/dummy_catalog_entry.h new file mode 100644 index 0000000000..3b2d1aac9d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/dummy_catalog_entry.h @@ -0,0 +1,22 @@ +#pragma once + +#include "catalog/catalog_entry/catalog_entry.h" + +namespace lbug { +namespace catalog { + +class DummyCatalogEntry final : public CatalogEntry { +public: + explicit DummyCatalogEntry(std::string name, common::oid_t oid) + : CatalogEntry{CatalogEntryType::DUMMY_ENTRY, std::move(name)} { + setDeleted(true); + setTimestamp(0); + setOID(oid); + } + + void serialize(common::Serializer& /*serializer*/) const override { KU_UNREACHABLE; } + std::string toCypher(const ToCypherInfo& /*info*/) const override { KU_UNREACHABLE; } +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/function_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/function_catalog_entry.h new file mode 100644 index 0000000000..1bd5ed85ef --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/function_catalog_entry.h @@ -0,0 +1,35 @@ +#pragma once + +#include "catalog_entry.h" +#include "function/function.h" + +namespace lbug { +namespace catalog { + +class LBUG_API FunctionCatalogEntry : public CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + FunctionCatalogEntry() = default; + FunctionCatalogEntry(CatalogEntryType entryType, std::string name, + function::function_set functionSet); + + //===--------------------------------------------------------------------===// + // getters & setters + //===--------------------------------------------------------------------===// + const function::function_set& getFunctionSet() const { return functionSet; } + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + // We always register functions while initializing the catalog, so we don't have to + // serialize functions. + void serialize(common::Serializer& /*serializer*/) const override { KU_UNREACHABLE; } + +protected: + function::function_set functionSet; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/index_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/index_catalog_entry.h new file mode 100644 index 0000000000..0781680190 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/index_catalog_entry.h @@ -0,0 +1,116 @@ +#pragma once + +#include "catalog_entry.h" +#include "common/copier_config/file_scan_info.h" +#include "common/serializer/buffer_reader.h" +#include "common/serializer/deserializer.h" +#include "table_catalog_entry.h" + +namespace lbug::common { +struct BufferReader; +} +namespace lbug::common { +class BufferWriter; +} +namespace lbug { +namespace catalog { + +struct LBUG_API IndexToCypherInfo : ToCypherInfo { + const main::ClientContext* context; + const common::FileScanInfo& exportFileInfo; + + IndexToCypherInfo(const main::ClientContext* context, + const common::FileScanInfo& exportFileInfo) + : context{context}, exportFileInfo{exportFileInfo} {} +}; + +class IndexCatalogEntry; +struct LBUG_API IndexAuxInfo { + virtual ~IndexAuxInfo() = default; + virtual std::shared_ptr serialize() const; + + virtual std::unique_ptr copy() = 0; + + template + TARGET& cast() { + return dynamic_cast(*this); + } + template + const TARGET& cast() const { + return dynamic_cast(*this); + } + + virtual std::string toCypher(const IndexCatalogEntry& indexEntry, + const ToCypherInfo& info) const = 0; + + virtual TableCatalogEntry* getTableEntryToExport(const main::ClientContext* /*context*/) const { + return nullptr; + } +}; + +class LBUG_API IndexCatalogEntry final : public CatalogEntry { +public: + static std::string getInternalIndexName(common::table_id_t tableID, std::string indexName) { + return common::stringFormat("{}_{}", tableID, std::move(indexName)); + } + + IndexCatalogEntry(std::string type, common::table_id_t tableID, std::string indexName, + std::vector properties, std::unique_ptr auxInfo) + : CatalogEntry{CatalogEntryType::INDEX_ENTRY, + common::stringFormat("{}_{}", tableID, indexName)}, + type{std::move(type)}, tableID{tableID}, indexName{std::move(indexName)}, + propertyIDs{std::move(properties)}, auxInfo{std::move(auxInfo)} {} + + std::string getIndexType() const { return type; } + + common::table_id_t getTableID() const { return tableID; } + + std::string getIndexName() const { return indexName; } + + std::vector getPropertyIDs() const { return propertyIDs; } + bool containsPropertyID(common::property_id_t propertyID) const; + + // When serializing index entries to disk, we first write the fields of the base class, + // followed by the size (in bytes) of the auxiliary data and its content. + void serialize(common::Serializer& serializer) const override; + // During deserialization of index entries from disk, we first read the base class + // (IndexCatalogEntry). The auxiliary data is stored in auxBuffer, with its size in + // auxBufferSize. Once the extension is loaded, the corresponding indexes are reconstructed + // using the auxBuffer. + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + std::string toCypher(const ToCypherInfo& info) const override { + return isLoaded() ? auxInfo->toCypher(*this, info) : ""; + } + + void copyFrom(const CatalogEntry& other) override; + + std::unique_ptr getAuxBufferReader() const; + + void setAuxInfo(std::unique_ptr auxInfo_); + const IndexAuxInfo& getAuxInfo() const { return *auxInfo; } + IndexAuxInfo& getAuxInfoUnsafe() { return *auxInfo; } + + bool isLoaded() const { return auxBuffer == nullptr; } + + TableCatalogEntry* getTableEntryToExport(main::ClientContext* context) const { + return isLoaded() ? auxInfo->getTableEntryToExport(context) : nullptr; + } + + std::unique_ptr copy() const { + return std::make_unique(type, tableID, indexName, propertyIDs, + auxInfo->copy()); + } + +protected: + std::string type; + common::table_id_t tableID = common::INVALID_TABLE_ID; + std::string indexName; + std::vector propertyIDs; + std::unique_ptr auxBuffer = nullptr; + std::unique_ptr auxInfo; + uint64_t auxBufferSize = 0; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/node_table_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/node_table_catalog_entry.h new file mode 100644 index 0000000000..c7de5bac40 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/node_table_catalog_entry.h @@ -0,0 +1,50 @@ +#pragma once + +#include "table_catalog_entry.h" + +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace catalog { + +class Catalog; +class LBUG_API NodeTableCatalogEntry final : public TableCatalogEntry { + static constexpr CatalogEntryType entryType_ = CatalogEntryType::NODE_TABLE_ENTRY; + +public: + NodeTableCatalogEntry() = default; + NodeTableCatalogEntry(std::string name, std::string primaryKeyName) + : TableCatalogEntry{entryType_, std::move(name)}, + primaryKeyName{std::move(primaryKeyName)} {} + + bool isParent(common::table_id_t /*tableID*/) override { return false; } + common::TableType getTableType() const override { return common::TableType::NODE; } + + std::string getPrimaryKeyName() const { return primaryKeyName; } + common::property_id_t getPrimaryKeyID() const { + return propertyCollection.getPropertyID(primaryKeyName); + } + const binder::PropertyDefinition& getPrimaryKeyDefinition() const { + return getProperty(primaryKeyName); + } + + void renameProperty(const std::string& propertyName, const std::string& newName) override; + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + std::unique_ptr copy() const override; + std::string toCypher(const ToCypherInfo& info) const override; + +private: + std::unique_ptr getBoundExtraCreateInfo( + transaction::Transaction* transaction) const override; + +private: + std::string primaryKeyName; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/node_table_id_pair.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/node_table_id_pair.h new file mode 100644 index 0000000000..d006cd53f1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/node_table_id_pair.h @@ -0,0 +1,38 @@ +#pragma once + +#include "common/types/types.h" + +namespace lbug { +namespace catalog { + +struct NodeTableIDPair { + common::table_id_t srcTableID = common::INVALID_TABLE_ID; + common::table_id_t dstTableID = common::INVALID_TABLE_ID; + + NodeTableIDPair() = default; + NodeTableIDPair(common::table_id_t srcTableID, common::table_id_t dstTableID) + : srcTableID{srcTableID}, dstTableID{dstTableID} {} + + void serialize(common::Serializer& serializer) const; + static NodeTableIDPair deserialize(common::Deserializer& deser); +}; + +struct NodeTableIDPairHash { + std::size_t operator()(const NodeTableIDPair& np) const { + std::size_t h1 = std::hash{}(np.srcTableID); + std::size_t h2 = std::hash{}(np.dstTableID); + return h1 ^ (h2 << 1); + } +}; + +struct NodeTableIDPairEqual { + bool operator()(const NodeTableIDPair& lhs, const NodeTableIDPair& rhs) const { + return lhs.srcTableID == rhs.srcTableID && lhs.dstTableID == rhs.dstTableID; + } +}; + +using node_table_id_pair_set_t = + std::unordered_set; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/rel_group_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/rel_group_catalog_entry.h new file mode 100644 index 0000000000..4349bf8884 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/rel_group_catalog_entry.h @@ -0,0 +1,103 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/enums/extend_direction.h" +#include "common/enums/rel_direction.h" +#include "common/enums/rel_multiplicity.h" +#include "node_table_id_pair.h" + +namespace lbug { +namespace catalog { + +struct RelGroupToCypherInfo final : ToCypherInfo { + const main::ClientContext* context; + + explicit RelGroupToCypherInfo(const main::ClientContext* context) : context{context} {} +}; + +struct RelTableCatalogInfo { + NodeTableIDPair nodePair; + common::oid_t oid = common::INVALID_OID; + + RelTableCatalogInfo() = default; + RelTableCatalogInfo(NodeTableIDPair nodePair, common::oid_t oid) + : nodePair{nodePair}, oid{oid} {} + + void serialize(common::Serializer& ser) const; + static RelTableCatalogInfo deserialize(common::Deserializer& deser); +}; + +class LBUG_API RelGroupCatalogEntry final : public TableCatalogEntry { + static constexpr CatalogEntryType type_ = CatalogEntryType::REL_GROUP_ENTRY; + +public: + RelGroupCatalogEntry() = default; + RelGroupCatalogEntry(std::string tableName, common::RelMultiplicity srcMultiplicity, + common::RelMultiplicity dstMultiplicity, common::ExtendDirection storageDirection, + std::vector relTableInfos) + : TableCatalogEntry{type_, std::move(tableName)}, srcMultiplicity{srcMultiplicity}, + dstMultiplicity{dstMultiplicity}, storageDirection{storageDirection}, + relTableInfos{std::move(relTableInfos)} { + propertyCollection = + PropertyDefinitionCollection{1}; // Skip NBR_NODE_ID column as the first one. + } + + bool isParent(common::table_id_t tableID) override; + common::TableType getTableType() const override { return common::TableType::REL; } + + common::RelMultiplicity getMultiplicity(common::RelDataDirection direction) const { + return direction == common::RelDataDirection::FWD ? dstMultiplicity : srcMultiplicity; + } + bool isSingleMultiplicity(common::RelDataDirection direction) const { + return getMultiplicity(direction) == common::RelMultiplicity::ONE; + } + + common::ExtendDirection getStorageDirection() const { return storageDirection; } + + common::idx_t getNumRelTables() const { return relTableInfos.size(); } + const std::vector& getRelEntryInfos() const { return relTableInfos; } + const RelTableCatalogInfo& getSingleRelEntryInfo() const; + bool hasRelEntryInfo(common::table_id_t srcTableID, common::table_id_t dstTableID) const { + return getRelEntryInfo(srcTableID, dstTableID) != nullptr; + } + const RelTableCatalogInfo* getRelEntryInfo(common::table_id_t srcTableID, + common::table_id_t dstTableID) const; + + std::unordered_set getSrcNodeTableIDSet() const; + std::unordered_set getDstNodeTableIDSet() const; + std::unordered_set getBoundNodeTableIDSet( + common::RelDataDirection direction) const { + return direction == common::RelDataDirection::FWD ? getSrcNodeTableIDSet() : + getDstNodeTableIDSet(); + } + std::unordered_set getNbrNodeTableIDSet( + common::RelDataDirection direction) const { + return direction == common::RelDataDirection::FWD ? getDstNodeTableIDSet() : + getSrcNodeTableIDSet(); + } + + std::vector getRelDataDirections() const; + + void addFromToConnection(common::table_id_t srcTableID, common::table_id_t dstTableID, + common::oid_t oid); + void dropFromToConnection(common::table_id_t srcTableID, common::table_id_t dstTableID); + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + std::string toCypher(const ToCypherInfo& info) const override; + + std::unique_ptr copy() const override; + +protected: + std::unique_ptr getBoundExtraCreateInfo( + transaction::Transaction*) const override; + +private: + common::RelMultiplicity srcMultiplicity = common::RelMultiplicity::MANY; + common::RelMultiplicity dstMultiplicity = common::RelMultiplicity::MANY; + // TODO(Guodong): Avoid using extend direction for storage direction + common::ExtendDirection storageDirection = common::ExtendDirection::BOTH; + std::vector relTableInfos; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/scalar_macro_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/scalar_macro_catalog_entry.h new file mode 100644 index 0000000000..1fdd9c2c9c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/scalar_macro_catalog_entry.h @@ -0,0 +1,35 @@ +#pragma once + +#include "catalog_entry.h" +#include "function/scalar_macro_function.h" + +namespace lbug { +namespace catalog { + +class ScalarMacroCatalogEntry final : public CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + ScalarMacroCatalogEntry() = default; + ScalarMacroCatalogEntry(std::string name, + std::unique_ptr macroFunction); + + //===--------------------------------------------------------------------===// + // getter & setter + //===--------------------------------------------------------------------===// + function::ScalarMacroFunction* getMacroFunction() const { return macroFunction.get(); } + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + std::string toCypher(const ToCypherInfo& info) const override; + +private: + std::unique_ptr macroFunction; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/sequence_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/sequence_catalog_entry.h new file mode 100644 index 0000000000..337aaeaa61 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/sequence_catalog_entry.h @@ -0,0 +1,95 @@ +#pragma once + +#include + +#include "binder/ddl/bound_create_sequence_info.h" +#include "catalog_entry.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { +class ValueVector; +} + +namespace binder { +struct BoundExtraCreateCatalogEntryInfo; +struct BoundAlterInfo; +} // namespace binder + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace catalog { + +struct SequenceRollbackData { + uint64_t usageCount; + int64_t currVal; +}; + +struct SequenceData { + SequenceData() = default; + explicit SequenceData(const binder::BoundCreateSequenceInfo& info) + : usageCount{0}, currVal{info.startWith}, increment{info.increment}, + startValue{info.startWith}, minValue{info.minValue}, maxValue{info.maxValue}, + cycle{info.cycle} {} + + uint64_t usageCount; + int64_t currVal; + int64_t increment; + int64_t startValue; + int64_t minValue; + int64_t maxValue; + bool cycle; +}; + +class CatalogSet; +class LBUG_API SequenceCatalogEntry final : public CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + SequenceCatalogEntry() : sequenceData{} {} + explicit SequenceCatalogEntry(const binder::BoundCreateSequenceInfo& sequenceInfo) + : CatalogEntry{CatalogEntryType::SEQUENCE_ENTRY, sequenceInfo.sequenceName}, + sequenceData{SequenceData(sequenceInfo)} {} + + //===--------------------------------------------------------------------===// + // getter & setter + //===--------------------------------------------------------------------===// + SequenceData getSequenceData(); + + //===--------------------------------------------------------------------===// + // sequence functions + //===--------------------------------------------------------------------===// + int64_t currVal(); + void nextKVal(transaction::Transaction* transaction, const uint64_t& count); + void nextKVal(transaction::Transaction* transaction, const uint64_t& count, + common::ValueVector& resultVector); + void rollbackVal(const uint64_t& usageCount, const int64_t& currVal); + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + std::string toCypher(const ToCypherInfo& info) const override; + + binder::BoundCreateSequenceInfo getBoundCreateSequenceInfo(bool isInternal) const; + + static std::string getSerialName(const std::string& tableName, + const std::string& propertyName) { + return std::string(tableName).append("_").append(propertyName).append("_").append("serial"); + } + +private: + void nextValNoLock(); + +private: + std::mutex mtx; + SequenceData sequenceData; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/table_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/table_catalog_entry.h new file mode 100644 index 0000000000..797bbf39fc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/table_catalog_entry.h @@ -0,0 +1,96 @@ +#pragma once + +#include + +#include "binder/ddl/bound_alter_info.h" +#include "binder/ddl/bound_create_table_info.h" +#include "catalog/catalog_entry/catalog_entry.h" +#include "catalog/property_definition_collection.h" +#include "common/enums/table_type.h" +#include "common/types/types.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace binder { +struct BoundExtraCreateCatalogEntryInfo; +} // namespace binder + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace catalog { + +class CatalogSet; +class Catalog; +class LBUG_API TableCatalogEntry : public CatalogEntry { +public: + TableCatalogEntry() = default; + TableCatalogEntry(CatalogEntryType catalogType, std::string name) + : CatalogEntry{catalogType, std::move(name)} {} + TableCatalogEntry& operator=(const TableCatalogEntry&) = delete; + + common::table_id_t getTableID() const { return oid; } + + virtual std::unique_ptr alter(common::transaction_t timestamp, + const binder::BoundAlterInfo& alterInfo, CatalogSet* tables) const; + + virtual bool isParent(common::table_id_t /*tableID*/) { return false; }; + virtual common::TableType getTableType() const = 0; + + std::string getComment() const { return comment; } + void setComment(std::string newComment) { comment = std::move(newComment); } + + virtual function::TableFunction getScanFunction() { KU_UNREACHABLE; } + + common::column_id_t getMaxColumnID() const; + void vacuumColumnIDs(common::column_id_t nextColumnID); + std::vector getProperties() const { + return propertyCollection.getDefinitions(); + } + common::idx_t getNumProperties() const { return propertyCollection.size(); } + bool containsProperty(const std::string& propertyName) const; + common::property_id_t getPropertyID(const std::string& propertyName) const; + const binder::PropertyDefinition& getProperty(const std::string& propertyName) const; + const binder::PropertyDefinition& getProperty(common::idx_t idx) const; + virtual common::column_id_t getColumnID(const std::string& propertyName) const; + common::column_id_t getColumnID(common::idx_t idx) const; + void addProperty(const binder::PropertyDefinition& propertyDefinition); + void dropProperty(const std::string& propertyName); + virtual void renameProperty(const std::string& propertyName, const std::string& newName); + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + CatalogEntryType type); + virtual std::unique_ptr copy() const = 0; + + binder::BoundCreateTableInfo getBoundCreateTableInfo(transaction::Transaction* transaction, + bool isInternal) const; + +protected: + void copyFrom(const CatalogEntry& other) override; + virtual std::unique_ptr getBoundExtraCreateInfo( + transaction::Transaction* transaction) const = 0; + +protected: + std::string comment; + PropertyDefinitionCollection propertyCollection; +}; + +struct TableCatalogEntryHasher { + std::size_t operator()(TableCatalogEntry* entry) const { + return std::hash{}(entry->getTableID()); + } +}; + +struct TableCatalogEntryEquality { + bool operator()(TableCatalogEntry* left, TableCatalogEntry* right) const { + return left->getTableID() == right->getTableID(); + } +}; + +using table_catalog_entry_set_t = + std::unordered_set; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/type_catalog_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/type_catalog_entry.h new file mode 100644 index 0000000000..6cd347e80d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_entry/type_catalog_entry.h @@ -0,0 +1,33 @@ +#pragma once + +#include "catalog_entry.h" + +namespace lbug { +namespace catalog { + +class TypeCatalogEntry : public CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + TypeCatalogEntry() = default; + TypeCatalogEntry(std::string name, common::LogicalType type) + : CatalogEntry{CatalogEntryType::TYPE_ENTRY, std::move(name)}, type{std::move(type)} {} + + //===--------------------------------------------------------------------===// + // getter & setter + //===--------------------------------------------------------------------===// + const common::LogicalType& getLogicalType() const { return type; } + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + +private: + common::LogicalType type; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_set.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_set.h new file mode 100644 index 0000000000..431d4664f0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/catalog_set.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include + +#include "catalog_entry/catalog_entry.h" +#include "common/case_insensitive_map.h" + +namespace lbug { +namespace binder { +struct BoundAlterInfo; +} // namespace binder + +namespace storage { +class UndoBuffer; +} // namespace storage + +namespace transaction { +class Transaction; +} // namespace transaction + +using CatalogEntrySet = common::case_insensitive_map_t; + +namespace catalog { +class LBUG_API CatalogSet { + friend class storage::UndoBuffer; + +public: + CatalogSet() = default; + explicit CatalogSet(bool isInternal); + bool containsEntry(const transaction::Transaction* transaction, const std::string& name); + CatalogEntry* getEntry(const transaction::Transaction* transaction, const std::string& name); + common::oid_t createEntry(transaction::Transaction* transaction, + std::unique_ptr entry); + void dropEntry(transaction::Transaction* transaction, const std::string& name, + common::oid_t oid); + + void alterTableEntry(transaction::Transaction* transaction, + const binder::BoundAlterInfo& alterInfo); + + CatalogEntrySet getEntries(const transaction::Transaction* transaction); + CatalogEntry* getEntryOfOID(const transaction::Transaction* transaction, common::oid_t oid); + + void serialize(common::Serializer serializer) const; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + common::oid_t getNextOID() { + std::unique_lock lck{mtx}; + return nextOID++; + } + + common::oid_t getNextOIDNoLock() { return nextOID++; } + +private: + bool containsEntryNoLock(const transaction::Transaction* transaction, + const std::string& name) const; + CatalogEntry* getEntryNoLock(const transaction::Transaction* transaction, + const std::string& name) const; + CatalogEntry* createEntryNoLock(const transaction::Transaction* transaction, + std::unique_ptr entry); + CatalogEntry* dropEntryNoLock(const transaction::Transaction* transaction, + const std::string& name, common::oid_t oid); + + void validateExistNoLock(const transaction::Transaction* transaction, + const std::string& name) const; + void validateNotExistNoLock(const transaction::Transaction* transaction, + const std::string& name) const; + + void emplaceNoLock(std::unique_ptr entry); + void eraseNoLock(const std::string& name); + + static std::unique_ptr createDummyEntryNoLock(std::string name, + common::oid_t oid); + + static CatalogEntry* traverseVersionChainsForTransactionNoLock( + const transaction::Transaction* transaction, CatalogEntry* currentEntry); + static CatalogEntry* getCommittedEntryNoLock(CatalogEntry* entry); + bool isInternal() const { return nextOID >= INTERNAL_CATALOG_SET_START_OID; } + +public: + // To ensure the uniqueness of the OID and avoid conflict with user tables/sequence, we make the + // start OID of the internal catalog set to be 2^63. + static constexpr common::oid_t INTERNAL_CATALOG_SET_START_OID = 1LL << 63; + +private: + std::shared_mutex mtx; + common::oid_t nextOID = 0; + common::case_insensitive_map_t> entries; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/property_definition_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/property_definition_collection.h new file mode 100644 index 0000000000..d2f836659f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/catalog/property_definition_collection.h @@ -0,0 +1,53 @@ +#pragma once + +#include "binder/ddl/property_definition.h" +#include "common/case_insensitive_map.h" + +namespace lbug { +namespace catalog { + +class LBUG_API PropertyDefinitionCollection { +public: + PropertyDefinitionCollection() : nextColumnID{0}, nextPropertyID{0} {} + explicit PropertyDefinitionCollection(common::column_id_t nextColumnID) + : nextColumnID{nextColumnID}, nextPropertyID{0} {} + EXPLICIT_COPY_DEFAULT_MOVE(PropertyDefinitionCollection); + + common::idx_t size() const { return definitions.size(); } + + bool contains(const std::string& name) const { return nameToPropertyIDMap.contains(name); } + + std::vector getDefinitions() const; + const binder::PropertyDefinition& getDefinition(const std::string& name) const; + const binder::PropertyDefinition& getDefinition(common::idx_t idx) const; + common::column_id_t getMaxColumnID() const; + common::column_id_t getColumnID(const std::string& name) const; + common::column_id_t getColumnID(common::property_id_t propertyID) const; + common::property_id_t getPropertyID(const std::string& name) const; + void vacuumColumnIDs(common::column_id_t nextColumnID); + + void add(const binder::PropertyDefinition& definition); + void drop(const std::string& name); + void rename(const std::string& name, const std::string& newName); + + std::string toCypher() const; + + void serialize(common::Serializer& serializer) const; + static PropertyDefinitionCollection deserialize(common::Deserializer& deserializer); + +private: + PropertyDefinitionCollection(const PropertyDefinitionCollection& other) + : nextColumnID{other.nextColumnID}, nextPropertyID{other.nextPropertyID}, + definitions{copyMap(other.definitions)}, columnIDs{other.columnIDs}, + nameToPropertyIDMap{other.nameToPropertyIDMap} {} + +private: + common::column_id_t nextColumnID; + common::property_id_t nextPropertyID; + std::map definitions; + std::unordered_map columnIDs; + common::case_insensitive_map_t nameToPropertyIDMap; +}; + +} // namespace catalog +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/api.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/api.h new file mode 100644 index 0000000000..1ad3e3d3a7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/api.h @@ -0,0 +1,36 @@ +#pragma once + +// Helpers +#if defined _WIN32 || defined __CYGWIN__ +#define LBUG_HELPER_DLL_IMPORT __declspec(dllimport) +#define LBUG_HELPER_DLL_EXPORT __declspec(dllexport) +#define LBUG_HELPER_DLL_LOCAL +#define LBUG_HELPER_DEPRECATED __declspec(deprecated) +#else +#define LBUG_HELPER_DLL_IMPORT __attribute__((visibility("default"))) +#define LBUG_HELPER_DLL_EXPORT __attribute__((visibility("default"))) +#define LBUG_HELPER_DLL_LOCAL __attribute__((visibility("hidden"))) +#define LBUG_HELPER_DEPRECATED __attribute__((__deprecated__)) +#endif + +#ifdef LBUG_STATIC_DEFINE +#define LBUG_API +#else +#ifndef LBUG_API +#ifdef LBUG_EXPORTS +/* We are building this library */ +#define LBUG_API LBUG_HELPER_DLL_EXPORT +#else +/* We are using this library */ +#define LBUG_API LBUG_HELPER_DLL_IMPORT +#endif +#endif +#endif + +#ifndef LBUG_DEPRECATED +#define LBUG_DEPRECATED LBUG_HELPER_DEPRECATED +#endif + +#ifndef LBUG_DEPRECATED_EXPORT +#define LBUG_DEPRECATED_EXPORT LBUG_API LBUG_DEPRECATED +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/array_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/array_utils.h new file mode 100644 index 0000000000..c2c3498f45 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/array_utils.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include + +namespace lbug::common { +template +constexpr std::array arrayConcat(const std::array& arr1, + const std::array& arr2) { + std::array ret{}; + std::copy_n(arr1.cbegin(), arr1.size(), ret.begin()); + std::copy_n(arr2.cbegin(), arr2.size(), ret.begin() + arr1.size()); + return ret; +} +} // namespace lbug::common diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow.h new file mode 100644 index 0000000000..40d9917628 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow.h @@ -0,0 +1,74 @@ +#pragma once + +// The Arrow C data interface. +// https://arrow.apache.org/docs/format/CDataInterface.html + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + +#ifdef __cplusplus +} +#endif + +struct ArrowSchemaWrapper : public ArrowSchema { + ArrowSchemaWrapper() : ArrowSchema{} { release = nullptr; } + ~ArrowSchemaWrapper() { + if (release) { + release(this); + } + } +}; + +struct ArrowArrayWrapper : public ArrowArray { + ArrowArrayWrapper() : ArrowArray{} { release = nullptr; } + ~ArrowArrayWrapper() { + if (release) { + release(this); + } + } +}; diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_buffer.h new file mode 100644 index 0000000000..9a3ac8300f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_buffer.h @@ -0,0 +1,83 @@ +#pragma once + +#include "common/utils.h" + +struct ArrowSchema; + +namespace lbug { +namespace common { + +struct ArrowBuffer { + ArrowBuffer() : dataptr(nullptr), count(0), capacity(0) {} + ~ArrowBuffer() { + if (!dataptr) { + return; + } + free(dataptr); + dataptr = nullptr; + count = 0; + capacity = 0; + } + // disable copy constructors + ArrowBuffer(const ArrowBuffer& other) = delete; + ArrowBuffer& operator=(const ArrowBuffer&) = delete; + //! enable move constructors + ArrowBuffer(ArrowBuffer&& other) noexcept { + std::swap(dataptr, other.dataptr); + std::swap(count, other.count); + std::swap(capacity, other.capacity); + } + ArrowBuffer& operator=(ArrowBuffer&& other) noexcept { + std::swap(dataptr, other.dataptr); + std::swap(count, other.count); + std::swap(capacity, other.capacity); + return *this; + } + + void reserve(uint64_t bytes) { // NOLINT + auto new_capacity = nextPowerOfTwo(bytes); + if (new_capacity <= capacity) { + return; + } + reserveInternal(new_capacity); + } + + void resize(uint64_t bytes) { // NOLINT + reserve(bytes); + count = bytes; + } + + void resize(uint64_t bytes, uint8_t value) { // NOLINT + reserve(bytes); + for (uint64_t i = count; i < bytes; i++) { + dataptr[i] = value; + } + count = bytes; + } + + uint64_t size() { // NOLINT + return count; + } + + uint8_t* data() { // NOLINT + return dataptr; + } + +private: + void reserveInternal(uint64_t bytes) { + if (dataptr) { + dataptr = (uint8_t*)realloc(dataptr, bytes); + } else { + dataptr = (uint8_t*)malloc(bytes); + } + capacity = bytes; + } + +private: + uint8_t* dataptr = nullptr; + uint64_t count = 0; + uint64_t capacity = 0; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_converter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_converter.h new file mode 100644 index 0000000000..f35e638a67 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_converter.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#include "common/arrow/arrow.h" +#include "common/arrow/arrow_nullmask_tree.h" + +struct ArrowSchema; + +namespace lbug { +namespace common { + +struct ArrowSchemaHolder { + std::vector children; + std::vector childrenPtrs; + std::vector> nestedChildren; + std::vector> nestedChildrenPtr; + std::vector> ownedTypeNames; + std::vector> ownedMetadatas; +}; + +class ArrowConverter { +public: + static std::unique_ptr toArrowSchema(const std::vector& dataTypes, + const std::vector& columnNames, bool fallbackExtensionTypes); + + static LogicalType fromArrowSchema(const ArrowSchema* schema); + static void fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count); + static void fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector); + +private: + static void initializeChild(ArrowSchema& child, const std::string& name = ""); + static void setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes); + static void setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes); + static void setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes); + static void setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType, bool fallbackExtensionTypes); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_nullmask_tree.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_nullmask_tree.h new file mode 100644 index 0000000000..85f0ade76b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_nullmask_tree.h @@ -0,0 +1,42 @@ +#pragma once + +#include "common/arrow/arrow.h" +#include "common/null_mask.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +class ArrowNullMaskTree { +public: + ArrowNullMaskTree(const ArrowSchema* schema, const ArrowArray* array, uint64_t srcOffset, + uint64_t count, const NullMask* parentMask = nullptr); + + void copyToValueVector(ValueVector* vec, uint64_t dstOffset, uint64_t count); + bool isNull(int64_t idx) { return mask->isNull(idx + offset); } + ArrowNullMaskTree* getChild(int idx) { return &(*children)[idx]; } + ArrowNullMaskTree* getDictionary() { return dictionary.get(); } + ArrowNullMaskTree offsetBy(int64_t offset); + +private: + bool copyFromBuffer(const void* buffer, uint64_t srcOffset, uint64_t count); + bool applyParentBitmap(const NullMask* buffer); + + template + void scanListPushDown(const ArrowSchema* schema, const ArrowArray* array, uint64_t srcOffset, + uint64_t count); + + void scanArrayPushDown(const ArrowSchema* schema, const ArrowArray* array, uint64_t srcOffset, + uint64_t count); + + void scanStructPushDown(const ArrowSchema* schema, const ArrowArray* array, uint64_t srcOffset, + uint64_t count); + + int64_t offset; + std::shared_ptr mask; + std::shared_ptr> children; + std::shared_ptr dictionary; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_result_config.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_result_config.h new file mode 100644 index 0000000000..b101ec6eae --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_result_config.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +struct ArrowResultConfig { + int64_t chunkSize; + + ArrowResultConfig() : chunkSize(DEFAULT_CHUNK_SIZE) {} + explicit ArrowResultConfig(int64_t chunkSize) : chunkSize(chunkSize) {} + +private: + static constexpr int64_t DEFAULT_CHUNK_SIZE = 1000; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_row_batch.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_row_batch.h new file mode 100644 index 0000000000..195232bec0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/arrow/arrow_row_batch.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +#include "common/arrow/arrow.h" +#include "common/arrow/arrow_buffer.h" +#include "common/types/types.h" + +struct ArrowSchema; + +namespace lbug { +namespace processor { +class FlatTuple; +} + +namespace common { +class Value; + +// An Arrow Vector(i.e., Array) is defined by a few pieces of metadata and data: +// 1) a logical data type; +// 2) a sequence of buffers: validity bitmaps, data buffer, overflow(optional), children(optional). +// 3) a length as a 64-bit signed integer; +// 4) a null count as a 64-bit signed integer; +// 5) an optional dictionary for dictionary-encoded arrays. +// See https://arrow.apache.org/docs/format/Columnar.html for more details. + +static inline uint64_t getNumBytesForBits(uint64_t numBits) { + return (numBits + 7) / 8; +} + +struct ArrowVector { + ArrowBuffer data; + ArrowBuffer validity; + ArrowBuffer overflow; + + int64_t numValues = 0; + int64_t capacity = 0; + int64_t numNulls = 0; + + std::vector> childData; + + // The arrow array C API data, only set after Finalize + std::unique_ptr array; + std::array buffers = {{nullptr, nullptr, nullptr}}; + std::vector childPointers; +}; + +// An arrow data chunk consisting of N rows in columnar format. +class ArrowRowBatch { +public: + ArrowRowBatch(const std::vector& types, std::int64_t capacity, + bool fallbackExtensionTypes); + + void append(const processor::FlatTuple& tuple); + std::int64_t size() const { return numTuples; } + ArrowArray toArray(const std::vector& types); + +private: + static void appendValue(ArrowVector* vector, const Value& value, bool fallbackExtensionTypes); + + static ArrowArray* convertVectorToArray(ArrowVector& vector, const LogicalType& type, + bool fallbackExtensionTypes); + static ArrowArray* convertStructVectorToArray(ArrowVector& vector, const LogicalType& type, + bool fallbackExtensionTypes); + static ArrowArray* convertInternalIDVectorToArray(ArrowVector& vector, const LogicalType& type, + bool fallbackExtensionTypes); + + static void copyNonNullValue(ArrowVector* vector, const Value& value, std::int64_t pos, + bool fallbackExtensionTypes); + static void copyNullValue(ArrowVector* vector, const Value& value, std::int64_t pos); + + template + static void templateCopyNonNullValue(ArrowVector* vector, const Value& value, std::int64_t pos, + bool fallbackExtensionTypes); + template + static void templateCopyNullValue(ArrowVector* vector, std::int64_t pos); + static void copyNullValueUnion(ArrowVector* vector, const Value& value, std::int64_t pos); + template + static ArrowArray* templateCreateArray(ArrowVector& vector, const LogicalType& type, + bool fallbackExtensionTypes); + +private: + std::vector> vectors; + std::int64_t numTuples; + bool fallbackExtensionTypes = false; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/assert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/assert.h new file mode 100644 index 0000000000..4cf26fd8a1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/assert.h @@ -0,0 +1,36 @@ +#pragma once + +#include "common/exception/internal.h" +#include "common/string_format.h" + +namespace lbug { +namespace common { + +[[noreturn]] inline void kuAssertFailureInternal(const char* condition_name, const char* file, + int linenr) { + // LCOV_EXCL_START + throw InternalException(stringFormat("Assertion failed in file \"{}\" on line {}: {}", file, + linenr, condition_name)); + // LCOV_EXCL_STOP +} + +#define KU_ASSERT_UNCONDITIONAL(condition) \ + static_cast(condition) ? \ + void(0) : \ + lbug::common::kuAssertFailureInternal(#condition, __FILE__, __LINE__) + +#if defined(LBUG_RUNTIME_CHECKS) || !defined(NDEBUG) +#define RUNTIME_CHECK(code) code +#define KU_ASSERT(condition) KU_ASSERT_UNCONDITIONAL(condition) +#else +#define KU_ASSERT(condition) void(0) +#define RUNTIME_CHECK(code) void(0) +#endif + +#define KU_UNREACHABLE \ + /* LCOV_EXCL_START */ [[unlikely]] lbug::common::kuAssertFailureInternal("KU_UNREACHABLE", \ + __FILE__, __LINE__) /* LCOV_EXCL_STOP */ +#define KU_UNUSED(expr) (void)(expr) + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/case_insensitive_map.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/case_insensitive_map.h new file mode 100644 index 0000000000..6406fc2cc8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/case_insensitive_map.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +struct CaseInsensitiveStringHashFunction { + LBUG_API uint64_t operator()(const std::string& str) const; +}; + +struct CaseInsensitiveStringEquality { + LBUG_API bool operator()(const std::string& lhs, const std::string& rhs) const; +}; + +template +using case_insensitive_map_t = std::unordered_map; + +using case_insensitve_set_t = std::unordered_set; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/cast.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/cast.h new file mode 100644 index 0000000000..b92a991f2a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/cast.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include "common/assert.h" + +namespace lbug { +namespace common { + +template +TO ku_dynamic_cast(FROM* old) { +#if defined(LBUG_RUNTIME_CHECKS) || !defined(NDEBUG) + static_assert(std::is_pointer()); + TO newVal = dynamic_cast(old); + KU_ASSERT(newVal != nullptr); + return newVal; +#else + return reinterpret_cast(old); +#endif +} + +template +TO ku_dynamic_cast(FROM& old) { +#if defined(LBUG_RUNTIME_CHECKS) || !defined(NDEBUG) + static_assert(std::is_reference()); + try { + TO newVal = dynamic_cast(old); + return newVal; + } catch (std::bad_cast& e) { + KU_ASSERT(false); + } +#else + return reinterpret_cast(old); +#endif +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/checksum.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/checksum.h new file mode 100644 index 0000000000..bd59326f17 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/checksum.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include +namespace lbug::common { + +//! Compute a checksum over a buffer of size size +uint64_t checksum(uint8_t* buffer, size_t size); + +} // namespace lbug::common diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/concurrent_vector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/concurrent_vector.h new file mode 100644 index 0000000000..e61dd6f08b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/concurrent_vector.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/system_config.h" + +namespace lbug { +namespace common { + +template +// Vector which doesn't move when resizing +// The initial size is fixed, and new elements are added in fixed sized blocks which are indexed and +// the indices are chained in a linked list. Currently only one thread can write concurrently, but +// any number of threads can read, even when the vector is being written to. +// +// Accessing elements which existed when the vector was created is as fast as possible, requiring +// just one comparison and one pointer reads, and accessing new elements is still reasonably fast, +// usually requiring reading just two pointers, with a small amount of arithmetic, or maybe more if +// an extremely large number of elements has been added +// (access cost increases every BLOCK_SIZE * INDEX_SIZE elements). +class ConcurrentVector { +public: + explicit ConcurrentVector(uint64_t initialNumElements, uint64_t initialBlockSize) + : numElements{initialNumElements}, initialBlock{std::make_unique(initialBlockSize)}, + initialBlockSize{initialBlockSize}, firstIndex{nullptr} {} + // resize never deallocates memory + // Not thread-safe + // It could be made to be thread-safe by storing the size atomically and doing compare and swap + // when adding new indices and blocks + void resize(uint64_t newSize) { + while (newSize > initialBlockSize + blocks.size() * BLOCK_SIZE) { + auto newBlock = std::make_unique(); + if (indices.empty()) { + auto index = std::make_unique(); + index->blocks[0] = newBlock.get(); + index->numBlocks = 1; + firstIndex = index.get(); + indices.push_back(std::move(index)); + } else if (indices.back()->numBlocks < INDEX_SIZE) { + auto& index = indices.back(); + index->blocks[index->numBlocks] = newBlock.get(); + index->numBlocks++; + } else { + KU_ASSERT(indices.back()->numBlocks == INDEX_SIZE); + auto index = std::make_unique(); + index->blocks[0] = newBlock.get(); + index->numBlocks = 1; + indices.back()->nextIndex = index.get(); + indices.push_back(std::move(index)); + } + blocks.push_back(std::move(newBlock)); + } + numElements = newSize; + } + + void push_back(T&& value) { + auto index = numElements; + resize(numElements + 1); + (*this)[index] = std::move(value); + } + + T& operator[](uint64_t elemPos) { + if (elemPos < initialBlockSize) { + KU_ASSERT(initialBlock); + return initialBlock[elemPos]; + } else { + auto blockNum = (elemPos - initialBlockSize) / BLOCK_SIZE; + auto posInBlock = (elemPos - initialBlockSize) % BLOCK_SIZE; + auto indexNum = blockNum / INDEX_SIZE; + BlockIndex* index = firstIndex; + KU_ASSERT(index != nullptr); + while (indexNum > 0) { + KU_ASSERT(index->nextIndex != nullptr); + index = index->nextIndex; + indexNum--; + } + KU_ASSERT(index->blocks[blockNum % INDEX_SIZE] != nullptr); + return index->blocks[blockNum % INDEX_SIZE]->data[posInBlock]; + } + } + + uint64_t size() { return numElements; } + +private: + uint64_t numElements; + std::unique_ptr initialBlock; + uint64_t initialBlockSize; + struct Block { + std::array data; + }; + struct BlockIndex { + BlockIndex() : nextIndex{nullptr}, blocks{}, numBlocks{0} {} + BlockIndex* nextIndex; + std::array blocks; + uint64_t numBlocks; + }; + BlockIndex* firstIndex; + std::vector> blocks; + std::vector> indices; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/constants.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/constants.h new file mode 100644 index 0000000000..d9b3d58757 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/constants.h @@ -0,0 +1,223 @@ +#pragma once + +#include +#include +#include + +#include "common/array_utils.h" +#include "common/types/types.h" + +namespace lbug { +namespace common { + +extern const char* LBUG_VERSION; + +constexpr double DEFAULT_HT_LOAD_FACTOR = 1.5; + +// This is the default thread sleep time we use when a thread, +// e.g., a worker thread is in TaskScheduler, needs to block. +constexpr uint64_t THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS = 500; + +constexpr uint64_t DEFAULT_CHECKPOINT_WAIT_TIMEOUT_IN_MICROS = 5000000; + +// Note that some places use std::bit_ceil to calculate resizes, +// which won't work for values other than 2. If this is changed, those will need to be updated +constexpr uint64_t CHUNK_RESIZE_RATIO = 2; + +struct InternalKeyword { + static constexpr char ANONYMOUS[] = ""; + static constexpr char ID[] = "_ID"; + static constexpr char LABEL[] = "_LABEL"; + static constexpr char SRC[] = "_SRC"; + static constexpr char DST[] = "_DST"; + static constexpr char DIRECTION[] = "_DIRECTION"; + static constexpr char LENGTH[] = "_LENGTH"; + static constexpr char NODES[] = "_NODES"; + static constexpr char RELS[] = "_RELS"; + static constexpr char STAR[] = "*"; + static constexpr char PLACE_HOLDER[] = "_PLACE_HOLDER"; + static constexpr char MAP_KEY[] = "KEY"; + static constexpr char MAP_VALUE[] = "VALUE"; + + static constexpr std::string_view ROW_OFFSET = "_row_offset"; + static constexpr std::string_view SRC_OFFSET = "_src_offset"; + static constexpr std::string_view DST_OFFSET = "_dst_offset"; +}; + +enum PageSizeClass : uint8_t { + REGULAR_PAGE = 0, + TEMP_PAGE = 1, +}; + +struct BufferPoolConstants { + // If a user does not specify a max size for BM, we by default set the max size of BM to + // maxPhyMemSize * DEFAULT_PHY_MEM_SIZE_RATIO_FOR_BM. + static constexpr double DEFAULT_PHY_MEM_SIZE_RATIO_FOR_BM = 0.8; +// The default max size for a VMRegion. +#ifdef __32BIT__ + static constexpr uint64_t DEFAULT_VM_REGION_MAX_SIZE = (uint64_t)1 << 30; // (1GB) +#elif defined(__ANDROID__) + static constexpr uint64_t DEFAULT_VM_REGION_MAX_SIZE = (uint64_t)1 << 38; // (256GB) +#else + static constexpr uint64_t DEFAULT_VM_REGION_MAX_SIZE = static_cast(1) << 43; // (8TB) +#endif +}; + +struct StorageConstants { + static constexpr page_idx_t DB_HEADER_PAGE_IDX = 0; + static constexpr char WAL_FILE_SUFFIX[] = "wal"; + static constexpr char SHADOWING_SUFFIX[] = "shadow"; + static constexpr char TEMP_FILE_SUFFIX[] = "tmp"; + + // The number of pages that we add at one time when we need to grow a file. + static constexpr uint64_t PAGE_GROUP_SIZE_LOG2 = 10; + static constexpr uint64_t PAGE_GROUP_SIZE = static_cast(1) << PAGE_GROUP_SIZE_LOG2; + static constexpr uint64_t PAGE_IDX_IN_GROUP_MASK = + (static_cast(1) << PAGE_GROUP_SIZE_LOG2) - 1; + + static constexpr double PACKED_CSR_DENSITY = 0.8; + static constexpr double LEAF_HIGH_CSR_DENSITY = 1.0; + + static constexpr uint64_t MAX_NUM_ROWS_IN_TABLE = static_cast(1) << 62; +}; + +struct TableOptionConstants { + static constexpr char REL_STORAGE_DIRECTION_OPTION[] = "STORAGE_DIRECTION"; +}; + +// Hash Index Configurations +struct HashIndexConstants { + static constexpr uint16_t SLOT_CAPACITY_BYTES = 256; + static constexpr uint64_t NUM_HASH_INDEXES_LOG2 = 8; + static constexpr uint64_t NUM_HASH_INDEXES = 1 << NUM_HASH_INDEXES_LOG2; +}; + +struct CopyConstants { + // Initial size of buffer for CSV Reader. + static constexpr uint64_t INITIAL_BUFFER_SIZE = 16384; + // This means that we will usually read the entirety of the contents of the file we need for a + // block in one read request. It is also very small, which means we can parallelize small files + // efficiently. + static constexpr uint64_t PARALLEL_BLOCK_SIZE = INITIAL_BUFFER_SIZE / 2; + + static constexpr const char* IGNORE_ERRORS_OPTION_NAME = "IGNORE_ERRORS"; + + static constexpr const char* FROM_OPTION_NAME = "FROM"; + static constexpr const char* TO_OPTION_NAME = "TO"; + + static constexpr const char* BOOL_CSV_PARSING_OPTIONS[] = {"HEADER", "PARALLEL", + "LIST_UNBRACED", "AUTODETECT", "AUTO_DETECT", CopyConstants::IGNORE_ERRORS_OPTION_NAME}; + static constexpr bool DEFAULT_CSV_HAS_HEADER = false; + static constexpr bool DEFAULT_CSV_PARALLEL = true; + + // Default configuration for csv file parsing + static constexpr const char* STRING_CSV_PARSING_OPTIONS[] = {"ESCAPE", "DELIM", "DELIMITER", + "QUOTE"}; + static constexpr char DEFAULT_CSV_ESCAPE_CHAR = '"'; + static constexpr char DEFAULT_CSV_DELIMITER = ','; + static constexpr bool DEFAULT_CSV_ALLOW_UNBRACED_LIST = false; + static constexpr char DEFAULT_CSV_QUOTE_CHAR = '"'; + static constexpr char DEFAULT_CSV_LIST_BEGIN_CHAR = '['; + static constexpr char DEFAULT_CSV_LIST_END_CHAR = ']'; + static constexpr bool DEFAULT_IGNORE_ERRORS = false; + static constexpr bool DEFAULT_CSV_AUTO_DETECT = true; + static constexpr bool DEFAULT_CSV_SET_DIALECT = false; + static constexpr std::array DEFAULT_CSV_DELIMITER_SEARCH_SPACE = {',', ';', '\t', '|'}; + static constexpr std::array DEFAULT_CSV_QUOTE_SEARCH_SPACE = {'"', '\''}; + static constexpr std::array DEFAULT_CSV_ESCAPE_SEARCH_SPACE = {'"', '\\', '\''}; + static constexpr std::array DEFAULT_CSV_NULL_STRINGS = {""}; + + static constexpr const char* INT_CSV_PARSING_OPTIONS[] = {"SKIP", "SAMPLE_SIZE"}; + static constexpr uint64_t DEFAULT_CSV_SKIP_NUM = 0; + static constexpr uint64_t DEFAULT_CSV_TYPE_DEDUCTION_SAMPLE_SIZE = 256; + + static constexpr const char* LIST_CSV_PARSING_OPTIONS[] = {"NULL_STRINGS"}; + + // metadata columns used to populate CSV warnings + static constexpr std::array SHARED_WARNING_DATA_COLUMN_NAMES = {"blockIdx", "offsetInBlock", + "startByteOffset", "endByteOffset"}; + static constexpr std::array SHARED_WARNING_DATA_COLUMN_TYPES = {LogicalTypeID::UINT64, + LogicalTypeID::UINT32, LogicalTypeID::UINT64, LogicalTypeID::UINT64}; + static constexpr column_id_t SHARED_WARNING_DATA_NUM_COLUMNS = + SHARED_WARNING_DATA_COLUMN_NAMES.size(); + + static constexpr std::array CSV_SPECIFIC_WARNING_DATA_COLUMN_NAMES = {"fileIdx"}; + static constexpr std::array CSV_SPECIFIC_WARNING_DATA_COLUMN_TYPES = {LogicalTypeID::UINT32}; + + static constexpr std::array CSV_WARNING_DATA_COLUMN_NAMES = + arrayConcat(SHARED_WARNING_DATA_COLUMN_NAMES, CSV_SPECIFIC_WARNING_DATA_COLUMN_NAMES); + static constexpr std::array CSV_WARNING_DATA_COLUMN_TYPES = + arrayConcat(SHARED_WARNING_DATA_COLUMN_TYPES, CSV_SPECIFIC_WARNING_DATA_COLUMN_TYPES); + static constexpr column_id_t CSV_WARNING_DATA_NUM_COLUMNS = + CSV_WARNING_DATA_COLUMN_NAMES.size(); + static_assert(CSV_WARNING_DATA_NUM_COLUMNS == CSV_WARNING_DATA_COLUMN_TYPES.size()); + + static constexpr column_id_t MAX_NUM_WARNING_DATA_COLUMNS = CSV_WARNING_DATA_NUM_COLUMNS; +}; + +struct PlannerKnobs { + static constexpr double NON_EQUALITY_PREDICATE_SELECTIVITY = 0.1; + static constexpr double EQUALITY_PREDICATE_SELECTIVITY = 0.01; + static constexpr uint64_t BUILD_PENALTY = 2; + // Avoid doing probe to build SIP if we have to accumulate a probe side that is much bigger than + // build side. Also avoid doing build to probe SIP if probe side is not much bigger than build. + static constexpr uint64_t SIP_RATIO = 5; +}; + +struct OrderByConstants { + static constexpr uint64_t NUM_BYTES_FOR_PAYLOAD_IDX = 8; + static constexpr uint64_t MIN_LIMIT_RATIO_TO_REDUCE = 2; +}; + +struct ParquetConstants { + static constexpr uint64_t PARQUET_DEFINE_VALID = 65535; + static constexpr const char* PARQUET_MAGIC_WORDS = "PAR1"; + // We limit the uncompressed page size to 100MB. + // The max size in Parquet is 2GB, but we choose a more conservative limit. + static constexpr uint64_t MAX_UNCOMPRESSED_PAGE_SIZE = 100000000; + // Dictionary pages must be below 2GB. Unlike data pages, there's only one dictionary page. + // For this reason we go with a much higher, but still a conservative upper bound of 1GB. + static constexpr uint64_t MAX_UNCOMPRESSED_DICT_PAGE_SIZE = 1e9; + // The maximum size a key entry in an RLE page takes. + static constexpr uint64_t MAX_DICTIONARY_KEY_SIZE = sizeof(uint32_t); + // The size of encoding the string length. + static constexpr uint64_t STRING_LENGTH_SIZE = sizeof(uint32_t); + static constexpr uint64_t MAX_STRING_STATISTICS_SIZE = 10000; + static constexpr uint64_t PARQUET_INTERVAL_SIZE = 12; + static constexpr uint64_t PARQUET_UUID_SIZE = 16; +}; + +struct ExportCSVConstants { + static constexpr const char* DEFAULT_CSV_NEWLINE = "\n\r"; + static constexpr const char* DEFAULT_NULL_STR = ""; + static constexpr bool DEFAULT_FORCE_QUOTE = false; + static constexpr uint64_t DEFAULT_CSV_FLUSH_SIZE = 4096 * 8; +}; + +struct PortDBConstants { + static constexpr char INDEX_FILE_NAME[] = "index.cypher"; + static constexpr char SCHEMA_FILE_NAME[] = "schema.cypher"; + static constexpr char COPY_FILE_NAME[] = "copy.cypher"; + static constexpr const char* SCHEMA_ONLY_OPTION = "SCHEMA_ONLY"; + static constexpr const char* EXPORT_FORMAT_OPTION = "FORMAT"; + static constexpr const char* DEFAULT_EXPORT_FORMAT_OPTION = "PARQUET"; +}; + +struct WarningConstants { + static constexpr std::array WARNING_TABLE_COLUMN_NAMES{"query_id", "message", "file_path", + "line_number", "skipped_line_or_record"}; + static constexpr std::array WARNING_TABLE_COLUMN_DATA_TYPES{LogicalTypeID::UINT64, + LogicalTypeID::STRING, LogicalTypeID::STRING, LogicalTypeID::UINT64, LogicalTypeID::STRING}; + static constexpr uint64_t WARNING_TABLE_NUM_COLUMNS = WARNING_TABLE_COLUMN_NAMES.size(); + + static_assert(WARNING_TABLE_COLUMN_DATA_TYPES.size() == WARNING_TABLE_NUM_COLUMNS); +}; + +static constexpr char ATTACHED_LBUG_DB_TYPE[] = "LBUG"; + +static constexpr char LOCAL_DB_NAME[] = "local(lbug)"; + +constexpr auto DECIMAL_PRECISION_LIMIT = 38; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copier_config/csv_reader_config.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copier_config/csv_reader_config.h new file mode 100644 index 0000000000..5f80780367 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copier_config/csv_reader_config.h @@ -0,0 +1,112 @@ +#pragma once + +#include "common/case_insensitive_map.h" +#include "common/constants.h" +#include "common/copy_constructors.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { + +struct CSVOption { + // TODO(Xiyang): Add newline character option and delimiter can be a string. + char escapeChar; + char delimiter; + char quoteChar; + bool hasHeader; + uint64_t skipNum; + uint64_t sampleSize; + bool allowUnbracedList; + bool ignoreErrors; + + bool autoDetection; + // These fields aim to identify whether the options are set by user, or set by default. + bool setEscape; + bool setDelim; + bool setQuote; + bool setHeader; + std::vector nullStrings; + + CSVOption() + : escapeChar{CopyConstants::DEFAULT_CSV_ESCAPE_CHAR}, + delimiter{CopyConstants::DEFAULT_CSV_DELIMITER}, + quoteChar{CopyConstants::DEFAULT_CSV_QUOTE_CHAR}, + hasHeader{CopyConstants::DEFAULT_CSV_HAS_HEADER}, + skipNum{CopyConstants::DEFAULT_CSV_SKIP_NUM}, + sampleSize{CopyConstants::DEFAULT_CSV_TYPE_DEDUCTION_SAMPLE_SIZE}, + allowUnbracedList{CopyConstants::DEFAULT_CSV_ALLOW_UNBRACED_LIST}, + ignoreErrors(CopyConstants::DEFAULT_IGNORE_ERRORS), + autoDetection{CopyConstants::DEFAULT_CSV_AUTO_DETECT}, + setEscape{CopyConstants::DEFAULT_CSV_SET_DIALECT}, + setDelim{CopyConstants::DEFAULT_CSV_SET_DIALECT}, + setQuote{CopyConstants::DEFAULT_CSV_SET_DIALECT}, + setHeader{CopyConstants::DEFAULT_CSV_SET_DIALECT}, + nullStrings{CopyConstants::DEFAULT_CSV_NULL_STRINGS[0]} {} + + EXPLICIT_COPY_DEFAULT_MOVE(CSVOption); + + // TODO: COPY FROM and COPY TO should support transform special options, like '\'. + std::unordered_map toOptionsMap(const bool& parallel) const { + std::unordered_map result; + result["parallel"] = parallel ? "true" : "false"; + if (setHeader) { + result["header"] = hasHeader ? "true" : "false"; + } + if (setEscape) { + result["escape"] = stringFormat("'\\{}'", escapeChar); + } + if (setDelim) { + result["delim"] = stringFormat("'{}'", delimiter); + } + if (setQuote) { + result["quote"] = stringFormat("'\\{}'", quoteChar); + } + if (autoDetection != CopyConstants::DEFAULT_CSV_AUTO_DETECT) { + result["auto_detect"] = autoDetection ? "true" : "false"; + } + return result; + } + + static std::string toCypher(const std::unordered_map& options) { + if (options.empty()) { + return ""; + } + std::string result = ""; + for (const auto& [key, value] : options) { + if (!result.empty()) { + result += ", "; + } + result += key + "=" + value; + } + return "(" + result + ")"; + } + + // Explicit copy constructor + CSVOption(const CSVOption& other) + : escapeChar{other.escapeChar}, delimiter{other.delimiter}, quoteChar{other.quoteChar}, + hasHeader{other.hasHeader}, skipNum{other.skipNum}, + sampleSize{other.sampleSize == 0 ? + CopyConstants::DEFAULT_CSV_TYPE_DEDUCTION_SAMPLE_SIZE : + other.sampleSize}, // Set to DEFAULT_CSV_TYPE_DEDUCTION_SAMPLE_SIZE if + // sampleSize is 0 + allowUnbracedList{other.allowUnbracedList}, ignoreErrors{other.ignoreErrors}, + autoDetection{other.autoDetection}, setEscape{other.setEscape}, setDelim{other.setDelim}, + setQuote{other.setQuote}, setHeader{other.setHeader}, nullStrings{other.nullStrings} {} +}; + +struct CSVReaderConfig { + CSVOption option; + bool parallel; + + CSVReaderConfig() : option{}, parallel{CopyConstants::DEFAULT_CSV_PARALLEL} {} + EXPLICIT_COPY_DEFAULT_MOVE(CSVReaderConfig); + + static CSVReaderConfig construct(const case_insensitive_map_t& options); + +private: + CSVReaderConfig(const CSVReaderConfig& other) + : option{other.option.copy()}, parallel{other.parallel} {} +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copier_config/file_scan_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copier_config/file_scan_info.h new file mode 100644 index 0000000000..5f4790b0e4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copier_config/file_scan_info.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include "common/case_insensitive_map.h" +#include "common/copy_constructors.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { + +enum class FileType : uint8_t { + UNKNOWN = 0, + CSV = 1, + PARQUET = 2, + NPY = 3, +}; + +struct FileTypeInfo { + FileType fileType = FileType::UNKNOWN; + std::string fileTypeStr; +}; + +struct FileTypeUtils { + static FileType getFileTypeFromExtension(std::string_view extension); + static std::string toString(FileType fileType); + static FileType fromString(std::string fileType); +}; + +struct FileScanInfo { + static constexpr const char* FILE_FORMAT_OPTION_NAME = "FILE_FORMAT"; + + FileTypeInfo fileTypeInfo; + std::vector filePaths; + case_insensitive_map_t options; + + FileScanInfo() : fileTypeInfo{FileType::UNKNOWN, ""} {} + FileScanInfo(FileTypeInfo fileTypeInfo, std::vector filePaths) + : fileTypeInfo{std::move(fileTypeInfo)}, filePaths{std::move(filePaths)} {} + EXPLICIT_COPY_DEFAULT_MOVE(FileScanInfo); + + uint32_t getNumFiles() const { return filePaths.size(); } + std::string getFilePath(idx_t fileIdx) const { + KU_ASSERT(fileIdx < getNumFiles()); + return filePaths[fileIdx]; + } + + template + T getOption(std::string optionName, T defaultValue) const { + const auto optionIt = options.find(optionName); + if (optionIt != options.end()) { + return optionIt->second.getValue(); + } else { + return defaultValue; + } + } + +private: + FileScanInfo(const FileScanInfo& other) + : fileTypeInfo{other.fileTypeInfo}, filePaths{other.filePaths}, options{other.options} {} +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copy_constructors.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copy_constructors.h new file mode 100644 index 0000000000..98d93c98e8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/copy_constructors.h @@ -0,0 +1,117 @@ +#pragma once +#include +#include +#include +#include +// This file defines many macros for controlling copy constructors and move constructors on classes. + +// NOLINTBEGIN(bugprone-macro-parentheses): Although this is a good check in general, here, we +// cannot add parantheses around the arguments, for it would be invalid syntax. +#define DELETE_COPY_CONSTRUCT(Object) Object(const Object& other) = delete +#define DELETE_COPY_ASSN(Object) Object& operator=(const Object& other) = delete + +#define DELETE_MOVE_CONSTRUCT(Object) Object(Object&& other) = delete +#define DELETE_MOVE_ASSN(Object) Object& operator=(Object&& other) = delete + +#define DELETE_BOTH_COPY(Object) \ + DELETE_COPY_CONSTRUCT(Object); \ + DELETE_COPY_ASSN(Object) + +#define DELETE_BOTH_MOVE(Object) \ + DELETE_MOVE_CONSTRUCT(Object); \ + DELETE_MOVE_ASSN(Object) + +#define DEFAULT_MOVE_CONSTRUCT(Object) Object(Object&& other) = default +#define DEFAULT_MOVE_ASSN(Object) Object& operator=(Object&& other) = default + +#define DEFAULT_BOTH_MOVE(Object) \ + DEFAULT_MOVE_CONSTRUCT(Object); \ + DEFAULT_MOVE_ASSN(Object) + +#define EXPLICIT_COPY_METHOD(Object) \ + Object copy() const { \ + return *this; \ + } + +// EXPLICIT_COPY_DEFAULT_MOVE should be the default choice. It expects a PRIVATE copy constructor to +// be defined, which will be used by an explicit `copy()` method. For instance: +// +// private: +// MyClass(const MyClass& other) : field(other.field.copy()) {} +// +// public: +// EXPLICIT_COPY_DEFAULT_MOVE(MyClass); +// +// Now: +// +// MyClass o1; +// MyClass o2 = o1; // Compile error, copy assignment deleted. +// MyClass o2 = o1.copy(); // OK. +// MyClass o2(o1); // Compile error, copy constructor is private. +#define EXPLICIT_COPY_DEFAULT_MOVE(Object) \ + DELETE_COPY_ASSN(Object); \ + DEFAULT_BOTH_MOVE(Object); \ + EXPLICIT_COPY_METHOD(Object) + +// NO_COPY should be used for objects that for whatever reason, should never be copied, but can be +// moved. +#define DELETE_COPY_DEFAULT_MOVE(Object) \ + DELETE_BOTH_COPY(Object); \ + DEFAULT_BOTH_MOVE(Object) + +// NO_MOVE_OR_COPY exists solely for explicitness, when an object cannot be moved nor copied. Any +// object containing a lock cannot be moved or copied. +#define DELETE_COPY_AND_MOVE(Object) \ + DELETE_BOTH_COPY(Object); \ + DELETE_BOTH_MOVE(Object) +// NOLINTEND(bugprone-macro-parentheses): + +template +static std::vector copyVector(const std::vector& objects) { + std::vector result; + result.reserve(objects.size()); + for (auto& object : objects) { + result.push_back(object.copy()); + } + return result; +} + +template +static std::vector> copyVector(const std::vector>& objects) { + std::vector> result; + result.reserve(objects.size()); + for (auto& object : objects) { + T& ob = *object; + result.push_back(ob.copy()); + } + return result; +} + +template +static std::vector> copyVector(const std::vector>& objects) { + std::vector> result; + result.reserve(objects.size()); + for (auto& object : objects) { + T& ob = *object; + result.push_back(ob.copy()); + } + return result; +} + +template +static std::unordered_map copyUnorderedMap(const std::unordered_map& objects) { + std::unordered_map result; + for (auto& [k, v] : objects) { + result.insert({k, v.copy()}); + } + return result; +} + +template +static std::map copyMap(const std::map& objects) { + std::map result; + for (auto& [k, v] : objects) { + result.insert({k, v.copy()}); + } + return result; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/counter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/counter.h new file mode 100644 index 0000000000..50f35dde84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/counter.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "types/types.h" + +namespace lbug { +namespace common { + +class LimitCounter { +public: + explicit LimitCounter(common::offset_t limitNumber) : limitNumber{limitNumber} { + counter.store(0); + } + + void increase(common::offset_t number) { counter.fetch_add(number); } + + bool exceedLimit() const { return counter.load() >= limitNumber; } + +private: + common::offset_t limitNumber; + std::atomic counter; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk.h new file mode 100644 index 0000000000..9d0423191b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include + +#include "common/copy_constructors.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +// A DataChunk represents tuples as a set of value vectors and a selector array. +// The data chunk represents a subset of a relation i.e., a set of tuples as +// lists of the same length. It is appended into DataChunks and passed as intermediate +// representations between operators. +// A data chunk further contains a DataChunkState, which keeps the data chunk's size, selector, and +// currIdx (used when flattening and implies the value vector only contains the elements at currIdx +// of each value vector). +class LBUG_API DataChunk { +public: + DataChunk() : DataChunk{0} {} + explicit DataChunk(uint32_t numValueVectors) + : DataChunk(numValueVectors, std::make_shared()){}; + + DataChunk(uint32_t numValueVectors, const std::shared_ptr& state) + : valueVectors(numValueVectors), state{state} {}; + DELETE_COPY_DEFAULT_MOVE(DataChunk); + + void insert(uint32_t pos, std::shared_ptr valueVector); + + void resetAuxiliaryBuffer(); + + uint32_t getNumValueVectors() const { return valueVectors.size(); } + + const ValueVector& getValueVector(uint64_t valueVectorPos) const { + return *valueVectors[valueVectorPos]; + } + ValueVector& getValueVectorMutable(uint64_t valueVectorPos) const { + return *valueVectors[valueVectorPos]; + } + +public: + std::vector> valueVectors; + std::shared_ptr state; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk_collection.h new file mode 100644 index 0000000000..fde1db8e4a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk_collection.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/data_chunk/data_chunk.h" + +namespace lbug { +namespace common { + +// TODO(Guodong): Should rework this to use ColumnChunk. +class DataChunkCollection { +public: + explicit DataChunkCollection(storage::MemoryManager* mm); + DELETE_COPY_DEFAULT_MOVE(DataChunkCollection); + + void append(DataChunk& chunk); + const std::vector& getChunks() const { return chunks; } + std::vector& getChunksUnsafe() { return chunks; } + uint64_t getNumChunks() const { return chunks.size(); } + const DataChunk& getChunk(uint64_t idx) const { + KU_ASSERT(idx < chunks.size()); + return chunks[idx]; + } + DataChunk& getChunkUnsafe(uint64_t idx) { + KU_ASSERT(idx < chunks.size()); + return chunks[idx]; + } + +private: + void allocateChunk(const DataChunk& chunk); + + void initTypes(const DataChunk& chunk); + +private: + storage::MemoryManager* mm; + std::vector types; + std::vector chunks; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk_state.h new file mode 100644 index 0000000000..cbc6b5ea9a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/data_chunk_state.h @@ -0,0 +1,44 @@ +#pragma once + +#include "common/data_chunk/sel_vector.h" + +namespace lbug { +namespace common { + +// F stands for Factorization +enum class FStateType : uint8_t { + FLAT = 0, + UNFLAT = 1, +}; + +class LBUG_API DataChunkState { +public: + DataChunkState(); + explicit DataChunkState(sel_t capacity) : fStateType{FStateType::UNFLAT} { + selVector = std::make_shared(capacity); + } + + // returns a dataChunkState for vectors holding a single value. + static std::shared_ptr getSingleValueDataChunkState(); + + void initOriginalAndSelectedSize(uint64_t size) { selVector->setSelSize(size); } + bool isFlat() const { return fStateType == FStateType::FLAT; } + void setToFlat() { fStateType = FStateType::FLAT; } + void setToUnflat() { fStateType = FStateType::UNFLAT; } + + const SelectionVector& getSelVector() const { return *selVector; } + sel_t getSelSize() const { return selVector->getSelSize(); } + SelectionVector& getSelVectorUnsafe() { return *selVector; } + std::shared_ptr getSelVectorShared() { return selVector; } + void setSelVector(std::shared_ptr selVector_) { + this->selVector = std::move(selVector_); + } + +private: + std::shared_ptr selVector; + // TODO: We should get rid of `fStateType` and merge DataChunkState with SelectionVector. + FStateType fStateType; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/sel_vector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/sel_vector.h new file mode 100644 index 0000000000..411e0d8844 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/data_chunk/sel_vector.h @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include + +#include "common/types/types.h" +#include + +namespace lbug { +namespace common { + +class ValueVector; + +// A lightweight, immutable view over a SelectionVector, or a subsequence of a selection vector +// SelectionVectors are also SelectionViews so that you can pass a SelectionVector to functions +// which take a SelectionView& +class SelectionView { +protected: + // In DYNAMIC mode, selectedPositions points to a mutable buffer that can be modified through + // getMutableBuffer In STATIC mode, selectedPositions points to somewhere in + // INCREMENTAL_SELECTED_POS + // Note that the vector is considered unfiltered only if it is both STATIC and the first + // selected position is 0 + enum class State { + DYNAMIC, + STATIC, + }; + +public: + // STATIC selectionView over 0..selectedSize + explicit SelectionView(sel_t selectedSize); + + template + void forEach(Func&& func) const { + if (state == State::DYNAMIC) { + for (size_t i = 0; i < selectedSize; i++) { + func(selectedPositions[i]); + } + } else { + const auto start = selectedPositions[0]; + for (size_t i = start; i < start + selectedSize; i++) { + func(i); + } + } + } + + template + void forEachBreakWhenFalse(Func&& func) const { + if (state == State::DYNAMIC) { + for (size_t i = 0; i < selectedSize; i++) { + if (!func(selectedPositions[i])) { + break; + } + } + } else { + const auto start = selectedPositions[0]; + for (size_t i = start; i < start + selectedSize; i++) { + if (!func(i)) { + break; + } + } + } + } + + sel_t getSelSize() const { return selectedSize; } + + sel_t operator[](sel_t index) const { + KU_ASSERT(index < selectedSize); + return selectedPositions[index]; + } + + bool isUnfiltered() const { return state == State::STATIC && selectedPositions[0] == 0; } + bool isStatic() const { return state == State::STATIC; } + + std::span getSelectedPositions() const { + return std::span(selectedPositions, selectedSize); + } + +protected: + static SelectionView slice(std::span selectedPositions, State state) { + return SelectionView(selectedPositions, state); + } + + // Intended to be used only as a subsequence of a SelectionVector in SelectionVector::slice + explicit SelectionView(std::span selectedPositions, State state) + : selectedPositions{selectedPositions.data()}, selectedSize{selectedPositions.size()}, + state{state} {} + +protected: + const sel_t* selectedPositions; + sel_t selectedSize; + State state; +}; + +class SelectionVector : public SelectionView { +public: + explicit SelectionVector(sel_t capacity) + : SelectionView{std::span(), State::STATIC}, + selectedPositionsBuffer{std::make_unique(capacity)}, capacity{capacity} { + setToUnfiltered(); + } + + // This View should be considered invalid if the SelectionVector it was created from has been + // modified + SelectionView slice(sel_t startIndex, sel_t selectedSize) const { + return SelectionView::slice(getSelectedPositions().subspan(startIndex, selectedSize), + state); + } + + SelectionVector(); + + LBUG_API void setToUnfiltered(); + LBUG_API void setToUnfiltered(sel_t size); + void setRange(sel_t startPos, sel_t size) { + KU_ASSERT(startPos + size <= capacity); + selectedPositions = selectedPositionsBuffer.get(); + for (auto i = 0u; i < size; ++i) { + selectedPositionsBuffer[i] = startPos + i; + } + selectedSize = size; + state = State::DYNAMIC; + } + + // Set to filtered is not very accurate. It sets selectedPositions to a mutable array. + void setToFiltered() { + selectedPositions = selectedPositionsBuffer.get(); + state = State::DYNAMIC; + } + void setToFiltered(sel_t size) { + KU_ASSERT(size <= capacity && selectedPositionsBuffer); + setToFiltered(); + selectedSize = size; + } + + // Copies the data in selectedPositions into selectedPositionsBuffer + void makeDynamic() { + memcpy(selectedPositionsBuffer.get(), selectedPositions, selectedSize * sizeof(sel_t)); + state = State::DYNAMIC; + selectedPositions = selectedPositionsBuffer.get(); + } + + std::span getMutableBuffer() const { + return std::span(selectedPositionsBuffer.get(), capacity); + } + + void setSelSize(sel_t size) { + KU_ASSERT(size <= capacity); + selectedSize = size; + } + void incrementSelSize(sel_t increment = 1) { + KU_ASSERT(selectedSize < capacity); + selectedSize += increment; + } + + sel_t operator[](sel_t index) const { + KU_ASSERT(index < capacity); + return const_cast(selectedPositions[index]); + } + sel_t& operator[](sel_t index) { + KU_ASSERT(index < capacity); + return const_cast(selectedPositions[index]); + } + + static std::vector fromValueVectors( + const std::vector>& vec); + +private: + std::unique_ptr selectedPositionsBuffer; + sel_t capacity; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/database_lifecycle_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/database_lifecycle_manager.h new file mode 100644 index 0000000000..a57f6020a6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/database_lifecycle_manager.h @@ -0,0 +1,10 @@ +#pragma once + +namespace lbug { +namespace common { +struct DatabaseLifeCycleManager { + bool isDatabaseClosed = false; + void checkDatabaseClosedOrThrow() const; +}; +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/accumulate_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/accumulate_type.h new file mode 100644 index 0000000000..5bff679c1c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/accumulate_type.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +enum class AccumulateType : uint8_t { + REGULAR = 0, + OPTIONAL_ = 1, +}; + +struct AccumulateTypeUtil { + static std::string toString(AccumulateType type); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/alter_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/alter_type.h new file mode 100644 index 0000000000..7f9c3d9623 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/alter_type.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class AlterType : uint8_t { + RENAME = 0, + + ADD_PROPERTY = 10, + DROP_PROPERTY = 11, + RENAME_PROPERTY = 12, + ADD_FROM_TO_CONNECTION = 13, + DROP_FROM_TO_CONNECTION = 14, + + COMMENT = 201, + INVALID = 255 +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/clause_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/clause_type.h new file mode 100644 index 0000000000..87e4bf8349 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/clause_type.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class ClauseType : uint8_t { + // updating clause + SET = 0, + DELETE_ = 1, // winnt.h defines DELETE as a macro, so we use DELETE_ instead of DELETE. + INSERT = 2, + MERGE = 3, + + // reading clause + MATCH = 10, + UNWIND = 11, + IN_QUERY_CALL = 12, + TABLE_FUNCTION_CALL = 13, + GDS_CALL = 14, + LOAD_FROM = 15, +}; + +enum class MatchClauseType : uint8_t { + MATCH = 0, + OPTIONAL_MATCH = 1, +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/column_evaluate_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/column_evaluate_type.h new file mode 100644 index 0000000000..e39a17f318 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/column_evaluate_type.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class ColumnEvaluateType : uint8_t { + REFERENCE = 0, + DEFAULT = 1, + CAST = 2, +}; + +} +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/conflict_action.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/conflict_action.h new file mode 100644 index 0000000000..ca0ef6b6a3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/conflict_action.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +enum class ConflictAction : uint8_t { + ON_CONFLICT_THROW = 0, + ON_CONFLICT_DO_NOTHING = 1, + INVALID = 255, +}; + +struct ConflictActionUtil { + static std::string toString(ConflictAction conflictAction); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/delete_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/delete_type.h new file mode 100644 index 0000000000..e0f160eda6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/delete_type.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class DeleteNodeType : uint8_t { + DELETE = 0, + DETACH_DELETE = 1, +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/drop_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/drop_type.h new file mode 100644 index 0000000000..2998ef4ce4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/drop_type.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +enum class DropType : uint8_t { + TABLE = 0, + SEQUENCE = 1, + MACRO = 2, +}; + +struct DropTypeUtils { + static std::string toString(DropType type); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/explain_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/explain_type.h new file mode 100644 index 0000000000..726deeb0d1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/explain_type.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class ExplainType : uint8_t { + PROFILE = 0, + LOGICAL_PLAN = 1, + PHYSICAL_PLAN = 2, +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/expression_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/expression_type.h new file mode 100644 index 0000000000..0997abce47 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/expression_type.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +enum class ExpressionType : uint8_t { + // Boolean Connection Expressions + OR = 0, + XOR = 1, + AND = 2, + NOT = 3, + + // Comparison Expressions + EQUALS = 10, + NOT_EQUALS = 11, + GREATER_THAN = 12, + GREATER_THAN_EQUALS = 13, + LESS_THAN = 14, + LESS_THAN_EQUALS = 15, + + // Null Operator Expressions + IS_NULL = 50, + IS_NOT_NULL = 51, + + PROPERTY = 60, + + LITERAL = 70, + + STAR = 80, + + VARIABLE = 90, + PATH = 91, + PATTERN = 92, // Node & Rel pattern + + PARAMETER = 100, + + // At parsing stage, both aggregate and scalar functions have type FUNCTION. + // After binding, only scalar function have type FUNCTION. + FUNCTION = 110, + + AGGREGATE_FUNCTION = 130, + + SUBQUERY = 190, + + CASE_ELSE = 200, + + GRAPH = 210, + + LAMBDA = 220, + + // NOTE: this enum has type uint8_t so don't assign over 255. + INVALID = 255, +}; + +struct ExpressionTypeUtil { + static bool isUnary(ExpressionType type); + static bool isBinary(ExpressionType type); + static bool isBoolean(ExpressionType type); + static bool isComparison(ExpressionType type); + static bool isNullOperator(ExpressionType type); + + static ExpressionType reverseComparisonDirection(ExpressionType type); + + static LBUG_API std::string toString(ExpressionType type); + static std::string toParsableString(ExpressionType type); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/extend_direction.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/extend_direction.h new file mode 100644 index 0000000000..ee8ab2127e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/extend_direction.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 }; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/extend_direction_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/extend_direction_util.h new file mode 100644 index 0000000000..69df2093d3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/extend_direction_util.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common/assert.h" +#include "common/enums/extend_direction.h" +#include "common/enums/rel_direction.h" + +namespace lbug { +namespace common { + +class ExtendDirectionUtil { +public: + static RelDataDirection getRelDataDirection(ExtendDirection direction) { + KU_ASSERT(direction != ExtendDirection::BOTH); + return direction == ExtendDirection::FWD ? RelDataDirection::FWD : RelDataDirection::BWD; + } + + static ExtendDirection fromString(const std::string& str); + static std::string toString(ExtendDirection direction); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/join_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/join_type.h new file mode 100644 index 0000000000..4751925c53 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/join_type.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class JoinType : uint8_t { + INNER = 0, + LEFT = 1, + MARK = 2, + COUNT = 3, +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/path_semantic.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/path_semantic.h new file mode 100644 index 0000000000..803a337a13 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/path_semantic.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +enum class PathSemantic : uint8_t { + WALK = 0, + TRAIL = 1, + ACYCLIC = 2, +}; + +struct PathSemanticUtils { + static PathSemantic fromString(const std::string& str); + static std::string toString(PathSemantic semantic); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/query_rel_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/query_rel_type.h new file mode 100644 index 0000000000..ed1d5f1723 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/query_rel_type.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "path_semantic.h" + +namespace lbug { +namespace function { +class RJAlgorithm; +} + +namespace common { + +enum class QueryRelType : uint8_t { + NON_RECURSIVE = 0, + VARIABLE_LENGTH_WALK = 1, + VARIABLE_LENGTH_TRAIL = 2, + VARIABLE_LENGTH_ACYCLIC = 3, + SHORTEST = 4, + ALL_SHORTEST = 5, + WEIGHTED_SHORTEST = 6, + ALL_WEIGHTED_SHORTEST = 7, +}; + +struct QueryRelTypeUtils { + static bool isRecursive(QueryRelType type) { return type != QueryRelType::NON_RECURSIVE; } + + static bool isWeighted(QueryRelType type) { + return type == QueryRelType::WEIGHTED_SHORTEST || + type == QueryRelType::ALL_WEIGHTED_SHORTEST; + } + + static PathSemantic getPathSemantic(QueryRelType queryRelType); + + static std::unique_ptr getFunction(QueryRelType type); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/rel_direction.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/rel_direction.h new file mode 100644 index 0000000000..064d809cc7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/rel_direction.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include "common/types/types.h" + +namespace lbug { +namespace common { + +enum class RelDataDirection : uint8_t { FWD = 0, BWD = 1, INVALID = 255 }; +static constexpr idx_t NUM_REL_DIRECTIONS = 2; + +struct RelDirectionUtils { + static RelDataDirection getOppositeDirection(RelDataDirection direction); + + static std::string relDirectionToString(RelDataDirection direction); + static idx_t relDirectionToKeyIdx(RelDataDirection direction); + static table_id_t getNbrTableID(RelDataDirection direction, table_id_t srcTableID, + table_id_t dstTableID); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/rel_multiplicity.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/rel_multiplicity.h new file mode 100644 index 0000000000..dc3bcdd684 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/rel_multiplicity.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +enum class RelMultiplicity : uint8_t { MANY, ONE }; + +struct RelMultiplicityUtils { + static RelMultiplicity getFwd(const std::string& str); + static RelMultiplicity getBwd(const std::string& str); + static std::string toString(RelMultiplicity multiplicity); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/scan_source_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/scan_source_type.h new file mode 100644 index 0000000000..43958cca61 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/scan_source_type.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +enum class ScanSourceType : uint8_t { + EMPTY = 0, + FILE = 1, + OBJECT = 2, + QUERY = 3, + TABLE_FUNC = 4, + PARAM = 5, +}; + +class ScanSourceTypeUtils { +public: + static std::string toString(ScanSourceType type); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/statement_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/statement_type.h new file mode 100644 index 0000000000..038c8a4b1d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/statement_type.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class StatementType : uint8_t { + QUERY = 0, + CREATE_TABLE = 1, + DROP = 2, + ALTER = 3, + COPY_TO = 19, + COPY_FROM = 20, + STANDALONE_CALL = 21, + STANDALONE_CALL_FUNCTION = 22, + EXPLAIN = 23, + CREATE_MACRO = 24, + TRANSACTION = 30, + EXTENSION = 31, + EXPORT_DATABASE = 32, + IMPORT_DATABASE = 33, + ATTACH_DATABASE = 34, + DETACH_DATABASE = 35, + USE_DATABASE = 36, + CREATE_SEQUENCE = 37, + CREATE_TYPE = 39, + EXTENSION_CLAUSE = 40, +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/subquery_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/subquery_type.h new file mode 100644 index 0000000000..81f1bdda3c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/subquery_type.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class SubqueryType : uint8_t { + COUNT = 1, + EXISTS = 2, +}; + +} +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/table_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/table_type.h new file mode 100644 index 0000000000..c1aaae971b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/table_type.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +enum class TableType : uint8_t { + UNKNOWN = 0, + NODE = 1, + REL = 2, + FOREIGN = 5, +}; + +struct LBUG_API TableTypeUtils { + static std::string toString(TableType tableType); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/zone_map_check_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/zone_map_check_result.h new file mode 100644 index 0000000000..647e5e9dd7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/enums/zone_map_check_result.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +enum class ZoneMapCheckResult : uint8_t { + ALWAYS_SCAN = 0, + SKIP_SCAN = 1, +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/binder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/binder.h new file mode 100644 index 0000000000..161d0cef3f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/binder.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API BinderException : public Exception { +public: + explicit BinderException(const std::string& msg) : Exception("Binder exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/buffer_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/buffer_manager.h new file mode 100644 index 0000000000..007ee300e4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/buffer_manager.h @@ -0,0 +1,16 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API BufferManagerException : public Exception { +public: + explicit BufferManagerException(const std::string& msg) + : Exception("Buffer manager exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/catalog.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/catalog.h new file mode 100644 index 0000000000..e83ee694d9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/catalog.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API CatalogException : public Exception { +public: + explicit CatalogException(const std::string& msg) : Exception("Catalog exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/checkpoint.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/checkpoint.h new file mode 100644 index 0000000000..6603c2d77a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/checkpoint.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API CheckpointException : public Exception { +public: + explicit CheckpointException(const std::exception& e) : Exception(e.what()){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/connection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/connection.h new file mode 100644 index 0000000000..afc92beb41 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/connection.h @@ -0,0 +1,16 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API ConnectionException : public Exception { +public: + explicit ConnectionException(const std::string& msg) + : Exception("Connection exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/conversion.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/conversion.h new file mode 100644 index 0000000000..d857e39a1d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/conversion.h @@ -0,0 +1,16 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API ConversionException : public Exception { +public: + explicit ConversionException(const std::string& msg) + : Exception("Conversion exception: " + msg) {} +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/copy.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/copy.h new file mode 100644 index 0000000000..1756690da8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/copy.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API CopyException : public Exception { +public: + explicit CopyException(const std::string& msg) : Exception("Copy exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/exception.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/exception.h new file mode 100644 index 0000000000..2e49d6c7e0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/exception.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +class LBUG_API Exception : public std::exception { +public: + explicit Exception(std::string msg); + +public: + const char* what() const noexcept override { return exception_message_.c_str(); } + +private: + std::string exception_message_; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/extension.h new file mode 100644 index 0000000000..52b187b983 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/extension.h @@ -0,0 +1,15 @@ +#pragma once + +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API ExtensionException : public Exception { +public: + explicit ExtensionException(const std::string& msg) + : Exception("Extension exception: " + msg) {} +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/internal.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/internal.h new file mode 100644 index 0000000000..acf5b6fdf1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/internal.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API InternalException : public Exception { +public: + explicit InternalException(const std::string& msg) : Exception(msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/interrupt.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/interrupt.h new file mode 100644 index 0000000000..7df12870fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/interrupt.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API InterruptException : public Exception { +public: + explicit InterruptException() : Exception("Interrupted."){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/io.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/io.h new file mode 100644 index 0000000000..6795882363 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/io.h @@ -0,0 +1,14 @@ +#pragma once + +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API IOException : public Exception { +public: + explicit IOException(const std::string& msg) : Exception("IO exception: " + msg) {} +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/message.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/message.h new file mode 100644 index 0000000000..dd00120252 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/message.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +// Add exception only if you need to throw it in multiple places. +struct ExceptionMessage { + // Primary key. + static std::string duplicatePKException(const std::string& pkString); + static std::string nonExistentPKException(const std::string& pkString); + static std::string invalidPKType(const std::string& type); + static std::string nullPKException(); + // Long string. + static std::string overLargeStringPKValueException(uint64_t length); + static std::string overLargeStringValueException(uint64_t length); + // Foreign key. + static std::string violateDeleteNodeWithConnectedEdgesConstraint(const std::string& tableName, + const std::string& offset, const std::string& direction); + static std::string violateRelMultiplicityConstraint(const std::string& tableName, + const std::string& offset, const std::string& direction); + // Binding exception + static std::string variableNotInScope(const std::string& varName); + static std::string listFunctionIncompatibleChildrenType(const std::string& functionName, + const std::string& leftType, const std::string& rightType); + // Skip limit exception + static std::string invalidSkipLimitParam(const std::string& exprName, + const std::string& skipOrLimit); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/not_implemented.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/not_implemented.h new file mode 100644 index 0000000000..793edba30e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/not_implemented.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API NotImplementedException : public Exception { +public: + explicit NotImplementedException(const std::string& msg) : Exception(msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/overflow.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/overflow.h new file mode 100644 index 0000000000..1ff0b9fd63 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/overflow.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API OverflowException : public Exception { +public: + explicit OverflowException(const std::string& msg) : Exception("Overflow exception: " + msg) {} +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/parser.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/parser.h new file mode 100644 index 0000000000..8a6c0ed0b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/parser.h @@ -0,0 +1,17 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API ParserException : public Exception { +public: + static constexpr const char* ERROR_PREFIX = "Parser exception: "; + + explicit ParserException(const std::string& msg) : Exception(ERROR_PREFIX + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/runtime.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/runtime.h new file mode 100644 index 0000000000..fcb2530599 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/runtime.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API RuntimeException : public Exception { +public: + explicit RuntimeException(const std::string& msg) : Exception("Runtime exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/storage.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/storage.h new file mode 100644 index 0000000000..eafedc76cb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/storage.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API StorageException : public Exception { +public: + explicit StorageException(const std::string& msg) : Exception("Storage exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/test.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/test.h new file mode 100644 index 0000000000..046413a8bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/test.h @@ -0,0 +1,14 @@ +#pragma once + +#include "exception.h" + +namespace lbug { +namespace common { + +class TestException : public Exception { +public: + explicit TestException(const std::string& msg) : Exception("Test exception: " + msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/transaction_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/transaction_manager.h new file mode 100644 index 0000000000..4496c2535e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/exception/transaction_manager.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/api.h" +#include "exception.h" + +namespace lbug { +namespace common { + +class LBUG_API TransactionManagerException : public Exception { +public: + explicit TransactionManagerException(const std::string& msg) : Exception(msg){}; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/compressed_file_system.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/compressed_file_system.h new file mode 100644 index 0000000000..204b1738ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/compressed_file_system.h @@ -0,0 +1,84 @@ +#pragma once + +#include "common/file_system/file_system.h" +#include "common/types/types.h" + +namespace lbug { +namespace common { + +struct StreamData { + bool refresh = false; + std::unique_ptr inputBuf; + std::unique_ptr outputBuf; + uint8_t* inputBufStart = nullptr; + uint8_t* inputBufEnd = nullptr; + uint8_t* outputBufStart = nullptr; + uint8_t* outputBufEnd = nullptr; + common::idx_t inputBufSize = 0; + common::idx_t outputBufSize = 0; +}; + +struct CompressedFileInfo; + +struct StreamWrapper { + virtual ~StreamWrapper() = default; + virtual void initialize(CompressedFileInfo& file) = 0; + virtual bool read(StreamData& stream_data) = 0; + virtual void close() = 0; +}; + +class CompressedFileSystem : public FileSystem { +public: + virtual std::unique_ptr openCompressedFile(std::unique_ptr fileInfo) = 0; + virtual std::unique_ptr createStream() = 0; + virtual idx_t getInputBufSize() = 0; + virtual idx_t getOutputBufSize() = 0; + + bool canPerformSeek() const override { return false; } + +protected: + std::vector glob(main::ClientContext* /*context*/, + const std::string& /*path*/) const override { + KU_UNREACHABLE; + } + + void readFromFile(FileInfo& /*fileInfo*/, void* /*buffer*/, uint64_t /*numBytes*/, + uint64_t /*position*/) const override; + + int64_t readFile(FileInfo& fileInfo, void* buf, size_t numBytes) const override; + + void writeFile(FileInfo& /*fileInfo*/, const uint8_t* /*buffer*/, uint64_t /*numBytes*/, + uint64_t /*offset*/) const override { + KU_UNREACHABLE; + } + + void reset(FileInfo& fileInfo) override; + + int64_t seek(FileInfo& /*fileInfo*/, uint64_t /*offset*/, int /*whence*/) const override { + KU_UNREACHABLE; + } + + uint64_t getFileSize(const FileInfo& fileInfo) const override; + + void syncFile(const FileInfo& fileInfo) const override; +}; + +struct CompressedFileInfo : public FileInfo { + CompressedFileSystem& compressedFS; + std::unique_ptr childFileInfo; + StreamData streamData; + idx_t currentPos = 0; + std::unique_ptr stream_wrapper; + + CompressedFileInfo(CompressedFileSystem& compressedFS, std::unique_ptr childFileInfo) + : FileInfo{childFileInfo->path, &compressedFS}, compressedFS{compressedFS}, + childFileInfo{std::move(childFileInfo)} {} + ~CompressedFileInfo() override { close(); } + + void initialize(); + int64_t readData(void* buffer, size_t numBytes); + void close(); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/file_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/file_info.h new file mode 100644 index 0000000000..8c3926f1a8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/file_info.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "common/cast.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace common { + +class FileSystem; + +struct LBUG_API FileInfo { + FileInfo(std::string path, FileSystem* fileSystem) + : path{std::move(path)}, fileSystem{fileSystem} {} + + virtual ~FileInfo() = default; + + uint64_t getFileSize() const; + + void readFromFile(void* buffer, uint64_t numBytes, uint64_t position); + + int64_t readFile(void* buf, size_t nbyte); + + void writeFile(const uint8_t* buffer, uint64_t numBytes, uint64_t offset); + + void syncFile() const; + + int64_t seek(uint64_t offset, int whence); + + void reset(); + + void truncate(uint64_t size); + + bool canPerformSeek() const; + + virtual function::TableFunction getHandleFunction() const { KU_UNREACHABLE; } + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + + const std::string path; + + FileSystem* fileSystem; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/file_system.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/file_system.h new file mode 100644 index 0000000000..6a83bb88a1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/file_system.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "common/cast.h" +#include "file_info.h" + +namespace lbug { +namespace main { +class ClientContext; +} // namespace main + +namespace common { + +enum class FileLockType : uint8_t { NO_LOCK = 0, READ_LOCK = 1, WRITE_LOCK = 2 }; + +enum class FileCompressionType : uint8_t { AUTO_DETECT = 0, UNCOMPRESSED = 1, GZIP = 2, ZSTD = 3 }; + +struct FileFlags { + static constexpr uint8_t READ_ONLY = 1 << 0; + static constexpr uint8_t WRITE = 1 << 1; + // Create file if not exists, can only be used together with WRITE + static constexpr uint8_t CREATE_IF_NOT_EXISTS = 1 << 3; + // Always create a new file. If a file exists, the file is truncated. Cannot be used together + // with CREATE_IF_NOT_EXISTS. + static constexpr uint8_t CREATE_AND_TRUNCATE_IF_EXISTS = 1 << 4; + // Temporary file that is not persisted to disk. + static constexpr uint8_t TEMPORARY = 1 << 5; +#ifdef _WIN32 + // Only used in windows to open files in binary mode. + static constexpr uint8_t BINARY = 1 << 5; +#endif +}; + +struct FileOpenFlags { + int flags; + FileLockType lockType = FileLockType::NO_LOCK; + FileCompressionType compressionType = FileCompressionType::AUTO_DETECT; + + explicit FileOpenFlags(int flags) : flags{flags} {} + explicit FileOpenFlags(int flags, FileLockType lockType) : flags{flags}, lockType{lockType} {} +}; + +class LBUG_API FileSystem { + friend struct FileInfo; + +public: + FileSystem() = default; + + explicit FileSystem(std::string homeDir) : dbPath(std::move(homeDir)) {} + + virtual ~FileSystem() = default; + + virtual std::unique_ptr openFile(const std::string& /*path*/, FileOpenFlags /*flags*/, + main::ClientContext* /*context*/ = nullptr) { + KU_UNREACHABLE; + } + + virtual std::vector glob(main::ClientContext* /*context*/, + const std::string& /*path*/) const { + KU_UNREACHABLE; + } + + virtual void overwriteFile(const std::string& from, const std::string& to); + + virtual void copyFile(const std::string& from, const std::string& to); + + virtual void createDir(const std::string& dir) const; + + virtual void removeFileIfExists(const std::string& path, + const main::ClientContext* context = nullptr); + + virtual bool fileOrPathExists(const std::string& path, main::ClientContext* context = nullptr); + + virtual std::string expandPath(main::ClientContext* context, const std::string& path) const; + + static std::string joinPath(const std::string& base, const std::string& part); + + static std::string getFileExtension(const std::filesystem::path& path); + + static bool isCompressedFile(const std::filesystem::path& path); + + static std::string getFileName(const std::filesystem::path& path); + + virtual bool canHandleFile(const std::string_view /*path*/) const { KU_UNREACHABLE; } + + virtual void syncFile(const FileInfo& fileInfo) const = 0; + + virtual bool canPerformSeek() const { return true; } + + virtual bool handleFileViaFunction(const std::string& /*path*/) const { return false; } + + virtual function::TableFunction getHandleFunction(const std::string& /*path*/) const { + KU_UNREACHABLE; + } + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + + virtual void cleanUP(main::ClientContext* /*context*/) {} + +protected: + virtual void readFromFile(FileInfo& fileInfo, void* buffer, uint64_t numBytes, + uint64_t position) const = 0; + + virtual int64_t readFile(FileInfo& fileInfo, void* buf, size_t numBytes) const = 0; + + virtual void writeFile(FileInfo& fileInfo, const uint8_t* buffer, uint64_t numBytes, + uint64_t offset) const; + + virtual int64_t seek(FileInfo& fileInfo, uint64_t offset, int whence) const = 0; + + virtual void reset(FileInfo& fileInfo); + + virtual void truncate(FileInfo& fileInfo, uint64_t size) const; + + virtual uint64_t getFileSize(const FileInfo& fileInfo) const = 0; + + static bool isGZIPCompressed(const std::filesystem::path& path); + + std::string dbPath; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/gzip_file_system.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/gzip_file_system.h new file mode 100644 index 0000000000..5e402a2198 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/gzip_file_system.h @@ -0,0 +1,40 @@ +#pragma once + +#include "compressed_file_system.h" + +namespace lbug { +namespace common { + +class GZipFileSystem : public CompressedFileSystem { +public: + static constexpr const idx_t BUFFER_SIZE = 1u << 15; // 32 KB + static constexpr const uint8_t GZIP_COMPRESSION_DEFLATE = 0x08; + static constexpr const uint8_t GZIP_FLAG_ASCII = 0x1; + static constexpr const uint8_t GZIP_FLAG_MULTIPART = 0x2; + static constexpr const uint8_t GZIP_FLAG_EXTRA = 0x4; + static constexpr const uint8_t GZIP_FLAG_NAME = 0x8; + static constexpr const uint8_t GZIP_FLAG_COMMENT = 0x10; + static constexpr const uint8_t GZIP_FLAG_ENCRYPT = 0x20; + static constexpr const uint8_t GZIP_HEADER_MINSIZE = 10; + static constexpr const idx_t GZIP_HEADER_MAXSIZE = 1u << 15; + static constexpr const uint8_t GZIP_FOOTER_SIZE = 8; + static constexpr const unsigned char GZIP_FLAG_UNSUPPORTED = + GZIP_FLAG_ASCII | GZIP_FLAG_MULTIPART | GZIP_FLAG_COMMENT | GZIP_FLAG_ENCRYPT; + +public: + std::unique_ptr openCompressedFile(std::unique_ptr fileInfo) override; + + std::unique_ptr createStream() override; + idx_t getInputBufSize() override { return BUFFER_SIZE; } + idx_t getOutputBufSize() override { return BUFFER_SIZE; } +}; + +struct GZIPFileInfo : public CompressedFileInfo { + GZIPFileInfo(CompressedFileSystem& compressedFS, std::unique_ptr childFileInfo) + : CompressedFileInfo{compressedFS, std::move(childFileInfo)} { + initialize(); + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/local_file_system.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/local_file_system.h new file mode 100644 index 0000000000..62f37d505a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/local_file_system.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include "file_system.h" + +namespace lbug { +namespace common { + +struct LocalFileInfo final : FileInfo { +#ifdef _WIN32 + LocalFileInfo(std::string path, const void* handle, FileSystem* fileSystem) + : FileInfo{std::move(path), fileSystem}, handle{handle} {} +#else + LocalFileInfo(std::string path, const int fd, FileSystem* fileSystem) + : FileInfo{std::move(path), fileSystem}, fd{fd} {} +#endif + + ~LocalFileInfo() override; + +#ifdef _WIN32 + const void* handle; +#else + const int fd; +#endif +}; + +class LBUG_API LocalFileSystem final : public FileSystem { +public: + explicit LocalFileSystem(std::string homeDir) : FileSystem(std::move(homeDir)) {} + + std::unique_ptr openFile(const std::string& path, FileOpenFlags flags, + main::ClientContext* context = nullptr) override; + + std::vector glob(main::ClientContext* context, + const std::string& path) const override; + + void overwriteFile(const std::string& from, const std::string& to) override; + + void copyFile(const std::string& from, const std::string& to) override; + + void createDir(const std::string& dir) const override; + + void removeFileIfExists(const std::string& path, + const main::ClientContext* context = nullptr) override; + + bool fileOrPathExists(const std::string& path, main::ClientContext* context = nullptr) override; + + std::string expandPath(main::ClientContext* context, const std::string& path) const override; + + void syncFile(const FileInfo& fileInfo) const override; + + static bool isLocalPath(const std::string& path); + + static bool fileExists(const std::string& filename); + +protected: + void readFromFile(FileInfo& fileInfo, void* buffer, uint64_t numBytes, + uint64_t position) const override; + + int64_t readFile(FileInfo& fileInfo, void* buf, size_t nbyte) const override; + + void writeFile(FileInfo& fileInfo, const uint8_t* buffer, uint64_t numBytes, + uint64_t offset) const override; + + int64_t seek(FileInfo& fileInfo, uint64_t offset, int whence) const override; + + void truncate(FileInfo& fileInfo, uint64_t size) const override; + + uint64_t getFileSize(const FileInfo& fileInfo) const override; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/virtual_file_system.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/virtual_file_system.h new file mode 100644 index 0000000000..97aa5df540 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/file_system/virtual_file_system.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include + +#include "compressed_file_system.h" +#include "file_system.h" + +namespace lbug { +namespace main { +class Database; +} + +namespace storage { +class BufferManager; +}; +namespace common { + +class LBUG_API VirtualFileSystem final : public FileSystem { + friend class storage::BufferManager; + +public: + VirtualFileSystem(); + explicit VirtualFileSystem(std::string homeDir); + + ~VirtualFileSystem() override; + + void registerFileSystem(std::unique_ptr fileSystem); + + std::unique_ptr openFile(const std::string& path, FileOpenFlags flags, + main::ClientContext* context = nullptr) override; + + std::vector glob(main::ClientContext* context, + const std::string& path) const override; + + void overwriteFile(const std::string& from, const std::string& to) override; + + void createDir(const std::string& dir) const override; + + void removeFileIfExists(const std::string& path, + const main::ClientContext* context = nullptr) override; + + bool fileOrPathExists(const std::string& path, main::ClientContext* context = nullptr) override; + + std::string expandPath(main::ClientContext* context, const std::string& path) const override; + + void syncFile(const FileInfo& fileInfo) const override; + + void cleanUP(main::ClientContext* context) override; + + bool handleFileViaFunction(const std::string& path) const override; + + function::TableFunction getHandleFunction(const std::string& path) const override; + + static VirtualFileSystem* GetUnsafe(const main::ClientContext& context); + +protected: + void readFromFile(FileInfo& fileInfo, void* buffer, uint64_t numBytes, + uint64_t position) const override; + + int64_t readFile(FileInfo& fileInfo, void* buf, size_t nbyte) const override; + + void writeFile(FileInfo& fileInfo, const uint8_t* buffer, uint64_t numBytes, + uint64_t offset) const override; + + int64_t seek(FileInfo& fileInfo, uint64_t offset, int whence) const override; + + void truncate(FileInfo& fileInfo, uint64_t size) const override; + + uint64_t getFileSize(const FileInfo& fileInfo) const override; + +private: + FileSystem* findFileSystem(const std::string& path) const; + + static FileCompressionType autoDetectCompressionType(const std::string& path); + +private: + std::vector> subSystems; + std::unique_ptr defaultFS; + std::unordered_map> + compressedFileSystem; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/finally_wrapper.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/finally_wrapper.h new file mode 100644 index 0000000000..3136a32583 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/finally_wrapper.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace lbug::common { +// RAII wrapper that calls an enclosed function when this class goes out of scope +// Should be used for any cleanup code that must be executed even if exceptions occur +template +struct FinallyWrapper { + explicit FinallyWrapper(Func&& func) : func(func) {} + ~FinallyWrapper() { func(); } + Func func; +}; +} // namespace lbug::common diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/in_mem_overflow_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/in_mem_overflow_buffer.h new file mode 100644 index 0000000000..8fa2c8e526 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/in_mem_overflow_buffer.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" +#include "common/copy_constructors.h" + +namespace lbug { +namespace storage { +class MemoryBuffer; +class MemoryManager; +} // namespace storage + +namespace common { + +struct LBUG_API BufferBlock { +public: + explicit BufferBlock(std::unique_ptr block); + ~BufferBlock(); + + uint64_t size() const; + uint8_t* data() const; + +public: + uint64_t currentOffset; + std::unique_ptr block; + + void resetCurrentOffset() { currentOffset = 0; } +}; + +class LBUG_API InMemOverflowBuffer { + +public: + explicit InMemOverflowBuffer(storage::MemoryManager* memoryManager) + : memoryManager{memoryManager} {}; + + DEFAULT_BOTH_MOVE(InMemOverflowBuffer); + + uint8_t* allocateSpace(uint64_t size); + + void merge(InMemOverflowBuffer& other) { + move(begin(other.blocks), end(other.blocks), back_inserter(blocks)); + // We clear the other InMemOverflowBuffer's block because when it is deconstructed, + // InMemOverflowBuffer's deconstructed tries to free these pages by calling + // memoryManager->freeBlock, but it should not because this InMemOverflowBuffer still + // needs them. + other.blocks.clear(); + } + + // Releases all memory accumulated for string overflows so far and re-initializes its state to + // an empty buffer. If there is a large string that used point to any of these overflow buffers + // they will error. + void resetBuffer(); + + // Manually set the underlying memory buffer to evicted to avoid double free + void preventDestruction(); + + storage::MemoryManager* getMemoryManager() { return memoryManager; } + +private: + bool requireNewBlock(uint64_t sizeToAllocate) { + return blocks.empty() || + (currentBlock()->currentOffset + sizeToAllocate) > currentBlock()->size(); + } + + void allocateNewBlock(uint64_t size); + + BufferBlock* currentBlock() { return blocks.back().get(); } + +private: + std::vector> blocks; + storage::MemoryManager* memoryManager; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mask.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mask.h new file mode 100644 index 0000000000..3c4855019d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mask.h @@ -0,0 +1,91 @@ +#pragma once + +#include "common/types/types.h" + +namespace lbug { +namespace common { + +// Note that this class is NOT thread-safe. +class SemiMask { +public: + explicit SemiMask(offset_t maxOffset) : maxOffset{maxOffset}, enabled{false} {} + + virtual ~SemiMask() = default; + + virtual void mask(offset_t nodeOffset) = 0; + virtual void maskRange(offset_t startNodeOffset, offset_t endNodeOffset) = 0; + + virtual bool isMasked(offset_t startNodeOffset) = 0; + + // include&exclude + virtual offset_vec_t range(uint32_t start, uint32_t end) = 0; + + virtual uint64_t getNumMaskedNodes() const = 0; + + virtual offset_vec_t collectMaskedNodes(uint64_t size) const = 0; + + offset_t getMaxOffset() const { return maxOffset; } + + bool isEnabled() const { return enabled; } + void enable() { enabled = true; } + +private: + offset_t maxOffset; + bool enabled; +}; + +struct SemiMaskUtil { + LBUG_API static std::unique_ptr createMask(offset_t maxOffset); +}; + +class NodeOffsetMaskMap { +public: + NodeOffsetMaskMap() = default; + + offset_t getNumMaskedNode() const; + + void addMask(table_id_t tableID, std::unique_ptr mask) { + KU_ASSERT(!maskMap.contains(tableID)); + maskMap.insert({tableID, std::move(mask)}); + } + + table_id_map_t getMasks() const { + table_id_map_t result; + for (auto& [tableID, mask] : maskMap) { + result.emplace(tableID, mask.get()); + } + return result; + } + + bool containsTableID(table_id_t tableID) const { return maskMap.contains(tableID); } + SemiMask* getOffsetMask(table_id_t tableID) const { + KU_ASSERT(containsTableID(tableID)); + return maskMap.at(tableID).get(); + } + + void pin(table_id_t tableID) { + if (maskMap.contains(tableID)) { + pinnedMask = maskMap.at(tableID).get(); + } else { + pinnedMask = nullptr; + } + } + bool hasPinnedMask() const { return pinnedMask != nullptr; } + SemiMask* getPinnedMask() const { return pinnedMask; } + + bool valid(offset_t offset) const { + KU_ASSERT(pinnedMask != nullptr); + return pinnedMask->isMasked(offset); + } + bool valid(nodeID_t nodeID) const { + KU_ASSERT(maskMap.contains(nodeID.tableID)); + return maskMap.at(nodeID.tableID)->isMasked(nodeID.offset); + } + +private: + table_id_map_t> maskMap; + SemiMask* pinnedMask = nullptr; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/md5.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/md5.h new file mode 100644 index 0000000000..858a3da11c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/md5.h @@ -0,0 +1,97 @@ +#pragma once + +/* +** This code taken from the SQLite test library (can be found at +** https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). +** Originally found on the internet. The original header comment follows this comment. +** The code has been refactored, but the algorithm stays the same. +*/ +/* + * This code implements the MD5 message-digest algorithm. + * The algorithm is due to Ron Rivest. This code was + * written by Colin Plumb in 1993, no copyright is claimed. + * This code is in the public domain; do with it what you wish. + * + * Equivalent code is available from RSA Data Security, Inc. + * This code has been tested against that, and is equivalent, + * except that you don't need to include two pages of legalese + * with every copy. + * + * To compute the message digest of a chunk of bytes, declare an + * MD5Context structure, pass it to MD5Init, call MD5Update as + * needed on buffers full of bytes, and then call MD5Final, which + * will fill a supplied 16-byte array with the digest. + */ + +#include + +namespace lbug { +namespace common { + +class MD5 { + struct Context { + int isInit; + uint32_t buf[4]; + uint32_t bits[2]; + unsigned char in[64]; + }; + typedef struct Context MD5Context; + + // Status of an MD5 hash. - changed from static global variables to private members + MD5Context ctx{}; + int isInit = 0; + char zResult[34] = ""; + + // Note: this code is harmless on little-endian machines. + void byteReverse(unsigned char* buf, unsigned longs); + + // The core of the MD5 algorithm, this alters an existing MD5 hash to + // reflect the addition of 16 longwords of new data. MD5Update blocks + // the data and converts bytes into longwords for this routine. + void MD5Transform(uint32_t buf[4], const uint32_t in[16]); + + // Start MD5 accumulation. Set bit count to 0 and buffer to mysterious + // initialization constants. + void MD5Init(); + + // Update context to reflect the concatenation of another buffer full + // of bytes. + void MD5Update(const unsigned char* buf, unsigned int len); + + // Final wrapup - pad to 64-byte boundary with the bit pattern + // 1 0* (64-bit count of bits processed, MSB-first) + void MD5Final(unsigned char digest[16]); + + // Convert a digest into base-16. digest should be declared as + // "unsigned char digest[16]" in the calling function. The MD5 + // digest is stored in the first 16 bytes. zBuf should + // be "char zBuf[33]". + static void DigestToBase16(const unsigned char* digest, char* zBuf); + +public: + // Add additional text to the current MD5 hash. + // note: original name changed from md5_add + void addToMD5(const char* z, uint32_t len) { + if (!isInit) { + MD5Init(); + isInit = 1; + } + MD5Update((unsigned char*)z, len); + } + + // Compute the final signature. Reset the hash generator in preparation + // for the next round. + // note: original name changed from md5_finish + const char* finishMD5() { + if (isInit) { + unsigned char digest[16]; + MD5Final(digest); + isInit = 0; + DigestToBase16(digest, zResult); + } + return zResult; + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/metric.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/metric.h new file mode 100644 index 0000000000..d7e78366d2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/metric.h @@ -0,0 +1,52 @@ +#pragma once + +#include "common/timer.h" + +namespace lbug { +namespace common { + +/** + * Note that metrics are not thread safe. + */ +class Metric { + +public: + explicit Metric(bool enabled) : enabled{enabled} {} + + virtual ~Metric() = default; + +public: + bool enabled; +}; + +class TimeMetric : public Metric { + +public: + explicit TimeMetric(bool enable); + + void start(); + void stop(); + + double getElapsedTimeMS() const; + +public: + double accumulatedTime; + bool isStarted; + Timer timer; +}; + +class NumericMetric : public Metric { + +public: + explicit NumericMetric(bool enable); + + void increase(uint64_t value); + + void incrementByOne(); + +public: + uint64_t accumulatedValue; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mpsc_queue.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mpsc_queue.h new file mode 100644 index 0000000000..8c8957977c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mpsc_queue.h @@ -0,0 +1,107 @@ +#pragma once + +// Based on the following design: +// https://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue + +#include + +#include "common/assert.h" +#include "common/copy_constructors.h" + +namespace lbug { +namespace common { + +// Producers are completely wait-free. +template +class MPSCQueue { + struct Node { + T data; + std::atomic next; + + explicit Node(T data) : data(std::move(data)), next(nullptr) {} + }; + +public: + MPSCQueue() : head(nullptr), tail(nullptr), _approxSize(0) { + // Allocate a dummy element. + Node* stub = new Node(T()); + head = stub; + + // Ordering doesn't matter. + tail.store(stub, std::memory_order_relaxed); + } + DELETE_BOTH_COPY(MPSCQueue); + MPSCQueue(MPSCQueue&& other) + : head(other.head), tail(other.tail.exchange(nullptr, std::memory_order_relaxed)), + _approxSize(other._approxSize.load(std::memory_order_relaxed)) { + other.head = nullptr; + } + // If this method existed, it wouldn't be atomic, and so would be rather error-prone. Maybe + // there's a valid future use case. + DELETE_MOVE_ASSN(MPSCQueue); + + // NOTE: It is NOT guaranteed that the result of a push() is accessible to a thread that calls + // pop() after the push(), because of implementation details. See the body of the function for + // details. + void push(T elem) { + Node* node = new Node(std::move(elem)); + _approxSize.fetch_add(1, std::memory_order_relaxed); + // ORDERING: must acquire any updates to prev before modifying it, and release our updates + // to node for other producers. + Node* prev = tail.exchange(node, std::memory_order_acq_rel); + // NOTE: If the thread is suspended here, then ALL FUTURE push() calls will be INACCESSIBLE + // by pop() calls until the next line runs. In order to guarantee that a push() is visible + // to a thread that calls pop(), ALL push() calls must have completed. + // ORDERING: must make updates visible to consumers. + prev->next.store(node, std::memory_order_release); + } + + // NOTE: It is NOT safe to call pop() from multiple threads without synchronization. + bool pop(T& elem) { + // ORDERING: Acquire any updates made by producers. + // Note that head is accessed only by the single consumer, so accesses to it need not be + // synchronized. + Node* next = head->next.load(std::memory_order_acquire); + if (next == nullptr) { + return false; + } + // Free the old element. + delete head; + head = next; + elem = std::move(head->data); + _approxSize.fetch_sub(1, std::memory_order_relaxed); + // Now the current head has dummy data in it again (i.e., whatever was leftover after the + // move()). + return true; + } + + // Return an approximation of the number of elements in the queue. + // Due to implementation details, this number must not be relied on. However, it can be used to + // get a rough estimate for the size of the queue. + size_t approxSize() const { return _approxSize.load(std::memory_order_relaxed); } + + // Drain the queue. All operations on the queue MUST have finished. I.e., there must be NO + // push() or pop() operations in progress of any kind. + ~MPSCQueue() { + // If we were moved out of, return. + if (!head) { + return; + } + + T dummy; + while (pop(dummy)) {} + KU_ASSERT(head == tail.load(std::memory_order_relaxed)); + delete head; + } + +private: + // Head is always present, but always has dummy data. This ensures that it is always easy to + // append to the list, without branching in the methods. + Node* head; + std::atomic tail; + + std::atomic _approxSize; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mutex.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mutex.h new file mode 100644 index 0000000000..b5ede2b97e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/mutex.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "common/copy_constructors.h" + +namespace lbug { +namespace common { + +template +class MutexGuard { + template + friend class Mutex; + MutexGuard(T& data, std::unique_lock lck) : data(&data), lck(std::move(lck)) {} + +public: + DELETE_COPY_DEFAULT_MOVE(MutexGuard); + + T* operator->() & { return data; } + T& operator*() & { return *data; } + T* get() & { return data; } + + // Must not call these operators on a temporary MutexGuard! + // Guards _must_ be held while accessing the inner data. + T* operator->() && = delete; + T& operator*() && = delete; + T* get() && = delete; + +private: + T* data; + std::unique_lock lck; +}; + +template +class Mutex { +public: + Mutex() : data() {} + explicit Mutex(T data) : data(std::move(data)) {} + DELETE_COPY_AND_MOVE(Mutex); + + MutexGuard lock() { + std::unique_lock lck{mtx}; + return MutexGuard(data, std::move(lck)); + } + + std::optional> try_lock() { + if (!mtx.try_lock()) { + return std::nullopt; + } + std::unique_lock lck{mtx, std::adopt_lock}; + return MutexGuard(data, std::move(lck)); + } + +private: + T data; + std::mutex mtx; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/null_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/null_buffer.h new file mode 100644 index 0000000000..cb82ddc5b3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/null_buffer.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace common { + +class NullBuffer { + +public: + constexpr static const uint64_t NUM_NULL_MASKS_PER_BYTE = 8; + + static inline bool isNull(const uint8_t* nullBytes, uint64_t valueIdx) { + return nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] & + (1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE)); + } + + static inline void setNull(uint8_t* nullBytes, uint64_t valueIdx) { + nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] |= + (1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE)); + } + + static inline void setNoNull(uint8_t* nullBytes, uint64_t valueIdx) { + nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] &= + ~(1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE)); + } + + static inline uint64_t getNumBytesForNullValues(uint64_t numValues) { + return (numValues + NUM_NULL_MASKS_PER_BYTE - 1) / NUM_NULL_MASKS_PER_BYTE; + } + + static inline void initNullBytes(uint8_t* nullBytes, uint64_t numValues) { + memset(nullBytes, 0 /* value */, getNumBytesForNullValues(numValues)); + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/null_mask.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/null_mask.h new file mode 100644 index 0000000000..5395e47089 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/null_mask.h @@ -0,0 +1,194 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include + +namespace lbug { +namespace common { + +class ArrowNullMaskTree; +class Serializer; +class Deserializer; + +constexpr uint64_t NULL_BITMASKS_WITH_SINGLE_ONE[64] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x10000, 0x20000, 0x40000, 0x80000, + 0x100000, 0x200000, 0x400000, 0x800000, 0x1000000, 0x2000000, 0x4000000, 0x8000000, 0x10000000, + 0x20000000, 0x40000000, 0x80000000, 0x100000000, 0x200000000, 0x400000000, 0x800000000, + 0x1000000000, 0x2000000000, 0x4000000000, 0x8000000000, 0x10000000000, 0x20000000000, + 0x40000000000, 0x80000000000, 0x100000000000, 0x200000000000, 0x400000000000, 0x800000000000, + 0x1000000000000, 0x2000000000000, 0x4000000000000, 0x8000000000000, 0x10000000000000, + 0x20000000000000, 0x40000000000000, 0x80000000000000, 0x100000000000000, 0x200000000000000, + 0x400000000000000, 0x800000000000000, 0x1000000000000000, 0x2000000000000000, + 0x4000000000000000, 0x8000000000000000}; +constexpr uint64_t NULL_BITMASKS_WITH_SINGLE_ZERO[64] = {0xfffffffffffffffe, 0xfffffffffffffffd, + 0xfffffffffffffffb, 0xfffffffffffffff7, 0xffffffffffffffef, 0xffffffffffffffdf, + 0xffffffffffffffbf, 0xffffffffffffff7f, 0xfffffffffffffeff, 0xfffffffffffffdff, + 0xfffffffffffffbff, 0xfffffffffffff7ff, 0xffffffffffffefff, 0xffffffffffffdfff, + 0xffffffffffffbfff, 0xffffffffffff7fff, 0xfffffffffffeffff, 0xfffffffffffdffff, + 0xfffffffffffbffff, 0xfffffffffff7ffff, 0xffffffffffefffff, 0xffffffffffdfffff, + 0xffffffffffbfffff, 0xffffffffff7fffff, 0xfffffffffeffffff, 0xfffffffffdffffff, + 0xfffffffffbffffff, 0xfffffffff7ffffff, 0xffffffffefffffff, 0xffffffffdfffffff, + 0xffffffffbfffffff, 0xffffffff7fffffff, 0xfffffffeffffffff, 0xfffffffdffffffff, + 0xfffffffbffffffff, 0xfffffff7ffffffff, 0xffffffefffffffff, 0xffffffdfffffffff, + 0xffffffbfffffffff, 0xffffff7fffffffff, 0xfffffeffffffffff, 0xfffffdffffffffff, + 0xfffffbffffffffff, 0xfffff7ffffffffff, 0xffffefffffffffff, 0xffffdfffffffffff, + 0xffffbfffffffffff, 0xffff7fffffffffff, 0xfffeffffffffffff, 0xfffdffffffffffff, + 0xfffbffffffffffff, 0xfff7ffffffffffff, 0xffefffffffffffff, 0xffdfffffffffffff, + 0xffbfffffffffffff, 0xff7fffffffffffff, 0xfeffffffffffffff, 0xfdffffffffffffff, + 0xfbffffffffffffff, 0xf7ffffffffffffff, 0xefffffffffffffff, 0xdfffffffffffffff, + 0xbfffffffffffffff, 0x7fffffffffffffff}; + +const uint64_t NULL_LOWER_MASKS[65] = {0x0, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f, 0xff, 0x1ff, + 0x3ff, 0x7ff, 0xfff, 0x1fff, 0x3fff, 0x7fff, 0xffff, 0x1ffff, 0x3ffff, 0x7ffff, 0xfffff, + 0x1fffff, 0x3fffff, 0x7fffff, 0xffffff, 0x1ffffff, 0x3ffffff, 0x7ffffff, 0xfffffff, 0x1fffffff, + 0x3fffffff, 0x7fffffff, 0xffffffff, 0x1ffffffff, 0x3ffffffff, 0x7ffffffff, 0xfffffffff, + 0x1fffffffff, 0x3fffffffff, 0x7fffffffff, 0xffffffffff, 0x1ffffffffff, 0x3ffffffffff, + 0x7ffffffffff, 0xfffffffffff, 0x1fffffffffff, 0x3fffffffffff, 0x7fffffffffff, 0xffffffffffff, + 0x1ffffffffffff, 0x3ffffffffffff, 0x7ffffffffffff, 0xfffffffffffff, 0x1fffffffffffff, + 0x3fffffffffffff, 0x7fffffffffffff, 0xffffffffffffff, 0x1ffffffffffffff, 0x3ffffffffffffff, + 0x7ffffffffffffff, 0xfffffffffffffff, 0x1fffffffffffffff, 0x3fffffffffffffff, + 0x7fffffffffffffff, 0xffffffffffffffff}; +const uint64_t NULL_HIGH_MASKS[65] = {0x0, 0x8000000000000000, 0xc000000000000000, + 0xe000000000000000, 0xf000000000000000, 0xf800000000000000, 0xfc00000000000000, + 0xfe00000000000000, 0xff00000000000000, 0xff80000000000000, 0xffc0000000000000, + 0xffe0000000000000, 0xfff0000000000000, 0xfff8000000000000, 0xfffc000000000000, + 0xfffe000000000000, 0xffff000000000000, 0xffff800000000000, 0xffffc00000000000, + 0xffffe00000000000, 0xfffff00000000000, 0xfffff80000000000, 0xfffffc0000000000, + 0xfffffe0000000000, 0xffffff0000000000, 0xffffff8000000000, 0xffffffc000000000, + 0xffffffe000000000, 0xfffffff000000000, 0xfffffff800000000, 0xfffffffc00000000, + 0xfffffffe00000000, 0xffffffff00000000, 0xffffffff80000000, 0xffffffffc0000000, + 0xffffffffe0000000, 0xfffffffff0000000, 0xfffffffff8000000, 0xfffffffffc000000, + 0xfffffffffe000000, 0xffffffffff000000, 0xffffffffff800000, 0xffffffffffc00000, + 0xffffffffffe00000, 0xfffffffffff00000, 0xfffffffffff80000, 0xfffffffffffc0000, + 0xfffffffffffe0000, 0xffffffffffff0000, 0xffffffffffff8000, 0xffffffffffffc000, + 0xffffffffffffe000, 0xfffffffffffff000, 0xfffffffffffff800, 0xfffffffffffffc00, + 0xfffffffffffffe00, 0xffffffffffffff00, 0xffffffffffffff80, 0xffffffffffffffc0, + 0xffffffffffffffe0, 0xfffffffffffffff0, 0xfffffffffffffff8, 0xfffffffffffffffc, + 0xfffffffffffffffe, 0xffffffffffffffff}; + +class LBUG_API NullMask { +public: + static constexpr uint64_t NO_NULL_ENTRY = 0; + static constexpr uint64_t ALL_NULL_ENTRY = ~uint64_t(NO_NULL_ENTRY); + static constexpr uint64_t NUM_BITS_PER_NULL_ENTRY_LOG2 = 6; + static constexpr uint64_t NUM_BITS_PER_NULL_ENTRY = (uint64_t)1 << NUM_BITS_PER_NULL_ENTRY_LOG2; + static constexpr uint64_t NUM_BYTES_PER_NULL_ENTRY = NUM_BITS_PER_NULL_ENTRY >> 3; + + // For creating a managed null mask + explicit NullMask(uint64_t capacity) : mayContainNulls{false} { + auto numNullEntries = (capacity + NUM_BITS_PER_NULL_ENTRY - 1) / NUM_BITS_PER_NULL_ENTRY; + buffer = std::make_unique(numNullEntries); + data = std::span(buffer.get(), numNullEntries); + std::fill(data.begin(), data.end(), NO_NULL_ENTRY); + } + + // For creating a null mask using existing data + explicit NullMask(std::span nullData, bool mayContainNulls) + : data{nullData}, buffer{}, mayContainNulls{mayContainNulls} {} + + inline void setAllNonNull() { + if (!mayContainNulls) { + return; + } + std::fill(data.begin(), data.end(), NO_NULL_ENTRY); + mayContainNulls = false; + } + inline void setAllNull() { + std::fill(data.begin(), data.end(), ALL_NULL_ENTRY); + mayContainNulls = true; + } + + inline bool hasNoNullsGuarantee() const { return !mayContainNulls; } + uint64_t countNulls() const; + + static void setNull(uint64_t* nullEntries, uint32_t pos, bool isNull); + inline void setNull(uint32_t pos, bool isNull) { + KU_ASSERT(pos < getNumNullBits(data)); + setNull(data.data(), pos, isNull); + if (isNull) { + mayContainNulls = true; + } + } + + static inline bool isNull(const uint64_t* nullEntries, uint32_t pos) { + auto [entryPos, bitPosInEntry] = getNullEntryAndBitPos(pos); + return nullEntries[entryPos] & NULL_BITMASKS_WITH_SINGLE_ONE[bitPosInEntry]; + } + + static uint64_t getNumNullBits(std::span data) { + return data.size() * NullMask::NUM_BITS_PER_NULL_ENTRY; + } + + inline bool isNull(uint32_t pos) const { + KU_ASSERT(pos < getNumNullBits(data)); + return isNull(data.data(), pos); + } + + // const because updates to the data must set mayContainNulls if any value + // becomes non-null + // Modifying the underlying data should be done with setNull or copyFromNullData + inline const uint64_t* getData() const { return data.data(); } + + static inline uint64_t getNumNullEntries(uint64_t numNullBits) { + return (numNullBits >> NUM_BITS_PER_NULL_ENTRY_LOG2) + + ((numNullBits - (numNullBits << NUM_BITS_PER_NULL_ENTRY_LOG2)) == 0 ? 0 : 1); + } + + // Copies bitpacked null flags from one buffer to another, starting at an arbitrary bit + // offset and preserving adjacent bits. + // + // returns true if we have copied a nullBit with value 1 (indicates a null value) to + // dstNullEntries. + static bool copyNullMask(const uint64_t* srcNullEntries, uint64_t srcOffset, + uint64_t* dstNullEntries, uint64_t dstOffset, uint64_t numBitsToCopy, bool invert = false); + + inline bool copyFrom(const NullMask& nullMask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t numBitsToCopy, bool invert = false) { + if (nullMask.hasNoNullsGuarantee()) { + setNullFromRange(dstOffset, numBitsToCopy, invert); + return invert; + } else { + return copyFromNullBits(nullMask.getData(), srcOffset, dstOffset, numBitsToCopy, + invert); + } + } + bool copyFromNullBits(const uint64_t* srcNullEntries, uint64_t srcOffset, uint64_t dstOffset, + uint64_t numBitsToCopy, bool invert = false); + + // Sets the given number of bits to null (if isNull is true) or non-null (if isNull is false), + // starting at the offset + static void setNullRange(uint64_t* nullEntries, uint64_t offset, uint64_t numBitsToSet, + bool isNull); + + void setNullFromRange(uint64_t offset, uint64_t numBitsToSet, bool isNull); + + void resize(uint64_t capacity); + + void operator|=(const NullMask& other); + + // Fast calculation of the minimum and maximum null values + // (essentially just three states, all null, all non-null and some null) + static std::pair getMinMax(const uint64_t* nullEntries, uint64_t offset, + uint64_t numValues); + +private: + static inline std::pair getNullEntryAndBitPos(uint64_t pos) { + auto nullEntryPos = pos >> NUM_BITS_PER_NULL_ENTRY_LOG2; + return std::make_pair(nullEntryPos, + pos - (nullEntryPos << NullMask::NUM_BITS_PER_NULL_ENTRY_LOG2)); + } + + static bool copyUnaligned(const uint64_t* srcNullEntries, uint64_t srcOffset, + uint64_t* dstNullEntries, uint64_t dstOffset, uint64_t numBitsToCopy, bool invert = false); + +private: + std::span data; + std::unique_ptr buffer; + bool mayContainNulls; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/numeric_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/numeric_utils.h new file mode 100644 index 0000000000..60dbc1a5ba --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/numeric_utils.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include "common/types/int128_t.h" +#include "common/types/types.h" +#include +#include + +namespace lbug { +namespace common { +namespace numeric_utils { + +template +concept IsIntegral = std::integral || std::same_as, int128_t>; + +template +concept IsSigned = std::same_as || std::numeric_limits::is_signed; + +template +concept IsUnSigned = std::numeric_limits::is_unsigned; + +template +struct MakeSigned { + using type = std::make_signed_t; +}; + +template<> +struct MakeSigned { + using type = int128_t; +}; + +template +using MakeSignedT = typename MakeSigned::type; + +template +struct MakeUnSigned { + using type = std::make_unsigned_t; +}; + +template<> +struct MakeUnSigned { + // currently evaluates to int128_t as we don't have an uint128_t type + using type = int128_t; +}; + +template +using MakeUnSignedT = typename MakeUnSigned::type; + +template +decltype(auto) makeValueSigned(T value) { + return static_cast>(value); +} + +template +decltype(auto) makeValueUnSigned(T value) { + return static_cast>(value); +} + +template +constexpr int bitWidth(T x) { + return std::bit_width(x); +} + +template<> +constexpr int bitWidth(int128_t x) { + if (x.high != 0) { + constexpr size_t BITS_PER_BYTE = 8; + return sizeof(x.low) * BITS_PER_BYTE + std::bit_width(makeValueUnSigned(x.high)); + } + return std::bit_width(x.low); +} +} // namespace numeric_utils +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/profiler.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/profiler.h new file mode 100644 index 0000000000..5688d39a44 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/profiler.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/metric.h" + +namespace lbug { +namespace common { + +class Profiler { + +public: + TimeMetric* registerTimeMetric(const std::string& key); + + NumericMetric* registerNumericMetric(const std::string& key); + + double sumAllTimeMetricsWithKey(const std::string& key); + + uint64_t sumAllNumericMetricsWithKey(const std::string& key); + +private: + void addMetric(const std::string& key, std::unique_ptr metric); + +public: + std::mutex mtx; + bool enabled = false; + std::unordered_map>> metrics; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/random_engine.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/random_engine.h new file mode 100644 index 0000000000..3cccbba931 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/random_engine.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include "common/api.h" +#include "pcg_random.hpp" + +namespace lbug { + +namespace main { +class ClientContext; +} + +namespace common { + +struct RandomState { + pcg32 pcg; + + RandomState() {} +}; + +class LBUG_API RandomEngine { +public: + RandomEngine(); + RandomEngine(uint64_t seed, uint64_t stream); + + void setSeed(uint64_t seed); + + uint32_t nextRandomInteger(); + uint32_t nextRandomInteger(uint32_t upper); + + static RandomEngine* Get(const main::ClientContext& context); + +private: + std::mutex mtx; + RandomState randomState; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/roaring_mask.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/roaring_mask.h new file mode 100644 index 0000000000..c4b367d402 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/roaring_mask.h @@ -0,0 +1,54 @@ +#pragma once + +#include "common/mask.h" +#include "roaring.hh" + +namespace lbug { +namespace common { + +class Roaring32BitmapSemiMask final : public SemiMask { +public: + explicit Roaring32BitmapSemiMask(offset_t maxOffset) + : SemiMask(maxOffset), roaring(std::make_shared()) {} + + void mask(offset_t nodeOffset) override { roaring->add(nodeOffset); } + void maskRange(offset_t startNodeOffset, offset_t endNodeOffset) override { + roaring->addRange(startNodeOffset, endNodeOffset); + } + + bool isMasked(offset_t startNodeOffset) override { return roaring->contains(startNodeOffset); } + + uint64_t getNumMaskedNodes() const override { return roaring->cardinality(); } + + offset_vec_t collectMaskedNodes(uint64_t size) const override; + + // include&exclude + offset_vec_t range(uint32_t start, uint32_t end) override; + + std::shared_ptr roaring; +}; + +class Roaring64BitmapSemiMask final : public SemiMask { +public: + explicit Roaring64BitmapSemiMask(offset_t maxOffset) + : SemiMask(maxOffset), roaring(std::make_shared()) {} + + void mask(offset_t nodeOffset) override { roaring->add(nodeOffset); } + void maskRange(offset_t startNodeOffset, offset_t endNodeOffset) override { + roaring->addRange(startNodeOffset, endNodeOffset); + } + + bool isMasked(offset_t startNodeOffset) override { return roaring->contains(startNodeOffset); } + + uint64_t getNumMaskedNodes() const override { return roaring->cardinality(); } + + offset_vec_t collectMaskedNodes(uint64_t size) const override; + + // include&exclude + offset_vec_t range(uint32_t start, uint32_t end) override; + + std::shared_ptr roaring; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffer_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffer_reader.h new file mode 100644 index 0000000000..d6b73aef0c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffer_reader.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "common/serializer/reader.h" + +namespace lbug { +namespace common { + +struct BufferReader final : Reader { + BufferReader(uint8_t* data, size_t dataSize) : data(data), dataSize(dataSize), readSize(0) {} + + void read(uint8_t* outputData, uint64_t size) override { + memcpy(outputData, data + readSize, size); + readSize += size; + } + + bool finished() override { return readSize >= dataSize; } + + uint8_t* data; + size_t dataSize; + size_t readSize; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffer_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffer_writer.h new file mode 100644 index 0000000000..794b9d8905 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffer_writer.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "common/serializer/writer.h" + +namespace lbug { +namespace common { + +static constexpr uint64_t SERIALIZER_DEFAULT_SIZE = 1024; + +struct BinaryData { + std::unique_ptr data; + uint64_t size = 0; +}; + +class LBUG_API BufferWriter : public Writer { +public: + // Serializes to a buffer allocated by the serializer, will expand when + // writing past the initial threshold. + explicit BufferWriter(uint64_t maximumSize = SERIALIZER_DEFAULT_SIZE); + + // Retrieves the data after the writing has been completed. + BinaryData getData() { return std::move(blob); } + + uint64_t getSize() const override { return blob.size; } + + uint8_t* getBlobData() const { return blob.data.get(); } + + void clear() override { blob.size = 0; } + void flush() override { + // DO NOTHING: BufferedWriter does not need to flush. + } + void sync() override { + // DO NOTHING: BufferedWriter does not need to sync. + } + + template + void write(T element) { + static_assert(std::is_trivially_destructible(), + "Write element must be trivially destructible"); + write(reinterpret_cast(&element), sizeof(T)); + } + + void write(const uint8_t* buffer, uint64_t len) final; + + void writeBufferData(const std::string& str) { + write(reinterpret_cast(str.c_str()), str.size()); + } + + void writeBufferData(const char& ch) { + write(reinterpret_cast(&ch), sizeof(char)); + } + +private: + uint64_t maximumSize; + uint8_t* data; + + BinaryData blob; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffered_file.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffered_file.h new file mode 100644 index 0000000000..c1cf0a8c12 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/buffered_file.h @@ -0,0 +1,70 @@ +#pragma once + +#include + +#include "common/serializer/reader.h" +#include "common/serializer/writer.h" + +namespace lbug { +namespace common { + +struct FileInfo; + +class BufferedFileWriter final : public Writer { +public: + explicit BufferedFileWriter(FileInfo& fileInfo); + ~BufferedFileWriter() override; + + void write(const uint8_t* data, uint64_t size) override; + + void clear() override; + void flush() override; + void sync() override; + + // Note: this function resets the next file offset to be written. Make sure the buffer is empty. + void setFileOffset(uint64_t fileOffset) { this->fileOffset = fileOffset; } + uint64_t getFileOffset() const { return fileOffset; } + void resetOffsets() { + fileOffset = 0; + bufferOffset = 0; + } + + uint64_t getSize() const override; + +protected: + std::unique_ptr buffer; + uint64_t fileOffset, bufferOffset; + FileInfo& fileInfo; +}; + +class BufferedFileReader final : public Reader { +public: + explicit BufferedFileReader(FileInfo& fileInfo); + + // Note: this function resets the next file offset to read. + void resetReadOffset(uint64_t fileOffset) { + this->fileOffset = fileOffset; + bufferOffset = 0; + bufferSize = 0; + } + + void read(uint8_t* data, uint64_t size) override; + + bool finished() override; + + uint64_t getReadOffset() const { return fileOffset - bufferSize + bufferOffset; } + FileInfo* getFileInfo() const { return &fileInfo; } + +private: + void readNextPage(); + +private: + std::unique_ptr buffer; + uint64_t fileOffset, bufferOffset; + FileInfo& fileInfo; + uint64_t fileSize; + uint64_t bufferSize; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/deserializer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/deserializer.h new file mode 100644 index 0000000000..07283a2634 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/deserializer.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/serializer/reader.h" + +namespace lbug { +namespace common { + +class LBUG_API Deserializer { +public: + explicit Deserializer(std::unique_ptr reader) : reader(std::move(reader)) {} + + bool finished() const { return reader->finished(); } + + template + requires std::is_trivially_destructible_v || std::is_same_v + void deserializeValue(T& value) { + reader->read(reinterpret_cast(&value), sizeof(T)); + } + + void read(uint8_t* data, uint64_t size) const { reader->read(data, size); } + + Reader* getReader() const { return reader.get(); } + + void validateDebuggingInfo(std::string& value, const std::string& expectedVal); + + template + void deserializeOptionalValue(std::unique_ptr& value) { + bool isNull = false; + deserializeValue(isNull); + if (!isNull) { + value = T::deserialize(*this); + } + } + + template + void deserializeMap(std::map& values) { + uint64_t mapSize = 0; + deserializeValue(mapSize); + for (auto i = 0u; i < mapSize; i++) { + T1 key; + deserializeValue(key); + auto val = T2::deserialize(*this); + values.emplace(key, std::move(val)); + } + } + + template + void deserializeUnorderedMap(std::unordered_map& values) { + uint64_t mapSize = 0; + deserializeValue(mapSize); + for (auto i = 0u; i < mapSize; i++) { + T1 key; + deserializeValue(key); + T2 val; + deserializeValue(val); + values.emplace(key, std::move(val)); + } + } + + template + void deserializeUnorderedMapOfPtrs(std::unordered_map>& values) { + uint64_t mapSize = 0; + deserializeValue(mapSize); + values.reserve(mapSize); + for (auto i = 0u; i < mapSize; i++) { + T1 key; + deserializeValue(key); + auto val = T2::deserialize(*this); + values.emplace(key, std::move(val)); + } + } + + template + void deserializeVector(std::vector& values) { + uint64_t vectorSize = 0; + deserializeValue(vectorSize); + values.resize(vectorSize); + for (auto& value : values) { + if constexpr (requires(Deserializer& deser) { T::deserialize(deser); }) { + value = T::deserialize(*this); + } else { + deserializeValue(value); + } + } + } + + template + void deserializeArray(std::array& values) { + KU_ASSERT(values.size() == ARRAY_SIZE); + for (auto& value : values) { + if constexpr (requires(Deserializer& deser) { T::deserialize(deser); }) { + value = T::deserialize(*this); + } else { + deserializeValue(value); + } + } + } + + template + void deserializeVectorOfPtrs(std::vector>& values) { + uint64_t vectorSize = 0; + deserializeValue(vectorSize); + values.resize(vectorSize); + for (auto i = 0u; i < vectorSize; i++) { + values[i] = T::deserialize(*this); + } + } + + template + void deserializeVectorOfPtrs(std::vector>& values, + std::function(Deserializer&)> deserializeFunc) { + uint64_t vectorSize = 0; + deserializeValue(vectorSize); + values.resize(vectorSize); + for (auto i = 0u; i < vectorSize; i++) { + values[i] = deserializeFunc(*this); + } + } + + template + void deserializeUnorderedSet(std::unordered_set& values) { + uint64_t setSize = 0; + deserializeValue(setSize); + for (auto i = 0u; i < setSize; i++) { + T value; + deserializeValue(value); + values.insert(value); + } + } + +private: + std::unique_ptr reader; +}; + +template<> +void Deserializer::deserializeValue(std::string& value); + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/in_mem_file_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/in_mem_file_writer.h new file mode 100644 index 0000000000..617d987655 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/in_mem_file_writer.h @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include "common/serializer/writer.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug::storage { +struct PageRange; +class ShadowFile; +class PageAllocator; +} // namespace lbug::storage + +namespace lbug { +namespace common { +class BufferedFileWriter; + +class InMemFileWriter final : public Writer { +public: + explicit InMemFileWriter(storage::MemoryManager& mm); + + void write(const uint8_t* data, uint64_t size) override; + + std::span getPage(page_idx_t pageIdx) const { + KU_ASSERT(pageIdx < pages.size()); + return pages[pageIdx]->getBuffer(); + } + + storage::PageRange flush(storage::PageAllocator& pageAllocator, + storage::ShadowFile& shadowFile) const; + void flush(storage::PageRange allocatedPages, storage::FileHandle* fileHandle, + storage::ShadowFile& shadowFile) const; + + page_idx_t getNumPagesToFlush() const { return pages.size(); } + + static uint64_t getPageSize(); + void flush(Writer& writer) const; + + uint64_t getSize() const override { + uint64_t size = pages.size() > 1 ? LBUG_PAGE_SIZE * (pages.size() - 1) : 0; + return size + pageOffset; + } + + void clear() override { + pages.clear(); + pageOffset = 0; + } + + void flush() override { + // DO NOTHING: InMemWriter does not need to flush. + } + void sync() override { + // DO NOTHING: InMemWriter does not need to sync. + } + +private: + bool needNewBuffer(uint64_t size) const; + +private: + std::vector> pages; + storage::MemoryManager& mm; + uint64_t pageOffset; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/reader.h new file mode 100644 index 0000000000..a6d91a76ec --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/reader.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "common/cast.h" + +namespace lbug { +namespace common { + +class Reader { +public: + virtual void read(uint8_t* data, uint64_t size) = 0; + virtual ~Reader() = default; + + virtual bool finished() = 0; + virtual void onObjectBegin() {}; + virtual void onObjectEnd() {}; + + template + TARGET* cast() { + return common::ku_dynamic_cast(this); + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/serializer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/serializer.h new file mode 100644 index 0000000000..c9d4f5f83e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/serializer.h @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common/api.h" +#include "common/serializer/writer.h" + +namespace lbug { +namespace common { + +class LBUG_API Serializer { +public: + explicit Serializer(std::shared_ptr writer) : writer(std::move(writer)) {} + + template + requires std::is_trivially_destructible_v || std::is_same_v + void serializeValue(const T& value) { + writer->write(reinterpret_cast(&value), sizeof(T)); + } + + // Alias for serializeValue + template + void write(const T& value) { + serializeValue(value); + } + + void writeDebuggingInfo(const std::string& value); + + void write(const uint8_t* value, uint64_t len) { writer->write(value, len); } + + template + void serializeOptionalValue(const std::unique_ptr& value) { + serializeValue(value == nullptr); + if (value != nullptr) { + value->serialize(*this); + } + } + + template + void serializeMap(const std::map& values) { + uint64_t mapSize = values.size(); + serializeValue(mapSize); + for (auto& value : values) { + serializeValue(value.first); + value.second.serialize(*this); + } + } + + template + void serializeUnorderedMap(const std::unordered_map& values) { + uint64_t mapSize = values.size(); + serializeValue(mapSize); + for (auto& value : values) { + serializeValue(value.first); + serializeValue(value.second); + } + } + + template + void serializeUnorderedMapOfPtrs(const std::unordered_map>& values) { + uint64_t mapSize = values.size(); + serializeValue(mapSize); + for (auto& value : values) { + serializeValue(value.first); + value.second->serialize(*this); + } + } + + template + void serializeVector(const std::vector& values) { + uint64_t vectorSize = values.size(); + serializeValue(vectorSize); + for (auto& value : values) { + if constexpr (requires(Serializer& ser) { value.serialize(ser); }) { + value.serialize(*this); + } else { + serializeValue(value); + } + } + } + + template + void serializeArray(const std::array& values) { + for (auto& value : values) { + if constexpr (requires(Serializer& ser) { value.serialize(ser); }) { + value.serialize(*this); + } else { + serializeValue(value); + } + } + } + + template + void serializeVectorOfPtrs(const std::vector>& values) { + uint64_t vectorSize = values.size(); + serializeValue(vectorSize); + for (auto& value : values) { + value->serialize(*this); + } + } + + template + void serializeUnorderedSet(const std::unordered_set& values) { + uint64_t setSize = values.size(); + serializeValue(setSize); + for (const auto& value : values) { + serializeValue(value); + } + } + + Writer* getWriter() const { return writer.get(); } + +private: + std::shared_ptr writer; +}; + +template<> +void Serializer::serializeValue(const std::string& value); + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/writer.h new file mode 100644 index 0000000000..2c6014d550 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/serializer/writer.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +class Writer { +public: + virtual void write(const uint8_t* data, uint64_t size) = 0; + virtual ~Writer() = default; + + virtual uint64_t getSize() const = 0; + + virtual void clear() = 0; + virtual void flush() = 0; + virtual void sync() = 0; + virtual void onObjectBegin() {}; + virtual void onObjectEnd() {}; + + template + const TARGET& cast() const { + return dynamic_cast(*this); + } + template + TARGET& cast() { + return dynamic_cast(*this); + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/sha256.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/sha256.h new file mode 100644 index 0000000000..f1e2bbaa39 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/sha256.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "mbedtls/sha256.h" + +namespace lbug { +namespace common { + +class SHA256 { +public: + static constexpr size_t SHA256_HASH_LENGTH_BYTES = 32; + static constexpr size_t SHA256_HASH_LENGTH_TEXT = 64; + +public: + SHA256(); + ~SHA256(); + void addString(const std::string& str); + void finishSHA256(char* out); + static void toBase16(const char* in, char* out, size_t len); + +private: + typedef mbedtls_sha256_context SHA256Context; + + SHA256Context shaContext; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/static_vector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/static_vector.h new file mode 100644 index 0000000000..a83a0a5b94 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/static_vector.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include + +#include "common/assert.h" +#include "common/copy_constructors.h" + +namespace lbug { +namespace common { + +template +class MaybeUninit { +public: + T& assumeInit() { return *ptr(); } + const T& assumeInit() const { return *ptr(); } + T* ptr() { return reinterpret_cast(&data); } + const T* ptr() const { return reinterpret_cast(&data); } + +private: + alignas(T) std::byte data[sizeof(T)]; +}; + +template +class StaticVector { + StaticVector(const StaticVector& other) : len(other.len) { + std::uninitialized_copy(other.begin(), other.end(), begin()); + } + +public: + StaticVector() : len(0){}; + StaticVector(StaticVector&& other) : len(other.len) { + std::uninitialized_move(other.begin(), other.end(), begin()); + other.len = 0; + } + DELETE_COPY_ASSN(StaticVector); + EXPLICIT_COPY_METHOD(StaticVector); + StaticVector& operator=(StaticVector&& other) { + if (&other != this) { + clear(); + len = other.len; + std::uninitialized_move(other.begin(), other.end(), begin()); + other.len = 0; + } + return *this; + } + ~StaticVector() { clear(); } + + T& operator[](size_t i) { + KU_ASSERT(i < len); + return items[i].assumeInit(); + } + const T& operator[](size_t i) const { + KU_ASSERT(i < len); + return items[i].assumeInit(); + } + void push_back(T elem) { + KU_ASSERT(len < N); + new (items[len].ptr()) T(std::move(elem)); + len++; + } + T pop_back() { + KU_ASSERT(len > 0); + len--; + return std::move(items[len].assumeInit()); + } + T* begin() { return items[0].ptr(); } + const T* begin() const { return items[0].ptr(); } + T* end() { return items[len].ptr(); } + const T* end() const { return items[len].ptr(); } + + void clear() { + std::destroy(begin(), end()); + len = 0; + } + + bool empty() const { return len == 0; } + bool full() const { return len == N; } + size_t size() const { return len; } + +private: + MaybeUninit items[N]; + size_t len; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/string_format.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/string_format.h new file mode 100644 index 0000000000..f6d8ca2686 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/string_format.h @@ -0,0 +1,120 @@ +#pragma once + +#include +#include +#if USE_STD_FORMAT +#include +#else +#include "common/exception/internal.h" +#endif + +namespace lbug { +namespace common { + +#if USE_STD_FORMAT + +template +inline std::string stringFormat(std::format_string format, Args&&... args) { + return std::format(format, std::forward(args)...); +} + +#else + +namespace string_format_detail { +#define MAP_STD_TO_STRING(typ) \ + inline std::string map(typ v) { \ + return std::to_string(v); \ + } + +MAP_STD_TO_STRING(short) +MAP_STD_TO_STRING(unsigned short) +MAP_STD_TO_STRING(int) +MAP_STD_TO_STRING(unsigned int) +MAP_STD_TO_STRING(long) +MAP_STD_TO_STRING(unsigned long) +MAP_STD_TO_STRING(long long) +MAP_STD_TO_STRING(unsigned long long) +MAP_STD_TO_STRING(float) +MAP_STD_TO_STRING(double) +#undef MAP_STD_TO_STRING + +#define MAP_SELF(typ) \ + inline typ map(typ v) { \ + return v; \ + } +MAP_SELF(const char*); +// Also covers std::string +MAP_SELF(std::string_view) + +// Chars are mapped to themselves, but signed char and unsigned char (which are used for int8_t and +// uint8_t respectively), need to be cast to be properly output as integers. This is consistent with +// fmt's behavior. +MAP_SELF(char) +inline std::string map(signed char v) { + return std::to_string(int(v)); +} +inline std::string map(unsigned char v) { + return std::to_string(unsigned(v)); +} +#undef MAP_SELF + +template +inline void stringFormatHelper(std::string& ret, std::string_view format, Args&&... args) { + size_t bracket = format.find('{'); + if (bracket == std::string_view::npos) { + ret += format; + return; + } + ret += format.substr(0, bracket); + if (format.substr(bracket, 4) == "{{}}") { + // Escaped {}. + ret += "{}"; + return stringFormatHelper(ret, format.substr(bracket + 4), std::forward(args)...); + } else if (format.substr(bracket, 2) == "{}") { + // Formatted {}. + throw InternalException("Not enough values for string_format."); + } + // Something else. + ret.push_back('{'); + return stringFormatHelper(ret, format.substr(bracket + 1), std::forward(args)...); +} + +template +inline void stringFormatHelper(std::string& ret, std::string_view format, Arg&& arg, + Args&&... args) { + size_t bracket = format.find('{'); + if (bracket == std::string_view::npos) { + throw InternalException("Too many values for string_format."); + } + ret += format.substr(0, bracket); + if (format.substr(bracket, 4) == "{{}}") { + // Escaped {}. + ret += "{}"; + return stringFormatHelper(ret, format.substr(bracket + 4), std::forward(arg), + std::forward(args)...); + } else if (format.substr(bracket, 2) == "{}") { + // Formatted {}. + ret += map(arg); + return stringFormatHelper(ret, format.substr(bracket + 2), std::forward(args)...); + } + // Something else. + ret.push_back('{'); + return stringFormatHelper(ret, format.substr(bracket + 1), std::forward(arg), + std::forward(args)...); +} +} // namespace string_format_detail + +// Formats `args` according to `format`. Accepts {} for formatting the argument and {{}} for +// a literal {}. Formatting is done with std::ostream::operator<<. +template +inline std::string stringFormat(std::string_view format, Args&&... args) { + std::string ret; + ret.reserve(32); // Optimistic pre-allocation. + string_format_detail::stringFormatHelper(ret, format, std::forward(args)...); + return ret; +} + +#endif + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/string_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/string_utils.h new file mode 100644 index 0000000000..2e14a5a53c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/string_utils.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" +#include "function/cast/functions/numeric_limits.h" +#include + +namespace lbug { +namespace common { + +class LBUG_API StringUtils { +public: + static std::vector splitComma(const std::string& input); + + // Does not split within [], {}, or (). + // can specify maximum number of elements to split + static std::vector smartSplit(std::string_view input, char splitChar, + uint64_t maxNumEle = function::NumericLimits::maximum()); + + static std::vector split(const std::string& input, const std::string& delimiter, + bool ignoreEmptyStringParts = true); + static std::vector splitBySpace(const std::string& input); + + static std::string getUpper(const std::string& input); + static std::string getUpper(const std::string_view& input); + static std::string getLower(const std::string& input); + static void toLower(std::string& input); + static void toUpper(std::string& input); + + static bool isSpace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; + } + static bool characterIsNewLine(char c) { return c == '\n' || c == '\r'; } + static bool CharacterIsDigit(char c) { return c >= '0' && c <= '9'; } + + static std::string ltrim(const std::string& input) { + auto s = input; + s.erase(s.begin(), + find_if(s.begin(), s.end(), [](unsigned char ch) { return !isspace(ch); })); + return s; + } + static std::string_view ltrim(std::string_view input) { + auto begin = 0u; + while (begin < input.size() && isspace(input[begin])) { + begin++; + } + return input.substr(begin); + } + static std::string rtrim(const std::string& input) { + auto s = input; + s.erase(find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !isspace(ch); }).base(), + s.end()); + return s; + } + static std::string_view rtrim(std::string_view input) { + auto end = input.size(); + while (end > 0 && isSpace(input[end - 1])) { + end--; + } + return input.substr(0, end); + } + static std::string ltrimNewlines(const std::string& input); + static std::string rtrimNewlines(const std::string& input); + + static void removeWhiteSpaces(std::string& str) { + std::regex whiteSpacePattern{"\\s"}; + str = std::regex_replace(str, whiteSpacePattern, ""); + } + + static void removeCStringWhiteSpaces(const char*& input, uint64_t& len); + + static void replaceAll(std::string& str, const std::string& search, + const std::string& replacement); + + static std::string extractStringBetween(const std::string& input, char delimiterStart, + char delimiterEnd, bool includeDelimiter = false); + + static uint64_t caseInsensitiveHash(const std::string& str); + + static bool caseInsensitiveEquals(std::string_view left, std::string_view right); + + // join multiple strings into one string. Components are concatenated by the given separator + static std::string join(const std::vector& input, const std::string& separator); + static std::string join(const std::span input, + const std::string& separator); + + // join multiple items of container with given size, transformed to string + // using function, into one string using the given separator + template + static std::string join(const C& input, S count, const std::string& separator, Func f); + + static constexpr uint8_t asciiToLowerCaseMap[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, + 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, + 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, + 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, + 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, + 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, + 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, + 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255}; + + static std::string encodeURL(const std::string& input, bool encodeSlash = false); + + // container hash function for strings which lets you hash both string_view and string + // references + struct string_hash { + using hash_type = std::hash; + using is_transparent = void; + + std::size_t operator()(const char* str) const { return hash_type{}(str); } + std::size_t operator()(std::string_view str) const { return hash_type{}(str); } + std::size_t operator()(std::string const& str) const { return hash_type{}(str); } + }; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/system_message.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/system_message.h new file mode 100644 index 0000000000..bed980f433 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/system_message.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +inline std::string systemErrMessage(int code) { + // System errors are unexpected. For anything expected, we should catch it explicitly and + // provide a better error message to the user. + // LCOV_EXCL_START + return std::system_category().message(code); + // LCOV_EXCL_STOP +} + +inline std::string posixErrMessage() { + // LCOV_EXCL_START + return systemErrMessage(errno); + // LCOV_EXCL_STOP +} + +LBUG_API std::string dlErrMessage(); + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/progress_bar.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/progress_bar.h new file mode 100644 index 0000000000..7e5e37fd3f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/progress_bar.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "progress_bar_display.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace common { + +typedef std::unique_ptr (*progress_bar_display_create_func_t)(); + +/** + * @brief Progress bar for tracking the progress of a pipeline. Prints the progress of each query + * pipeline and the overall progress. + */ +class ProgressBar { +public: + explicit ProgressBar(bool enableProgressBar); + + static std::shared_ptr DefaultProgressBarDisplay(); + + void addPipeline(); + + void finishPipeline(uint64_t queryID); + + void endProgress(uint64_t queryID); + + void startProgress(uint64_t queryID); + + void toggleProgressBarPrinting(bool enable); + + LBUG_API void updateProgress(uint64_t queryID, double curPipelineProgress); + + void setDisplay(std::shared_ptr progressBarDipslay); + + std::shared_ptr getDisplay() { return display; } + + bool getProgressBarPrinting() const { return trackProgress; } + + LBUG_API static ProgressBar* Get(const main::ClientContext& context); + +private: + void resetProgressBar(uint64_t queryID); + + void updateDisplay(uint64_t queryID, double curPipelineProgress); + +private: + uint32_t numPipelines; + uint32_t numPipelinesFinished; + std::mutex progressBarLock; + bool trackProgress; + std::shared_ptr display; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/progress_bar_display.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/progress_bar_display.h new file mode 100644 index 0000000000..fbe202950c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/progress_bar_display.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include + +namespace lbug { +namespace common { + +/** + * @brief Interface for displaying progress of a pipeline and a query. + */ +class ProgressBarDisplay { +public: + ProgressBarDisplay() : pipelineProgress{0}, numPipelines{0}, numPipelinesFinished{0} {}; + + virtual ~ProgressBarDisplay() = default; + + // Update the progress of the pipeline and the number of finished pipelines. queryID is used to + // identify the query when we track progress of multiple queries asynchronously + // This function should work even if called concurrently by multiple threads + virtual void updateProgress(uint64_t queryID, double newPipelineProgress, + uint32_t newNumPipelinesFinished) = 0; + + // Finish the progress display. queryID is used to identify the query when we track progress of + // multiple queries asynchronously + virtual void finishProgress(uint64_t queryID) = 0; + + void setNumPipelines(uint32_t newNumPipelines) { numPipelines = newNumPipelines; }; + +protected: + std::atomic pipelineProgress; + uint32_t numPipelines; + std::atomic numPipelinesFinished; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/task.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/task.h new file mode 100644 index 0000000000..fd6c37edbc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/task.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +using lock_t = std::unique_lock; + +/** + * Task represents a task that can be executed by multiple threads in the TaskScheduler. Task is a + * virtual class. Users of TaskScheduler need to extend the Task class and implement at + * least a virtual run() function. Users can assume that before T.run() is called, a worker thread W + * has grabbed task T from the TaskScheduler's queue and registered itself to T. They can also + * assume that after run() is called W will deregister itself from T. When deregistering, if W is + * the last worker to finish on T, i.e., once W finishes, T will be completed, the + * finalize() will be called. So the run() and finalize() calls are separate + * calls and if there is some state from the run() function execution that will be needed by + * finalize, users should save it somewhere that can be accessed in + * finalize(). See ProcessorTask for an example of this. + */ +class LBUG_API Task { + friend class TaskScheduler; + +public: + explicit Task(uint64_t maxNumThreads) + : parent{nullptr}, maxNumThreads{maxNumThreads}, numThreadsFinished{0}, + numThreadsRegistered{0}, exceptionsPtr{nullptr}, ID{UINT64_MAX} {} + + virtual ~Task() = default; + virtual void run() = 0; + // This function is called from inside deRegisterThreadAndFinalizeTaskIfNecessary() only + // once by the last registered worker that is completing this task. So the task lock is + // already acquired. So do not attempt to acquire the task lock inside. If needed we can + // make the deregister function release the lock before calling finalize and + // drop this assumption. + virtual void finalize() {} + // If task should terminate all subsequent tasks. + virtual bool terminate() { return false; } + + void addChildTask(std::unique_ptr child) { + child->parent = this; + children.push_back(std::move(child)); + } + + bool isCompletedSuccessfully() { + lock_t lck{taskMtx}; + return isCompletedNoLock() && !hasExceptionNoLock(); + } + + bool isCompletedNoLock() const { + return numThreadsRegistered > 0 && numThreadsFinished == numThreadsRegistered; + } + + void setSingleThreadedTask() { maxNumThreads = 1; } + + bool registerThread(); + + void deRegisterThreadAndFinalizeTask(); + + void setException(const std::exception_ptr& exceptionPtr) { + lock_t lck{taskMtx}; + setExceptionNoLock(exceptionPtr); + } + + bool hasException() { + lock_t lck{taskMtx}; + return exceptionsPtr != nullptr; + } + + std::exception_ptr getExceptionPtr() { + lock_t lck{taskMtx}; + return exceptionsPtr; + } + +private: + bool canRegisterNoLock() const { + return 0 == numThreadsFinished && maxNumThreads > numThreadsRegistered; + } + + bool hasExceptionNoLock() const { return exceptionsPtr != nullptr; } + + void setExceptionNoLock(const std::exception_ptr& exceptionPtr) { + if (exceptionsPtr == nullptr) { + exceptionsPtr = exceptionPtr; + } + } + +public: + Task* parent; + std::vector> + children; // Dependency tasks that needs to be executed first. + +protected: + std::mutex taskMtx; + std::condition_variable cv; + uint64_t maxNumThreads, numThreadsFinished, numThreadsRegistered; + std::exception_ptr exceptionsPtr; + uint64_t ID; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/task_scheduler.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/task_scheduler.h new file mode 100644 index 0000000000..8b8fbd6ee8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/task_scheduler.h @@ -0,0 +1,112 @@ +#pragma once +#include + +#ifndef __SINGLE_THREADED__ +#include +#include +#endif + +#include "common/task_system/task.h" +#include "processor/execution_context.h" + +namespace lbug { +namespace common { + +struct ScheduledTask { + ScheduledTask(std::shared_ptr task, uint64_t ID) : task{std::move(task)}, ID{ID} {}; + std::shared_ptr task; + uint64_t ID; +}; + +/** + * TaskScheduler is a library that manages a set of worker threads that can execute tasks that are + * put into a task queue. Each task accepts a maximum number of threads. Users of TaskScheduler + * schedule tasks to be executed by calling schedule functions, e.g., pushTaskIntoQueue or + * scheduleTaskAndWaitOrError. New tasks are put at the end of the queue. Workers grab the first + * task from the beginning of the queue that they can register themselves to work on. Any task that + * is completed is removed automatically from the queue. If there is a task that raises an + * exception, the worker threads catch it and store it with the tasks. The user thread that is + * waiting on the completion of the task (or tasks) will throw the exception (the user thread could + * be waiting on a tasks through a function that waits, e.g., scheduleTaskAndWaitOrError. + * + * Currently there is one way the TaskScheduler can be used: + * Schedule one task T and wait for T to finish or error if there was an exception raised by + * one of the threads working on T that errored. This is simply done by the call: + * scheduleTaskAndWaitOrError(T); + * + * TaskScheduler guarantees that workers will register themselves to tasks in FIFO order. However + * this does not guarantee that the tasks will be completed in FIFO order: a long running task + * that is not accepting more registration can stay in the queue for an unlimited time until + * completion. + */ +#ifndef __SINGLE_THREADED__ +class LBUG_API TaskScheduler { +public: +#if defined(__APPLE__) + explicit TaskScheduler(uint64_t numWorkerThreads, uint32_t threadQos); +#else + explicit TaskScheduler(uint64_t numWorkerThreads); +#endif + ~TaskScheduler(); + + // Schedules the dependencies of the given task and finally the task one after another (so + // not concurrently), and throws an exception if any of the tasks errors. Regardless of + // whether or not the given task or one of its dependencies errors, when this function + // returns, no task related to the given task will be in the task queue. Further no worker + // thread will be working on the given task. + void scheduleTaskAndWaitOrError(const std::shared_ptr& task, + processor::ExecutionContext* context, bool launchNewWorkerThread = false); + + static TaskScheduler* Get(const main::ClientContext& context); + +private: + // Functions to launch worker threads and for the worker threads to use to grab task from queue. + void runWorkerThread(); + + std::shared_ptr pushTaskIntoQueue(const std::shared_ptr& task); + + void removeErroringTask(uint64_t scheduledTaskID); + + std::shared_ptr getTaskAndRegister(); + static void runTask(Task* task); + +private: + std::deque> taskQueue; + bool stopWorkerThreads; + std::vector workerThreads; + std::mutex taskSchedulerMtx; + std::condition_variable cv; + uint64_t nextScheduledTaskID; +#if defined(__APPLE__) + uint32_t threadQos; // Thread quality of service for worker threads. +#endif +}; +#else +// Single-threaded version of TaskScheduler +class TaskScheduler { +public: + explicit TaskScheduler(uint64_t numWorkerThreads); + ~TaskScheduler(); + + void scheduleTaskAndWaitOrError(const std::shared_ptr& task, + processor::ExecutionContext* context, bool launchNewWorkerThread = false); + + static TaskScheduler* Get(const main::ClientContext& context); + +private: + std::shared_ptr pushTaskIntoQueue(const std::shared_ptr& task); + + void removeErroringTask(uint64_t scheduledTaskID); + + std::shared_ptr getTaskAndRegister(); + static void runTask(Task* task); + +private: + std::deque> taskQueue; + bool stopWorkerThreads; + std::mutex taskSchedulerMtx; + uint64_t nextScheduledTaskID; +}; +#endif +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/terminal_progress_bar_display.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/terminal_progress_bar_display.h new file mode 100644 index 0000000000..7cd2387c65 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/task_system/terminal_progress_bar_display.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include "progress_bar_display.h" + +namespace lbug { +namespace common { + +/** + * @brief A class that displays a progress bar in the terminal. + */ +class TerminalProgressBarDisplay final : public ProgressBarDisplay { +public: + void updateProgress(uint64_t queryID, double newPipelineProgress, + uint32_t newNumPipelinesFinished) override; + + void finishProgress(uint64_t queryID) override; + +private: + void setGreenFont() const { std::cerr << "\033[1;32m"; } + + void setDefaultFont() const { std::cerr << "\033[0m"; } + + void printProgressBar(); + +private: + bool printing = false; + std::atomic currentlyPrintingProgress; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/timer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/timer.h new file mode 100644 index 0000000000..7f609d9186 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/timer.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "exception/exception.h" + +namespace lbug { +namespace common { + +class Timer { + +public: + void start() { + finished = false; + startTime = std::chrono::high_resolution_clock::now(); + } + + void stop() { + stopTime = std::chrono::high_resolution_clock::now(); + finished = true; + } + + double getDuration() const { + if (finished) { + auto duration = stopTime - startTime; + return (double)std::chrono::duration_cast(duration).count(); + } + throw Exception("Timer is still running."); + } + + uint64_t getElapsedTimeInMS() const { + auto now = std::chrono::high_resolution_clock::now(); + auto duration = now - startTime; + auto count = std::chrono::duration_cast(duration).count(); + KU_ASSERT(count >= 0); + return count; + } + +private: + std::chrono::time_point startTime; + std::chrono::time_point stopTime; + bool finished = false; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/type_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/type_utils.h new file mode 100644 index 0000000000..f5574f5997 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/type_utils.h @@ -0,0 +1,333 @@ +#pragma once + +#include + +#include "common/assert.h" +#include "common/types/blob.h" +#include "common/types/date_t.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/ku_string.h" +#include "common/types/timestamp_t.h" +#include "common/types/types.h" +#include "common/types/uint128_t.h" +#include "common/types/uuid.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { + +class ValueVector; + +template +struct overload : Funcs... { + explicit overload(Funcs... funcs) : Funcs(funcs)... {} + using Funcs::operator()...; +}; + +class TypeUtils { +public: + template + static void paramPackForEachHelper(const Func& func, std::index_sequence, + Types&&... values) { + ((func(indices, values)), ...); + } + + template + static void paramPackForEach(const Func& func, Types&&... values) { + paramPackForEachHelper(func, std::index_sequence_for(), + std::forward(values)...); + } + + static std::string entryToString(const LogicalType& dataType, const uint8_t* value, + ValueVector* vector); + + template + static inline std::string toString(const T& val, void* /*valueVector*/ = nullptr) { + if constexpr (std::is_same_v) { + return val; + } else if constexpr (std::is_same_v) { + return val.getAsString(); + } else { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value); + return std::to_string(val); + } + } + static std::string nodeToString(const struct_entry_t& val, ValueVector* vector); + static std::string relToString(const struct_entry_t& val, ValueVector* vector); + + static inline void encodeOverflowPtr(uint64_t& overflowPtr, page_idx_t pageIdx, + uint32_t pageOffset) { + memcpy(&overflowPtr, &pageIdx, 4); + memcpy(((uint8_t*)&overflowPtr) + 4, &pageOffset, 4); + } + static inline void decodeOverflowPtr(uint64_t overflowPtr, page_idx_t& pageIdx, + uint32_t& pageOffset) { + pageIdx = 0; + memcpy(&pageIdx, &overflowPtr, 4); + memcpy(&pageOffset, ((uint8_t*)&overflowPtr) + 4, 4); + } + + template + static inline constexpr common::PhysicalTypeID getPhysicalTypeIDForType() { + if constexpr (std::is_same_v) { + return common::PhysicalTypeID::INT64; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::INT32; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::INT16; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::INT8; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::UINT64; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::UINT32; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::UINT16; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::UINT8; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::FLOAT; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::DOUBLE; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::INT128; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::INTERVAL; + } else if constexpr (std::is_same_v) { + return common::PhysicalTypeID::UINT128; + } else if constexpr (std::same_as || std::same_as || + std::same_as) { + return common::PhysicalTypeID::STRING; + } else { + KU_UNREACHABLE; + } + } + + /* + * TypeUtils::visit can be used to call generic code on all or some Logical and Physical type + * variants with access to type information. + * + * E.g. + * + * std::string result; + * visit(dataType, [&](T) { + * if constexpr(std::is_same_v()) { + * result = vector->getValue(0).getAsString(); + * } else if (std::integral) { + * result = std::to_string(vector->getValue(0)); + * } else { + * KU_UNREACHABLE; + * } + * }); + * + * or + * std::string result; + * visit(dataType, + * [&](ku_string_t) { + * result = vector->getValue(0); + * }, + * [&](T) { + * result = std::to_string(vector->getValue(0)); + * }, + * [](auto) { KU_UNREACHABLE; } + * ); + * + * Note that when multiple functions are provided, at least one function must match all data + * types. + * + * Also note that implicit conversions may occur with the multi-function variant + * if you don't include a generic auto function to cover types which aren't explicitly included. + * See https://en.cppreference.com/w/cpp/utility/variant/visit + */ + template + static inline auto visit(const LogicalType& dataType, Fs... funcs) { + // Note: arguments are used only for type deduction and have no meaningful value. + // They should be optimized out by the compiler + auto func = overload(funcs...); + switch (dataType.getLogicalTypeID()) { + /* NOLINTBEGIN(bugprone-branch-clone)*/ + case LogicalTypeID::INT8: + return func(int8_t()); + case LogicalTypeID::UINT8: + return func(uint8_t()); + case LogicalTypeID::INT16: + return func(int16_t()); + case LogicalTypeID::UINT16: + return func(uint16_t()); + case LogicalTypeID::INT32: + return func(int32_t()); + case LogicalTypeID::UINT32: + return func(uint32_t()); + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + return func(int64_t()); + case LogicalTypeID::UINT64: + return func(uint64_t()); + case LogicalTypeID::BOOL: + return func(bool()); + case LogicalTypeID::INT128: + return func(int128_t()); + case LogicalTypeID::DOUBLE: + return func(double()); + case LogicalTypeID::FLOAT: + return func(float()); + case LogicalTypeID::DECIMAL: + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + return func(int16_t()); + case PhysicalTypeID::INT32: + return func(int32_t()); + case PhysicalTypeID::INT64: + return func(int64_t()); + case PhysicalTypeID::INT128: + return func(int128_t()); + default: + KU_UNREACHABLE; + } + case LogicalTypeID::INTERVAL: + return func(interval_t()); + case LogicalTypeID::INTERNAL_ID: + return func(internalID_t()); + case LogicalTypeID::UINT128: + return func(uint128_t()); + case LogicalTypeID::STRING: + return func(ku_string_t()); + case LogicalTypeID::DATE: + return func(date_t()); + case LogicalTypeID::TIMESTAMP_NS: + return func(timestamp_ns_t()); + case LogicalTypeID::TIMESTAMP_MS: + return func(timestamp_ms_t()); + case LogicalTypeID::TIMESTAMP_SEC: + return func(timestamp_sec_t()); + case LogicalTypeID::TIMESTAMP_TZ: + return func(timestamp_tz_t()); + case LogicalTypeID::TIMESTAMP: + return func(timestamp_t()); + case LogicalTypeID::BLOB: + return func(blob_t()); + case LogicalTypeID::UUID: + return func(ku_uuid_t()); + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: + return func(list_entry_t()); + case LogicalTypeID::MAP: + return func(map_entry_t()); + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::STRUCT: + return func(struct_entry_t()); + case LogicalTypeID::UNION: + return func(union_entry_t()); + /* NOLINTEND(bugprone-branch-clone)*/ + default: + // Unsupported type + KU_UNREACHABLE; + } + } + + template + static inline auto visit(PhysicalTypeID dataType, Fs&&... funcs) { + // Note: arguments are used only for type deduction and have no meaningful value. + // They should be optimized out by the compiler + auto func = overload(funcs...); + switch (dataType) { + /* NOLINTBEGIN(bugprone-branch-clone)*/ + case PhysicalTypeID::INT8: + return func(int8_t()); + case PhysicalTypeID::UINT8: + return func(uint8_t()); + case PhysicalTypeID::INT16: + return func(int16_t()); + case PhysicalTypeID::UINT16: + return func(uint16_t()); + case PhysicalTypeID::INT32: + return func(int32_t()); + case PhysicalTypeID::UINT32: + return func(uint32_t()); + case PhysicalTypeID::INT64: + return func(int64_t()); + case PhysicalTypeID::UINT64: + return func(uint64_t()); + case PhysicalTypeID::BOOL: + return func(bool()); + case PhysicalTypeID::INT128: + return func(int128_t()); + case PhysicalTypeID::DOUBLE: + return func(double()); + case PhysicalTypeID::FLOAT: + return func(float()); + case PhysicalTypeID::INTERVAL: + return func(interval_t()); + case PhysicalTypeID::INTERNAL_ID: + return func(internalID_t()); + case PhysicalTypeID::UINT128: + return func(uint128_t()); + case PhysicalTypeID::STRING: + return func(ku_string_t()); + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + return func(list_entry_t()); + case PhysicalTypeID::STRUCT: + return func(struct_entry_t()); + /* NOLINTEND(bugprone-branch-clone)*/ + case PhysicalTypeID::ANY: + case PhysicalTypeID::POINTER: + case PhysicalTypeID::ALP_EXCEPTION_DOUBLE: + case PhysicalTypeID::ALP_EXCEPTION_FLOAT: + // Unsupported type + KU_UNREACHABLE; + // Needed for return type deduction to work + return func(uint8_t()); + default: + KU_UNREACHABLE; + } + } +}; + +// Forward declaration of template specializations. +template<> +std::string TypeUtils::toString(const int128_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const uint128_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const bool& val, void* valueVector); +template<> +std::string TypeUtils::toString(const internalID_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const date_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const timestamp_ns_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const timestamp_ms_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const timestamp_sec_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const timestamp_tz_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const timestamp_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const interval_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const ku_string_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const blob_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const ku_uuid_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const list_entry_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const map_entry_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const struct_entry_t& val, void* valueVector); +template<> +std::string TypeUtils::toString(const union_entry_t& val, void* valueVector); + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/blob.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/blob.h new file mode 100644 index 0000000000..9bbb6cc899 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/blob.h @@ -0,0 +1,52 @@ +#pragma once + +#include "common/types/ku_string.h" + +namespace lbug { +namespace common { + +struct blob_t { + ku_string_t value; +}; + +struct HexFormatConstants { + // map of integer -> hex value. + static constexpr const char* HEX_TABLE = "0123456789ABCDEF"; + // reverse map of byte -> integer value, or -1 for invalid hex values. + static const int HEX_MAP[256]; + static constexpr const uint64_t NUM_BYTES_TO_SHIFT_FOR_FIRST_BYTE = 4; + static constexpr const uint64_t SECOND_BYTE_MASK = 0x0F; + static constexpr const char PREFIX[] = "\\x"; + static constexpr const uint64_t PREFIX_LENGTH = 2; + static constexpr const uint64_t FIRST_BYTE_POS = PREFIX_LENGTH; + static constexpr const uint64_t SECOND_BYTES_POS = PREFIX_LENGTH + 1; + static constexpr const uint64_t LENGTH = 4; +}; + +struct Blob { + static std::string toString(const uint8_t* value, uint64_t len); + + static inline std::string toString(const blob_t& blob) { + return toString(blob.value.getData(), blob.value.len); + } + + static uint64_t getBlobSize(const ku_string_t& blob); + + static uint64_t fromString(const char* str, uint64_t length, uint8_t* resultBuffer); + + template + static inline T getValue(const blob_t& data) { + return *reinterpret_cast(data.value.getData()); + } + template + // NOLINTNEXTLINE(readability-non-const-parameter): Would cast away qualifiers. + static inline T getValue(char* data) { + return *reinterpret_cast(data); + } + +private: + static void validateHexCode(const uint8_t* blobStr, uint64_t length, uint64_t curPos); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/cast_helpers.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/cast_helpers.h new file mode 100644 index 0000000000..0f4c5454ec --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/cast_helpers.h @@ -0,0 +1,297 @@ +#pragma once + +#include + +#include "common/types/date_t.h" +#include "interval_t.h" + +namespace lbug { +namespace common { + +// This is copied from third_party/fmt/include/fmt/format.h and format-inl.h. +static const char digits[] = "0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"; + +//! NumericHelper is a static class that holds helper functions for integers/doubles +class NumericHelper { +public: + // Formats value in reverse and returns a pointer to the beginning. + template + static char* FormatUnsigned(T value, char* ptr) { + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". + auto index = static_cast((value % 100) * 2); + value /= 100; + *--ptr = digits[index + 1]; + *--ptr = digits[index]; + } + if (value < 10) { + *--ptr = static_cast('0' + value); + return ptr; + } + auto index = static_cast(value * 2); + *--ptr = digits[index + 1]; + *--ptr = digits[index]; + return ptr; + } + static int getUnsignedInt64Length(uint64_t value) { + if (value >= 10000000000ULL) { + if (value >= 1000000000000000ULL) { + int length = 16; + length += value >= 10000000000000000ULL; + length += value >= 100000000000000000ULL; + length += value >= 1000000000000000000ULL; + length += value >= 10000000000000000000ULL; + return length; + } else { + int length = 11; + length += value >= 100000000000ULL; + length += value >= 1000000000000ULL; + length += value >= 10000000000000ULL; + length += value >= 100000000000000ULL; + return length; + } + } else { + if (value >= 100000ULL) { + int length = 6; + length += value >= 1000000ULL; + length += value >= 10000000ULL; + length += value >= 100000000ULL; + length += value >= 1000000000ULL; + return length; + } else { + int length = 1; + length += value >= 10ULL; + length += value >= 100ULL; + length += value >= 1000ULL; + length += value >= 10000ULL; + return length; + } + } + } +}; + +struct DateToStringCast { + static uint64_t Length(int32_t date[], uint64_t& yearLength, bool& addBC) { + // format is YYYY-MM-DD with optional (BC) at the end + // regular length is 10 + uint64_t length = 6; + yearLength = 4; + addBC = false; + if (date[0] <= 0) { + // add (BC) suffix + length += strlen(Date::BC_SUFFIX); + date[0] = -date[0] + 1; + addBC = true; + } + + // potentially add extra characters depending on length of year + yearLength += date[0] >= 10000; + yearLength += date[0] >= 100000; + yearLength += date[0] >= 1000000; + yearLength += date[0] >= 10000000; + length += yearLength; + return length; + } + + static void Format(char* data, int32_t date[], uint64_t yearLen, bool addBC) { + // now we write the string, first write the year + auto endptr = data + yearLen; + endptr = NumericHelper::FormatUnsigned(date[0], endptr); + // add optional leading zeros + while (endptr > data) { + *--endptr = '0'; + } + // now write the month and day + auto ptr = data + yearLen; + for (int i = 1; i <= 2; i++) { + ptr[0] = '-'; + if (date[i] < 10) { + ptr[1] = '0'; + ptr[2] = '0' + date[i]; + } else { + auto index = static_cast(date[i] * 2); + ptr[1] = digits[index]; + ptr[2] = digits[index + 1]; + } + ptr += 3; + } + // optionally add BC to the end of the date + if (addBC) { + memcpy(ptr, Date::BC_SUFFIX, // NOLINT(bugprone-not-null-terminated-result): no need to + // put null terminator + strlen(Date::BC_SUFFIX)); + } + } +}; + +struct TimeToStringCast { + // Format microseconds to a buffer of length 6. Returns the number of trailing zeros + static int32_t FormatMicros(uint32_t microseconds, char micro_buffer[]) { + char* endptr = micro_buffer + 6; + endptr = NumericHelper::FormatUnsigned(microseconds, endptr); + while (endptr > micro_buffer) { + *--endptr = '0'; + } + uint64_t trailing_zeros = 0; + for (uint64_t i = 5; i > 0; i--) { + if (micro_buffer[i] != '0') { + break; + } + trailing_zeros++; + } + return trailing_zeros; + } + + static uint64_t Length(int32_t time[], char micro_buffer[]) { + // format is HH:MM:DD.MS + // microseconds come after the time with a period separator + uint64_t length = 0; + if (time[3] == 0) { + // no microseconds + // format is HH:MM:DD + length = 8; + } else { + length = 15; + // for microseconds, we truncate any trailing zeros (i.e. "90000" becomes ".9") + // first write the microseconds to the microsecond buffer + // we write backwards and pad with zeros to the left + // now we figure out how many digits we need to include by looking backwards + // and checking how many zeros we encounter + length -= FormatMicros(time[3], micro_buffer); + } + return length; + } + + static void FormatTwoDigits(char* ptr, int32_t value) { + if (value < 10) { + ptr[0] = '0'; + ptr[1] = '0' + value; + } else { + auto index = static_cast(value * 2); + ptr[0] = digits[index]; + ptr[1] = digits[index + 1]; + } + } + + static void Format(char* data, uint64_t length, int32_t time[], char micro_buffer[]) { + // first write hour, month and day + auto ptr = data; + ptr[2] = ':'; + ptr[5] = ':'; + for (int i = 0; i <= 2; i++) { + FormatTwoDigits(ptr, time[i]); + ptr += 3; + } + if (length > 8) { + // write the micro seconds at the end + data[8] = '.'; + memcpy(data + 9, micro_buffer, length - 9); + } + } +}; + +struct IntervalToStringCast { + static void FormatSignedNumber(int64_t value, char buffer[], uint64_t& length) { + int sign = -(value < 0); + uint64_t unsigned_value = (value ^ sign) - sign; + length += NumericHelper::getUnsignedInt64Length(unsigned_value) - sign; + auto endptr = buffer + length; + endptr = NumericHelper::FormatUnsigned(unsigned_value, endptr); + if (sign) { + *--endptr = '-'; + } + } + + static void FormatTwoDigits(int64_t value, char buffer[], uint64_t& length) { + TimeToStringCast::FormatTwoDigits(buffer + length, value); + length += 2; + } + + static void FormatIntervalValue(int32_t value, char buffer[], uint64_t& length, + const char* name, uint64_t name_len) { + if (value == 0) { + return; + } + if (length != 0) { + // space if there is already something in the buffer + buffer[length++] = ' '; + } + FormatSignedNumber(value, buffer, length); + // append the name together with a potential "s" (for plurals) + memcpy(buffer + length, name, name_len); + length += name_len; + if (value != 1) { + buffer[length++] = 's'; + } + } + + //! Formats an interval to a buffer, the buffer should be >=70 characters + //! years: 17 characters (max value: "-2147483647 years") + //! months: 9 (max value: "12 months") + //! days: 16 characters (max value: "-2147483647 days") + //! time: 24 characters (max value: -2562047788:00:00.123456) + //! spaces between all characters (+3 characters) + //! Total: 70 characters + //! Returns the length of the interval + static uint64_t Format(interval_t interval, char buffer[]) { + uint64_t length = 0; + if (interval.months != 0) { + int32_t years = interval.months / 12; + int32_t months = interval.months - years * 12; + // format the years and months + FormatIntervalValue(years, buffer, length, " year", 5); + FormatIntervalValue(months, buffer, length, " month", 6); + } + if (interval.days != 0) { + // format the days + FormatIntervalValue(interval.days, buffer, length, " day", 4); + } + if (interval.micros != 0) { + if (length != 0) { + // space if there is already something in the buffer + buffer[length++] = ' '; + } + int64_t micros = interval.micros; + if (micros < 0) { + // negative time: append negative sign + buffer[length++] = '-'; + micros = -micros; + } + int64_t hour = micros / Interval::MICROS_PER_HOUR; + micros -= hour * Interval::MICROS_PER_HOUR; + int64_t min = micros / Interval::MICROS_PER_MINUTE; + micros -= min * Interval::MICROS_PER_MINUTE; + int64_t sec = micros / Interval::MICROS_PER_SEC; + micros -= sec * Interval::MICROS_PER_SEC; + + if (hour < 10) { + buffer[length++] = '0'; + } + FormatSignedNumber(hour, buffer, length); + buffer[length++] = ':'; + FormatTwoDigits(min, buffer, length); + buffer[length++] = ':'; + FormatTwoDigits(sec, buffer, length); + if (micros != 0) { + buffer[length++] = '.'; + auto trailing_zeros = TimeToStringCast::FormatMicros(micros, buffer + length); + length += 6 - trailing_zeros; + } + } else if (length == 0) { + // empty interval: default to 00:00:00 + strcpy(buffer, "00:00:00"); // NOLINT(clang-analyzer-security.insecureAPI.strcpy): + // safety guaranteed by Length(). + return 8; + } + return length; + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/date_t.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/date_t.h new file mode 100644 index 0000000000..bbdd1d049e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/date_t.h @@ -0,0 +1,121 @@ +#pragma once + +#include "interval_t.h" + +namespace lbug { + +namespace regex { +class RE2; +} + +namespace common { + +struct timestamp_t; + +// System representation of dates as the number of days since 1970-01-01. +struct LBUG_API date_t { + int32_t days; + + date_t(); + explicit date_t(int32_t days_p); + + // Comparison operators with date_t. + bool operator==(const date_t& rhs) const; + bool operator!=(const date_t& rhs) const; + bool operator<=(const date_t& rhs) const; + bool operator<(const date_t& rhs) const; + bool operator>(const date_t& rhs) const; + bool operator>=(const date_t& rhs) const; + + // Comparison operators with timestamp_t. + bool operator==(const timestamp_t& rhs) const; + bool operator!=(const timestamp_t& rhs) const; + bool operator<(const timestamp_t& rhs) const; + bool operator<=(const timestamp_t& rhs) const; + bool operator>(const timestamp_t& rhs) const; + bool operator>=(const timestamp_t& rhs) const; + + // arithmetic operators + date_t operator+(const int32_t& day) const; + date_t operator-(const int32_t& day) const; + + date_t operator+(const interval_t& interval) const; + date_t operator-(const interval_t& interval) const; + + int64_t operator-(const date_t& rhs) const; +}; + +inline date_t operator+(int64_t i, const date_t date) { + return date + i; +} + +// Note: Aside from some minor changes, this implementation is copied from DuckDB's source code: +// https://github.com/duckdb/duckdb/blob/master/src/include/duckdb/common/types/date.hpp. +// https://github.com/duckdb/duckdb/blob/master/src/common/types/date.cpp. +// For example, instead of using their idx_t type to refer to indices, we directly use uint64_t, +// which is the actual type of idx_t (so we say uint64_t len instead of idx_t len). When more +// functionality is needed, we should first consult these DuckDB links. +class Date { +public: + LBUG_API static const int32_t NORMAL_DAYS[13]; + LBUG_API static const int32_t CUMULATIVE_DAYS[13]; + LBUG_API static const int32_t LEAP_DAYS[13]; + LBUG_API static const int32_t CUMULATIVE_LEAP_DAYS[13]; + LBUG_API static const int32_t CUMULATIVE_YEAR_DAYS[401]; + LBUG_API static const int8_t MONTH_PER_DAY_OF_YEAR[365]; + LBUG_API static const int8_t LEAP_MONTH_PER_DAY_OF_YEAR[366]; + + LBUG_API constexpr static const int32_t MIN_YEAR = -290307; + LBUG_API constexpr static const int32_t MAX_YEAR = 294247; + LBUG_API constexpr static const int32_t EPOCH_YEAR = 1970; + + LBUG_API constexpr static const int32_t YEAR_INTERVAL = 400; + LBUG_API constexpr static const int32_t DAYS_PER_YEAR_INTERVAL = 146097; + constexpr static const char* BC_SUFFIX = " (BC)"; + + // Convert a string in the format "YYYY-MM-DD" to a date object + LBUG_API static date_t fromCString(const char* str, uint64_t len); + // Convert a date object to a string in the format "YYYY-MM-DD" + LBUG_API static std::string toString(date_t date); + // Try to convert text in a buffer to a date; returns true if parsing was successful + LBUG_API static bool tryConvertDate(const char* buf, uint64_t len, uint64_t& pos, + date_t& result, bool allowTrailing = false); + + // private: + // Returns true if (year) is a leap year, and false otherwise + LBUG_API static bool isLeapYear(int32_t year); + // Returns true if the specified (year, month, day) combination is a valid + // date + LBUG_API static bool isValid(int32_t year, int32_t month, int32_t day); + // Extract the year, month and day from a given date object + LBUG_API static void convert(date_t date, int32_t& out_year, int32_t& out_month, + int32_t& out_day); + // Create a Date object from a specified (year, month, day) combination + LBUG_API static date_t fromDate(int32_t year, int32_t month, int32_t day); + + // Helper function to parse two digits from a string (e.g. "30" -> 30, "03" -> 3, "3" -> 3) + LBUG_API static bool parseDoubleDigit(const char* buf, uint64_t len, uint64_t& pos, + int32_t& result); + + LBUG_API static int32_t monthDays(int32_t year, int32_t month); + + LBUG_API static std::string getDayName(date_t date); + + LBUG_API static std::string getMonthName(date_t date); + + LBUG_API static date_t getLastDay(date_t date); + + LBUG_API static int32_t getDatePart(DatePartSpecifier specifier, date_t date); + + LBUG_API static date_t trunc(DatePartSpecifier specifier, date_t date); + + LBUG_API static int64_t getEpochNanoSeconds(const date_t& date); + + LBUG_API static const regex::RE2& regexPattern(); + +private: + static void extractYearOffset(int32_t& n, int32_t& year, int32_t& year_offset); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/dtime_t.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/dtime_t.h new file mode 100644 index 0000000000..524f72f412 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/dtime_t.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +// Type used to represent time (microseconds) +struct LBUG_API dtime_t { + int64_t micros; + + dtime_t(); + explicit dtime_t(int64_t micros_p); + dtime_t& operator=(int64_t micros_p); + + // explicit conversion + explicit operator int64_t() const; + explicit operator double() const; + + // comparison operators + bool operator==(const dtime_t& rhs) const; + bool operator!=(const dtime_t& rhs) const; + bool operator<=(const dtime_t& rhs) const; + bool operator<(const dtime_t& rhs) const; + bool operator>(const dtime_t& rhs) const; + bool operator>=(const dtime_t& rhs) const; +}; + +// Note: Aside from some minor changes, this implementation is copied from DuckDB's source code: +// https://github.com/duckdb/duckdb/blob/master/src/include/duckdb/common/types/time.hpp. +// https://github.com/duckdb/duckdb/blob/master/src/common/types/time.cpp. +// For example, instead of using their idx_t type to refer to indices, we directly use uint64_t, +// which is the actual type of idx_t (so we say uint64_t len instead of idx_t len). When more +// functionality is needed, we should first consult these DuckDB links. +class Time { +public: + // Convert a string in the format "hh:mm:ss" to a time object + LBUG_API static dtime_t fromCString(const char* buf, uint64_t len); + LBUG_API static bool tryConvertInterval(const char* buf, uint64_t len, uint64_t& pos, + dtime_t& result); + LBUG_API static bool tryConvertTime(const char* buf, uint64_t len, uint64_t& pos, + dtime_t& result); + + // Convert a time object to a string in the format "hh:mm:ss" + LBUG_API static std::string toString(dtime_t time); + + LBUG_API static dtime_t fromTime(int32_t hour, int32_t minute, int32_t second, + int32_t microseconds = 0); + + // Extract the time from a given timestamp object + LBUG_API static void convert(dtime_t time, int32_t& out_hour, int32_t& out_min, + int32_t& out_sec, int32_t& out_micros); + + LBUG_API static bool isValid(int32_t hour, int32_t minute, int32_t second, + int32_t milliseconds); + +private: + static bool tryConvertInternal(const char* buf, uint64_t len, uint64_t& pos, dtime_t& result); + static dtime_t fromTimeInternal(int32_t hour, int32_t minute, int32_t second, + int32_t microseconds = 0); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/int128_t.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/int128_t.h new file mode 100644 index 0000000000..d64a8d4f4d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/int128_t.h @@ -0,0 +1,230 @@ +// ========================================================================================= +// This int128 implementtaion got + +// ========================================================================================= +#pragma once + +#include +#include +#include + +#include "common/api.h" +#include "common/exception/overflow.h" + +namespace lbug { +namespace common { + +struct LBUG_API int128_t; +struct uint128_t; + +// System representation for int128_t. +struct LBUG_API int128_t { + uint64_t low; + int64_t high; + + int128_t() noexcept = default; + int128_t(int64_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(int32_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(int16_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(int8_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(uint64_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(uint32_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(uint16_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(uint8_t value); // NOLINT: Allow implicit conversion from numeric values + int128_t(double value); // NOLINT: Allow implicit conversion from numeric values + int128_t(float value); // NOLINT: Allow implicit conversion from numeric values + + constexpr int128_t(uint64_t low, int64_t high) noexcept : low(low), high(high) {} + + constexpr int128_t(const int128_t&) noexcept = default; + constexpr int128_t(int128_t&&) noexcept = default; + int128_t& operator=(const int128_t&) noexcept = default; + int128_t& operator=(int128_t&&) noexcept = default; + + int128_t operator-() const; + + // inplace arithmetic operators + int128_t& operator+=(const int128_t& rhs); + int128_t& operator*=(const int128_t& rhs); + int128_t& operator|=(const int128_t& rhs); + int128_t& operator&=(const int128_t& rhs); + + // cast operators + explicit operator int64_t() const; + explicit operator int32_t() const; + explicit operator int16_t() const; + explicit operator int8_t() const; + explicit operator uint64_t() const; + explicit operator uint32_t() const; + explicit operator uint16_t() const; + explicit operator uint8_t() const; + explicit operator double() const; + explicit operator float() const; + + explicit operator uint128_t() const; +}; + +// arithmetic operators +LBUG_API int128_t operator+(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator-(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator*(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator/(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator%(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator^(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator&(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator~(const int128_t& val); +LBUG_API int128_t operator|(const int128_t& lhs, const int128_t& rhs); +LBUG_API int128_t operator<<(const int128_t& lhs, int amount); +LBUG_API int128_t operator>>(const int128_t& lhs, int amount); + +// comparison operators +LBUG_API bool operator==(const int128_t& lhs, const int128_t& rhs); +LBUG_API bool operator!=(const int128_t& lhs, const int128_t& rhs); +LBUG_API bool operator>(const int128_t& lhs, const int128_t& rhs); +LBUG_API bool operator>=(const int128_t& lhs, const int128_t& rhs); +LBUG_API bool operator<(const int128_t& lhs, const int128_t& rhs); +LBUG_API bool operator<=(const int128_t& lhs, const int128_t& rhs); + +class Int128_t { +public: + static std::string toString(int128_t input); + + template + static bool tryCast(int128_t input, T& result); + + template + static T cast(int128_t input) { + T result; + tryCast(input, result); + return result; + } + + template + static bool tryCastTo(T value, int128_t& result); + + template + static int128_t castTo(T value) { + int128_t result{}; + if (!tryCastTo(value, result)) { + throw common::OverflowException("INT128 is out of range"); + } + return result; + } + + // negate + static void negateInPlace(int128_t& input) { + if (input.high == INT64_MIN && input.low == 0) { + throw common::OverflowException("INT128 is out of range: cannot negate INT128_MIN"); + } + input.low = UINT64_MAX + 1 - input.low; + input.high = -input.high - 1 + (input.low == 0); + } + + static int128_t negate(int128_t input) { + negateInPlace(input); + return input; + } + + static bool tryMultiply(int128_t lhs, int128_t rhs, int128_t& result); + + static int128_t Add(int128_t lhs, int128_t rhs); + static int128_t Sub(int128_t lhs, int128_t rhs); + static int128_t Mul(int128_t lhs, int128_t rhs); + static int128_t Div(int128_t lhs, int128_t rhs); + static int128_t Mod(int128_t lhs, int128_t rhs); + static int128_t Xor(int128_t lhs, int128_t rhs); + static int128_t LeftShift(int128_t lhs, int amount); + static int128_t RightShift(int128_t lhs, int amount); + static int128_t BinaryAnd(int128_t lhs, int128_t rhs); + static int128_t BinaryOr(int128_t lhs, int128_t rhs); + static int128_t BinaryNot(int128_t val); + + static int128_t divMod(int128_t lhs, int128_t rhs, int128_t& remainder); + static int128_t divModPositive(int128_t lhs, uint64_t rhs, uint64_t& remainder); + + static bool addInPlace(int128_t& lhs, int128_t rhs); + static bool subInPlace(int128_t& lhs, int128_t rhs); + + // comparison operators + static bool equals(int128_t lhs, int128_t rhs) { + return lhs.low == rhs.low && lhs.high == rhs.high; + } + + static bool notEquals(int128_t lhs, int128_t rhs) { + return lhs.low != rhs.low || lhs.high != rhs.high; + } + + static bool greaterThan(int128_t lhs, int128_t rhs) { + return (lhs.high > rhs.high) || (lhs.high == rhs.high && lhs.low > rhs.low); + } + + static bool greaterThanOrEquals(int128_t lhs, int128_t rhs) { + return (lhs.high > rhs.high) || (lhs.high == rhs.high && lhs.low >= rhs.low); + } + + static bool lessThan(int128_t lhs, int128_t rhs) { + return (lhs.high < rhs.high) || (lhs.high == rhs.high && lhs.low < rhs.low); + } + + static bool lessThanOrEquals(int128_t lhs, int128_t rhs) { + return (lhs.high < rhs.high) || (lhs.high == rhs.high && lhs.low <= rhs.low); + } +}; + +template<> +bool Int128_t::tryCast(int128_t input, int8_t& result); +template<> +bool Int128_t::tryCast(int128_t input, int16_t& result); +template<> +bool Int128_t::tryCast(int128_t input, int32_t& result); +template<> +bool Int128_t::tryCast(int128_t input, int64_t& result); +template<> +bool Int128_t::tryCast(int128_t input, uint8_t& result); +template<> +bool Int128_t::tryCast(int128_t input, uint16_t& result); +template<> +bool Int128_t::tryCast(int128_t input, uint32_t& result); +template<> +bool Int128_t::tryCast(int128_t input, uint64_t& result); +template<> +bool Int128_t::tryCast(int128_t input, uint128_t& result); // signed to unsigned +template<> +bool Int128_t::tryCast(int128_t input, float& result); +template<> +bool Int128_t::tryCast(int128_t input, double& result); +template<> +bool Int128_t::tryCast(int128_t input, long double& result); + +template<> +bool Int128_t::tryCastTo(int8_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(int16_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(int32_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(int64_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(uint8_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(uint16_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(uint32_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(uint64_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(int128_t value, int128_t& result); +template<> +bool Int128_t::tryCastTo(float value, int128_t& result); +template<> +bool Int128_t::tryCastTo(double value, int128_t& result); +template<> +bool Int128_t::tryCastTo(long double value, int128_t& result); + +} // namespace common +} // namespace lbug + +template<> +struct std::hash { + std::size_t operator()(const lbug::common::int128_t& v) const noexcept; +}; diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/internal_id_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/internal_id_util.h new file mode 100644 index 0000000000..bbb197afa3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/internal_id_util.h @@ -0,0 +1,18 @@ +#pragma once + +#include "common/types/types.h" +#include "function/hash/hash_functions.h" + +namespace lbug { +namespace common { + +using internal_id_set_t = std::unordered_set; +using node_id_set_t = internal_id_set_t; +using rel_id_set_t = internal_id_set_t; +template +using internal_id_map_t = std::unordered_map; +template +using node_id_map_t = internal_id_map_t; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/interval_t.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/interval_t.h new file mode 100644 index 0000000000..3ee8c7de77 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/interval_t.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include + +#include "common/api.h" + +namespace lbug { + +namespace regex { +class RE2; +} + +namespace common { + +struct timestamp_t; +struct date_t; + +enum class DatePartSpecifier : uint8_t { + YEAR, + MONTH, + DAY, + DECADE, + CENTURY, + MILLENNIUM, + QUARTER, + MICROSECOND, + MILLISECOND, + SECOND, + MINUTE, + HOUR, + WEEK, +}; + +struct LBUG_API interval_t { + int32_t months = 0; + int32_t days = 0; + int64_t micros = 0; + + interval_t(); + interval_t(int32_t months_p, int32_t days_p, int64_t micros_p); + + // comparator operators + bool operator==(const interval_t& rhs) const; + bool operator!=(const interval_t& rhs) const; + + bool operator>(const interval_t& rhs) const; + bool operator<=(const interval_t& rhs) const; + bool operator<(const interval_t& rhs) const; + bool operator>=(const interval_t& rhs) const; + + // arithmetic operators + interval_t operator+(const interval_t& rhs) const; + timestamp_t operator+(const timestamp_t& rhs) const; + date_t operator+(const date_t& rhs) const; + interval_t operator-(const interval_t& rhs) const; + + interval_t operator/(const uint64_t& rhs) const; +}; + +// Note: Aside from some minor changes, this implementation is copied from DuckDB's source code: +// https://github.com/duckdb/duckdb/blob/master/src/include/duckdb/common/types/interval.hpp. +// https://github.com/duckdb/duckdb/blob/master/src/common/types/interval.cpp. +// When more functionality is needed, we should first consult these DuckDB links. +// The Interval class is a static class that holds helper functions for the Interval type. +class Interval { +public: + static constexpr const int32_t MONTHS_PER_MILLENIUM = 12000; + static constexpr const int32_t MONTHS_PER_CENTURY = 1200; + static constexpr const int32_t MONTHS_PER_DECADE = 120; + static constexpr const int32_t MONTHS_PER_YEAR = 12; + static constexpr const int32_t MONTHS_PER_QUARTER = 3; + static constexpr const int32_t DAYS_PER_WEEK = 7; + //! only used for interval comparison/ordering purposes, in which case a month counts as 30 days + static constexpr const int64_t DAYS_PER_MONTH = 30; + static constexpr const int64_t DAYS_PER_YEAR = 365; + static constexpr const int64_t MSECS_PER_SEC = 1000; + static constexpr const int32_t SECS_PER_MINUTE = 60; + static constexpr const int32_t MINS_PER_HOUR = 60; + static constexpr const int32_t HOURS_PER_DAY = 24; + static constexpr const int32_t SECS_PER_HOUR = SECS_PER_MINUTE * MINS_PER_HOUR; + static constexpr const int32_t SECS_PER_DAY = SECS_PER_HOUR * HOURS_PER_DAY; + static constexpr const int32_t SECS_PER_WEEK = SECS_PER_DAY * DAYS_PER_WEEK; + + static constexpr const int64_t MICROS_PER_MSEC = 1000; + static constexpr const int64_t MICROS_PER_SEC = MICROS_PER_MSEC * MSECS_PER_SEC; + static constexpr const int64_t MICROS_PER_MINUTE = MICROS_PER_SEC * SECS_PER_MINUTE; + static constexpr const int64_t MICROS_PER_HOUR = MICROS_PER_MINUTE * MINS_PER_HOUR; + static constexpr const int64_t MICROS_PER_DAY = MICROS_PER_HOUR * HOURS_PER_DAY; + static constexpr const int64_t MICROS_PER_WEEK = MICROS_PER_DAY * DAYS_PER_WEEK; + static constexpr const int64_t MICROS_PER_MONTH = MICROS_PER_DAY * DAYS_PER_MONTH; + + static constexpr const int64_t NANOS_PER_MICRO = 1000; + static constexpr const int64_t NANOS_PER_MSEC = NANOS_PER_MICRO * MICROS_PER_MSEC; + static constexpr const int64_t NANOS_PER_SEC = NANOS_PER_MSEC * MSECS_PER_SEC; + static constexpr const int64_t NANOS_PER_MINUTE = NANOS_PER_SEC * SECS_PER_MINUTE; + static constexpr const int64_t NANOS_PER_HOUR = NANOS_PER_MINUTE * MINS_PER_HOUR; + static constexpr const int64_t NANOS_PER_DAY = NANOS_PER_HOUR * HOURS_PER_DAY; + static constexpr const int64_t NANOS_PER_WEEK = NANOS_PER_DAY * DAYS_PER_WEEK; + + LBUG_API static void addition(interval_t& result, uint64_t number, std::string specifierStr); + LBUG_API static interval_t fromCString(const char* str, uint64_t len); + LBUG_API static std::string toString(interval_t interval); + LBUG_API static bool greaterThan(const interval_t& left, const interval_t& right); + LBUG_API static void normalizeIntervalEntries(interval_t input, int64_t& months, int64_t& days, + int64_t& micros); + LBUG_API static void tryGetDatePartSpecifier(std::string specifier, DatePartSpecifier& result); + LBUG_API static int32_t getIntervalPart(DatePartSpecifier specifier, interval_t timestamp); + LBUG_API static int64_t getMicro(const interval_t& val); + LBUG_API static int64_t getNanoseconds(const interval_t& val); + LBUG_API static const regex::RE2& regexPattern1(); + LBUG_API static const regex::RE2& regexPattern2(); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/ku_list.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/ku_list.h new file mode 100644 index 0000000000..249d7eb4b3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/ku_list.h @@ -0,0 +1,23 @@ +#pragma once + +#include "types.h" + +namespace lbug { +namespace common { + +struct ku_list_t { + ku_list_t() : size{0}, overflowPtr{0} {} + ku_list_t(uint64_t size, uint64_t overflowPtr) : size{size}, overflowPtr{overflowPtr} {} + + void set(const uint8_t* values, const LogicalType& dataType) const; + +private: + void set(const std::vector& parameters, LogicalTypeID childTypeId); + +public: + uint64_t size; + uint64_t overflowPtr; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/ku_string.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/ku_string.h new file mode 100644 index 0000000000..6d067d7cfa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/ku_string.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +struct LBUG_API ku_string_t { + + static constexpr uint64_t PREFIX_LENGTH = 4; + static constexpr uint64_t INLINED_SUFFIX_LENGTH = 8; + static constexpr uint64_t SHORT_STR_LENGTH = PREFIX_LENGTH + INLINED_SUFFIX_LENGTH; + + uint32_t len; + uint8_t prefix[PREFIX_LENGTH]; + union { + uint8_t data[INLINED_SUFFIX_LENGTH]; + uint64_t overflowPtr; + }; + + ku_string_t() : len{0}, prefix{}, overflowPtr{0} {} + ku_string_t(const char* value, uint64_t length); + + static bool isShortString(uint32_t len) { return len <= SHORT_STR_LENGTH; } + + const uint8_t* getData() const { + return isShortString(len) ? prefix : reinterpret_cast(overflowPtr); + } + + uint8_t* getDataUnsafe() { + return isShortString(len) ? prefix : reinterpret_cast(overflowPtr); + } + + // These functions do *NOT* allocate/resize the overflow buffer, it only copies the content and + // set the length. + void set(const std::string& value); + void set(const char* value, uint64_t length); + void set(const ku_string_t& value); + void setShortString(const char* value, uint64_t length) { + this->len = length; + memcpy(prefix, value, length); + } + void setLongString(const char* value, uint64_t length) { + this->len = length; + memcpy(prefix, value, PREFIX_LENGTH); + memcpy(reinterpret_cast(overflowPtr), value, length); + } + void setShortString(const ku_string_t& value) { + this->len = value.len; + memcpy(prefix, value.prefix, value.len); + } + void setLongString(const ku_string_t& value) { + this->len = value.len; + memcpy(prefix, value.prefix, PREFIX_LENGTH); + memcpy(reinterpret_cast(overflowPtr), reinterpret_cast(value.overflowPtr), + value.len); + } + + void setFromRawStr(const char* value, uint64_t length) { + this->len = length; + if (isShortString(length)) { + setShortString(value, length); + } else { + memcpy(prefix, value, PREFIX_LENGTH); + overflowPtr = reinterpret_cast(value); + } + } + + std::string getAsShortString() const; + std::string getAsString() const; + std::string_view getAsStringView() const; + + bool operator==(const ku_string_t& rhs) const; + + inline bool operator!=(const ku_string_t& rhs) const { return !(*this == rhs); } + + bool operator>(const ku_string_t& rhs) const; + + inline bool operator>=(const ku_string_t& rhs) const { return (*this > rhs) || (*this == rhs); } + + inline bool operator<(const ku_string_t& rhs) const { return !(*this >= rhs); } + + inline bool operator<=(const ku_string_t& rhs) const { return !(*this > rhs); } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/timestamp_t.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/timestamp_t.h new file mode 100644 index 0000000000..10af1e191c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/timestamp_t.h @@ -0,0 +1,125 @@ +#pragma once + +#include "date_t.h" +#include "dtime_t.h" + +namespace lbug { +namespace common { + +// Type used to represent timestamps (value is in microseconds since 1970-01-01) +struct LBUG_API timestamp_t { + int64_t value = 0; + + timestamp_t(); + explicit timestamp_t(int64_t value_p); + timestamp_t& operator=(int64_t value_p); + + // explicit conversion + explicit operator int64_t() const; + + // Comparison operators with timestamp_t. + bool operator==(const timestamp_t& rhs) const; + bool operator!=(const timestamp_t& rhs) const; + bool operator<=(const timestamp_t& rhs) const; + bool operator<(const timestamp_t& rhs) const; + bool operator>(const timestamp_t& rhs) const; + bool operator>=(const timestamp_t& rhs) const; + + // Comparison operators with date_t. + bool operator==(const date_t& rhs) const; + bool operator!=(const date_t& rhs) const; + bool operator<(const date_t& rhs) const; + bool operator<=(const date_t& rhs) const; + bool operator>(const date_t& rhs) const; + bool operator>=(const date_t& rhs) const; + + // arithmetic operator + timestamp_t operator+(const interval_t& interval) const; + timestamp_t operator-(const interval_t& interval) const; + + interval_t operator-(const timestamp_t& rhs) const; +}; + +struct timestamp_tz_t : public timestamp_t { // NO LINT + using timestamp_t::timestamp_t; +}; +struct timestamp_ns_t : public timestamp_t { // NO LINT + using timestamp_t::timestamp_t; +}; +struct timestamp_ms_t : public timestamp_t { // NO LINT + using timestamp_t::timestamp_t; +}; +struct timestamp_sec_t : public timestamp_t { // NO LINT + using timestamp_t::timestamp_t; +}; + +// Note: Aside from some minor changes, this implementation is copied from DuckDB's source code: +// https://github.com/duckdb/duckdb/blob/master/src/include/duckdb/common/types/timestamp.hpp. +// https://github.com/duckdb/duckdb/blob/master/src/common/types/timestamp.cpp. +// For example, instead of using their idx_t type to refer to indices, we directly use uint64_t, +// which is the actual type of idx_t (so we say uint64_t len instead of idx_t len). When more +// functionality is needed, we should first consult these DuckDB links. + +// The Timestamp class is a static class that holds helper functions for the Timestamp type. +// timestamp/datetime uses 64 bits, high 32 bits for date and low 32 bits for time +class Timestamp { +public: + LBUG_API static timestamp_t fromCString(const char* str, uint64_t len); + + // Convert a timestamp object to a std::string in the format "YYYY-MM-DD hh:mm:ss". + LBUG_API static std::string toString(timestamp_t timestamp); + + // Date header is in the format: %Y%m%d. + LBUG_API static std::string getDateHeader(const timestamp_t& timestamp); + + // Timestamp header is in the format: %Y%m%dT%H%M%SZ. + LBUG_API static std::string getDateTimeHeader(const timestamp_t& timestamp); + + LBUG_API static date_t getDate(timestamp_t timestamp); + + LBUG_API static dtime_t getTime(timestamp_t timestamp); + + // Create a Timestamp object from a specified (date, time) combination. + LBUG_API static timestamp_t fromDateTime(date_t date, dtime_t time); + + LBUG_API static bool tryConvertTimestamp(const char* str, uint64_t len, timestamp_t& result); + + // Extract the date and time from a given timestamp object. + LBUG_API static void convert(timestamp_t timestamp, date_t& out_date, dtime_t& out_time); + + // Create a Timestamp object from the specified epochMs. + LBUG_API static timestamp_t fromEpochMicroSeconds(int64_t epochMs); + + // Create a Timestamp object from the specified epochMs. + LBUG_API static timestamp_t fromEpochMilliSeconds(int64_t ms); + + // Create a Timestamp object from the specified epochSec. + LBUG_API static timestamp_t fromEpochSeconds(int64_t sec); + + // Create a Timestamp object from the specified epochNs. + LBUG_API static timestamp_t fromEpochNanoSeconds(int64_t ns); + + LBUG_API static int32_t getTimestampPart(DatePartSpecifier specifier, timestamp_t timestamp); + + LBUG_API static timestamp_t trunc(DatePartSpecifier specifier, timestamp_t date); + + LBUG_API static int64_t getEpochNanoSeconds(const timestamp_t& timestamp); + + LBUG_API static int64_t getEpochMilliSeconds(const timestamp_t& timestamp); + + LBUG_API static int64_t getEpochSeconds(const timestamp_t& timestamp); + + LBUG_API static bool tryParseUTCOffset(const char* str, uint64_t& pos, uint64_t len, + int& hour_offset, int& minute_offset); + + static std::string getTimestampConversionExceptionMsg(const char* str, uint64_t len, + const std::string& typeID = "TIMESTAMP") { + return "Error occurred during parsing " + typeID + ". Given: \"" + std::string(str, len) + + "\". Expected format: (YYYY-MM-DD hh:mm:ss[.zzzzzz][+-TT[:tt]])"; + } + + LBUG_API static timestamp_t getCurrentTimestamp(); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/types.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/types.h new file mode 100644 index 0000000000..03e9f47cf7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/types.h @@ -0,0 +1,677 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "common/api.h" +#include "common/cast.h" +#include "common/copy_constructors.h" +#include "common/types/interval_t.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace processor { +class ParquetReader; +} +namespace catalog { +class NodeTableCatalogEntry; +} +namespace common { + +class Serializer; +class Deserializer; +struct FileInfo; + +using sel_t = uint64_t; +constexpr sel_t INVALID_SEL = UINT64_MAX; +using hash_t = uint64_t; +using page_idx_t = uint32_t; +using frame_idx_t = page_idx_t; +using page_offset_t = uint32_t; +constexpr page_idx_t INVALID_PAGE_IDX = UINT32_MAX; +using file_idx_t = uint32_t; +constexpr file_idx_t INVALID_FILE_IDX = UINT32_MAX; +using page_group_idx_t = uint32_t; +using frame_group_idx_t = page_group_idx_t; +using column_id_t = uint32_t; +using property_id_t = uint32_t; +constexpr column_id_t INVALID_COLUMN_ID = UINT32_MAX; +constexpr column_id_t ROW_IDX_COLUMN_ID = INVALID_COLUMN_ID - 1; +using idx_t = uint32_t; +constexpr idx_t INVALID_IDX = UINT32_MAX; +using block_idx_t = uint64_t; +constexpr block_idx_t INVALID_BLOCK_IDX = UINT64_MAX; +using struct_field_idx_t = uint16_t; +using union_field_idx_t = struct_field_idx_t; +constexpr struct_field_idx_t INVALID_STRUCT_FIELD_IDX = UINT16_MAX; +using row_idx_t = uint64_t; +constexpr row_idx_t INVALID_ROW_IDX = UINT64_MAX; +constexpr uint32_t UNDEFINED_CAST_COST = UINT32_MAX; +using node_group_idx_t = uint64_t; +constexpr node_group_idx_t INVALID_NODE_GROUP_IDX = UINT64_MAX; +using partition_idx_t = uint64_t; +constexpr partition_idx_t INVALID_PARTITION_IDX = UINT64_MAX; +using length_t = uint64_t; +constexpr length_t INVALID_LENGTH = UINT64_MAX; +using list_size_t = uint32_t; +using sequence_id_t = uint64_t; +using oid_t = uint64_t; +constexpr oid_t INVALID_OID = UINT64_MAX; + +using transaction_t = uint64_t; +constexpr transaction_t INVALID_TRANSACTION = UINT64_MAX; +using executor_id_t = uint64_t; +using executor_info = std::unordered_map; + +// table id type alias +using table_id_t = oid_t; +using table_id_vector_t = std::vector; +using table_id_set_t = std::unordered_set; +template +using table_id_map_t = std::unordered_map; +constexpr table_id_t INVALID_TABLE_ID = INVALID_OID; +// offset type alias +using offset_t = uint64_t; +constexpr offset_t INVALID_OFFSET = UINT64_MAX; +// internal id type alias +struct internalID_t; +using nodeID_t = internalID_t; +using relID_t = internalID_t; + +using cardinality_t = uint64_t; +constexpr offset_t INVALID_LIMIT = UINT64_MAX; +using offset_vec_t = std::vector; +// System representation for internalID. +struct LBUG_API internalID_t { + offset_t offset; + table_id_t tableID; + + internalID_t(); + internalID_t(offset_t offset, table_id_t tableID); + + // comparison operators + bool operator==(const internalID_t& rhs) const; + bool operator!=(const internalID_t& rhs) const; + bool operator>(const internalID_t& rhs) const; + bool operator>=(const internalID_t& rhs) const; + bool operator<(const internalID_t& rhs) const; + bool operator<=(const internalID_t& rhs) const; +}; + +// System representation for a variable-sized overflow value. +struct overflow_value_t { + // the size of the overflow buffer can be calculated as: + // numElements * sizeof(Element) + nullMap(4 bytes alignment) + uint64_t numElements = 0; + uint8_t* value = nullptr; +}; + +struct list_entry_t { + offset_t offset; + list_size_t size; + + constexpr list_entry_t() : offset{INVALID_OFFSET}, size{UINT32_MAX} {} + constexpr list_entry_t(offset_t offset, list_size_t size) : offset{offset}, size{size} {} +}; + +struct struct_entry_t { + int64_t pos; +}; + +struct map_entry_t { + list_entry_t entry; +}; + +struct union_entry_t { + struct_entry_t entry; +}; + +struct int128_t; +struct uint128_t; +struct ku_string_t; + +template +concept SignedIntegerTypes = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; + +template +concept UnsignedIntegerTypes = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; + +template +concept IntegerTypes = SignedIntegerTypes || UnsignedIntegerTypes; + +template +concept FloatingPointTypes = std::is_same_v || std::is_same_v; + +template +concept NumericTypes = IntegerTypes || std::floating_point; + +template +concept ComparableTypes = NumericTypes || std::is_same_v || + std::is_same_v || std::is_same_v; + +template +concept HashablePrimitive = + ((std::integral && !std::is_same_v) || std::floating_point || + std::is_same_v || std::is_same_v); +template +concept IndexHashable = ((std::integral && !std::is_same_v) || std::floating_point || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::same_as); + +template +concept HashableNonNestedTypes = + (std::integral || std::floating_point || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v); + +template +concept HashableNestedTypes = + (std::is_same_v || std::is_same_v); + +template +concept HashableTypes = (HashableNestedTypes || HashableNonNestedTypes); + +enum class LogicalTypeID : uint8_t { + ANY = 0, + NODE = 10, + REL = 11, + RECURSIVE_REL = 12, + // SERIAL is a special data type that is used to represent a sequence of INT64 values that are + // incremented by 1 starting from 0. + SERIAL = 13, + + BOOL = 22, + INT64 = 23, + INT32 = 24, + INT16 = 25, + INT8 = 26, + UINT64 = 27, + UINT32 = 28, + UINT16 = 29, + UINT8 = 30, + INT128 = 31, + DOUBLE = 32, + FLOAT = 33, + DATE = 34, + TIMESTAMP = 35, + TIMESTAMP_SEC = 36, + TIMESTAMP_MS = 37, + TIMESTAMP_NS = 38, + TIMESTAMP_TZ = 39, + INTERVAL = 40, + DECIMAL = 41, + INTERNAL_ID = 42, + UINT128 = 43, + + STRING = 50, + BLOB = 51, + + LIST = 52, + ARRAY = 53, + STRUCT = 54, + MAP = 55, + UNION = 56, + POINTER = 58, + + UUID = 59, + +}; + +enum class PhysicalTypeID : uint8_t { + // Fixed size types. + ANY = 0, + BOOL = 1, + INT64 = 2, + INT32 = 3, + INT16 = 4, + INT8 = 5, + UINT64 = 6, + UINT32 = 7, + UINT16 = 8, + UINT8 = 9, + INT128 = 10, + DOUBLE = 11, + FLOAT = 12, + INTERVAL = 13, + INTERNAL_ID = 14, + ALP_EXCEPTION_FLOAT = 15, + ALP_EXCEPTION_DOUBLE = 16, + UINT128 = 17, + + // Variable size types. + STRING = 20, + LIST = 22, + ARRAY = 23, + STRUCT = 24, + POINTER = 25, +}; + +class ExtraTypeInfo; +class StructField; +class StructTypeInfo; + +enum class TypeCategory : uint8_t { INTERNAL = 0, UDT = 1 }; + +class LogicalType { + friend struct LogicalTypeUtils; + friend struct DecimalType; + friend struct StructType; + friend struct ListType; + friend struct ArrayType; + + LBUG_API LogicalType(const LogicalType& other); + +public: + LogicalType() : typeID{LogicalTypeID::ANY}, extraTypeInfo{nullptr} { + physicalType = getPhysicalType(this->typeID); + }; + explicit LBUG_API LogicalType(LogicalTypeID typeID, TypeCategory info = TypeCategory::INTERNAL); + EXPLICIT_COPY_DEFAULT_MOVE(LogicalType); + + LBUG_API bool operator==(const LogicalType& other) const; + LBUG_API bool operator!=(const LogicalType& other) const; + + LBUG_API std::string toString() const; + static bool isBuiltInType(const std::string& str); + static LogicalType convertFromString(const std::string& str, main::ClientContext* context); + + LogicalTypeID getLogicalTypeID() const { return typeID; } + bool containsAny() const; + bool isInternalType() const { return category == TypeCategory::INTERNAL; } + + PhysicalTypeID getPhysicalType() const { return physicalType; } + LBUG_API static PhysicalTypeID getPhysicalType(LogicalTypeID logicalType, + const std::unique_ptr& extraTypeInfo = nullptr); + + void setExtraTypeInfo(std::unique_ptr typeInfo) { + extraTypeInfo = std::move(typeInfo); + } + + const ExtraTypeInfo* getExtraTypeInfo() const { return extraTypeInfo.get(); } + + void serialize(Serializer& serializer) const; + + static LogicalType deserialize(Deserializer& deserializer); + + LBUG_API static std::vector copy(const std::vector& types); + LBUG_API static std::vector copy(const std::vector& types); + + static LogicalType ANY() { return LogicalType(LogicalTypeID::ANY); } + + // NOTE: avoid using this if possible, this is a temporary hack for passing internal types + // TODO(Royi) remove this when float compression no longer relies on this or ColumnChunkData + // takes physical types instead of logical types + static LogicalType ANY(PhysicalTypeID physicalType) { + auto ret = LogicalType(LogicalTypeID::ANY); + ret.physicalType = physicalType; + return ret; + } + + static LogicalType BOOL() { return LogicalType(LogicalTypeID::BOOL); } + static LogicalType HASH() { return LogicalType(LogicalTypeID::UINT64); } + static LogicalType INT64() { return LogicalType(LogicalTypeID::INT64); } + static LogicalType INT32() { return LogicalType(LogicalTypeID::INT32); } + static LogicalType INT16() { return LogicalType(LogicalTypeID::INT16); } + static LogicalType INT8() { return LogicalType(LogicalTypeID::INT8); } + static LogicalType UINT64() { return LogicalType(LogicalTypeID::UINT64); } + static LogicalType UINT32() { return LogicalType(LogicalTypeID::UINT32); } + static LogicalType UINT16() { return LogicalType(LogicalTypeID::UINT16); } + static LogicalType UINT8() { return LogicalType(LogicalTypeID::UINT8); } + static LogicalType INT128() { return LogicalType(LogicalTypeID::INT128); } + static LogicalType DOUBLE() { return LogicalType(LogicalTypeID::DOUBLE); } + static LogicalType FLOAT() { return LogicalType(LogicalTypeID::FLOAT); } + static LogicalType DATE() { return LogicalType(LogicalTypeID::DATE); } + static LogicalType TIMESTAMP_NS() { return LogicalType(LogicalTypeID::TIMESTAMP_NS); } + static LogicalType TIMESTAMP_MS() { return LogicalType(LogicalTypeID::TIMESTAMP_MS); } + static LogicalType TIMESTAMP_SEC() { return LogicalType(LogicalTypeID::TIMESTAMP_SEC); } + static LogicalType TIMESTAMP_TZ() { return LogicalType(LogicalTypeID::TIMESTAMP_TZ); } + static LogicalType TIMESTAMP() { return LogicalType(LogicalTypeID::TIMESTAMP); } + static LogicalType INTERVAL() { return LogicalType(LogicalTypeID::INTERVAL); } + static LBUG_API LogicalType DECIMAL(uint32_t precision, uint32_t scale); + static LogicalType INTERNAL_ID() { return LogicalType(LogicalTypeID::INTERNAL_ID); } + static LogicalType UINT128() { return LogicalType(LogicalTypeID::UINT128); }; + static LogicalType SERIAL() { return LogicalType(LogicalTypeID::SERIAL); } + static LogicalType STRING() { return LogicalType(LogicalTypeID::STRING); } + static LogicalType BLOB() { return LogicalType(LogicalTypeID::BLOB); } + static LogicalType UUID() { return LogicalType(LogicalTypeID::UUID); } + static LogicalType POINTER() { return LogicalType(LogicalTypeID::POINTER); } + static LBUG_API LogicalType STRUCT(std::vector&& fields); + + static LBUG_API LogicalType RECURSIVE_REL(std::vector&& fields); + + static LBUG_API LogicalType NODE(std::vector&& fields); + + static LBUG_API LogicalType REL(std::vector&& fields); + + static LBUG_API LogicalType UNION(std::vector&& fields); + + static LBUG_API LogicalType LIST(LogicalType childType); + template + static inline LogicalType LIST(T&& childType) { + return LogicalType::LIST(LogicalType(std::forward(childType))); + } + + static LBUG_API LogicalType MAP(LogicalType keyType, LogicalType valueType); + template + static LogicalType MAP(T&& keyType, T&& valueType) { + return LogicalType::MAP(LogicalType(std::forward(keyType)), + LogicalType(std::forward(valueType))); + } + + static LBUG_API LogicalType ARRAY(LogicalType childType, uint64_t numElements); + template + static LogicalType ARRAY(T&& childType, uint64_t numElements) { + return LogicalType::ARRAY(LogicalType(std::forward(childType)), numElements); + } + +private: + friend struct CAPIHelper; + friend struct JavaAPIHelper; + friend class lbug::processor::ParquetReader; + explicit LogicalType(LogicalTypeID typeID, std::unique_ptr extraTypeInfo); + +private: + LogicalTypeID typeID; + PhysicalTypeID physicalType; + std::unique_ptr extraTypeInfo; + TypeCategory category = TypeCategory::INTERNAL; +}; + +class LBUG_API ExtraTypeInfo { +public: + virtual ~ExtraTypeInfo() = default; + + void serialize(Serializer& serializer) const { serializeInternal(serializer); } + + virtual bool containsAny() const = 0; + + virtual bool operator==(const ExtraTypeInfo& other) const = 0; + + virtual std::unique_ptr copy() const = 0; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + +protected: + virtual void serializeInternal(Serializer& serializer) const = 0; +}; + +class LBUG_API UDTTypeInfo : public ExtraTypeInfo { +public: + explicit UDTTypeInfo(std::string typeName) : typeName{std::move(typeName)} {} + + std::string getTypeName() const { return typeName; } + + bool containsAny() const override { return false; } + + bool operator==(const ExtraTypeInfo& other) const override; + + std::unique_ptr copy() const override; + + static std::unique_ptr deserialize(Deserializer& deserializer); + +private: + void serializeInternal(Serializer& serializer) const override; + +private: + std::string typeName; +}; + +class DecimalTypeInfo final : public ExtraTypeInfo { +public: + explicit DecimalTypeInfo(uint32_t precision = 18, uint32_t scale = 3) + : precision(precision), scale(scale) {} + + uint32_t getPrecision() const { return precision; } + uint32_t getScale() const { return scale; } + + bool containsAny() const override { return false; } + + bool operator==(const ExtraTypeInfo& other) const override; + + std::unique_ptr copy() const override; + + static std::unique_ptr deserialize(Deserializer& deserializer); + +protected: + void serializeInternal(Serializer& serializer) const override; + + uint32_t precision, scale; +}; + +class LBUG_API ListTypeInfo : public ExtraTypeInfo { +public: + ListTypeInfo() = default; + explicit ListTypeInfo(LogicalType childType) : childType{std::move(childType)} {} + + const LogicalType& getChildType() const { return childType; } + + bool containsAny() const override; + + bool operator==(const ExtraTypeInfo& other) const override; + + std::unique_ptr copy() const override; + + static std::unique_ptr deserialize(Deserializer& deserializer); + +protected: + void serializeInternal(Serializer& serializer) const override; + +protected: + LogicalType childType; +}; + +class LBUG_API ArrayTypeInfo final : public ListTypeInfo { +public: + ArrayTypeInfo() : numElements{0} {}; + explicit ArrayTypeInfo(LogicalType childType, uint64_t numElements) + : ListTypeInfo{std::move(childType)}, numElements{numElements} {} + + uint64_t getNumElements() const { return numElements; } + + bool operator==(const ExtraTypeInfo& other) const override; + + static std::unique_ptr deserialize(Deserializer& deserializer); + + std::unique_ptr copy() const override; + +private: + void serializeInternal(Serializer& serializer) const override; + +private: + uint64_t numElements; +}; + +class StructField { +public: + StructField() : type{LogicalType()} {} + StructField(std::string name, LogicalType type) + : name{std::move(name)}, type{std::move(type)} {}; + + DELETE_COPY_DEFAULT_MOVE(StructField); + + std::string getName() const { return name; } + + const LogicalType& getType() const { return type; } + + bool containsAny() const; + + bool operator==(const StructField& other) const; + bool operator!=(const StructField& other) const { return !(*this == other); } + + void serialize(Serializer& serializer) const; + + static StructField deserialize(Deserializer& deserializer); + + StructField copy() const; + +private: + std::string name; + LogicalType type; +}; + +class StructTypeInfo final : public ExtraTypeInfo { +public: + StructTypeInfo() = default; + explicit StructTypeInfo(std::vector&& fields); + StructTypeInfo(const std::vector& fieldNames, + const std::vector& fieldTypes); + + bool hasField(const std::string& fieldName) const; + struct_field_idx_t getStructFieldIdx(std::string fieldName) const; + const StructField& getStructField(struct_field_idx_t idx) const; + const StructField& getStructField(const std::string& fieldName) const; + const std::vector& getStructFields() const; + + const LogicalType& getChildType(struct_field_idx_t idx) const; + std::vector getChildrenTypes() const; + // can't be a vector of refs since that can't be for-each looped through + std::vector getChildrenNames() const; + + bool containsAny() const override; + + bool operator==(const ExtraTypeInfo& other) const override; + + static std::unique_ptr deserialize(Deserializer& deserializer); + std::unique_ptr copy() const override; + +private: + void serializeInternal(Serializer& serializer) const override; + +private: + std::vector fields; + std::unordered_map fieldNameToIdxMap; +}; + +using logical_type_vec_t = std::vector; + +struct LBUG_API DecimalType { + static uint32_t getPrecision(const LogicalType& type); + static uint32_t getScale(const LogicalType& type); + static std::string insertDecimalPoint(const std::string& value, uint32_t posFromEnd); +}; + +struct LBUG_API ListType { + static const LogicalType& getChildType(const LogicalType& type); +}; + +struct LBUG_API ArrayType { + static const LogicalType& getChildType(const LogicalType& type); + static uint64_t getNumElements(const LogicalType& type); +}; + +struct LBUG_API StructType { + static std::vector getFieldTypes(const LogicalType& type); + // since the field types isn't stored as a vector of LogicalTypes, we can't return vector<>& + + static const LogicalType& getFieldType(const LogicalType& type, struct_field_idx_t idx); + + static const LogicalType& getFieldType(const LogicalType& type, const std::string& key); + + static std::vector getFieldNames(const LogicalType& type); + + static uint64_t getNumFields(const LogicalType& type); + + static const std::vector& getFields(const LogicalType& type); + + static bool hasField(const LogicalType& type, const std::string& key); + + static const StructField& getField(const LogicalType& type, struct_field_idx_t idx); + + static const StructField& getField(const LogicalType& type, const std::string& key); + + static struct_field_idx_t getFieldIdx(const LogicalType& type, const std::string& key); +}; + +struct LBUG_API MapType { + static const LogicalType& getKeyType(const LogicalType& type); + + static const LogicalType& getValueType(const LogicalType& type); +}; + +struct LBUG_API UnionType { + static constexpr union_field_idx_t TAG_FIELD_IDX = 0; + + static constexpr auto TAG_FIELD_TYPE = LogicalTypeID::UINT16; + + static constexpr char TAG_FIELD_NAME[] = "tag"; + + static union_field_idx_t getInternalFieldIdx(union_field_idx_t idx); + + static std::string getFieldName(const LogicalType& type, union_field_idx_t idx); + + static const LogicalType& getFieldType(const LogicalType& type, union_field_idx_t idx); + + static const LogicalType& getFieldType(const LogicalType& type, const std::string& key); + + static uint64_t getNumFields(const LogicalType& type); + + static bool hasField(const LogicalType& type, const std::string& key); + + static union_field_idx_t getFieldIdx(const LogicalType& type, const std::string& key); +}; + +struct PhysicalTypeUtils { + static std::string toString(PhysicalTypeID physicalType); + static uint32_t getFixedTypeSize(PhysicalTypeID physicalType); +}; + +struct LBUG_API LogicalTypeUtils { + static std::string toString(LogicalTypeID dataTypeID); + static std::string toString(const std::vector& dataTypes); + static std::string toString(const std::vector& dataTypeIDs); + static uint32_t getRowLayoutSize(const LogicalType& logicalType); + static bool isDate(const LogicalType& dataType); + static bool isDate(const LogicalTypeID& dataType); + static bool isTimestamp(const LogicalType& dataType); + static bool isTimestamp(const LogicalTypeID& dataType); + static bool isUnsigned(const LogicalType& dataType); + static bool isUnsigned(const LogicalTypeID& dataType); + static bool isIntegral(const LogicalType& dataType); + static bool isIntegral(const LogicalTypeID& dataType); + static bool isNumerical(const LogicalType& dataType); + static bool isNumerical(const LogicalTypeID& dataType); + static bool isFloatingPoint(const LogicalTypeID& dataType); + static bool isNested(const LogicalType& dataType); + static bool isNested(LogicalTypeID logicalTypeID); + static std::vector getAllValidComparableLogicalTypes(); + static std::vector getNumericalLogicalTypeIDs(); + static std::vector getIntegerTypeIDs(); + static std::vector getFloatingPointTypeIDs(); + static std::vector getAllValidLogicTypeIDs(); + static std::vector getAllValidLogicTypes(); + static bool tryGetMaxLogicalType(const LogicalType& left, const LogicalType& right, + LogicalType& result); + static bool tryGetMaxLogicalType(const std::vector& types, LogicalType& result); + + // Differs from tryGetMaxLogicalType because it treats string as a maximal type, instead of a + // minimal type. as such, it will always succeed. + // Also combines structs by the union of their fields. As such, currently, it is not guaranteed + // for casting to work from input types to resulting types. Ideally this changes + static LogicalType combineTypes(const LogicalType& left, const LogicalType& right); + static LogicalType combineTypes(const std::vector& types); + + // makes a copy of the type with any occurences of ANY replaced with replacement + static LogicalType purgeAny(const LogicalType& type, const LogicalType& replacement); + +private: + static bool tryGetMaxLogicalTypeID(const LogicalTypeID& left, const LogicalTypeID& right, + LogicalTypeID& result); +}; + +enum class FileVersionType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 }; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/uint128_t.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/uint128_t.h new file mode 100644 index 0000000000..6b5eeea36f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/uint128_t.h @@ -0,0 +1,221 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "common/exception/overflow.h" +#include "common/types/types.h" + +namespace lbug { +namespace common { + +struct int128_t; + +struct LBUG_API uint128_t { + uint64_t low; + uint64_t high; + + uint128_t() noexcept = default; + uint128_t(int64_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(int32_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(int16_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(int8_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(uint64_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(uint32_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(uint16_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(uint8_t value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(double value); // NOLINT: Allow implicit conversion from numeric values + uint128_t(float value); // NOLINT: Allow implicit conversion from numeric values + + constexpr uint128_t(uint64_t low, uint64_t high) noexcept : low(low), high(high) {} + + constexpr uint128_t(const uint128_t&) noexcept = default; + constexpr uint128_t(uint128_t&&) noexcept = default; + uint128_t& operator=(const uint128_t&) noexcept = default; + uint128_t& operator=(uint128_t&&) noexcept = default; + + uint128_t operator-() const; + + // inplace arithmetic operators + uint128_t& operator+=(const uint128_t& rhs); + uint128_t& operator*=(const uint128_t& rhs); + uint128_t& operator|=(const uint128_t& rhs); + uint128_t& operator&=(const uint128_t& rhs); + + // cast operators + explicit operator int64_t() const; + explicit operator int32_t() const; + explicit operator int16_t() const; + explicit operator int8_t() const; + explicit operator uint64_t() const; + explicit operator uint32_t() const; + explicit operator uint16_t() const; + explicit operator uint8_t() const; + explicit operator double() const; + explicit operator float() const; + + operator int128_t() const; // NOLINT: Allow implicit conversion from uint128 to int128 +}; + +// arithmetic operators +LBUG_API uint128_t operator+(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator-(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator*(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator/(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator%(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator^(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator&(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator~(const uint128_t& val); +LBUG_API uint128_t operator|(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API uint128_t operator<<(const uint128_t& lhs, int amount); +LBUG_API uint128_t operator>>(const uint128_t& lhs, int amount); + +// comparison operators +LBUG_API bool operator==(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API bool operator!=(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API bool operator>(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API bool operator>=(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API bool operator<(const uint128_t& lhs, const uint128_t& rhs); +LBUG_API bool operator<=(const uint128_t& lhs, const uint128_t& rhs); + +class UInt128_t { +public: + static std::string toString(uint128_t input); + + template + static bool tryCast(uint128_t input, T& result); + + template + static T cast(uint128_t input) { + T result; + tryCast(input, result); + return result; + } + + template + static bool tryCastTo(T value, uint128_t& result); + + template + static uint128_t castTo(T value) { + uint128_t result{}; + if (!tryCastTo(value, result)) { + throw common::OverflowException("UINT128 is out of range"); + } + return result; + } + + // negate (required by function/arithmetic/negate.h) + static void negateInPlace(uint128_t& input) { + input.low = UINT64_MAX + 1 - input.low; + input.high = -input.high - 1 + (input.low == 0); + } + + static uint128_t negate(uint128_t input) { + negateInPlace(input); + return input; + } + + static bool tryMultiply(uint128_t lhs, uint128_t rhs, uint128_t& result); + + static uint128_t Add(uint128_t lhs, uint128_t rhs); + static uint128_t Sub(uint128_t lhs, uint128_t rhs); + static uint128_t Mul(uint128_t lhs, uint128_t rhs); + static uint128_t Div(uint128_t lhs, uint128_t rhs); + static uint128_t Mod(uint128_t lhs, uint128_t rhs); + static uint128_t Xor(uint128_t lhs, uint128_t rhs); + static uint128_t LeftShift(uint128_t lhs, int amount); + static uint128_t RightShift(uint128_t lhs, int amount); + static uint128_t BinaryAnd(uint128_t lhs, uint128_t rhs); + static uint128_t BinaryOr(uint128_t lhs, uint128_t rhs); + static uint128_t BinaryNot(uint128_t val); + + static uint128_t divMod(uint128_t lhs, uint128_t rhs, uint128_t& remainder); + static uint128_t divModPositive(uint128_t lhs, uint64_t rhs, uint64_t& remainder); + + static bool addInPlace(uint128_t& lhs, uint128_t rhs); + static bool subInPlace(uint128_t& lhs, uint128_t rhs); + + // comparison operators + static bool equals(uint128_t lhs, uint128_t rhs) { + return lhs.low == rhs.low && lhs.high == rhs.high; + } + + static bool notEquals(uint128_t lhs, uint128_t rhs) { + return lhs.low != rhs.low || lhs.high != rhs.high; + } + + static bool greaterThan(uint128_t lhs, uint128_t rhs) { + return (lhs.high > rhs.high) || (lhs.high == rhs.high && lhs.low > rhs.low); + } + + static bool greaterThanOrEquals(uint128_t lhs, uint128_t rhs) { + return (lhs.high > rhs.high) || (lhs.high == rhs.high && lhs.low >= rhs.low); + } + + static bool lessThan(uint128_t lhs, uint128_t rhs) { + return (lhs.high < rhs.high) || (lhs.high == rhs.high && lhs.low < rhs.low); + } + + static bool lessThanOrEquals(uint128_t lhs, uint128_t rhs) { + return (lhs.high < rhs.high) || (lhs.high == rhs.high && lhs.low <= rhs.low); + } +}; + +template<> +bool UInt128_t::tryCast(uint128_t input, int8_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, int16_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, int32_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, int64_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, uint8_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, uint16_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, uint32_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, uint64_t& result); +template<> +bool UInt128_t::tryCast(uint128_t input, int128_t& result); // unsigned to signed +template<> +bool UInt128_t::tryCast(uint128_t input, float& result); +template<> +bool UInt128_t::tryCast(uint128_t input, double& result); +template<> +bool UInt128_t::tryCast(uint128_t input, long double& result); + +template<> +bool UInt128_t::tryCastTo(int8_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(int16_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(int32_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(int64_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(uint8_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(uint16_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(uint32_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(uint64_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(uint128_t value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(float value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(double value, uint128_t& result); +template<> +bool UInt128_t::tryCastTo(long double value, uint128_t& result); + +} // namespace common +} // namespace lbug + +template<> +struct std::hash { + std::size_t operator()(const lbug::common::uint128_t& v) const noexcept; +}; diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/uuid.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/uuid.h new file mode 100644 index 0000000000..7c46731a9d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/uuid.h @@ -0,0 +1,40 @@ +#pragma once + +#include "int128_t.h" + +namespace lbug { + +namespace regex { +class RE2; +} + +namespace common { + +class RandomEngine; + +// Note: uuid_t is a reserved keyword in MSVC, we have to use ku_uuid_t instead. +struct ku_uuid_t { + int128_t value; +}; + +struct UUID { + static constexpr const uint8_t UUID_STRING_LENGTH = 36; + static constexpr const char HEX_DIGITS[] = "0123456789abcdef"; + static void byteToHex(char byteVal, char* buf, uint64_t& pos); + static unsigned char hex2Char(char ch); + static bool isHex(char ch); + static bool fromString(std::string str, int128_t& result); + + static int128_t fromString(std::string str); + static int128_t fromCString(const char* str, uint64_t len); + static void toString(int128_t input, char* buf); + static std::string toString(int128_t input); + static std::string toString(ku_uuid_t val); + + static ku_uuid_t generateRandomUUID(RandomEngine* engine); + + static const regex::RE2& regexPattern(); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/nested.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/nested.h new file mode 100644 index 0000000000..fbafce903e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/nested.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +class Value; + +class NestedVal { +public: + LBUG_API static uint32_t getChildrenSize(const Value* val); + + LBUG_API static Value* getChildVal(const Value* val, uint32_t idx); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/node.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/node.h new file mode 100644 index 0000000000..f3abccad12 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/node.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +class Value; + +/** + * @brief NodeVal represents a node in the graph and stores the nodeID, label and properties of that + * node. + */ +class NodeVal { +public: + /** + * @return all properties of the NodeVal. + * @note this function copies all the properties into a vector, which is not efficient. use + * `getPropertyName` and `getPropertyVal` instead if possible. + */ + LBUG_API static std::vector>> getProperties( + const Value* val); + /** + * @return number of properties of the RelVal. + */ + LBUG_API static uint64_t getNumProperties(const Value* val); + + /** + * @return the name of the property at the given index. + */ + LBUG_API static std::string getPropertyName(const Value* val, uint64_t index); + + /** + * @return the value of the property at the given index. + */ + LBUG_API static Value* getPropertyVal(const Value* val, uint64_t index); + /** + * @return the nodeID as a Value. + */ + LBUG_API static Value* getNodeIDVal(const Value* val); + /** + * @return the name of the node as a Value. + */ + LBUG_API static Value* getLabelVal(const Value* val); + /** + * @return the current node values in string format. + */ + LBUG_API static std::string toString(const Value* val); + +private: + static void throwIfNotNode(const Value* val); + // 2 offsets for id and label. + static constexpr uint64_t OFFSET = 2; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/recursive_rel.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/recursive_rel.h new file mode 100644 index 0000000000..0fb913902f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/recursive_rel.h @@ -0,0 +1,31 @@ +#pragma once + +#include "common/api.h" + +namespace lbug { +namespace common { + +class Value; + +/** + * @brief RecursiveRelVal represents a path in the graph and stores the corresponding rels and nodes + * of that path. + */ +class RecursiveRelVal { +public: + /** + * @return the list of nodes in the recursive rel as a Value. + */ + LBUG_API static Value* getNodes(const Value* val); + + /** + * @return the list of rels in the recursive rel as a Value. + */ + LBUG_API static Value* getRels(const Value* val); + +private: + static void throwIfNotRecursiveRel(const Value* val); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/rel.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/rel.h new file mode 100644 index 0000000000..9aadf767ca --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/rel.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace common { + +class Value; + +/** + * @brief RelVal represents a rel in the graph and stores the relID, src/dst nodes and properties of + * that rel. + */ +class RelVal { +public: + /** + * @return all properties of the RelVal. + * @note this function copies all the properties into a vector, which is not efficient. use + * `getPropertyName` and `getPropertyVal` instead if possible. + */ + LBUG_API static std::vector>> getProperties( + const Value* val); + /** + * @return number of properties of the RelVal. + */ + LBUG_API static uint64_t getNumProperties(const Value* val); + /** + * @return the name of the property at the given index. + */ + LBUG_API static std::string getPropertyName(const Value* val, uint64_t index); + /** + * @return the value of the property at the given index. + */ + LBUG_API static Value* getPropertyVal(const Value* val, uint64_t index); + /** + * @return the src nodeID value of the RelVal in Value. + */ + LBUG_API static Value* getSrcNodeIDVal(const Value* val); + /** + * @return the dst nodeID value of the RelVal in Value. + */ + LBUG_API static Value* getDstNodeIDVal(const Value* val); + /** + * @return the internal ID value of the RelVal in Value. + */ + LBUG_API static Value* getIDVal(const Value* val); + /** + * @return the label value of the RelVal. + */ + LBUG_API static Value* getLabelVal(const Value* val); + /** + * @return the value of the RelVal in string format. + */ + LBUG_API static std::string toString(const Value* val); + +private: + static void throwIfNotRel(const Value* val); + // 4 offset for id, label, src, dst. + static constexpr uint64_t OFFSET = 4; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/value.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/value.h new file mode 100644 index 0000000000..393d4c83bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/types/value/value.h @@ -0,0 +1,899 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "common/types/date_t.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/ku_list.h" +#include "common/types/timestamp_t.h" +#include "common/types/uint128_t.h" +#include "common/types/uuid.h" + +namespace lbug { +namespace common { + +class NodeVal; +class RelVal; +struct FileInfo; +class NestedVal; +class RecursiveRelVal; +class ArrowRowBatch; +class ValueVector; +class Serializer; +class Deserializer; + +class Value { + friend class NodeVal; + friend class RelVal; + friend class NestedVal; + friend class RecursiveRelVal; + friend class ArrowRowBatch; + friend class ValueVector; + +public: + /** + * @return a NULL value of ANY type. + */ + LBUG_API static Value createNullValue(); + /** + * @param dataType the type of the NULL value. + * @return a NULL value of the given type. + */ + LBUG_API static Value createNullValue(const LogicalType& dataType); + /** + * @param dataType the type of the non-NULL value. + * @return a default non-NULL value of the given type. + */ + LBUG_API static Value createDefaultValue(const LogicalType& dataType); + /** + * @param val_ the boolean value to set. + */ + LBUG_API explicit Value(bool val_); + /** + * @param val_ the int8_t value to set. + */ + LBUG_API explicit Value(int8_t val_); + /** + * @param val_ the int16_t value to set. + */ + LBUG_API explicit Value(int16_t val_); + /** + * @param val_ the int32_t value to set. + */ + LBUG_API explicit Value(int32_t val_); + /** + * @param val_ the int64_t value to set. + */ + LBUG_API explicit Value(int64_t val_); + /** + * @param val_ the uint8_t value to set. + */ + LBUG_API explicit Value(uint8_t val_); + /** + * @param val_ the uint16_t value to set. + */ + LBUG_API explicit Value(uint16_t val_); + /** + * @param val_ the uint32_t value to set. + */ + LBUG_API explicit Value(uint32_t val_); + /** + * @param val_ the uint64_t value to set. + */ + LBUG_API explicit Value(uint64_t val_); + /** + * @param val_ the int128_t value to set. + */ + LBUG_API explicit Value(int128_t val_); + /** + * @param val_ the UUID value to set. + */ + LBUG_API explicit Value(ku_uuid_t val_); + /** + * @param val_ the double value to set. + */ + LBUG_API explicit Value(double val_); + /** + * @param val_ the float value to set. + */ + LBUG_API explicit Value(float val_); + /** + * @param val_ the date value to set. + */ + LBUG_API explicit Value(date_t val_); + /** + * @param val_ the timestamp_ns value to set. + */ + LBUG_API explicit Value(timestamp_ns_t val_); + /** + * @param val_ the timestamp_ms value to set. + */ + LBUG_API explicit Value(timestamp_ms_t val_); + /** + * @param val_ the timestamp_sec value to set. + */ + LBUG_API explicit Value(timestamp_sec_t val_); + /** + * @param val_ the timestamp_tz value to set. + */ + LBUG_API explicit Value(timestamp_tz_t val_); + /** + * @param val_ the timestamp value to set. + */ + LBUG_API explicit Value(timestamp_t val_); + /** + * @param val_ the interval value to set. + */ + LBUG_API explicit Value(interval_t val_); + /** + * @param val_ the internalID value to set. + */ + LBUG_API explicit Value(internalID_t val_); + /** + * @param val_ the uint128_t value to set. + */ + LBUG_API explicit Value(uint128_t val_); + /** + * @param val_ the string value to set. + */ + LBUG_API explicit Value(const char* val_); + /** + * @param val_ the string value to set. + */ + LBUG_API explicit Value(const std::string& val_); + /** + * @param val_ the uint8_t* value to set. + */ + LBUG_API explicit Value(uint8_t* val_); + /** + * @param type the logical type of the value. + * @param val_ the string value to set. + */ + LBUG_API explicit Value(LogicalType type, std::string val_); + /** + * @param dataType the logical type of the value. + * @param children a vector of children values. + */ + LBUG_API explicit Value(LogicalType dataType, std::vector> children); + /** + * @param other the value to copy from. + */ + LBUG_API Value(const Value& other); + + /** + * @param other the value to move from. + */ + LBUG_API Value(Value&& other) = default; + LBUG_API Value& operator=(Value&& other) = default; + LBUG_API bool operator==(const Value& rhs) const; + + /** + * @brief Sets the data type of the Value. + * @param dataType_ the data type to set to. + */ + LBUG_API void setDataType(const LogicalType& dataType_); + /** + * @return the dataType of the value. + */ + LBUG_API const LogicalType& getDataType() const; + /** + * @brief Sets the null flag of the Value. + * @param flag null value flag to set. + */ + LBUG_API void setNull(bool flag); + /** + * @brief Sets the null flag of the Value to true. + */ + LBUG_API void setNull(); + /** + * @return whether the Value is null or not. + */ + LBUG_API bool isNull() const; + /** + * @brief Copies from the row layout value. + * @param value value to copy from. + */ + LBUG_API void copyFromRowLayout(const uint8_t* value); + /** + * @brief Copies from the col layout value. + * @param value value to copy from. + */ + LBUG_API void copyFromColLayout(const uint8_t* value, ValueVector* vec = nullptr); + /** + * @brief Copies from the other. + * @param other value to copy from. + */ + LBUG_API void copyValueFrom(const Value& other); + /** + * @return the value of the given type. + */ + template + T getValue() const { + throw std::runtime_error("Unimplemented template for Value::getValue()"); + } + /** + * @return a reference to the value of the given type. + */ + template + T& getValueReference() { + throw std::runtime_error("Unimplemented template for Value::getValueReference()"); + } + /** + * @return a Value object based on value. + */ + template + static Value createValue(T /*value*/) { + throw std::runtime_error("Unimplemented template for Value::createValue()"); + } + + /** + * @return a copy of the current value. + */ + LBUG_API std::unique_ptr copy() const; + /** + * @return the current value in string format. + */ + LBUG_API std::string toString() const; + + LBUG_API void serialize(Serializer& serializer) const; + + LBUG_API static std::unique_ptr deserialize(Deserializer& deserializer); + + LBUG_API void validateType(common::LogicalTypeID targetTypeID) const; + + bool hasNoneNullChildren() const; + bool allowTypeChange() const; + + uint64_t computeHash() const; + + uint32_t getChildrenSize() const { return childrenSize; } + +private: + Value(); + explicit Value(const LogicalType& dataType); + + void resizeChildrenVector(uint64_t size, const LogicalType& childType); + void copyFromRowLayoutList(const ku_list_t& list, const LogicalType& childType); + void copyFromColLayoutList(const list_entry_t& list, ValueVector* vec); + void copyFromRowLayoutStruct(const uint8_t* kuStruct); + void copyFromColLayoutStruct(const struct_entry_t& structEntry, ValueVector* vec); + void copyFromUnion(const uint8_t* kuUnion); + + std::string mapToString() const; + std::string listToString() const; + std::string structToString() const; + std::string nodeToString() const; + std::string relToString() const; + std::string decimalToString() const; + +public: + union Val { + constexpr Val() : booleanVal{false} {} + bool booleanVal; + int128_t int128Val; + int64_t int64Val; + int32_t int32Val; + int16_t int16Val; + int8_t int8Val; + uint64_t uint64Val; + uint32_t uint32Val; + uint16_t uint16Val; + uint8_t uint8Val; + double doubleVal; + float floatVal; + // TODO(Ziyi): Should we remove the val suffix from all values in Val? Looks redundant. + uint8_t* pointer; + interval_t intervalVal; + internalID_t internalIDVal; + uint128_t uint128Val; + } val; + std::string strVal; + +private: + LogicalType dataType; + bool isNull_; + + // Note: ALWAYS use childrenSize over children.size(). We do NOT resize children when + // iterating with nested value. So children.size() reflects the capacity() rather the actual + // size. + std::vector> children; + uint32_t childrenSize; +}; + +/** + * @return boolean value. + */ +template<> +inline bool Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::BOOL); + return val.booleanVal; +} + +/** + * @return int8 value. + */ +template<> +inline int8_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT8); + return val.int8Val; +} + +/** + * @return int16 value. + */ +template<> +inline int16_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT16); + return val.int16Val; +} + +/** + * @return int32 value. + */ +template<> +inline int32_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT32); + return val.int32Val; +} + +/** + * @return int64 value. + */ +template<> +inline int64_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT64); + return val.int64Val; +} + +/** + * @return uint64 value. + */ +template<> +inline uint64_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT64); + return val.uint64Val; +} + +/** + * @return uint32 value. + */ +template<> +inline uint32_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT32); + return val.uint32Val; +} + +/** + * @return uint16 value. + */ +template<> +inline uint16_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT16); + return val.uint16Val; +} + +/** + * @return uint8 value. + */ +template<> +inline uint8_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT8); + return val.uint8Val; +} + +/** + * @return int128 value. + */ +template<> +inline int128_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT128); + return val.int128Val; +} + +/** + * @return float value. + */ +template<> +inline float Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::FLOAT); + return val.floatVal; +} + +/** + * @return double value. + */ +template<> +inline double Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::DOUBLE); + return val.doubleVal; +} + +/** + * @return date_t value. + */ +template<> +inline date_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::DATE); + return date_t{val.int32Val}; +} + +/** + * @return timestamp_t value. + */ +template<> +inline timestamp_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP); + return timestamp_t{val.int64Val}; +} + +/** + * @return timestamp_ns_t value. + */ +template<> +inline timestamp_ns_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_NS); + return timestamp_ns_t{val.int64Val}; +} + +/** + * @return timestamp_ms_t value. + */ +template<> +inline timestamp_ms_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_MS); + return timestamp_ms_t{val.int64Val}; +} + +/** + * @return timestamp_sec_t value. + */ +template<> +inline timestamp_sec_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_SEC); + return timestamp_sec_t{val.int64Val}; +} + +/** + * @return timestamp_tz_t value. + */ +template<> +inline timestamp_tz_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_TZ); + return timestamp_tz_t{val.int64Val}; +} + +/** + * @return interval_t value. + */ +template<> +inline interval_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::INTERVAL); + return val.intervalVal; +} + +/** + * @return internal_t value. + */ +template<> +inline internalID_t Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); + return val.internalIDVal; +} + +/** + * @return uint128 value. + */ +template<> +inline uint128_t Value::getValue() const { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT128); + return val.uint128Val; +} + +/** + * @return string value. + */ +template<> +inline std::string Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::STRING || + dataType.getLogicalTypeID() == LogicalTypeID::BLOB || + dataType.getLogicalTypeID() == LogicalTypeID::UUID); + return strVal; +} + +/** + * @return uint8_t* value. + */ +template<> +inline uint8_t* Value::getValue() const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::POINTER); + return val.pointer; +} + +/** + * @return the reference to the boolean value. + */ +template<> +inline bool& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::BOOL); + return val.booleanVal; +} + +/** + * @return the reference to the int8 value. + */ +template<> +inline int8_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT8); + return val.int8Val; +} + +/** + * @return the reference to the int16 value. + */ +template<> +inline int16_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT16); + return val.int16Val; +} + +/** + * @return the reference to the int32 value. + */ +template<> +inline int32_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT32); + return val.int32Val; +} + +/** + * @return the reference to the int64 value. + */ +template<> +inline int64_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT64); + return val.int64Val; +} + +/** + * @return the reference to the uint8 value. + */ +template<> +inline uint8_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT8); + return val.uint8Val; +} + +/** + * @return the reference to the uint16 value. + */ +template<> +inline uint16_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT16); + return val.uint16Val; +} + +/** + * @return the reference to the uint32 value. + */ +template<> +inline uint32_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT32); + return val.uint32Val; +} + +/** + * @return the reference to the uint64 value. + */ +template<> +inline uint64_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT64); + return val.uint64Val; +} + +/** + * @return the reference to the int128 value. + */ +template<> +inline int128_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::INT128); + return val.int128Val; +} + +/** + * @return the reference to the float value. + */ +template<> +inline float& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::FLOAT); + return val.floatVal; +} + +/** + * @return the reference to the double value. + */ +template<> +inline double& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::DOUBLE); + return val.doubleVal; +} + +/** + * @return the reference to the date value. + */ +template<> +inline date_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::DATE); + return *reinterpret_cast(&val.int32Val); +} + +/** + * @return the reference to the timestamp value. + */ +template<> +inline timestamp_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP); + return *reinterpret_cast(&val.int64Val); +} + +/** + * @return the reference to the timestamp_ms value. + */ +template<> +inline timestamp_ms_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_MS); + return *reinterpret_cast(&val.int64Val); +} + +/** + * @return the reference to the timestamp_ns value. + */ +template<> +inline timestamp_ns_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_NS); + return *reinterpret_cast(&val.int64Val); +} + +/** + * @return the reference to the timestamp_sec value. + */ +template<> +inline timestamp_sec_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_SEC); + return *reinterpret_cast(&val.int64Val); +} + +/** + * @return the reference to the timestamp_tz value. + */ +template<> +inline timestamp_tz_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_TZ); + return *reinterpret_cast(&val.int64Val); +} + +/** + * @return the reference to the interval value. + */ +template<> +inline interval_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::INTERVAL); + return val.intervalVal; +} + +/** + * @return the reference to the uint128 value. + */ +template<> +inline uint128_t& Value::getValueReference() { + KU_ASSERT(dataType.getPhysicalType() == PhysicalTypeID::UINT128); + return val.uint128Val; +} + +/** + * @return the reference to the internal_id value. + */ +template<> +inline nodeID_t& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); + return val.internalIDVal; +} + +/** + * @return the reference to the string value. + */ +template<> +inline std::string& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::STRING); + return strVal; +} + +/** + * @return the reference to the uint8_t* value. + */ +template<> +inline uint8_t*& Value::getValueReference() { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::POINTER); + return val.pointer; +} + +/** + * @param val the boolean value + * @return a Value with BOOL type and val value. + */ +template<> +inline Value Value::createValue(bool val) { + return Value(val); +} + +template<> +inline Value Value::createValue(int8_t val) { + return Value(val); +} + +/** + * @param val the int16 value + * @return a Value with INT16 type and val value. + */ +template<> +inline Value Value::createValue(int16_t val) { + return Value(val); +} + +/** + * @param val the int32 value + * @return a Value with INT32 type and val value. + */ +template<> +inline Value Value::createValue(int32_t val) { + return Value(val); +} + +/** + * @param val the int64 value + * @return a Value with INT64 type and val value. + */ +template<> +inline Value Value::createValue(int64_t val) { + return Value(val); +} + +/** + * @param val the uint8 value + * @return a Value with UINT8 type and val value. + */ +template<> +inline Value Value::createValue(uint8_t val) { + return Value(val); +} + +/** + * @param val the uint16 value + * @return a Value with UINT16 type and val value. + */ +template<> +inline Value Value::createValue(uint16_t val) { + return Value(val); +} + +/** + * @param val the uint32 value + * @return a Value with UINT32 type and val value. + */ +template<> +inline Value Value::createValue(uint32_t val) { + return Value(val); +} + +/** + * @param val the uint64 value + * @return a Value with UINT64 type and val value. + */ +template<> +inline Value Value::createValue(uint64_t val) { + return Value(val); +} + +/** + * @param val the int128_t value + * @return a Value with INT128 type and val value. + */ +template<> +inline Value Value::createValue(int128_t val) { + return Value(val); +} + +/** + * @param val the double value + * @return a Value with DOUBLE type and val value. + */ +template<> +inline Value Value::createValue(double val) { + return Value(val); +} + +/** + * @param val the date_t value + * @return a Value with DATE type and val value. + */ +template<> +inline Value Value::createValue(date_t val) { + return Value(val); +} + +/** + * @param val the timestamp_t value + * @return a Value with TIMESTAMP type and val value. + */ +template<> +inline Value Value::createValue(timestamp_t val) { + return Value(val); +} + +/** + * @param val the interval_t value + * @return a Value with INTERVAL type and val value. + */ +template<> +inline Value Value::createValue(interval_t val) { + return Value(val); +} + +/** + * @param val the uint128_t value + * @return a Value with UINT128 type and val value. + */ +template<> +inline Value Value::createValue(uint128_t val) { + return Value(val); +} + +/** + * @param val the nodeID_t value + * @return a Value with NODE_ID type and val value. + */ +template<> +inline Value Value::createValue(nodeID_t val) { + return Value(val); +} + +/** + * @param val the string value + * @return a Value with type and val value. + */ +template<> +inline Value Value::createValue(std::string val) { + return Value(LogicalType::STRING(), std::move(val)); +} + +/** + * @param value the string value + * @return a Value with STRING type and val value. + */ +template<> +inline Value Value::createValue(const char* value) { + return Value(LogicalType::STRING(), std::string(value)); +} + +/** + * @param val the uint8_t* val + * @return a Value with POINTER type and val val. + */ +template<> +inline Value Value::createValue(uint8_t* val) { + return Value(val); +} + +/** + * @param val the uuid_t* val + * @return a Value with UUID type and val val. + */ +template<> +inline Value Value::createValue(ku_uuid_t val) { + return Value(val); +} + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/uniq_lock.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/uniq_lock.h new file mode 100644 index 0000000000..38974b1fa2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/uniq_lock.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +namespace lbug { +namespace common { + +struct UniqLock { + UniqLock() {} + explicit UniqLock(std::mutex& mtx) : lck{mtx} {} + + UniqLock(const UniqLock&) = delete; + UniqLock& operator=(const UniqLock&) = delete; + + UniqLock(UniqLock&& other) noexcept { std::swap(lck, other.lck); } + UniqLock& operator=(UniqLock&& other) noexcept { + std::swap(lck, other.lck); + return *this; + } + bool isLocked() const { return lck.owns_lock(); } + +private: + std::unique_lock lck; +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/utils.h new file mode 100644 index 0000000000..022d1dd6d8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/utils.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include + +#include "common/assert.h" +#include "common/numeric_utils.h" +#include "common/types/int128_t.h" +#include + +namespace lbug { +namespace common { + +class BitmaskUtils { +public: + template + requires std::integral + static T all1sMaskForLeastSignificantBits(uint32_t numBits) { + KU_ASSERT(numBits <= 64); + using U = numeric_utils::MakeUnSignedT; + return (T)(numBits == (sizeof(U) * 8) ? std::numeric_limits::max() : + static_cast(((U)1 << numBits) - 1)); + } + + // constructs all 1s mask while avoiding overflow/underflow for int128 + template + requires std::same_as, int128_t> + static T all1sMaskForLeastSignificantBits(uint32_t numBits) { + static constexpr uint8_t numBitsInT = sizeof(T) * 8; + + // use ~T(1) instead of ~T(0) to avoid sign-bit filling + const T fullMask = ~(T(1) << (numBitsInT - 1)); + + const size_t numBitsToDiscard = (numBitsInT - 1 - numBits); + return (fullMask >> numBitsToDiscard); + } +}; + +uint64_t nextPowerOfTwo(uint64_t v); +uint64_t prevPowerOfTwo(uint64_t v); + +bool isLittleEndian(); + +template +constexpr T ceilDiv(T a, T b) { + return (a / b) + (a % b != 0); +} + +template +constexpr To safeIntegerConversion(From val) { + KU_ASSERT(static_cast(val) == val); + return val; +} + +template +bool containsValue(const Container& container, const T& value) { + return std::find(container.begin(), container.end(), value) != container.end(); +} + +template +constexpr T countBits(T) { + constexpr T bitsPerByte = 8; + return sizeof(T) * bitsPerByte; +} + +template +struct CountZeros { + static constexpr idx_t Leading(T value_in) { return std::countl_zero(value_in); } + static constexpr idx_t Trailing(T value_in) { return std::countr_zero(value_in); } +}; + +template<> +struct CountZeros { + static constexpr idx_t Leading(int128_t value) { + const uint64_t upper = static_cast(value.high); + const uint64_t lower = value.low; + + if (upper) { + return CountZeros::Leading(upper); + } + if (lower) { + return 64 + CountZeros::Leading(lower); + } + return 128; + } + + static constexpr idx_t Trailing(int128_t value) { + const uint64_t upper = static_cast(value.high); + const uint64_t lower = value.low; + + if (lower) { + return CountZeros::Trailing(lower); + } + if (upper) { + return 64 + CountZeros::Trailing(upper); + } + return 128; + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/vector/auxiliary_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/vector/auxiliary_buffer.h new file mode 100644 index 0000000000..f1c7859007 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/vector/auxiliary_buffer.h @@ -0,0 +1,105 @@ +#pragma once + +#include "common/api.h" +#include "common/in_mem_overflow_buffer.h" +#include "common/types/types.h" + +namespace lbug { +namespace common { + +class ValueVector; + +// AuxiliaryBuffer holds data which is only used by the targeting dataType. +class LBUG_API AuxiliaryBuffer { +public: + virtual ~AuxiliaryBuffer() = default; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +class StringAuxiliaryBuffer : public AuxiliaryBuffer { +public: + explicit StringAuxiliaryBuffer(storage::MemoryManager* memoryManager) { + inMemOverflowBuffer = std::make_unique(memoryManager); + } + + InMemOverflowBuffer* getOverflowBuffer() const { return inMemOverflowBuffer.get(); } + uint8_t* allocateOverflow(uint64_t size) { return inMemOverflowBuffer->allocateSpace(size); } + void resetOverflowBuffer() const { inMemOverflowBuffer->resetBuffer(); } + +private: + std::unique_ptr inMemOverflowBuffer; +}; + +class LBUG_API StructAuxiliaryBuffer : public AuxiliaryBuffer { +public: + StructAuxiliaryBuffer(const LogicalType& type, storage::MemoryManager* memoryManager); + + void referenceChildVector(idx_t idx, std::shared_ptr vectorToReference) { + childrenVectors[idx] = std::move(vectorToReference); + } + const std::vector>& getFieldVectors() const { + return childrenVectors; + } + std::shared_ptr getFieldVectorShared(idx_t idx) const { + return childrenVectors[idx]; + } + ValueVector* getFieldVectorPtr(idx_t idx) const { return childrenVectors[idx].get(); } + +private: + std::vector> childrenVectors; +}; + +// ListVector layout: +// To store a list value in the valueVector, we could use two separate vectors. +// 1. A vector(called offset vector) for the list offsets and length(called list_entry_t): This +// vector contains the starting indices and length for each list within the data vector. +// 2. A data vector(called dataVector) to store the actual list elements: This vector holds the +// actual elements of the lists in a flat, continuous storage. Each list would be represented as a +// contiguous subsequence of elements in this vector. +class LBUG_API ListAuxiliaryBuffer : public AuxiliaryBuffer { + friend class ListVector; + +public: + ListAuxiliaryBuffer(const LogicalType& dataVectorType, storage::MemoryManager* memoryManager); + + void setDataVector(std::shared_ptr vector) { dataVector = std::move(vector); } + ValueVector* getDataVector() const { return dataVector.get(); } + std::shared_ptr getSharedDataVector() const { return dataVector; } + + list_entry_t addList(list_size_t listSize); + + uint64_t getSize() const { return size; } + + void resetSize() { size = 0; } + + void resize(uint64_t numValues); + +private: + void resizeDataVector(ValueVector* dataVector); + + void resizeStructDataVector(ValueVector* dataVector); + +private: + uint64_t capacity; + uint64_t size; + + std::shared_ptr dataVector; +}; + +class AuxiliaryBufferFactory { +public: + static std::unique_ptr getAuxiliaryBuffer(LogicalType& type, + storage::MemoryManager* memoryManager); +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/vector/value_vector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/vector/value_vector.h new file mode 100644 index 0000000000..fc695d0027 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/vector/value_vector.h @@ -0,0 +1,352 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "common/cast.h" +#include "common/copy_constructors.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/null_mask.h" +#include "common/types/ku_string.h" +#include "common/vector/auxiliary_buffer.h" + +namespace lbug { +namespace common { + +class Value; + +//! A Vector represents values of the same data type. +//! The capacity of a ValueVector is either 1 (sequence) or DEFAULT_VECTOR_CAPACITY. +class LBUG_API ValueVector { + friend class ListVector; + friend class ListAuxiliaryBuffer; + friend class StructVector; + friend class StringVector; + friend class ArrowColumnVector; + +public: + explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr, + std::shared_ptr dataChunkState = nullptr); + explicit ValueVector(LogicalTypeID dataTypeID, storage::MemoryManager* memoryManager = nullptr) + : ValueVector(LogicalType(dataTypeID), memoryManager) { + KU_ASSERT(dataTypeID != LogicalTypeID::LIST); + } + + DELETE_COPY_AND_MOVE(ValueVector); + ~ValueVector() = default; + + template + std::optional firstNonNull() const { + sel_t selectedSize = state->getSelSize(); + if (selectedSize == 0) { + return std::nullopt; + } + if (hasNoNullsGuarantee()) { + return getValue(state->getSelVector()[0]); + } else { + for (size_t i = 0; i < selectedSize; i++) { + auto pos = state->getSelVector()[i]; + if (!isNull(pos)) { + return std::make_optional(getValue(pos)); + } + } + } + return std::nullopt; + } + + template + void forEachNonNull(Func&& func) const { + if (hasNoNullsGuarantee()) { + state->getSelVector().forEach(func); + } else { + state->getSelVector().forEach([&](auto i) { + if (!isNull(i)) { + func(i); + } + }); + } + } + + uint32_t countNonNull() const; + + void setState(const std::shared_ptr& state_); + + void setAllNull() { nullMask.setAllNull(); } + void setAllNonNull() { nullMask.setAllNonNull(); } + // On return true, there are no null. On return false, there may or may not be nulls. + bool hasNoNullsGuarantee() const { return nullMask.hasNoNullsGuarantee(); } + void setNullRange(uint32_t startPos, uint32_t len, bool value) { + nullMask.setNullFromRange(startPos, len, value); + } + const NullMask& getNullMask() const { return nullMask; } + void setNull(uint32_t pos, bool isNull); + uint8_t isNull(uint32_t pos) const { return nullMask.isNull(pos); } + void setAsSingleNullEntry() { + state->getSelVectorUnsafe().setSelSize(1); + setNull(state->getSelVector()[0], true); + } + + bool setNullFromBits(const uint64_t* srcNullEntries, uint64_t srcOffset, uint64_t dstOffset, + uint64_t numBitsToCopy, bool invert = false); + + uint32_t getNumBytesPerValue() const { return numBytesPerValue; } + + // TODO(Guodong): Rename this to getValueRef + template + const T& getValue(uint32_t pos) const { + return ((T*)valueBuffer.get())[pos]; + } + template + T& getValue(uint32_t pos) { + return ((T*)valueBuffer.get())[pos]; + } + template + void setValue(uint32_t pos, T val); + // copyFromRowData assumes rowData is non-NULL. + void copyFromRowData(uint32_t pos, const uint8_t* rowData); + // copyToRowData assumes srcVectorData is non-NULL. + void copyToRowData(uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer) const; + // copyFromVectorData assumes srcVectorData is non-NULL. + void copyFromVectorData(uint8_t* dstData, const ValueVector* srcVector, + const uint8_t* srcVectorData); + void copyFromVectorData(uint64_t dstPos, const ValueVector* srcVector, uint64_t srcPos); + void copyFromValue(uint64_t pos, const Value& value); + + std::unique_ptr getAsValue(uint64_t pos) const; + + uint8_t* getData() const { return valueBuffer.get(); } + + offset_t readNodeOffset(uint32_t pos) const { + KU_ASSERT(dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); + return getValue(pos).offset; + } + + void resetAuxiliaryBuffer(); + + // If there is still non-null values after discarding, return true. Otherwise, return false. + // For an unflat vector, its selection vector is also updated to the resultSelVector. + static bool discardNull(ValueVector& vector); + + void serialize(Serializer& ser) const; + static std::unique_ptr deSerialize(Deserializer& deSer, storage::MemoryManager* mm, + std::shared_ptr dataChunkState); + + SelectionVector* getSelVectorPtr() const { + return state ? &state->getSelVectorUnsafe() : nullptr; + } + +private: + uint32_t getDataTypeSize(const LogicalType& type); + void initializeValueBuffer(); + +public: + LogicalType dataType; + std::shared_ptr state; + +private: + std::unique_ptr valueBuffer; + NullMask nullMask; + uint32_t numBytesPerValue; + std::unique_ptr auxiliaryBuffer; +}; + +class LBUG_API StringVector { +public: + static inline InMemOverflowBuffer* getInMemOverflowBuffer(ValueVector* vector) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + return ku_dynamic_cast(vector->auxiliaryBuffer.get()) + ->getOverflowBuffer(); + } + + static void addString(ValueVector* vector, uint32_t vectorPos, ku_string_t& srcStr); + static void addString(ValueVector* vector, uint32_t vectorPos, const char* srcStr, + uint64_t length); + static void addString(ValueVector* vector, uint32_t vectorPos, std::string_view srcStr); + // Add empty string with space reserved for the provided size + // Returned value can be modified to set the string contents + static ku_string_t& reserveString(ValueVector* vector, uint32_t vectorPos, uint64_t length); + static void reserveString(ValueVector* vector, ku_string_t& dstStr, uint64_t length); + static void addString(ValueVector* vector, ku_string_t& dstStr, ku_string_t& srcStr); + static void addString(ValueVector* vector, ku_string_t& dstStr, const char* srcStr, + uint64_t length); + static void addString(lbug::common::ValueVector* vector, ku_string_t& dstStr, + const std::string& srcStr); + static void copyToRowData(const ValueVector* vector, uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer); +}; + +struct LBUG_API BlobVector { + static void addBlob(ValueVector* vector, uint32_t pos, const char* data, uint32_t length) { + StringVector::addString(vector, pos, data, length); + } // namespace common + static void addBlob(ValueVector* vector, uint32_t pos, const uint8_t* data, uint64_t length) { + StringVector::addString(vector, pos, reinterpret_cast(data), length); + } +}; // namespace lbug + +// ListVector is used for both LIST and ARRAY physical type +class LBUG_API ListVector { +public: + static const ListAuxiliaryBuffer& getAuxBuffer(const ValueVector& vector) { + return vector.auxiliaryBuffer->constCast(); + } + static ListAuxiliaryBuffer& getAuxBufferUnsafe(const ValueVector& vector) { + return vector.auxiliaryBuffer->cast(); + } + // If you call setDataVector during initialize, there must be a followed up + // copyListEntryAndBufferMetaData at runtime. + // TODO(Xiyang): try to merge setDataVector & copyListEntryAndBufferMetaData + static void setDataVector(const ValueVector* vector, std::shared_ptr dataVector) { + KU_ASSERT(validateType(*vector)); + auto& listBuffer = getAuxBufferUnsafe(*vector); + listBuffer.setDataVector(std::move(dataVector)); + } + static void copyListEntryAndBufferMetaData(ValueVector& vector, + const SelectionVector& selVector, const ValueVector& other, + const SelectionVector& otherSelVector); + static ValueVector* getDataVector(const ValueVector* vector) { + KU_ASSERT(validateType(*vector)); + return getAuxBuffer(*vector).getDataVector(); + } + static std::shared_ptr getSharedDataVector(const ValueVector* vector) { + KU_ASSERT(validateType(*vector)); + return getAuxBuffer(*vector).getSharedDataVector(); + } + static uint64_t getDataVectorSize(const ValueVector* vector) { + KU_ASSERT(validateType(*vector)); + return getAuxBuffer(*vector).getSize(); + } + static uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) { + KU_ASSERT(validateType(*vector)); + auto dataVector = getDataVector(vector); + return dataVector->getData() + dataVector->getNumBytesPerValue() * listEntry.offset; + } + static uint8_t* getListValuesWithOffset(const ValueVector* vector, + const list_entry_t& listEntry, offset_t elementOffsetInList) { + KU_ASSERT(validateType(*vector)); + return getListValues(vector, listEntry) + + elementOffsetInList * getDataVector(vector)->getNumBytesPerValue(); + } + static list_entry_t addList(ValueVector* vector, uint64_t listSize) { + KU_ASSERT(validateType(*vector)); + return getAuxBufferUnsafe(*vector).addList(listSize); + } + static void resizeDataVector(ValueVector* vector, uint64_t numValues) { + KU_ASSERT(validateType(*vector)); + getAuxBufferUnsafe(*vector).resize(numValues); + } + + static void copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData); + static void copyToRowData(const ValueVector* vector, uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer); + static void copyFromVectorData(ValueVector* dstVector, uint8_t* dstData, + const ValueVector* srcVector, const uint8_t* srcData); + static void appendDataVector(ValueVector* dstVector, ValueVector* srcDataVector, + uint64_t numValuesToAppend); + static void sliceDataVector(ValueVector* vectorToSlice, uint64_t offset, uint64_t numValues); + +private: + static bool validateType(const ValueVector& vector) { + switch (vector.dataType.getPhysicalType()) { + case PhysicalTypeID::LIST: + case PhysicalTypeID::ARRAY: + return true; + default: + return false; + } + } +}; + +class StructVector { +public: + static const std::vector>& getFieldVectors( + const ValueVector* vector) { + return ku_dynamic_cast(vector->auxiliaryBuffer.get()) + ->getFieldVectors(); + } + + static std::shared_ptr getFieldVector(const ValueVector* vector, + struct_field_idx_t idx) { + return ku_dynamic_cast(vector->auxiliaryBuffer.get()) + ->getFieldVectorShared(idx); + } + + static ValueVector* getFieldVectorRaw(const ValueVector& vector, const std::string& fieldName) { + auto idx = StructType::getFieldIdx(vector.dataType, fieldName); + return ku_dynamic_cast(vector.auxiliaryBuffer.get()) + ->getFieldVectorPtr(idx); + } + + static void referenceVector(ValueVector* vector, struct_field_idx_t idx, + std::shared_ptr vectorToReference) { + ku_dynamic_cast(vector->auxiliaryBuffer.get()) + ->referenceChildVector(idx, std::move(vectorToReference)); + } + + static void copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData); + static void copyToRowData(const ValueVector* vector, uint32_t pos, uint8_t* rowData, + InMemOverflowBuffer* rowOverflowBuffer); + static void copyFromVectorData(ValueVector* dstVector, const uint8_t* dstData, + const ValueVector* srcVector, const uint8_t* srcData); +}; + +class UnionVector { +public: + static inline ValueVector* getTagVector(const ValueVector* vector) { + KU_ASSERT(vector->dataType.getLogicalTypeID() == LogicalTypeID::UNION); + return StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX).get(); + } + + static inline ValueVector* getValVector(const ValueVector* vector, union_field_idx_t fieldIdx) { + KU_ASSERT(vector->dataType.getLogicalTypeID() == LogicalTypeID::UNION); + return StructVector::getFieldVector(vector, UnionType::getInternalFieldIdx(fieldIdx)).get(); + } + + static inline std::shared_ptr getSharedValVector(const ValueVector* vector, + union_field_idx_t fieldIdx) { + KU_ASSERT(vector->dataType.getLogicalTypeID() == LogicalTypeID::UNION); + return StructVector::getFieldVector(vector, UnionType::getInternalFieldIdx(fieldIdx)); + } + + static inline void referenceVector(ValueVector* vector, union_field_idx_t fieldIdx, + std::shared_ptr vectorToReference) { + StructVector::referenceVector(vector, UnionType::getInternalFieldIdx(fieldIdx), + std::move(vectorToReference)); + } + + static inline void setTagField(ValueVector& vector, SelectionVector& sel, + union_field_idx_t tag) { + KU_ASSERT(vector.dataType.getLogicalTypeID() == LogicalTypeID::UNION); + for (auto i = 0u; i < sel.getSelSize(); i++) { + vector.setValue(sel[i], tag); + } + } +}; + +class MapVector { +public: + static inline ValueVector* getKeyVector(const ValueVector* vector) { + return StructVector::getFieldVector(ListVector::getDataVector(vector), 0 /* keyVectorPos */) + .get(); + } + + static inline ValueVector* getValueVector(const ValueVector* vector) { + return StructVector::getFieldVector(ListVector::getDataVector(vector), 1 /* valVectorPos */) + .get(); + } + + static inline uint8_t* getMapKeys(const ValueVector* vector, const list_entry_t& listEntry) { + auto keyVector = getKeyVector(vector); + return keyVector->getData() + keyVector->getNumBytesPerValue() * listEntry.offset; + } + + static inline uint8_t* getMapValues(const ValueVector* vector, const list_entry_t& listEntry) { + auto valueVector = getValueVector(vector); + return valueVector->getData() + valueVector->getNumBytesPerValue() * listEntry.offset; + } +}; + +} // namespace common +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/windows_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/windows_utils.h new file mode 100644 index 0000000000..bf2466d707 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/common/windows_utils.h @@ -0,0 +1,18 @@ +#pragma once + +#if defined(_WIN32) +#include + +#include "windows.h" + +namespace lbug { +namespace common { + +struct WindowsUtils { + static std::wstring utf8ToUnicode(const char* input); + static std::string unicodeToUTF8(LPCWSTR input); +}; + +} // namespace common +} // namespace lbug +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/case_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/case_evaluator.h new file mode 100644 index 0000000000..90b2b27fcd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/case_evaluator.h @@ -0,0 +1,79 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" +#include "common/system_config.h" +#include "expression_evaluator.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace evaluator { + +struct CaseAlternativeEvaluator { + std::unique_ptr whenEvaluator; + std::unique_ptr thenEvaluator; + std::unique_ptr whenSelVector; + + CaseAlternativeEvaluator(std::unique_ptr whenEvaluator, + std::unique_ptr thenEvaluator) + : whenEvaluator{std::move(whenEvaluator)}, thenEvaluator{std::move(thenEvaluator)} {} + EXPLICIT_COPY_DEFAULT_MOVE(CaseAlternativeEvaluator); + + void init(const processor::ResultSet& resultSet, main::ClientContext* clientContext); + +private: + CaseAlternativeEvaluator(const CaseAlternativeEvaluator& other) + : whenEvaluator{other.whenEvaluator->copy()}, thenEvaluator{other.thenEvaluator->copy()} {} +}; + +class CaseExpressionEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::CASE_ELSE; + +public: + CaseExpressionEvaluator(std::shared_ptr expression, + std::vector alternativeEvaluators, + std::unique_ptr elseEvaluator) + : ExpressionEvaluator{type_, std::move(expression)}, + alternativeEvaluators{std::move(alternativeEvaluators)}, + elseEvaluator{std::move(elseEvaluator)} {} + + const std::vector& getAlternativeEvaluators() const { + return alternativeEvaluators; + } + ExpressionEvaluator* getElseEvaluator() const { return elseEvaluator.get(); } + + void init(const processor::ResultSet& resultSet, main::ClientContext* clientContext) override; + + void evaluate() override; + + bool selectInternal(common::SelectionVector& selVector) override; + + std::unique_ptr copy() override { + return std::make_unique(expression, + copyVector(alternativeEvaluators), elseEvaluator->copy()); + } + +protected: + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) override; + +private: + void fillSelected(const common::SelectionVector& selVector, common::ValueVector* srcVector); + + void fillAll(common::ValueVector* srcVector); + + void fillEntry(common::sel_t resultPos, common::ValueVector* srcVector); + +private: + std::vector alternativeEvaluators; + std::unique_ptr elseEvaluator; + + std::bitset filledMask; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator.h new file mode 100644 index 0000000000..a538035ef6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator.h @@ -0,0 +1,121 @@ +#pragma once + +#include "processor/result/result_set.h" + +namespace lbug { +namespace binder { +class Expression; +} +namespace evaluator { + +struct EvaluatorLocalState { + main::ClientContext* clientContext = nullptr; +}; + +enum class EvaluatorType : uint8_t { + CASE_ELSE = 0, + FUNCTION = 1, + LAMBDA_PARAM = 2, + LIST_LAMBDA = 3, + LITERAL = 4, + PATH = 5, + NODE_REL = 6, + REFERENCE = 8, +}; + +class ExpressionEvaluator; +using evaluator_vector_t = std::vector>; + +class ExpressionEvaluator { +public: + explicit ExpressionEvaluator(EvaluatorType type, std::shared_ptr expression) + : type{type}, expression{std::move(expression)} {}; + ExpressionEvaluator(EvaluatorType type, std::shared_ptr expression, + bool isResultFlat) + : type{type}, expression{std::move(expression)}, isResultFlat_{isResultFlat} {} + ExpressionEvaluator(EvaluatorType type, std::shared_ptr expression, + evaluator_vector_t children) + : type{type}, expression{std::move(expression)}, children{std::move(children)} {} + ExpressionEvaluator(const ExpressionEvaluator& other) + : type{other.type}, expression{other.expression}, isResultFlat_{other.isResultFlat_}, + children{copyVector(other.children)} {} + virtual ~ExpressionEvaluator() = default; + + EvaluatorType getEvaluatorType() const { return type; } + + std::shared_ptr getExpression() const { return expression; } + bool isResultFlat() const { return isResultFlat_; } + + const evaluator_vector_t& getChildren() const { return children; } + + virtual void init(const processor::ResultSet& resultSet, main::ClientContext* clientContext); + + virtual void evaluate() = 0; + // Evaluate and duplicate result for count times. This is a fast path we implemented for + // bulk-insert when evaluate default values. A default value should be + // - a constant (after folding); or + // - a nextVal() function for serial column + virtual void evaluate(common::sel_t count); + + bool select(common::SelectionVector& selVector, bool shouldSetSelVectorToFiltered); + + virtual std::unique_ptr copy() = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + +protected: + virtual void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) = 0; + + void resolveResultStateFromChildren(const std::vector& inputEvaluators); + + virtual bool selectInternal(common::SelectionVector& selVector) = 0; + + bool updateSelectedPos(common::SelectionVector& selVector) const { + auto& resultSelVector = resultVector->state->getSelVector(); + if (resultSelVector.getSelSize() > 1) { + auto numSelectedValues = 0u; + for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) { + auto pos = resultSelVector[i]; + auto selectedPosBuffer = selVector.getMutableBuffer(); + selectedPosBuffer[numSelectedValues] = pos; + numSelectedValues += + resultVector->isNull(pos) ? 0 : resultVector->getValue(pos); + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } else { + // If result state is flat (i.e. all children are flat), we shouldn't try to update + // selectedPos because we don't know which one is leading, i.e. the one being selected + // by filter. + // So we forget about selectedPos and directly return true/false. This doesn't change + // the correctness, because when all children are flat the check is done on tuple. + auto pos = resultVector->state->getSelVector()[0]; + return resultVector->isNull(pos) ? 0 : resultVector->getValue(pos); + } + } + +public: + std::shared_ptr resultVector; + +protected: + EvaluatorType type; + std::shared_ptr expression; + bool isResultFlat_ = true; + evaluator_vector_t children; + EvaluatorLocalState localState; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator_utils.h new file mode 100644 index 0000000000..8916428228 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator_utils.h @@ -0,0 +1,15 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace evaluator { + +struct ExpressionEvaluatorUtils { + static LBUG_API common::Value evaluateConstantExpression( + std::shared_ptr expression, main::ClientContext* clientContext); +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator_visitor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator_visitor.h new file mode 100644 index 0000000000..67f6d9fecc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/expression_evaluator_visitor.h @@ -0,0 +1,43 @@ +#pragma once + +#include "expression_evaluator.h" + +namespace lbug { +namespace evaluator { + +class ExpressionEvaluatorVisitor { +public: + virtual ~ExpressionEvaluatorVisitor() = default; + +protected: + void visitSwitch(ExpressionEvaluator* evaluator); + + virtual void visitCase(ExpressionEvaluator*) {} + virtual void visitFunction(ExpressionEvaluator*) {} + virtual void visitLambdaParam(ExpressionEvaluator*) {} + virtual void visitListLambda(ExpressionEvaluator*) {} + virtual void visitLiteral(ExpressionEvaluator*) {} + virtual void visitPath(ExpressionEvaluator*) {} + virtual void visitReference(ExpressionEvaluator*) {} + // NOTE: If one decides to overwrite pattern evaluator visitor, make sure we differentiate + // pattern evaluator and undirected rel evaluator. + void visitPattern(ExpressionEvaluator*) {} +}; + +class LambdaParamEvaluatorCollector final : public ExpressionEvaluatorVisitor { +public: + void visit(ExpressionEvaluator* evaluator); + + std::vector getEvaluators() const { return evaluators; } + +protected: + void visitLambdaParam(ExpressionEvaluator* evaluator) override { + evaluators.push_back(evaluator); + } + +private: + std::vector evaluators; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/function_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/function_evaluator.h new file mode 100644 index 0000000000..371b5b3911 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/function_evaluator.h @@ -0,0 +1,38 @@ +#pragma once + +#include "expression_evaluator.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace evaluator { + +class FunctionExpressionEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::FUNCTION; + +public: + FunctionExpressionEvaluator(std::shared_ptr expression, + std::vector> children); + + void evaluate() override; + void evaluate(common::sel_t count) override; + + bool selectInternal(common::SelectionVector& selVector) override; + + std::unique_ptr copy() override { + return std::make_unique(expression, copyVector(children)); + } + +protected: + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) override; + + void runExecFunc(void* dataPtr = nullptr); + +private: + std::vector> parameters; + std::unique_ptr function; + std::unique_ptr bindData; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/lambda_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/lambda_evaluator.h new file mode 100644 index 0000000000..7750b0648d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/lambda_evaluator.h @@ -0,0 +1,101 @@ +#pragma once + +#include "binder/expression/scalar_function_expression.h" +#include "expression_evaluator.h" + +namespace lbug { +namespace evaluator { + +class ListSliceInfo; +class ListEntryTracker; + +enum class ListLambdaType : uint8_t { + LIST_TRANSFORM = 0, + LIST_FILTER = 1, + LIST_REDUCE = 2, + DEFAULT = 3 +}; + +class LambdaParamEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::LAMBDA_PARAM; + +public: + explicit LambdaParamEvaluator(std::shared_ptr expression) + : ExpressionEvaluator{type_, std::move(expression), false /* isResultFlat */} {} + + void evaluate() override {} + + bool selectInternal(common::SelectionVector&) override { KU_UNREACHABLE; } + + std::unique_ptr copy() override { + return std::make_unique(expression); + } + + std::string getVarName() { return this->getExpression()->toString(); } + +protected: + void resolveResultVector(const processor::ResultSet&, storage::MemoryManager*) override {} +}; + +struct ListLambdaBindData { + std::vector lambdaParamEvaluators; + std::vector paramIndices; + ExpressionEvaluator* rootEvaluator = nullptr; + ListSliceInfo* sliceInfo = nullptr; +}; + +// E.g. for function list_transform([0,1,2], x->x+1) +// ListLambdaEvaluator has one child that is the evaluator of [0,1,2] +// lambdaRootEvaluator is the evaluator of x+1 +// lambdaParamEvaluator is the evaluator of x +class ListLambdaEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::LIST_LAMBDA; + static ListLambdaType checkListLambdaTypeWithFunctionName(std::string functionName); + +public: + ListLambdaEvaluator(std::shared_ptr expression, evaluator_vector_t children) + : ExpressionEvaluator{type_, expression, std::move(children)}, memoryManager(nullptr) { + execFunc = expression->constCast().getFunction().execFunc; + listLambdaType = checkListLambdaTypeWithFunctionName( + expression->constCast().getFunction().name); + } + + void setLambdaRootEvaluator(std::unique_ptr evaluator) { + lambdaRootEvaluator = std::move(evaluator); + } + + void init(const processor::ResultSet& resultSet, main::ClientContext* clientContext) override; + + void evaluate() override; + + bool selectInternal(common::SelectionVector& selVector) override; + + std::unique_ptr copy() override { + auto result = std::make_unique(expression, copyVector(children)); + result->setLambdaRootEvaluator(lambdaRootEvaluator->copy()); + return result; + } + + std::vector getParamIndices(); + +protected: + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) override; + +private: + void evaluateInternal(); + + function::scalar_func_exec_t execFunc; + ListLambdaBindData bindData; + +private: + std::unique_ptr lambdaRootEvaluator; + std::vector lambdaParamEvaluators; + std::vector> params; + ListLambdaType listLambdaType; + + storage::MemoryManager* memoryManager; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/list_slice_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/list_slice_info.h new file mode 100644 index 0000000000..62ac1b5bd4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/list_slice_info.h @@ -0,0 +1,111 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug::evaluator { + +class LambdaParamEvaluator; + +class ListEntryTracker { +public: + explicit ListEntryTracker(common::ValueVector* listVector); + + common::offset_t getCurDataOffset() const { return getCurListEntry().offset + offsetInList; } + common::offset_t getNextDataOffset(); + common::list_entry_t getCurListEntry() const { + return listVector->getValue(getListEntryPos()); + } + common::idx_t getListEntryPos() const { return listEntries[listEntryIdx]; } + + bool done() const { return listEntryIdx >= listEntries.size(); } + +private: + void updateListEntry(); + + common::ValueVector* listVector; + common::idx_t listEntryIdx; + common::offset_t offsetInList; + + // selected pos of each list entry + std::vector listEntries; +}; + +/** + * List data vectors can their number of elements exceed DEFAULT_VECTOR_CAPACITY + * However, most expression evaluators can only process elements in batches of size + * DEFAULT_VECTOR_CAPACITY + * This means that in order for lambda evaluators to work it must pass in its data in slices of size + * DEFAULT_VECTOR_CAPACITY for processing by child evaluators + * + * A consequence of this is that some lists may have their data vectors split into different slices + * and thus it is unreasonable to have execFuncs operate on a list-by-list basis. Instead, they + * should operate on each data vector entry individually. + * + * Instead, any lambda execFunc should follow this pattern using the ListSliceInfo struct + * void execFunc(...) { + * auto& sliceInfo = *listLambdaBindData->sliceInfo; + * // loop through each data vector entry in the slice + * for (sel_t i = 0; i < sliceInfo.getSliceSize(); ++i) { + * // dataOffset: the offset of the current entry in the input data vector + * // listEntryPos: the pos of the list entry containing to the data entry in the list vector + * const auto [listEntryPos, dataOffset] = sliceInfo.getPos(i); + * doSomething(listEntryPos, dataOffset); + * } + * + * // do any final processing required on each list entry vector + * // only do this once for all slices + * if (sliceInfo.done()) { + * for (uint64_t i = 0; i < inputSelVector.getSelSize(); ++i) { + * auto pos = inputSelVector[i]; + * doSomething(inputVector, pos); + * } + * } + */ +class ListSliceInfo { +public: + explicit ListSliceInfo(common::ValueVector* listVector) + : resultSliceOffset(0), listEntryTracker(listVector), + sliceDataState(std::make_shared()), + sliceListEntryState(std::make_shared()) { + sliceDataState->setToUnflat(); + sliceDataState->getSelVectorUnsafe().setToFiltered(); + sliceListEntryState->setToUnflat(); + sliceListEntryState->getSelVectorUnsafe().setToFiltered(); + } + + void nextSlice(); + + std::vector> overrideAndSaveParamStates( + std::span lambdaParamEvaluators); + static void restoreParamStates(std::span lambdaParamEvaluators, + std::vector> savedStates); + + // use in cases (like list filter) where the output data offset may not correspond to the input + // data offset + common::offset_t& getResultSliceOffset() { return resultSliceOffset; } + + bool done() const; + + common::sel_t getSliceSize() const { + KU_ASSERT(sliceDataState->getSelSize() == sliceListEntryState->getSelSize()); + return sliceDataState->getSelSize(); + } + + // returns {list entry pos, data pos} + std::pair getPos(common::idx_t i) const { + return {sliceListEntryState->getSelVector()[i], sliceDataState->getSelVector()[i]}; + } + +private: + void updateSelVector(); + + // offset/size refer to the data vector + common::offset_t resultSliceOffset; + + ListEntryTracker listEntryTracker; + + std::shared_ptr sliceDataState; + std::shared_ptr sliceListEntryState; +}; + +} // namespace lbug::evaluator diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/literal_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/literal_evaluator.h new file mode 100644 index 0000000000..7fd5740166 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/literal_evaluator.h @@ -0,0 +1,38 @@ +#pragma once + +#include "common/types/value/value.h" +#include "expression_evaluator.h" + +namespace lbug { +namespace evaluator { + +class LiteralExpressionEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::LITERAL; + +public: + LiteralExpressionEvaluator(std::shared_ptr expression, common::Value value) + : ExpressionEvaluator{type_, std::move(expression), true /* isResultFlat */}, + value{std::move(value)} {} + + void evaluate() override; + + void evaluate(common::sel_t count) override; + + bool selectInternal(common::SelectionVector& selVector) override; + + std::unique_ptr copy() override { + return std::make_unique(expression, value); + } + +protected: + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) override; + +private: + common::Value value; + std::shared_ptr flatState; + std::shared_ptr unFlatState; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/path_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/path_evaluator.h new file mode 100644 index 0000000000..0ab902e704 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/path_evaluator.h @@ -0,0 +1,69 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "expression_evaluator.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace evaluator { + +class PathExpressionEvaluator final : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::PATH; + +public: + PathExpressionEvaluator(std::shared_ptr expression, + evaluator_vector_t children) + : ExpressionEvaluator{type_, std::move(expression), std::move(children)}, + resultNodesVector(nullptr), resultRelsVector(nullptr) {} + + void init(const processor::ResultSet& resultSet, main::ClientContext* clientContext) override; + + void evaluate() override; + + bool selectInternal(common::SelectionVector& /*selVector*/) override { KU_UNREACHABLE; } + + std::unique_ptr copy() override { + return make_unique(expression, copyVector(children)); + } + +private: + struct InputVectors { + // input can either be NODE, REL or RECURSIVE_REL + common::ValueVector* input = nullptr; + // nodesInput is LIST[NODE] for RECURSIVE_REL input and nullptr otherwise + common::ValueVector* nodesInput = nullptr; + // nodesDataInput is NODE for RECURSIVE_REL and nullptr otherwise + common::ValueVector* nodesDataInput = nullptr; + // relsInput is LIST[REL] for RECURSIVE_REL input and nullptr otherwise + common::ValueVector* relsInput = nullptr; + // relsDataInput is REL for RECURSIVE_REL input and nullptr otherwise + common::ValueVector* relsDataInput = nullptr; + + std::vector nodeFieldVectors; + std::vector relFieldVectors; + }; + + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) override; + + void copyNodes(common::sel_t resultPos, bool isEmptyRels); + uint64_t copyRels(common::sel_t resultPos); + + void copyFieldVectors(common::offset_t inputVectorPos, + const std::vector& inputFieldVectors, + common::offset_t& resultVectorPos, + const std::vector& resultFieldVectors); + +private: + std::vector> inputVectorsPerChild; + common::ValueVector* resultNodesVector; // LIST[NODE] + common::ValueVector* resultRelsVector; // LIST[REL] + std::vector resultNodesFieldVectors; + std::vector resultRelsFieldVectors; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/pattern_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/pattern_evaluator.h new file mode 100644 index 0000000000..e4d5ed3ff3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/pattern_evaluator.h @@ -0,0 +1,61 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "expression_evaluator.h" + +namespace lbug { +namespace evaluator { + +class PatternExpressionEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::NODE_REL; + +public: + PatternExpressionEvaluator(std::shared_ptr pattern, + evaluator_vector_t children) + : ExpressionEvaluator{type_, std::move(pattern), std::move(children)}, idVector{nullptr} {} + + void evaluate() override; + + bool selectInternal(common::SelectionVector&) override { KU_UNREACHABLE; } + + std::unique_ptr copy() override { + return std::make_unique(expression, copyVector(children)); + } + +protected: + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager* memoryManager) override; + + virtual void initFurther(const processor::ResultSet& resultSet); + +protected: + common::ValueVector* idVector; + std::vector> parameters; +}; + +class UndirectedRelExpressionEvaluator final : public PatternExpressionEvaluator { +public: + UndirectedRelExpressionEvaluator(std::shared_ptr pattern, + evaluator_vector_t children, std::unique_ptr directionEvaluator) + : PatternExpressionEvaluator{std::move(pattern), std::move(children)}, srcIDVector{nullptr}, + dstIDVector{nullptr}, directionVector{nullptr}, + directionEvaluator{std::move(directionEvaluator)} {} + + void evaluate() override; + + void initFurther(const processor::ResultSet& resultSet) override; + + std::unique_ptr copy() override { + return std::make_unique(expression, copyVector(children), + directionEvaluator->copy()); + } + +private: + common::ValueVector* srcIDVector; + common::ValueVector* dstIDVector; + common::ValueVector* directionVector; + std::unique_ptr directionEvaluator; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/reference_evaluator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/reference_evaluator.h new file mode 100644 index 0000000000..57ce5b2bf2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/expression_evaluator/reference_evaluator.h @@ -0,0 +1,39 @@ +#pragma once + +#include "expression_evaluator.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace evaluator { + +class ReferenceExpressionEvaluator : public ExpressionEvaluator { + static constexpr EvaluatorType type_ = EvaluatorType::REFERENCE; + +public: + ReferenceExpressionEvaluator(std::shared_ptr expression, bool isResultFlat, + const processor::DataPos& dataPos) + : ExpressionEvaluator{type_, std::move(expression), isResultFlat}, dataPos{dataPos} {} + + void evaluate() override {} + + bool selectInternal(common::SelectionVector& selVector) override; + + std::unique_ptr copy() override { + return std::make_unique(expression, isResultFlat_, dataPos); + } + +protected: + void resolveResultVector(const processor::ResultSet& resultSet, + storage::MemoryManager*) override { + resultVector = resultSet.getValueVector(dataPos); + } + +private: + processor::DataPos dataPos; +}; + +} // namespace evaluator +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/binder_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/binder_extension.h new file mode 100644 index 0000000000..ea06223f45 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/binder_extension.h @@ -0,0 +1,19 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "parser/statement.h" + +namespace lbug { +namespace extension { + +class LBUG_API BinderExtension { +public: + BinderExtension() {} + + virtual ~BinderExtension() = default; + + virtual std::unique_ptr bind(const parser::Statement& statement) = 0; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/bound_extension_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/bound_extension_clause.h new file mode 100644 index 0000000000..8d38366673 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/bound_extension_clause.h @@ -0,0 +1,24 @@ +#pragma once + +#include "binder/bound_statement.h" + +namespace lbug { +namespace extension { + +class BoundExtensionClause : public binder::BoundStatement { + static constexpr common::StatementType type_ = common::StatementType::EXTENSION_CLAUSE; + +public: + explicit BoundExtensionClause(std::string statementName) + : BoundStatement{type_, binder::BoundStatementResult::createSingleStringColumnResult( + "result" /* columnName */)}, + statementName{std::move(statementName)} {} + + std::string getStatementName() const { return statementName; } + +private: + std::string statementName; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/catalog_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/catalog_extension.h new file mode 100644 index 0000000000..6ec87b03b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/catalog_extension.h @@ -0,0 +1,18 @@ +#pragma once + +#include "catalog/catalog.h" + +namespace lbug { +namespace extension { + +class LBUG_API CatalogExtension : public catalog::Catalog { +public: + CatalogExtension() : Catalog() {} + + virtual void init() = 0; + + void invalidateCache(); +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension.h new file mode 100644 index 0000000000..aac882de25 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension.h @@ -0,0 +1,196 @@ +#pragma once + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/catalog_entry_type.h" +#include "common/api.h" +#include "main/database.h" +#include "transaction/transaction.h" + +#define ADD_EXTENSION_OPTION(OPTION) \ + db->addExtensionOption(OPTION::NAME, OPTION::TYPE, OPTION::getDefaultValue()) + +#define ADD_CONFIDENTIAL_EXTENSION_OPTION(OPTION) \ + db->addExtensionOption(OPTION::NAME, OPTION::TYPE, OPTION::getDefaultValue(), true) + +namespace lbug::storage { +struct IndexType; +} +namespace lbug { +namespace function { +struct TableFunction; +} // namespace function + +namespace extension { + +typedef void (*ext_init_func_t)(main::ClientContext*); +typedef const char* (*ext_name_func_t)(); +using ext_load_func_t = ext_init_func_t; +typedef void (*ext_install_func_t)(const std::string&, main::ClientContext&); + +std::string getPlatform(); + +class LBUG_API Extension { +public: + virtual ~Extension() = default; +}; + +struct ExtensionRepoInfo { + std::string hostPath; + std::string hostURL; + std::string repoURL; +}; + +enum class ExtensionSource : uint8_t { OFFICIAL, USER, STATIC_LINKED }; + +struct ExtensionSourceUtils { + static std::string toString(ExtensionSource source); +}; + +template +void addFunc(main::Database& database, std::string name, catalog::CatalogEntryType functionType, + bool isInternal = false) { + auto catalog = database.getCatalog(); + if (catalog->containsFunction(&transaction::DUMMY_TRANSACTION, name, isInternal)) { + return; + } + catalog->addFunction(&transaction::DUMMY_TRANSACTION, functionType, std::move(name), + T::getFunctionSet(), isInternal); +} + +struct LBUG_API ExtensionUtils { + static constexpr const char* OFFICIAL_EXTENSION_REPO = "http://extension.ladybugdb.com/"; + static constexpr const char* EXTENSION_FILE_SUFFIX = "lbug_extension"; + + static constexpr const char* EXTENSION_FILE_REPO_PATH = "{}v{}/{}/{}/{}"; + + static constexpr const char* SHARED_LIB_REPO = "{}v{}/{}/common/{}"; + + static constexpr const char* EXTENSION_FILE_NAME = "lib{}.{}"; + + static constexpr const char* OFFICIAL_EXTENSION[] = {"HTTPFS", "POSTGRES", "DUCKDB", "JSON", + "SQLITE", "FTS", "DELTA", "ICEBERG", "AZURE", "UNITY_CATALOG", "VECTOR", "NEO4J", "ALGO", + "LLM"}; + + static constexpr const char* EXTENSION_LOADER_SUFFIX = "_loader"; + + static constexpr const char* EXTENSION_INSTALLER_SUFFIX = "_installer"; + + static ExtensionRepoInfo getExtensionLibRepoInfo(const std::string& extensionName, + const std::string& extensionRepo); + + static ExtensionRepoInfo getExtensionLoaderRepoInfo(const std::string& extensionName, + const std::string& extensionRepo); + + static ExtensionRepoInfo getExtensionInstallerRepoInfo(const std::string& extensionName, + const std::string& extensionRepo); + + static ExtensionRepoInfo getSharedLibRepoInfo(const std::string& fileName, + const std::string& extensionRepo); + + static std::string getExtensionFileName(const std::string& name); + + static std::string getLocalPathForExtensionLib(main::ClientContext* context, + const std::string& extensionName); + + static std::string getLocalPathForExtensionLoader(main::ClientContext* context, + const std::string& extensionName); + + static std::string getLocalPathForExtensionInstaller(main::ClientContext* context, + const std::string& extensionName); + + static std::string getLocalDirForExtension(main::ClientContext* context, + const std::string& extensionName); + + static std::string appendLibSuffix(const std::string& libName); + + static std::string getLocalPathForSharedLib(main::ClientContext* context, + const std::string& libName); + + static std::string getLocalPathForSharedLib(main::ClientContext* context); + + static bool isOfficialExtension(const std::string& extension); + + template + static void addTableFunc(main::Database& database) { + addFunc(database, T::name, catalog::CatalogEntryType::TABLE_FUNCTION_ENTRY); + } + + template + static void addTableFuncAlias(main::Database& database) { + addFunc(database, T::name, + catalog::CatalogEntryType::TABLE_FUNCTION_ENTRY); + } + + template + static void addStandaloneTableFunc(main::Database& database) { + addFunc(database, T::name, catalog::CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY, + false /* isInternal */); + } + template + static void addInternalStandaloneTableFunc(main::Database& database) { + addFunc(database, T::name, catalog::CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY, + true /* isInternal */); + } + + template + static void addScalarFunc(main::Database& database) { + addFunc(database, T::name, catalog::CatalogEntryType::SCALAR_FUNCTION_ENTRY); + } + + template + static void addScalarFuncAlias(main::Database& database) { + addFunc(database, T::name, + catalog::CatalogEntryType::SCALAR_FUNCTION_ENTRY); + } + + template + static void addExportFunc(main::Database& database) { + addFunc(database, T::name, catalog::CatalogEntryType::COPY_FUNCTION_ENTRY); + } + + static void registerIndexType(main::Database& database, storage::IndexType type); +}; + +class LBUG_API ExtensionLibLoader { +public: + static constexpr const char* EXTENSION_LOAD_FUNC_NAME = "load"; + + static constexpr const char* EXTENSION_INIT_FUNC_NAME = "init"; + + static constexpr const char* EXTENSION_NAME_FUNC_NAME = "name"; + + static constexpr const char* EXTENSION_INSTALL_FUNC_NAME = "install"; + +public: + ExtensionLibLoader(const std::string& extensionName, const std::string& path); + + ext_load_func_t getLoadFunc(); + + ext_init_func_t getInitFunc(); + + ext_name_func_t getNameFunc(); + + ext_install_func_t getInstallFunc(); + + void unload(); + +private: + void* getDynamicLibFunc(const std::string& funcName); + +private: + std::string extensionName; + void* libHdl; +}; + +#ifdef _WIN32 +std::wstring utf8ToUnicode(const char* input); + +void* dlopen(const char* file, int /*mode*/); + +void* dlsym(void* handle, const char* name); + +void dlclose(void* handle); +#endif + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_action.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_action.h new file mode 100644 index 0000000000..ec29750bcf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_action.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace extension { + +enum class ExtensionAction : uint8_t { + INSTALL = 0, + LOAD = 1, + UNINSTALL = 2, +}; + +struct ExtensionAuxInfo { + ExtensionAction action; + std::string path; + + ExtensionAuxInfo(ExtensionAction action, std::string path) + : action{action}, path{std::move(path)} {} + + virtual ~ExtensionAuxInfo() = default; + + template + const TARGET& contCast() const { + return dynamic_cast(*this); + } + + virtual std::unique_ptr copy() { + return std::make_unique(*this); + } +}; + +struct InstallExtensionAuxInfo : public ExtensionAuxInfo { + std::string extensionRepo; + bool forceInstall; + + explicit InstallExtensionAuxInfo(std::string extensionRepo, std::string path, bool forceInstall) + : ExtensionAuxInfo{ExtensionAction::INSTALL, std::move(path)}, + extensionRepo{std::move(extensionRepo)}, forceInstall{forceInstall} {} + + std::unique_ptr copy() override { + return std::make_unique(*this); + } +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_installer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_installer.h new file mode 100644 index 0000000000..c7610f70a3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_installer.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include "common/api.h" +#include "extension.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace extension { + +struct InstallExtensionInfo { + std::string name; + std::string repo; + bool forceInstall; + + InstallExtensionInfo(std::string name, std::string repo, bool forceInstall) + : name{std::move(name)}, repo{std::move(repo)}, forceInstall{forceInstall} {} +}; + +class LBUG_API ExtensionInstaller { +public: + ExtensionInstaller(const InstallExtensionInfo& info, main::ClientContext& context) + : info{info}, context{context} {} + + virtual ~ExtensionInstaller() = default; + + virtual bool install(); + +protected: + void tryDownloadExtensionFile(const ExtensionRepoInfo& info, const std::string& localFilePath); + +private: + bool installExtension(); + void installDependencies(); + +protected: + const InstallExtensionInfo& info; + main::ClientContext& context; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_loader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_loader.h new file mode 100644 index 0000000000..14f72f5cf2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_loader.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "common/api.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace extension { + +class LBUG_API ExtensionLoader { +public: + explicit ExtensionLoader(std::string extensionName) : extensionName{std::move(extensionName)} {} + + virtual ~ExtensionLoader() = default; + + virtual void loadDependency(main::ClientContext* context) = 0; + +protected: + std::string extensionName; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_manager.h new file mode 100644 index 0000000000..4bee80cd5f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_manager.h @@ -0,0 +1,50 @@ +#pragma once + +#include "loaded_extension.h" +#include "storage/storage_extension.h" + +namespace lbug { +namespace main {} +namespace extension { + +struct ExtensionEntry { + const char* name; + const char* extensionName; +}; + +class ExtensionManager { +public: + void loadExtension(const std::string& path, main::ClientContext* context); + + LBUG_API std::string toCypher(); + + LBUG_API void addExtensionOption(std::string name, common::LogicalTypeID type, + common::Value defaultValue, bool isConfidential); + + const main::ExtensionOption* getExtensionOption(std::string name) const; + + LBUG_API void registerStorageExtension(std::string name, + std::unique_ptr storageExtension); + + std::vector getStorageExtensions(); + + LBUG_API const std::vector& getLoadedExtensions() const { + return loadedExtensions; + } + + static std::optional lookupExtensionsByFunctionName( + std::string_view functionName); + static std::optional lookupExtensionsByTypeName(std::string_view typeName); + + void autoLoadLinkedExtensions(main::ClientContext* context); + + LBUG_API static ExtensionManager* Get(const main::ClientContext& context); + +private: + std::vector loadedExtensions; + std::unordered_map extensionOptions; + common::case_insensitive_map_t> storageExtensions; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_statement.h new file mode 100644 index 0000000000..6f8b42737e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/extension_statement.h @@ -0,0 +1,22 @@ +#pragma once + +#include "parser/statement.h" + +namespace lbug { +namespace extension { + +class ExtensionStatement : public parser::Statement { + static constexpr common::StatementType type_ = common::StatementType::EXTENSION_CLAUSE; + +public: + explicit ExtensionStatement(std::string statementName) + : parser::Statement{type_}, statementName{std::move(statementName)} {} + + std::string getStatementName() const { return statementName; } + +private: + std::string statementName; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/loaded_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/loaded_extension.h new file mode 100644 index 0000000000..c4508e7433 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/loaded_extension.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "extension.h" + +namespace lbug { +namespace extension { + +class LoadedExtension { + +public: + LoadedExtension(std::string extensionName, std::string fullPath, ExtensionSource source) + : extensionName{std::move(extensionName)}, fullPath{std::move(fullPath)}, source{source} {} + + std::string getExtensionName() const { return extensionName; } + + std::string getFullPath() const { return fullPath; } + + ExtensionSource getSource() const { return source; } + + std::string toCypher(); + +private: + std::string extensionName; + std::string fullPath; + ExtensionSource source; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/logical_extension_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/logical_extension_clause.h new file mode 100644 index 0000000000..3210c08b2e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/logical_extension_clause.h @@ -0,0 +1,26 @@ +#pragma once + +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace extension { + +class LogicalExtensionClause : public planner::LogicalSimple { + static constexpr planner::LogicalOperatorType type_ = + planner::LogicalOperatorType::EXTENSION_CLAUSE; + +public: + explicit LogicalExtensionClause(std::string statementName) + : LogicalSimple{type_}, statementName{std::move(statementName)} {} + + void computeFactorizedSchema() override { createEmptySchema(); } + void computeFlatSchema() override { createEmptySchema(); } + + std::string getStatementName() const { return statementName; } + +private: + std::string statementName; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/mapper_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/mapper_extension.h new file mode 100644 index 0000000000..86cf925456 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/mapper_extension.h @@ -0,0 +1,21 @@ +#pragma once + +#include "planner/operator/logical_operator.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace extension { + +class LBUG_API MapperExtension { +public: + MapperExtension() {} + + virtual ~MapperExtension() = default; + + virtual std::unique_ptr map( + const planner::LogicalOperator* logicalOperator, main::ClientContext* context, + uint32_t operatorID) = 0; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/planner_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/planner_extension.h new file mode 100644 index 0000000000..50e447c47a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/planner_extension.h @@ -0,0 +1,21 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "planner/planner.h" + +namespace lbug { +namespace extension { + +class PlannerExtension { + +public: + PlannerExtension() {} + + virtual ~PlannerExtension() = default; + + virtual std::shared_ptr plan( + const binder::BoundStatement& boundStatement) = 0; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/transformer_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/transformer_extension.h new file mode 100644 index 0000000000..2d839c90ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/extension/transformer_extension.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "parser/statement.h" + +namespace antlr4 { +class ParserRuleContext; +} + +namespace lbug { +namespace extension { + +class LBUG_API TransformerExtension { +public: + TransformerExtension() {} + + virtual ~TransformerExtension() = default; + + virtual std::unique_ptr transform(antlr4::ParserRuleContext* context) = 0; +}; + +} // namespace extension +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/avg.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/avg.h new file mode 100644 index 0000000000..80b476b82c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/avg.h @@ -0,0 +1,101 @@ +#pragma once + +#include "common/in_mem_overflow_buffer.h" +#include "common/types/int128_t.h" +#include "common/types/uint128_t.h" +#include "function/aggregate_function.h" +#include "function/arithmetic/add.h" + +namespace lbug { +namespace function { + +template +struct AvgState : public AggregateStateWithNull { + uint32_t getStateSize() const override { return sizeof(*this); } + void writeToVector(common::ValueVector* outputVector, uint64_t pos) override { + outputVector->setValue(pos, avg); + } + + void finalize() + requires common::IntegerTypes + { + using ResultType = std::conditional, common::Int128_t, + common::UInt128_t>::type; + if (!isNull) { + avg = ResultType::template cast(sum) / + ResultType::template cast(count); + } + } + + void finalize() + requires common::FloatingPointTypes + { + if (!isNull) { + avg = sum / count; + } + } + + T sum{}; + uint64_t count = 0; + double avg = 0; +}; + +template +struct AvgFunction { + + static std::unique_ptr initialize() { + return std::make_unique>(); + } + + static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + common::InMemOverflowBuffer* /*overflowBuffer*/) { + auto* state = reinterpret_cast*>(state_); + KU_ASSERT(!input->state->isFlat()); + input->forEachNonNull( + [&](auto pos) { updateSingleValue(state, input, pos, multiplicity); }); + } + + static void updatePos(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + uint32_t pos, common::InMemOverflowBuffer* /*overflowBuffer*/) { + updateSingleValue(reinterpret_cast*>(state_), input, pos, + multiplicity); + } + + static void updateSingleValue(AvgState* state, common::ValueVector* input, + uint32_t pos, uint64_t multiplicity) { + INPUT_TYPE val = input->getValue(pos); + for (auto i = 0u; i < multiplicity; ++i) { + if (state->isNull) { + state->sum = (RESULT_TYPE)val; + state->isNull = false; + } else { + Add::operation(state->sum, val, state->sum); + } + } + state->count += multiplicity; + } + + static void combine(uint8_t* state_, uint8_t* otherState_, + common::InMemOverflowBuffer* /*overflowBuffer*/) { + auto* otherState = reinterpret_cast*>(otherState_); + if (otherState->isNull) { + return; + } + auto* state = reinterpret_cast*>(state_); + if (state->isNull) { + state->sum = otherState->sum; + state->isNull = false; + } else { + Add::operation(state->sum, otherState->sum, state->sum); + } + state->count = state->count + otherState->count; + } + + static void finalize(uint8_t* state_) { + auto* state = reinterpret_cast*>(state_); + state->finalize(); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/base_count.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/base_count.h new file mode 100644 index 0000000000..0cf99be70e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/base_count.h @@ -0,0 +1,36 @@ +#pragma once + +#include "function/aggregate_function.h" + +namespace lbug { +namespace function { + +struct BaseCountFunction { + + struct CountState : public AggregateState { + inline uint32_t getStateSize() const override { return sizeof(*this); } + inline void writeToVector(common::ValueVector* outputVector, uint64_t pos) override { + memcpy(outputVector->getData() + pos * outputVector->getNumBytesPerValue(), + reinterpret_cast(&count), outputVector->getNumBytesPerValue()); + } + + uint64_t count = 0; + }; + + static std::unique_ptr initialize() { + auto state = std::make_unique(); + return state; + } + + static void combine(uint8_t* state_, uint8_t* otherState_, + common::InMemOverflowBuffer* /*overflowBuffer*/) { + auto state = reinterpret_cast(state_); + auto otherState = reinterpret_cast(otherState_); + state->count += otherState->count; + } + + static void finalize(uint8_t* /*state_*/) {} +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/count.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/count.h new file mode 100644 index 0000000000..b65d80ba4d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/count.h @@ -0,0 +1,26 @@ +#pragma once + +#include "base_count.h" + +namespace lbug { +namespace function { + +struct CountFunction : public BaseCountFunction { + static constexpr const char* name = "COUNT"; + + static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + common::InMemOverflowBuffer* overflowBuffer); + + // NOLINTNEXTLINE(readability-non-const-parameter): Would cast away qualifiers. + static inline void updatePos(uint8_t* state_, common::ValueVector* /*input*/, + uint64_t multiplicity, uint32_t /*pos*/, common::InMemOverflowBuffer* /*overflowBuffer*/) { + reinterpret_cast(state_)->count += multiplicity; + } + + static void paramRewriteFunc(binder::expression_vector& arguments); + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/count_star.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/count_star.h new file mode 100644 index 0000000000..194800d6be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/count_star.h @@ -0,0 +1,21 @@ +#pragma once + +#include "base_count.h" + +namespace lbug { +namespace function { + +struct CountStarFunction : public BaseCountFunction { + static constexpr const char* name = "COUNT_STAR"; + + static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + common::InMemOverflowBuffer* /*overflowBuffer*/); + + static void updatePos(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + uint32_t /*pos*/, common::InMemOverflowBuffer* /*overflowBuffer*/); + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/min_max.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/min_max.h new file mode 100644 index 0000000000..2b4315c4e7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/min_max.h @@ -0,0 +1,92 @@ +#pragma once + +#include "common/in_mem_overflow_buffer.h" +#include "function/aggregate_function.h" + +namespace lbug { +namespace function { + +template +struct MinMaxFunction { + + struct MinMaxState : public AggregateStateWithNull { + uint32_t getStateSize() const override { return sizeof(*this); } + void writeToVector(common::ValueVector* outputVector, uint64_t pos) override { + outputVector->setValue(pos, val); + } + void setVal(const T& val_, common::InMemOverflowBuffer* /*overflowBuffer*/) { val = val_; } + + T val{}; + }; + + static std::unique_ptr initialize() { return std::make_unique(); } + + template + static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t /*multiplicity*/, + common::InMemOverflowBuffer* overflowBuffer) { + KU_ASSERT(!input->state->isFlat()); + auto* state = reinterpret_cast(state_); + input->forEachNonNull( + [&](auto pos) { updateSingleValue(state, input, pos, overflowBuffer); }); + } + + template + static inline void updatePos(uint8_t* state_, common::ValueVector* input, + uint64_t /*multiplicity*/, uint32_t pos, common::InMemOverflowBuffer* overflowBuffer) { + updateSingleValue(reinterpret_cast(state_), input, pos, overflowBuffer); + } + + template + static void updateSingleValue(MinMaxState* state, common::ValueVector* input, uint32_t pos, + common::InMemOverflowBuffer* overflowBuffer) { + T val = input->getValue(pos); + if (state->isNull) { + state->setVal(val, overflowBuffer); + state->isNull = false; + } else { + uint8_t compare_result = 0; + OP::template operation(val, state->val, compare_result, nullptr /* leftVector */, + nullptr /* rightVector */); + if (compare_result) { + state->setVal(val, overflowBuffer); + } + } + } + + template + static void combine(uint8_t* state_, uint8_t* otherState_, + common::InMemOverflowBuffer* overflowBuffer) { + auto* otherState = reinterpret_cast(otherState_); + if (otherState->isNull) { + return; + } + auto* state = reinterpret_cast(state_); + if (state->isNull) { + state->setVal(otherState->val, overflowBuffer); + state->isNull = false; + } else { + uint8_t compareResult = 0; + OP::template operation(otherState->val, state->val, compareResult, + nullptr /* leftVector */, nullptr /* rightVector */); + if (compareResult) { + state->setVal(otherState->val, overflowBuffer); + } + } + } + + static void finalize(uint8_t* /*state_*/) {} +}; + +template<> +void MinMaxFunction::MinMaxState::setVal(const common::ku_string_t& val_, + common::InMemOverflowBuffer* overflowBuffer) { + // We only need to allocate memory if the new val_ is a long string and is longer + // than the current val. + if (val_.len > common::ku_string_t::SHORT_STR_LENGTH && val_.len > val.len) { + val.overflowPtr = reinterpret_cast(overflowBuffer->allocateSpace(val_.len)); + } + val.set(val_); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/sum.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/sum.h new file mode 100644 index 0000000000..a0e1694ab5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate/sum.h @@ -0,0 +1,71 @@ +#pragma once + +#include "function/aggregate_function.h" +#include "function/arithmetic/add.h" + +namespace lbug { +namespace function { + +template +struct SumState : public AggregateStateWithNull { + uint32_t getStateSize() const override { return sizeof(*this); } + void writeToVector(common::ValueVector* outputVector, uint64_t pos) override { + outputVector->setValue(pos, sum); + } + + RESULT_TYPE sum{}; +}; + +template +struct SumFunction { + static std::unique_ptr initialize() { + return std::make_unique>(); + } + + static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + common::InMemOverflowBuffer* /*overflowBuffer*/) { + KU_ASSERT(!input->state->isFlat()); + auto* state = reinterpret_cast*>(state_); + input->forEachNonNull( + [&](auto pos) { updateSingleValue(state, input, pos, multiplicity); }); + } + + static void updatePos(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity, + uint32_t pos, common::InMemOverflowBuffer* /*overflowBuffer*/) { + auto* state = reinterpret_cast*>(state_); + updateSingleValue(state, input, pos, multiplicity); + } + + static void updateSingleValue(SumState* state, common::ValueVector* input, + uint32_t pos, uint64_t multiplicity) { + INPUT_TYPE val = input->getValue(pos); + for (auto j = 0u; j < multiplicity; ++j) { + if (state->isNull) { + state->sum = val; + state->isNull = false; + } else { + Add::operation(state->sum, val, state->sum); + } + } + } + + static void combine(uint8_t* state_, uint8_t* otherState_, + common::InMemOverflowBuffer* /*overflowBuffer*/) { + auto* otherState = reinterpret_cast*>(otherState_); + if (otherState->isNull) { + return; + } + auto* state = reinterpret_cast*>(state_); + if (state->isNull) { + state->sum = otherState->sum; + state->isNull = false; + } else { + Add::operation(state->sum, otherState->sum, state->sum); + } + } + + static void finalize(uint8_t* /*state_*/) {} +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate_function.h new file mode 100644 index 0000000000..5479240a26 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/aggregate_function.h @@ -0,0 +1,146 @@ +#pragma once + +#include +#include + +#include "common/in_mem_overflow_buffer.h" +#include "common/vector/value_vector.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct AggregateState { + virtual uint32_t getStateSize() const = 0; + virtual void writeToVector(common::ValueVector* outputVector, uint64_t pos) = 0; + virtual ~AggregateState() = default; + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +struct AggregateStateWithNull : public AggregateState { + bool isNull = true; +}; + +using param_rewrite_function_t = std::function; +using aggr_initialize_function_t = std::function()>; +using aggr_update_all_function_t = std::function; +using aggr_update_pos_function_t = std::function; +using aggr_combine_function_t = std::function; +using aggr_finalize_function_t = std::function; + +struct AggregateFunction final : public ScalarOrAggregateFunction { + bool isDistinct; + bool needToHandleNulls = false; + aggr_initialize_function_t initializeFunc; + aggr_update_all_function_t updateAllFunc; + aggr_update_pos_function_t updatePosFunc; + aggr_combine_function_t combineFunc; + aggr_finalize_function_t finalizeFunc; + std::unique_ptr initialNullAggregateState; + // Rewrite aggregate on NODE/REL, e.g. COUNT(a) -> COUNT(a._id) + param_rewrite_function_t paramRewriteFunc; + + AggregateFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, aggr_initialize_function_t initializeFunc, + aggr_update_all_function_t updateAllFunc, aggr_update_pos_function_t updatePosFunc, + aggr_combine_function_t combineFunc, aggr_finalize_function_t finalizeFunc, bool isDistinct, + scalar_bind_func bindFunc = nullptr, param_rewrite_function_t paramRewriteFunc = nullptr) + : ScalarOrAggregateFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID, + std::move(bindFunc)}, + isDistinct{isDistinct}, initializeFunc{std::move(initializeFunc)}, + updateAllFunc{std::move(updateAllFunc)}, updatePosFunc{std::move(updatePosFunc)}, + combineFunc{std::move(combineFunc)}, finalizeFunc{std::move(finalizeFunc)}, + paramRewriteFunc{std::move(paramRewriteFunc)} { + initialNullAggregateState = createInitialNullAggregateState(); + } + + EXPLICIT_COPY_DEFAULT_MOVE(AggregateFunction); + + common::idx_t getAggregateStateSize() const { + return initialNullAggregateState->getStateSize(); + } + + // NOLINTNEXTLINE(readability-make-member-function-const): Returns a non-const pointer. + AggregateState* getInitialNullAggregateState() { return initialNullAggregateState.get(); } + + std::unique_ptr createInitialNullAggregateState() const { + return initializeFunc(); + } + + void updateAllState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity, + common::InMemOverflowBuffer* overflowBuffer) const { + return updateAllFunc(state, input, multiplicity, overflowBuffer); + } + + void updatePosState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity, + uint32_t pos, common::InMemOverflowBuffer* overflowBuffer) const { + return updatePosFunc(state, input, multiplicity, pos, overflowBuffer); + } + + void combineState(uint8_t* state, uint8_t* otherState, + common::InMemOverflowBuffer* overflowBuffer) const { + return combineFunc(state, otherState, overflowBuffer); + } + + void finalizeState(uint8_t* state) const { return finalizeFunc(state); } + + bool isFunctionDistinct() const { return isDistinct; } + +private: + AggregateFunction(const AggregateFunction& other); +}; + +struct AggregateFunctionUtils { + template + static std::unique_ptr getAggFunc(std::string name, + common::LogicalTypeID inputType, common::LogicalTypeID resultType, bool isDistinct, + param_rewrite_function_t paramRewriteFunc = nullptr) { + return std::make_unique(std::move(name), + std::vector{inputType}, resultType, T::initialize, T::updateAll, + T::updatePos, T::combine, T::finalize, isDistinct, nullptr /* bindFunc */, + paramRewriteFunc); + } + + template class FunctionType> + static void appendSumOrAvgFuncs(std::string name, common::LogicalTypeID inputType, + function_set& result); +}; + +struct AggregateSumFunction { + static constexpr const char* name = "SUM"; + + static function_set getFunctionSet(); +}; + +struct AggregateAvgFunction { + static constexpr const char* name = "AVG"; + + static function_set getFunctionSet(); +}; + +struct AggregateMinFunction { + static constexpr const char* name = "MIN"; + + static function_set getFunctionSet(); +}; + +struct AggregateMaxFunction { + static constexpr const char* name = "MAX"; + + static function_set getFunctionSet(); +}; + +struct CollectFunction { + static constexpr const char* name = "COLLECT"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/abs.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/abs.h new file mode 100644 index 0000000000..8f3aef62bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/abs.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include "common/types/int128_t.h" +#include "common/types/types.h" +#include "common/types/uint128_t.h" + +namespace lbug { +namespace function { + +struct Abs { + template + static inline void operation(T& input, T& result) { + if constexpr (common::UnsignedIntegerTypes) { + result = input; + } else { + result = std::abs(input); + } + } +}; + +template<> +void Abs::operation(int8_t& input, int8_t& result); + +template<> +void Abs::operation(int16_t& input, int16_t& result); + +template<> +void Abs::operation(int32_t& input, int32_t& result); + +template<> +void Abs::operation(int64_t& input, int64_t& result); + +template<> +void Abs::operation(common::int128_t& input, common::int128_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/add.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/add.h new file mode 100644 index 0000000000..51300ebb7f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/add.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +struct Add { + template + static inline void operation(A& left, B& right, R& result) { + result = left + right; + } +}; + +template<> +void Add::operation(uint8_t& left, uint8_t& right, uint8_t& result); + +template<> +void Add::operation(uint16_t& left, uint16_t& right, uint16_t& result); + +template<> +void Add::operation(uint32_t& left, uint32_t& right, uint32_t& result); + +template<> +void Add::operation(uint64_t& left, uint64_t& right, uint64_t& result); + +template<> +void Add::operation(int8_t& left, int8_t& right, int8_t& result); + +template<> +void Add::operation(int16_t& left, int16_t& right, int16_t& result); + +template<> +void Add::operation(int32_t& left, int32_t& right, int32_t& result); + +template<> +void Add::operation(int64_t& left, int64_t& right, int64_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/arithmetic_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/arithmetic_functions.h new file mode 100644 index 0000000000..47301ea7cc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/arithmetic_functions.h @@ -0,0 +1,245 @@ +#pragma once + +#include + +#include "common/types/int128_t.h" +#include "common/types/uint128_t.h" + +namespace lbug { +namespace function { + +struct Power { + template + static inline void operation(A& left, B& right, R& result) { + result = pow(left, right); + } +}; + +struct Floor { + template + static inline void operation(T& input, T& result) { + result = floor(input); + } +}; + +template<> +inline void Floor::operation(common::int128_t& input, common::int128_t& result) { + result = input; +} + +template<> +inline void Floor::operation(common::uint128_t& input, common::uint128_t& result) { + result = input; +} + +struct Ceil { + template + static inline void operation(T& input, T& result) { + result = ceil(input); + } +}; + +template<> +inline void Ceil::operation(common::int128_t& input, common::int128_t& result) { + result = input; +} + +template<> +inline void Ceil::operation(common::uint128_t& input, common::uint128_t& result) { + result = input; +} + +struct Sin { + template + static inline void operation(T& input, double& result) { + result = sin(input); + } +}; + +struct Cos { + template + static inline void operation(T& input, double& result) { + result = cos(input); + } +}; + +struct Tan { + template + static inline void operation(T& input, double& result) { + result = tan(input); + } +}; + +struct Cot { + template + static inline void operation(T& input, double& result) { + double tanValue = 0; + Tan::operation(input, tanValue); + result = 1 / tanValue; + } +}; + +struct Asin { + template + static inline void operation(T& input, double& result) { + result = asin(input); + } +}; + +struct Acos { + template + static inline void operation(T& input, double& result) { + result = acos(input); + } +}; + +struct Atan { + template + static inline void operation(T& input, double& result) { + result = atan(input); + } +}; + +struct Even { + template + static inline void operation(T& input, double& result) { + result = input >= 0 ? ceil(input) : floor(input); + // Note: c++ doesn't support double % integer, so we have to use the following code to check + // whether result is odd or even. + if (std::floor(result / 2) * 2 != result) { + result += (input >= 0 ? 1 : -1); + } + } +}; + +struct Factorial { + static inline void operation(int64_t& input, int64_t& result) { + result = 1; + for (int64_t i = 2; i <= input; i++) { + result *= i; + } + } +}; + +struct Sign { + template + static inline void operation(T& input, int64_t& result) { + result = (input > 0) - (input < 0); + } +}; + +struct Sqrt { + template + static inline void operation(T& input, double& result) { + result = sqrt(input); + } +}; + +struct Cbrt { + template + static inline void operation(T& input, double& result) { + result = cbrt(input); + } +}; + +struct Gamma { + template + static inline void operation(T& input, T& result) { + result = tgamma(input); + } +}; + +struct Lgamma { + template + static inline void operation(T& input, double& result) { + result = + lgamma(input); // NOLINT(concurrency-mt-unsafe): We don't use the thread-unsafe signgam. + } +}; + +struct Ln { + template + static inline void operation(T& input, double& result) { + result = log(input); + } +}; + +struct Log { + template + static inline void operation(T& input, double& result) { + result = log10(input); + } +}; + +struct Log2 { + template + static inline void operation(T& input, double& result) { + result = log2(input); + } +}; + +struct Degrees { + template + static inline void operation(T& input, double& result) { + result = input * 180 / M_PI; + } +}; + +struct Radians { + template + static inline void operation(T& input, double& result) { + result = input * M_PI / 180; + } +}; + +struct Atan2 { + template + static inline void operation(A& left, B& right, double& result) { + result = atan2(left, right); + } +}; + +struct Round { + template + static inline void operation(A& left, B& right, double& result) { + auto multiplier = pow(10, right); + result = round(left * multiplier) / multiplier; + } +}; + +struct BitwiseXor { + static inline void operation(int64_t& left, int64_t& right, int64_t& result) { + result = left ^ right; + } +}; + +struct BitwiseAnd { + static inline void operation(int64_t& left, int64_t& right, int64_t& result) { + result = left & right; + } +}; + +struct BitwiseOr { + static inline void operation(int64_t& left, int64_t& right, int64_t& result) { + result = left | right; + } +}; + +struct BitShiftLeft { + static inline void operation(int64_t& left, int64_t& right, int64_t& result) { + result = left << right; + } +}; + +struct BitShiftRight { + static inline void operation(int64_t& left, int64_t& right, int64_t& result) { + result = left >> right; + } +}; + +struct Pi { + static inline void operation(double& result) { result = M_PI; } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/divide.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/divide.h new file mode 100644 index 0000000000..19111004f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/divide.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +struct Divide { + template + static inline void operation(A& left, B& right, R& result) { + result = left / right; + } +}; + +template<> +void Divide::operation(uint8_t& left, uint8_t& right, uint8_t& result); + +template<> +void Divide::operation(uint16_t& left, uint16_t& right, uint16_t& result); + +template<> +void Divide::operation(uint32_t& left, uint32_t& right, uint32_t& result); + +template<> +void Divide::operation(uint64_t& left, uint64_t& right, uint64_t& result); + +template<> +void Divide::operation(int8_t& left, int8_t& right, int8_t& result); + +template<> +void Divide::operation(int16_t& left, int16_t& right, int16_t& result); + +template<> +void Divide::operation(int32_t& left, int32_t& right, int32_t& result); + +template<> +void Divide::operation(int64_t& left, int64_t& right, int64_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/modulo.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/modulo.h new file mode 100644 index 0000000000..0595c9304d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/modulo.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include "common/types/int128_t.h" +#include "common/types/uint128_t.h" + +namespace lbug { +namespace function { + +struct Modulo { + template + static inline void operation(A& left, B& right, R& result) { + result = fmod(left, right); + } +}; + +template<> +void Modulo::operation(uint8_t& left, uint8_t& right, uint8_t& result); + +template<> +void Modulo::operation(uint16_t& left, uint16_t& right, uint16_t& result); + +template<> +void Modulo::operation(uint32_t& left, uint32_t& right, uint32_t& result); + +template<> +void Modulo::operation(uint64_t& left, uint64_t& right, uint64_t& result); + +template<> +void Modulo::operation(int8_t& left, int8_t& right, int8_t& result); + +template<> +void Modulo::operation(int16_t& left, int16_t& right, int16_t& result); + +template<> +void Modulo::operation(int32_t& left, int32_t& right, int32_t& result); + +template<> +void Modulo::operation(int64_t& left, int64_t& right, int64_t& result); + +template<> +void Modulo::operation(common::int128_t& left, common::int128_t& right, common::int128_t& result); + +template<> +void Modulo::operation(common::uint128_t& left, common::uint128_t& right, + common::uint128_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/multiply.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/multiply.h new file mode 100644 index 0000000000..4062be69d3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/multiply.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +struct Multiply { + template + static inline void operation(A& left, B& right, R& result) { + result = left * right; + } +}; + +template<> +void Multiply::operation(uint8_t& left, uint8_t& right, uint8_t& result); + +template<> +void Multiply::operation(uint16_t& left, uint16_t& right, uint16_t& result); + +template<> +void Multiply::operation(uint32_t& left, uint32_t& right, uint32_t& result); + +template<> +void Multiply::operation(uint64_t& left, uint64_t& right, uint64_t& result); + +template<> +void Multiply::operation(int8_t& left, int8_t& right, int8_t& result); + +template<> +void Multiply::operation(int16_t& left, int16_t& right, int16_t& result); + +template<> +void Multiply::operation(int32_t& left, int32_t& right, int32_t& result); + +template<> +void Multiply::operation(int64_t& left, int64_t& right, int64_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/negate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/negate.h new file mode 100644 index 0000000000..cc2d55cc7b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/negate.h @@ -0,0 +1,28 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +struct Negate { + template + static inline void operation(T& input, T& result) { + result = -input; + } +}; + +template<> +void Negate::operation(int8_t& input, int8_t& result); + +template<> +void Negate::operation(int16_t& input, int16_t& result); + +template<> +void Negate::operation(int32_t& input, int32_t& result); + +template<> +void Negate::operation(int64_t& input, int64_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/subtract.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/subtract.h new file mode 100644 index 0000000000..1fff6a2197 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/subtract.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +struct Subtract { + template + static inline void operation(A& left, B& right, R& result) { + result = left - right; + } +}; + +template<> +void Subtract::operation(uint8_t& left, uint8_t& right, uint8_t& result); + +template<> +void Subtract::operation(uint16_t& left, uint16_t& right, uint16_t& result); + +template<> +void Subtract::operation(uint32_t& left, uint32_t& right, uint32_t& result); + +template<> +void Subtract::operation(uint64_t& left, uint64_t& right, uint64_t& result); + +template<> +void Subtract::operation(int8_t& left, int8_t& right, int8_t& result); + +template<> +void Subtract::operation(int16_t& left, int16_t& right, int16_t& result); + +template<> +void Subtract::operation(int32_t& left, int32_t& right, int32_t& result); + +template<> +void Subtract::operation(int64_t& left, int64_t& right, int64_t& result); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/vector_arithmetic_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/vector_arithmetic_functions.h new file mode 100644 index 0000000000..3a7d2f1d8c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/arithmetic/vector_arithmetic_functions.h @@ -0,0 +1,263 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct AddFunction { + static constexpr const char* name = "+"; + + static function_set getFunctionSet(); +}; + +struct SubtractFunction { + static constexpr const char* name = "-"; + + static function_set getFunctionSet(); +}; + +struct MultiplyFunction { + static constexpr const char* name = "*"; + + static function_set getFunctionSet(); +}; + +struct DivideFunction { + static constexpr const char* name = "/"; + + static function_set getFunctionSet(); +}; + +struct ModuloFunction { + static constexpr const char* name = "%"; + + static function_set getFunctionSet(); +}; + +struct PowerFunction { + static constexpr const char* name = "^"; + + static function_set getFunctionSet(); +}; + +struct PowFunction { + using alias = PowerFunction; + + static constexpr const char* name = "POW"; +}; + +struct AbsFunction { + static constexpr const char* name = "ABS"; + + static function_set getFunctionSet(); +}; + +struct AcosFunction { + static constexpr const char* name = "ACOS"; + + static function_set getFunctionSet(); +}; + +struct AsinFunction { + static constexpr const char* name = "ASIN"; + + static function_set getFunctionSet(); +}; + +struct AtanFunction { + static constexpr const char* name = "ATAN"; + + static function_set getFunctionSet(); +}; + +struct Atan2Function { + static constexpr const char* name = "ATAN2"; + + static function_set getFunctionSet(); +}; + +struct BitwiseXorFunction { + static constexpr const char* name = "BITWISE_XOR"; + + static function_set getFunctionSet(); +}; + +struct BitwiseAndFunction { + static constexpr const char* name = "BITWISE_AND"; + + static function_set getFunctionSet(); +}; + +struct BitwiseOrFunction { + static constexpr const char* name = "BITWISE_OR"; + + static function_set getFunctionSet(); +}; + +struct BitShiftLeftFunction { + static constexpr const char* name = "BITSHIFT_LEFT"; + + static function_set getFunctionSet(); +}; + +struct BitShiftRightFunction { + static constexpr const char* name = "BITSHIFT_RIGHT"; + + static function_set getFunctionSet(); +}; + +struct CbrtFunction { + static constexpr const char* name = "CBRT"; + + static function_set getFunctionSet(); +}; + +struct CeilFunction { + static constexpr const char* name = "CEIL"; + + static function_set getFunctionSet(); +}; + +struct CeilingFunction { + using alias = CeilFunction; + + static constexpr const char* name = "CEILING"; +}; + +struct CosFunction { + static constexpr const char* name = "COS"; + + static function_set getFunctionSet(); +}; + +struct CotFunction { + static constexpr const char* name = "COT"; + + static function_set getFunctionSet(); +}; + +struct DegreesFunction { + static constexpr const char* name = "DEGREES"; + + static function_set getFunctionSet(); +}; + +struct EvenFunction { + static constexpr const char* name = "EVEN"; + + static function_set getFunctionSet(); +}; + +struct FactorialFunction { + static constexpr const char* name = "FACTORIAL"; + + static function_set getFunctionSet(); +}; + +struct FloorFunction { + static constexpr const char* name = "FLOOR"; + + static function_set getFunctionSet(); +}; + +struct GammaFunction { + static constexpr const char* name = "GAMMA"; + + static function_set getFunctionSet(); +}; + +struct LgammaFunction { + static constexpr const char* name = "LGAMMA"; + + static function_set getFunctionSet(); +}; + +struct LnFunction { + static constexpr const char* name = "LN"; + + static function_set getFunctionSet(); +}; + +struct LogFunction { + static constexpr const char* name = "LOG"; + + static constexpr const char* alias = "LOG10"; + + static function_set getFunctionSet(); +}; + +struct Log10Function { + using alias = LogFunction; + + static constexpr const char* name = "LOG10"; +}; + +struct Log2Function { + static constexpr const char* name = "LOG2"; + + static function_set getFunctionSet(); +}; + +struct NegateFunction { + static constexpr const char* name = "NEGATE"; + + static function_set getFunctionSet(); +}; + +struct PiFunction { + static constexpr const char* name = "PI"; + + static function_set getFunctionSet(); +}; + +struct RandFunction { + static constexpr const char* name = "RANDOM"; + + static function_set getFunctionSet(); +}; + +struct SetSeedFunction { + static constexpr const char* name = "SETSEED"; + + static function_set getFunctionSet(); +}; + +struct RadiansFunction { + static constexpr const char* name = "RADIANS"; + + static function_set getFunctionSet(); +}; + +struct RoundFunction { + static constexpr const char* name = "ROUND"; + + static function_set getFunctionSet(); +}; + +struct SinFunction { + static constexpr const char* name = "SIN"; + + static function_set getFunctionSet(); +}; + +struct SignFunction { + static constexpr const char* name = "SIGN"; + + static function_set getFunctionSet(); +}; + +struct SqrtFunction { + static constexpr const char* name = "SQRT"; + + static function_set getFunctionSet(); +}; + +struct TanFunction { + static constexpr const char* name = "TAN"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_cosine_similarity.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_cosine_similarity.h new file mode 100644 index 0000000000..849546efdc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_cosine_similarity.h @@ -0,0 +1,31 @@ +#pragma once + +#include "math.h" + +#include "common/vector/value_vector.h" +#include + +namespace lbug { +namespace function { + +struct ArrayCosineSimilarity { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + KU_ASSERT(left.size == right.size); + simsimd_distance_t tmpResult = 0.0; + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { + simsimd_cos_f32(leftElements, rightElements, left.size, &tmpResult); + } else { + simsimd_cos_f64(leftElements, rightElements, left.size, &tmpResult); + } + result = 1.0 - tmpResult; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_cross_product.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_cross_product.h new file mode 100644 index 0000000000..7536ebc6d4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_cross_product.h @@ -0,0 +1,24 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +template +struct ArrayCrossProduct { + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + result = common::ListVector::addList(&resultVector, left.size); + auto resultElements = (T*)common::ListVector::getListValues(&resultVector, result); + resultElements[0] = leftElements[1] * rightElements[2] - leftElements[2] * rightElements[1]; + resultElements[1] = leftElements[2] * rightElements[0] - leftElements[0] * rightElements[2]; + resultElements[2] = leftElements[0] * rightElements[1] - leftElements[1] * rightElements[0]; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_distance.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_distance.h new file mode 100644 index 0000000000..88af4c5722 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_distance.h @@ -0,0 +1,23 @@ +#pragma once + +#include "math.h" + +#include "common/vector/value_vector.h" +#include "function/array/functions/array_squared_distance.h" + +namespace lbug { +namespace function { + +// Euclidean distance between two arrays. +struct ArrayDistance { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + ArraySquaredDistance::operation(left, right, result, leftVector, rightVector, resultVector); + result = std::sqrt(result); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_inner_product.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_inner_product.h new file mode 100644 index 0000000000..0649320f23 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_inner_product.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include + +namespace lbug { +namespace function { + +struct ArrayInnerProduct { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + KU_ASSERT(left.size == right.size); + simsimd_distance_t tmpResult = 0.0; + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { + simsimd_dot_f32(leftElements, rightElements, left.size, &tmpResult); + } else { + simsimd_dot_f64(leftElements, rightElements, left.size, &tmpResult); + } + result = tmpResult; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_squared_distance.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_squared_distance.h new file mode 100644 index 0000000000..c4afd310bc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/functions/array_squared_distance.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include + +namespace lbug { +namespace function { + +struct ArraySquaredDistance { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + KU_ASSERT(left.size == right.size); + simsimd_distance_t tmpResult = 0.0; + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { + simsimd_l2sq_f32(leftElements, rightElements, left.size, &tmpResult); + } else { + simsimd_l2sq_f64(leftElements, rightElements, left.size, &tmpResult); + } + result = tmpResult; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/vector_array_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/vector_array_functions.h new file mode 100644 index 0000000000..5e8aa1b8d0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/array/vector_array_functions.h @@ -0,0 +1,106 @@ +#pragma once + +#include "function/function.h" +#include "function/list/vector_list_functions.h" + +namespace lbug { +namespace function { + +struct ArrayValueFunction { + static constexpr const char* name = "ARRAY_VALUE"; + + static function_set getFunctionSet(); +}; + +struct ArrayCrossProductFunction { + static constexpr const char* name = "ARRAY_CROSS_PRODUCT"; + + static function_set getFunctionSet(); +}; + +struct ArrayCosineSimilarityFunction { + static constexpr const char* name = "ARRAY_COSINE_SIMILARITY"; + + static function_set getFunctionSet(); +}; + +struct ArrayDistanceFunction { + static constexpr const char* name = "ARRAY_DISTANCE"; + + static function_set getFunctionSet(); +}; + +struct ArraySquaredDistanceFunction { + static constexpr const char* name = "ARRAY_SQUARED_DISTANCE"; + + static function_set getFunctionSet(); +}; + +struct ArrayInnerProductFunction { + static constexpr const char* name = "ARRAY_INNER_PRODUCT"; + + static function_set getFunctionSet(); +}; + +struct ArrayDotProductFunction { + static constexpr const char* name = "ARRAY_DOT_PRODUCT"; + + static function_set getFunctionSet(); +}; + +struct ArrayConcatFunction : public ListConcatFunction { + static constexpr const char* name = "ARRAY_CONCAT"; +}; + +struct ArrayCatFunction { + using alias = ArrayConcatFunction; + + static constexpr const char* name = "ARRAY_CAT"; +}; + +struct ArrayAppendFunction : public ListAppendFunction { + static constexpr const char* name = "ARRAY_APPEND"; +}; + +struct ArrayPushBackFunction { + using alias = ArrayAppendFunction; + + static constexpr const char* name = "ARRAY_PUSH_BACK"; +}; + +struct ArrayPrependFunction : public ListPrependFunction { + static constexpr const char* name = "ARRAY_PREPEND"; +}; + +struct ArrayPushFrontFunction { + using alias = ArrayPrependFunction; + + static constexpr const char* name = "ARRAY_PUSH_FRONT"; +}; + +struct ArrayPositionFunction : public ListPositionFunction { + static constexpr const char* name = "ARRAY_POSITION"; +}; + +struct ArrayIndexOfFunction { + using alias = ArrayPositionFunction; + + static constexpr const char* name = "ARRAY_INDEXOF"; +}; + +struct ArrayContainsFunction : public ListContainsFunction { + static constexpr const char* name = "ARRAY_CONTAINS"; +}; + +struct ArrayHasFunction { + using alias = ArrayContainsFunction; + + static constexpr const char* name = "ARRAY_HAS"; +}; + +struct ArraySliceFunction : public ListSliceFunction { + static constexpr const char* name = "ARRAY_SLICE"; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/binary_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/binary_function_executor.h new file mode 100644 index 0000000000..6a63ce7e88 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/binary_function_executor.h @@ -0,0 +1,329 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +/** + * Binary operator assumes function with null returns null. This does NOT applies to binary boolean + * operations (e.g. AND, OR, XOR). + */ + +struct BinaryFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, + common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* /*dataPtr*/) { + OP::operation(left, right, result); + } +}; + +struct BinaryListStructFunctionWrapper { + template + static void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* /*dataPtr*/) { + OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector); + } +}; + +struct BinaryMapCreationFunctionWrapper { + template + static void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* dataPtr) { + OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector, + dataPtr); + } +}; + +struct BinaryListExtractFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + common::ValueVector* resultValueVector, uint64_t resultPos, void* /*dataPtr*/) { + OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector, + resultPos); + } +}; + +struct BinaryStringFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, + common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* /*dataPtr*/) { + OP::operation(left, right, result, *resultValueVector); + } +}; + +struct BinaryComparisonFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* /*dataPtr*/) { + OP::operation(left, right, result, leftValueVector, rightValueVector); + } +}; + +struct BinaryUDFFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, + common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* dataPtr) { + OP::operation(left, right, result, dataPtr); + } +}; + +struct BinarySelectWithBindDataWrapper { + template + static void operation(LEFT_TYPE& left, RIGHT_TYPE& right, uint8_t& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + void* dataPtr) { + OP::operation(left, right, result, *leftValueVector, *rightValueVector, *leftValueVector, + dataPtr); + } +}; + +struct BinaryFunctionExecutor { + + template + static inline void executeOnValue(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos, + void* dataPtr) { + OP_WRAPPER::template operation( + ((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos], + ((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector, + resPos, dataPtr); + } + + static inline std::tuple getSelectedPositions( + common::SelectionVector* leftSelVector, common::SelectionVector* rightSelVector, + common::SelectionVector* resultSelVector, common::sel_t selPos, bool leftFlat, + bool rightFlat) { + common::sel_t lPos = (*leftSelVector)[leftFlat ? 0 : selPos]; + common::sel_t rPos = (*rightSelVector)[rightFlat ? 0 : selPos]; + common::sel_t resPos = (*resultSelVector)[leftFlat && rightFlat ? 0 : selPos]; + return {lPos, rPos, resPos}; + } + + template + static void executeOnSelectedValues(common::ValueVector& left, + common::SelectionVector* leftSelVector, common::ValueVector& right, + common::SelectionVector* rightSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + const bool leftFlat = left.state->isFlat(); + const bool rightFlat = right.state->isFlat(); + + const bool allNullsGuaranteed = (rightFlat && right.isNull((*rightSelVector)[0])) || + (leftFlat && left.isNull((*leftSelVector)[0])); + if (allNullsGuaranteed) { + result.setAllNull(); + } else { + const bool noNullsGuaranteed = (leftFlat || left.hasNoNullsGuarantee()) && + (rightFlat || right.hasNoNullsGuarantee()); + if (noNullsGuaranteed) { + result.setAllNonNull(); + } + + const auto numSelectedValues = + leftFlat ? rightSelVector->getSelSize() : leftSelVector->getSelSize(); + for (common::sel_t selPos = 0; selPos < numSelectedValues; ++selPos) { + auto [lPos, rPos, resPos] = getSelectedPositions(leftSelVector, rightSelVector, + resultSelVector, selPos, leftFlat, rightFlat); + if (noNullsGuaranteed) { + executeOnValue(left, + right, result, lPos, rPos, resPos, dataPtr); + } else { + result.setNull(resPos, left.isNull(lPos) || right.isNull(rPos)); + if (!result.isNull(resPos)) { + executeOnValue(left, + right, result, lPos, rPos, resPos, dataPtr); + } + } + } + } + } + + template + static void executeSwitch(common::ValueVector& left, common::SelectionVector* leftSelVector, + common::ValueVector& right, common::SelectionVector* rightSelVector, + common::ValueVector& result, common::SelectionVector* resultSelVector, void* dataPtr) { + result.resetAuxiliaryBuffer(); + executeOnSelectedValues(left, + leftSelVector, right, rightSelVector, result, resultSelVector, dataPtr); + } + + template + static void execute(common::ValueVector& left, common::SelectionVector* leftSelVector, + common::ValueVector& right, common::SelectionVector* rightSelVector, + common::ValueVector& result, common::SelectionVector* resultSelVector) { + executeSwitch(left, + leftSelVector, right, rightSelVector, result, resultSelVector, nullptr /* dataPtr */); + } + + struct BinarySelectWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, uint8_t& result, + common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, + void* /*dataPtr*/) { + OP::operation(left, right, result); + } + }; + + struct BinaryComparisonSelectWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, uint8_t& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + void* /*dataPtr*/) { + OP::operation(left, right, result, leftValueVector, rightValueVector); + } + }; + + template + static void selectOnValue(common::ValueVector& left, common::ValueVector& right, uint64_t lPos, + uint64_t rPos, uint64_t resPos, uint64_t& numSelectedValues, + std::span selectedPositionsBuffer, void* dataPtr) { + uint8_t resultValue = 0; + SELECT_WRAPPER::template operation( + ((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos], resultValue, + &left, &right, dataPtr); + selectedPositionsBuffer[numSelectedValues] = resPos; + numSelectedValues += (resultValue == true); + } + + template + static uint64_t selectBothFlat(common::ValueVector& left, common::ValueVector& right, + void* dataPtr) { + auto lPos = left.state->getSelVector()[0]; + auto rPos = right.state->getSelVector()[0]; + uint8_t resultValue = 0; + if (!left.isNull(lPos) && !right.isNull(rPos)) { + SELECT_WRAPPER::template operation( + ((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos], + resultValue, &left, &right, dataPtr); + } + return resultValue == true; + } + + template + static bool selectFlatUnFlat(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector, void* dataPtr) { + auto lPos = left.state->getSelVector()[0]; + uint64_t numSelectedValues = 0; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + auto& rightSelVector = right.state->getSelVector(); + if (left.isNull(lPos)) { + return numSelectedValues; + } else if (right.hasNoNullsGuarantee()) { + rightSelVector.forEach([&](auto i) { + selectOnValue(left, right, lPos, i, i, + numSelectedValues, selectedPositionsBuffer, dataPtr); + }); + } else { + rightSelVector.forEach([&](auto i) { + if (!right.isNull(i)) { + selectOnValue(left, right, lPos, i, + i, numSelectedValues, selectedPositionsBuffer, dataPtr); + } + }); + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + + template + static bool selectUnFlatFlat(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector, void* dataPtr) { + auto rPos = right.state->getSelVector()[0]; + uint64_t numSelectedValues = 0; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + auto& leftSelVector = left.state->getSelVector(); + if (right.isNull(rPos)) { + return numSelectedValues; + } else if (left.hasNoNullsGuarantee()) { + leftSelVector.forEach([&](auto i) { + selectOnValue(left, right, i, rPos, i, + numSelectedValues, selectedPositionsBuffer, dataPtr); + }); + } else { + leftSelVector.forEach([&](auto i) { + if (!left.isNull(i)) { + selectOnValue(left, right, i, rPos, + i, numSelectedValues, selectedPositionsBuffer, dataPtr); + } + }); + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + + // Right, left, and result vectors share the same selectedPositions. + template + static bool selectBothUnFlat(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector, void* dataPtr) { + uint64_t numSelectedValues = 0; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + auto& leftSelVector = left.state->getSelVector(); + if (left.hasNoNullsGuarantee() && right.hasNoNullsGuarantee()) { + leftSelVector.forEach([&](auto i) { + selectOnValue(left, right, i, i, i, + numSelectedValues, selectedPositionsBuffer, dataPtr); + }); + } else { + leftSelVector.forEach([&](auto i) { + auto isNull = left.isNull(i) || right.isNull(i); + if (!isNull) { + selectOnValue(left, right, i, i, i, + numSelectedValues, selectedPositionsBuffer, dataPtr); + } + }); + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + + // BOOLEAN (AND, OR, XOR) + template + static bool select(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector, void* dataPtr) { + if (left.state->isFlat() && right.state->isFlat()) { + return selectBothFlat(left, right, dataPtr); + } else if (left.state->isFlat() && !right.state->isFlat()) { + return selectFlatUnFlat(left, right, selVector, + dataPtr); + } else if (!left.state->isFlat() && right.state->isFlat()) { + return selectUnFlatFlat(left, right, selVector, + dataPtr); + } else { + return selectBothUnFlat(left, right, selVector, + dataPtr); + } + } + + // COMPARISON (GT, GTE, LT, LTE, EQ, NEQ) + template + static bool selectComparison(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector, void* dataPtr) { + if (left.state->isFlat() && right.state->isFlat()) { + return selectBothFlat(left, + right, dataPtr); + } else if (left.state->isFlat() && !right.state->isFlat()) { + return selectFlatUnFlat( + left, right, selVector, dataPtr); + } else if (!left.state->isFlat() && right.state->isFlat()) { + return selectUnFlatFlat( + left, right, selVector, dataPtr); + } else { + return selectBothUnFlat( + left, right, selVector, dataPtr); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/decode_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/decode_function.h new file mode 100644 index 0000000000..db8599afc3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/decode_function.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/types/blob.h" +#include "common/vector/value_vector.h" +#include "utf8proc_wrapper.h" + +namespace lbug { +namespace function { + +struct Decode { + static inline void operation(common::blob_t& input, common::ku_string_t& result, + common::ValueVector& resultVector) { + if (utf8proc::Utf8Proc::analyze(reinterpret_cast(input.value.getData()), + input.value.len) == utf8proc::UnicodeType::INVALID) { + throw common::RuntimeException( + "Failure in decode: could not convert blob to UTF8 string, " + "the blob contained invalid UTF8 characters"); + } + common::StringVector::addString(&resultVector, result, input.value); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/encode_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/encode_function.h new file mode 100644 index 0000000000..206513ee91 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/encode_function.h @@ -0,0 +1,17 @@ +#pragma once + +#include "common/types/blob.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct Encode { + static inline void operation(common::ku_string_t& input, common::blob_t& result, + common::ValueVector& resultVector) { + common::StringVector::addString(&resultVector, result.value, input); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/octet_length_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/octet_length_function.h new file mode 100644 index 0000000000..d98d45230f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/functions/octet_length_function.h @@ -0,0 +1,15 @@ +#pragma once + +#include "common/types/blob.h" + +namespace lbug { +namespace function { + +struct OctetLength { + static inline void operation(common::blob_t& input, int64_t& result) { + result = input.value.len; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/vector_blob_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/vector_blob_functions.h new file mode 100644 index 0000000000..501e056684 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/blob/vector_blob_functions.h @@ -0,0 +1,27 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct OctetLengthFunctions { + static constexpr const char* name = "OCTET_LENGTH"; + + static function_set getFunctionSet(); +}; + +struct EncodeFunctions { + static constexpr const char* name = "ENCODE"; + + static function_set getFunctionSet(); +}; + +struct DecodeFunctions { + static constexpr const char* name = "DECODE"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/boolean_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/boolean_function_executor.h new file mode 100644 index 0000000000..39bd8e0cf2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/boolean_function_executor.h @@ -0,0 +1,342 @@ +#pragma once + +#include "boolean_functions.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +/** + * Binary boolean function requires special executor implementation because it's truth table + * handles null differently (e.g. NULL OR TRUE = TRUE). Note that unary boolean operation (currently + * only NOT) does not require special implementation because NOT NULL = NULL. + */ +struct BinaryBooleanFunctionExecutor { + + template + static inline void executeOnValueNoNull(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, uint64_t lPos, uint64_t rPos, uint64_t resPos) { + auto resValues = (uint8_t*)result.getData(); + FUNC::operation(left.getValue(lPos), right.getValue(rPos), + resValues[resPos], false /* isLeftNull */, false /* isRightNull */); + result.setNull(resPos, false /* isNull */); + } + + template + static inline void executeOnValue(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, uint64_t lPos, uint64_t rPos, uint64_t resPos) { + auto resValues = (uint8_t*)result.getData(); + FUNC::operation(left.getValue(lPos), right.getValue(rPos), + resValues[resPos], left.isNull(lPos), right.isNull(rPos)); + result.setNull(resPos, result.getValue(resPos) == NULL_BOOL); + } + + template + static inline void executeBothFlat(common::ValueVector& left, + common::SelectionVector* leftSelVector, common::ValueVector& right, + common::SelectionVector* rightSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector) { + auto lPos = (*leftSelVector)[0]; + auto rPos = (*rightSelVector)[0]; + auto resPos = (*resultSelVector)[0]; + executeOnValue(left, right, result, lPos, rPos, resPos); + } + + template + static void executeFlatUnFlat(common::ValueVector& left, common::SelectionVector* leftSelVector, + common::ValueVector& right, common::SelectionVector* rightSelVector, + common::ValueVector& result, common::SelectionVector*) { + auto lPos = (*leftSelVector)[0]; + if (rightSelVector->isUnfiltered()) { + if (right.hasNoNullsGuarantee() && !left.isNull(lPos)) { + for (auto i = 0u; i < rightSelVector->getSelSize(); ++i) { + executeOnValueNoNull(left, right, result, lPos, i, i); + } + } else { + for (auto i = 0u; i < rightSelVector->getSelSize(); ++i) { + executeOnValue(left, right, result, lPos, i, i); + } + } + } else { + if (right.hasNoNullsGuarantee() && !left.isNull(lPos)) { + for (auto i = 0u; i < rightSelVector->getSelSize(); ++i) { + auto rPos = (*rightSelVector)[i]; + executeOnValueNoNull(left, right, result, lPos, rPos, rPos); + } + } else { + for (auto i = 0u; i < rightSelVector->getSelSize(); ++i) { + auto rPos = (*rightSelVector)[i]; + executeOnValue(left, right, result, lPos, rPos, rPos); + } + } + } + } + + template + static void executeUnFlatFlat(common::ValueVector& left, common::SelectionVector* leftSelVector, + common::ValueVector& right, common::SelectionVector* rightSelVector, + common::ValueVector& result, common::SelectionVector*) { + auto rPos = (*rightSelVector)[0]; + if (leftSelVector->isUnfiltered()) { + if (left.hasNoNullsGuarantee() && !right.isNull(rPos)) { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + executeOnValueNoNull(left, right, result, i, rPos, i); + } + } else { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + executeOnValue(left, right, result, i, rPos, i); + } + } + } else { + if (left.hasNoNullsGuarantee() && !right.isNull(rPos)) { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + auto lPos = (*leftSelVector)[i]; + executeOnValueNoNull(left, right, result, lPos, rPos, lPos); + } + } else { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + auto lPos = (*leftSelVector)[i]; + executeOnValue(left, right, result, lPos, rPos, lPos); + } + } + } + } + + template + static void executeBothUnFlat(common::ValueVector& left, common::SelectionVector* leftSelVector, + common::ValueVector& right, [[maybe_unused]] common::SelectionVector* rightSelVector, + common::ValueVector& result, common::SelectionVector*) { + KU_ASSERT(leftSelVector == rightSelVector); + if (leftSelVector->isUnfiltered()) { + if (left.hasNoNullsGuarantee() && right.hasNoNullsGuarantee()) { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + executeOnValueNoNull(left, right, result, i, i, i); + } + } else { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + executeOnValue(left, right, result, i, i, i); + } + } + } else { + if (left.hasNoNullsGuarantee() && right.hasNoNullsGuarantee()) { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + auto pos = (*leftSelVector)[i]; + executeOnValueNoNull(left, right, result, pos, pos, pos); + } + } else { + for (auto i = 0u; i < leftSelVector->getSelSize(); ++i) { + auto pos = (*leftSelVector)[i]; + executeOnValue(left, right, result, pos, pos, pos); + } + } + } + } + + template + static void execute(common::ValueVector& left, common::SelectionVector* leftSelVector, + common::ValueVector& right, common::SelectionVector* rightSelVector, + common::ValueVector& result, common::SelectionVector* resultSelVector) { + KU_ASSERT(left.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL && + right.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL && + result.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL); + if (left.state->isFlat() && right.state->isFlat()) { + executeBothFlat(left, leftSelVector, right, rightSelVector, result, + resultSelVector); + } else if (left.state->isFlat() && !right.state->isFlat()) { + executeFlatUnFlat(left, leftSelVector, right, rightSelVector, result, + resultSelVector); + } else if (!left.state->isFlat() && right.state->isFlat()) { + executeUnFlatFlat(left, leftSelVector, right, rightSelVector, result, + resultSelVector); + } else { + executeBothUnFlat(left, leftSelVector, right, rightSelVector, result, + resultSelVector); + } + } + + template + static void selectOnValue(common::ValueVector& left, common::ValueVector& right, uint64_t lPos, + uint64_t rPos, uint64_t resPos, uint64_t& numSelectedValues, + std::span selectedPositionsBuffer) { + uint8_t resultValue = 0; + FUNC::operation(left.getValue(lPos), right.getValue(rPos), resultValue, + left.isNull(lPos), right.isNull(rPos)); + selectedPositionsBuffer[numSelectedValues] = resPos; + numSelectedValues += (resultValue == true); + } + + template + static bool selectBothFlat(common::ValueVector& left, common::ValueVector& right) { + auto lPos = left.state->getSelVector()[0]; + auto rPos = right.state->getSelVector()[0]; + uint8_t resultValue = 0; + FUNC::operation(left.getValue(lPos), right.getValue(rPos), resultValue, + (bool)left.isNull(lPos), (bool)right.isNull(rPos)); + return resultValue == true; + } + + template + static bool selectFlatUnFlat(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector) { + auto lPos = left.state->getSelVector()[0]; + uint64_t numSelectedValues = 0; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + auto& rightSelVector = right.state->getSelVector(); + if (rightSelVector.isUnfiltered()) { + for (auto i = 0u; i < rightSelVector.getSelSize(); ++i) { + selectOnValue(left, right, lPos, i, i, numSelectedValues, + selectedPositionsBuffer); + } + } else { + for (auto i = 0u; i < rightSelVector.getSelSize(); ++i) { + auto rPos = right.state->getSelVector()[i]; + selectOnValue(left, right, lPos, rPos, rPos, numSelectedValues, + selectedPositionsBuffer); + } + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + + template + static bool selectUnFlatFlat(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector) { + auto rPos = right.state->getSelVector()[0]; + uint64_t numSelectedValues = 0; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + auto& leftSelVector = left.state->getSelVector(); + if (leftSelVector.isUnfiltered()) { + for (auto i = 0u; i < leftSelVector.getSelSize(); ++i) { + selectOnValue(left, right, i, rPos, i, numSelectedValues, + selectedPositionsBuffer); + } + } else { + for (auto i = 0u; i < leftSelVector.getSelSize(); ++i) { + auto lPos = left.state->getSelVector()[i]; + selectOnValue(left, right, lPos, rPos, lPos, numSelectedValues, + selectedPositionsBuffer); + } + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + + template + static bool selectBothUnFlat(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector) { + uint64_t numSelectedValues = 0; + auto selectedPositionsBuffer = selVector.getMutableBuffer(); + auto& leftSelVector = left.state->getSelVector(); + if (leftSelVector.isUnfiltered()) { + for (auto i = 0u; i < leftSelVector.getSelSize(); ++i) { + selectOnValue(left, right, i, i, i, numSelectedValues, + selectedPositionsBuffer); + } + } else { + for (auto i = 0u; i < leftSelVector.getSelSize(); ++i) { + auto pos = left.state->getSelVector()[i]; + selectOnValue(left, right, pos, pos, pos, numSelectedValues, + selectedPositionsBuffer); + } + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + + template + static bool select(common::ValueVector& left, common::ValueVector& right, + common::SelectionVector& selVector) { + KU_ASSERT(left.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL && + right.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL); + if (left.state->isFlat() && right.state->isFlat()) { + return selectBothFlat(left, right); + } else if (left.state->isFlat() && !right.state->isFlat()) { + return selectFlatUnFlat(left, right, selVector); + } else if (!left.state->isFlat() && right.state->isFlat()) { + return selectUnFlatFlat(left, right, selVector); + } else { + return selectBothUnFlat(left, right, selVector); + } + } +}; + +struct UnaryBooleanOperationExecutor { + + template + static inline void executeOnValue(common::ValueVector& operand, uint64_t operandPos, + common::ValueVector& result, uint64_t resultPos) { + auto resultValues = (uint8_t*)result.getData(); + FUNC::operation(operand.getValue(operandPos), operand.isNull(operandPos), + resultValues[resultPos]); + result.setNull(resultPos, result.getValue(resultPos) == NULL_BOOL); + } + + template + static void executeSwitch(common::ValueVector& operand, + common::SelectionVector* operandSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector) { + result.resetAuxiliaryBuffer(); + if (operand.state->isFlat()) { + auto pos = (*operandSelVector)[0]; + auto resultPos = (*resultSelVector)[0]; + executeOnValue(operand, pos, result, resultPos); + } else { + if (operandSelVector->isUnfiltered()) { + for (auto i = 0u; i < operandSelVector->getSelSize(); i++) { + executeOnValue(operand, i, result, i); + } + } else { + for (auto i = 0u; i < operandSelVector->getSelSize(); i++) { + auto pos = (*operandSelVector)[i]; + executeOnValue(operand, pos, result, pos); + } + } + } + } + + template + static inline void execute(common::ValueVector& operand, + common::SelectionVector* operandSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector) { + executeSwitch(operand, operandSelVector, result, resultSelVector); + } + + template + static inline void selectOnValue(common::ValueVector& operand, uint64_t operandPos, + uint64_t& numSelectedValues, std::span selectedPositionsBuffer) { + uint8_t resultValue = 0; + FUNC::operation(operand.getValue(operandPos), operand.isNull(operandPos), + resultValue); + selectedPositionsBuffer[numSelectedValues] = operandPos; + numSelectedValues += resultValue == true; + } + + template + static bool select(common::ValueVector& operand, common::SelectionVector& selVector) { + if (operand.state->isFlat()) { + auto pos = operand.state->getSelVector()[0]; + uint8_t resultValue = 0; + FUNC::operation(operand.getValue(pos), operand.isNull(pos), resultValue); + return resultValue == true; + } else { + auto& operandSelVector = operand.state->getSelVector(); + uint64_t numSelectedValues = 0; + auto selectedPositionBuffer = selVector.getMutableBuffer(); + if (operandSelVector.isUnfiltered()) { + for (auto i = 0ul; i < operandSelVector.getSelSize(); i++) { + selectOnValue(operand, i, numSelectedValues, selectedPositionBuffer); + } + } else { + for (auto i = 0ul; i < operandSelVector.getSelSize(); i++) { + auto pos = operand.state->getSelVector()[i]; + selectOnValue(operand, pos, numSelectedValues, selectedPositionBuffer); + } + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/boolean_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/boolean_functions.h new file mode 100644 index 0000000000..73d637832e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/boolean_functions.h @@ -0,0 +1,121 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +/** + * The boolean operators (AND, OR, XOR, NOT) works a little differently from other operators. While + * other operators can operate on only non null operands, boolean operators can operate even with + * null operands in certain cases, for instance, Null OR True = True. Hence, the result value of + * the boolean operator can be True, False or Null. To accommodate for this, the dataType of + * result is uint8_t (that can have more than 2 values) rather than bool. In case, the result is + * computed to be Null based on the operands, we set result = NULL_BOOL, which should rightly be + * interpreted by operator executors as NULL and not as True. + * */ + +/** + * IMPORTANT: Not to be used outside the context of boolean operators. + * */ +const uint8_t NULL_BOOL = 2; + +/** + * AND operator Truth table: + * + * left isLeftNull right isRightNull result + * ------ ------------ ------- ------------- -------- + * T F T F 1 + * T F F F 0 + * F F T F 0 + * F F F F 0 + * - T T F 2 + * - T F F 0 + * T F - T 2 + * F F - T 0 + * - T - T 2 + * */ +struct And { + static inline void operation(bool left, bool right, uint8_t& result, bool isLeftNull, + bool isRightNull) { + if ((!left && !isLeftNull) || (!right && !isRightNull)) { + result = false; + } else if (isLeftNull || isRightNull) { + result = NULL_BOOL; + } else { + result = true; + } + } +}; + +/** + * OR operator Truth table: + * + * left isLeftNull right isRightNull result + * ------ ------------ ------- ------------- -------- + * T F T F 1 + * T F F F 1 + * F F T F 1 + * F F F F 0 + * - T T F 1 + * - T F F 2 + * T F - T 1 + * F F - T 2 + * - T - T 2 + * */ +struct Or { + static inline void operation(bool left, bool right, uint8_t& result, bool isLeftNull, + bool isRightNull) { + if ((left && !isLeftNull) || (right && !isRightNull)) { + result = true; + } else if (isLeftNull || isRightNull) { + result = NULL_BOOL; + } else { + result = false; + } + } +}; + +/** + * XOR operator Truth table: + * + * left isLeftNull right isRightNull result + * ------ ------------ ------- ------------- -------- + * T F T F 0 + * T F F F 1 + * F F T F 1 + * F F F F 0 + * - T T F 2 + * - T F F 2 + * T F - T 2 + * F F - T 2 + * - T - T 2 + * */ +struct Xor { + static inline void operation(bool left, bool right, uint8_t& result, bool isLeftNull, + bool isRightNull) { + if (isLeftNull || isRightNull) { + result = NULL_BOOL; + } else { + result = left ^ right; + } + } +}; + +/** + * NOT operator Truth table: + * + * operand isNull right + * --------- ------------ ------- + * T F 0 + * F F 1 + * - T 2 + * */ +struct Not { + static inline void operation(bool operand, bool isNull, uint8_t& result) { + result = isNull ? NULL_BOOL : !operand; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/vector_boolean_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/vector_boolean_functions.h new file mode 100644 index 0000000000..b27592b7b9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/boolean/vector_boolean_functions.h @@ -0,0 +1,68 @@ +#pragma once + +#include "boolean_function_executor.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +class VectorBooleanFunction { +public: + static void bindExecFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func); + + static void bindSelectFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func); + +private: + template + static void BinaryBooleanExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryBooleanFunctionExecutor::execute(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector); + } + + template + static bool BinaryBooleanSelectFunction( + const std::vector>& params, + common::SelectionVector& selVector, void* /*dataPtr*/) { + KU_ASSERT(params.size() == 2); + return BinaryBooleanFunctionExecutor::select(*params[0], *params[1], selVector); + } + + template + static void UnaryBooleanExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 1); + UnaryBooleanOperationExecutor::execute(*params[0], paramSelVectors[0], result, + resultSelVector); + } + + template + static bool UnaryBooleanSelectFunction( + const std::vector>& params, + common::SelectionVector& selVector, void* /*dataPtr*/) { + KU_ASSERT(params.size() == 1); + return UnaryBooleanOperationExecutor::select(*params[0], selVector); + } + + static void bindBinaryExecFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func); + + static void bindBinarySelectFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func); + + static void bindUnaryExecFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func); + + static void bindUnarySelectFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/built_in_function_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/built_in_function_utils.h new file mode 100644 index 0000000000..2b41914dd8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/built_in_function_utils.h @@ -0,0 +1,100 @@ +#pragma once + +#include "aggregate_function.h" +#include "catalog/catalog_entry/catalog_entry_type.h" +#include "function.h" + +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace catalog { +class FunctionCatalogEntry; +} // namespace catalog + +namespace function { + +class BuiltInFunctionsUtils { +public: + // TODO(Ziyi): We should have a unified interface for matching table, aggregate and scalar + // functions. + static LBUG_API Function* matchFunction(const std::string& name, + const catalog::FunctionCatalogEntry* catalogEntry) { + return matchFunction(name, {}, catalogEntry); + } + static LBUG_API Function* matchFunction(const std::string& name, + const std::vector& inputTypes, + const catalog::FunctionCatalogEntry* functionEntry); + + static AggregateFunction* matchAggregateFunction(const std::string& name, + const std::vector& inputTypes, bool isDistinct, + const catalog::FunctionCatalogEntry* functionEntry); + + static LBUG_API uint32_t getCastCost(common::LogicalTypeID inputTypeID, + common::LogicalTypeID targetTypeID); + + static LBUG_API std::string getFunctionMatchFailureMsg(const std::string name, + const std::vector& inputTypes, const std::string& supportedInputs, + bool isDistinct = false); + +private: + // TODO(Xiyang): move casting cost related functions to binder. + static uint32_t getTargetTypeCost(common::LogicalTypeID typeID); + + static uint32_t castInt64(common::LogicalTypeID targetTypeID); + + static uint32_t castInt32(common::LogicalTypeID targetTypeID); + + static uint32_t castInt16(common::LogicalTypeID targetTypeID); + + static uint32_t castInt8(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt64(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt32(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt16(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt8(common::LogicalTypeID targetTypeID); + + static uint32_t castInt128(common::LogicalTypeID targetTypeID); + + static uint32_t castDouble(common::LogicalTypeID targetTypeID); + + static uint32_t castFloat(common::LogicalTypeID targetTypeID); + + static uint32_t castDecimal(common::LogicalTypeID targetTypeID); + + static uint32_t castDate(common::LogicalTypeID targetTypeID); + + static uint32_t castSerial(common::LogicalTypeID targetTypeID); + + static uint32_t castTimestamp(common::LogicalTypeID targetTypeID); + + static uint32_t castFromString(common::LogicalTypeID inputTypeID); + + static uint32_t castUUID(common::LogicalTypeID targetTypeID); + + static uint32_t castList(common::LogicalTypeID targetTypeID); + + static uint32_t castArray(common::LogicalTypeID targetTypeID); + + static Function* getBestMatch(std::vector& functions); + + static uint32_t getFunctionCost(const std::vector& inputTypes, + Function* function, catalog::CatalogEntryType type); + static uint32_t matchParameters(const std::vector& inputTypes, + const std::vector& targetTypeIDs); + static uint32_t matchVarLengthParameters(const std::vector& inputTypes, + common::LogicalTypeID targetTypeID); + static uint32_t getAggregateFunctionCost(const std::vector& inputTypes, + bool isDistinct, AggregateFunction* function); + + static void validateSpecialCases(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + const function::function_set& set); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/cast_function_bind_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/cast_function_bind_data.h new file mode 100644 index 0000000000..58e9d93703 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/cast_function_bind_data.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common/copier_config/csv_reader_config.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct CastFunctionBindData : public FunctionBindData { + // We don't allow configuring delimiters, ... in CAST function. + // For performance purpose, we generate a default option object during binding time. + common::CSVOption option; + // TODO(Mahn): the following field should be removed once we refactor fixed list. + uint64_t numOfEntries; + + explicit CastFunctionBindData(common::LogicalType dataType) + : FunctionBindData{std::move(dataType)}, numOfEntries{0} {} + + inline std::unique_ptr copy() const override { + auto result = std::make_unique(resultType.copy()); + result->numOfEntries = numOfEntries; + result->option = option.copy(); + return result; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/cast_union_bind_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/cast_union_bind_data.h new file mode 100644 index 0000000000..4fce9f2217 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/cast_union_bind_data.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common/data_chunk/sel_vector.h" +#include "common/types/types.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct CastToUnionBindData : public FunctionBindData { + using inner_func_t = std::function; + + common::union_field_idx_t targetTag; + inner_func_t innerFunc; + + CastToUnionBindData(common::union_field_idx_t targetTag, inner_func_t innerFunc, + common::LogicalType dataType) + : FunctionBindData{std::move(dataType)}, targetTag{targetTag}, + innerFunc{std::move(innerFunc)} {} + + std::unique_ptr copy() const override { + return std::make_unique(targetTag, innerFunc, resultType.copy()); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_array.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_array.h new file mode 100644 index 0000000000..ca547678da --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_array.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/types/types.h" +#include "common/vector/value_vector.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct CastArrayHelper { + static bool checkCompatibleNestedTypes(LogicalTypeID sourceTypeID, LogicalTypeID targetTypeID); + + static bool isUnionSpecialCast(const LogicalType& srcType, const LogicalType& dstType); + + static bool containsListToArray(const LogicalType& srcType, const LogicalType& dstType); + + static void validateListEntry(ValueVector* inputVector, const LogicalType& resultType, + uint64_t pos); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_decimal.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_decimal.h new file mode 100644 index 0000000000..52de3b6ff4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_decimal.h @@ -0,0 +1,155 @@ +#pragma once + +#include +#include + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "common/types/int128_t.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "function/cast/functions/cast_string_non_nested_functions.h" +#include "function/cast/functions/numeric_limits.h" + +namespace lbug { +namespace function { + +template +struct pickDecimalPhysicalType { + static constexpr bool AISFLOAT = std::is_floating_point::value; + static constexpr bool BISFLOAT = std::is_floating_point::value; + using RES = + typename std::conditional<(AISFLOAT ? false : (BISFLOAT ? true : sizeof(A) > sizeof(B))), A, + B>::type; +}; + +struct CastDecimalTo { + template + static void operation(SRC& input, DST& output, const ValueVector& inputVec, + ValueVector& outputVec) { + using T = typename pickDecimalPhysicalType::RES; + constexpr auto pow10s = pow10Sequence(); + auto scale = DecimalType::getScale(inputVec.dataType); + if constexpr (std::is_floating_point::value) { + output = (DST)input / (DST)pow10s[scale]; + } else { + auto roundconst = (input < 0 ? -5 : 5); + auto tmp = ((scale > 0 ? pow10s[scale - 1] * roundconst : 0) + input) / pow10s[scale]; + if (tmp < NumericLimits::minimum() || tmp > NumericLimits::maximum()) { + throw OverflowException(stringFormat("Cast Failed: {} is not in {} range", + DecimalType::insertDecimalPoint(TypeUtils::toString(input), scale), + outputVec.dataType.toString())); + } + output = (DST)tmp; + } + } +}; + +struct CastToDecimal { + template + static void operation(SRC& input, DST& output, const ValueVector&, + const ValueVector& outputVec) { + constexpr auto pow10s = pow10Sequence(); + auto precision = DecimalType::getPrecision(outputVec.dataType); + auto scale = DecimalType::getScale(outputVec.dataType); + if constexpr (std::is_floating_point::value) { + auto roundconst = (input < 0 ? -0.5 : 0.5); + output = (DST)((double)pow10s[scale] * input + roundconst); + } else { + output = (DST)(pow10s[scale] * input); + } + if (output <= -pow10s[precision] || output >= pow10s[precision]) { + throw OverflowException(stringFormat("To Decimal Cast Failed: {} is not in {} range", + TypeUtils::toString(input), outputVec.dataType.toString())); + } + } +}; + +struct CastBetweenDecimal { + template + static void operation(SRC& input, DST& output, const ValueVector& inputVec, + const ValueVector& outputVec) { + using T = typename pickDecimalPhysicalType::RES; + constexpr auto pow10s = pow10Sequence(); + auto outputPrecision = DecimalType::getPrecision(outputVec.dataType); + auto inputScale = DecimalType::getScale(inputVec.dataType); + auto outputScale = DecimalType::getScale(outputVec.dataType); + if (inputScale == outputScale) { + output = (DST)input; + } else if (inputScale < outputScale) { + output = (DST)(pow10s[outputScale - inputScale] * input); + } else { + auto roundconst = (input < 0 ? -5 : 5); + output = (DST)((pow10s[inputScale - outputScale - 1] * roundconst + input) / + pow10s[inputScale - outputScale]); + } + if (pow10s[outputPrecision] <= output || -pow10s[outputPrecision] >= output) { + throw OverflowException(stringFormat( + "Decimal Cast Failed: input {} is not in range of {}", + DecimalType::insertDecimalPoint(TypeUtils::toString(input, nullptr), inputScale), + outputVec.dataType.toString())); + } + } +}; + +// DECIMAL TO STRING SPECIALIZATION +template<> +inline void CastDecimalTo::operation(int16_t& input, ku_string_t& output, + const ValueVector& inputVec, ValueVector& resultVector) { + auto scale = DecimalType::getScale(inputVec.dataType); + auto str = DecimalType::insertDecimalPoint(std::to_string(input), scale); + common::StringVector::addString(&resultVector, output, str); +} + +template<> +inline void CastDecimalTo::operation(int32_t& input, ku_string_t& output, + const ValueVector& inputVec, ValueVector& resultVector) { + auto scale = DecimalType::getScale(inputVec.dataType); + auto str = DecimalType::insertDecimalPoint(std::to_string(input), scale); + common::StringVector::addString(&resultVector, output, str); +} + +template<> +inline void CastDecimalTo::operation(int64_t& input, ku_string_t& output, + const ValueVector& inputVec, ValueVector& resultVector) { + auto scale = DecimalType::getScale(inputVec.dataType); + auto str = DecimalType::insertDecimalPoint(std::to_string(input), scale); + common::StringVector::addString(&resultVector, output, str); +} + +template<> +inline void CastDecimalTo::operation(common::int128_t& input, ku_string_t& output, + const ValueVector& inputVec, ValueVector& resultVector) { + auto scale = DecimalType::getScale(inputVec.dataType); + auto str = DecimalType::insertDecimalPoint(common::Int128_t::toString(input), scale); + common::StringVector::addString(&resultVector, output, str); +} + +// STRING TO DECIMAL SPECIALIZATION +template<> +inline void CastToDecimal::operation(ku_string_t& input, int16_t& output, const ValueVector&, + const ValueVector& outputVec) { + decimalCast((const char*)input.getData(), input.len, output, outputVec.dataType); +} + +template<> +inline void CastToDecimal::operation(ku_string_t& input, int32_t& output, const ValueVector&, + const ValueVector& outputVec) { + decimalCast((const char*)input.getData(), input.len, output, outputVec.dataType); +} + +template<> +inline void CastToDecimal::operation(ku_string_t& input, int64_t& output, const ValueVector&, + const ValueVector& outputVec) { + decimalCast((const char*)input.getData(), input.len, output, outputVec.dataType); +} + +template<> +inline void CastToDecimal::operation(ku_string_t& input, common::int128_t& output, + const ValueVector&, const ValueVector& outputVec) { + decimalCast((const char*)input.getData(), input.len, output, outputVec.dataType); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_from_string_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_from_string_functions.h new file mode 100644 index 0000000000..e9ee1350de --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_from_string_functions.h @@ -0,0 +1,190 @@ +#pragma once + +#include "cast_string_non_nested_functions.h" +#include "common/copier_config/csv_reader_config.h" +#include "common/type_utils.h" +#include "common/types/blob.h" +#include "common/types/uuid.h" +#include "common/vector/value_vector.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +struct LBUG_API CastString { + static void copyStringToVector(ValueVector* vector, uint64_t vectorPos, std::string_view strVal, + const CSVOption* option); + + template + static inline bool tryCast(const ku_string_t& input, T& result) { + // try cast for signed integer types (not including int128) + return trySimpleIntegerCast(reinterpret_cast(input.getData()), + input.len, result); + } + + template + static inline void operation(const ku_string_t& input, T& result, + ValueVector* /*resultVector*/ = nullptr, uint64_t /*rowToAdd*/ = 0, + const CSVOption* /*option*/ = nullptr) { + // base case: int64 + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, + result, LogicalTypeID::INT64); + } +}; + +template<> +inline void CastString::operation(const ku_string_t& input, int128_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, result, + LogicalTypeID::INT128); +} + +template<> +inline void CastString::operation(const ku_string_t& input, uint128_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, + result, LogicalTypeID::UINT128); +} + +template<> +inline void CastString::operation(const ku_string_t& input, int32_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, result, + LogicalTypeID::INT32); +} + +template<> +inline void CastString::operation(const ku_string_t& input, int16_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, result, + LogicalTypeID::INT16); +} + +template<> +inline void CastString::operation(const ku_string_t& input, int8_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, result, + LogicalTypeID::INT8); +} + +template<> +inline void CastString::operation(const ku_string_t& input, uint64_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, + result, LogicalTypeID::UINT64); +} + +template<> +inline void CastString::operation(const ku_string_t& input, uint32_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, + result, LogicalTypeID::UINT32); +} + +template<> +inline void CastString::operation(const ku_string_t& input, uint16_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, + result, LogicalTypeID::UINT16); +} + +template<> +inline void CastString::operation(const ku_string_t& input, uint8_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + simpleIntegerCast(reinterpret_cast(input.getData()), input.len, + result, LogicalTypeID::UINT8); +} + +template<> +inline void CastString::operation(const ku_string_t& input, float& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + doubleCast(reinterpret_cast(input.getData()), input.len, result, + LogicalTypeID::FLOAT); +} + +template<> +inline void CastString::operation(const ku_string_t& input, double& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + doubleCast(reinterpret_cast(input.getData()), input.len, result, + LogicalTypeID::DOUBLE); +} + +template<> +inline void CastString::operation(const ku_string_t& input, date_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result = Date::fromCString((const char*)input.getData(), input.len); +} + +template<> +inline void CastString::operation(const ku_string_t& input, timestamp_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result = Timestamp::fromCString((const char*)input.getData(), input.len); +} + +template<> +inline void CastString::operation(const ku_string_t& input, timestamp_ns_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast((const char*)input.getData(), input.len, result, + LogicalTypeID::TIMESTAMP_NS); +} + +template<> +inline void CastString::operation(const ku_string_t& input, timestamp_ms_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast((const char*)input.getData(), input.len, result, + LogicalTypeID::TIMESTAMP_MS); +} + +template<> +inline void CastString::operation(const ku_string_t& input, timestamp_sec_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast((const char*)input.getData(), input.len, result, + LogicalTypeID::TIMESTAMP_SEC); +} + +template<> +inline void CastString::operation(const ku_string_t& input, timestamp_tz_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + TryCastStringToTimestamp::cast((const char*)input.getData(), input.len, result, + LogicalTypeID::TIMESTAMP_TZ); +} + +template<> +inline void CastString::operation(const ku_string_t& input, interval_t& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + result = Interval::fromCString((const char*)input.getData(), input.len); +} + +template<> +inline void CastString::operation(const ku_string_t& input, bool& result, + ValueVector* /*resultVector*/, uint64_t /*rowToAdd*/, const CSVOption* /*option*/) { + castStringToBool(reinterpret_cast(input.getData()), input.len, result); +} + +template<> +void CastString::operation(const ku_string_t& input, blob_t& result, ValueVector* resultVector, + uint64_t rowToAdd, const CSVOption* option); + +template<> +void CastString::operation(const ku_string_t& input, ku_uuid_t& result, ValueVector* result_vector, + uint64_t rowToAdd, const CSVOption* option); + +template<> +void CastString::operation(const ku_string_t& input, list_entry_t& result, + ValueVector* resultVector, uint64_t rowToAdd, const CSVOption* option); + +template<> +void CastString::operation(const ku_string_t& input, map_entry_t& result, ValueVector* resultVector, + uint64_t rowToAdd, const CSVOption* option); + +template<> +void CastString::operation(const ku_string_t& input, struct_entry_t& result, + ValueVector* resultVector, uint64_t rowToAdd, const CSVOption* option); + +template<> +void CastString::operation(const ku_string_t& input, union_entry_t& result, + ValueVector* resultVector, uint64_t rowToAdd, const CSVOption* option); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_functions.h new file mode 100644 index 0000000000..da3564908b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_functions.h @@ -0,0 +1,415 @@ +#pragma once + +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "common/types/int128_t.h" +#include "common/types/uint128_t.h" +#include "common/vector/value_vector.h" +#include "function/cast/cast_union_bind_data.h" +#include "function/cast/functions/numeric_cast.h" + +namespace lbug { +namespace function { + +struct CastToString { + template + static inline void operation(T& input, common::ku_string_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto str = common::TypeUtils::toString(input, (void*)&inputVector); + common::StringVector::addString(&resultVector, result, str); + } +}; + +struct CastNodeToString { + static inline void operation(common::struct_entry_t& input, common::ku_string_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto str = common::TypeUtils::nodeToString(input, &inputVector); + common::StringVector::addString(&resultVector, result, str); + } +}; + +struct CastRelToString { + static inline void operation(common::struct_entry_t& input, common::ku_string_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto str = common::TypeUtils::relToString(input, &inputVector); + common::StringVector::addString(&resultVector, result, str); + } +}; + +struct CastToUnion { + static inline void operation(common::ValueVector& inputVector, + common::ValueVector& resultVector, uint64_t inputPos, uint64_t resultPos, void* dataPtr) { + const auto& bindData = *reinterpret_cast(dataPtr); + auto& tagVector = *common::UnionVector::getTagVector(&resultVector); + auto& valVector = *common::UnionVector::getValVector(&resultVector, bindData.targetTag); + tagVector.setValue(resultPos, bindData.targetTag); + bindData.innerFunc(&inputVector, valVector, inputVector.getSelVectorPtr(), inputPos, + resultPos); + } +}; + +struct CastDateToTimestamp { + template + static inline void operation(common::date_t& input, T& result) { + // base case: timestamp + result = common::Timestamp::fromDateTime(input, common::dtime_t{}); + } +}; + +template<> +inline void CastDateToTimestamp::operation(common::date_t& input, common::timestamp_ns_t& result) { + operation(input, result); + result = common::timestamp_ns_t{common::Timestamp::getEpochNanoSeconds(result)}; +} + +template<> +inline void CastDateToTimestamp::operation(common::date_t& input, common::timestamp_ms_t& result) { + operation(input, result); + result.value /= common::Interval::MICROS_PER_MSEC; +} + +template<> +inline void CastDateToTimestamp::operation(common::date_t& input, common::timestamp_sec_t& result) { + operation(input, result); + result.value /= common::Interval::MICROS_PER_SEC; +} + +struct CastToDate { + template + static inline void operation(T& input, common::date_t& result); +}; + +template<> +inline void CastToDate::operation(common::timestamp_t& input, common::date_t& result) { + result = common::Timestamp::getDate(input); +} + +template<> +inline void CastToDate::operation(common::timestamp_ns_t& input, common::date_t& result) { + auto tmp = common::Timestamp::fromEpochNanoSeconds(input.value); + operation(tmp, result); +} + +template<> +inline void CastToDate::operation(common::timestamp_ms_t& input, common::date_t& result) { + auto tmp = common::Timestamp::fromEpochMilliSeconds(input.value); + operation(tmp, result); +} + +template<> +inline void CastToDate::operation(common::timestamp_sec_t& input, common::date_t& result) { + auto tmp = common::Timestamp::fromEpochSeconds(input.value); + operation(tmp, result); +} + +struct CastToDouble { + template + static inline void operation(T& input, double& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within DOUBLE range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToDouble::operation(common::int128_t& input, double& result) { + if (!common::Int128_t::tryCast(input, result)) { // LCOV_EXCL_START + throw common::OverflowException{common::stringFormat("Value {} is not within DOUBLE range", + common::TypeUtils::toString(input))}; + } // LCOV_EXCL_STOP +} + +struct CastToFloat { + template + static inline void operation(T& input, float& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within FLOAT range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToFloat::operation(common::int128_t& input, float& result) { + if (!common::Int128_t::tryCast(input, result)) { // LCOV_EXCL_START + throw common::OverflowException{common::stringFormat("Value {} is not within FLOAT range", + common::TypeUtils::toString(input))}; + }; // LCOV_EXCL_STOP +} + +struct CastToInt128 { + template + static inline void operation(T& input, common::int128_t& result) { + common::Int128_t::tryCastTo(input, result); + } +}; + +struct CastToUInt128 { + template + static inline void operation(T& input, common::uint128_t& result) { + common::UInt128_t::tryCastTo(input, result); + } +}; + +template<> +inline void CastToInt128::operation(common::uint128_t& input, common::int128_t& result) { + result = (common::int128_t)input; +} + +template<> +inline void CastToUInt128::operation(common::int128_t& input, common::uint128_t& result) { + result = (common::uint128_t)input; +} + +struct CastToInt64 { + template + static inline void operation(T& input, int64_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT64 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToInt64::operation(common::int128_t& input, int64_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within INT64 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToSerial { + template + static inline void operation(T& input, int64_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT64 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToSerial::operation(common::int128_t& input, int64_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within INT64 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToInt32 { + template + static inline void operation(T& input, int32_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT32 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToInt32::operation(common::int128_t& input, int32_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within INT32 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToInt16 { + template + static inline void operation(T& input, int16_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT16 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToInt16::operation(common::int128_t& input, int16_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within INT16 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToInt8 { + template + static inline void operation(T& input, int8_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT8 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToInt8::operation(common::int128_t& input, int8_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within INT8 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToUInt64 { + template + static inline void operation(T& input, uint64_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT64 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToUInt64::operation(common::int128_t& input, uint64_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within UINT64 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToUInt32 { + template + static inline void operation(T& input, uint32_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT32 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToUInt32::operation(common::int128_t& input, uint32_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within UINT32 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToUInt16 { + template + static inline void operation(T& input, uint16_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT16 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToUInt16::operation(common::int128_t& input, uint16_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within UINT16 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastToUInt8 { + template + static inline void operation(T& input, uint8_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT8 range", common::TypeUtils::toString(input))}; + } + } +}; + +template<> +inline void CastToUInt8::operation(common::int128_t& input, uint8_t& result) { + if (!common::Int128_t::tryCast(input, result)) { + throw common::OverflowException{common::stringFormat("Value {} is not within UINT8 range", + common::TypeUtils::toString(input))}; + }; +} + +struct CastBetweenTimestamp { + template + static void operation(const SRC_TYPE& input, DST_TYPE& result) { + // base case: same type + result.value = input.value; + } +}; + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_t& input, + common::timestamp_ns_t& output) { + output.value = common::Timestamp::getEpochNanoSeconds(input); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_t& input, + common::timestamp_ms_t& output) { + output.value = common::Timestamp::getEpochMilliSeconds(input); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_t& input, + common::timestamp_sec_t& output) { + output.value = common::Timestamp::getEpochSeconds(input); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_ms_t& input, + common::timestamp_t& output) { + output = common::Timestamp::fromEpochMilliSeconds(input.value); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_ms_t& input, + common::timestamp_ns_t& output) { + operation(input, output); + operation(output, output); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_ms_t& input, + common::timestamp_sec_t& output) { + operation(input, output); + operation(output, output); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_ns_t& input, + common::timestamp_t& output) { + output = common::Timestamp::fromEpochNanoSeconds(input.value); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_ns_t& input, + common::timestamp_ms_t& output) { + operation(input, output); + operation(output, output); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_ns_t& input, + common::timestamp_sec_t& output) { + operation(input, output); + operation(output, output); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_sec_t& input, + common::timestamp_t& output) { + output = common::Timestamp::fromEpochSeconds(input.value); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_sec_t& input, + common::timestamp_ns_t& output) { + operation(input, output); + operation(output, output); +} + +template<> +inline void CastBetweenTimestamp::operation(const common::timestamp_sec_t& input, + common::timestamp_ms_t& output) { + operation(input, output); + operation(output, output); +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_string_non_nested_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_string_non_nested_functions.h new file mode 100644 index 0000000000..70900bc0e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/cast_string_non_nested_functions.h @@ -0,0 +1,400 @@ +#pragma once + +#include "common/constants.h" +#include "common/exception/conversion.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "common/types/int128_t.h" +#include "common/types/timestamp_t.h" +#include "common/types/types.h" +#include "common/types/uint128_t.h" +#include "fast_float.h" +#include "function/cast/functions/numeric_limits.h" + +using namespace lbug::common; + +namespace lbug { +namespace function { + +bool isAnyType(std::string_view cpy); + +LogicalType LBUG_API inferMinimalTypeFromString(const std::string& str); +LogicalType LBUG_API inferMinimalTypeFromString(std::string_view str); +// Infer the type that the string represents. +// Note: minimal integer width is int64 +// Used for sniffing + +// cast string to numerical +template +struct IntegerCastData {}; + +template +struct IntegerCastData { + using Result = T; + Result result; +}; + +template<> +struct IntegerCastData { + int128_t result = 0; + int64_t intermediate = 0; + uint8_t digits = 0; + bool decimal = false; + + bool flush() { + if (digits == 0 && intermediate == 0) { + return true; + } + if (result.low != 0 || result.high != 0) { + if (digits > DECIMAL_PRECISION_LIMIT) { + return false; + } + if (!Int128_t::tryMultiply(result, function::pow10Sequence()[digits], + result)) { + return false; + } + } + if (!Int128_t::addInPlace(result, int128_t(intermediate))) { + return false; + } + digits = 0; + intermediate = 0; + return true; + } +}; + +template<> +struct IntegerCastData { + uint128_t result = 0; + uint64_t intermediate = 0; + uint8_t digits = 0; + bool decimal = false; + + bool flush() { + if (digits == 0 && intermediate == 0) { + return true; + } + if (result.low != 0 || result.high != 0) { + if (digits > DECIMAL_PRECISION_LIMIT) { + return false; + } + if (!UInt128_t::tryMultiply(result, function::pow10Sequence()[digits], + result)) { + return false; + } + } + if (!UInt128_t::addInPlace(result, uint128_t(intermediate))) { + return false; + } + digits = 0; + intermediate = 0; + return true; + } +}; + +template +struct IntegerCastOperation {}; + +template +struct IntegerCastOperation { + using CastData = IntegerCastData; + + template + static bool handleDigit(CastData& state, uint8_t digit) { + using result_t = typename CastData::Result; + if constexpr (NEGATIVE) { + if (state.result < ((std::numeric_limits::min() + digit) / 10)) { + return false; + } + state.result = state.result * 10 - digit; + } else { + if (state.result > ((std::numeric_limits::max() - digit) / 10)) { + return false; + } + state.result = state.result * 10 + digit; + } + return true; + } + + // TODO(Kebing): handle decimals + static bool finalize(CastData& /*state*/) { return true; } +}; + +// cast string to bool +bool tryCastToBool(const char* input, uint64_t len, bool& result); +void LBUG_API castStringToBool(const char* input, uint64_t len, bool& result); + +template<> +struct IntegerCastOperation { + using CastData = IntegerCastData; + + template + static bool handleDigit(CastData& result, uint8_t digit) { + if constexpr (NEGATIVE) { + if (result.intermediate < (NumericLimits::minimum() + digit) / 10) { + if (!result.flush()) { + return false; + } + } + result.intermediate *= 10; + result.intermediate -= digit; + } else { + if (result.intermediate > (std::numeric_limits::max() - digit) / 10) { + if (!result.flush()) { + return false; + } + } + result.intermediate *= 10; + result.intermediate += digit; + } + result.digits++; + return true; + } + + static bool finalize(CastData& result) { return result.flush(); } +}; + +template<> +struct IntegerCastOperation { + using CastData = IntegerCastData; + + template + static bool handleDigit(CastData& result, uint8_t digit) { + if constexpr (NEGATIVE) { + if (result.intermediate < digit / 10) { + if (!result.flush()) { + return false; + } + } + result.intermediate *= 10; + result.intermediate -= digit; + } else { + if (result.intermediate > (std::numeric_limits::max() - digit) / 10) { + if (!result.flush()) { + return false; + } + } + result.intermediate *= 10; + result.intermediate += digit; + } + result.digits++; + return true; + } + + static bool finalize(CastData& result) { return result.flush(); } +}; + +// cast string to bool +bool tryCastToBool(const char* input, uint64_t len, bool& result); +void LBUG_API castStringToBool(const char* input, uint64_t len, bool& result); + +// cast to numerical values +// TODO(Kebing): support exponent + decimal +template +inline bool integerCastLoop(const char* input, uint64_t len, IntegerCastData& result) { + using OP = IntegerCastOperation; + auto start_pos = 0u; + if (NEGATIVE) { + start_pos = 1; + } + auto pos = start_pos; + while (pos < len) { + if (!StringUtils::CharacterIsDigit(input[pos])) { + return false; + } + uint8_t digit = input[pos++] - '0'; + if (!OP::template handleDigit(result, digit)) { + return false; + } + } // append all digits to result + if (!OP::finalize(result)) { + return false; + } + return pos > start_pos; // false if no digits "" or "-" +} + +template +inline bool tryIntegerCast(const char* input, uint64_t& len, IntegerCastData& result) { + StringUtils::removeCStringWhiteSpaces(input, len); + if (len == 0) { + return false; + } + + // negative + if (*input == '-') { + if constexpr (!IS_SIGNED) { // unsigned if not -0 + uint64_t pos = 1; + while (pos < len) { + if (input[pos++] != '0') { + return false; + } + } + } + // decimal separator is default to "." + return integerCastLoop(input, len, result); + } + + // not allow leading 0 + if (len > 1 && *input == '0') { + return false; + } + return integerCastLoop(input, len, result); +} + +template +inline bool trySimpleIntegerCast(const char* input, uint64_t len, T& result) { + IntegerCastData data{}; + data.result = 0; + if (tryIntegerCast(input, len, data)) { + result = data.result; + return true; + } + return false; +} + +template +inline void simpleIntegerCast(const char* input, uint64_t len, T& result, LogicalTypeID typeID) { + if (!trySimpleIntegerCast(input, len, result)) { + throw ConversionException(stringFormat("Cast failed. Could not convert \"{}\" to {}.", + std::string{input, (size_t)len}, LogicalTypeUtils::toString(typeID))); + } +} + +template +inline bool tryDoubleCast(const char* input, uint64_t len, T& result) { + StringUtils::removeCStringWhiteSpaces(input, len); + if (len == 0) { + return false; + } + // not allow leading 0 + if (len > 1 && *input == '0') { + if (StringUtils::CharacterIsDigit(input[1])) { + return false; + } + } + auto end = input + len; + auto parse_result = lbug_fast_float::from_chars(input, end, result); + if (parse_result.ec != std::errc()) { + return false; + } + return parse_result.ptr == end; +} + +template +inline void doubleCast(const char* input, uint64_t len, T& result, + LogicalTypeID typeID = LogicalTypeID::ANY) { + if (!tryDoubleCast(input, len, result)) { + throw ConversionException(stringFormat("Cast failed. {} is not in {} range.", + std::string{input, (size_t)len}, LogicalTypeUtils::toString(typeID))); + } +} + +// ---------------------- try cast String to Timestamp -------------------- // +struct TryCastStringToTimestamp { + template + static bool tryCast(const char* input, uint64_t len, timestamp_t& result); + + template + static void cast(const char* input, uint64_t len, timestamp_t& result, LogicalTypeID typeID) { + if (!tryCast(input, len, result)) { + throw ConversionException(Timestamp::getTimestampConversionExceptionMsg(input, len, + LogicalTypeUtils::toString(typeID))); + } + } +}; + +template<> +bool TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result); + +template<> +bool TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result); + +template<> +bool TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result); + +template<> +bool inline TryCastStringToTimestamp::tryCast(const char* input, uint64_t len, + timestamp_t& result) { + return Timestamp::tryConvertTimestamp(input, len, result); +} + +// ---------------------- cast String to Decimal -------------------- // + +template +bool tryDecimalCast(const char* input, uint64_t len, T& result, uint32_t precision, + uint32_t scale) { + constexpr auto pow10s = pow10Sequence(); + using CAST_OP = IntegerCastOperation; + using CAST_DATA = IntegerCastData; + StringUtils::removeCStringWhiteSpaces(input, len); + if (len == 0) { + return false; + } + + bool negativeFlag = input[0] == '-'; + if (negativeFlag) { + input++; + len -= 1; + } + + CAST_DATA res; + res.result = 0; + auto pos = 0u; + auto periodPos = len - 1u; + while (pos < len) { + auto chr = input[pos]; + if (input[pos] == '.') { + periodPos = pos; + } else if (pos > periodPos && pos - periodPos > scale) { + // we've parsed the digit limit + break; + } else if (!StringUtils::CharacterIsDigit(chr) || + !CAST_OP::template handleDigit(res, chr - '0')) { + return false; + } + pos++; + } + if (pos < len) { + // then we parsed the digit limit, so round the final digit + if (!StringUtils::CharacterIsDigit(input[pos])) { + return false; + } + if (!CAST_OP::finalize(res)) { + return false; + } + // then determine rounding + if (input[pos] >= '5') { + res.result += 1; + } + } + while (pos - periodPos < scale + 1) { + // trailing 0's + if (!CAST_OP::template handleDigit(res, 0)) { + return false; + } + pos++; + } + if (!CAST_OP::finalize(res)) { + return false; + } + if (res.result >= pow10s[precision]) { + return false; + } + result = negativeFlag ? -res.result : res.result; + return true; +} + +template +void decimalCast(const char* input, uint64_t len, T& result, const LogicalType& type) { + if (!tryDecimalCast(input, len, result, DecimalType::getPrecision(type), + DecimalType::getScale(type))) { + throw ConversionException(stringFormat("Cast failed. {} is not in {} range.", + std::string{input, (size_t)len}, type.toString())); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/numeric_cast.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/numeric_cast.h new file mode 100644 index 0000000000..b7bddc6d5e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/numeric_cast.h @@ -0,0 +1,164 @@ +#pragma once + +#include +#include + +#include "numeric_limits.h" + +namespace lbug { +namespace function { + +template +inline bool tryCastWithOverflowCheck(SRC value, DST& result) { + if (NumericLimits::isSigned() != NumericLimits::isSigned()) { + if (NumericLimits::isSigned()) { + if (NumericLimits::maxNumDigits() > NumericLimits::maxNumDigits()) { + if (value < 0 || value > (SRC)NumericLimits::maximum()) { + return false; + } + } else { + if (value < 0) { + return false; + } + } + result = (DST)value; + return true; + } else { + // unsigned to signed conversion + if (NumericLimits::maxNumDigits() >= NumericLimits::maxNumDigits()) { + if (value <= (SRC)NumericLimits::maximum()) { + result = (DST)value; + return true; + } + return false; + } else { + result = (DST)value; + return true; + } + } + } else { + // same sign conversion + if (NumericLimits::maxNumDigits() >= NumericLimits::maxNumDigits()) { + result = (DST)value; + return true; + } else { + if (value < SRC(NumericLimits::minimum()) || + value > SRC(NumericLimits::maximum())) { + return false; + } + result = (DST)value; + return true; + } + } +} + +template +inline bool tryCastWithOverflowCheckFloat(SRC value, T& result, SRC min, SRC max) { + if (!(value >= min && value < max)) { + return false; + } + // PG FLOAT => INT casts use statistical rounding. + result = std::nearbyint(value); + return true; +} + +template<> +inline bool tryCastWithOverflowCheck(float value, int8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -128.0f, 128.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, int16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -32768.0f, 32768.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, int32_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -2147483648.0f, + 2147483648.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, int64_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -9223372036854775808.0f, + 9223372036854775808.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, uint8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, 256.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, uint16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, 65536.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, uint32_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, 4294967296.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(float value, uint64_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, + 18446744073709551616.0f); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, int8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -128.0, 128.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, int16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -32768.0, 32768.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, int32_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -2147483648.0, + 2147483648.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, int64_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -9223372036854775808.0, + 9223372036854775808.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, uint8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, 256.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, uint16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, 65536.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, uint32_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, 4294967296.0); +} + +template<> +inline bool tryCastWithOverflowCheck(double value, uint64_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, + 18446744073709551615.0); +} + +template<> +inline bool tryCastWithOverflowCheck(float input, double& result) { + result = double(input); + return true; +} + +template<> +inline bool tryCastWithOverflowCheck(double input, float& result) { + result = float(input); + return true; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/numeric_limits.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/numeric_limits.h new file mode 100644 index 0000000000..c2ef07c2e0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/functions/numeric_limits.h @@ -0,0 +1,208 @@ +#pragma once + +#include +#include +#include + +#include "common/types/int128_t.h" +#include "common/types/uint128_t.h" + +namespace lbug { +namespace function { + +template +struct NumericLimits { + static constexpr T minimum() { return std::numeric_limits::lowest(); } + static constexpr T maximum() { return std::numeric_limits::max(); } + static constexpr bool isSigned() { return std::is_signed::value; } + template + static bool isInBounds(V val) { + return minimum() <= val && val <= maximum(); + } + static constexpr uint64_t maxNumDigits(); +}; + +template<> +struct NumericLimits { + static constexpr common::int128_t minimum() { + return {0, std::numeric_limits::lowest()}; + } + static constexpr common::int128_t maximum() { + return {std::numeric_limits::max(), std::numeric_limits::max()}; + } + template + static bool isInBounds(V val) { + return minimum() <= val && val <= maximum(); + } + static constexpr bool isSigned() { return true; } + static constexpr uint64_t maxNumDigits() { return 39; } +}; + +template<> +struct NumericLimits { + static constexpr common::uint128_t minimum() { return {0, 0}; } + static constexpr common::uint128_t maximum() { + return {std::numeric_limits::max(), std::numeric_limits::max()}; + } + template + static bool isInBounds(V val) { + return minimum() <= val && val <= maximum(); + } + static constexpr bool isSigned() { return false; } + static constexpr uint64_t maxNumDigits() { return 39; } +}; + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 3; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 5; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 10; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 19; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 3; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 5; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 10; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 20; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 127; +} + +template<> +constexpr uint64_t NumericLimits::maxNumDigits() { + return 250; +} + +template +static constexpr std::array::maxNumDigits()> pow10Sequence() { + std::array::maxNumDigits()> retval{}; + retval[0] = 1; + for (auto i = 1u; i < NumericLimits::maxNumDigits(); i++) { + retval[i] = retval[i - 1] * 10; + } + return retval; +} + +template<> +constexpr std::array::maxNumDigits()> +pow10Sequence() { + return { + common::int128_t(1UL, 0LL), + common::int128_t(10UL, 0LL), + common::int128_t(100UL, 0LL), + common::int128_t(1000UL, 0LL), + common::int128_t(10000UL, 0LL), + common::int128_t(100000UL, 0LL), + common::int128_t(1000000UL, 0LL), + common::int128_t(10000000UL, 0LL), + common::int128_t(100000000UL, 0LL), + common::int128_t(1000000000UL, 0LL), + common::int128_t(10000000000UL, 0LL), + common::int128_t(100000000000UL, 0LL), + common::int128_t(1000000000000UL, 0LL), + common::int128_t(10000000000000UL, 0LL), + common::int128_t(100000000000000UL, 0LL), + common::int128_t(1000000000000000UL, 0LL), + common::int128_t(10000000000000000UL, 0LL), + common::int128_t(100000000000000000UL, 0LL), + common::int128_t(1000000000000000000UL, 0LL), + common::int128_t(10000000000000000000UL, 0LL), + common::int128_t(7766279631452241920UL, 5LL), + common::int128_t(3875820019684212736UL, 54LL), + common::int128_t(1864712049423024128UL, 542LL), + common::int128_t(200376420520689664UL, 5421LL), + common::int128_t(2003764205206896640UL, 54210LL), + common::int128_t(1590897978359414784UL, 542101LL), + common::int128_t(15908979783594147840UL, 5421010LL), + common::int128_t(11515845246265065472UL, 54210108LL), + common::int128_t(4477988020393345024UL, 542101086LL), + common::int128_t(7886392056514347008UL, 5421010862LL), + common::int128_t(5076944270305263616UL, 54210108624LL), + common::int128_t(13875954555633532928UL, 542101086242LL), + common::int128_t(9632337040368467968UL, 5421010862427LL), + common::int128_t(4089650035136921600UL, 54210108624275LL), + common::int128_t(4003012203950112768UL, 542101086242752LL), + common::int128_t(3136633892082024448UL, 5421010862427522LL), + common::int128_t(12919594847110692864UL, 54210108624275221LL), + common::int128_t(68739955140067328UL, 542101086242752217LL), + common::int128_t(687399551400673280UL, 5421010862427522170LL), + }; // Couldn't find a clean way to do this +} + +template<> +constexpr std::array::maxNumDigits()> +pow10Sequence() { + return { + common::uint128_t(1UL, 0UL), + common::uint128_t(10UL, 0UL), + common::uint128_t(100UL, 0UL), + common::uint128_t(1000UL, 0UL), + common::uint128_t(10000UL, 0UL), + common::uint128_t(100000UL, 0UL), + common::uint128_t(1000000UL, 0UL), + common::uint128_t(10000000UL, 0UL), + common::uint128_t(100000000UL, 0UL), + common::uint128_t(1000000000UL, 0UL), + common::uint128_t(10000000000UL, 0UL), + common::uint128_t(100000000000UL, 0UL), + common::uint128_t(1000000000000UL, 0UL), + common::uint128_t(10000000000000UL, 0UL), + common::uint128_t(100000000000000UL, 0UL), + common::uint128_t(1000000000000000UL, 0UL), + common::uint128_t(10000000000000000UL, 0UL), + common::uint128_t(100000000000000000UL, 0UL), + common::uint128_t(1000000000000000000UL, 0UL), + common::uint128_t(10000000000000000000UL, 0UL), + common::uint128_t(7766279631452241920UL, 5UL), + common::uint128_t(3875820019684212736UL, 54UL), + common::uint128_t(1864712049423024128UL, 542UL), + common::uint128_t(200376420520689664UL, 5421UL), + common::uint128_t(2003764205206896640UL, 54210UL), + common::uint128_t(1590897978359414784UL, 542101UL), + common::uint128_t(15908979783594147840UL, 5421010UL), + common::uint128_t(11515845246265065472UL, 54210108UL), + common::uint128_t(4477988020393345024UL, 542101086UL), + common::uint128_t(7886392056514347008UL, 5421010862UL), + common::uint128_t(5076944270305263616UL, 54210108624UL), + common::uint128_t(13875954555633532928UL, 542101086242UL), + common::uint128_t(9632337040368467968UL, 5421010862427UL), + common::uint128_t(4089650035136921600UL, 54210108624275UL), + common::uint128_t(4003012203950112768UL, 542101086242752UL), + common::uint128_t(3136633892082024448UL, 5421010862427522UL), + common::uint128_t(12919594847110692864UL, 54210108624275221UL), + common::uint128_t(68739955140067328UL, 542101086242752217UL), + common::uint128_t(687399551400673280UL, 5421010862427522170UL), + }; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/vector_cast_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/vector_cast_functions.h new file mode 100644 index 0000000000..c2dac1caa5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/cast/vector_cast_functions.h @@ -0,0 +1,194 @@ +#pragma once + +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +/** + * In the system we define explicit cast and implicit cast. + * Explicit casts are performed from user function calls, e.g. date(), string(). + * Implicit casts are added internally. + */ + +struct CastChildFunctionExecutor; + +template +concept CastExecutor = + std::is_same_v || std::is_same_v; + +struct CastFunction { + // This function is only used by expression binder when implicit cast is needed. + // The expression binder should consider reusing the existing matchFunction() API. + static bool hasImplicitCast(const common::LogicalType& srcType, + const common::LogicalType& dstType); + + template + static std::unique_ptr bindCastFunction(const std::string& functionName, + const common::LogicalType& sourceType, const common::LogicalType& targetType); +}; + +struct CastToDateFunction { + static constexpr const char* name = "TO_DATE"; + + static function_set getFunctionSet(); +}; + +struct DateFunction { + using alias = CastToDateFunction; + + static constexpr const char* name = "DATE"; +}; + +struct CastToTimestampFunction { + static constexpr const char* name = "TIMESTAMP"; + + static function_set getFunctionSet(); +}; + +struct CastToIntervalFunction { + static constexpr const char* name = "TO_INTERVAL"; + + static function_set getFunctionSet(); +}; + +struct IntervalFunctionAlias { + using alias = CastToIntervalFunction; + + static constexpr const char* name = "INTERVAL"; +}; + +struct DurationFunction { + using alias = CastToIntervalFunction; + + static constexpr const char* name = "DURATION"; +}; + +struct CastToStringFunction { + static constexpr const char* name = "TO_STRING"; + + static function_set getFunctionSet(); +}; + +struct StringFunction { + using alias = CastToStringFunction; + + static constexpr const char* name = "STRING"; +}; + +struct CastToBlobFunction { + static constexpr const char* name = "TO_BLOB"; + + static function_set getFunctionSet(); +}; + +struct BlobFunction { + using alias = CastToBlobFunction; + + static constexpr const char* name = "BLOB"; +}; + +struct CastToUUIDFunction { + static constexpr const char* name = "TO_UUID"; + + static function_set getFunctionSet(); +}; + +struct UUIDFunction { + using alias = CastToUUIDFunction; + + static constexpr const char* name = "UUID"; +}; + +struct CastToBoolFunction { + static constexpr const char* name = "TO_BOOL"; + + static function_set getFunctionSet(); +}; + +struct CastToDoubleFunction { + static constexpr const char* name = "TO_DOUBLE"; + + static function_set getFunctionSet(); +}; + +struct CastToFloatFunction { + static constexpr const char* name = "TO_FLOAT"; + + static function_set getFunctionSet(); +}; + +struct CastToSerialFunction { + static constexpr const char* name = "TO_SERIAL"; + + static function_set getFunctionSet(); +}; + +struct CastToInt128Function { + static constexpr const char* name = "TO_INT128"; + + static function_set getFunctionSet(); +}; + +struct CastToInt64Function { + static constexpr const char* name = "TO_INT64"; + + static function_set getFunctionSet(); +}; + +struct CastToInt32Function { + static constexpr const char* name = "TO_INT32"; + + static function_set getFunctionSet(); +}; + +struct CastToInt16Function { + static constexpr const char* name = "TO_INT16"; + + static function_set getFunctionSet(); +}; + +struct CastToInt8Function { + static constexpr const char* name = "TO_INT8"; + + static function_set getFunctionSet(); +}; + +struct CastToUInt128Function { + static constexpr const char* name = "TO_UINT128"; + + static function_set getFunctionSet(); +}; + +struct CastToUInt64Function { + static constexpr const char* name = "TO_UINT64"; + + static function_set getFunctionSet(); +}; + +struct CastToUInt32Function { + static constexpr const char* name = "TO_UINT32"; + + static function_set getFunctionSet(); +}; + +struct CastToUInt16Function { + static constexpr const char* name = "TO_UINT16"; + + static function_set getFunctionSet(); +}; + +struct CastToUInt8Function { + static constexpr const char* name = "TO_UINT8"; + + static function_set getFunctionSet(); +}; + +struct CastAnyFunction { + static constexpr const char* name = "CAST"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/comparison/comparison_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/comparison/comparison_functions.h new file mode 100644 index 0000000000..ec6a4df4ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/comparison/comparison_functions.h @@ -0,0 +1,120 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct Equals { + template + static inline void operation(const A& left, const B& right, uint8_t& result, + common::ValueVector* /*leftVector*/, common::ValueVector* /*rightVector*/) { + result = left == right; + } + + template + static bool operation(const T& left, const T& right) { + uint8_t result = 0; + operation(left, right, result, nullptr, nullptr); + return result; + } +}; + +struct NotEquals { + template + static inline void operation(const A& left, const B& right, uint8_t& result, + common::ValueVector* leftVector, common::ValueVector* rightVector) { + Equals::operation(left, right, result, leftVector, rightVector); + result = !result; + } + + template + static bool operation(const T& left, const T& right) { + uint8_t result = 0; + operation(left, right, result, nullptr, nullptr); + return result; + } +}; + +struct GreaterThan { + template + static inline void operation(const A& left, const B& right, uint8_t& result, + common::ValueVector* /*leftVector*/, common::ValueVector* /*rightVector*/) { + result = left > right; + } + + template + static bool operation(const T& left, const T& right) { + uint8_t result = 0; + operation(left, right, result, nullptr, nullptr); + return result; + } +}; + +struct GreaterThanEquals { + template + static inline void operation(const A& left, const B& right, uint8_t& result, + common::ValueVector* leftVector, common::ValueVector* rightVector) { + uint8_t isGreater = 0; + uint8_t isEqual = 0; + GreaterThan::operation(left, right, isGreater, leftVector, rightVector); + Equals::operation(left, right, isEqual, leftVector, rightVector); + result = isGreater || isEqual; + } + + template + static bool operation(const T& left, const T& right) { + uint8_t result = 0; + operation(left, right, result, nullptr, nullptr); + return result; + } +}; + +struct LessThan { + template + static inline void operation(const A& left, const B& right, uint8_t& result, + common::ValueVector* leftVector, common::ValueVector* rightVector) { + GreaterThanEquals::operation(left, right, result, leftVector, rightVector); + result = !result; + } + + template + static bool operation(const T& left, const T& right) { + uint8_t result = 0; + operation(left, right, result, nullptr, nullptr); + return result; + } +}; + +struct LessThanEquals { + template + static inline void operation(const A& left, const B& right, uint8_t& result, + common::ValueVector* leftVector, common::ValueVector* rightVector) { + GreaterThan::operation(left, right, result, leftVector, rightVector); + result = !result; + } + + template + static bool operation(const T& left, const T& right) { + uint8_t result = 0; + operation(left, right, result, nullptr, nullptr); + return result; + } +}; + +// specialization for equal and greater than. +template<> +void Equals::operation(const common::list_entry_t& left, const common::list_entry_t& right, + uint8_t& result, common::ValueVector* leftVector, common::ValueVector* rightVector); +template<> +void Equals::operation(const common::struct_entry_t& left, const common::struct_entry_t& right, + uint8_t& result, common::ValueVector* leftVector, common::ValueVector* rightVector); +template<> +void GreaterThan::operation(const common::list_entry_t& left, const common::list_entry_t& right, + uint8_t& result, common::ValueVector* leftVector, common::ValueVector* rightVector); +template<> +void GreaterThan::operation(const common::struct_entry_t& left, const common::struct_entry_t& right, + uint8_t& result, common::ValueVector* leftVector, common::ValueVector* rightVector); + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/comparison/vector_comparison_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/comparison/vector_comparison_functions.h new file mode 100644 index 0000000000..54024f1d48 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/comparison/vector_comparison_functions.h @@ -0,0 +1,272 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/uint128_t.h" +#include "comparison_functions.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +struct ComparisonFunction { + template + static function_set getFunctionSet(const std::string& name) { + function_set functionSet; + for (auto& comparableType : common::LogicalTypeUtils::getAllValidLogicTypeIDs()) { + functionSet.push_back(getFunction(name, comparableType, comparableType)); + } + functionSet.push_back(getDecimalCompare(name)); + return functionSet; + } + +private: + template + static void BinaryComparisonExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector, dataPtr); + } + + template + static bool BinaryComparisonSelectFunction( + const std::vector>& params, + common::SelectionVector& selVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 2); + return BinaryFunctionExecutor::selectComparison(*params[0], + *params[1], selVector, dataPtr); + } + + template + static std::unique_ptr getFunction(const std::string& name, + common::LogicalTypeID leftType, common::LogicalTypeID rightType) { + auto leftPhysical = common::LogicalType::getPhysicalType(leftType); + auto rightPhysical = common::LogicalType::getPhysicalType(rightType); + scalar_func_exec_t execFunc; + getExecFunc(leftPhysical, rightPhysical, execFunc); + scalar_func_select_t selectFunc; + getSelectFunc(leftPhysical, rightPhysical, selectFunc); + return std::make_unique(name, + std::vector{leftType, rightType}, common::LogicalTypeID::BOOL, + execFunc, selectFunc); + } + + template + static std::unique_ptr bindDecimalCompare(ScalarBindFuncInput bindInput) { + auto func = bindInput.definition->ptrCast(); + // assumes input types are identical + auto physicalType = bindInput.arguments[0]->dataType.getPhysicalType(); + getExecFunc(physicalType, physicalType, func->execFunc); + getSelectFunc(physicalType, physicalType, func->selectFunc); + return nullptr; + } + + template + static std::unique_ptr getDecimalCompare(const std::string& name) { + scalar_bind_func bindFunc = bindDecimalCompare; + auto func = std::make_unique(name, + std::vector{common::LogicalTypeID::DECIMAL, + common::LogicalTypeID::DECIMAL}, + common::LogicalTypeID::BOOL); // necessary because decimal physical type is not known + // from the ID + func->bindFunc = bindFunc; + return func; + } + + // When comparing two values, we guarantee that they must have the same dataType. So we only + // need to switch the physical type to get the corresponding exec function. + template + static void getExecFunc(common::PhysicalTypeID leftType, common::PhysicalTypeID rightType, + scalar_func_exec_t& func) { + switch (leftType) { + case common::PhysicalTypeID::INT64: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::INT32: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::INT16: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::INT8: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::UINT64: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::UINT32: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::UINT16: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::UINT8: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::INT128: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::DOUBLE: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::FLOAT: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::BOOL: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::STRING: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::INTERNAL_ID: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::UINT128: { + func = + BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::INTERVAL: { + func = + BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::ARRAY: + case common::PhysicalTypeID::LIST: { + func = BinaryComparisonExecFunction; + } break; + case common::PhysicalTypeID::STRUCT: { + func = BinaryComparisonExecFunction; + } break; + default: + throw common::RuntimeException( + "Invalid input data types(" + common::PhysicalTypeUtils::toString(leftType) + "," + + common::PhysicalTypeUtils::toString(rightType) + ") for getExecFunc."); + } + } + + template + static void getSelectFunc(common::PhysicalTypeID leftTypeID, common::PhysicalTypeID rightTypeID, + scalar_func_select_t& func) { + KU_ASSERT(leftTypeID == rightTypeID); + switch (leftTypeID) { + case common::PhysicalTypeID::INT64: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::INT32: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::INT16: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::INT8: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::UINT64: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::UINT32: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::UINT16: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::UINT8: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::INT128: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::DOUBLE: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::FLOAT: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::BOOL: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::STRING: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::INTERNAL_ID: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::UINT128: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::INTERVAL: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::ARRAY: + case common::PhysicalTypeID::LIST: { + func = BinaryComparisonSelectFunction; + } break; + case common::PhysicalTypeID::STRUCT: { + func = BinaryComparisonSelectFunction; + } break; + default: + throw common::RuntimeException( + "Invalid input data types(" + common::PhysicalTypeUtils::toString(leftTypeID) + + "," + common::PhysicalTypeUtils::toString(rightTypeID) + ") for getSelectFunc."); + } + } +}; + +struct EqualsFunction { + static constexpr const char* name = "EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); + } +}; + +struct NotEqualsFunction { + static constexpr const char* name = "NOT_EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); + } +}; + +struct GreaterThanFunction { + static constexpr const char* name = "GREATER_THAN"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); + } +}; + +struct GreaterThanEqualsFunction { + static constexpr const char* name = "GREATER_THAN_EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); + } +}; + +struct LessThanFunction { + static constexpr const char* name = "LESS_THAN"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); + } +}; + +struct LessThanEqualsFunction { + static constexpr const char* name = "LESS_THAN_EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/const_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/const_function_executor.h new file mode 100644 index 0000000000..25e04f415e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/const_function_executor.h @@ -0,0 +1,21 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct ConstFunctionExecutor { + + template + static void execute(common::ValueVector& result, common::SelectionVector& sel) { + KU_ASSERT(result.state->isFlat()); + auto resultValues = (RESULT_TYPE*)result.getData(); + auto idx = sel[0]; + KU_ASSERT(idx == 0); + OP::operation(resultValues[idx]); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/date/date_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/date/date_functions.h new file mode 100644 index 0000000000..62689f6ad5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/date/date_functions.h @@ -0,0 +1,162 @@ +#pragma once + +#include "common/assert.h" +#include "common/types/date_t.h" +#include "common/types/ku_string.h" +#include "common/types/timestamp_t.h" + +namespace lbug { +namespace function { + +struct DayName { + template + static inline void operation(T& /*input*/, common::ku_string_t& /*result*/) { + KU_UNREACHABLE; + } +}; + +template<> +inline void DayName::operation(common::date_t& input, common::ku_string_t& result) { + std::string dayName = common::Date::getDayName(input); + result.set(dayName); +} + +template<> +inline void DayName::operation(common::timestamp_t& input, common::ku_string_t& result) { + common::dtime_t time{}; + common::date_t date{}; + common::Timestamp::convert(input, date, time); + std::string dayName = common::Date::getDayName(date); + result.set(dayName); +} + +struct MonthName { + template + static inline void operation(T& /*input*/, common::ku_string_t& /*result*/) { + KU_UNREACHABLE; + } +}; + +template<> +inline void MonthName::operation(common::date_t& input, common::ku_string_t& result) { + std::string monthName = common::Date::getMonthName(input); + result.set(monthName); +} + +template<> +inline void MonthName::operation(common::timestamp_t& input, common::ku_string_t& result) { + common::dtime_t time{}; + common::date_t date{}; + common::Timestamp::convert(input, date, time); + std::string monthName = common::Date::getMonthName(date); + result.set(monthName); +} + +struct LastDay { + template + + static inline void operation(T& /*input*/, common::date_t& /*result*/) { + KU_UNREACHABLE; + } +}; + +template<> +inline void LastDay::operation(common::date_t& input, common::date_t& result) { + result = common::Date::getLastDay(input); +} + +template<> +inline void LastDay::operation(common::timestamp_t& input, common::date_t& result) { + common::date_t date{}; + common::dtime_t time{}; + common::Timestamp::convert(input, date, time); + result = common::Date::getLastDay(date); +} + +struct DatePart { + template + static inline void operation(LEFT_TYPE& /*partSpecifier*/, RIGHT_TYPE& /*input*/, + int64_t& /*result*/) { + KU_UNREACHABLE; + } +}; + +template<> +inline void DatePart::operation(common::ku_string_t& partSpecifier, common::date_t& input, + int64_t& result) { + common::DatePartSpecifier specifier{}; + common::Interval::tryGetDatePartSpecifier(partSpecifier.getAsString(), specifier); + result = common::Date::getDatePart(specifier, input); +} + +template<> +inline void DatePart::operation(common::ku_string_t& partSpecifier, common::timestamp_t& input, + int64_t& result) { + common::DatePartSpecifier specifier{}; + common::Interval::tryGetDatePartSpecifier(partSpecifier.getAsString(), specifier); + result = common::Timestamp::getTimestampPart(specifier, input); +} + +template<> +inline void DatePart::operation(common::ku_string_t& partSpecifier, common::interval_t& input, + int64_t& result) { + common::DatePartSpecifier specifier{}; + common::Interval::tryGetDatePartSpecifier(partSpecifier.getAsString(), specifier); + result = common::Interval::getIntervalPart(specifier, input); +} + +struct DateTrunc { + template + static inline void operation(LEFT_TYPE& /*partSpecifier*/, RIGHT_TYPE& /*input*/, + RIGHT_TYPE& /*result*/) { + KU_UNREACHABLE; + } +}; + +template<> +inline void DateTrunc::operation(common::ku_string_t& partSpecifier, common::date_t& input, + common::date_t& result) { + common::DatePartSpecifier specifier{}; + common::Interval::tryGetDatePartSpecifier(partSpecifier.getAsString(), specifier); + result = common::Date::trunc(specifier, input); +} + +template<> +inline void DateTrunc::operation(common::ku_string_t& partSpecifier, common::timestamp_t& input, + common::timestamp_t& result) { + common::DatePartSpecifier specifier{}; + common::Interval::tryGetDatePartSpecifier(partSpecifier.getAsString(), specifier); + result = common::Timestamp::trunc(specifier, input); +} + +struct Greatest { + template + static inline void operation(T& left, T& right, T& result) { + result = left > right ? left : right; + } +}; + +struct Least { + template + static inline void operation(T& left, T& right, T& result) { + result = left > right ? right : left; + } +}; + +struct MakeDate { + static inline void operation(int64_t& year, int64_t& month, int64_t& day, + common::date_t& result) { + result = common::Date::fromDate(year, month, day); + } +}; + +struct CurrentDate { + static void operation(common::date_t& result, void* dataPtr); +}; + +struct CurrentTimestamp { + static void operation(common::timestamp_tz_t& result, void* dataPtr); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/date/vector_date_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/date/vector_date_functions.h new file mode 100644 index 0000000000..292257b1d1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/date/vector_date_functions.h @@ -0,0 +1,81 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct DatePartFunction { + static constexpr const char* name = "DATE_PART"; + + static function_set getFunctionSet(); +}; + +struct DatePartFunctionAlias { + using alias = DatePartFunction; + + static constexpr const char* name = "DATEPART"; +}; + +struct DateTruncFunction { + static constexpr const char* name = "DATE_TRUNC"; + + static function_set getFunctionSet(); +}; + +struct DateTruncFunctionAlias { + using alias = DateTruncFunction; + + static constexpr const char* name = "DATETRUNC"; +}; + +struct DayNameFunction { + static constexpr const char* name = "DAYNAME"; + + static function_set getFunctionSet(); +}; + +struct GreatestFunction { + static constexpr const char* name = "GREATEST"; + + static function_set getFunctionSet(); +}; + +struct LastDayFunction { + static constexpr const char* name = "LAST_DAY"; + + static function_set getFunctionSet(); +}; + +struct LeastFunction { + static constexpr const char* name = "LEAST"; + + static function_set getFunctionSet(); +}; + +struct MakeDateFunction { + static constexpr const char* name = "MAKE_DATE"; + + static function_set getFunctionSet(); +}; + +struct MonthNameFunction { + static constexpr const char* name = "MONTHNAME"; + + static function_set getFunctionSet(); +}; + +struct CurrentDateFunction { + static constexpr const char* name = "CURRENT_DATE"; + + static function_set getFunctionSet(); +}; + +struct CurrentTimestampFunction { + static constexpr const char* name = "CURRENT_TIMESTAMP"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/export/export_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/export/export_function.h new file mode 100644 index 0000000000..7099d82bef --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/export/export_function.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +#include "common/case_insensitive_map.h" +#include "common/types/value/value.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct ExportFuncLocalState { + virtual ~ExportFuncLocalState() = default; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct ExportFuncBindData; + +struct ExportFuncSharedState { + virtual ~ExportFuncSharedState() = default; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + + virtual void init(main::ClientContext& context, const ExportFuncBindData& bindData) = 0; + + std::atomic parallelFlag = true; +}; + +struct ExportFuncBindData { + std::vector columnNames; + std::vector types; + std::string fileName; + + ExportFuncBindData(std::vector columnNames, std::string fileName) + : columnNames{std::move(columnNames)}, fileName{std::move(fileName)} {} + + virtual ~ExportFuncBindData() = default; + + void setDataType(std::vector types_) { types = std::move(types_); } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + + virtual std::unique_ptr copy() const = 0; +}; + +struct ExportFuncBindInput { + std::vector columnNames; + std::string filePath; + common::case_insensitive_map_t parsingOptions; +}; + +using export_bind_t = + std::function(function::ExportFuncBindInput& bindInput)>; +using export_init_local_t = std::function( + main::ClientContext&, const ExportFuncBindData&, std::vector)>; +using export_create_shared_t = std::function()>; +using export_init_shared_t = + std::function; +using export_sink_t = std::function>)>; +using export_combine_t = std::function; +using export_finalize_t = std::function; + +struct LBUG_API ExportFunction : public Function { + ExportFunction() = default; + explicit ExportFunction(std::string name) : Function{std::move(name), {}} {} + + ExportFunction(std::string name, export_init_local_t initLocal, + export_create_shared_t createShared, export_init_shared_t initShared, + export_sink_t copyToSink, export_combine_t copyToCombine, export_finalize_t copyToFinalize) + : Function{std::move(name), {}}, initLocalState{std::move(initLocal)}, + createSharedState{std::move(createShared)}, initSharedState{std::move(initShared)}, + sink{std::move(copyToSink)}, combine{std::move(copyToCombine)}, + finalize{std::move(copyToFinalize)} {} + + export_bind_t bind; + export_init_local_t initLocalState; + export_create_shared_t createSharedState; + export_init_shared_t initSharedState; + export_sink_t sink; + export_combine_t combine; + export_finalize_t finalize; +}; + +struct ExportCSVFunction : public ExportFunction { + static constexpr const char* name = "COPY_CSV"; + + static function_set getFunctionSet(); +}; + +struct ExportParquetFunction : public ExportFunction { + static constexpr const char* name = "COPY_PARQUET"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/function.h new file mode 100644 index 0000000000..0c25a21a59 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/function.h @@ -0,0 +1,108 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/api.h" + +namespace lbug { + +namespace main { +class ClientContext; +} + +namespace function { + +struct LBUG_API FunctionBindData { + std::vector paramTypes; + common::LogicalType resultType; + // TODO: the following two fields should be moved to FunctionLocalState. + main::ClientContext* clientContext; + int64_t count; + + explicit FunctionBindData(common::LogicalType dataType) + : resultType{std::move(dataType)}, clientContext{nullptr}, count{1} {} + FunctionBindData(std::vector paramTypes, common::LogicalType resultType) + : paramTypes{std::move(paramTypes)}, resultType{std::move(resultType)}, + clientContext{nullptr}, count{1} {} + DELETE_COPY_AND_MOVE(FunctionBindData); + virtual ~FunctionBindData() = default; + + static std::unique_ptr getSimpleBindData( + const binder::expression_vector& params, const common::LogicalType& resultType); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + + virtual std::unique_ptr copy() const { + return std::make_unique(common::LogicalType::copy(paramTypes), + resultType.copy()); + } +}; + +struct Function; +using function_set = std::vector>; + +struct ScalarBindFuncInput { + const binder::expression_vector& arguments; + Function* definition; + main::ClientContext* context; + std::vector optionalArguments; + + ScalarBindFuncInput(const binder::expression_vector& arguments, Function* definition, + main::ClientContext* context, std::vector optionalArguments) + : arguments{arguments}, definition{definition}, context{context}, + optionalArguments{std::move(optionalArguments)} {} +}; + +using scalar_bind_func = + std::function(const ScalarBindFuncInput& bindInput)>; + +struct LBUG_API Function { + std::string name; + std::vector parameterTypeIDs; + bool isReadOnly = true; + + Function() : isReadOnly{true} {}; + Function(std::string name, std::vector parameterTypeIDs) + : name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)} {} + Function(const Function&) = default; + + virtual ~Function() = default; + + virtual std::string signatureToString() const { + return common::LogicalTypeUtils::toString(parameterTypeIDs); + } + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +struct ScalarOrAggregateFunction : Function { + common::LogicalTypeID returnTypeID = common::LogicalTypeID::ANY; + scalar_bind_func bindFunc = nullptr; + + ScalarOrAggregateFunction() : Function{} {} + ScalarOrAggregateFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID) + : Function{std::move(name), std::move(parameterTypeIDs)}, returnTypeID{returnTypeID} {} + ScalarOrAggregateFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_bind_func bindFunc) + : Function{std::move(name), std::move(parameterTypeIDs)}, returnTypeID{returnTypeID}, + bindFunc{std::move(bindFunc)} {} + + std::string signatureToString() const override { + auto result = Function::signatureToString(); + result += " -> " + common::LogicalTypeUtils::toString(returnTypeID); + return result; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/function_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/function_collection.h new file mode 100644 index 0000000000..0f82645a47 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/function_collection.h @@ -0,0 +1,24 @@ +#pragma once + +#include "catalog/catalog_entry/catalog_entry_type.h" +#include "function.h" + +using namespace lbug::catalog; + +namespace lbug { +namespace function { + +using get_function_set_fun = std::function; + +struct FunctionCollection { + get_function_set_fun getFunctionSetFunc; + + const char* name; + + CatalogEntryType catalogEntryType; + + static FunctionCollection* getFunctions(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/auxiliary_state/gds_auxilary_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/auxiliary_state/gds_auxilary_state.h new file mode 100644 index 0000000000..da3894d480 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/auxiliary_state/gds_auxilary_state.h @@ -0,0 +1,48 @@ +#pragma once + +#include "common/cast.h" +#include "common/types/types.h" + +namespace lbug { +namespace graph { +class Graph; +} + +namespace processor { +struct ExecutionContext; +} + +namespace function { + +// Maintain algorithm specific data structures +class GDSAuxiliaryState { +public: + GDSAuxiliaryState() = default; + virtual ~GDSAuxiliaryState() = default; + + // Initialize state for source node. + virtual void initSource(common::nodeID_t) {} + // Initialize state before extending from `fromTable` to `toTable`. + // Normally you want to pin data structures on `toTableID`. + virtual void beginFrontierCompute(common::table_id_t fromTableID, + common::table_id_t toTableID) = 0; + + virtual void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) = 0; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +class EmptyGDSAuxiliaryState : public GDSAuxiliaryState { +public: + EmptyGDSAuxiliaryState() = default; + + void beginFrontierCompute(common::table_id_t, common::table_id_t) override {} + + void switchToDense(processor::ExecutionContext*, graph::Graph*) override {} +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/auxiliary_state/path_auxiliary_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/auxiliary_state/path_auxiliary_state.h new file mode 100644 index 0000000000..f5e38fa30a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/auxiliary_state/path_auxiliary_state.h @@ -0,0 +1,55 @@ +#pragma once + +#include "function/gds/bfs_graph.h" +#include "gds_auxilary_state.h" + +namespace lbug { +namespace function { + +class PathAuxiliaryState : public GDSAuxiliaryState { +public: + explicit PathAuxiliaryState(std::unique_ptr bfsGraphManager) + : bfsGraphManager{std::move(bfsGraphManager)} {} + + BFSGraphManager* getBFSGraphManager() { return bfsGraphManager.get(); } + + void beginFrontierCompute(common::table_id_t, common::table_id_t toTableID) override { + bfsGraphManager->getCurrentGraph()->pinTableID(toTableID); + } + + void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) override { + bfsGraphManager->switchToDense(context, graph); + } + +private: + std::unique_ptr bfsGraphManager; +}; + +class WSPPathsAuxiliaryState : public GDSAuxiliaryState { +public: + explicit WSPPathsAuxiliaryState(std::unique_ptr bfsGraphManager) + : bfsGraphManager{std::move(bfsGraphManager)} {} + + BFSGraphManager* getBFSGraphManager() { return bfsGraphManager.get(); } + + void initSource(common::nodeID_t sourceNodeID) override { + sourceParent.setCost(0); + bfsGraphManager->getCurrentGraph()->pinTableID(sourceNodeID.tableID); + bfsGraphManager->getCurrentGraph()->setParentList(sourceNodeID.offset, &sourceParent); + } + + void beginFrontierCompute(common::table_id_t, common::table_id_t toTableID) override { + bfsGraphManager->getCurrentGraph()->pinTableID(toTableID); + } + + void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) override { + bfsGraphManager->switchToDense(context, graph); + } + +private: + std::unique_ptr bfsGraphManager; + ParentList sourceParent; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/bfs_graph.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/bfs_graph.h new file mode 100644 index 0000000000..ca51f2a306 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/bfs_graph.h @@ -0,0 +1,183 @@ +#pragma once + +#include "density_state.h" +#include "gds_object_manager.h" +#include "graph/graph.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} +namespace processor { +struct ExecutionContext; +} +namespace function { + +// TODO(Xiyang): optimize if edgeID is not needed. +class ParentList { +public: + void setNbrInfo(common::nodeID_t nodeID_, common::relID_t edgeID_, bool isFwd_) { + nodeID = nodeID_; + edgeID = edgeID_; + isFwd = isFwd_; + } + common::nodeID_t getNodeID() const { return nodeID; } + common::relID_t getEdgeID() const { return edgeID; } + bool isFwdEdge() const { return isFwd; } + + void setNextPtr(ParentList* ptr) { next.store(ptr, std::memory_order_relaxed); } + ParentList* getNextPtr() { return next.load(std::memory_order_relaxed); } + + void setIter(uint16_t iter_) { iter = iter_; } + uint16_t getIter() const { return iter; } + + void setCost(double cost_) { cost = cost_; } + double getCost() const { return cost; } + +private: + common::nodeID_t nodeID; + common::relID_t edgeID; + bool isFwd = true; + + uint16_t iter = UINT16_MAX; + double cost = std::numeric_limits::max(); + // Next pointer + std::atomic next; +}; + +class BaseBFSGraph { + friend class BFSGraphManager; + +public: + explicit BaseBFSGraph(storage::MemoryManager* mm) : mm{mm} {} + virtual ~BaseBFSGraph() = default; + + // This function should be called by a worker thread Ti to grab a block of memory that + // Ti owns and writes to. + ObjectBlock* addNewBlock(); + + virtual void pinTableID(common::table_id_t tableID) = 0; + + // Used to track path for all shortest path & variable length path. + virtual void addParent(uint16_t iter, common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) = 0; + // Used to track path for single shortest path. Assume each offset has at most one parent. + virtual void addSingleParent(uint16_t iter, common::nodeID_t boundNodeID, + common::relID_t edgeID, common::nodeID_t nbrNodeID, bool fwdEdge, + ObjectBlock* block) = 0; + // Used to track path for all weighted shortest path. + virtual bool tryAddParentWithWeight(common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, double weight, + ObjectBlock* block) = 0; + // Used to track path for single weighted shortest path. Assume each offset has at most one + // parent. + virtual bool tryAddSingleParentWithWeight(common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, double weight, + ObjectBlock* block) = 0; + + virtual ParentList* getParentListHead(common::offset_t offset) = 0; + virtual ParentList* getParentListHead(common::nodeID_t nodeID) = 0; + + virtual void setParentList(common::offset_t offset, ParentList* parentList) = 0; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + +protected: + std::mutex mtx; + storage::MemoryManager* mm; + std::vector>> blocks; +}; + +class DenseBFSGraph : public BaseBFSGraph { + friend class BFSGraphManager; + friend class BFSGraphInitVertexCompute; + +public: + DenseBFSGraph(storage::MemoryManager* mm, common::table_id_map_t maxOffsetMap) + : BaseBFSGraph{mm}, maxOffsetMap{std::move(maxOffsetMap)} {} + + void init(processor::ExecutionContext* context, graph::Graph* graph); + + void pinTableID(common::table_id_t tableID) override; + + void addParent(uint16_t iter, common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) override; + void addSingleParent(uint16_t iter, common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) override; + bool tryAddParentWithWeight(common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, double weight, + ObjectBlock* block) override; + bool tryAddSingleParentWithWeight(common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, double weight, + ObjectBlock* block) override; + + ParentList* getParentListHead(common::offset_t offset) override; + ParentList* getParentListHead(common::nodeID_t nodeID) override; + + void setParentList(common::offset_t offset, ParentList* parentList) override; + +private: + common::table_id_map_t maxOffsetMap; + GDSDenseObjectManager> denseObjects; + std::atomic* curData = nullptr; +}; + +class SparseBFSGraph : public BaseBFSGraph { + friend class BFSGraphManager; + +public: + explicit SparseBFSGraph(storage::MemoryManager* mm, + common::table_id_map_t maxOffsetMap) + : BaseBFSGraph{mm}, sparseObjects{maxOffsetMap} {} + + void pinTableID(common::table_id_t tableID) override; + + void addParent(uint16_t iter, common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) override; + void addSingleParent(uint16_t iter, common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, ObjectBlock* block) override; + bool tryAddParentWithWeight(common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, double weight, + ObjectBlock* block) override; + bool tryAddSingleParentWithWeight(common::nodeID_t boundNodeID, common::relID_t edgeID, + common::nodeID_t nbrNodeID, bool fwdEdge, double weight, + ObjectBlock* block) override; + + ParentList* getParentListHead(common::offset_t offset) override; + ParentList* getParentListHead(common::nodeID_t nodeID) override; + + void setParentList(common::offset_t offset, ParentList* parentList) override; + + const std::unordered_map& getCurrentData() const { + return *curData; + } + +private: + GDSSpareObjectManager sparseObjects; + std::unordered_map* curData = nullptr; +}; + +class BFSGraphManager { +public: + BFSGraphManager(common::table_id_map_t maxOffsetMap, + storage::MemoryManager* mm); + + BaseBFSGraph* getCurrentGraph() const { + KU_ASSERT(curGraph); + return curGraph; + } + + void switchToDense(processor::ExecutionContext* context, graph::Graph* graph); + +private: + GDSDensityState state = GDSDensityState::SPARSE; + std::unique_ptr denseBFSGraph; + std::unique_ptr sparseBFSGraph; + BaseBFSGraph* curGraph = nullptr; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/compute.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/compute.h new file mode 100644 index 0000000000..41a00a85e4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/compute.h @@ -0,0 +1,61 @@ +#pragma once + +#include "common/mask.h" +#include "common/types/types.h" +#include "graph/graph.h" + +namespace lbug { +namespace function { + +/** + * Base interface for algorithms that can be implemented in Pregel-like vertex-centric manner or + * more specifically Ligra's edgeCompute (called edgeUpdate in Ligra paper) function. Intended to be + * passed to the helper functions in GDSUtils that parallelize such Pregel-like computations. + */ +class EdgeCompute { +public: + virtual ~EdgeCompute() = default; + + // Does any work that is needed while extending the (boundNodeID, nbrNodeID, edgeID) edge. + // boundNodeID is the nodeID that is in the current frontier and currently executing. + // Returns a list of neighbors which should be put in the next frontier. + // So if the implementing class has access to the next frontier as a field, + // **do not** call setActive. Helper functions in GDSUtils will do that work. + virtual std::vector edgeCompute(common::nodeID_t boundNodeID, + graph::NbrScanState::Chunk& results, bool fwdEdge) = 0; + + virtual void resetSingleThreadState() {} + + virtual bool terminate(common::NodeOffsetMaskMap&) { return false; } + + virtual std::unique_ptr copy() = 0; +}; + +class VertexCompute { +public: + virtual ~VertexCompute() = default; + + // This function is called once on the "main" copy of VertexCompute in the + // GDSUtils::runVertexCompute function. runVertexCompute loops through + // each node table T on the graph on which vertexCompute should run and then before + // parallelizing the computation on T calls this function. + virtual bool beginOnTable(common::table_id_t) { return true; } + + // This function is called by each worker thread T on each node in the morsel that T grabs. + // Does any vertex-centric work that is needed while running on the curNodeID. This function + // should itself do the work of checking if any work should be done on the vertex or not. Note + // that this contrasts with how EdgeCompute::edgeCompute() should be implemented, where the + // GDSUtils helper functions call isActive on nodes to check if any work should be done for + // the edges of a node. Instead, here GDSUtils helper functions for VertexCompute blindly run + // the function on each node in a graph. + virtual void vertexCompute(const graph::VertexScanState::Chunk&) {} + virtual void vertexCompute(common::offset_t, common::offset_t, common::table_id_t) {} + // This function assumes the number of nodes is small (sparse) and morsel driven parallelism + // is not necessary. It should not be used in parallel computations. + virtual void vertexCompute(common::table_id_t) {} + + virtual std::unique_ptr copy() = 0; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/density_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/density_state.h new file mode 100644 index 0000000000..7b6be23279 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/density_state.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +enum class GDSDensityState : uint8_t { + SPARSE = 0, + DENSE = 1, +}; + +} +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/frontier_morsel.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/frontier_morsel.h new file mode 100644 index 0000000000..a1c43fa8dc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/frontier_morsel.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +#include "common/types/types.h" + +namespace lbug { +namespace function { + +class FrontierMorsel { +public: + FrontierMorsel() = default; + + common::offset_t getBeginOffset() const { return beginOffset; } + common::offset_t getEndOffset() const { return endOffset; } + + void init(common::offset_t beginOffset_, common::offset_t endOffset_) { + beginOffset = beginOffset_; + endOffset = endOffset_; + } + +private: + common::offset_t beginOffset = common::INVALID_OFFSET; + common::offset_t endOffset = common::INVALID_OFFSET; +}; + +class LBUG_API FrontierMorselDispatcher { + static constexpr uint64_t MIN_FRONTIER_MORSEL_SIZE = 512; + // Note: MIN_NUMBER_OF_FRONTIER_MORSELS is the minimum number of morsels we aim to have but we + // can have fewer than this. See the beginFrontierComputeBetweenTables to see the actual + // morselSize computation for details. + static constexpr uint64_t MIN_NUMBER_OF_FRONTIER_MORSELS = 128; + +public: + explicit FrontierMorselDispatcher(uint64_t maxThreads); + + void init(common::offset_t _maxOffset); + + bool getNextRangeMorsel(FrontierMorsel& frontierMorsel); + +private: + common::offset_t maxOffset; + std::atomic nextOffset; + uint64_t maxThreads; + uint64_t morselSize; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds.h new file mode 100644 index 0000000000..5db219bace --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds.h @@ -0,0 +1,104 @@ +#pragma once + +#include "binder/expression/node_expression.h" +#include "common/mask.h" +#include "function/table/bind_data.h" +#include "graph/graph.h" +#include "graph/graph_entry.h" +#include "graph/parsed_graph_entry.h" +#include "processor/result/factorized_table_pool.h" + +namespace lbug { + +namespace main { +class ClientContext; +} + +namespace function { + +struct LBUG_API GDSConfig { + virtual ~GDSConfig() = default; + + template + const TARGET& constCast() const { + return *common::ku_dynamic_cast(this); + } +}; + +struct LBUG_API GDSBindData : public TableFuncBindData { + graph::NativeGraphEntry graphEntry; + binder::expression_vector output; + + GDSBindData(binder::expression_vector columns, graph::NativeGraphEntry graphEntry, + binder::expression_vector output) + : TableFuncBindData{std::move(columns)}, graphEntry{graphEntry.copy()}, + output{std::move(output)} {} + + GDSBindData(const GDSBindData& other) + : TableFuncBindData{other}, graphEntry{other.graphEntry.copy()}, output{other.output}, + resultTable{other.resultTable} {} + + void setResultFTable(std::shared_ptr table) { + resultTable = std::move(table); + } + std::shared_ptr getResultTable() const { return resultTable; } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + std::shared_ptr resultTable; +}; + +struct LBUG_API GDSFuncSharedState : public TableFuncSharedState { + std::unique_ptr graph; + + GDSFuncSharedState(std::shared_ptr fTable, + std::unique_ptr graph) + : TableFuncSharedState{}, graph{std::move(graph)}, factorizedTablePool{std::move(fTable)} {} + + void setGraphNodeMask(std::unique_ptr maskMap); + common::NodeOffsetMaskMap* getGraphNodeMaskMap() const { return graphNodeMask.get(); } + +public: + processor::FactorizedTablePool factorizedTablePool; + +private: + std::unique_ptr graphNodeMask = nullptr; +}; + +// Base class for every graph data science algorithm. +class LBUG_API GDSFunction { + static constexpr char NODE_COLUMN_NAME[] = "node"; + static constexpr char REL_COLUMN_NAME[] = "rel"; + +public: + static graph::NativeGraphEntry bindGraphEntry(main::ClientContext& context, + const std::string& name); + static graph::NativeGraphEntry bindGraphEntry(main::ClientContext& context, + const graph::ParsedNativeGraphEntry& parsedGraphEntry); + static std::shared_ptr bindRelOutput(const TableFuncBindInput& bindInput, + const std::vector& relEntries, + std::shared_ptr srcNode, + std::shared_ptr dstNode, + const std::optional& name = std::nullopt, + const std::optional& yieldVariableIdx = std::nullopt); + static std::shared_ptr bindNodeOutput(const TableFuncBindInput& bindInput, + const std::vector& nodeEntries, + const std::optional& name = std::nullopt, + const std::optional& yieldVariableIdx = std::nullopt); + static std::string bindColumnName(const parser::YieldVariable& yieldVariable, + std::string expressionName); + + static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input); + static void getLogicalPlan(planner::Planner* planner, + const binder::BoundReadingClause& readingClause, binder::expression_vector predicates, + planner::LogicalPlan& plan); + static std::unique_ptr getPhysicalPlan( + processor::PlanMapper* planMapper, const planner::LogicalOperator* logicalOp); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_frontier.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_frontier.h new file mode 100644 index 0000000000..3ad85187f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_frontier.h @@ -0,0 +1,317 @@ +#pragma once + +#include +#include + +#include "compute.h" +#include "density_state.h" +#include "gds_object_manager.h" + +namespace lbug { +namespace processor { +struct ExecutionContext; +} +namespace function { + +using iteration_t = uint16_t; +static constexpr iteration_t FRONTIER_UNVISITED = UINT16_MAX; +static constexpr iteration_t FRONTIER_INITIAL_VISITED = 0; + +// Base frontier implementation. +// A frontier keeps track of the existence of node. Instead of using boolean, we assign an iteration +// number to each node. A node with iteration number "i", meaning it is visited in the i-th +// iteration. +class LBUG_API Frontier { +public: + virtual ~Frontier() = default; + + virtual void pinTableID(common::table_id_t tableID) = 0; + + virtual void addNode(common::nodeID_t nodeID, iteration_t iter) = 0; + virtual void addNode(common::offset_t offset, iteration_t iter) = 0; + virtual void addNodes(const std::vector& nodeIDs, iteration_t iter) = 0; + + virtual iteration_t getIteration(common::offset_t offset) const = 0; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +// Sparse frontier implementation assuming the number of nodes is small. +// Use an STL hash map to maintain node offset-> iteration number +class LBUG_API SparseFrontier : public Frontier { + friend class SparseFrontierReference; + friend class SPFrontierPair; + friend class DenseSparseDynamicFrontierPair; + +public: + explicit SparseFrontier(const common::table_id_map_t& nodeMaxOffsetMap) + : sparseObjects{nodeMaxOffsetMap} {} + + void pinTableID(common::table_id_t tableID) override; + + void addNode(common::nodeID_t nodeID, iteration_t iter) override; + void addNode(common::offset_t offset, iteration_t iter) override; + void addNodes(const std::vector& nodeIDs, iteration_t iter) override; + + iteration_t getIteration(common::offset_t offset) const override; + + uint64_t size() const { return sparseObjects.size(); } + + const std::unordered_map& getCurrentData() const { + return *curData; + } + +private: + GDSSpareObjectManager sparseObjects; + std::unordered_map* curData = nullptr; +}; + +// Sparse frontier implementation that refers to the data owned by another sparse frontier. +// This should be used only for shortest-path type of algorithms where a node is guaranteed +// to be visited only once. See SPFrontierPair for its usage. +class SparseFrontierReference : public Frontier { +public: + explicit SparseFrontierReference(SparseFrontier& frontier) + : sparseObjects{frontier.sparseObjects} {} + + void pinTableID(common::table_id_t tableID) override; + + void addNode(common::nodeID_t nodeID, iteration_t iter) override; + void addNode(common::offset_t offset, iteration_t iter) override; + void addNodes(const std::vector& nodeIDs, iteration_t iter) override; + + iteration_t getIteration(common::offset_t offset) const override; + + const std::unordered_map& getCurrentData() const { + return *curData; + } + +private: + GDSSpareObjectManager& sparseObjects; + std::unordered_map* curData = nullptr; +}; + +// Dense frontier implementation assuming the number of nodes is large. +// Use an array of iteration number. The array is allocated to max offset +class LBUG_API DenseFrontier : public Frontier { + friend class SparseFrontier; + friend class DenseFrontierReference; + friend class SPFrontierPair; + friend class DenseSparseDynamicFrontierPair; + +public: + explicit DenseFrontier(const common::table_id_map_t& nodeMaxOffsetMap) + : nodeMaxOffsetMap{nodeMaxOffsetMap} {} + DenseFrontier(const DenseFrontier& other) = delete; + DenseFrontier(const DenseFrontier&& other) = delete; + + // Allocate memory and initialize. + void init(processor::ExecutionContext* context, graph::Graph* graph, iteration_t val); + void resetValue(processor::ExecutionContext* context, graph::Graph* graph, iteration_t val); + + void pinTableID(common::table_id_t tableID) override; + + void addNode(common::nodeID_t nodeID, iteration_t iter) override; + void addNode(common::offset_t offset, iteration_t iter) override; + void addNodes(const std::vector& nodeIDs, iteration_t iter) override; + + iteration_t getIteration(common::offset_t offset) const override; + + // Get frontier without initialization. + static std::unique_ptr getUninitializedFrontier( + processor::ExecutionContext* context, graph::Graph* graph); + // Get frontier initialized to UNVISITED. + static std::unique_ptr getUnvisitedFrontier(processor::ExecutionContext* context, + graph::Graph* graph); + // Get frontier initialized to INITIAL_VISITED. + static std::unique_ptr getVisitedFrontier(processor::ExecutionContext* context, + graph::Graph* graph); + // Init frontier to 0 according to mask + static std::unique_ptr getVisitedFrontier(processor::ExecutionContext* context, + graph::Graph* graph, common::NodeOffsetMaskMap* maskMap); + +private: + common::table_id_map_t nodeMaxOffsetMap; + GDSDenseObjectManager> denseObjects; + std::atomic* curData = nullptr; +}; + +// Dense frontier implementation that refers to the data owned by another dense frontier. +// Should be used in the same case as SparseFrontierReference +class DenseFrontierReference : public Frontier { + friend class SPFrontierPair; + +public: + explicit DenseFrontierReference(const DenseFrontier& denseFrontier) + : denseObjects{denseFrontier.denseObjects} {} + + void pinTableID(common::table_id_t tableID) override; + + void addNode(common::nodeID_t nodeID, iteration_t iter) override; + void addNode(common::offset_t offset, iteration_t iter) override; + void addNodes(const std::vector& nodeIDs, iteration_t iter) override; + + iteration_t getIteration(common::offset_t offset) const override; + +private: + const GDSDenseObjectManager>& denseObjects; + std::atomic* curData = nullptr; +}; + +class LBUG_API FrontierPair { +public: + FrontierPair() { hasActiveNodesForNextIter_.store(false); } + virtual ~FrontierPair() = default; + + void resetCurrentIter() { curIter = 0; } + iteration_t getCurrentIter() const { return curIter; } + + void setActiveNodesForNextIter() { hasActiveNodesForNextIter_.store(true); } + + bool continueNextIter(uint16_t maxIter) { + return hasActiveNodesForNextIter_.load(std::memory_order_relaxed) && + getCurrentIter() < maxIter; + } + + // Initialize state for new iteration. + void beginNewIteration(); + void pinCurrentFrontier(common::table_id_t tableID); + void pinNextFrontier(common::table_id_t tableID); + // Pin current & next frontier + void beginFrontierComputeBetweenTables(common::table_id_t curTableID, + common::table_id_t nextTableID); + + // Write to next frontier + void addNodeToNextFrontier(common::nodeID_t nodeID); + void addNodeToNextFrontier(common::offset_t offset); + void addNodesToNextFrontier(const std::vector& nodeIDs); + + iteration_t getNextFrontierValue(common::offset_t offset); + bool isActiveOnCurrentFrontier(common::offset_t offset); + virtual std::unordered_set getActiveNodesOnCurrentFrontier() = 0; + + virtual GDSDensityState getState() const = 0; + virtual bool needSwitchToDense(uint64_t threshold) const = 0; + virtual void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) = 0; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + +protected: + virtual void beginNewIterationInternalNoLock() = 0; + +protected: + std::mutex mtx; + // curIter is the iteration number of the algorithm and starts from 0. + iteration_t curIter = 0; + std::atomic hasActiveNodesForNextIter_; + Frontier* currentFrontier = nullptr; + Frontier* nextFrontier = nullptr; +}; + +// Shortest path (excluding weighted shortest path )frontier implementation. Different from other +// recursive algorithms, shortest path has the guarantee that a node will not be visited repeatedly +// in different iteration. So we make current/next frontier reference writes to the same frontier. +class SPFrontierPair : public FrontierPair { +public: + explicit SPFrontierPair(std::unique_ptr denseFrontier); + + // Get sparse or dense frontier based on state. + // No need to specify current or next because there is only one frontier. + Frontier* getFrontier(); + + void beginNewIterationInternalNoLock() override; + + // Get number of active nodes in current frontier. Used for shortest path early termination. + common::offset_t getNumActiveNodesInCurrentFrontier(common::NodeOffsetMaskMap& mask); + + std::unordered_set getActiveNodesOnCurrentFrontier() override; + + GDSDensityState getState() const override { return state; } + bool needSwitchToDense(uint64_t threshold) const override { + return state == GDSDensityState::SPARSE && sparseFrontier->size() > threshold; + } + void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) override; + +private: + GDSDensityState state; + std::unique_ptr denseFrontier; + std::unique_ptr curDenseFrontier = nullptr; + std::unique_ptr nextDenseFrontier = nullptr; + std::unique_ptr sparseFrontier; + std::unique_ptr curSparseFrontier = nullptr; + std::unique_ptr nextSparseFrontier = nullptr; +}; + +// Frontier pair implementation that switches from sparse to dense adaptively. +class LBUG_API DenseSparseDynamicFrontierPair : public FrontierPair { +public: + DenseSparseDynamicFrontierPair(std::unique_ptr curDenseFrontier, + std::unique_ptr nextDenseFrontier); + + void beginNewIterationInternalNoLock() override; + + std::unordered_set getActiveNodesOnCurrentFrontier() override; + + GDSDensityState getState() const override { return state; } + bool needSwitchToDense(uint64_t threshold) const override { + return state == GDSDensityState::SPARSE && nextSparseFrontier->size() > threshold; + } + void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) override; + +private: + GDSDensityState state; + std::unique_ptr curDenseFrontier = nullptr; + std::unique_ptr nextDenseFrontier = nullptr; + std::unique_ptr curSparseFrontier = nullptr; + std::unique_ptr nextSparseFrontier = nullptr; +}; + +// Frontier pair implementation that only uses dense frontier. This is mostly used in +// algorithms like wcc, scc where algorithms touch all nodes in the graph. +class LBUG_API DenseFrontierPair : public FrontierPair { +public: + DenseFrontierPair(std::unique_ptr curDenseFrontier, + std::unique_ptr nextDenseFrontier); + + void beginNewIterationInternalNoLock() override; + + std::unordered_set getActiveNodesOnCurrentFrontier() override { + KU_UNREACHABLE; + } + + void resetValue(processor::ExecutionContext* context, graph::Graph* graph, iteration_t val); + + GDSDensityState getState() const override { return GDSDensityState::DENSE; } + bool needSwitchToDense(uint64_t) const override { return false; } + void switchToDense(processor::ExecutionContext*, graph::Graph*) override { + // Do nothing. + } + +private: + std::shared_ptr curDenseFrontier; + std::shared_ptr nextDenseFrontier; +}; + +class SPEdgeCompute : public EdgeCompute { +public: + explicit SPEdgeCompute(SPFrontierPair* frontierPair) + : frontierPair{frontierPair}, numNodesReached{0} {} + + void resetSingleThreadState() override { numNodesReached = 0; } + + bool terminate(common::NodeOffsetMaskMap& maskMap) override; + +protected: + SPFrontierPair* frontierPair; + // States that should be only modified with single thread + common::offset_t numNodesReached; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_function_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_function_collection.h new file mode 100644 index 0000000000..d443cb73b2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_function_collection.h @@ -0,0 +1,57 @@ +#pragma once + +#include "function/gds/rec_joins.h" + +namespace lbug { +namespace function { + +struct VarLenJoinsFunction { + static constexpr const char* name = "VAR_LEN_JOINS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct AllSPDestinationsFunction { + static constexpr const char* name = "ALL_SP_DESTINATIONS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct AllSPPathsFunction { + static constexpr const char* name = "ALL_SP_PATHS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct SingleSPDestinationsFunction { + static constexpr const char* name = "SINGLE_SP_DESTINATIONS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct SingleSPPathsFunction { + static constexpr const char* name = "SINGLE_SP_PATHS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct WeightedSPDestinationsFunction { + static constexpr const char* name = "WEIGHTED_SP_DESTINATIONS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct WeightedSPPathsFunction { + static constexpr const char* name = "WEIGHTED_SP_PATHS"; + + static std::unique_ptr getAlgorithm(); +}; + +struct AllWeightedSPPathsFunction { + static constexpr const char* name = "ALL_WEIGHTED_SP_PATHS"; + + static std::unique_ptr getAlgorithm(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_object_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_object_manager.h new file mode 100644 index 0000000000..744e866483 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_object_manager.h @@ -0,0 +1,250 @@ +#pragma once + +#include +#include + +#include "storage/buffer_manager/memory_manager.h" +#include "storage/buffer_manager/mm_allocator.h" + +namespace lbug { +namespace function { + +// ObjectBlock represents a pre-allocated amount of memory that can hold up to maxElements objects +// ObjectBlock should be accessed by a single thread. +template +class ObjectBlock { +public: + ObjectBlock(std::unique_ptr block, uint64_t sizeInBytes) + : block{std::move(block)} { + maxElements.store(sizeInBytes / (sizeof(T)), std::memory_order_relaxed); + nextPosToWrite.store(0, std::memory_order_relaxed); + } + + T* reserveNext() { return getData() + nextPosToWrite.fetch_add(1, std::memory_order_relaxed); } + void revertLast() { nextPosToWrite.fetch_sub(1, std::memory_order_relaxed); } + + bool hasSpace() const { + return nextPosToWrite.load(std::memory_order_relaxed) < + maxElements.load(std::memory_order_relaxed); + } + +private: + T* getData() const { return reinterpret_cast(block->getData()); } + +private: + std::unique_ptr block; + std::atomic maxElements; + std::atomic nextPosToWrite; +}; + +// Pre-allocated array of objects. +template +class ObjectArray { +public: + ObjectArray() : size{0} {} + ObjectArray(const common::offset_t size, storage::MemoryManager* mm, + bool initializeToZero = false) + : size{size}, mm{mm} { + allocate(size, mm, initializeToZero); + } + + void allocate(const common::offset_t size, storage::MemoryManager* mm, bool initializeToZero) { + allocation = mm->allocateBuffer(initializeToZero, size * sizeof(T)); + data = std::span(reinterpret_cast(allocation->getData()), size); + this->size = size; + } + + common::offset_t getSize() const { return size; } + + void reallocate(const common::offset_t newSize, storage::MemoryManager* mm) { + if (newSize > size) { + allocate(newSize, mm, false /* initializeToZero */); + } + } + + void set(const common::offset_t pos, const T value) { + KU_ASSERT_UNCONDITIONAL(pos < size); + data[pos] = value; + } + + const T& get(const common::offset_t pos) const { + KU_ASSERT_UNCONDITIONAL(pos < size); + return data[pos]; + } + + T& getUnsafe(const common::offset_t pos) { + KU_ASSERT_UNCONDITIONAL(pos < size); + return data[pos]; + } + +private: + template + friend class AtomicObjectArray; + common::offset_t size; + std::span data; + std::unique_ptr allocation; + storage::MemoryManager* mm = nullptr; +}; + +// Pre-allocated array of atomic objects. +template +class AtomicObjectArray { +public: + AtomicObjectArray() = default; + AtomicObjectArray(const common::offset_t size, storage::MemoryManager* mm, + bool initializeToZero = false) + : array{ObjectArray>(size, mm, initializeToZero)} {} + + common::offset_t getSize() const { return array.size; } + + void reallocate(const common::offset_t newSize, storage::MemoryManager* mm) { + array.reallocate(newSize, mm); + } + + void set(common::offset_t pos, const T& value, + std::memory_order order = std::memory_order_seq_cst) { + KU_ASSERT_UNCONDITIONAL(pos < array.size); + array.data[pos].store(value, order); + } + + T get(const common::offset_t pos, std::memory_order order = std::memory_order_seq_cst) { + KU_ASSERT_UNCONDITIONAL(pos < array.size); + return array.data[pos].load(order); + } + + void fetchAdd(common::offset_t pos, const T& value, + std::memory_order order = std::memory_order_seq_cst) { + KU_ASSERT_UNCONDITIONAL(pos < array.size); + array.data[pos].fetch_add(value, order); + } + + bool compareExchangeMax(const common::offset_t src, const common::offset_t dest, + std::memory_order order = std::memory_order_seq_cst) { + auto srcValue = get(src, order); + auto dstValue = get(dest, order); + // From https://en.cppreference.com/w/cpp/std::atomic/std::atomic/compare_exchange: + // When a compare-and-exchange is in a loop, the weak version will yield better performance + // on some platforms. + while (dstValue < srcValue) { + if (array.data[dest].compare_exchange_weak(dstValue, srcValue)) { + return true; + } + } + return false; + } + +private: + ObjectArray> array; +}; + +template +class ku_vector_t { +public: + explicit ku_vector_t(storage::MemoryManager* mm) : vec(storage::MmAllocator(mm)) {} + ku_vector_t(storage::MemoryManager* mm, std::size_t size) + : vec(size, storage::MmAllocator(mm)) {} + + void reserve(std::size_t size) { vec.reserve(size); } + + void resize(std::size_t size) { vec.resize(size); } + + void push_back(const T& value) { vec.push_back(value); } + + void push_back(T&& value) { vec.push_back(std::move(value)); } + + bool empty() { return vec.empty(); } + + auto begin() { return vec.begin(); } + + auto end() { return vec.end(); } + + auto begin() const { return vec.begin(); } + + auto end() const { return vec.end(); } + + template + void emplace_back(Args&&... args) { + vec.emplace_back(std::forward(args)...); + } + + void pop_back() { vec.pop_back(); } + + void clear() { vec.clear(); } + + std::size_t size() const { return vec.size(); } + + T& operator[](std::size_t index) { return vec[index]; } + const T& operator[](std::size_t index) const { return vec[index]; } + + T& at(std::size_t index) { return vec.at(index); } + const T& at(std::size_t index) const { return vec.at(index); } + +private: + std::vector> vec; +}; + +// ObjectArraysMap represents a pre-allocated amount of object per tableID. +template +class GDSDenseObjectManager { +public: + void allocate(common::table_id_t tableID, common::offset_t maxOffset, + storage::MemoryManager* mm) { + auto buffer = mm->allocateBuffer(false, maxOffset * sizeof(T)); + bufferPerTable.insert({tableID, std::move(buffer)}); + } + + T* getData(common::table_id_t tableID) const { + KU_ASSERT(bufferPerTable.contains(tableID)); + return reinterpret_cast(bufferPerTable.at(tableID)->getData()); + } + +private: + common::table_id_map_t> bufferPerTable; +}; + +template +class GDSSpareObjectManager { +public: + explicit GDSSpareObjectManager( + const common::table_id_map_t& nodeMaxOffsetMap) { + for (auto& [tableID, _] : nodeMaxOffsetMap) { + allocate(tableID); + } + } + + void allocate(common::table_id_t tableID) { + KU_ASSERT(!mapPerTable.contains(tableID)); + mapPerTable.insert({tableID, {}}); + } + + const common::table_id_map_t>& getData() { + return mapPerTable; + } + + std::unordered_map* getMap(common::table_id_t tableID) { + KU_ASSERT(mapPerTable.contains(tableID)); + return &mapPerTable.at(tableID); + } + + std::unordered_map* getData(common::table_id_t tableID) { + if (!mapPerTable.contains(tableID)) { + mapPerTable.insert({tableID, {}}); + } + KU_ASSERT(mapPerTable.contains(tableID)); + return &mapPerTable.at(tableID); + } + + uint64_t size() const { + uint64_t result = 0; + for (auto [_, map] : mapPerTable) { + result += map.size(); + } + return result; + } + +private: + common::table_id_map_t> mapPerTable; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_state.h new file mode 100644 index 0000000000..1d536dbdc1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_state.h @@ -0,0 +1,37 @@ +#pragma once + +#include "auxiliary_state/gds_auxilary_state.h" +#include "gds_frontier.h" + +namespace lbug { +namespace function { + +struct GDSComputeState { + std::shared_ptr frontierPair = nullptr; + std::unique_ptr edgeCompute = nullptr; + std::unique_ptr auxiliaryState = nullptr; + + GDSComputeState(std::shared_ptr frontierPair, + std::unique_ptr edgeCompute, std::unique_ptr auxiliaryState) + : frontierPair{std::move(frontierPair)}, edgeCompute{std::move(edgeCompute)}, + auxiliaryState{std::move(auxiliaryState)} {} + + void initSource(common::nodeID_t sourceNodeID) const; + // When performing computations on multi-label graphs, it is beneficial to fix a single + // node table of nodes in the current frontier and a single node table of nodes for the next + // frontier. That is because algorithms will perform extensions using a single relationship + // table at a time, and each relationship table R is between a single source node table S and + // a single destination node table T. Therefore, during execution the algorithm will need to + // check only the active S nodes in current frontier and update the active statuses of only the + // T nodes in the next frontier. The information that the algorithm is beginning and S-to-T + // extensions are be given to the data structures of the computation, e.g., FrontierPairs and + // RJOutputs, to possibly avoid them doing lookups of S and T-related data structures, + // e.g., maps, internally. + void beginFrontierCompute(common::table_id_t currTableID, common::table_id_t nextTableID) const; + + // Switch all data structures (frontierPair & auxiliaryState) to dense version. + void switchToDense(processor::ExecutionContext* context, graph::Graph* graph) const; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_task.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_task.h new file mode 100644 index 0000000000..bc947eda52 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_task.h @@ -0,0 +1,106 @@ +#pragma once + +#include + +#include "common/enums/extend_direction.h" +#include "common/task_system/task.h" +#include "frontier_morsel.h" +#include "function/gds/gds_frontier.h" +#include "graph/graph.h" + +namespace lbug { +namespace function { + +struct FrontierTaskInfo { + common::table_id_t srcTableID; + common::table_id_t dstTableID; + catalog::TableCatalogEntry* relGroupEntry = nullptr; + graph::Graph* graph; + common::ExtendDirection direction; + EdgeCompute& edgeCompute; + std::vector propertiesToScan; + + FrontierTaskInfo(common::table_id_t srcTableID, common::table_id_t dstTableID, + catalog::TableCatalogEntry* relGroupEntry, graph::Graph* graph, + common::ExtendDirection direction, EdgeCompute& edgeCompute, + std::vector propertiesToScan) + : srcTableID{srcTableID}, dstTableID{dstTableID}, relGroupEntry{relGroupEntry}, + graph{graph}, direction{direction}, edgeCompute{edgeCompute}, + propertiesToScan{std::move(propertiesToScan)} {} + FrontierTaskInfo(const FrontierTaskInfo& other) + : srcTableID{other.srcTableID}, dstTableID{other.dstTableID}, + relGroupEntry{other.relGroupEntry}, graph{other.graph}, direction{other.direction}, + edgeCompute{other.edgeCompute}, propertiesToScan{other.propertiesToScan} {} + + common::table_id_t getBoundTableID() const; + common::table_id_t getNbrTableID() const; + common::oid_t getRelTableID() const; +}; + +struct FrontierTaskSharedState { + FrontierMorselDispatcher morselDispatcher; + FrontierPair& frontierPair; + + FrontierTaskSharedState(uint64_t maxNumThreads, FrontierPair& frontierPair) + : morselDispatcher{maxNumThreads}, frontierPair{frontierPair} {} + DELETE_COPY_AND_MOVE(FrontierTaskSharedState); +}; + +class FrontierTask : public common::Task { +public: + FrontierTask(uint64_t maxNumThreads, const FrontierTaskInfo& info, + std::shared_ptr sharedState) + : Task{maxNumThreads}, info{info}, sharedState{std::move(sharedState)} {} + + void run() override; + + void runSparse(); + +private: + FrontierTaskInfo info; + std::shared_ptr sharedState; +}; + +struct VertexComputeTaskSharedState { + FrontierMorselDispatcher morselDispatcher; + + explicit VertexComputeTaskSharedState(uint64_t maxNumThreads) + : morselDispatcher{maxNumThreads} {} +}; + +struct VertexComputeTaskInfo { + VertexCompute& vc; + graph::Graph* graph; + catalog::TableCatalogEntry* tableEntry; + std::vector propertiesToScan; + + VertexComputeTaskInfo(VertexCompute& vc, graph::Graph* graph, + catalog::TableCatalogEntry* tableEntry, std::vector propertiesToScan) + : vc{vc}, graph{graph}, tableEntry{tableEntry}, + propertiesToScan{std::move(propertiesToScan)} {} + VertexComputeTaskInfo(const VertexComputeTaskInfo& other) + : vc{other.vc}, graph{other.graph}, tableEntry{other.tableEntry}, + propertiesToScan{other.propertiesToScan} {} + + bool hasPropertiesToScan() const { return !propertiesToScan.empty(); } +}; + +class VertexComputeTask : public common::Task { +public: + VertexComputeTask(uint64_t maxNumThreads, const VertexComputeTaskInfo& info, + std::shared_ptr sharedState) + : common::Task{maxNumThreads}, info{info}, sharedState{std::move(sharedState)} {}; + + VertexComputeTaskSharedState* getSharedState() const { return sharedState.get(); } + + void run() override; + + void runSparse(); + +private: + VertexComputeTaskInfo info; + std::shared_ptr sharedState; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_utils.h new file mode 100644 index 0000000000..b6a3eecf1f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_utils.h @@ -0,0 +1,39 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/enums/extend_direction.h" +#include "gds_state.h" + +namespace lbug { +namespace function { + +class LBUG_API GDSUtils { +public: + // Run edge compute for graph algorithms + static void runAlgorithmEdgeCompute(processor::ExecutionContext* context, + GDSComputeState& compState, graph::Graph* graph, common::ExtendDirection extendDirection, + uint64_t maxIteration); + // Run edge compute for full text search + static void runFTSEdgeCompute(processor::ExecutionContext* context, GDSComputeState& compState, + graph::Graph* graph, common::ExtendDirection extendDirection, + const std::vector& propertiesToScan); + // Run edge compute for recursive join. + static void runRecursiveJoinEdgeCompute(processor::ExecutionContext* context, + GDSComputeState& compState, graph::Graph* graph, common::ExtendDirection extendDirection, + uint64_t maxIteration, common::NodeOffsetMaskMap* outputNodeMask, + const std::vector& propertiesToScan); + + // Run vertex compute without property scan + static void runVertexCompute(processor::ExecutionContext* context, GDSDensityState densityState, + graph::Graph* graph, VertexCompute& vc); + // Run vertex compute with property scan + static void runVertexCompute(processor::ExecutionContext* context, GDSDensityState densityState, + graph::Graph* graph, VertexCompute& vc, const std::vector& propertiesToScan); + // Run vertex compute on specific table with property scan + static void runVertexCompute(processor::ExecutionContext* context, GDSDensityState densityState, + graph::Graph* graph, VertexCompute& vc, catalog::TableCatalogEntry* entry, + const std::vector& propertiesToScan); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_vertex_compute.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_vertex_compute.h new file mode 100644 index 0000000000..2a45fbac23 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/gds_vertex_compute.h @@ -0,0 +1,61 @@ +#pragma once + +#include "function/gds/compute.h" +#include "function/gds/gds.h" + +namespace lbug { +namespace function { + +class GDSVertexCompute : public VertexCompute { +public: + explicit GDSVertexCompute(common::NodeOffsetMaskMap* nodeMask) : nodeMask{nodeMask} {} + + bool beginOnTable(common::table_id_t tableID) override { + if (nodeMask != nullptr) { + nodeMask->pin(tableID); + } + beginOnTableInternal(tableID); + return true; + } + +protected: + bool skip(common::offset_t offset) { + if (nodeMask != nullptr && nodeMask->hasPinnedMask()) { + return !nodeMask->valid(offset); + } + return false; + } + + virtual void beginOnTableInternal(common::table_id_t tableID) = 0; + +protected: + common::NodeOffsetMaskMap* nodeMask; +}; + +class GDSResultVertexCompute : public GDSVertexCompute { +public: + GDSResultVertexCompute(storage::MemoryManager* mm, GDSFuncSharedState* sharedState) + : GDSVertexCompute{sharedState->getGraphNodeMaskMap()}, sharedState{sharedState}, mm{mm} { + localFT = sharedState->factorizedTablePool.claimLocalTable(mm); + } + ~GDSResultVertexCompute() override { + sharedState->factorizedTablePool.returnLocalTable(localFT); + } + +protected: + std::unique_ptr createVector(const common::LogicalType& type) { + auto vector = std::make_unique(type.copy(), mm); + vector->state = common::DataChunkState::getSingleValueDataChunkState(); + vectors.push_back(vector.get()); + return vector; + } + +protected: + GDSFuncSharedState* sharedState; + storage::MemoryManager* mm; + processor::FactorizedTable* localFT = nullptr; + std::vector vectors; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/rec_joins.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/rec_joins.h new file mode 100644 index 0000000000..3ce606dffb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/rec_joins.h @@ -0,0 +1,61 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/enums/extend_direction.h" +#include "common/enums/path_semantic.h" +#include "function/gds/gds_state.h" +#include "graph/graph_entry.h" +#include "processor/operator/recursive_extend_shared_state.h" +#include "rj_output_writer.h" + +namespace lbug { +namespace function { + +struct RJBindData { + graph::NativeGraphEntry graphEntry; + + std::shared_ptr nodeInput = nullptr; + std::shared_ptr nodeOutput = nullptr; + // For any form of shortest path lower bound must always be 1. + // If lowerBound equals to 0, an empty path with source node only will be returned. + uint16_t lowerBound = 0; + uint16_t upperBound = 0; + common::PathSemantic semantic = common::PathSemantic::WALK; + + common::ExtendDirection extendDirection = common::ExtendDirection::FWD; + + bool flipPath = false; // See PathsOutputWriterInfo::flipPath for comments. + bool writePath = true; + + std::shared_ptr directionExpr = nullptr; + std::shared_ptr lengthExpr = nullptr; + std::shared_ptr pathNodeIDsExpr = nullptr; + std::shared_ptr pathEdgeIDsExpr = nullptr; + + std::shared_ptr weightPropertyExpr = nullptr; + std::shared_ptr weightOutputExpr = nullptr; + + explicit RJBindData(graph::NativeGraphEntry graphEntry) : graphEntry{std::move(graphEntry)} {} + RJBindData(const RJBindData& other); + + PathsOutputWriterInfo getPathWriterInfo() const; +}; + +class RJAlgorithm { +public: + virtual ~RJAlgorithm() = default; + + virtual std::string getFunctionName() const = 0; + virtual binder::expression_vector getResultColumns(const RJBindData& bindData) const = 0; + + virtual std::unique_ptr getComputeState(processor::ExecutionContext* context, + const RJBindData& bindData, processor::RecursiveExtendSharedState* sharedState) = 0; + virtual std::unique_ptr getOutputWriter(processor::ExecutionContext* context, + const RJBindData& bindData, GDSComputeState& computeState, common::nodeID_t sourceNodeID, + processor::RecursiveExtendSharedState* sharedState) = 0; + + virtual std::unique_ptr copy() const = 0; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/rj_output_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/rj_output_writer.h new file mode 100644 index 0000000000..0aba7834ff --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/rj_output_writer.h @@ -0,0 +1,135 @@ +#pragma once + +#include "bfs_graph.h" +#include "common/counter.h" +#include "common/enums/path_semantic.h" +#include "common/mask.h" +#include "common/types/types.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace function { + +class RJOutputWriter { +public: + RJOutputWriter(main::ClientContext* context, common::NodeOffsetMaskMap* outputNodeMask, + common::nodeID_t sourceNodeID); + virtual ~RJOutputWriter() = default; + + void beginWriting(common::table_id_t tableID) { + pinOutputNodeMask(tableID); + beginWritingInternal(tableID); + } + virtual void beginWritingInternal(common::table_id_t tableID) = 0; + // Write + virtual void write(processor::FactorizedTable& fTable, common::table_id_t tableID, + common::LimitCounter* counter) = 0; + virtual void write(processor::FactorizedTable& fTable, common::nodeID_t dstNodeID, + common::LimitCounter* counter) = 0; + + bool inOutputNodeMask(common::offset_t offset); + + virtual std::unique_ptr copy() = 0; + +protected: + std::unique_ptr createVector(const common::LogicalType& type); + + void pinOutputNodeMask(common::table_id_t tableID); + +protected: + main::ClientContext* context; + common::NodeOffsetMaskMap* outputNodeMask; + common::nodeID_t sourceNodeID_; + + std::vector vectors; + std::unique_ptr srcNodeIDVector; + std::unique_ptr dstNodeIDVector; +}; + +struct PathsOutputWriterInfo { + // Semantic + common::PathSemantic semantic = common::PathSemantic::WALK; + // Range + uint16_t lowerBound = 0; + // Direction + bool flipPath = false; + bool writeEdgeDirection = false; + bool writePath = false; + // Node predicate mask + common::NodeOffsetMaskMap* pathNodeMask = nullptr; + + bool hasNodeMask() const { return pathNodeMask != nullptr; } +}; + +class PathsOutputWriter : public RJOutputWriter { +public: + PathsOutputWriter(main::ClientContext* context, common::NodeOffsetMaskMap* outputNodeMask, + common::nodeID_t sourceNodeID, PathsOutputWriterInfo info, BaseBFSGraph& bfsGraph); + + void beginWritingInternal(common::table_id_t tableID) override { bfsGraph.pinTableID(tableID); } + + void write(processor::FactorizedTable& fTable, common::table_id_t tableID, + common::LimitCounter* counter) override; + void write(processor::FactorizedTable& fTable, common::nodeID_t dstNodeID, + common::LimitCounter* counter) override; + +protected: + virtual void writeInternal(processor::FactorizedTable& fTable, common::nodeID_t dstNodeID, + common::LimitCounter* counter) = 0; + // Fast path when there is no node predicate or semantic check + void dfsFast(ParentList* firstParent, processor::FactorizedTable& fTable, + common::LimitCounter* counter); + // Slow path to check node predicate or semantic. + void dfsSlow(ParentList* firstParent, processor::FactorizedTable& fTable, + common::LimitCounter* counter); + + bool updateCounterAndTerminate(common::LimitCounter* counter); + + ParentList* findFirstParent(common::offset_t dstOffset) const; + + bool checkPathNodeMask(ParentList* element) const; + // Check semantics + bool checkAppendSemantic(const std::vector& path, ParentList* candidate) const; + bool checkReplaceTopSemantic(const std::vector& path, ParentList* candidate) const; + bool isAppendTrail(const std::vector& path, ParentList* candidate) const; + bool isAppendAcyclic(const std::vector& path, ParentList* candidate) const; + bool isReplaceTopTrail(const std::vector& path, ParentList* candidate) const; + bool isReplaceTopAcyclic(const std::vector& path, ParentList* candidate) const; + + bool isNextViable(ParentList* next, const std::vector& path) const; + + void beginWritePath(common::idx_t length) const; + void writePath(const std::vector& path) const; + void writePathFwd(const std::vector& path) const; + void writePathBwd(const std::vector& path) const; + + void addEdge(common::relID_t edgeID, bool fwdEdge, common::sel_t pos) const; + void addNode(common::nodeID_t nodeID, common::sel_t pos) const; + +protected: + PathsOutputWriterInfo info; + BaseBFSGraph& bfsGraph; + + std::unique_ptr directionVector = nullptr; + std::unique_ptr lengthVector = nullptr; + std::unique_ptr pathNodeIDsVector = nullptr; + std::unique_ptr pathEdgeIDsVector = nullptr; +}; + +class SPPathsOutputWriter : public PathsOutputWriter { +public: + SPPathsOutputWriter(main::ClientContext* context, common::NodeOffsetMaskMap* outputNodeMask, + common::nodeID_t sourceNodeID, PathsOutputWriterInfo info, BaseBFSGraph& bfsGraph) + : PathsOutputWriter{context, outputNodeMask, sourceNodeID, info, bfsGraph} {} + + void writeInternal(processor::FactorizedTable& fTable, common::nodeID_t dstNodeID, + common::LimitCounter* counter) override; + + std::unique_ptr copy() override { + return std::make_unique(context, outputNodeMask, sourceNodeID_, info, + bfsGraph); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/weight_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/weight_utils.h new file mode 100644 index 0000000000..90856ece6c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/gds/weight_utils.h @@ -0,0 +1,101 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "common/types/types.h" + +namespace lbug { +namespace function { + +struct WeightUtils { + template + static auto visit(const std::string& fcn, const common::LogicalType& dataType, Fs... funcs); + + template + static auto visit(const std::string& fcn, const common::LogicalTypeID& dataType, Fs... funcs); + + template + static void checkWeight(const std::string& fcn, T weight); +}; + +template +auto WeightUtils::visit(const std::string& fcn, const common::LogicalType& dataType, Fs... funcs) { + auto func = common::overload(funcs...); + switch (dataType.getLogicalTypeID()) { + /* NOLINTBEGIN(bugprone-branch-clone)*/ + case common::LogicalTypeID::INT8: + return func(int8_t()); + case common::LogicalTypeID::UINT8: + return func(uint8_t()); + case common::LogicalTypeID::INT16: + return func(int16_t()); + case common::LogicalTypeID::UINT16: + return func(uint16_t()); + case common::LogicalTypeID::INT32: + return func(int32_t()); + case common::LogicalTypeID::UINT32: + return func(uint32_t()); + case common::LogicalTypeID::INT64: + return func(int64_t()); + case common::LogicalTypeID::UINT64: + return func(uint64_t()); + case common::LogicalTypeID::DOUBLE: + return func(double()); + case common::LogicalTypeID::FLOAT: + return func(float()); + /* NOLINTEND(bugprone-branch-clone)*/ + default: + break; + } + // LCOV_EXCL_START + throw common::RuntimeException( + common::stringFormat("{} weight type is not supported for {}.", dataType.toString(), fcn)); + // LCOV_EXCL_STOP +} + +template +auto WeightUtils::visit(const std::string& fcn, const common::LogicalTypeID& dataType, + Fs... funcs) { + auto func = common::overload(funcs...); + switch (dataType) { + /* NOLINTBEGIN(bugprone-branch-clone)*/ + case common::LogicalTypeID::INT8: + return func(int8_t()); + case common::LogicalTypeID::UINT8: + return func(uint8_t()); + case common::LogicalTypeID::INT16: + return func(int16_t()); + case common::LogicalTypeID::UINT16: + return func(uint16_t()); + case common::LogicalTypeID::INT32: + return func(int32_t()); + case common::LogicalTypeID::UINT32: + return func(uint32_t()); + case common::LogicalTypeID::INT64: + return func(int64_t()); + case common::LogicalTypeID::UINT64: + return func(uint64_t()); + case common::LogicalTypeID::DOUBLE: + return func(double()); + case common::LogicalTypeID::FLOAT: + return func(float()); + /* NOLINTEND(bugprone-branch-clone)*/ + default: + break; + } + // LCOV_EXCL_START + throw common::RuntimeException(common::stringFormat("{} weight type is not supported for {}.", + common::LogicalType(dataType).toString(), fcn)); + // LCOV_EXCL_STOP +} + +template +void WeightUtils::checkWeight(const std::string& fcn, T weight) { + if (weight < 0) { + [[unlikely]] throw common::RuntimeException(common::stringFormat( + "Found negative weight {}. This is not a supported weight for {}", weight, fcn)); + } +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/hash/hash_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/hash/hash_functions.h new file mode 100644 index 0000000000..d2fb586658 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/hash/hash_functions.h @@ -0,0 +1,187 @@ +#pragma once + +#include +#include +#include + +#include "common/exception/runtime.h" +#include "common/types/int128_t.h" +#include "common/types/interval_t.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "common/types/uint128_t.h" + +namespace lbug { +namespace function { + +constexpr const uint64_t NULL_HASH = UINT64_MAX; + +inline common::hash_t murmurhash64(uint64_t x) { + // taken from https://nullprogram.com/blog/2018/07/31. + x ^= x >> 32; + x *= 0xd6e8feb86659fd93U; + x ^= x >> 32; + x *= 0xd6e8feb86659fd93U; + x ^= x >> 32; + return x; +} + +inline common::hash_t combineHashScalar(const common::hash_t a, const common::hash_t b) { + return (a * UINT64_C(0xbf58476d1ce4e5b9)) ^ b; +} + +struct Hash { + template + static void operation(const T& /*key*/, common::hash_t& /*result*/) { + // LCOV_EXCL_START + throw common::RuntimeException( + "Hash type: " + std::string(typeid(T).name()) + " is not supported."); + // LCOV_EXCL_STOP + } + + template + static void operation(const T& key, bool isNull, common::hash_t& result) { + if (isNull) { + result = NULL_HASH; + return; + } + operation(key, result); + } +}; + +struct CombineHash { + static inline void operation(const common::hash_t& left, const common::hash_t& right, + common::hash_t& result) { + result = combineHashScalar(left, right); + } +}; + +template<> +inline void Hash::operation(const common::internalID_t& key, common::hash_t& result) { + result = murmurhash64(key.offset) ^ murmurhash64(key.tableID); +} + +template<> +inline void Hash::operation(const bool& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const uint8_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const uint16_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const uint32_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const uint64_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const int64_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const int32_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const int16_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const int8_t& key, common::hash_t& result) { + result = murmurhash64(key); +} + +template<> +inline void Hash::operation(const common::int128_t& key, common::hash_t& result) { + result = murmurhash64(key.low) ^ murmurhash64(key.high); +} + +template<> +inline void Hash::operation(const common::uint128_t& key, common::hash_t& result) { + result = murmurhash64(key.low) ^ murmurhash64(key.high); +} + +template<> +inline void Hash::operation(const double& key, common::hash_t& result) { + // 0 and -0 are not byte-equivalent, but should have the same hash + if (key == 0) { + result = murmurhash64(0); + } else { + result = murmurhash64(*reinterpret_cast(&key)); + } +} + +template<> +inline void Hash::operation(const float& key, common::hash_t& result) { + // 0 and -0 are not byte-equivalent, but should have the same hash + if (key == 0) { + result = murmurhash64(0); + } else { + result = murmurhash64(*reinterpret_cast(&key)); + } +} + +template<> +inline void Hash::operation(const std::string_view& key, common::hash_t& result) { + common::hash_t hashValue = 0; + auto data64 = reinterpret_cast(key.data()); + for (size_t i = 0u; i < key.size() / 8; i++) { + auto blockHash = lbug::function::murmurhash64(*(data64 + i)); + hashValue = lbug::function::combineHashScalar(hashValue, blockHash); + } + uint64_t last = 0; + for (size_t i = 0u; i < key.size() % 8; i++) { + last |= static_cast(key[key.size() / 8 * 8 + i]) << i * 8; + } + hashValue = lbug::function::combineHashScalar(hashValue, lbug::function::murmurhash64(last)); + result = hashValue; +} + +template<> +inline void Hash::operation(const std::string& key, common::hash_t& result) { + Hash::operation(std::string_view(key), result); +} + +template<> +inline void Hash::operation(const common::ku_string_t& key, common::hash_t& result) { + Hash::operation(key.getAsStringView(), result); +} + +template<> +inline void Hash::operation(const common::interval_t& key, common::hash_t& result) { + result = combineHashScalar(murmurhash64(key.months), + combineHashScalar(murmurhash64(key.days), murmurhash64(key.micros))); +} + +template<> +inline void Hash::operation(const std::unordered_set& key, common::hash_t& result) { + for (auto&& s : key) { + result ^= std::hash()(s); + } +} + +struct InternalIDHasher { + std::size_t operator()(const common::internalID_t& internalID) const { + common::hash_t result = 0; + function::Hash::operation(internalID, result); + return result; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/hash/vector_hash_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/hash/vector_hash_functions.h new file mode 100644 index 0000000000..f8c2b5788e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/hash/vector_hash_functions.h @@ -0,0 +1,53 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct UnaryHashFunctionExecutor { + template + static void execute(const common::ValueVector& operand, + const common::SelectionView& operandSelectVec, common::ValueVector& result, + const common::SelectionView& resultSelectVec); +}; + +struct BinaryHashFunctionExecutor { + template + static void execute(const common::ValueVector& left, const common::SelectionView& leftSelVec, + const common::ValueVector& right, const common::SelectionView& rightSelVec, + common::ValueVector& result, const common::SelectionView& resultSelVec); +}; + +struct VectorHashFunction { + static void computeHash(const common::ValueVector& operand, + const common::SelectionView& operandSelectVec, common::ValueVector& result, + const common::SelectionView& resultSelectVec); + + static void combineHash(const common::ValueVector& left, + const common::SelectionView& leftSelVec, const common::ValueVector& right, + const common::SelectionView& rightSelVec, common::ValueVector& result, + const common::SelectionView& resultSelVec); +}; + +struct MD5Function { + static constexpr const char* name = "MD5"; + + static function_set getFunctionSet(); +}; + +struct SHA256Function { + static constexpr const char* name = "SHA256"; + + static function_set getFunctionSet(); +}; + +struct HashFunction { + static constexpr const char* name = "HASH"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/internal_id/vector_internal_id_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/internal_id/vector_internal_id_functions.h new file mode 100644 index 0000000000..e4d01c5988 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/internal_id/vector_internal_id_functions.h @@ -0,0 +1,15 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct InternalIDCreationFunction { + static constexpr const char* name = "internal_id"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/interval/interval_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/interval/interval_functions.h new file mode 100644 index 0000000000..673b8e41c7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/interval/interval_functions.h @@ -0,0 +1,65 @@ +#pragma once + +#include "common/types/interval_t.h" + +namespace lbug { +namespace function { + +struct ToYears { + static inline void operation(int64_t& input, common::interval_t& result) { + result.days = result.micros = 0; + result.months = input * common::Interval::MONTHS_PER_YEAR; + } +}; + +struct ToMonths { + static inline void operation(int64_t& input, common::interval_t& result) { + result.days = result.micros = 0; + result.months = input; + } +}; + +struct ToDays { + static inline void operation(int64_t& input, common::interval_t& result) { + result.micros = result.months = 0; + result.days = input; + } +}; + +struct ToHours { + static inline void operation(int64_t& input, common::interval_t& result) { + result.months = result.days = 0; + result.micros = input * common::Interval::MICROS_PER_HOUR; + } +}; + +struct ToMinutes { + static inline void operation(int64_t& input, common::interval_t& result) { + result.months = result.days = 0; + result.micros = input * common::Interval::MICROS_PER_MINUTE; + } +}; + +struct ToSeconds { + static inline void operation(int64_t& input, common::interval_t& result) { + result.months = result.days = 0; + result.micros = input * common::Interval::MICROS_PER_SEC; + } +}; + +struct ToMilliseconds { + static inline void operation(int64_t& input, common::interval_t& result) { + result.months = result.days = 0; + result.micros = input * common::Interval::MICROS_PER_MSEC; + } +}; + +struct ToMicroseconds { + static inline void operation(int64_t& input, common::interval_t& result) { + result.months = result.days = 0; + result.micros = input; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/interval/vector_interval_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/interval/vector_interval_functions.h new file mode 100644 index 0000000000..0c74e17bc0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/interval/vector_interval_functions.h @@ -0,0 +1,87 @@ +#pragma once + +#include "function/scalar_function.h" +#include "interval_functions.h" + +namespace lbug { +namespace function { + +struct IntervalFunction { +public: + template + static function_set getUnaryIntervalFunction(std::string funcName) { + function_set result; + result.push_back(std::make_unique(funcName, + std::vector{common::LogicalTypeID::INT64}, + common::LogicalTypeID::INTERVAL, + ScalarFunction::UnaryExecFunction)); + return result; + } +}; + +struct ToYearsFunction { + static constexpr const char* name = "TO_YEARS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToMonthsFunction { + static constexpr const char* name = "TO_MONTHS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToDaysFunction { + static constexpr const char* name = "TO_DAYS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToHoursFunction { + static constexpr const char* name = "TO_HOURS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToMinutesFunction { + static constexpr const char* name = "TO_MINUTES"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToSecondsFunction { + static constexpr const char* name = "TO_SECONDS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToMillisecondsFunction { + static constexpr const char* name = "TO_MILLISECONDS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +struct ToMicrosecondsFunction { + static constexpr const char* name = "TO_MICROSECONDS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/base_list_sort_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/base_list_sort_function.h new file mode 100644 index 0000000000..45cfb2cf57 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/base_list_sort_function.h @@ -0,0 +1,103 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct BaseListSortOperation { +public: + static inline bool isAscOrder(const std::string& sortOrder) { + std::string upperSortOrder = common::StringUtils::getUpper(sortOrder); + if (upperSortOrder == "ASC") { + return true; + } else if (upperSortOrder == "DESC") { + return false; + } else { + throw common::RuntimeException("Invalid sortOrder"); + } + } + + static inline bool isNullFirst(const std::string& nullOrder) { + std::string upperNullOrder = common::StringUtils::getUpper(nullOrder); + if (upperNullOrder == "NULLS FIRST") { + return true; + } else if (upperNullOrder == "NULLS LAST") { + return false; + } else { + throw common::RuntimeException("Invalid nullOrder"); + } + } + + template + static void sortValues(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector, bool ascOrder, + bool nullFirst) { + // TODO(Ziyi) - Replace this sort implementation with radix_sort implementation: + // https://github.com/kuzudb/kuzu/issues/1536. + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + auto inputPos = input.offset; + + // Calculate null count. + auto nullCount = 0; + for (auto i = 0u; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + nullCount += 1; + } + } + + result = common::ListVector::addList(&resultVector, input.size); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + + // Add nulls first. + if (nullFirst) { + setVectorRangeToNull(*resultDataVector, result.offset, 0, nullCount); + resultPos += nullCount; + } + + // Add actual data. + for (auto i = 0u; i < input.size; i++) { + if (inputDataVector->isNull(inputPos)) { + inputPos++; + continue; + } + resultDataVector->copyFromVectorData(resultPos++, inputDataVector, inputPos++); + } + + // Add nulls in the end. + if (!nullFirst) { + setVectorRangeToNull(*resultDataVector, result.offset, input.size - nullCount, + input.size); + } + + // Determine the starting and ending position of the data to be sorted. + auto sortStart = nullCount; + auto sortEnd = input.size; + if (!nullFirst) { + sortStart = 0; + sortEnd = input.size - nullCount; + } + + // Sort the data based on order. + auto sortingValues = + reinterpret_cast(common::ListVector::getListValues(&resultVector, result)); + if (ascOrder) { + std::sort(sortingValues + sortStart, sortingValues + sortEnd, std::less{}); + } else { + std::sort(sortingValues + sortStart, sortingValues + sortEnd, std::greater{}); + } + } + + static void setVectorRangeToNull(common::ValueVector& vector, uint64_t offset, + uint64_t startPos, uint64_t endPos) { + for (auto i = startPos; i < endPos; i++) { + vector.setNull(offset + i, true); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_concat_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_concat_function.h new file mode 100644 index 0000000000..4a6992d8f3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_concat_function.h @@ -0,0 +1,17 @@ +#pragma once + +#include "common/types/types.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct ListConcat { +public: + static void operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_extract_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_extract_function.h new file mode 100644 index 0000000000..862f98c867 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_extract_function.h @@ -0,0 +1,54 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/type_utils.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" +#include "function/string/functions/array_extract_function.h" + +namespace lbug { +namespace function { + +struct ListExtract { +public: + // Note: this function takes in a 1-based position (The index of the first value in the list + // is 1). + template + static inline void operation(common::list_entry_t& listEntry, int64_t pos, T& result, + common::ValueVector& listVector, common::ValueVector& /*posVector*/, + common::ValueVector& resultVector, uint64_t resPos) { + if (pos == 0) { + throw common::RuntimeException("List extract takes 1-based position."); + } + if ((pos > 0 && pos > listEntry.size) || (pos < 0 && pos < -(int64_t)listEntry.size)) { + throw common::RuntimeException( + common::stringFormat("list_extract(list, index): index={} is out of range.", + common::TypeUtils::toString(pos))); + } + if (pos > 0) { + pos--; + } else { + pos = listEntry.size + pos; + } + auto listDataVector = common::ListVector::getDataVector(&listVector); + resultVector.setNull(resPos, listDataVector->isNull(listEntry.offset + pos)); + if (!resultVector.isNull(resPos)) { + auto listValues = + common::ListVector::getListValuesWithOffset(&listVector, listEntry, pos); + resultVector.copyFromVectorData(reinterpret_cast(&result), listDataVector, + listValues); + } + } + + static inline void operation(common::ku_string_t& str, int64_t& idx, + common::ku_string_t& result) { + if (str.len < idx) { + result.set("", 0); + } else { + ArrayExtract::operation(str, idx, result); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_function_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_function_utils.h new file mode 100644 index 0000000000..2230d4c454 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_function_utils.h @@ -0,0 +1,58 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +enum class ListOp { Append, Prepend, Concat }; + +template +struct ListTypeResolver; + +template<> +struct ListTypeResolver { + static void anyEmpty(std::vector& types, common::LogicalType& targetType); + static void bothNull(std::vector& types, common::LogicalType& targetType); + static void leftNull(std::vector& types, common::LogicalType& targetType); + static void rightNull(std::vector& types, common::LogicalType& targetType); + static void finalResolver(std::vector& types, + common::LogicalType& targetType); +}; + +template<> +struct ListTypeResolver + : ListTypeResolver { /*Prepend empty list resolution follows the same logic as + the Append operation*/ +}; + +template<> +struct ListTypeResolver { + static void leftEmpty(std::vector& types, common::LogicalType& targetType); + static void rightEmpty(std::vector& types, + common::LogicalType& targetType); + static void bothNull(std::vector& types, common::LogicalType& targetType); + static void finalResolver(std::vector& types, + common::LogicalType& targetType); +}; + +struct ListFunctionUtils { +public: + using type_resolver = std::function& types, + common::LogicalType& targetType)>; + + static void resolveEmptyList(const ScalarBindFuncInput& input, + std::vector& types, type_resolver bothEmpty, type_resolver leftEmpty, + type_resolver rightEmpty, type_resolver finalEmptyListResolver); + + static void resolveNulls(std::vector& types, type_resolver bothNull, + type_resolver leftNull, type_resolver rightNull, type_resolver finalNullParamResolver); + + static void resolveTypes(const ScalarBindFuncInput& input, + std::vector& types, type_resolver bothEmpty, type_resolver leftEmpty, + type_resolver rightEmpty, type_resolver finalEmptyListResolver, type_resolver bothNull, + type_resolver leftNull, type_resolver rightNull, type_resolver finalNullParamResolver); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_len_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_len_function.h new file mode 100644 index 0000000000..c94be3accf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_len_function.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "common/types/ku_string.h" +#include "utf8proc.h" + +namespace lbug { +namespace function { + +struct ListLen { +public: + template + static void operation(T& input, int64_t& result) { + result = input.size; + } +}; + +template<> +inline void ListLen::operation(common::ku_string_t& input, int64_t& result) { + auto totalByteLength = input.len; + auto inputString = input.getAsString(); + for (auto i = 0u; i < totalByteLength; i++) { + if (inputString[i] & 0x80) { + int64_t length = 0; + // Use grapheme iterator to identify bytes of utf8 char and increment once for each + // char. + utf8proc::utf8proc_grapheme_callback(inputString.c_str(), totalByteLength, + [&](size_t /*start*/, size_t /*end*/) { + length++; + return true; + }); + result = length; + return; + } + } + result = totalByteLength; +} + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_position_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_position_function.h new file mode 100644 index 0000000000..34c3feab91 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_position_function.h @@ -0,0 +1,36 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/comparison/comparison_functions.h" + +namespace lbug { +namespace function { + +struct ListPosition { + // Note: this function takes in a 1-based element (The index of the first element in the list + // is 1). + template + static void operation(common::list_entry_t& list, T& element, int64_t& result, + common::ValueVector& listVector, common::ValueVector& elementVector, + common::ValueVector& /*resultVector*/) { + if (common::ListType::getChildType(listVector.dataType) != elementVector.dataType) { + result = 0; + return; + } + auto listElements = + reinterpret_cast(common::ListVector::getListValues(&listVector, list)); + uint8_t comparisonResult = 0; + for (auto i = 0u; i < list.size; i++) { + Equals::operation(listElements[i], element, comparisonResult, + common::ListVector::getDataVector(&listVector), &elementVector); + if (comparisonResult) { + result = i + 1; + return; + } + } + result = 0; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_reverse_sort_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_reverse_sort_function.h new file mode 100644 index 0000000000..88c63c2689 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_reverse_sort_function.h @@ -0,0 +1,33 @@ +#pragma once + +#include "base_list_sort_function.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +template +struct ListReverseSort : BaseListSortOperation { + static inline void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + sortValues(input, result, inputVector, resultVector, false /* ascOrder */, + true /* nullFirst */); + } + + static inline void operation(common::list_entry_t& input, common::ku_string_t& nullOrder, + common::list_entry_t& result, common::ValueVector& inputVector, + common::ValueVector& /*valueVector*/, common::ValueVector& resultVector) { + sortValues(input, result, inputVector, resultVector, false /* ascOrder */, + isNullFirst(nullOrder.getAsString()) /* nullFirst */); + } + + static inline void operation(common::list_entry_t& /*input*/, + common::ku_string_t& /*sortOrder*/, common::ku_string_t& /*nullOrder*/, + common::list_entry_t& /*result*/, common::ValueVector& /*inputVector*/, + common::ValueVector& /*resultVector*/) { + throw common::RuntimeException("Invalid number of arguments"); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_sort_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_sort_function.h new file mode 100644 index 0000000000..1cebd3d6d1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_sort_function.h @@ -0,0 +1,33 @@ +#pragma once + +#include "base_list_sort_function.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +template +struct ListSort : BaseListSortOperation { + static void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + sortValues(input, result, inputVector, resultVector, true /* ascOrder */, + true /* nullFirst */); + } + + static void operation(common::list_entry_t& input, common::ku_string_t& sortOrder, + common::list_entry_t& result, common::ValueVector& inputVector, + common::ValueVector& /*valueVector*/, common::ValueVector& resultVector) { + sortValues(input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()), + true /* nullFirst */); + } + + static void operation(common::list_entry_t& input, common::ku_string_t& sortOrder, + common::ku_string_t& nullOrder, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + sortValues(input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()), + isNullFirst(nullOrder.getAsString())); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_unique_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_unique_function.h new file mode 100644 index 0000000000..02e057e5ca --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/functions/list_unique_function.h @@ -0,0 +1,35 @@ +#pragma once + +#include "common/type_utils.h" +#include "common/types/value/value.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct ValueHashFunction { + uint64_t operator()(const common::Value& value) const { return (uint64_t)value.computeHash(); } +}; + +struct ValueEquality { + bool operator()(const common::Value& a, const common::Value& b) const { return a == b; } +}; + +using ValueSet = std::unordered_set; + +using duplicate_value_handler = std::function; +using unique_value_handler = std::function; +using null_value_handler = std::function; + +struct ListUnique { + static uint64_t appendListElementsToValueSet(common::list_entry_t& input, + common::ValueVector& inputVector, duplicate_value_handler duplicateValHandler = nullptr, + unique_value_handler uniqueValueHandler = nullptr, + null_value_handler nullValueHandler = nullptr); + + static void operation(common::list_entry_t& input, int64_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/vector_list_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/vector_list_functions.h new file mode 100644 index 0000000000..54ecc6cdf7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/list/vector_list_functions.h @@ -0,0 +1,216 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct ListCreationFunction { + static constexpr const char* name = "LIST_CREATION"; + + static function_set getFunctionSet(); + static void execFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr); +}; + +struct ListRangeFunction { + static constexpr const char* name = "RANGE"; + + static function_set getFunctionSet(); +}; + +struct SizeFunction { + static constexpr const char* name = "SIZE"; + + static function_set getFunctionSet(); +}; + +struct CardinalityFunction { + using alias = SizeFunction; + + static constexpr const char* name = "CARDINALITY"; +}; + +struct ListExtractFunction { + static constexpr const char* name = "LIST_EXTRACT"; + + static function_set getFunctionSet(); +}; + +struct ListElementFunction { + using alias = ListExtractFunction; + + static constexpr const char* name = "LIST_ELEMENT"; +}; + +struct ListConcatFunction { + static constexpr const char* name = "LIST_CONCAT"; + + static function_set getFunctionSet(); + static std::unique_ptr bindFunc(const ScalarBindFuncInput& input); +}; + +struct ListCatFunction { + using alias = ListConcatFunction; + + static constexpr const char* name = "LIST_CAT"; +}; + +struct ListAppendFunction { + static constexpr const char* name = "LIST_APPEND"; + + static function_set getFunctionSet(); +}; + +struct ListPrependFunction { + static constexpr const char* name = "LIST_PREPEND"; + + static function_set getFunctionSet(); +}; + +struct ListPositionFunction { + static constexpr const char* name = "LIST_POSITION"; + + static function_set getFunctionSet(); +}; + +struct ListIndexOfFunction { + using alias = ListPositionFunction; + + static constexpr const char* name = "LIST_INDEXOF"; +}; + +struct ListContainsFunction { + static constexpr const char* name = "LIST_CONTAINS"; + + static function_set getFunctionSet(); +}; + +struct ListHasFunction { + using alias = ListContainsFunction; + + static constexpr const char* name = "LIST_HAS"; +}; + +struct ListSliceFunction { + static constexpr const char* name = "LIST_SLICE"; + + static function_set getFunctionSet(); +}; + +struct ListSortFunction { + static constexpr const char* name = "LIST_SORT"; + + static function_set getFunctionSet(); +}; + +struct ListReverseSortFunction { + static constexpr const char* name = "LIST_REVERSE_SORT"; + + static function_set getFunctionSet(); +}; + +struct ListSumFunction { + static constexpr const char* name = "LIST_SUM"; + + static function_set getFunctionSet(); +}; + +struct ListProductFunction { + static constexpr const char* name = "LIST_PRODUCT"; + + static function_set getFunctionSet(); +}; + +struct ListDistinctFunction { + static constexpr const char* name = "LIST_DISTINCT"; + + static function_set getFunctionSet(); +}; + +struct ListUniqueFunction { + static constexpr const char* name = "LIST_UNIQUE"; + + static function_set getFunctionSet(); +}; + +struct ListAnyValueFunction { + static constexpr const char* name = "LIST_ANY_VALUE"; + + static function_set getFunctionSet(); +}; + +struct ListReverseFunction { + static constexpr const char* name = "LIST_REVERSE"; + + static function_set getFunctionSet(); +}; + +struct ListToStringFunction { + static constexpr const char* name = "LIST_TO_STRING"; + + static function_set getFunctionSet(); +}; + +struct ListTransformFunction { + static constexpr const char* name = "LIST_TRANSFORM"; + + static function_set getFunctionSet(); +}; + +struct ListFilterFunction { + static constexpr const char* name = "LIST_FILTER"; + + static function_set getFunctionSet(); +}; + +struct ListReduceFunction { + static constexpr const char* name = "LIST_REDUCE"; + + static function_set getFunctionSet(); +}; + +using quantifier_handler = std::function; + +void execQuantifierFunc(quantifier_handler handler, + const std::vector>& input, + const std::vector& inputSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* bindData); + +std::unique_ptr bindQuantifierFunc(const ScalarBindFuncInput& input); + +struct ListAnyFunction { + static constexpr const char* name = "ANY"; + + static function_set getFunctionSet(); +}; + +struct ListAllFunction { + static constexpr const char* name = "ALL"; + + static function_set getFunctionSet(); +}; + +struct ListNoneFunction { + static constexpr const char* name = "None"; + + static function_set getFunctionSet(); +}; + +struct ListSingleFunction { + static constexpr const char* name = "Single"; + + static function_set getFunctionSet(); +}; + +struct ListHasAllFunction { + static constexpr const char* name = "LIST_HAS_ALL"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/base_map_extract_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/base_map_extract_function.h new file mode 100644 index 0000000000..99a8cb9286 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/base_map_extract_function.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct BaseMapExtract { + static void operation(common::list_entry_t& resultEntry, common::ValueVector& resultVector, + uint8_t* srcValues, common::ValueVector* srcVector, uint64_t numValuesToCopy) { + resultEntry = common::ListVector::addList(&resultVector, numValuesToCopy); + auto dstValues = common::ListVector::getListValues(&resultVector, resultEntry); + auto dstDataVector = common::ListVector::getDataVector(&resultVector); + for (auto i = 0u; i < numValuesToCopy; i++) { + dstDataVector->copyFromVectorData(dstValues, srcVector, srcValues); + dstValues += dstDataVector->getNumBytesPerValue(); + srcValues += srcVector->getNumBytesPerValue(); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_creation_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_creation_function.h new file mode 100644 index 0000000000..2875264653 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_creation_function.h @@ -0,0 +1,58 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/vector/value_vector.h" +#include "function/list/functions/list_unique_function.h" +#include "main/client_context.h" + +namespace lbug { +namespace function { + +static void duplicateValueHandler(const std::string& key) { + throw common::RuntimeException{common::stringFormat("Found duplicate key: {} in map.", key)}; +} + +static void nullValueHandler() { + throw common::RuntimeException("Null value key is not allowed in map."); +} + +static void validateKeys(common::list_entry_t& keyEntry, common::ValueVector& keyVector) { + ListUnique::appendListElementsToValueSet(keyEntry, keyVector, duplicateValueHandler, + nullptr /* uniqueValueHandler */, nullValueHandler); +} + +struct MapCreation { + static void operation(common::list_entry_t& keyEntry, common::list_entry_t& valueEntry, + common::list_entry_t& resultEntry, common::ValueVector& keyVector, + common::ValueVector& valueVector, common::ValueVector& resultVector, void* dataPtr) { + if (keyEntry.size != valueEntry.size) { + throw common::RuntimeException{"Unaligned key list and value list."}; + } + if (!reinterpret_cast(dataPtr) + ->clientContext->getClientConfig() + ->disableMapKeyCheck) { + validateKeys(keyEntry, keyVector); + } + resultEntry = common::ListVector::addList(&resultVector, keyEntry.size); + auto resultStructVector = common::ListVector::getDataVector(&resultVector); + copyListEntry(resultEntry, + common::StructVector::getFieldVector(resultStructVector, 0 /* keyVector */).get(), + keyEntry, &keyVector); + copyListEntry(resultEntry, + common::StructVector::getFieldVector(resultStructVector, 1 /* valueVector */).get(), + valueEntry, &valueVector); + } + + static void copyListEntry(common::list_entry_t& resultEntry, common::ValueVector* resultVector, + common::list_entry_t& srcEntry, common::ValueVector* srcVector) { + auto resultPos = resultEntry.offset; + auto srcDataVector = common::ListVector::getDataVector(srcVector); + auto srcPos = srcEntry.offset; + for (auto i = 0u; i < srcEntry.size; i++) { + resultVector->copyFromVectorData(resultPos++, srcDataVector, srcPos++); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_extract_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_extract_function.h new file mode 100644 index 0000000000..b59cd3520a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_extract_function.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/comparison/comparison_functions.h" + +namespace lbug { +namespace function { + +struct MapExtract { + template + static void operation(common::list_entry_t& listEntry, T& key, + common::list_entry_t& resultEntry, common::ValueVector& listVector, + common::ValueVector& keyVector, common::ValueVector& resultVector) { + auto mapKeyVector = common::MapVector::getKeyVector(&listVector); + auto mapKeyValues = common::MapVector::getMapKeys(&listVector, listEntry); + auto mapValVector = common::MapVector::getValueVector(&listVector); + auto mapValPos = listEntry.offset; + common::offset_vec_t mapValPoses; + uint8_t comparisonResult = 0; + for (auto i = 0u; i < listEntry.size; i++) { + Equals::operation(*reinterpret_cast(mapKeyValues), key, comparisonResult, + mapKeyVector, &keyVector); + if (comparisonResult) { + mapValPoses.push_back(mapValPos); + } + mapKeyValues += mapKeyVector->getNumBytesPerValue(); + mapValPos++; + } + resultEntry = common::ListVector::addList(&resultVector, mapValPoses.size()); + auto resultOffset = resultEntry.offset; + for (auto& valPos : mapValPoses) { + common::ListVector::getDataVector(&resultVector) + ->copyFromVectorData(resultOffset++, mapValVector, valPos); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_keys_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_keys_function.h new file mode 100644 index 0000000000..11c4101df3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_keys_function.h @@ -0,0 +1,19 @@ +#pragma once + +#include "function/map/functions/base_map_extract_function.h" + +namespace lbug { +namespace function { + +struct MapKeys : public BaseMapExtract { + static void operation(common::list_entry_t& listEntry, common::list_entry_t& resultEntry, + common::ValueVector& listVector, common::ValueVector& resultVector) { + auto mapKeyVector = common::MapVector::getKeyVector(&listVector); + auto mapKeyValues = common::MapVector::getMapKeys(&listVector, listEntry); + BaseMapExtract::operation(resultEntry, resultVector, mapKeyValues, mapKeyVector, + listEntry.size); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_values_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_values_function.h new file mode 100644 index 0000000000..0d5f3f0a79 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/functions/map_values_function.h @@ -0,0 +1,20 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/map/functions/base_map_extract_function.h" + +namespace lbug { +namespace function { + +struct MapValues : public BaseMapExtract { + static void operation(common::list_entry_t& listEntry, common::list_entry_t& resultEntry, + common::ValueVector& listVector, common::ValueVector& resultVector) { + auto mapValueVector = common::MapVector::getValueVector(&listVector); + auto mapValueValues = common::MapVector::getMapValues(&listVector, listEntry); + BaseMapExtract::operation(resultEntry, resultVector, mapValueValues, mapValueVector, + listEntry.size); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/vector_map_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/vector_map_functions.h new file mode 100644 index 0000000000..54b3f9c678 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/map/vector_map_functions.h @@ -0,0 +1,39 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct MapCreationFunctions { + static constexpr const char* name = "MAP"; + + static function_set getFunctionSet(); +}; + +struct MapExtractFunctions { + static constexpr const char* name = "MAP_EXTRACT"; + + static function_set getFunctionSet(); +}; + +struct ElementAtFunctions { + using alias = MapExtractFunctions; + + static constexpr const char* name = "ELEMENT_AT"; +}; + +struct MapKeysFunctions { + static constexpr const char* name = "MAP_KEYS"; + + static function_set getFunctionSet(); +}; + +struct MapValuesFunctions { + static constexpr const char* name = "MAP_VALUES"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/null_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/null_function_executor.h new file mode 100644 index 0000000000..8e2ef81087 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/null_function_executor.h @@ -0,0 +1,69 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct NullOperationExecutor { + + template + static void execute(common::ValueVector& operand, common::SelectionVector& operandSelVector, + common::ValueVector& result) { + KU_ASSERT(result.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL); + auto resultValues = (uint8_t*)result.getData(); + if (operand.state->isFlat()) { + auto pos = operandSelVector[0]; + auto resultPos = result.state->getSelVector()[0]; + FUNC::operation(operand.getValue(pos), (bool)operand.isNull(pos), + resultValues[resultPos]); + } else { + if (operandSelVector.isUnfiltered()) { + for (auto i = 0u; i < operandSelVector.getSelSize(); i++) { + FUNC::operation(operand.getValue(i), (bool)operand.isNull(i), + resultValues[i]); + } + } else { + for (auto i = 0u; i < operandSelVector.getSelSize(); i++) { + auto pos = operandSelVector[i]; + FUNC::operation(operand.getValue(pos), (bool)operand.isNull(pos), + resultValues[pos]); + } + } + } + } + + template + static bool select(common::ValueVector& operand, common::SelectionVector& selVector, + void* /*dataPtr*/) { + auto& operandSelVector = operand.state->getSelVector(); + if (operand.state->isFlat()) { + auto pos = operandSelVector[0]; + uint8_t resultValue = 0; + FUNC::operation(operand.getValue(pos), operand.isNull(pos), resultValue); + return resultValue == true; + } else { + uint64_t numSelectedValues = 0; + auto buffer = selVector.getMutableBuffer(); + for (auto i = 0ul; i < operandSelVector.getSelSize(); i++) { + auto pos = operandSelVector[i]; + selectOnValue(operand, pos, numSelectedValues, buffer); + } + selVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } + } + + template + static void selectOnValue(common::ValueVector& operand, uint64_t operandPos, + uint64_t& numSelectedValues, std::span selectedPositionsBuffer) { + uint8_t resultValue = 0; + FUNC::operation(operand.getValue(operandPos), operand.isNull(operandPos), + resultValue); + selectedPositionsBuffer[numSelectedValues] = operandPos; + numSelectedValues += resultValue == true; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/null_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/null_functions.h new file mode 100644 index 0000000000..94efa83689 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/null_functions.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace lbug { +namespace function { + +struct IsNull { + template + static inline void operation(T /*value*/, bool isNull, uint8_t& result) { + result = isNull; + } +}; + +struct IsNotNull { + template + static inline void operation(T /*value*/, bool isNull, uint8_t& result) { + result = !isNull; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/vector_null_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/vector_null_functions.h new file mode 100644 index 0000000000..9347d54414 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/null/vector_null_functions.h @@ -0,0 +1,37 @@ +#pragma once + +#include "function/scalar_function.h" +#include "null_function_executor.h" + +namespace lbug { +namespace function { + +class VectorNullFunction { +public: + static void bindExecFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_exec_t& func); + + static void bindSelectFunction(common::ExpressionType expressionType, + const binder::expression_vector& children, scalar_func_select_t& func); + +private: + template + static void UnaryNullExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector*, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 1); + NullOperationExecutor::execute(*params[0], *paramSelVectors[0], result); + } + + template + static bool UnaryNullSelectFunction( + const std::vector>& params, + common::SelectionVector& selVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + return NullOperationExecutor::select(*params[0], selVector, dataPtr); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/path/path_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/path/path_function_executor.h new file mode 100644 index 0000000000..f47cecf34a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/path/path_function_executor.h @@ -0,0 +1,132 @@ +#pragma once + +#include "common/constants.h" +#include "common/vector/value_vector.h" +#include "function/hash/hash_functions.h" + +namespace lbug { +namespace function { + +static bool isAllInternalIDDistinct(common::ValueVector* dataVector, common::offset_t startOffset, + uint64_t size, std::unordered_set& internalIDSet) { + internalIDSet.clear(); + for (auto i = 0u; i < size; ++i) { + auto& internalID = dataVector->getValue(startOffset + i); + if (internalIDSet.contains(internalID)) { + return false; + } + internalIDSet.insert(internalID); + } + return true; +} + +// Note: this executor is only used for isTrail and isAcyclic. So we add some ad-hoc optimization +// into executor, e.g. internalIDSet. A more general implementation can be done once needed. But +// pay attention to the performance drop. Depends on how bad it becomes, we may want to implement +// customized executors. +struct UnaryPathExecutor { + static void executeNodeIDs(common::ValueVector& input, common::SelectionVector& inputSelVector, + common::ValueVector& result) { + auto nodesFieldIdx = 0; + KU_ASSERT(nodesFieldIdx == + common::StructType::getFieldIdx(input.dataType, common::InternalKeyword::NODES)); + auto nodesVector = common::StructVector::getFieldVector(&input, nodesFieldIdx).get(); + auto internalIDFieldIdx = 0; + execute(inputSelVector, *nodesVector, internalIDFieldIdx, result); + } + + static void executeRelIDs(common::ValueVector& input, common::SelectionVector& inputSelVector, + common::ValueVector& result) { + auto relsFieldIdx = 1; + KU_ASSERT(relsFieldIdx == + common::StructType::getFieldIdx(input.dataType, common::InternalKeyword::RELS)); + auto relsVector = common::StructVector::getFieldVector(&input, relsFieldIdx).get(); + auto internalIDFieldIdx = 3; + execute(inputSelVector, *relsVector, internalIDFieldIdx, result); + } + + static bool selectNodeIDs(common::ValueVector& input, + common::SelectionVector& selectionVector) { + auto nodesFieldIdx = 0; + KU_ASSERT(nodesFieldIdx == + common::StructType::getFieldIdx(input.dataType, common::InternalKeyword::NODES)); + auto nodesVector = common::StructVector::getFieldVector(&input, nodesFieldIdx).get(); + auto internalIDFieldIdx = 0; + return select(input.state->getSelVector(), *nodesVector, internalIDFieldIdx, + selectionVector); + } + + static bool selectRelIDs(common::ValueVector& input, common::SelectionVector& selectionVector) { + auto relsFieldIdx = 1; + KU_ASSERT(relsFieldIdx == + common::StructType::getFieldIdx(input.dataType, common::InternalKeyword::RELS)); + auto relsVector = common::StructVector::getFieldVector(&input, relsFieldIdx).get(); + auto internalIDFieldIdx = 3; + return select(input.state->getSelVector(), *relsVector, internalIDFieldIdx, + selectionVector); + } + +private: + static void execute(const common::SelectionVector& inputSelVector, + common::ValueVector& listVector, common::struct_field_idx_t fieldIdx, + common::ValueVector& result) { + auto listDataVector = common::ListVector::getDataVector(&listVector); + KU_ASSERT(fieldIdx == common::StructType::getFieldIdx(listDataVector->dataType, + common::InternalKeyword::ID)); + auto internalIDsVector = + common::StructVector::getFieldVector(listDataVector, fieldIdx).get(); + std::unordered_set internalIDSet; + if (inputSelVector.isUnfiltered()) { + for (auto i = 0u; i < inputSelVector.getSelSize(); ++i) { + auto& listEntry = listVector.getValue(i); + bool isTrail = isAllInternalIDDistinct(internalIDsVector, listEntry.offset, + listEntry.size, internalIDSet); + result.setValue(i, isTrail); + } + } else { + for (auto i = 0u; i < inputSelVector.getSelSize(); ++i) { + auto pos = inputSelVector[i]; + auto& listEntry = listVector.getValue(pos); + bool isTrail = isAllInternalIDDistinct(internalIDsVector, listEntry.offset, + listEntry.size, internalIDSet); + result.setValue(pos, isTrail); + } + } + } + + static bool select(const common::SelectionVector& inputSelVector, + common::ValueVector& listVector, common::struct_field_idx_t fieldIdx, + common::SelectionVector& selectionVector) { + auto listDataVector = common::ListVector::getDataVector(&listVector); + KU_ASSERT(fieldIdx == common::StructType::getFieldIdx(listDataVector->dataType, + common::InternalKeyword::ID)); + auto internalIDsVector = + common::StructVector::getFieldVector(listDataVector, fieldIdx).get(); + std::unordered_set internalIDSet; + auto numSelectedValues = 0u; + auto buffer = selectionVector.getMutableBuffer(); + if (inputSelVector.isUnfiltered()) { + for (auto i = 0u; i < inputSelVector.getSelSize(); ++i) { + auto& listEntry = listVector.getValue(i); + bool isTrail = isAllInternalIDDistinct(internalIDsVector, listEntry.offset, + listEntry.size, internalIDSet); + buffer[numSelectedValues] = i; + numSelectedValues += isTrail; + } + } else { + for (auto i = 0u; i < inputSelVector.getSelSize(); ++i) { + auto pos = inputSelVector[i]; + auto& listEntry = listVector.getValue(pos); + bool isTrail = isAllInternalIDDistinct(internalIDsVector, listEntry.offset, + listEntry.size, internalIDSet); + buffer[numSelectedValues] = pos; + numSelectedValues += isTrail; + } + } + selectionVector.setSelSize(numSelectedValues); + return numSelectedValues > 0; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/path/vector_path_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/path/vector_path_functions.h new file mode 100644 index 0000000000..f7faf1b038 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/path/vector_path_functions.h @@ -0,0 +1,62 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct NodesFunction { + static constexpr const char* name = "NODES"; + + static function_set getFunctionSet(); +}; + +struct RelsFunction { + static constexpr const char* name = "RELS"; + + static function_set getFunctionSet(); +}; + +struct RelationshipsFunction { + using alias = RelsFunction; + + static constexpr const char* name = "RELATIONSHIPS"; +}; + +struct PropertiesBindData : public FunctionBindData { + common::idx_t childIdx; + + PropertiesBindData(common::LogicalType dataType, common::idx_t childIdx) + : FunctionBindData{std::move(dataType)}, childIdx{childIdx} {} + + inline std::unique_ptr copy() const override { + return std::make_unique(resultType.copy(), childIdx); + } +}; + +struct PropertiesFunction { + static constexpr const char* name = "PROPERTIES"; + + static function_set getFunctionSet(); +}; + +struct IsTrailFunction { + static constexpr const char* name = "IS_TRAIL"; + + static function_set getFunctionSet(); +}; + +struct IsACyclicFunction { + static constexpr const char* name = "IS_ACYCLIC"; + + static function_set getFunctionSet(); +}; + +struct LengthFunction { + static constexpr const char* name = "LENGTH"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/pointer_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/pointer_function_executor.h new file mode 100644 index 0000000000..c295dc3fbe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/pointer_function_executor.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct PointerFunctionExecutor { + template + static void execute(common::ValueVector& result, common::SelectionVector& sel, void* dataPtr) { + if (sel.isUnfiltered()) { + for (auto i = 0u; i < sel.getSelSize(); i++) { + OP::operation(result.getValue(i), dataPtr); + } + } else { + for (auto i = 0u; i < sel.getSelSize(); i++) { + auto pos = sel[i]; + OP::operation(result.getValue(pos), dataPtr); + } + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/rewrite_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/rewrite_function.h new file mode 100644 index 0000000000..7753b893b3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/rewrite_function.h @@ -0,0 +1,42 @@ +#pragma once + +#include "function.h" + +namespace lbug { +namespace binder { +class ExpressionBinder; +} +namespace function { + +struct RewriteFunctionBindInput { + main::ClientContext* context; + binder::ExpressionBinder* expressionBinder; + binder::expression_vector arguments; + + RewriteFunctionBindInput(main::ClientContext* context, + binder::ExpressionBinder* expressionBinder, binder::expression_vector arguments) + : context{context}, expressionBinder{expressionBinder}, arguments{std::move(arguments)} {} +}; + +// Rewrite function to a different expression, e.g. id(n) -> n._id. +using rewrite_func_rewrite_t = + std::function(const RewriteFunctionBindInput&)>; + +// We write for the following functions +// ID(n) -> n._id +struct RewriteFunction final : Function { + rewrite_func_rewrite_t rewriteFunc; + + RewriteFunction(std::string name, std::vector parameterTypeIDs, + rewrite_func_rewrite_t rewriteFunc) + : Function{std::move(name), std::move(parameterTypeIDs)}, + rewriteFunc{std::move(rewriteFunc)} {} + EXPLICIT_COPY_DEFAULT_MOVE(RewriteFunction) + +private: + RewriteFunction(const RewriteFunction& other) + : Function{other}, rewriteFunc{other.rewriteFunc} {} +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/scalar_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/scalar_function.h new file mode 100644 index 0000000000..e910c78ad4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/scalar_function.h @@ -0,0 +1,251 @@ +#pragma once + +#include "binary_function_executor.h" +#include "const_function_executor.h" +#include "function.h" +#include "pointer_function_executor.h" +#include "ternary_function_executor.h" +#include "unary_function_executor.h" + +namespace lbug { +namespace function { + +// Evaluate function at compile time, e.g. struct_extraction. +using scalar_func_compile_exec_t = + std::function>&, + std::shared_ptr&)>; +// Execute function. +using scalar_func_exec_t = + std::function>&, + const std::vector&, common::ValueVector&, + common::SelectionVector*, void*)>; +// Execute boolean function and write result to selection vector. Fast path for filter. +using scalar_func_select_t = std::function>&, common::SelectionVector&, void*)>; + +struct LBUG_API ScalarFunction : public ScalarOrAggregateFunction { + scalar_func_exec_t execFunc = nullptr; + scalar_func_select_t selectFunc = nullptr; + scalar_func_compile_exec_t compileFunc = nullptr; + bool isListLambda = false; + bool isVarLength = false; + + ScalarFunction() = default; + ScalarFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID) + : ScalarOrAggregateFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID} {} + ScalarFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_func_exec_t execFunc) + : ScalarOrAggregateFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID}, + execFunc{std::move(execFunc)} {} + ScalarFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_func_exec_t execFunc, + scalar_func_select_t selectFunc) + : ScalarOrAggregateFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID}, + execFunc{std::move(execFunc)}, selectFunc{std::move(selectFunc)} {} + + template + static void TernaryExecFunction(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 3); + TernaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], paramSelVectors[1], + *params[2], paramSelVectors[2], result, resultSelVector, dataPtr); + } + + template + static void TernaryStringExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 3); + TernaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], *params[2], paramSelVectors[2], result, resultSelVector, dataPtr); + } + + template + static void TernaryRegexExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + TernaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], *params[2], paramSelVectors[2], result, resultSelVector, dataPtr); + } + + template + static void TernaryExecListStructFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 3); + TernaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], *params[2], paramSelVectors[2], result, resultSelVector, dataPtr); + } + + template + static void BinaryExecFunction(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::execute(*params[0], + paramSelVectors[0], *params[1], paramSelVectors[1], result, resultSelVector); + } + + template + static void BinaryStringExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector, dataPtr); + } + + template + static void BinaryExecListStructFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector, dataPtr); + } + + template + static void BinaryExecWithBindData( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], *params[1], + paramSelVectors[1], result, resultSelVector, dataPtr); + } + + template + static bool BinarySelectFunction( + const std::vector>& params, + common::SelectionVector& selVector, void* dataPtr) { + KU_ASSERT(params.size() == 2); + return BinaryFunctionExecutor::select(*params[0], *params[1], + selVector, dataPtr); + } + + template + static bool BinarySelectWithBindData( + const std::vector>& params, + common::SelectionVector& selVector, void* dataPtr) { + KU_ASSERT(params.size() == 2); + return BinaryFunctionExecutor::select(*params[0], *params[1], selVector, dataPtr); + } + + template + static void UnaryExecFunction(const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + EXECUTOR::template executeSwitch( + *params[0], paramSelVectors[0], result, resultSelVector, dataPtr); + } + + template + static void UnarySequenceExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + UnaryFunctionExecutor::executeSequence(*params[0], + paramSelVectors[0], result, resultSelVector, dataPtr); + } + + template + static void UnaryStringExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 1); + UnaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], result, resultSelVector, + nullptr /* dataPtr */); + } + + template + static void UnaryCastStringExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + EXECUTOR::template executeSwitch(*params[0], paramSelVectors[0], result, resultSelVector, + dataPtr); + } + + template + static void UnaryCastExecFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + EXECUTOR::template executeSwitch(*params[0], + paramSelVectors[0], result, resultSelVector, dataPtr); + } + + template + static void UnaryExecNestedTypeFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + EXECUTOR::template executeSwitch(*params[0], paramSelVectors[0], result, resultSelVector, + dataPtr); + } + + template + static void UnarySetSeedFunction( + const std::vector>& params, + const std::vector& paramSelVectors, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.size() == 1); + EXECUTOR::template executeSwitch( + *params[0], paramSelVectors[0], result, resultSelVector, dataPtr); + } + + template + static void NullaryExecFunction( + [[maybe_unused]] const std::vector>& params, + [[maybe_unused]] const std::vector& paramSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.empty() && paramSelVectors.empty()); + ConstFunctionExecutor::execute(result, *resultSelVector); + } + + template + static void NullaryAuxilaryExecFunction( + [[maybe_unused]] const std::vector>& params, + [[maybe_unused]] const std::vector& paramSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(params.empty() && paramSelVectors.empty()); + PointerFunctionExecutor::execute(result, *resultSelVector, dataPtr); + } + + virtual std::unique_ptr copy() const { + return std::make_unique(*this); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/scalar_macro_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/scalar_macro_function.h new file mode 100644 index 0000000000..9a8394ead2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/scalar_macro_function.h @@ -0,0 +1,42 @@ +#pragma once + +#include + +#include "parser/create_macro.h" + +namespace lbug { +namespace function { + +using macro_parameter_value_map = std::unordered_map; + +struct ScalarMacroFunction { + std::unique_ptr expression; + std::vector positionalArgs; + parser::default_macro_args defaultArgs; + + ScalarMacroFunction() = default; + + ScalarMacroFunction(std::unique_ptr expression, + std::vector positionalArgs, parser::default_macro_args defaultArgs) + : expression{std::move(expression)}, positionalArgs{std::move(positionalArgs)}, + defaultArgs{std::move(defaultArgs)} {} + + std::string getDefaultParameterName(uint64_t idx) const { return defaultArgs[idx].first; } + + uint64_t getNumArgs() const { return positionalArgs.size() + defaultArgs.size(); } + + std::vector getPositionalArgs() const { return positionalArgs; } + + macro_parameter_value_map getDefaultParameterVals() const; + + std::unique_ptr copy() const; + + void serialize(common::Serializer& serializer) const; + + std::string toCypher(const std::string& name) const; + + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/schema/offset_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/schema/offset_functions.h new file mode 100644 index 0000000000..ba41162f01 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/schema/offset_functions.h @@ -0,0 +1,14 @@ +#pragma once + +#include "common/types/types.h" + +namespace lbug { +namespace function { + +struct Offset { + + static void operation(common::internalID_t& input, int64_t& result) { result = input.offset; } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/schema/vector_node_rel_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/schema/vector_node_rel_functions.h new file mode 100644 index 0000000000..c917d95a71 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/schema/vector_node_rel_functions.h @@ -0,0 +1,54 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct RewriteFunctionBindInput; + +struct OffsetFunction { + static constexpr const char* name = "OFFSET"; + + static function_set getFunctionSet(); +}; + +struct IDFunction { + static constexpr const char* name = "ID"; + + static function_set getFunctionSet(); +}; + +struct StartNodeFunction { + static constexpr const char* name = "START_NODE"; + + static function_set getFunctionSet(); +}; + +struct EndNodeFunction { + static constexpr const char* name = "END_NODE"; + + static function_set getFunctionSet(); +}; + +struct LabelFunction { + static constexpr const char* name = "LABEL"; + + static function_set getFunctionSet(); + static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input); +}; + +struct LabelsFunction { + using alias = LabelFunction; + + static constexpr const char* name = "LABELS"; +}; + +struct CostFunction { + static constexpr const char* name = "COST"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/sequence/sequence_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/sequence/sequence_functions.h new file mode 100644 index 0000000000..31e82dfc81 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/sequence/sequence_functions.h @@ -0,0 +1,21 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct CurrValFunction { + static constexpr const char* name = "CURRVAL"; + + static function_set getFunctionSet(); +}; + +struct NextValFunction { + static constexpr const char* name = "NEXTVAL"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/array_extract_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/array_extract_function.h new file mode 100644 index 0000000000..2d8766bf9b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/array_extract_function.h @@ -0,0 +1,65 @@ +#pragma once + +#include + +#include "common/types/ku_string.h" +#include "function/list/functions/list_len_function.h" + +namespace lbug { +namespace function { + +struct ArrayExtract { + static inline void operation(common::ku_string_t& str, int64_t& idx, + common::ku_string_t& result) { + if (idx == 0) { + result.len = 0; + return; + } + auto stringVal = str.getAsString(); + int64_t strLen = 0; + ListLen::operation(str, strLen); + auto idxPos = idx > 0 ? std::min(idx, strLen) : std::max(strLen + idx, (int64_t)0) + 1; + auto startPos = idxPos - 1; + auto endPos = startPos + 1; + bool isAscii = true; + for (auto i = 0u; i < std::min((size_t)idxPos + 1, stringVal.size()); i++) { + if (stringVal[i] & 0x80) { + isAscii = false; + break; + } + } + if (isAscii) { + copySubstr(str, idxPos, 1 /* length */, result, isAscii); + } else { + int64_t characterCount = 0, startBytePos = 0, endBytePos = 0; + lbug::utf8proc::utf8proc_grapheme_callback(stringVal.c_str(), stringVal.size(), + [&](int64_t gstart, int64_t /*gend*/) { + if (characterCount == startPos) { + startBytePos = gstart; + } else if (characterCount == endPos) { + endBytePos = gstart; + return false; + } + characterCount++; + return true; + }); + if (endBytePos == 0) { + endBytePos = str.len; + } + copySubstr(str, startBytePos, endBytePos - startBytePos, result, isAscii); + } + } + + static inline void copySubstr(common::ku_string_t& src, int64_t start, int64_t len, + common::ku_string_t& result, bool isAscii) { + result.len = std::min(len, src.len - start + 1); + if (isAscii) { + memcpy((uint8_t*)result.getData(), src.getData() + start - 1, result.len); + } else { + memcpy((uint8_t*)result.getData(), src.getData() + start, result.len); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_lower_upper_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_lower_upper_function.h new file mode 100644 index 0000000000..7e6ed7a2f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_lower_upper_function.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common/api.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct BaseLowerUpperFunction { + + LBUG_API static void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector, bool isUpper); + + static void convertCharCase(char* result, const char* input, int32_t charPos, bool toUpper, + int& originalSize, int& newSize); + static void convertCase(char* result, uint32_t len, char* input, bool toUpper); + static uint32_t getResultLen(char* inputStr, uint32_t inputLen, bool isUpper); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_pad_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_pad_function.h new file mode 100644 index 0000000000..e48f1d66cb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_pad_function.h @@ -0,0 +1,59 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" +#include "utf8proc.h" + +namespace lbug { +namespace function { + +// Padding logic has been taken from DuckDB: +// https://github.com/duckdb/duckdb/blob/master/src/function/scalar/string/pad.cpp +struct BasePadOperation { +public: + static inline void operation(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, common::ku_string_t& result, + common::ValueVector& resultValueVector, + void (*padOperation)(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, std::string& paddedResult)) { + if (count < 0) { + count = 0; + } + std::string paddedResult; + padOperation(src, count, characterToPad, paddedResult); + common::StringVector::addString(&resultValueVector, result, paddedResult.data(), + paddedResult.size()); + } + + static std::pair padCountChars(const uint32_t count, const char* data, + const uint32_t size) { + auto str = reinterpret_cast(data); + uint32_t byteCount = 0, charCount = 0; + for (; charCount < count && byteCount < size; charCount++) { + utf8proc::utf8proc_int32_t codepoint = 0; + auto bytes = utf8proc::utf8proc_iterate(str + byteCount, size - byteCount, &codepoint); + byteCount += bytes; + } + return {byteCount, charCount}; + } + + static void insertPadding(uint32_t charCount, common::ku_string_t pad, std::string& result) { + auto padData = pad.getData(); + auto padSize = pad.len; + uint32_t padByteCount = 0; + for (auto i = 0u; i < charCount; i++) { + if (padByteCount >= padSize) { + result.insert(result.end(), (char*)padData, (char*)(padData + padByteCount)); + padByteCount = 0; + } + utf8proc::utf8proc_int32_t codepoint = 0; + auto bytes = utf8proc::utf8proc_iterate(padData + padByteCount, padSize - padByteCount, + &codepoint); + padByteCount += bytes; + } + result.insert(result.end(), (char*)padData, (char*)(padData + padByteCount)); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_regexp_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_regexp_function.h new file mode 100644 index 0000000000..5a5bcaf1e3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_regexp_function.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct BaseRegexpOperation { + static inline std::string parseCypherPattern(const std::string& pattern) { + // Cypher parses escape characters with 2 backslash eg. for expressing '.' requires '\\.' + // Since Regular Expression requires only 1 backslash '\.' we need to replace double slash + // with single + return std::regex_replace(pattern, std::regex(R"(\\\\)"), "\\"); + } + + static inline void copyToLbugString(const std::string& value, common::ku_string_t& kuString, + common::ValueVector& valueVector) { + common::StringVector::addString(&valueVector, kuString, value.data(), value.length()); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_str_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_str_function.h new file mode 100644 index 0000000000..e248aff604 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/base_str_function.h @@ -0,0 +1,17 @@ +#pragma once + +#include "common/api.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct BaseStrOperation { +public: + LBUG_API static void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector, uint32_t (*strOperation)(char* data, uint32_t len)); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/contains_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/contains_function.h new file mode 100644 index 0000000000..17f686e289 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/contains_function.h @@ -0,0 +1,19 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "function/string/functions/find_function.h" + +namespace lbug { +namespace function { + +struct Contains { + static inline void operation(common::ku_string_t& left, common::ku_string_t& right, + uint8_t& result) { + int64_t pos = 0; + Find::operation(left, right, pos); + result = (pos != 0); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/ends_with_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/ends_with_function.h new file mode 100644 index 0000000000..98c5c763d7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/ends_with_function.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct EndsWith { + static inline void operation(common::ku_string_t& left, common::ku_string_t& right, + uint8_t& result) { + if (right.len > left.len) { + result = 0; + return; + } + auto lenDiff = left.len - right.len; + auto lData = left.getData(); + auto rData = right.getData(); + for (auto i = 0u; i < right.len; i++) { + if (rData[i] != lData[lenDiff + i]) { + result = 0; + return; + } + } + result = 1; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/find_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/find_function.h new file mode 100644 index 0000000000..383a190432 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/find_function.h @@ -0,0 +1,41 @@ +#pragma once + +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +// The string find algorithm is copied from duckdb. Source code: +// https://github.com/duckdb/duckdb/blob/master/src/function/scalar/string/contains.cpp + +struct Find { + static inline void operation(common::ku_string_t& left, common::ku_string_t& right, + int64_t& result) { + if (right.len == 0) { + result = 1; + } else if (right.len > left.len) { + result = 0; + } + result = Find::find(left.getData(), left.len, right.getData(), right.len) + 1; + } + +private: + template + static int64_t unalignedNeedleSizeFind(const uint8_t* haystack, uint32_t haystackLen, + const uint8_t* needle, uint32_t needleLen, uint32_t firstMatchCharOffset); + + template + static int64_t alignedNeedleSizeFind(const uint8_t* haystack, uint32_t haystackLen, + const uint8_t* needle, uint32_t firstMatchCharOffset); + + static int64_t genericFind(const uint8_t* haystack, uint32_t haystackLen, const uint8_t* needle, + uint32_t needLen, uint32_t firstMatchCharOffset); + + // Returns the position of the first occurrence of needle in the haystack. If haystack doesn't + // contain needle, it returns -1. + static int64_t find(const uint8_t* haystack, uint32_t haystackLen, const uint8_t* needle, + uint32_t needleLen); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/left_operation.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/left_operation.h new file mode 100644 index 0000000000..5adb7ec6cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/left_operation.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "function/list/functions/list_len_function.h" +#include "substr_function.h" + +namespace lbug { +namespace function { + +struct Left { +public: + static inline void operation(common::ku_string_t& left, int64_t& right, + common::ku_string_t& result, common::ValueVector& resultValueVector) { + int64_t leftLen = 0; + ListLen::operation(left, leftLen); + int64_t len = + (right > -1) ? std::min(leftLen, right) : std::max(leftLen + right, (int64_t)0); + SubStr::operation(left, 1, len, result, resultValueVector); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/lower_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/lower_function.h new file mode 100644 index 0000000000..5e73877176 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/lower_function.h @@ -0,0 +1,18 @@ +#pragma once + +#include "base_lower_upper_function.h" +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct Lower { +public: + static inline void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BaseLowerUpperFunction::operation(input, result, resultValueVector, false /* isUpper */); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/lpad_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/lpad_function.h new file mode 100644 index 0000000000..f2325c204e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/lpad_function.h @@ -0,0 +1,29 @@ +#pragma once + +#include "base_pad_function.h" +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct Lpad : BasePadOperation { +public: + static inline void operation(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BasePadOperation::operation(src, count, characterToPad, result, resultValueVector, + lpadOperation); + } + + static void lpadOperation(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, std::string& paddedResult) { + auto srcPadInfo = + BasePadOperation::padCountChars(count, (const char*)src.getData(), src.len); + auto srcData = (const char*)src.getData(); + BasePadOperation::insertPadding(count - srcPadInfo.second, characterToPad, paddedResult); + paddedResult.insert(paddedResult.end(), srcData, srcData + srcPadInfo.first); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/ltrim_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/ltrim_function.h new file mode 100644 index 0000000000..92ca1ba870 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/ltrim_function.h @@ -0,0 +1,30 @@ +#pragma once + +#include "base_str_function.h" +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct Ltrim { + static inline void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BaseStrOperation::operation(input, result, resultValueVector, ltrim); + } + + static uint32_t ltrim(char* data, uint32_t len) { + auto counter = 0u; + for (; counter < len; counter++) { + if (!isspace(data[counter])) { + break; + } + } + for (uint32_t i = 0; i < len - counter; i++) { + data[i] = data[i + counter]; + } + return len - counter; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/pad_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/pad_function.h new file mode 100644 index 0000000000..2af4c9844a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/pad_function.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/assert.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct PadOperation { +public: + static inline void operation(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, common::ku_string_t& result, + common::ValueVector& resultValueVector, + void (*padOperation)(common::ku_string_t& result, common::ku_string_t& src, + common::ku_string_t& characterToPad)) { + if (count <= 0) { + result.set("", 0); + return; + } + KU_ASSERT(characterToPad.len == 1); + padOperation(result, src, characterToPad); + common::StringVector::addString(&resultValueVector, result, (const char*)result.getData(), + count); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_extract_all_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_extract_all_function.h new file mode 100644 index 0000000000..92a3a3255d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_extract_all_function.h @@ -0,0 +1,76 @@ +#pragma once + +#include "base_regexp_function.h" +#include "common/exception/runtime.h" +#include "common/vector/value_vector.h" +#include "re2.h" + +namespace lbug { +namespace function { + +struct RegexpExtractAll : BaseRegexpOperation { + static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern, + std::int64_t& group, common::list_entry_t& result, common::ValueVector& resultVector) { + std::vector matches = + regexExtractAll(value.getAsString(), pattern.getAsString(), group); + result = common::ListVector::addList(&resultVector, matches.size()); + auto resultValues = common::ListVector::getListValues(&resultVector, result); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto numBytesPerValue = resultDataVector->getNumBytesPerValue(); + for (const auto& match : matches) { + common::ku_string_t kuString; + copyToLbugString(match, kuString, *resultDataVector); + resultDataVector->copyFromVectorData(resultValues, resultDataVector, + reinterpret_cast(&kuString)); + resultValues += numBytesPerValue; + } + } + + static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern, + common::list_entry_t& result, common::ValueVector& resultVector) { + int64_t defaultGroup = 0; + operation(value, pattern, defaultGroup, result, resultVector); + } + + static std::vector regexExtractAll(const std::string& value, + const std::string& pattern, std::int64_t& group) { + RE2 regex(parseCypherPattern(pattern)); + auto submatchCount = regex.NumberOfCapturingGroups() + 1; + if (group >= submatchCount) { + throw common::RuntimeException("Regex match group index is out of range"); + } + + regex::StringPiece input(value); + std::vector targetSubMatches; + targetSubMatches.resize(submatchCount); + uint64_t startPos = 0; + + std::vector matches; + while (regex.Match(input, startPos, input.length(), RE2::UNANCHORED, + targetSubMatches.data(), submatchCount)) { + uint64_t consumed = + static_cast(targetSubMatches[0].end() - (input.begin() + startPos)); + if (!consumed) { + // Empty match found, increment the position manually + consumed++; + while (startPos + consumed < input.length() && + !IsCharacter(input[startPos + consumed])) { + consumed++; + } + } + startPos += consumed; + matches.emplace_back(targetSubMatches[group]); + } + + return matches; + } + + static inline bool IsCharacter(char c) { + // Check if this character is not the middle of utf-8 character i.e. it shouldn't begin with + // 10 XX XX XX + return (c & 0xc0) != 0x80; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_extract_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_extract_function.h new file mode 100644 index 0000000000..feaabe3fec --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_extract_function.h @@ -0,0 +1,46 @@ +#pragma once + +#include "common/exception/runtime.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" +#include "function/string/functions/base_regexp_function.h" +#include "re2.h" + +namespace lbug { +namespace function { + +struct RegexpExtract : BaseRegexpOperation { + static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern, + std::int64_t& group, common::ku_string_t& result, common::ValueVector& resultValueVector) { + regexExtract(value.getAsString(), pattern.getAsString(), group, result, resultValueVector); + } + + static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern, + common::ku_string_t& result, common::ValueVector& resultValueVector) { + int64_t defaultGroup = 0; + regexExtract(value.getAsString(), pattern.getAsString(), defaultGroup, result, + resultValueVector); + } + + static void regexExtract(const std::string& input, const std::string& pattern, + std::int64_t& group, common::ku_string_t& result, common::ValueVector& resultValueVector) { + RE2 regex(parseCypherPattern(pattern)); + auto submatchCount = regex.NumberOfCapturingGroups() + 1; + if (group >= submatchCount) { + throw common::RuntimeException("Regex match group index is out of range"); + } + + std::vector targetSubMatches; + targetSubMatches.resize(submatchCount); + + if (!regex.Match(regex::StringPiece(input), 0, input.length(), RE2::UNANCHORED, + targetSubMatches.data(), submatchCount)) { + return; + } + + copyToLbugString(targetSubMatches[group].ToString(), result, resultValueVector); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_matches_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_matches_function.h new file mode 100644 index 0000000000..5bc3a97fea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_matches_function.h @@ -0,0 +1,18 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "function/string/functions/base_regexp_function.h" +#include "re2.h" + +namespace lbug { +namespace function { + +struct RegexpMatches : BaseRegexpOperation { + static inline void operation(common::ku_string_t& left, common::ku_string_t& right, + uint8_t& result) { + result = RE2::PartialMatch(left.getAsString(), parseCypherPattern(right.getAsString())); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_split_to_array_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_split_to_array_function.h new file mode 100644 index 0000000000..1a9f96948d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/regexp_split_to_array_function.h @@ -0,0 +1,65 @@ +#pragma once + +#include "base_regexp_function.h" +#include "common/vector/value_vector.h" +#include "re2.h" + +namespace lbug { +namespace function { + +struct RegexpSplitToArray : BaseRegexpOperation { + static void operation(common::ku_string_t& value, common::ku_string_t& regex, + common::list_entry_t& result, common::ValueVector& resultVector) { + std::vector matches = + regexExtractAll(value.getAsString(), regex.getAsString()); + result = common::ListVector::addList(&resultVector, matches.size()); + auto resultValues = common::ListVector::getListValues(&resultVector, result); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto numBytesPerValue = resultDataVector->getNumBytesPerValue(); + common::ku_string_t kuString; + for (const auto& match : matches) { + copyToLbugString(match, kuString, *resultDataVector); + resultDataVector->copyFromVectorData(resultValues, resultDataVector, + reinterpret_cast(&kuString)); + resultValues += numBytesPerValue; + } + } + + static std::vector regexExtractAll(const std::string& value, + const std::string& pattern) { + RE2 regex(parseCypherPattern(pattern)); + + regex::StringPiece input(value); + regex::StringPiece match; + uint64_t startPos = 0; + + std::vector splitParts; + while (startPos < input.length()) { + if (regex.Match(input, startPos, input.length(), RE2::UNANCHORED, &match, 1)) { + uint64_t matchStart = match.data() - input.data(); + uint64_t matchEnd = matchStart + match.size(); + + if (startPos < matchStart) { + splitParts.emplace_back(value.substr(startPos, matchStart - startPos)); + } + + startPos = matchEnd; + + if (match.size() == 0) { + // Match size is 0. + startPos++; + } + } else { + // No more regexp matches. + if (startPos < input.length()) { + splitParts.emplace_back(value.substr(startPos)); + } + break; + } + } + return splitParts; + } +}; + +} // namespace function +} // namespace lbug \ No newline at end of file diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/repeat_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/repeat_function.h new file mode 100644 index 0000000000..c073519fb6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/repeat_function.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "common/api.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct Repeat { +public: + LBUG_API static void operation(common::ku_string_t& left, int64_t& right, + common::ku_string_t& result, common::ValueVector& resultValueVector); + +private: + static void repeatStr(char* data, const std::string& pattern, uint64_t count) { + for (auto i = 0u; i < count; i++) { + memcpy(data + i * pattern.length(), pattern.c_str(), pattern.length()); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/reverse_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/reverse_function.h new file mode 100644 index 0000000000..9ecff898fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/reverse_function.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/api.h" +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct Reverse { +public: + LBUG_API static void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector); + +private: + static uint32_t reverseStr(char* data, uint32_t len) { + for (auto i = 0u; i < len / 2; i++) { + std::swap(data[i], data[len - i - 1]); + } + return len; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/right_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/right_function.h new file mode 100644 index 0000000000..36ea35cae5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/right_function.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "function/list/functions/list_len_function.h" +#include "substr_function.h" + +namespace lbug { +namespace function { + +struct Right { +public: + static inline void operation(common::ku_string_t& left, int64_t& right, + common::ku_string_t& result, common::ValueVector& resultValueVector) { + int64_t leftLen = 0; + ListLen::operation(left, leftLen); + int64_t len = + (right > -1) ? std::min(leftLen, right) : std::max(leftLen + right, (int64_t)0); + SubStr::operation(left, leftLen - len + 1, len, result, resultValueVector); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/rpad_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/rpad_function.h new file mode 100644 index 0000000000..9eaa14740c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/rpad_function.h @@ -0,0 +1,29 @@ +#pragma once + +#include "base_pad_function.h" +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct Rpad : BasePadOperation { +public: + static inline void operation(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BasePadOperation::operation(src, count, characterToPad, result, resultValueVector, + rpadOperation); + } + + static void rpadOperation(common::ku_string_t& src, int64_t count, + common::ku_string_t& characterToPad, std::string& paddedResult) { + auto srcPadInfo = + BasePadOperation::padCountChars(count, (const char*)src.getData(), src.len); + auto srcData = (const char*)src.getData(); + paddedResult.insert(paddedResult.end(), srcData, srcData + srcPadInfo.first); + BasePadOperation::insertPadding(count - srcPadInfo.second, characterToPad, paddedResult); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/rtrim_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/rtrim_function.h new file mode 100644 index 0000000000..a05454a30a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/rtrim_function.h @@ -0,0 +1,27 @@ +#pragma once + +#include "base_str_function.h" +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct Rtrim { + static inline void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BaseStrOperation::operation(input, result, resultValueVector, rtrim); + } + + static uint32_t rtrim(char* data, uint32_t len) { + int32_t counter = len - 1; + for (; counter >= 0; counter--) { + if (!isspace(data[counter])) { + break; + } + } + return counter + 1; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/starts_with_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/starts_with_function.h new file mode 100644 index 0000000000..c25ec3c142 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/starts_with_function.h @@ -0,0 +1,18 @@ +#pragma once + +#include "common/types/ku_string.h" + +namespace lbug { +namespace function { + +struct StartsWith { + static inline void operation(common::ku_string_t& left, common::ku_string_t& right, + uint8_t& result) { + auto lStr = left.getAsString(); + auto rStr = right.getAsString(); + result = lStr.starts_with(rStr); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/substr_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/substr_function.h new file mode 100644 index 0000000000..861f3fcc72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/substr_function.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" +#include "utf8proc.h" + +namespace lbug { +namespace function { + +struct SubStr { +public: + static inline void operation(common::ku_string_t& src, int64_t start, int64_t len, + common::ku_string_t& result, common::ValueVector& resultValueVector) { + std::string srcStr = src.getAsString(); + bool isAscii = true; + int64_t startPos = start - 1; + int64_t endPos = std::min(srcStr.size(), (size_t)(startPos + len)); + if (startPos >= endPos || startPos < 0 || startPos >= (int64_t)srcStr.size()) { + result.len = 0; + return; + } + // 1 character more than length has to be scanned for diatrics case: y + ˘ = ў. + for (auto i = 0u; i < std::min(srcStr.size(), endPos + 1); i++) { + // UTF-8 character encountered. + if (srcStr[i] & 0x80) { + isAscii = false; + break; + } + } + if (isAscii) { + copySubstr(src, start, len, result, resultValueVector, true /* isAscii */); + } else { + int64_t characterCount = 0, startBytePos = 0, endBytePos = 0; + lbug::utf8proc::utf8proc_grapheme_callback(srcStr.c_str(), srcStr.size(), + [&](int64_t gstart, int64_t /*gend*/) { + if (characterCount == startPos) { + startBytePos = gstart; + } else if (characterCount == endPos) { + endBytePos = gstart; + return false; + } + characterCount++; + return true; + }); + if (endBytePos == 0 && len != 0) { + endBytePos = src.len; + } + // In this case, the function gets the EXACT byte location to start copying from. + copySubstr(src, startBytePos, endBytePos - startBytePos, result, resultValueVector, + false /* isAscii */); + } + } + + static inline void copySubstr(common::ku_string_t& src, int64_t start, int64_t len, + common::ku_string_t& result, common::ValueVector& resultValueVector, bool isAscii) { + auto length = std::min(len, src.len - start + 1); + if (isAscii) { + // For normal ASCII char case, we get to the proper byte position to copy from by doing + // a -1 (since it is guaranteed each char is 1 byte). + common::StringVector::addString(&resultValueVector, result, + (const char*)(src.getData() + start - 1), length); + } else { + // For utf8 char copy, the function gets the exact starting byte position to copy from. + common::StringVector::addString(&resultValueVector, result, + (const char*)(src.getData() + start), length); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/trim_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/trim_function.h new file mode 100644 index 0000000000..52a9ed4def --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/trim_function.h @@ -0,0 +1,24 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "ltrim_function.h" +#include "rtrim_function.h" + +namespace lbug { +namespace function { + +struct Trim : BaseStrOperation { +public: + static inline void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BaseStrOperation::operation(input, result, resultValueVector, trim); + } + +private: + static uint32_t trim(char* data, uint32_t len) { + return Rtrim::rtrim(data, Ltrim::ltrim(data, len)); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/upper_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/upper_function.h new file mode 100644 index 0000000000..cd315f38e2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/functions/upper_function.h @@ -0,0 +1,19 @@ +#pragma once + +#include "common/types/ku_string.h" +#include "common/vector/value_vector.h" +#include "function/string/functions/base_lower_upper_function.h" + +namespace lbug { +namespace function { + +struct Upper { +public: + static inline void operation(common::ku_string_t& input, common::ku_string_t& result, + common::ValueVector& resultValueVector) { + BaseLowerUpperFunction::operation(input, result, resultValueVector, true /* isUpper */); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/vector_string_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/vector_string_functions.h new file mode 100644 index 0000000000..0b756d67dc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/string/vector_string_functions.h @@ -0,0 +1,255 @@ +#pragma once + +#include "function/scalar_function.h" +#include "function/string/functions/lower_function.h" +#include "function/string/functions/ltrim_function.h" +#include "function/string/functions/reverse_function.h" +#include "function/string/functions/rtrim_function.h" +#include "function/string/functions/trim_function.h" +#include "function/string/functions/upper_function.h" + +namespace lbug { +namespace function { + +struct VectorStringFunction { + template + static inline function_set getUnaryStrFunction(std::string funcName) { + function_set functionSet; + functionSet.emplace_back(std::make_unique(funcName, + std::vector{common::LogicalTypeID::STRING}, + common::LogicalTypeID::STRING, + ScalarFunction::UnaryStringExecFunction)); + return functionSet; + } +}; + +struct ArrayExtractFunction { + static constexpr const char* name = "ARRAY_EXTRACT"; + + static function_set getFunctionSet(); +}; + +struct ConcatFunction : public VectorStringFunction { + static constexpr const char* name = "CONCAT"; + + static void execFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, void* /*dataPtr*/); + + static function_set getFunctionSet(); +}; + +struct ContainsFunction : public VectorStringFunction { + static constexpr const char* name = "CONTAINS"; + + static function_set getFunctionSet(); +}; + +struct EndsWithFunction : public VectorStringFunction { + static constexpr const char* name = "ENDS_WITH"; + + static function_set getFunctionSet(); +}; + +struct SuffixFunction { + using alias = EndsWithFunction; + + static constexpr const char* name = "SUFFIX"; +}; + +struct LeftFunction : public VectorStringFunction { + static constexpr const char* name = "LEFT"; + + static function_set getFunctionSet(); +}; + +struct LowerFunction : public VectorStringFunction { + static constexpr const char* name = "LOWER"; + + static function_set getFunctionSet() { return getUnaryStrFunction(name); } +}; + +struct ToLowerFunction : public VectorStringFunction { + using alias = LowerFunction; + + static constexpr const char* name = "TOLOWER"; +}; + +struct LcaseFunction { + using alias = LowerFunction; + + static constexpr const char* name = "LCASE"; +}; + +struct LpadFunction : public VectorStringFunction { + static constexpr const char* name = "LPAD"; + + static function_set getFunctionSet(); +}; + +struct LtrimFunction : public VectorStringFunction { + static constexpr const char* name = "LTRIM"; + + static inline function_set getFunctionSet() { return getUnaryStrFunction(name); } +}; + +struct RepeatFunction : public VectorStringFunction { + static constexpr const char* name = "REPEAT"; + + static function_set getFunctionSet(); +}; + +struct ReverseFunction : public VectorStringFunction { + static constexpr const char* name = "REVERSE"; + + static inline function_set getFunctionSet() { return getUnaryStrFunction(name); } +}; + +struct RightFunction : public VectorStringFunction { + static constexpr const char* name = "RIGHT"; + + static function_set getFunctionSet(); +}; + +struct RpadFunction : public VectorStringFunction { + static constexpr const char* name = "RPAD"; + + static function_set getFunctionSet(); +}; + +struct RtrimFunction : public VectorStringFunction { + static constexpr const char* name = "RTRIM"; + + static inline function_set getFunctionSet() { return getUnaryStrFunction(name); } +}; + +struct StartsWithFunction : public VectorStringFunction { + static constexpr const char* name = "STARTS_WITH"; + + static function_set getFunctionSet(); +}; + +struct PrefixFunction { + using alias = StartsWithFunction; + + static constexpr const char* name = "PREFIX"; +}; + +struct SubStrFunction : public VectorStringFunction { + static constexpr const char* name = "SUBSTR"; + + static function_set getFunctionSet(); +}; + +struct SubstringFunction { + using alias = SubStrFunction; + + static constexpr const char* name = "SUBSTRING"; +}; + +struct TrimFunction : public VectorStringFunction { + static constexpr const char* name = "TRIM"; + + static function_set getFunctionSet() { return getUnaryStrFunction(name); } +}; + +struct UpperFunction : public VectorStringFunction { + static constexpr const char* name = "UPPER"; + + static function_set getFunctionSet() { return getUnaryStrFunction(name); } +}; + +struct ToUpperFunction : public VectorStringFunction { + using alias = UpperFunction; + + static constexpr const char* name = "TOUPPER"; +}; + +struct UCaseFunction { + using alias = UpperFunction; + + static constexpr const char* name = "UCASE"; +}; + +struct RegexpFullMatchFunction : public VectorStringFunction { + static constexpr const char* name = "REGEXP_FULL_MATCH"; + + static function_set getFunctionSet(); +}; + +struct RegexpMatchesFunction : public VectorStringFunction { + static constexpr const char* name = "REGEXP_MATCHES"; + + static function_set getFunctionSet(); +}; + +struct RegexpReplaceFunction : public VectorStringFunction { + static constexpr const char* name = "REGEXP_REPLACE"; + static constexpr const char* GLOBAL_REPLACE_OPTION = "g"; + + static function_set getFunctionSet(); +}; + +struct RegexpExtractFunction : public VectorStringFunction { + static constexpr const char* name = "REGEXP_EXTRACT"; + + static function_set getFunctionSet(); +}; + +struct RegexpExtractAllFunction : public VectorStringFunction { + static constexpr const char* name = "REGEXP_EXTRACT_ALL"; + + static function_set getFunctionSet(); +}; + +struct RegexpSplitToArrayFunction : public VectorStringFunction { + static constexpr const char* name = "REGEXP_SPLIT_TO_ARRAY"; + + static function_set getFunctionSet(); +}; + +struct LevenshteinFunction : public VectorStringFunction { + static constexpr const char* name = "LEVENSHTEIN"; + + static function_set getFunctionSet(); +}; + +struct InitCapFunction : public VectorStringFunction { + static constexpr const char* name = "INITCAP"; + + static function_set getFunctionSet(); +}; + +struct StringSplitFunction { + static constexpr const char* name = "STRING_SPLIT"; + + static function_set getFunctionSet(); +}; + +struct StrSplitFunction { + using alias = StringSplitFunction; + + static constexpr const char* name = "STR_SPLIT"; +}; + +struct StringToArrayFunction { + using alias = StringSplitFunction; + + static constexpr const char* name = "STRING_TO_ARRAY"; +}; + +struct SplitPartFunction { + static constexpr const char* name = "SPLIT_PART"; + + static function_set getFunctionSet(); +}; + +struct ConcatWSFunction { + static constexpr const char* name = "CONCAT_WS"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/struct/vector_struct_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/struct/vector_struct_functions.h new file mode 100644 index 0000000000..bf148fec86 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/struct/vector_struct_functions.h @@ -0,0 +1,59 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/function.h" + +namespace lbug { +namespace function { + +struct StructPackFunctions { + static constexpr const char* name = "STRUCT_PACK"; + + static function_set getFunctionSet(); + + static void execFunc(const std::vector>& parameters, + const std::vector& parameterSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr); + static void undirectedRelPackExecFunc( + const std::vector>& parameters, + common::ValueVector& result, void* /*dataPtr*/ = nullptr); + static void compileFunc(FunctionBindData* bindData, + const std::vector>& parameters, + std::shared_ptr& result); + static void undirectedRelCompileFunc(FunctionBindData* bindData, + const std::vector>& parameters, + std::shared_ptr& result); +}; + +struct StructExtractBindData : public FunctionBindData { + common::idx_t childIdx; + + StructExtractBindData(common::LogicalType dataType, common::idx_t childIdx) + : FunctionBindData{std::move(dataType)}, childIdx{childIdx} {} + + std::unique_ptr copy() const override { + return std::make_unique(resultType.copy(), childIdx); + } +}; + +struct StructExtractFunctions { + static constexpr const char* name = "STRUCT_EXTRACT"; + + static function_set getFunctionSet(); + + static std::unique_ptr bindFunc(const ScalarBindFuncInput& input); + + static void compileFunc(FunctionBindData* bindData, + const std::vector>& parameters, + std::shared_ptr& result); +}; + +struct KeysFunctions { + static constexpr const char* name = "KEYS"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/bind_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/bind_data.h new file mode 100644 index 0000000000..d6625c84c8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/bind_data.h @@ -0,0 +1,69 @@ +#pragma once + +#include "common/types/types.h" +#include "optional_params.h" +#include "storage/predicate/column_predicate.h" + +namespace lbug { +namespace common { +class FileSystem; +} + +namespace function { + +struct LBUG_API TableFuncBindData { + binder::expression_vector columns; + common::row_idx_t numRows; + std::unique_ptr optionalParams = nullptr; + + TableFuncBindData() : numRows{0} {} + explicit TableFuncBindData(common::row_idx_t numRows) : numRows{numRows} {} + explicit TableFuncBindData(binder::expression_vector columns) + : columns{std::move(columns)}, numRows{0} {} + TableFuncBindData(binder::expression_vector columns, common::row_idx_t numRows) + : columns{std::move(columns)}, numRows{numRows} {} + TableFuncBindData(const TableFuncBindData& other) + : columns{other.columns}, numRows{other.numRows}, + optionalParams{other.optionalParams == nullptr ? nullptr : other.optionalParams->copy()}, + columnSkips{other.columnSkips}, columnPredicates{copyVector(other.columnPredicates)} {} + TableFuncBindData& operator=(const TableFuncBindData& other) = delete; + virtual ~TableFuncBindData() = default; + + void evaluateParams(main::ClientContext* context) const { + if (!optionalParams) { + return; + } + optionalParams->evaluateParams(context); + } + common::idx_t getNumColumns() const { return columns.size(); } + void setColumnSkips(std::vector skips) { columnSkips = std::move(skips); } + std::vector getColumnSkips() const; + + void setColumnPredicates(std::vector predicates) { + columnPredicates = std::move(predicates); + } + const std::vector& getColumnPredicates() const { + return columnPredicates; + } + + virtual bool getIgnoreErrorsOption() const; + + virtual std::unique_ptr copy() const; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + + template + TARGET& cast() { + return *common::ku_dynamic_cast(this); + } + +protected: + std::vector columnSkips; + std::vector columnPredicates; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/bind_input.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/bind_input.h new file mode 100644 index 0000000000..b64445f0c4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/bind_input.h @@ -0,0 +1,65 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" +#include "common/case_insensitive_map.h" +#include "common/copier_config/file_scan_info.h" +#include "common/types/value/value.h" +#include "parser/query/reading_clause/yield_variable.h" + +namespace lbug { +namespace binder { +class LiteralExpression; +class Binder; +} // namespace binder +namespace main { +class ClientContext; +} + +namespace common { +class Value; +} + +namespace function { + +using optional_params_t = common::case_insensitive_map_t; + +struct TableFunction; + +struct ExtraTableFuncBindInput { + virtual ~ExtraTableFuncBindInput() = default; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } +}; + +struct LBUG_API TableFuncBindInput { + binder::expression_vector params; + optional_params_t optionalParams; + binder::expression_vector optionalParamsLegacy; + std::unique_ptr extraInput = nullptr; + binder::Binder* binder = nullptr; + std::vector yieldVariables; + + TableFuncBindInput() = default; + + void addLiteralParam(common::Value value); + + std::shared_ptr getParam(common::idx_t idx) const { return params[idx]; } + common::Value getValue(common::idx_t idx) const; + template + T getLiteralVal(common::idx_t idx) const; +}; + +struct LBUG_API ExtraScanTableFuncBindInput : ExtraTableFuncBindInput { + common::FileScanInfo fileScanInfo; + std::vector expectedColumnNames; + std::vector expectedColumnTypes; + TableFunction* tableFunction = nullptr; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/optional_params.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/optional_params.h new file mode 100644 index 0000000000..03c12920fd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/optional_params.h @@ -0,0 +1,76 @@ +#pragma once + +#include "binder/expression/expression_util.h" + +namespace lbug { +namespace function { + +template +struct LogicalTypeMapping; + +template<> +struct LogicalTypeMapping { + using type = double; +}; +template<> +struct LogicalTypeMapping { + using type = bool; +}; +template<> +struct LogicalTypeMapping { + using type = uint64_t; +}; +template<> +struct LogicalTypeMapping { + using type = int64_t; +}; + +template<> +struct LogicalTypeMapping { + using type = std::string; +}; + +template +struct OptionalParam { + using T = typename LogicalTypeMapping::type; + std::shared_ptr param = nullptr; + T paramVal = PARAM::DEFAULT_VALUE; + + OptionalParam() {} + + explicit OptionalParam(std::shared_ptr param) : param{std::move(param)} {} + + void evaluateParam(main::ClientContext* context) { + if (!param) { + paramVal = PARAM::DEFAULT_VALUE; + return; + } + if constexpr (requires { PARAM::validate; }) { + paramVal = binder::ExpressionUtil::evaluateLiteral(context, param, + common::LogicalType{PARAM::TYPE}, PARAM::validate); + } else { + paramVal = binder::ExpressionUtil::evaluateLiteral(context, param, + common::LogicalType{PARAM::TYPE}, nullptr /* validateFunc */); + } + } + + bool isSet() const { return param != nullptr; } + + T getParamVal() const { return paramVal; } +}; + +struct OptionalParams { + virtual ~OptionalParams() = default; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + + virtual void evaluateParams(main::ClientContext* /*context*/) = 0; + + virtual std::unique_ptr copy() = 0; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/scan_file_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/scan_file_function.h new file mode 100644 index 0000000000..2d158987f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/scan_file_function.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include "common/constants.h" +#include "common/copier_config/file_scan_info.h" +#include "function/table/bind_data.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace common { +class FileSystem; +} + +namespace function { + +struct ScanFileSharedState : public TableFuncSharedState { + const common::FileScanInfo fileScanInfo; + std::atomic fileIdx; + std::atomic blockIdx; + + ScanFileSharedState(common::FileScanInfo fileScanInfo, uint64_t numRows) + : TableFuncSharedState{numRows}, fileScanInfo{std::move(fileScanInfo)}, fileIdx{0}, + blockIdx{0} {} + + std::pair getNext() { + std::lock_guard guard{mtx}; + return fileIdx >= fileScanInfo.getNumFiles() ? std::make_pair(UINT64_MAX, UINT64_MAX) : + std::make_pair(fileIdx.load(), blockIdx++); + } +}; + +struct ScanFileWithProgressSharedState : ScanFileSharedState { + main::ClientContext* context; + uint64_t totalSize; // TODO(Mattias): I think we should unify the design on how we calculate the + // progress bar for scanning. Can we simply rely on a numRowsScaned stored + // in the TableFuncSharedState to determine the progress. + ScanFileWithProgressSharedState(common::FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context) + : ScanFileSharedState{std::move(fileScanInfo), numRows}, context{context}, totalSize{0} {} +}; + +struct LBUG_API ScanFileBindData : public TableFuncBindData { + common::FileScanInfo fileScanInfo; + main::ClientContext* context; + common::column_id_t numWarningDataColumns = 0; + + ScanFileBindData(binder::expression_vector columns, uint64_t numRows, + common::FileScanInfo fileScanInfo, main::ClientContext* context) + : TableFuncBindData{std::move(columns), numRows}, fileScanInfo{std::move(fileScanInfo)}, + context{context} {} + ScanFileBindData(binder::expression_vector columns, uint64_t numRows, + common::FileScanInfo fileScanInfo, main::ClientContext* context, + common::column_id_t numWarningDataColumns) + : TableFuncBindData{std::move(columns), numRows}, fileScanInfo{std::move(fileScanInfo)}, + context{context}, numWarningDataColumns{numWarningDataColumns} {} + ScanFileBindData(const ScanFileBindData& other) + : TableFuncBindData{other}, fileScanInfo{other.fileScanInfo.copy()}, context{other.context}, + numWarningDataColumns{other.numWarningDataColumns} {} + + bool getIgnoreErrorsOption() const override { + return fileScanInfo.getOption(common::CopyConstants::IGNORE_ERRORS_OPTION_NAME, + common::CopyConstants::DEFAULT_IGNORE_ERRORS); + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/scan_replacement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/scan_replacement.h new file mode 100644 index 0000000000..249adf51be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/scan_replacement.h @@ -0,0 +1,28 @@ +#pragma once + +#include "function/table/bind_input.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace function { + +struct ScanReplacementData { + TableFunction func; + TableFuncBindInput bindInput; +}; + +using scan_replace_handle_t = uint8_t*; +using handle_lookup_func_t = std::function(const std::string&)>; +using scan_replace_func_t = + std::function(std::span)>; + +struct ScanReplacement { + explicit ScanReplacement(handle_lookup_func_t lookupFunc, scan_replace_func_t replaceFunc) + : lookupFunc(std::move(lookupFunc)), replaceFunc{std::move(replaceFunc)} {} + + handle_lookup_func_t lookupFunc; + scan_replace_func_t replaceFunc; +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/simple_table_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/simple_table_function.h new file mode 100644 index 0000000000..76a1e9a0ea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/simple_table_function.h @@ -0,0 +1,188 @@ +#pragma once + +#include "common/system_config.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace function { + +struct TableFuncMorsel { + common::offset_t startOffset; + common::offset_t endOffset; + + TableFuncMorsel(common::offset_t startOffset, common::offset_t endOffset) + : startOffset{startOffset}, endOffset{endOffset} {} + + bool hasMoreToOutput() const { return startOffset != common::INVALID_OFFSET; } + + static TableFuncMorsel createInvalidMorsel() { + return {common::INVALID_OFFSET, common::INVALID_OFFSET}; + } + + uint64_t getMorselSize() const { return endOffset - startOffset; } + + bool isInvalid() const { + return startOffset == common::INVALID_OFFSET && endOffset == common::INVALID_OFFSET; + } +}; + +using simple_internal_table_func = std::function; + +class LBUG_API SimpleTableFuncSharedState : public TableFuncSharedState { +public: + SimpleTableFuncSharedState() = default; + + explicit SimpleTableFuncSharedState(common::row_idx_t numRows, + common::offset_t maxMorselSize = common::DEFAULT_VECTOR_CAPACITY) + : TableFuncSharedState{numRows}, maxMorselSize{maxMorselSize} {} + + virtual TableFuncMorsel getMorsel(); + + common::row_idx_t curRowIdx = 0; + common::offset_t maxMorselSize = common::DEFAULT_VECTOR_CAPACITY; +}; + +struct LBUG_API SimpleTableFunc { + static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input); + + static table_func_t getTableFunc(simple_internal_table_func internalTableFunc); +}; + +struct CurrentSettingFunction final { + static constexpr const char* name = "CURRENT_SETTING"; + + static function_set getFunctionSet(); +}; + +struct CatalogVersionFunction final { + static constexpr const char* name = "CATALOG_VERSION"; + + static function_set getFunctionSet(); +}; + +struct DBVersionFunction final { + static constexpr const char* name = "DB_VERSION"; + + static function_set getFunctionSet(); +}; + +struct ShowTablesFunction final { + static constexpr const char* name = "SHOW_TABLES"; + + static function_set getFunctionSet(); +}; + +struct ShowWarningsFunction final { + static constexpr const char* name = "SHOW_WARNINGS"; + + static function_set getFunctionSet(); +}; + +struct ShowMacrosFunction final { + static constexpr const char* name = "SHOW_MACROS"; + + static function_set getFunctionSet(); +}; + +struct TableInfoFunction final { + static constexpr const char* name = "TABLE_INFO"; + + static function_set getFunctionSet(); +}; + +struct ShowSequencesFunction final { + static constexpr const char* name = "SHOW_SEQUENCES"; + + static function_set getFunctionSet(); +}; + +struct ShowConnectionFunction final { + static constexpr const char* name = "SHOW_CONNECTION"; + + static function_set getFunctionSet(); +}; + +struct StorageInfoFunction final { + static constexpr const char* name = "STORAGE_INFO"; + + static function_set getFunctionSet(); +}; + +struct StatsInfoFunction final { + static constexpr const char* name = "STATS_INFO"; + + static function_set getFunctionSet(); +}; + +struct FreeSpaceInfoFunction final { + static constexpr const char* name = "FSM_INFO"; + + static function_set getFunctionSet(); +}; + +struct BMInfoFunction final { + static constexpr const char* name = "BM_INFO"; + + static function_set getFunctionSet(); +}; + +struct FileInfoFunction final { + static constexpr const char* name = "FILE_INFO"; + + static function_set getFunctionSet(); +}; + +struct ShowAttachedDatabasesFunction final { + static constexpr const char* name = "SHOW_ATTACHED_DATABASES"; + + static function_set getFunctionSet(); +}; + +struct ShowFunctionsFunction final { + static constexpr const char* name = "SHOW_FUNCTIONS"; + + static function_set getFunctionSet(); +}; + +struct ShowLoadedExtensionsFunction final { + static constexpr const char* name = "SHOW_LOADED_EXTENSIONS"; + + static function_set getFunctionSet(); +}; + +struct ShowOfficialExtensionsFunction final { + static constexpr const char* name = "SHOW_OFFICIAL_EXTENSIONS"; + + static function_set getFunctionSet(); +}; + +struct ShowIndexesFunction final { + static constexpr const char* name = "SHOW_INDEXES"; + + static function_set getFunctionSet(); +}; + +struct ShowProjectedGraphsFunction final { + static constexpr const char* name = "SHOW_PROJECTED_GRAPHS"; + + static function_set getFunctionSet(); +}; + +struct ProjectedGraphInfoFunction final { + static constexpr const char* name = "PROJECTED_GRAPH_INFO"; + + static function_set getFunctionSet(); +}; + +// Cache a table column to the transaction local cache. +// Note this is only used for internal purpose, and only supports node tables for now. +struct LocalCacheArrayColumnFunction final { + static constexpr const char* name = "_CACHE_ARRAY_COLUMN_LOCALLY"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/standalone_call_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/standalone_call_function.h new file mode 100644 index 0000000000..4ee96a7cc3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/standalone_call_function.h @@ -0,0 +1,33 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct ClearWarningsFunction { + static constexpr const char* name = "CLEAR_WARNINGS"; + + static function_set getFunctionSet(); +}; + +struct ProjectGraphNativeFunction { + static constexpr const char* name = "PROJECT_GRAPH"; + + static function_set getFunctionSet(); +}; + +struct ProjectGraphCypherFunction { + static constexpr const char* name = "PROJECT_GRAPH_CYPHER"; + + static function_set getFunctionSet(); +}; + +struct DropProjectedGraphFunction { + static constexpr const char* name = "DROP_PROJECTED_GRAPH"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/table_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/table_function.h new file mode 100644 index 0000000000..6eb6103176 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/table/table_function.h @@ -0,0 +1,201 @@ +#pragma once + +#include + +#include "common/data_chunk/data_chunk.h" +#include "common/mask.h" +#include "function/function.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace binder { +class BoundReadingClause; +} +namespace parser { +struct YieldVariable; +class ParsedExpression; +} // namespace parser + +namespace planner { +class LogicalOperator; +class LogicalPlan; +class Planner; +} // namespace planner + +namespace processor { +struct ExecutionContext; +class PlanMapper; +} // namespace processor + +namespace function { + +struct TableFuncBindInput; +struct TableFuncBindData; + +// Shared state +struct LBUG_API TableFuncSharedState { + common::row_idx_t numRows = 0; + // This for now is only used for QueryHNSWIndex. + // TODO(Guodong): This is not a good way to pass semiMasks to QueryHNSWIndex function. + // However, to avoid function specific logic when we handle semi mask in mapper, so we can move + // HNSW into an extension, we have to let semiMasks be owned by a base class. + common::NodeOffsetMaskMap semiMasks; + std::mutex mtx; + + explicit TableFuncSharedState() = default; + explicit TableFuncSharedState(common::row_idx_t numRows) : numRows{numRows} {} + virtual ~TableFuncSharedState() = default; + virtual uint64_t getNumRows() const { return numRows; } + + common::table_id_map_t getSemiMasks() const { return semiMasks.getMasks(); } + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +// Local state +struct TableFuncLocalState { + virtual ~TableFuncLocalState() = default; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +// Execution input +struct TableFuncInput { + TableFuncBindData* bindData; + TableFuncLocalState* localState; + TableFuncSharedState* sharedState; + processor::ExecutionContext* context; + + TableFuncInput() = default; + TableFuncInput(TableFuncBindData* bindData, TableFuncLocalState* localState, + TableFuncSharedState* sharedState, processor::ExecutionContext* context) + : bindData{bindData}, localState{localState}, sharedState{sharedState}, context{context} {} + DELETE_COPY_DEFAULT_MOVE(TableFuncInput); +}; + +// Execution output. +// We might want to merge this with TableFuncLocalState. Also not all table function output vectors +// in a single dataChunk, e.g. FTableScan. In future, if we have more cases, we should consider +// make TableFuncOutput pure virtual. +struct TableFuncOutput { + common::DataChunk dataChunk; + + explicit TableFuncOutput(common::DataChunk dataChunk) : dataChunk{std::move(dataChunk)} {} + virtual ~TableFuncOutput() = default; + + void resetState(); + void setOutputSize(common::offset_t size) const; +}; + +struct LBUG_API TableFuncInitSharedStateInput final { + TableFuncBindData* bindData; + processor::ExecutionContext* context; + + TableFuncInitSharedStateInput(TableFuncBindData* bindData, processor::ExecutionContext* context) + : bindData{bindData}, context{context} {} +}; + +// Init local state +struct TableFuncInitLocalStateInput { + TableFuncSharedState& sharedState; + TableFuncBindData& bindData; + main::ClientContext* clientContext; + + TableFuncInitLocalStateInput(TableFuncSharedState& sharedState, TableFuncBindData& bindData, + main::ClientContext* clientContext) + : sharedState{sharedState}, bindData{bindData}, clientContext{clientContext} {} +}; + +// Init output +struct TableFuncInitOutputInput { + std::vector outColumnPositions; + processor::ResultSet& resultSet; + + TableFuncInitOutputInput(std::vector outColumnPositions, + processor::ResultSet& resultSet) + : outColumnPositions{std::move(outColumnPositions)}, resultSet{resultSet} {} +}; + +using table_func_bind_t = std::function(main::ClientContext*, + const TableFuncBindInput*)>; +using table_func_t = + std::function; +using table_func_init_shared_t = + std::function(const TableFuncInitSharedStateInput&)>; +using table_func_init_local_t = + std::function(const TableFuncInitLocalStateInput&)>; +using table_func_init_output_t = + std::function(const TableFuncInitOutputInput&)>; +using table_func_can_parallel_t = std::function; +using table_func_progress_t = std::function; +using table_func_finalize_t = + std::function; +using table_func_rewrite_t = + std::function; +using table_func_get_logical_plan_t = + std::function>, planner::LogicalPlan&)>; +using table_func_get_physical_plan_t = std::function( + processor::PlanMapper*, const planner::LogicalOperator*)>; +using table_func_infer_input_types = + std::function(const binder::expression_vector&)>; + +struct LBUG_API TableFunction final : Function { + table_func_t tableFunc = nullptr; + table_func_bind_t bindFunc = nullptr; + table_func_init_shared_t initSharedStateFunc = nullptr; + table_func_init_local_t initLocalStateFunc = nullptr; + table_func_init_output_t initOutputFunc = nullptr; + table_func_can_parallel_t canParallelFunc = [] { return true; }; + table_func_progress_t progressFunc = [](TableFuncSharedState*) { return 0.0; }; + table_func_finalize_t finalizeFunc = [](auto, auto) {}; + table_func_rewrite_t rewriteFunc = nullptr; + table_func_get_logical_plan_t getLogicalPlanFunc = getLogicalPlan; + table_func_get_physical_plan_t getPhysicalPlanFunc = getPhysicalPlan; + table_func_infer_input_types inferInputTypes = nullptr; + + TableFunction() {} + TableFunction(std::string name, std::vector inputTypes) + : Function{std::move(name), std::move(inputTypes)} {} + ~TableFunction() override; + TableFunction(const TableFunction&) = default; + TableFunction& operator=(const TableFunction& other) = default; + DEFAULT_BOTH_MOVE(TableFunction); + + std::string signatureToString() const override { + return common::LogicalTypeUtils::toString(parameterTypeIDs); + } + + std::unique_ptr copy() const { return std::make_unique(*this); } + + // Init local state func + static std::unique_ptr initEmptyLocalState( + const TableFuncInitLocalStateInput& input); + // Init shared state func + static std::unique_ptr initEmptySharedState( + const TableFuncInitSharedStateInput& input); + // Init output func + static std::unique_ptr initSingleDataChunkScanOutput( + const TableFuncInitOutputInput& input); + // Utility functions + static std::vector extractYieldVariables(const std::vector& names, + const std::vector& yieldVariables); + // Get logical plan func + static void getLogicalPlan(planner::Planner* planner, + const binder::BoundReadingClause& boundReadingClause, binder::expression_vector predicates, + planner::LogicalPlan& plan); + // Get physical plan func + static std::unique_ptr getPhysicalPlan( + processor::PlanMapper* planMapper, const planner::LogicalOperator* logicalOp); + // Table func + static common::offset_t emptyTableFunc(const TableFuncInput& input, TableFuncOutput& output); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/ternary_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/ternary_function_executor.h new file mode 100644 index 0000000000..ff31e7badb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/ternary_function_executor.h @@ -0,0 +1,455 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct TernaryFunctionWrapper { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* /*aValueVector*/, void* /*resultValueVector*/, void* /*dataPtr*/) { + OP::operation(a, b, c, result); + } +}; + +struct TernaryStringFunctionWrapper { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* /*aValueVector*/, void* resultValueVector, void* /*dataPtr*/) { + OP::operation(a, b, c, result, *(common::ValueVector*)resultValueVector); + } +}; + +struct TernaryRegexFunctionWrapper { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* /*aValueVector*/, void* resultValueVector, void* dataPtr) { + OP::operation(a, b, c, result, *(common::ValueVector*)resultValueVector, dataPtr); + } +}; + +struct TernaryListFunctionWrapper { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* aValueVector, void* resultValueVector, void* /*dataPtr*/) { + OP::operation(a, b, c, result, *(common::ValueVector*)aValueVector, + *(common::ValueVector*)resultValueVector); + } +}; + +struct TernaryUDFFunctionWrapper { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* /*aValueVector*/, void* /*resultValueVector*/, void* dataPtr) { + OP::operation(a, b, c, result, dataPtr); + } +}; + +struct TernaryFunctionExecutor { + template + static void executeOnValue(common::ValueVector& a, common::ValueVector& b, + common::ValueVector& c, common::ValueVector& result, uint64_t aPos, uint64_t bPos, + uint64_t cPos, uint64_t resPos, void* dataPtr) { + auto resValues = (RESULT_TYPE*)result.getData(); + OP_WRAPPER::template operation( + ((A_TYPE*)a.getData())[aPos], ((B_TYPE*)b.getData())[bPos], + ((C_TYPE*)c.getData())[cPos], resValues[resPos], (void*)&a, (void*)&result, dataPtr); + } + + template + static void executeAllFlat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + auto aPos = (*aSelVector)[0]; + auto bPos = (*bSelVector)[0]; + auto cPos = (*cSelVector)[0]; + auto resPos = (*resultSelVector)[0]; + result.setNull(resPos, a.isNull(aPos) || b.isNull(bPos) || c.isNull(cPos)); + if (!result.isNull(resPos)) { + executeOnValue(a, b, c, result, + aPos, bPos, cPos, resPos, dataPtr); + } + } + + template + static void executeFlatFlatUnflat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + auto aPos = (*aSelVector)[0]; + auto bPos = (*bSelVector)[0]; + if (a.isNull(aPos) || b.isNull(bPos)) { + result.setAllNull(); + } else if (c.hasNoNullsGuarantee()) { + if (cSelVector->isUnfiltered()) { + for (auto i = 0u; i < cSelVector->getSelSize(); ++i) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, aPos, bPos, i, rPos, dataPtr); + } + } else { + for (auto i = 0u; i < cSelVector->getSelSize(); ++i) { + auto pos = (*cSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, aPos, bPos, pos, rPos, dataPtr); + } + } + } else { + if (cSelVector->isUnfiltered()) { + for (auto i = 0u; i < cSelVector->getSelSize(); ++i) { + result.setNull(i, c.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, aPos, bPos, i, rPos, dataPtr); + } + } + } else { + for (auto i = 0u; i < cSelVector->getSelSize(); ++i) { + auto pos = (*cSelVector)[i]; + result.setNull(pos, c.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, aPos, bPos, pos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeFlatUnflatUnflat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + [[maybe_unused]] common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(bSelVector == cSelVector); + auto aPos = (*aSelVector)[0]; + if (a.isNull(aPos)) { + result.setAllNull(); + } else if (b.hasNoNullsGuarantee() && c.hasNoNullsGuarantee()) { + if (bSelVector->isUnfiltered()) { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + executeOnValue(a, b, c, + result, aPos, i, i, i, dataPtr); + } + } else { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + auto pos = (*bSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, aPos, pos, pos, rPos, dataPtr); + } + } + } else { + if (bSelVector->isUnfiltered()) { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + result.setNull(i, b.isNull(i) || c.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, aPos, i, i, rPos, dataPtr); + } + } + } else { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + auto pos = (*bSelVector)[i]; + result.setNull(pos, b.isNull(pos) || c.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, aPos, pos, pos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeFlatUnflatFlat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + auto aPos = (*aSelVector)[0]; + auto cPos = (*cSelVector)[0]; + if (a.isNull(aPos) || c.isNull(cPos)) { + result.setAllNull(); + } else if (b.hasNoNullsGuarantee()) { + if (bSelVector->isUnfiltered()) { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, aPos, i, cPos, rPos, dataPtr); + } + } else { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + auto pos = (*bSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, aPos, pos, cPos, rPos, dataPtr); + } + } + } else { + if (bSelVector->isUnfiltered()) { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + result.setNull(i, b.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, aPos, i, cPos, rPos, dataPtr); + } + } + } else { + for (auto i = 0u; i < bSelVector->getSelSize(); ++i) { + auto pos = (*bSelVector)[i]; + result.setNull(pos, b.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, aPos, pos, cPos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeAllUnFlat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, [[maybe_unused]] common::SelectionVector* bSelVector, + common::ValueVector& c, [[maybe_unused]] common::SelectionVector* cSelVector, + common::ValueVector& result, common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(aSelVector == bSelVector && bSelVector == cSelVector); + if (a.hasNoNullsGuarantee() && b.hasNoNullsGuarantee() && c.hasNoNullsGuarantee()) { + if (aSelVector->isUnfiltered()) { + for (uint64_t i = 0; i < aSelVector->getSelSize(); i++) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, i, i, i, rPos, dataPtr); + } + } else { + for (uint64_t i = 0; i < aSelVector->getSelSize(); i++) { + auto pos = (*aSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, pos, pos, pos, rPos, dataPtr); + } + } + } else { + if (aSelVector->isUnfiltered()) { + for (uint64_t i = 0; i < aSelVector->getSelSize(); i++) { + result.setNull(i, a.isNull(i) || b.isNull(i) || c.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, i, i, i, rPos, dataPtr); + } + } + } else { + for (uint64_t i = 0; i < aSelVector->getSelSize(); i++) { + auto pos = (*aSelVector)[i]; + result.setNull(pos, a.isNull(pos) || b.isNull(pos) || c.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, pos, pos, pos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeUnflatFlatFlat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + auto bPos = (*bSelVector)[0]; + auto cPos = (*cSelVector)[0]; + if (b.isNull(bPos) || c.isNull(cPos)) { + result.setAllNull(); + } else if (a.hasNoNullsGuarantee()) { + if (aSelVector->isUnfiltered()) { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, i, bPos, cPos, rPos, dataPtr); + } + } else { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto pos = (*aSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, pos, bPos, cPos, rPos, dataPtr); + } + } + } else { + if (aSelVector->isUnfiltered()) { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + result.setNull(i, a.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, i, bPos, cPos, rPos, dataPtr); + } + } + } else { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto pos = (*aSelVector)[i]; + result.setNull(pos, a.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, pos, bPos, cPos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeUnflatFlatUnflat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + [[maybe_unused]] common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(aSelVector == cSelVector); + auto bPos = (*bSelVector)[0]; + if (b.isNull(bPos)) { + result.setAllNull(); + } else if (a.hasNoNullsGuarantee() && c.hasNoNullsGuarantee()) { + if (aSelVector->isUnfiltered()) { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, i, bPos, i, rPos, dataPtr); + } + } else { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto pos = (*aSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, pos, bPos, pos, rPos, dataPtr); + } + } + } else { + if (aSelVector->isUnfiltered()) { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + result.setNull(i, a.isNull(i) || c.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, i, bPos, i, rPos, dataPtr); + } + } + } else { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto pos = (*bSelVector)[i]; + result.setNull(pos, a.isNull(pos) || c.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, pos, bPos, pos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeUnflatUnFlatFlat(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, [[maybe_unused]] common::SelectionVector* bSelVector, + common::ValueVector& c, common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + KU_ASSERT(aSelVector == bSelVector); + auto cPos = (*cSelVector)[0]; + if (c.isNull(cPos)) { + result.setAllNull(); + } else if (a.hasNoNullsGuarantee() && b.hasNoNullsGuarantee()) { + if (aSelVector->isUnfiltered()) { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, i, i, cPos, rPos, dataPtr); + } + } else { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto pos = (*aSelVector)[i]; + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, c, + result, pos, pos, cPos, rPos, dataPtr); + } + } + } else { + if (aSelVector->isUnfiltered()) { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + result.setNull(i, a.isNull(i) || b.isNull(i)); + if (!result.isNull(i)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, i, i, cPos, rPos, dataPtr); + } + } + } else { + for (auto i = 0u; i < aSelVector->getSelSize(); ++i) { + auto pos = (*aSelVector)[i]; + result.setNull(pos, a.isNull(pos) || b.isNull(pos)); + if (!result.isNull(pos)) { + auto rPos = (*resultSelVector)[i]; + executeOnValue(a, b, + c, result, pos, pos, cPos, rPos, dataPtr); + } + } + } + } + } + + template + static void executeSwitch(common::ValueVector& a, common::SelectionVector* aSelVector, + common::ValueVector& b, common::SelectionVector* bSelVector, common::ValueVector& c, + common::SelectionVector* cSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + result.resetAuxiliaryBuffer(); + if (a.state->isFlat() && b.state->isFlat() && c.state->isFlat()) { + executeAllFlat(a, aSelVector, b, + bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (a.state->isFlat() && b.state->isFlat() && !c.state->isFlat()) { + executeFlatFlatUnflat(a, + aSelVector, b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (a.state->isFlat() && !b.state->isFlat() && !c.state->isFlat()) { + executeFlatUnflatUnflat(a, + aSelVector, b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (a.state->isFlat() && !b.state->isFlat() && c.state->isFlat()) { + executeFlatUnflatFlat(a, + aSelVector, b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (!a.state->isFlat() && !b.state->isFlat() && !c.state->isFlat()) { + executeAllUnFlat(a, aSelVector, + b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (!a.state->isFlat() && !b.state->isFlat() && c.state->isFlat()) { + executeUnflatUnFlatFlat(a, + aSelVector, b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (!a.state->isFlat() && b.state->isFlat() && c.state->isFlat()) { + executeUnflatFlatFlat(a, + aSelVector, b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else if (!a.state->isFlat() && b.state->isFlat() && !c.state->isFlat()) { + executeUnflatFlatUnflat(a, + aSelVector, b, bSelVector, c, cSelVector, result, resultSelVector, dataPtr); + } else { + KU_ASSERT(false); + } + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/timestamp/timestamp_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/timestamp/timestamp_function.h new file mode 100644 index 0000000000..d7aafa4dd7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/timestamp/timestamp_function.h @@ -0,0 +1,34 @@ +#pragma once + +#include "common/exception/conversion.h" +#include "common/types/interval_t.h" +#include "common/types/timestamp_t.h" +#include "function/cast/functions/numeric_cast.h" + +namespace lbug { +namespace function { + +struct Century { + static inline void operation(common::timestamp_t& timestamp, int64_t& result) { + result = common::Timestamp::getTimestampPart(common::DatePartSpecifier::CENTURY, timestamp); + } +}; + +struct EpochMs { + static inline void operation(int64_t& ms, common::timestamp_t& result) { + result = common::Timestamp::fromEpochMilliSeconds(ms); + } +}; + +struct ToTimestamp { + static inline void operation(double& sec, common::timestamp_t& result) { + int64_t ms = 0; + if (!tryCastWithOverflowCheck(sec * common::Interval::MICROS_PER_SEC, ms)) { + throw common::ConversionException("Could not convert epoch seconds to TIMESTAMP"); + } + result = common::timestamp_t(ms); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/timestamp/vector_timestamp_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/timestamp/vector_timestamp_functions.h new file mode 100644 index 0000000000..6b9568b635 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/timestamp/vector_timestamp_functions.h @@ -0,0 +1,33 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct CenturyFunction { + static constexpr const char* name = "CENTURY"; + + static function_set getFunctionSet(); +}; + +struct EpochMsFunction { + static constexpr const char* name = "EPOCH_MS"; + + static function_set getFunctionSet(); +}; + +struct ToTimestampFunction { + static constexpr const char* name = "TO_TIMESTAMP"; + + static function_set getFunctionSet(); +}; + +struct ToEpochMsFunction { + static constexpr const char* name = "TO_EPOCH_MS"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/udf_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/udf_function.h new file mode 100644 index 0000000000..414f949d8e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/udf_function.h @@ -0,0 +1,283 @@ +#pragma once + +#include "common/exception/binder.h" +#include "common/exception/catalog.h" +#include "common/type_utils.h" +#include "common/types/ku_string.h" +#include "function/scalar_function.h" + +namespace lbug { +namespace function { + +struct UnaryUDFExecutor { + template + static inline void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* udfFunc) { + typedef RESULT_TYPE (*unary_udf_func)(OPERAND_TYPE); + auto unaryUDFFunc = (unary_udf_func)udfFunc; + result = unaryUDFFunc(input); + } +}; + +struct BinaryUDFExecutor { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + void* udfFunc) { + typedef RESULT_TYPE (*binary_udf_func)(LEFT_TYPE, RIGHT_TYPE); + auto binaryUDFFunc = (binary_udf_func)udfFunc; + result = binaryUDFFunc(left, right); + } +}; + +struct TernaryUDFExecutor { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* udfFunc) { + typedef RESULT_TYPE (*ternary_udf_func)(A_TYPE, B_TYPE, C_TYPE); + auto ternaryUDFFunc = (ternary_udf_func)udfFunc; + result = ternaryUDFFunc(a, b, c); + } +}; + +struct UDF { + template + static bool templateValidateType(const common::LogicalTypeID& type) { + auto logicalType = common::LogicalType{type}; + auto physicalType = logicalType.getPhysicalType(); + auto physicalTypeMatch = common::TypeUtils::visit(physicalType, + [](T1) { return std::is_same::value; }); + auto logicalTypeMatch = common::TypeUtils::visit(logicalType, + [](T1) { return std::is_same::value; }); + return logicalTypeMatch || physicalTypeMatch; + } + + template + static void validateType(const common::LogicalTypeID& type) { + if (!templateValidateType(type)) { + throw common::CatalogException{ + "Incompatible udf parameter/return type and templated type."}; + } + } + + template + static function::scalar_func_exec_t createEmptyParameterExecFunc(RESULT_TYPE (*)(Args...), + const std::vector&) { + KU_UNREACHABLE; + } + + template + static function::scalar_func_exec_t createEmptyParameterExecFunc(RESULT_TYPE (*udfFunc)(), + const std::vector&) { + KU_UNUSED(udfFunc); // Disable compiler warnings. + return [udfFunc]( + [[maybe_unused]] const std::vector>& params, + [[maybe_unused]] const std::vector& paramSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr) -> void { + KU_ASSERT(params.empty() && paramSelVectors.empty()); + for (auto i = 0u; i < resultSelVector->getSelSize(); ++i) { + auto resultPos = (*resultSelVector)[i]; + result.copyFromValue(resultPos, common::Value(udfFunc())); + } + }; + } + + template + static function::scalar_func_exec_t createUnaryExecFunc(RESULT_TYPE (* /*udfFunc*/)(Args...), + const std::vector& /*parameterTypes*/) { + KU_UNREACHABLE; + } + + template + static function::scalar_func_exec_t createUnaryExecFunc(RESULT_TYPE (*udfFunc)(OPERAND_TYPE), + const std::vector& parameterTypes) { + if (parameterTypes.size() != 1) { + throw common::CatalogException{ + "Expected exactly one parameter type for unary udf. Got: " + + std::to_string(parameterTypes.size()) + "."}; + } + validateType(parameterTypes[0]); + function::scalar_func_exec_t execFunc = + [udfFunc](const std::vector>& params, + const std::vector& paramSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr) -> void { + KU_ASSERT(params.size() == 1); + UnaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], result, resultSelVector, + (void*)udfFunc); + }; + return execFunc; + } + + template + static function::scalar_func_exec_t createBinaryExecFunc(RESULT_TYPE (* /*udfFunc*/)(Args...), + const std::vector& /*parameterTypes*/) { + KU_UNREACHABLE; + } + + template + static function::scalar_func_exec_t createBinaryExecFunc( + RESULT_TYPE (*udfFunc)(LEFT_TYPE, RIGHT_TYPE), + const std::vector& parameterTypes) { + if (parameterTypes.size() != 2) { + throw common::CatalogException{ + "Expected exactly two parameter types for binary udf. Got: " + + std::to_string(parameterTypes.size()) + "."}; + } + validateType(parameterTypes[0]); + validateType(parameterTypes[1]); + function::scalar_func_exec_t execFunc = + [udfFunc](const std::vector>& params, + const std::vector& paramSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr) -> void { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], + *params[1], paramSelVectors[1], result, resultSelVector, (void*)udfFunc); + }; + return execFunc; + } + + template + static function::scalar_func_exec_t createTernaryExecFunc(RESULT_TYPE (* /*udfFunc*/)(Args...), + const std::vector& /*parameterTypes*/) { + KU_UNREACHABLE; + } + + template + static function::scalar_func_exec_t createTernaryExecFunc( + RESULT_TYPE (*udfFunc)(A_TYPE, B_TYPE, C_TYPE), + std::vector parameterTypes) { + if (parameterTypes.size() != 3) { + throw common::CatalogException{ + "Expected exactly three parameter types for ternary udf. Got: " + + std::to_string(parameterTypes.size()) + "."}; + } + validateType(parameterTypes[0]); + validateType(parameterTypes[1]); + validateType(parameterTypes[2]); + function::scalar_func_exec_t execFunc = + [udfFunc](const std::vector>& params, + const std::vector& paramSelVectors, + common::ValueVector& result, common::SelectionVector* resultSelVector, + void* /*dataPtr*/ = nullptr) -> void { + KU_ASSERT(params.size() == 3); + TernaryFunctionExecutor::executeSwitch(*params[0], paramSelVectors[0], + *params[1], paramSelVectors[1], *params[2], paramSelVectors[2], result, + resultSelVector, (void*)udfFunc); + }; + return execFunc; + } + + template + static scalar_func_exec_t getScalarExecFunc(TR (*udfFunc)(Args...), + std::vector parameterTypes) { + constexpr auto numArgs = sizeof...(Args); + switch (numArgs) { + case 0: + return createEmptyParameterExecFunc(udfFunc, std::move(parameterTypes)); + case 1: + return createUnaryExecFunc(udfFunc, std::move(parameterTypes)); + case 2: + return createBinaryExecFunc(udfFunc, std::move(parameterTypes)); + case 3: + return createTernaryExecFunc(udfFunc, std::move(parameterTypes)); + default: + throw common::BinderException("UDF function only supported until ternary!"); + } + } + + template + static common::LogicalTypeID getParameterType() { + if (std::is_same()) { + return common::LogicalTypeID::BOOL; + } else if (std::is_same()) { + return common::LogicalTypeID::INT8; + } else if (std::is_same()) { + return common::LogicalTypeID::INT16; + } else if (std::is_same()) { + return common::LogicalTypeID::INT32; + } else if (std::is_same()) { + return common::LogicalTypeID::INT64; + } else if (std::is_same()) { + return common::LogicalTypeID::INT128; + } else if (std::is_same()) { + return common::LogicalTypeID::UINT8; + } else if (std::is_same()) { + return common::LogicalTypeID::UINT16; + } else if (std::is_same()) { + return common::LogicalTypeID::UINT32; + } else if (std::is_same()) { + return common::LogicalTypeID::UINT64; + } else if (std::is_same()) { + return common::LogicalTypeID::FLOAT; + } else if (std::is_same()) { + return common::LogicalTypeID::DOUBLE; + } else if (std::is_same()) { + return common::LogicalTypeID::STRING; + } else { + KU_UNREACHABLE; + } + } + + template + static void getParameterTypesRecursive(std::vector& arguments) { + arguments.push_back(getParameterType()); + } + + template + static void getParameterTypesRecursive(std::vector& arguments) { + arguments.push_back(getParameterType()); + getParameterTypesRecursive(arguments); + } + + template + static std::vector getParameterTypes() { + std::vector parameterTypes; + if constexpr (sizeof...(Args) > 0) { + getParameterTypesRecursive(parameterTypes); + } + return parameterTypes; + } + + template + static function_set getFunction(std::string name, TR (*udfFunc)(Args...), + std::vector parameterTypes, common::LogicalTypeID returnType) { + function_set definitions; + if (returnType == common::LogicalTypeID::STRING) { + KU_UNREACHABLE; + } + validateType(returnType); + scalar_func_exec_t scalarExecFunc = getScalarExecFunc(udfFunc, parameterTypes); + definitions.push_back(std::make_unique(std::move(name), + std::move(parameterTypes), returnType, std::move(scalarExecFunc))); + return definitions; + } + + template + static function_set getFunction(std::string name, TR (*udfFunc)(Args...)) { + return getFunction(std::move(name), udfFunc, getParameterTypes(), + getParameterType()); + } + + template + static function_set getVectorizedFunction(std::string name, scalar_func_exec_t execFunc) { + function_set definitions; + definitions.push_back(std::make_unique(std::move(name), + getParameterTypes(), getParameterType(), std::move(execFunc))); + return definitions; + } + + static function_set getVectorizedFunction(std::string name, scalar_func_exec_t execFunc, + std::vector parameterTypes, common::LogicalTypeID returnType) { + function_set definitions; + definitions.push_back(std::make_unique(std::move(name), + std::move(parameterTypes), returnType, std::move(execFunc))); + return definitions; + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/unary_function_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/unary_function_executor.h new file mode 100644 index 0000000000..94dc372f25 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/unary_function_executor.h @@ -0,0 +1,198 @@ +#pragma once + +#include "common/vector/value_vector.h" +#include "function/cast/cast_function_bind_data.h" + +namespace lbug { +namespace function { + +/** + * Unary operator assumes operation with null returns null. This does NOT applies to IS_NULL and + * IS_NOT_NULL operation. + */ + +struct UnaryFunctionWrapper { + template + static inline void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* /*dataPtr*/) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_.getValue(inputPos), + resultVector_.getValue(resultPos)); + } +}; + +struct UnarySequenceFunctionWrapper { + template + static inline void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t /* resultPos */, void* dataPtr) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_.getValue(inputPos), resultVector_, dataPtr); + } +}; + +struct UnaryStringFunctionWrapper { + template + static void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* /*dataPtr*/) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_.getValue(inputPos), + resultVector_.getValue(resultPos), resultVector_); + } +}; + +struct UnaryCastStringFunctionWrapper { + template + static void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* dataPtr) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto resultVector_ = (common::ValueVector*)resultVector; + // TODO(Ziyi): the reinterpret_cast is not safe since we don't always pass + // CastFunctionBindData + FUNC::operation(inputVector_.getValue(inputPos), + resultVector_->getValue(resultPos), resultVector_, inputPos, + &reinterpret_cast(dataPtr)->option); + } +}; + +struct UnaryNestedTypeFunctionWrapper { + template + static inline void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* /*dataPtr*/) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_.getValue(inputPos), + resultVector_.getValue(resultPos), inputVector_, resultVector_); + } +}; + +struct SetSeedFunctionWrapper { + template + static inline void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* dataPtr) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + resultVector_.setNull(resultPos, true /* isNull */); + FUNC::operation(inputVector_.getValue(inputPos), dataPtr); + } +}; + +struct UnaryCastFunctionWrapper { + template + static void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* /*dataPtr*/) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_.getValue(inputPos), + resultVector_.getValue(resultPos), inputVector_, resultVector_); + } +}; + +struct UnaryCastUnionFunctionWrapper { + template + static void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* dataPtr) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_, resultVector_, inputPos, resultPos, dataPtr); + } +}; + +struct UnaryUDFFunctionWrapper { + template + static inline void operation(void* inputVector, uint64_t inputPos, void* resultVector, + uint64_t resultPos, void* dataPtr) { + auto& inputVector_ = *(common::ValueVector*)inputVector; + auto& resultVector_ = *(common::ValueVector*)resultVector; + FUNC::operation(inputVector_.getValue(inputPos), + resultVector_.getValue(resultPos), dataPtr); + } +}; + +struct UnaryFunctionExecutor { + + template + static void executeOnValue(common::ValueVector& inputVector, uint64_t inputPos, + common::ValueVector& resultVector, uint64_t resultPos, void* dataPtr) { + OP_WRAPPER::template operation((void*)&inputVector, + inputPos, (void*)&resultVector, resultPos, dataPtr); + } + + static std::pair getSelectedPos(common::idx_t selIdx, + common::SelectionVector* operandSelVector, common::SelectionVector* resultSelVector, + bool operandIsUnfiltered, bool resultIsUnfiltered) { + common::sel_t operandPos = operandIsUnfiltered ? selIdx : (*operandSelVector)[selIdx]; + common::sel_t resultPos = resultIsUnfiltered ? selIdx : (*resultSelVector)[selIdx]; + return {operandPos, resultPos}; + } + + template + static void executeOnSelectedValues(common::ValueVector& operand, + common::SelectionVector* operandSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + const bool noNullsGuaranteed = operand.hasNoNullsGuarantee(); + if (noNullsGuaranteed) { + result.setAllNonNull(); + } + + const bool operandIsUnfiltered = operandSelVector->isUnfiltered(); + const bool resultIsUnfiltered = resultSelVector->isUnfiltered(); + + for (auto i = 0u; i < operandSelVector->getSelSize(); i++) { + const auto [operandPos, resultPos] = getSelectedPos(i, operandSelVector, + resultSelVector, operandIsUnfiltered, resultIsUnfiltered); + if (noNullsGuaranteed) { + executeOnValue(operand, operandPos, + result, resultPos, dataPtr); + } else { + result.setNull(resultPos, operand.isNull(operandPos)); + if (!result.isNull(resultPos)) { + executeOnValue(operand, operandPos, + result, resultPos, dataPtr); + } + } + } + } + + template + static void executeSwitch(common::ValueVector& operand, + common::SelectionVector* operandSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + result.resetAuxiliaryBuffer(); + if (operand.state->isFlat()) { + auto inputPos = (*operandSelVector)[0]; + auto resultPos = (*resultSelVector)[0]; + result.setNull(resultPos, operand.isNull(inputPos)); + if (!result.isNull(resultPos)) { + executeOnValue(operand, inputPos, + result, resultPos, dataPtr); + } + } else { + executeOnSelectedValues(operand, + operandSelVector, result, resultSelVector, dataPtr); + } + } + + template + static void execute(common::ValueVector& operand, common::SelectionVector* operandSelVector, + common::ValueVector& result, common::SelectionVector* resultSelVector) { + executeSwitch(operand, + operandSelVector, result, resultSelVector, nullptr /* dataPtr */); + } + + template + static void executeSequence(common::ValueVector& operand, + common::SelectionVector* operandSelVector, common::ValueVector& result, + common::SelectionVector* resultSelVector, void* dataPtr) { + result.resetAuxiliaryBuffer(); + auto inputPos = (*operandSelVector)[0]; + auto resultPos = (*resultSelVector)[0]; + executeOnValue(operand, + inputPos, result, resultPos, dataPtr); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/union/functions/union_tag.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/union/functions/union_tag.h new file mode 100644 index 0000000000..cf242ac9bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/union/functions/union_tag.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace lbug { +namespace function { + +struct UnionTag { + static void operation(common::union_entry_t& unionValue, common::ku_string_t& tag, + common::ValueVector& unionVector, common::ValueVector& tagVector) { + auto tagIdxVector = common::UnionVector::getTagVector(&unionVector); + auto tagIdx = tagIdxVector->getValue(unionValue.entry.pos); + auto tagName = common::UnionType::getFieldName(unionVector.dataType, tagIdx); + if (tagName.length() > common::ku_string_t::SHORT_STR_LENGTH) { + tag.overflowPtr = + reinterpret_cast(common::StringVector::getInMemOverflowBuffer(&tagVector) + ->allocateSpace(tagName.length())); + memcpy(reinterpret_cast(tag.overflowPtr), tagName.c_str(), tagName.length()); + memcpy(tag.prefix, tagName.c_str(), common::ku_string_t::PREFIX_LENGTH); + } else { + memcpy(tag.prefix, tagName.c_str(), tagName.length()); + } + tag.len = tagName.length(); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/union/vector_union_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/union/vector_union_functions.h new file mode 100644 index 0000000000..450cf0dc48 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/union/vector_union_functions.h @@ -0,0 +1,27 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct UnionValueFunction { + static constexpr const char* name = "UNION_VALUE"; + + static function_set getFunctionSet(); +}; + +struct UnionTagFunction { + static constexpr const char* name = "UNION_TAG"; + + static function_set getFunctionSet(); +}; + +struct UnionExtractFunction { + static constexpr const char* name = "UNION_EXTRACT"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/utility/function_string_bind_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/utility/function_string_bind_data.h new file mode 100644 index 0000000000..f77f0e5517 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/utility/function_string_bind_data.h @@ -0,0 +1,20 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct FunctionStringBindData : public FunctionBindData { + explicit FunctionStringBindData(std::string str) + : FunctionBindData{common::LogicalType::STRING()}, str{std::move(str)} {} + + std::string str; + + inline std::unique_ptr copy() const override { + return std::make_unique(str); + } +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/utility/vector_utility_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/utility/vector_utility_functions.h new file mode 100644 index 0000000000..137803eed5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/utility/vector_utility_functions.h @@ -0,0 +1,51 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct CoalesceFunction { + static constexpr const char* name = "COALESCE"; + + static function_set getFunctionSet(); +}; + +struct IfNullFunction { + static constexpr const char* name = "IFNULL"; + + static function_set getFunctionSet(); +}; + +struct ConstantOrNullFunction { + static constexpr const char* name = "CONSTANT_OR_NULL"; + + static function_set getFunctionSet(); +}; + +struct CountIfFunction { + static constexpr const char* name = "COUNT_IF"; + + static function_set getFunctionSet(); +}; + +struct ErrorFunction { + static constexpr const char* name = "ERROR"; + + static function_set getFunctionSet(); +}; + +struct NullIfFunction { + static constexpr const char* name = "NULLIF"; + + static function_set getFunctionSet(); +}; + +struct TypeOfFunction { + static constexpr const char* name = "TYPEOF"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/uuid/functions/gen_random_uuid.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/uuid/functions/gen_random_uuid.h new file mode 100644 index 0000000000..e7e8410fab --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/uuid/functions/gen_random_uuid.h @@ -0,0 +1,13 @@ +#pragma once + +#include "common/types/uuid.h" + +namespace lbug { +namespace function { + +struct GenRandomUUID { + static void operation(common::ku_uuid_t& input, void* dataPtr); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/uuid/vector_uuid_functions.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/uuid/vector_uuid_functions.h new file mode 100644 index 0000000000..fb0e13d0ae --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/function/uuid/vector_uuid_functions.h @@ -0,0 +1,15 @@ +#pragma once + +#include "function/function.h" + +namespace lbug { +namespace function { + +struct GenRandomUUIDFunction { + static constexpr const char* name = "GEN_RANDOM_UUID"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph.h new file mode 100644 index 0000000000..faafef0c57 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph.h @@ -0,0 +1,255 @@ +#pragma once + +#include +#include + +#include "common/copy_constructors.h" +#include "common/data_chunk/sel_vector.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include + +namespace lbug { +namespace catalog { +class TableCatalogEntry; +} // namespace catalog +namespace transaction { +class Transaction; +} // namespace transaction + +namespace graph { +struct NativeGraphEntry; + +struct GraphRelInfo { + common::table_id_t srcTableID; + common::table_id_t dstTableID; + catalog::TableCatalogEntry* relGroupEntry; + common::oid_t relTableID; + + GraphRelInfo(common::table_id_t srcTableID, common::table_id_t dstTableID, + catalog::TableCatalogEntry* relGroupEntry, common::oid_t relTableID) + : srcTableID{srcTableID}, dstTableID{dstTableID}, relGroupEntry{relGroupEntry}, + relTableID{relTableID} {} +}; + +class LBUG_API NbrScanState { +public: + struct Chunk { + friend class NbrScanState; + + EXPLICIT_COPY_METHOD(Chunk); + // Any neighbour for which the given function returns false + // will be omitted from future iterations + // Used in GDSTask/EdgeCompute for updating the frontier + template + void forEach(Func&& func) const { + selVector.forEach([&](auto i) { func(nbrNodes, propertyVectors, i); }); + } + template + void forEachBreakWhenFalse(Func&& func) const { + selVector.forEachBreakWhenFalse([&](auto i) -> bool { return func(nbrNodes, i); }); + } + + uint64_t size() const { return selVector.getSelSize(); } + + private: + Chunk(std::span nbrNodes, common::SelectionVector& selVector, + std::span> propertyVectors); + + Chunk(const Chunk& other) noexcept + : nbrNodes{other.nbrNodes}, selVector{other.selVector}, + propertyVectors{other.propertyVectors} {} + + private: + std::span nbrNodes; + // this reference can be modified, but the underlying data will be reset the next time next + // is called + common::SelectionVector& selVector; + std::span> propertyVectors; + }; + + virtual ~NbrScanState() = default; + virtual Chunk getChunk() = 0; + + // Returns true if there are more values after the current batch + virtual bool next() = 0; + +protected: + static Chunk createChunk(std::span nbrNodes, + common::SelectionVector& selVector, + std::span> propertyVectors) { + return Chunk{nbrNodes, selVector, propertyVectors}; + } +}; + +class VertexScanState { +public: + struct Chunk { + friend class VertexScanState; + + size_t size() const { return nodeIDs.size(); } + std::span getNodeIDs() const { return nodeIDs; } + template + std::span getProperties(size_t propertyIndex) const { + return std::span(reinterpret_cast(propertyVectors[propertyIndex]->getData()), + nodeIDs.size()); + } + + private: + LBUG_API Chunk(std::span nodeIDs, + std::span> propertyVectors); + + private: + std::span nodeIDs; + std::span> propertyVectors; + }; + virtual Chunk getChunk() = 0; + + // Returns true if there are more values after the current batch + virtual bool next() = 0; + + virtual ~VertexScanState() = default; + +protected: + static Chunk createChunk(std::span nodeIDs, + std::span> propertyVectors) { + return Chunk{nodeIDs, propertyVectors}; + } +}; + +/** + * Graph interface to be use by GDS algorithms to get neighbors of nodes. + * + * Instances of Graph are not expected to be thread-safe. Therefore, if Graph is intended to be used + * in a parallel manner, the user should first copy() an instance and give each thread a separate + * copy. It is the responsibility of the implementing Graph class that the copy() is a lightweight + * operation that does not copy large amounts of data between instances. + */ +class Graph { +public: + class EdgeIterator { + public: + explicit constexpr EdgeIterator(NbrScanState* scanState) : scanState{scanState} {} + DEFAULT_BOTH_MOVE(EdgeIterator); + EdgeIterator(const EdgeIterator& other) = default; + EdgeIterator() : scanState{nullptr} {} + using difference_type = std::ptrdiff_t; + using value_type = NbrScanState::Chunk; + + value_type operator*() const { return scanState->getChunk(); } + EdgeIterator& operator++() { + if (!scanState->next()) { + scanState = nullptr; + } + return *this; + } + void operator++(int) { ++*this; } + bool operator==(const EdgeIterator& other) const { + // Only needed for comparing to the end, so they are equal if and only if both are null + return scanState == nullptr && other.scanState == nullptr; + } + // Counts and consumes the iterator + uint64_t count() const { + // TODO(bmwinger): avoid scanning if all that's necessary is to count the results + uint64_t result = 0; + do { + result += scanState->getChunk().size(); + } while (scanState->next()); + return result; + } + + std::vector collectNbrNodes() { + std::vector nbrNodes; + for (const auto chunk : *this) { + nbrNodes.reserve(nbrNodes.size() + chunk.size()); + chunk.forEach( + [&](auto neighbors, auto, auto i) { nbrNodes.push_back(neighbors[i]); }); + } + return nbrNodes; + } + + EdgeIterator& begin() noexcept { return *this; } + static constexpr EdgeIterator end() noexcept { return EdgeIterator(nullptr); } + + private: + NbrScanState* scanState; + }; + static_assert(std::input_iterator); + + Graph() = default; + virtual ~Graph() = default; + + virtual NativeGraphEntry* getGraphEntry() = 0; + + // Get id for all node tables. + virtual std::vector getNodeTableIDs() const = 0; + + // Get max offset of each table as a map. + virtual common::table_id_map_t getMaxOffsetMap( + transaction::Transaction* transaction) const = 0; + + // Get max offset of given table. + virtual common::offset_t getMaxOffset(transaction::Transaction* transaction, + common::table_id_t id) const = 0; + + // Get num nodes for all node tables. + virtual common::offset_t getNumNodes(transaction::Transaction* transaction) const = 0; + + // Get all possible (srcTable, dstTable, relTable)s. + virtual std::vector getRelInfos(common::table_id_t srcTableID) = 0; + + // Prepares scan on the specified relationship table (works for backwards and forwards scans) + virtual std::unique_ptr prepareRelScan(const catalog::TableCatalogEntry& entry, + common::oid_t relTableID, common::table_id_t nbrTableID, + std::vector relProperties, bool randomLookup = true) = 0; + + // Get dst nodeIDs for given src nodeID using forward adjList. + virtual EdgeIterator scanFwd(common::nodeID_t nodeID, NbrScanState& state) = 0; + + // Get dst nodeIDs for given src nodeID tables using backward adjList. + virtual EdgeIterator scanBwd(common::nodeID_t nodeID, NbrScanState& state) = 0; + + class VertexIterator { + public: + explicit constexpr VertexIterator(VertexScanState* scanState) : scanState{scanState} {} + DEFAULT_BOTH_MOVE(VertexIterator); + VertexIterator(const VertexIterator& other) = default; + VertexIterator() : scanState{nullptr} {} + using difference_type = std::ptrdiff_t; + using value_type = VertexScanState::Chunk; + + value_type operator*() const { return scanState->getChunk(); } + VertexIterator& operator++() { + if (!scanState->next()) { + scanState = nullptr; + } + return *this; + } + void operator++(int) { ++*this; } + bool operator==(const VertexIterator& other) const { + // Only needed for comparing to the end, so they are equal if and only if both are null + return scanState == nullptr && other.scanState == nullptr; + } + + VertexIterator& begin() noexcept { return *this; } + static constexpr VertexIterator end() noexcept { return VertexIterator(nullptr); } + + private: + VertexScanState* scanState; + }; + static_assert(std::input_iterator); + + virtual std::unique_ptr prepareVertexScan( + catalog::TableCatalogEntry* tableEntry, const std::vector& properties) = 0; + + virtual VertexIterator scanVertices(common::offset_t startNodeOffset, + common::offset_t endNodeOffsetExclusive, VertexScanState& scanState) = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph_entry.h new file mode 100644 index 0000000000..05b51713b0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph_entry.h @@ -0,0 +1,50 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/copy_constructors.h" +#include "common/types/types.h" + +namespace lbug { +namespace graph { + +struct NativeGraphEntryTableInfo { + catalog::TableCatalogEntry* entry; + + std::shared_ptr nodeOrRel; + std::shared_ptr predicate; + + explicit NativeGraphEntryTableInfo(catalog::TableCatalogEntry* entry) : entry{entry} {} + NativeGraphEntryTableInfo(catalog::TableCatalogEntry* entry, + std::shared_ptr nodeOrRel, + std::shared_ptr predicate) + : entry{entry}, nodeOrRel{std::move(nodeOrRel)}, predicate{std::move(predicate)} {} +}; + +// Organize projected graph similar to CatalogEntry. When we want to share projected graph across +// statements, we need to migrate this class to catalog (or client context). +struct LBUG_API NativeGraphEntry { + std::vector nodeInfos; + std::vector relInfos; + + NativeGraphEntry() = default; + NativeGraphEntry(std::vector nodeEntries, + std::vector relEntries); + EXPLICIT_COPY_DEFAULT_MOVE(NativeGraphEntry); + + bool isEmpty() const { return nodeInfos.empty() && relInfos.empty(); } + + std::vector getNodeTableIDs() const; + std::vector getRelEntries() const; + std::vector getNodeEntries() const; + + const NativeGraphEntryTableInfo& getRelInfo(common::table_id_t tableID) const; + + void setRelPredicate(std::shared_ptr predicate); + +private: + NativeGraphEntry(const NativeGraphEntry& other) + : nodeInfos{other.nodeInfos}, relInfos{other.relInfos} {} +}; + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph_entry_set.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph_entry_set.h new file mode 100644 index 0000000000..660f809b06 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/graph_entry_set.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "parsed_graph_entry.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace graph { + +class GraphEntrySet { +public: + void validateGraphNotExist(const std::string& name) const; + void validateGraphExist(const std::string& name) const; + + bool hasGraph(const std::string& name) const { return nameToEntry.contains(name); } + ParsedGraphEntry* getEntry(const std::string& name) const { + KU_ASSERT(hasGraph(name)); + return nameToEntry.at(name).get(); + } + void addGraph(const std::string& name, std::unique_ptr entry) { + nameToEntry.insert({name, std::move(entry)}); + } + void dropGraph(const std::string& name) { nameToEntry.erase(name); } + + const std::unordered_map>& + getNameToEntryMap() const { + return nameToEntry; + } + + LBUG_API static GraphEntrySet* Get(const main::ClientContext& context); + +private: + std::unordered_map> nameToEntry; +}; + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/on_disk_graph.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/on_disk_graph.h new file mode 100644 index 0000000000..8225c949f3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/on_disk_graph.h @@ -0,0 +1,163 @@ +#pragma once + +#include + +#include "common/assert.h" +#include "common/copy_constructors.h" +#include "common/data_chunk/sel_vector.h" +#include "common/enums/rel_direction.h" +#include "common/mask.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "graph.h" +#include "graph_entry.h" +#include "processor/operator/filtering_operator.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} + +namespace graph { + +class OnDiskGraphNbrScanState : public NbrScanState { + friend class OnDiskGraph; + +public: + OnDiskGraphNbrScanState(main::ClientContext* context, const catalog::TableCatalogEntry& entry, + common::oid_t relTableID, std::shared_ptr predicate); + OnDiskGraphNbrScanState(main::ClientContext* context, const catalog::TableCatalogEntry& entry, + common::oid_t relTableID, std::shared_ptr predicate, + std::vector relProperties, bool randomLookup = false); + + Chunk getChunk() override { + return createChunk(currentIter->getNbrNodes(), currentIter->getSelVectorUnsafe(), + std::span(propertyVectors.valueVectors)); + } + bool next() override; + + void startScan(common::RelDataDirection direction); + + class InnerIterator : public processor::SelVectorOverWriter { + public: + InnerIterator(const main::ClientContext* context, storage::RelTable* relTable, + std::unique_ptr tableScanState); + + DELETE_COPY_DEFAULT_MOVE(InnerIterator); + + std::span getNbrNodes() const { + RUNTIME_CHECK(for (size_t i = 0; i < getSelVector().getSelSize(); i++) { + KU_ASSERT( + getSelVector().getSelectedPositions()[i] < common::DEFAULT_VECTOR_CAPACITY); + }); + return std::span(&dstVector().getValue(0), + common::DEFAULT_VECTOR_CAPACITY); + } + + common::SelectionVector& getSelVectorUnsafe() { + return tableScanState->outState->getSelVectorUnsafe(); + } + + const common::SelectionVector& getSelVector() const { + return tableScanState->outState->getSelVector(); + } + + bool next(evaluator::ExpressionEvaluator* predicate, common::SemiMask* nbrNodeMask); + void initScan() const; + + common::RelDataDirection getDirection() const { return tableScanState->direction; } + + private: + common::ValueVector& dstVector() const { return *tableScanState->outputVectors[0]; } + + const main::ClientContext* context; + storage::RelTable* relTable; + std::unique_ptr tableScanState; + }; + +private: + std::unique_ptr srcNodeIDVector; + std::unique_ptr dstNodeIDVector; + common::DataChunk propertyVectors; + + std::unique_ptr relPredicateEvaluator; + common::SemiMask* nbrNodeMask = nullptr; + + std::vector directedIterators; + InnerIterator* currentIter = nullptr; +}; + +class OnDiskGraphVertexScanState final : public VertexScanState { +public: + OnDiskGraphVertexScanState(main::ClientContext& context, + const catalog::TableCatalogEntry* tableEntry, + const std::vector& propertyNames); + + void startScan(common::offset_t beginOffset, common::offset_t endOffsetExclusive); + + bool next() override; + Chunk getChunk() override { + return createChunk(std::span(&nodeIDVector->getValue(0), + nodeIDVector->getSelVectorPtr()->getSelSize()), + std::span(propertyVectors.valueVectors)); + } + +private: + const main::ClientContext& context; + const storage::NodeTable& nodeTable; + + common::DataChunk propertyVectors; + std::unique_ptr nodeIDVector; + std::unique_ptr tableScanState; + + common::offset_t numNodesToScan; + common::offset_t currentOffset; + common::offset_t endOffsetExclusive; +}; + +class LBUG_API OnDiskGraph final : public Graph { +public: + OnDiskGraph(main::ClientContext* context, NativeGraphEntry entry); + + NativeGraphEntry* getGraphEntry() override { return &graphEntry; } + + void setNodeOffsetMask(common::NodeOffsetMaskMap* maskMap) { nodeOffsetMaskMap = maskMap; } + + std::vector getNodeTableIDs() const override { + return graphEntry.getNodeTableIDs(); + } + + common::table_id_map_t getMaxOffsetMap( + transaction::Transaction* transaction) const override; + + common::offset_t getMaxOffset(transaction::Transaction* transaction, + common::table_id_t id) const override; + + common::offset_t getNumNodes(transaction::Transaction* transaction) const override; + + std::vector getRelInfos(common::table_id_t srcTableID) override; + + std::unique_ptr prepareRelScan(const catalog::TableCatalogEntry& entry, + common::oid_t relTableID, common::table_id_t nbrTableID, + std::vector relProperties, bool randomLookUp = true) override; + + EdgeIterator scanFwd(common::nodeID_t nodeID, NbrScanState& state) override; + EdgeIterator scanBwd(common::nodeID_t nodeID, NbrScanState& state) override; + + std::unique_ptr prepareVertexScan(catalog::TableCatalogEntry* tableEntry, + const std::vector& propertiesToScan) override; + VertexIterator scanVertices(common::offset_t beginOffset, common::offset_t endOffsetExclusive, + VertexScanState& state) override; + +private: + main::ClientContext* context; + NativeGraphEntry graphEntry; + common::NodeOffsetMaskMap* nodeOffsetMaskMap = nullptr; + common::table_id_map_t nodeIDToNodeTable; + std::vector relInfos; +}; + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/parsed_graph_entry.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/parsed_graph_entry.h new file mode 100644 index 0000000000..c4915909bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/graph/parsed_graph_entry.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include + +#include "common/cast.h" + +namespace lbug { +namespace graph { + +enum class GraphEntryType : uint8_t { + NATIVE = 0, + CYPHER = 1, +}; + +struct GraphEntryTypeUtils { + static std::string toString(GraphEntryType type); +}; + +struct LBUG_API ParsedGraphEntry { + GraphEntryType type; + + explicit ParsedGraphEntry(GraphEntryType type) : type{type} {} + virtual ~ParsedGraphEntry() = default; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct ParsedNativeGraphTableInfo { + std::string tableName; + std::string predicate; + + ParsedNativeGraphTableInfo(std::string tableName, std::string predicate) + : tableName{std::move(tableName)}, predicate{std::move(predicate)} {} +}; + +struct LBUG_API ParsedNativeGraphEntry : ParsedGraphEntry { + std::vector nodeInfos; + std::vector relInfos; + + ParsedNativeGraphEntry(std::vector nodeInfos, + std::vector relInfos) + : ParsedGraphEntry{GraphEntryType::NATIVE}, nodeInfos{std::move(nodeInfos)}, + relInfos{std::move(relInfos)} {} +}; + +struct LBUG_API ParsedCypherGraphEntry : ParsedGraphEntry { + std::string cypherQuery; + + explicit ParsedCypherGraphEntry(std::string cypherQuery) + : ParsedGraphEntry{GraphEntryType::CYPHER}, cypherQuery{std::move(cypherQuery)} {} +}; + +} // namespace graph +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/attached_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/attached_database.h new file mode 100644 index 0000000000..d6e64f871f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/attached_database.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include "extension/catalog_extension.h" +#include "transaction/transaction_manager.h" + +namespace duckdb { +class MaterializedQueryResult; +} + +namespace lbug { +namespace storage { +class StorageManager; +} // namespace storage + +namespace main { + +class AttachedDatabase { +public: + AttachedDatabase(std::string dbName, std::string dbType, + std::unique_ptr catalog) + : dbName{std::move(dbName)}, dbType{std::move(dbType)}, catalog{std::move(catalog)} {} + + virtual ~AttachedDatabase() = default; + + std::string getDBName() const { return dbName; } + + std::string getDBType() const { return dbType; } + + catalog::Catalog* getCatalog() { return catalog.get(); } + + std::unique_ptr executeQuery(const std::string& query); + + void invalidateCache(); + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +protected: + std::string dbName; + std::string dbType; + std::unique_ptr catalog; +}; + +class AttachedLbugDatabase final : public AttachedDatabase { +public: + AttachedLbugDatabase(std::string dbPath, std::string dbName, std::string dbType, + ClientContext* clientContext); + + storage::StorageManager* getStorageManager() { return storageManager.get(); } + + transaction::TransactionManager* getTransactionManager() { return transactionManager.get(); } + +private: + std::unique_ptr storageManager; + std::unique_ptr transactionManager; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/client_config.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/client_config.h new file mode 100644 index 0000000000..1fedc75eca --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/client_config.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +#include "common/enums/path_semantic.h" + +namespace lbug { +namespace main { + +struct ClientConfigDefault { + // 0 means timeout is disabled by default. + static constexpr uint64_t TIMEOUT_IN_MS = 0; + static constexpr uint32_t VAR_LENGTH_MAX_DEPTH = 30; + static constexpr uint64_t SPARSE_FRONTIER_THRESHOLD = 1000; + static constexpr bool ENABLE_SEMI_MASK = true; + static constexpr bool ENABLE_ZONE_MAP = true; + static constexpr bool ENABLE_PROGRESS_BAR = false; + static constexpr uint64_t SHOW_PROGRESS_AFTER = 1000; + static constexpr common::PathSemantic RECURSIVE_PATTERN_SEMANTIC = common::PathSemantic::WALK; + static constexpr uint32_t RECURSIVE_PATTERN_FACTOR = 100; + static constexpr bool DISABLE_MAP_KEY_CHECK = true; + static constexpr uint64_t WARNING_LIMIT = 8 * 1024; + static constexpr bool ENABLE_PLAN_OPTIMIZER = true; + static constexpr bool ENABLE_INTERNAL_CATALOG = false; +}; + +struct ClientConfig { + // System home directory. + std::string homeDirectory; + // File search path. + std::string fileSearchPath; + // If using semi mask in join. + bool enableSemiMask = ClientConfigDefault::ENABLE_SEMI_MASK; + // If using zone map in scan. + bool enableZoneMap = ClientConfigDefault::ENABLE_ZONE_MAP; + // Number of threads for execution. + uint64_t numThreads = 1; + // Timeout (milliseconds). + uint64_t timeoutInMS = ClientConfigDefault::TIMEOUT_IN_MS; + // Variable length maximum depth. + uint32_t varLengthMaxDepth = ClientConfigDefault::VAR_LENGTH_MAX_DEPTH; + // Threshold determines when to switch from sparse frontier to dense frontier + uint64_t sparseFrontierThreshold = ClientConfigDefault::SPARSE_FRONTIER_THRESHOLD; + // If using progress bar. + bool enableProgressBar = ClientConfigDefault::ENABLE_PROGRESS_BAR; + // time before displaying progress bar + uint64_t showProgressAfter = ClientConfigDefault::SHOW_PROGRESS_AFTER; + // Semantic for recursive pattern, can be either WALK, TRAIL, ACYCLIC + common::PathSemantic recursivePatternSemantic = ClientConfigDefault::RECURSIVE_PATTERN_SEMANTIC; + // Scale factor for recursive pattern cardinality estimation. + uint32_t recursivePatternCardinalityScaleFactor = ClientConfigDefault::RECURSIVE_PATTERN_FACTOR; + // Maximum number of cached warnings + uint64_t warningLimit = ClientConfigDefault::WARNING_LIMIT; + bool disableMapKeyCheck = ClientConfigDefault::DISABLE_MAP_KEY_CHECK; + // If enable plan optimizer + bool enablePlanOptimizer = ClientConfigDefault::ENABLE_PLAN_OPTIMIZER; + // If use internal catalog during binding + bool enableInternalCatalog = ClientConfigDefault::ENABLE_INTERNAL_CATALOG; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/client_context.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/client_context.h new file mode 100644 index 0000000000..044f7c485f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/client_context.h @@ -0,0 +1,256 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/arrow/arrow_result_config.h" +#include "common/timer.h" +#include "common/types/value/value.h" +#include "function/table/scan_replacement.h" +#include "main/client_config.h" +#include "main/prepared_statement_manager.h" +#include "main/query_result.h" +#include "prepared_statement.h" + +namespace lbug { +namespace common { +class RandomEngine; +class TaskScheduler; +class ProgressBar; +class VirtualFileSystem; +} // namespace common + +namespace catalog { +class Catalog; +} + +namespace extension { +class ExtensionManager; +} // namespace extension + +namespace graph { +class GraphEntrySet; +} + +namespace storage { +class StorageManager; +} + +namespace processor { +class ImportDB; +class WarningContext; +} // namespace processor + +namespace transaction { +class TransactionContext; +class Transaction; +} // namespace transaction + +namespace main { +struct DBConfig; +class Database; +class DatabaseManager; +class AttachedLbugDatabase; +struct SpillToDiskSetting; +struct ExtensionOption; +class EmbeddedShell; + +struct ActiveQuery { + explicit ActiveQuery(); + std::atomic interrupted; + common::Timer timer; + + void reset(); +}; + +/** + * @brief Contain client side configuration. We make profiler associated per query, so the profiler + * is not maintained in the client context. + */ +class LBUG_API ClientContext { + friend class Connection; + friend class EmbeddedShell; + friend struct SpillToDiskSetting; + friend class processor::ImportDB; + friend class processor::WarningContext; + friend class transaction::TransactionContext; + friend class common::RandomEngine; + friend class common::ProgressBar; + friend class graph::GraphEntrySet; + +public: + explicit ClientContext(Database* database); + ~ClientContext(); + + // Client config + const ClientConfig* getClientConfig() const { return &clientConfig; } + ClientConfig* getClientConfigUnsafe() { return &clientConfig; } + + // Database config + const DBConfig* getDBConfig() const; + DBConfig* getDBConfigUnsafe() const; + common::Value getCurrentSetting(const std::string& optionName) const; + + // Timer and timeout + void interrupt() { activeQuery.interrupted = true; } + bool interrupted() const { return activeQuery.interrupted; } + bool hasTimeout() const { return clientConfig.timeoutInMS != 0; } + void setQueryTimeOut(uint64_t timeoutInMS); + uint64_t getQueryTimeOut() const; + void startTimer(); + uint64_t getTimeoutRemainingInMS() const; + void resetActiveQuery() { activeQuery.reset(); } + + // Parallelism + void setMaxNumThreadForExec(uint64_t numThreads); + uint64_t getMaxNumThreadForExec() const; + + // Replace function. + void addScanReplace(function::ScanReplacement scanReplacement); + std::unique_ptr tryReplaceByName( + const std::string& objectName) const; + std::unique_ptr tryReplaceByHandle( + function::scan_replace_handle_t handle) const; + + // Extension + void setExtensionOption(std::string name, common::Value value); + const ExtensionOption* getExtensionOption(std::string optionName) const; + std::string getExtensionDir() const; + + // Getters. + std::string getDatabasePath() const; + Database* getDatabase() const; + AttachedLbugDatabase* getAttachedDatabase() const; + + const CachedPreparedStatementManager& getCachedPreparedStatementManager() const { + return cachedPreparedStatementManager; + } + + bool isInMemory() const; + + static std::string getEnvVariable(const std::string& name); + static std::string getUserHomeDir(); + + void setDefaultDatabase(AttachedLbugDatabase* defaultDatabase_); + bool hasDefaultDatabase() const; + void setUseInternalCatalogEntry(bool useInternalCatalogEntry) { + this->useInternalCatalogEntry_ = useInternalCatalogEntry; + } + bool useInternalCatalogEntry() const { + return clientConfig.enableInternalCatalog ? true : useInternalCatalogEntry_; + } + + void addScalarFunction(std::string name, function::function_set definitions); + void removeScalarFunction(const std::string& name); + + void cleanUp(); + + struct QueryConfig { + QueryResultType resultType; + common::ArrowResultConfig arrowConfig; + + QueryConfig() : resultType{QueryResultType::FTABLE}, arrowConfig{} {} + QueryConfig(QueryResultType resultType, common::ArrowResultConfig arrowConfig) + : resultType{resultType}, arrowConfig{arrowConfig} {} + }; + + std::unique_ptr query(std::string_view queryStatement, + std::optional queryID = std::nullopt, QueryConfig config = {}); + std::unique_ptr prepareWithParams(std::string_view query, + std::unordered_map> inputParams = {}); + std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, + std::unordered_map> inputParams, + std::optional queryID = std::nullopt); + + struct TransactionHelper { + enum class TransactionCommitAction : uint8_t { + COMMIT_IF_NEW, + COMMIT_IF_AUTO, + COMMIT_NEW_OR_AUTO, + NOT_COMMIT + }; + static bool commitIfNew(TransactionCommitAction action) { + return action == TransactionCommitAction::COMMIT_IF_NEW || + action == TransactionCommitAction::COMMIT_NEW_OR_AUTO; + } + static bool commitIfAuto(TransactionCommitAction action) { + return action == TransactionCommitAction::COMMIT_IF_AUTO || + action == TransactionCommitAction::COMMIT_NEW_OR_AUTO; + } + static TransactionCommitAction getAction(bool commitIfNew, bool commitIfAuto); + static void runFuncInTransaction(transaction::TransactionContext& context, + const std::function& fun, bool readOnlyStatement, bool isTransactionStatement, + TransactionCommitAction action); + }; + +private: + void validateTransaction(bool readOnly, bool requireTransaction) const; + + std::vector> parseQuery(std::string_view query); + + struct PrepareResult { + std::unique_ptr preparedStatement; + std::unique_ptr cachedPreparedStatement; + }; + + PrepareResult prepareNoLock(std::shared_ptr parsedStatement, + bool shouldCommitNewTransaction, + std::unordered_map> inputParams = {}); + + template + std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, + std::unordered_map> params, + std::pair arg, std::pair... args) { + auto name = arg.first; + auto val = std::make_unique((T)arg.second); + params.insert({name, std::move(val)}); + return executeWithParams(preparedStatement, std::move(params), args...); + } + + std::unique_ptr executeNoLock(PreparedStatement* preparedStatement, + CachedPreparedStatement* cachedPreparedStatement, + std::optional queryID = std::nullopt, QueryConfig config = {}); + std::unique_ptr queryNoLock(std::string_view query, + std::optional queryID = std::nullopt, QueryConfig config = {}); + + bool canExecuteWriteQuery() const; + + std::unique_ptr handleFailedExecution(std::optional queryID, + const std::exception& e) const; + + std::mutex mtx; + // Client side configurable settings. + ClientConfig clientConfig; + // Current query. + ActiveQuery activeQuery; + // Cache prepare statement. + CachedPreparedStatementManager cachedPreparedStatementManager; + // Transaction context. + std::unique_ptr transactionContext; + // Replace external object as pointer Value; + std::vector scanReplacements; + // Extension configurable settings. + std::unordered_map extensionOptionValues; + // Random generator for UUID. + std::unique_ptr randomEngine; + // Local database. + Database* localDatabase; + // Remote database. + AttachedLbugDatabase* remoteDatabase; + // Progress bar. + std::unique_ptr progressBar; + // Warning information + std::unique_ptr warningContext; + // Graph entries + std::unique_ptr graphEntrySet; + // Whether the query can access internal tables/sequences or not. + bool useInternalCatalogEntry_ = false; + // Whether the transaction should be rolled back on destruction. If the parent database is + // closed, the rollback should be prevented or it will SEGFAULT. + bool preventTransactionRollbackOnDestruction = false; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/connection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/connection.h new file mode 100644 index 0000000000..fac31a45a3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/connection.h @@ -0,0 +1,161 @@ +#pragma once + +#include "client_context.h" +#include "database.h" +#include "function/udf_function.h" + +namespace lbug { +namespace main { + +/** + * @brief Connection is used to interact with a Database instance. Each Connection is thread-safe. + * Multiple connections can connect to the same Database instance in a multi-threaded environment. + */ +class Connection { + friend class testing::BaseGraphTest; + friend class testing::PrivateGraphTest; + friend class testing::TestHelper; + friend class benchmark::Benchmark; + friend class ConnectionExecuteAsyncWorker; + friend class ConnectionQueryAsyncWorker; + +public: + /** + * @brief Creates a connection to the database. + * @param database A pointer to the database instance that this connection will be connected to. + */ + LBUG_API explicit Connection(Database* database); + /** + * @brief Destructs the connection. + */ + LBUG_API ~Connection(); + /** + * @brief Sets the maximum number of threads to use for execution in the current connection. + * @param numThreads The number of threads to use for execution in the current connection. + */ + LBUG_API void setMaxNumThreadForExec(uint64_t numThreads); + /** + * @brief Returns the maximum number of threads to use for execution in the current connection. + * @return the maximum number of threads to use for execution in the current connection. + */ + LBUG_API uint64_t getMaxNumThreadForExec(); + + /** + * @brief Executes the given query and returns the result. + * @param query The query to execute. + * @return the result of the query. + */ + LBUG_API std::unique_ptr query(std::string_view query); + + LBUG_API std::unique_ptr queryAsArrow(std::string_view query, int64_t chunkSize); + + /** + * @brief Prepares the given query and returns the prepared statement. + * @param query The query to prepare. + * @return the prepared statement. + */ + LBUG_API std::unique_ptr prepare(std::string_view query); + + /** + * @brief Prepares the given query and returns the prepared statement. + * @param query The query to prepare. + * @param inputParams The parameter pack where each arg is a pair with the first element + * being parameter name and second element being parameter value. The only parameters that are + * relevant during prepare are ones that will be substituted with a scan source. Any other + * parameters will either be ignored or will cause an error to be thrown. + * @return the prepared statement. + */ + LBUG_API std::unique_ptr prepareWithParams(std::string_view query, + std::unordered_map> inputParams); + + /** + * @brief Executes the given prepared statement with args and returns the result. + * @param preparedStatement The prepared statement to execute. + * @param args The parameter pack where each arg is a std::pair with the first element being + * parameter name and second element being parameter value. + * @return the result of the query. + */ + template + inline std::unique_ptr execute(PreparedStatement* preparedStatement, + std::pair... args) { + std::unordered_map> inputParameters; + return executeWithParams(preparedStatement, std::move(inputParameters), args...); + } + /** + * @brief Executes the given prepared statement with inputParams and returns the result. + * @param preparedStatement The prepared statement to execute. + * @param inputParams The parameter pack where each arg is a std::pair with the first element + * being parameter name and second element being parameter value. + * @return the result of the query. + */ + LBUG_API std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, + std::unordered_map> inputParams); + /** + * @brief interrupts all queries currently executing within this connection. + */ + LBUG_API void interrupt(); + + /** + * @brief sets the query timeout value of the current connection. A value of zero (the default) + * disables the timeout. + */ + LBUG_API void setQueryTimeOut(uint64_t timeoutInMS); + + template + void createScalarFunction(std::string name, TR (*udfFunc)(Args...)) { + addScalarFunction(name, function::UDF::getFunction(name, udfFunc)); + } + + template + void createScalarFunction(std::string name, std::vector parameterTypes, + common::LogicalTypeID returnType, TR (*udfFunc)(Args...)) { + addScalarFunction(name, function::UDF::getFunction(name, udfFunc, + std::move(parameterTypes), returnType)); + } + + void addUDFFunctionSet(std::string name, function::function_set func) { + addScalarFunction(name, std::move(func)); + } + + void removeUDFFunction(std::string name) { removeScalarFunction(name); } + + template + void createVectorizedFunction(std::string name, function::scalar_func_exec_t scalarFunc) { + addScalarFunction(name, + function::UDF::getVectorizedFunction(name, std::move(scalarFunc))); + } + + void createVectorizedFunction(std::string name, + std::vector parameterTypes, common::LogicalTypeID returnType, + function::scalar_func_exec_t scalarFunc) { + addScalarFunction(name, function::UDF::getVectorizedFunction(name, std::move(scalarFunc), + std::move(parameterTypes), returnType)); + } + + ClientContext* getClientContext() { return clientContext.get(); }; + +private: + template + std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, + std::unordered_map> params, + std::pair arg, std::pair... args) { + return clientContext->executeWithParams(preparedStatement, std::move(params), arg, args...); + } + + LBUG_API void addScalarFunction(std::string name, function::function_set definitions); + LBUG_API void removeScalarFunction(std::string name); + + std::unique_ptr queryWithID(std::string_view query, uint64_t queryID); + + std::unique_ptr executeWithParamsWithID(PreparedStatement* preparedStatement, + std::unordered_map> inputParams, + uint64_t queryID); + +private: + Database* database; + std::unique_ptr clientContext; + std::shared_ptr dbLifeCycleManager; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/database.h new file mode 100644 index 0000000000..63a4c75c7b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/database.h @@ -0,0 +1,204 @@ +#pragma once + +#include +#include +#include + +#if defined(__APPLE__) +#include +#endif + +#include "common/api.h" +#include "common/database_lifecycle_manager.h" +#include "lbug_fwd.h" +#include "main/db_config.h" + +namespace lbug { +namespace common { +class FileSystem; +} // namespace common + +namespace extension { +class ExtensionManager; +class TransformerExtension; +class BinderExtension; +class PlannerExtension; +class MapperExtension; +} // namespace extension + +namespace storage { +class StorageExtension; +} // namespace storage + +namespace main { +class DatabaseManager; +/** + * @brief Stores runtime configuration for creating or opening a Database + */ +struct LBUG_API SystemConfig { + /** + * @brief Creates a SystemConfig object. + * @param bufferPoolSize Max size of the buffer pool in bytes. + * The larger the buffer pool, the more data from the database files is kept in memory, + * reducing the amount of File I/O + * @param maxNumThreads The maximum number of threads to use during query execution + * @param enableCompression Whether or not to compress data on-disk for supported types + * @param readOnly If true, the database is opened read-only. No write transaction is + * allowed on the `Database` object. Multiple read-only `Database` objects can be created with + * the same database path. If false, the database is opened read-write. Under this mode, + * there must not be multiple `Database` objects created with the same database path. + * @param maxDBSize The maximum size of the database in bytes. Note that this is introduced + * temporarily for now to get around with the default 8TB mmap address space limit some + * environment. This will be removed once we implemente a better solution later. The value is + * default to 1 << 43 (8TB) under 64-bit environment and 1GB under 32-bit one (see + * `DEFAULT_VM_REGION_MAX_SIZE`). + * @param autoCheckpoint If true, the database will automatically checkpoint when the size of + * the WAL file exceeds the checkpoint threshold. + * @param checkpointThreshold The threshold of the WAL file size in bytes. When the size of the + * WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true. + * @param forceCheckpointOnClose If true, the database will force checkpoint when closing. + * @param throwOnWalReplayFailure If true, any WAL replaying failure when loading the database + * will throw an error. Otherwise, Lbug will silently ignore the failure and replay up to where + * the error occured. + * @param enableChecksums If true, the database will use checksums to detect corruption in the + * WAL file. + */ + explicit SystemConfig(uint64_t bufferPoolSize = -1u, uint64_t maxNumThreads = 0, + bool enableCompression = true, bool readOnly = false, uint64_t maxDBSize = -1u, + bool autoCheckpoint = true, uint64_t checkpointThreshold = 16777216 /* 16MB */, + bool forceCheckpointOnClose = true, bool throwOnWalReplayFailure = true, + bool enableChecksums = true +#if defined(__APPLE__) + , + uint32_t threadQos = QOS_CLASS_DEFAULT +#endif + ); + + uint64_t bufferPoolSize; + uint64_t maxNumThreads; + bool enableCompression; + bool readOnly; + uint64_t maxDBSize; + bool autoCheckpoint; + uint64_t checkpointThreshold; + bool forceCheckpointOnClose; + bool throwOnWalReplayFailure; + bool enableChecksums; +#if defined(__APPLE__) + uint32_t threadQos; +#endif +}; + +/** + * @brief Database class is the main class of Lbug. It manages all database components. + */ +class Database { + friend class EmbeddedShell; + friend class ClientContext; + friend class Connection; + friend class testing::BaseGraphTest; + +public: + /** + * @brief Creates a database object. + * @param databasePath Database path. If left empty, or :memory: is specified, this will create + * an in-memory database. + * @param systemConfig System configurations (buffer pool size and max num threads). + */ + LBUG_API explicit Database(std::string_view databasePath, + SystemConfig systemConfig = SystemConfig()); + /** + * @brief Destructs the database object. + */ + LBUG_API ~Database(); + + LBUG_API void registerFileSystem(std::unique_ptr fs); + + LBUG_API void registerStorageExtension(std::string name, + std::unique_ptr storageExtension); + + LBUG_API void addExtensionOption(std::string name, common::LogicalTypeID type, + common::Value defaultValue, bool isConfidential = false); + + LBUG_API void addTransformerExtension( + std::unique_ptr transformerExtension); + + std::vector getTransformerExtensions(); + + LBUG_API void addBinderExtension( + std::unique_ptr transformerExtension); + + std::vector getBinderExtensions(); + + LBUG_API void addPlannerExtension( + std::unique_ptr plannerExtension); + + std::vector getPlannerExtensions(); + + LBUG_API void addMapperExtension(std::unique_ptr mapperExtension); + + std::vector getMapperExtensions(); + + catalog::Catalog* getCatalog() { return catalog.get(); } + + const DBConfig& getConfig() const { return dbConfig; } + + std::vector getStorageExtensions(); + + uint64_t getNextQueryID(); + + storage::StorageManager* getStorageManager() { return storageManager.get(); } + + transaction::TransactionManager* getTransactionManager() { return transactionManager.get(); } + + DatabaseManager* getDatabaseManager() { return databaseManager.get(); } + + storage::MemoryManager* getMemoryManager() { return memoryManager.get(); } + + processor::QueryProcessor* getQueryProcessor() { return queryProcessor.get(); } + + extension::ExtensionManager* getExtensionManager() { return extensionManager.get(); } + + common::VirtualFileSystem* getVFS() { return vfs.get(); } + +private: + using construct_bm_func_t = + std::function(const Database&)>; + + struct QueryIDGenerator { + uint64_t queryID = 0; + std::mutex queryIDLock; + }; + + static std::unique_ptr initBufferManager(const Database& db); + void initMembers(std::string_view dbPath, construct_bm_func_t initBmFunc); + + // factory method only to be used for tests + Database(std::string_view databasePath, SystemConfig systemConfig, + construct_bm_func_t constructBMFunc); + + void validatePathInReadOnly() const; + +private: + std::string databasePath; + DBConfig dbConfig; + std::unique_ptr vfs; + std::unique_ptr bufferManager; + std::unique_ptr memoryManager; + std::unique_ptr queryProcessor; + std::unique_ptr catalog; + std::unique_ptr storageManager; + std::unique_ptr transactionManager; + std::unique_ptr lockFile; + std::unique_ptr databaseManager; + std::unique_ptr extensionManager; + QueryIDGenerator queryIDGenerator; + std::shared_ptr dbLifeCycleManager; + std::vector> transformerExtensions; + std::vector> binderExtensions; + std::vector> plannerExtensions; + std::vector> mapperExtensions; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/database_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/database_manager.h new file mode 100644 index 0000000000..98b728fa82 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/database_manager.h @@ -0,0 +1,30 @@ +#pragma once + +#include "attached_database.h" + +namespace lbug { +namespace main { + +class DatabaseManager { +public: + DatabaseManager(); + + void registerAttachedDatabase(std::unique_ptr attachedDatabase); + bool hasAttachedDatabase(const std::string& name); + LBUG_API AttachedDatabase* getAttachedDatabase(const std::string& name); + void detachDatabase(const std::string& databaseName); + std::string getDefaultDatabase() const { return defaultDatabase; } + bool hasDefaultDatabase() const { return defaultDatabase != ""; } + void setDefaultDatabase(const std::string& databaseName); + std::vector getAttachedDatabases() const; + LBUG_API void invalidateCache(); + + LBUG_API static DatabaseManager* Get(const ClientContext& context); + +private: + std::vector> attachedDatabases; + std::string defaultDatabase; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/db_config.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/db_config.h new file mode 100644 index 0000000000..366c4b0c8b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/db_config.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include "common/types/value/value.h" + +namespace lbug { +namespace common { +enum class LogicalTypeID : uint8_t; +} // namespace common + +namespace main { + +class ClientContext; +struct SystemConfig; + +typedef void (*set_context)(ClientContext* context, const common::Value& parameter); +typedef common::Value (*get_setting)(const ClientContext* context); + +enum class OptionType : uint8_t { CONFIGURATION = 0, EXTENSION = 1 }; + +struct Option { + std::string name; + common::LogicalTypeID parameterType; + OptionType optionType; + bool isConfidential; + + Option(std::string name, common::LogicalTypeID parameterType, OptionType optionType, + bool isConfidential) + : name{std::move(name)}, parameterType{parameterType}, optionType{optionType}, + isConfidential{isConfidential} {} + + virtual ~Option() = default; +}; + +struct ConfigurationOption final : Option { + set_context setContext; + get_setting getSetting; + + ConfigurationOption(std::string name, common::LogicalTypeID parameterType, + set_context setContext, get_setting getSetting) + : Option{std::move(name), parameterType, OptionType::CONFIGURATION, + false /* isConfidential */}, + setContext{setContext}, getSetting{getSetting} {} +}; + +struct ExtensionOption final : Option { + common::Value defaultValue; + + ExtensionOption(std::string name, common::LogicalTypeID parameterType, + common::Value defaultValue, bool isConfidential) + : Option{std::move(name), parameterType, OptionType::EXTENSION, isConfidential}, + defaultValue{std::move(defaultValue)} {} +}; + +struct DBConfig { + uint64_t bufferPoolSize; + uint64_t maxNumThreads; + bool enableCompression; + bool readOnly; + uint64_t maxDBSize; + bool enableMultiWrites; + bool autoCheckpoint; + uint64_t checkpointThreshold; + bool forceCheckpointOnClose; + bool throwOnWalReplayFailure; + bool enableChecksums; + bool enableSpillingToDisk; +#if defined(__APPLE__) + uint32_t threadQos; +#endif + + explicit DBConfig(const SystemConfig& systemConfig); + + static ConfigurationOption* getOptionByName(const std::string& optionName); + LBUG_API static bool isDBPathInMemory(const std::string& dbPath); +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/lbug.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/lbug.h new file mode 100644 index 0000000000..3b67c12e1c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/lbug.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common/types/date_t.h" // IWYU pragma: export +#include "common/types/dtime_t.h" // IWYU pragma: export +#include "common/types/int128_t.h" // IWYU pragma: export +#include "common/types/interval_t.h" // IWYU pragma: export +#include "common/types/timestamp_t.h" // IWYU pragma: export +#include "common/types/types.h" // IWYU pragma: export +#include "common/types/value/nested.h" // IWYU pragma: export +#include "common/types/value/node.h" // IWYU pragma: export +#include "common/types/value/recursive_rel.h" // IWYU pragma: export +#include "common/types/value/rel.h" // IWYU pragma: export +#include "common/types/value/value.h" // IWYU pragma: export +#include "main/connection.h" // IWYU pragma: export +#include "main/database.h" // IWYU pragma: export +#include "main/prepared_statement.h" // IWYU pragma: export +#include "main/query_result.h" // IWYU pragma: export +#include "main/query_summary.h" // IWYU pragma: export +#include "main/storage_driver.h" // IWYU pragma: export +#include "main/version.h" // IWYU pragma: export +#include "processor/result/flat_tuple.h" // IWYU pragma: export +#include "storage/storage_version_info.h" // IWYU pragma: export diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/lbug_fwd.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/lbug_fwd.h new file mode 100644 index 0000000000..f366429dd9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/lbug_fwd.h @@ -0,0 +1,62 @@ +#pragma once + +#include + +namespace lbug { + +namespace testing { +class BaseGraphTest; +class PrivateGraphTest; +class TestHelper; +class TestRunner; +} // namespace testing + +namespace benchmark { +class Benchmark; +} // namespace benchmark + +namespace binder { +class Expression; +class BoundStatementResult; +class PropertyExpression; +} // namespace binder + +namespace catalog { +class Catalog; +} // namespace catalog + +namespace common { +enum class StatementType : uint8_t; +class Value; +struct FileInfo; +class VirtualFileSystem; +} // namespace common + +namespace storage { +class MemoryManager; +class BufferManager; +class StorageManager; +class WAL; +enum class WALReplayMode : uint8_t; +} // namespace storage + +namespace planner { +class LogicalOperator; +class LogicalPlan; +} // namespace planner + +namespace processor { +class QueryProcessor; +class FactorizedTable; +class FlatTupleIterator; +class PhysicalOperator; +class PhysicalPlan; +} // namespace processor + +namespace transaction { +class Transaction; +class TransactionManager; +class TransactionContext; +} // namespace transaction + +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/plan_printer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/plan_printer.h new file mode 100644 index 0000000000..cd39efb92e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/plan_printer.h @@ -0,0 +1,120 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "common/profiler.h" +#include "json_fwd.hpp" +#include "lbug_fwd.h" + +namespace lbug { +namespace main { + +class OpProfileBox { +public: + OpProfileBox(std::string opName, const std::string& paramsName, + std::vector attributes); + + inline std::string getOpName() const { return opName; } + + inline uint32_t getNumParams() const { return paramsNames.size(); } + + std::string getParamsName(uint32_t idx) const; + + std::string getAttribute(uint32_t idx) const; + + inline uint32_t getNumAttributes() const { return attributes.size(); } + + uint32_t getAttributeMaxLen() const; + +private: + std::string opName; + std::vector paramsNames; + std::vector attributes; +}; + +class OpProfileTree { +public: + OpProfileTree(const processor::PhysicalOperator* op, common::Profiler& profiler); + + explicit OpProfileTree(const planner::LogicalOperator* op); + + std::ostringstream printPlanToOstream() const; + + std::ostringstream printLogicalPlanToOstream() const; + +private: + static void calculateNumRowsAndColsForOp(const processor::PhysicalOperator* op, + uint32_t& numRows, uint32_t& numCols); + + static void calculateNumRowsAndColsForOp(const planner::LogicalOperator* op, uint32_t& numRows, + uint32_t& numCols); + + uint32_t fillOpProfileBoxes(const processor::PhysicalOperator* op, uint32_t rowIdx, + uint32_t colIdx, uint32_t& maxFieldWidth, common::Profiler& profiler); + + uint32_t fillOpProfileBoxes(const planner::LogicalOperator* op, uint32_t rowIdx, + uint32_t colIdx, uint32_t& maxFieldWidth); + + void printOpProfileBoxUpperFrame(uint32_t rowIdx, std::ostringstream& oss) const; + + void printOpProfileBoxes(uint32_t rowIdx, std::ostringstream& oss) const; + + void printOpProfileBoxLowerFrame(uint32_t rowIdx, std::ostringstream& oss) const; + + void prettyPrintPlanTitle(std::ostringstream& oss, std::string title) const; + + static std::string genHorizLine(uint32_t len); + + inline void validateRowIdxAndColIdx(uint32_t rowIdx, uint32_t colIdx) const { + KU_ASSERT(rowIdx < opProfileBoxes.size() && colIdx < opProfileBoxes[rowIdx].size()); + (void)rowIdx; + (void)colIdx; + } + + void insertOpProfileBox(uint32_t rowIdx, uint32_t colIdx, + std::unique_ptr opProfileBox); + + OpProfileBox* getOpProfileBox(uint32_t rowIdx, uint32_t colIdx) const; + + bool hasOpProfileBox(uint32_t rowIdx, uint32_t colIdx) const { + return rowIdx < opProfileBoxes.size() && colIdx < opProfileBoxes[rowIdx].size() && + getOpProfileBox(rowIdx, colIdx); + } + + //! Returns true if there is a valid OpProfileBox on the upper left side of the OpProfileBox + //! located at (rowIdx, colIdx). + bool hasOpProfileBoxOnUpperLeft(uint32_t rowIdx, uint32_t colIdx) const; + + uint32_t calculateRowHeight(uint32_t rowIdx) const; + +private: + std::vector>> opProfileBoxes; + uint32_t opProfileBoxWidth; + static constexpr uint32_t INDENT_WIDTH = 3u; + static constexpr uint32_t BOX_FRAME_WIDTH = 1u; + static constexpr uint32_t MIN_LOGICAL_BOX_WIDTH = 22u; +}; + +struct PlanPrinter { + static nlohmann::json printPlanToJson(const processor::PhysicalPlan* physicalPlan, + common::Profiler* profiler); + static std::ostringstream printPlanToOstream(const processor::PhysicalPlan* physicalPlan, + common::Profiler* profiler); + static std::string getOperatorName(const processor::PhysicalOperator* physicalOperator); + static std::string getOperatorParams(const processor::PhysicalOperator* physicalOperator); + + static nlohmann::json printPlanToJson(const planner::LogicalPlan* logicalPlan); + static std::ostringstream printPlanToOstream(const planner::LogicalPlan* logicalPlan); + static std::string getOperatorName(const planner::LogicalOperator* logicalOperator); + static std::string getOperatorParams(const planner::LogicalOperator* logicalOperator); + +private: + static nlohmann::json toJson(const processor::PhysicalOperator* physicalOperator, + common::Profiler& profiler_); + static nlohmann::json toJson(const planner::LogicalOperator* logicalOperator); +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/prepared_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/prepared_statement.h new file mode 100644 index 0000000000..d166ebc6b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/prepared_statement.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/api.h" +#include "common/types/value/value.h" +#include "query_summary.h" + +namespace lbug { +namespace common { +class LogicalType; +} +namespace parser { +class Statement; +} +namespace binder { +class Expression; +} +namespace planner { +class LogicalPlan; +} + +namespace main { + +// Prepared statement cached in client context and NEVER serialized to client side. +struct CachedPreparedStatement { + bool useInternalCatalogEntry = false; + std::shared_ptr parsedStatement; + std::unique_ptr logicalPlan; + std::vector> columns; + + CachedPreparedStatement(); + ~CachedPreparedStatement(); + + std::vector getColumnNames() const; + std::vector getColumnTypes() const; +}; + +/** + * @brief A prepared statement is a parameterized query which can avoid planning the same query for + * repeated execution. + */ +class PreparedStatement { + friend class Connection; + friend class ClientContext; + +public: + LBUG_API ~PreparedStatement(); + /** + * @return the query is prepared successfully or not. + */ + LBUG_API bool isSuccess() const; + /** + * @return the error message if the query is not prepared successfully. + */ + LBUG_API std::string getErrorMessage() const; + /** + * @return the prepared statement is read-only or not. + */ + LBUG_API bool isReadOnly() const; + + const std::unordered_set& getUnknownParameters() const { + return unknownParameters; + } + std::unordered_set getKnownParameters(); + void updateParameter(const std::string& name, common::Value* value); + void addParameter(const std::string& name, common::Value* value); + + std::string getName() const { return cachedPreparedStatementName; } + + common::StatementType getStatementType() const; + + static std::unique_ptr getPreparedStatementWithError( + const std::string& errorMessage); + +private: + bool success = true; + bool readOnly = true; + std::string errMsg; + PreparedSummary preparedSummary; + std::string cachedPreparedStatementName; + std::unordered_set unknownParameters; + std::unordered_map> parameterMap; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/prepared_statement_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/prepared_statement_manager.h new file mode 100644 index 0000000000..30f3149694 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/prepared_statement_manager.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include + +namespace lbug { +namespace main { + +struct CachedPreparedStatement; + +class CachedPreparedStatementManager { +public: + CachedPreparedStatementManager(); + ~CachedPreparedStatementManager(); + + std::string addStatement(std::unique_ptr statement); + + bool containsStatement(const std::string& name) const { return statementMap.contains(name); } + + CachedPreparedStatement* getCachedStatement(const std::string& name) const; + +private: + std::mutex mtx; + uint32_t currentIdx = 0; + std::unordered_map> statementMap; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result.h new file mode 100644 index 0000000000..5c6ddef465 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result.h @@ -0,0 +1,194 @@ +#pragma once + +#include + +#include "common/api.h" +#include "common/arrow/arrow.h" +#include "common/database_lifecycle_manager.h" +#include "common/types/types.h" +#include "query_summary.h" + +namespace lbug { +namespace processor { +class FlatTuple; +} +namespace main { + +enum class QueryResultType { + FTABLE = 0, + ARROW = 1, +}; + +/** + * @brief QueryResult stores the result of a query execution. + */ +class QueryResult { +public: + /** + * @brief Used to create a QueryResult object for the failing query. + */ + LBUG_API QueryResult(); + explicit QueryResult(QueryResultType type); + QueryResult(QueryResultType type, std::vector columnNames, + std::vector columnTypes); + + /** + * @brief Deconstructs the QueryResult object. + */ + LBUG_API virtual ~QueryResult() = 0; + /** + * @return if the query is executed successfully or not. + */ + LBUG_API bool isSuccess() const; + /** + * @return error message of the query execution if the query fails. + */ + LBUG_API std::string getErrorMessage() const; + /** + * @return number of columns in query result. + */ + LBUG_API size_t getNumColumns() const; + /** + * @return name of each column in the query result. + */ + LBUG_API std::vector getColumnNames() const; + /** + * @return dataType of each column in the query result. + */ + LBUG_API std::vector getColumnDataTypes() const; + /** + * @return query summary which stores the execution time, compiling time, plan and query + * options. + */ + LBUG_API QuerySummary* getQuerySummary() const; + QuerySummary* getQuerySummaryUnsafe(); + /** + * @return whether there are more query results to read. + */ + LBUG_API bool hasNextQueryResult() const; + /** + * @return get the next query result to read (for multiple query statements). + */ + LBUG_API QueryResult* getNextQueryResult(); + /** + * @return num of tuples in query result. + */ + LBUG_API virtual uint64_t getNumTuples() const = 0; + /** + * @return whether there are more tuples to read. + */ + LBUG_API virtual bool hasNext() const = 0; + /** + * @return next flat tuple in the query result. Note that to reduce resource allocation, all + * calls to getNext() reuse the same FlatTuple object. Since its contents will be overwritten, + * please complete processing a FlatTuple or make a copy of its data before calling getNext() + * again. + */ + LBUG_API virtual std::shared_ptr getNext() = 0; + /** + * @brief Resets the result tuple iterator. + */ + LBUG_API virtual void resetIterator() = 0; + /** + * @return string of first query result. + */ + LBUG_API virtual std::string toString() const = 0; + /** + * @brief Returns the arrow schema of the query result. + * @return datatypes of the columns as an arrow schema + * + * It is the caller's responsibility to call the release function to release the underlying data + * If converting to another arrow type, this is usually handled automatically. + */ + LBUG_API std::unique_ptr getArrowSchema() const; + /** + * @return whether there are more arrow chunk to read. + */ + LBUG_API virtual bool hasNextArrowChunk() = 0; + /** + * @brief Returns the next chunk of the query result as an arrow array. + * @param chunkSize number of tuples to return in the chunk. + * @return An arrow array representation of the next chunkSize tuples of the query result. + * + * The ArrowArray internally stores an arrow struct with fields for each of the columns. + * This can be converted to a RecordBatch with arrow's ImportRecordBatch function + * + * It is the caller's responsibility to call the release function to release the underlying data + * If converting to another arrow type, this is usually handled automatically. + */ + LBUG_API virtual std::unique_ptr getNextArrowChunk(int64_t chunkSize) = 0; + + QueryResultType getType() const { return type; } + + void setColumnNames(std::vector columnNames); + void setColumnTypes(std::vector columnTypes); + + void addNextResult(std::unique_ptr next_); + std::unique_ptr moveNextResult(); + + void setQuerySummary(std::unique_ptr summary); + + void setDBLifeCycleManager( + std::shared_ptr dbLifeCycleManager); + + static std::unique_ptr getQueryResultWithError(const std::string& errorMessage); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +protected: + void validateQuerySucceed() const; + void checkDatabaseClosedOrThrow() const; + +protected: + class QueryResultIterator { + public: + QueryResultIterator() = default; + + explicit QueryResultIterator(QueryResult* startResult) : current(startResult) {} + + void operator++() { + if (current) { + current = current->nextQueryResult.get(); + } + } + + bool isEnd() const { return current == nullptr; } + + bool hasNextQueryResult() const { return current->nextQueryResult != nullptr; } + + QueryResult* getCurrentResult() const { return current; } + + private: + QueryResult* current; + }; + + QueryResultType type; + + bool success = true; + + std::string errMsg; + + std::vector columnNames; + + std::vector columnTypes; + + std::shared_ptr tuple; + + std::unique_ptr querySummary; + + std::unique_ptr nextQueryResult; + + QueryResultIterator queryResultIterator; + + std::shared_ptr dbLifeCycleManager; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result/arrow_query_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result/arrow_query_result.h new file mode 100644 index 0000000000..0f63fec48a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result/arrow_query_result.h @@ -0,0 +1,43 @@ +#pragma once + +#include "main/query_result.h" +#include "materialized_query_result.h" + +namespace lbug { +namespace main { + +class ArrowQueryResult : public QueryResult { + static constexpr QueryResultType type_ = QueryResultType::ARROW; + +public: + ArrowQueryResult(std::vector arrays, int64_t chunkSize); + ArrowQueryResult(std::vector columnNames, + std::vector columnTypes, processor::FactorizedTable& table, + int64_t chunkSize); + + uint64_t getNumTuples() const override; + + bool hasNext() const override; + + std::shared_ptr getNext() override; + + void resetIterator() override; + + std::string toString() const override; + + bool hasNextArrowChunk() override; + + std::unique_ptr getNextArrowChunk(int64_t chunkSize) override; + +private: + ArrowArray getArray(processor::FactorizedTableIterator& iterator, int64_t chunkSize); + +private: + std::vector arrays; + int64_t chunkSize_; + uint64_t numTuples = 0; + uint64_t cursor = 0; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result/materialized_query_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result/materialized_query_result.h new file mode 100644 index 0000000000..122c3cd12c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_result/materialized_query_result.h @@ -0,0 +1,46 @@ +#pragma once + +#include "main/query_result.h" + +namespace lbug { +namespace processor { +class FactorizedTable; +class FactorizedTableIterator; +} // namespace processor + +namespace main { + +class MaterializedQueryResult : public QueryResult { + static constexpr QueryResultType type_ = QueryResultType::FTABLE; + +public: + MaterializedQueryResult(); + LBUG_API explicit MaterializedQueryResult(std::shared_ptr table); + MaterializedQueryResult(std::vector columnNames, + std::vector columnTypes, + std::shared_ptr table); + ~MaterializedQueryResult() override; + + uint64_t getNumTuples() const override; + + bool hasNext() const override; + + std::shared_ptr getNext() override; + + void resetIterator() override; + + std::string toString() const override; + + bool hasNextArrowChunk() override; + + std::unique_ptr getNextArrowChunk(int64_t chunkSize) override; + + const processor::FactorizedTable& getFactorizedTable() const { return *table; } + +private: + std::shared_ptr table; + std::unique_ptr iterator; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_summary.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_summary.h new file mode 100644 index 0000000000..bdf51d214e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/query_summary.h @@ -0,0 +1,62 @@ +#pragma once + +#include + +#include "common/api.h" + +namespace lbug { +namespace common { +enum class StatementType : uint8_t; +} + +namespace main { + +/** + * @brief PreparedSummary stores the compiling time and query options of a query. + */ +struct PreparedSummary { // NOLINT(*-pro-type-member-init) + double compilingTime = 0; + common::StatementType statementType; +}; + +/** + * @brief QuerySummary stores the execution time, plan, compiling time and query options of a query. + */ +class QuerySummary { + +public: + QuerySummary() = default; + explicit QuerySummary(const PreparedSummary& preparedSummary) + : preparedSummary{preparedSummary} {} + /** + * @return query compiling time in milliseconds. + */ + LBUG_API double getCompilingTime() const; + /** + * @return query execution time in milliseconds. + */ + LBUG_API double getExecutionTime() const; + + void setExecutionTime(double time); + + void incrementCompilingTime(double increment); + + void incrementExecutionTime(double increment); + + /** + * @return true if the query is executed with EXPLAIN. + */ + bool isExplain() const; + + /** + * @return the statement type of the query. + */ + common::StatementType getStatementType() const; + +private: + double executionTime = 0; + PreparedSummary preparedSummary; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/settings.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/settings.h new file mode 100644 index 0000000000..46a30ac4b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/settings.h @@ -0,0 +1,149 @@ +#pragma once + +#include "common/types/value/value.h" + +namespace lbug { +namespace main { + +struct ThreadsSetting { + static constexpr auto name = "threads"; + static constexpr auto inputType = common::LogicalTypeID::UINT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct WarningLimitSetting { + static constexpr auto name = "warning_limit"; + static constexpr auto inputType = common::LogicalTypeID::UINT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct TimeoutSetting { + static constexpr auto name = "timeout"; + static constexpr auto inputType = common::LogicalTypeID::UINT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct ProgressBarSetting { + static constexpr auto name = "progress_bar"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct VarLengthExtendMaxDepthSetting { + static constexpr auto name = "var_length_extend_max_depth"; + static constexpr auto inputType = common::LogicalTypeID::INT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct SparseFrontierThresholdSetting { + static constexpr auto name = "sparse_frontier_threshold"; + static constexpr auto inputType = common::LogicalTypeID::INT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct EnableSemiMaskSetting { + static constexpr auto name = "enable_semi_mask"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct DisableMapKeyCheck { + static constexpr auto name = "disable_map_key_check"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct EnableZoneMapSetting { + static constexpr auto name = "enable_zone_map"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct HomeDirectorySetting { + static constexpr auto name = "home_directory"; + static constexpr auto inputType = common::LogicalTypeID::STRING; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct FileSearchPathSetting { + static constexpr auto name = "file_search_path"; + static constexpr auto inputType = common::LogicalTypeID::STRING; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct RecursivePatternSemanticSetting { + static constexpr auto name = "recursive_pattern_semantic"; + static constexpr auto inputType = common::LogicalTypeID::STRING; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct RecursivePatternFactorSetting { + static constexpr auto name = "recursive_pattern_factor"; + static constexpr auto inputType = common::LogicalTypeID::INT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct EnableMVCCSetting { + static constexpr auto name = "debug_enable_multi_writes"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct CheckpointThresholdSetting { + static constexpr auto name = "checkpoint_threshold"; + static constexpr auto inputType = common::LogicalTypeID::INT64; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct AutoCheckpointSetting { + static constexpr auto name = "auto_checkpoint"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct ForceCheckpointClosingDBSetting { + static constexpr auto name = "force_checkpoint_on_close"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct SpillToDiskSetting { + static constexpr auto name = "spill_to_disk"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct EnableOptimizerSetting { + static constexpr auto name = "enable_plan_optimizer"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +struct EnableInternalCatalogSetting { + static constexpr auto name = "enable_internal_catalog"; + static constexpr auto inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter); + static common::Value getSetting(const ClientContext* context); +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/storage_driver.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/storage_driver.h new file mode 100644 index 0000000000..cd4cbc9409 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/storage_driver.h @@ -0,0 +1,35 @@ +#pragma once + +#include "database.h" + +namespace lbug { +namespace storage { +class Table; +} + +namespace main { + +class ClientContext; +class LBUG_API StorageDriver { +public: + explicit StorageDriver(Database* database); + + ~StorageDriver(); + + void scan(const std::string& nodeName, const std::string& propertyName, + common::offset_t* offsets, size_t numOffsets, uint8_t* result, size_t numThreads); + + // TODO: Should merge following two functions into a single one. + uint64_t getNumNodes(const std::string& nodeName) const; + uint64_t getNumRels(const std::string& relName) const; + +private: + void scanColumn(storage::Table* table, common::column_id_t columnID, + const common::offset_t* offsets, size_t size, uint8_t* result) const; + +private: + std::unique_ptr clientContext; +}; + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/version.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/version.h new file mode 100644 index 0000000000..6b16c21dc7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/main/version.h @@ -0,0 +1,23 @@ +#pragma once +#include + +#include "common/api.h" +namespace lbug { +namespace main { + +struct Version { +public: + /** + * @brief Get the version of the Lbug library. + * @return const char* The version of the Lbug library. + */ + LBUG_API static const char* getVersion(); + + /** + * @brief Get the storage version of the Lbug library. + * @return uint64_t The storage version of the Lbug library. + */ + LBUG_API static uint64_t getStorageVersion(); +}; +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/acc_hash_join_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/acc_hash_join_optimizer.h new file mode 100644 index 0000000000..c408a59fb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/acc_hash_join_optimizer.h @@ -0,0 +1,26 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +// This optimizer enables the Accumulated hash join algorithm as introduced in paper "Lbug Graph +// Database Management System". +class HashJoinSIPOptimizer final : public LogicalOperatorVisitor { +public: + void rewrite(const planner::LogicalPlan* plan); + +private: + void visitOperator(planner::LogicalOperator* op); + + void visitHashJoin(planner::LogicalOperator* op) override; + + void visitIntersect(planner::LogicalOperator* op) override; + + void visitPathPropertyProbe(planner::LogicalOperator* op) override; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/agg_key_dependency_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/agg_key_dependency_optimizer.h new file mode 100644 index 0000000000..211484c495 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/agg_key_dependency_optimizer.h @@ -0,0 +1,26 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +// This optimizer analyzes the dependency between group by keys. If key2 depends on key1 (e.g. key1 +// is a primary key column) we only hash on key1 and saves key2 as a payload. +class AggKeyDependencyOptimizer : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + +private: + void visitOperator(planner::LogicalOperator* op); + + void visitAggregate(planner::LogicalOperator* op) override; + void visitDistinct(planner::LogicalOperator* op) override; + + std::pair resolveKeysAndDependentKeys( + const binder::expression_vector& keys); +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/cardinality_updater.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/cardinality_updater.h new file mode 100644 index 0000000000..ac79b0eb07 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/cardinality_updater.h @@ -0,0 +1,42 @@ +#pragma once + +#include "optimizer/logical_operator_visitor.h" +namespace lbug { +namespace planner { +class LogicalPlan; +class CardinalityEstimator; +} // namespace planner + +namespace transaction { +class Transaction; +} + +namespace optimizer { +class CardinalityUpdater : public LogicalOperatorVisitor { +public: + explicit CardinalityUpdater(const planner::CardinalityEstimator& cardinalityEstimator, + const transaction::Transaction* transaction) + : cardinalityEstimator(cardinalityEstimator), transaction(transaction) {} + + void rewrite(planner::LogicalPlan* plan); + +private: + void visitOperator(planner::LogicalOperator* op); + void visitOperatorSwitchWithDefault(planner::LogicalOperator* op); + + void visitOperatorDefault(planner::LogicalOperator* op); + void visitScanNodeTable(planner::LogicalOperator* op) override; + void visitExtend(planner::LogicalOperator* op) override; + void visitHashJoin(planner::LogicalOperator* op) override; + void visitCrossProduct(planner::LogicalOperator* op) override; + void visitIntersect(planner::LogicalOperator* op) override; + void visitFlatten(planner::LogicalOperator* op) override; + void visitFilter(planner::LogicalOperator* op) override; + void visitAggregate(planner::LogicalOperator* op) override; + void visitLimit(planner::LogicalOperator* op) override; + + const planner::CardinalityEstimator& cardinalityEstimator; + const transaction::Transaction* transaction; +}; +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/correlated_subquery_unnest_solver.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/correlated_subquery_unnest_solver.h new file mode 100644 index 0000000000..e2fd1b4e5f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/correlated_subquery_unnest_solver.h @@ -0,0 +1,25 @@ +#pragma once + +#include "logical_operator_visitor.h" + +namespace lbug { +namespace optimizer { + +class CorrelatedSubqueryUnnestSolver : public LogicalOperatorVisitor { +public: + explicit CorrelatedSubqueryUnnestSolver(planner::LogicalOperator* accumulateOp) + : accumulateOp{accumulateOp} {} + void solve(planner::LogicalOperator* root_); + +private: + void visitOperator(planner::LogicalOperator* op); + void visitExpressionsScan(planner::LogicalOperator* op) final; + + void solveAccHashJoin(planner::LogicalOperator* op) const; + +private: + planner::LogicalOperator* accumulateOp; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/factorization_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/factorization_rewriter.h new file mode 100644 index 0000000000..9af07fe1de --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/factorization_rewriter.h @@ -0,0 +1,41 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +class FactorizationRewriter final : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + + void visitOperator(planner::LogicalOperator* op); + +private: + void visitHashJoin(planner::LogicalOperator* op) override; + void visitIntersect(planner::LogicalOperator* op) override; + void visitProjection(planner::LogicalOperator* op) override; + void visitAccumulate(planner::LogicalOperator* op) override; + void visitAggregate(planner::LogicalOperator* op) override; + void visitOrderBy(planner::LogicalOperator* op) override; + void visitLimit(planner::LogicalOperator* op) override; + void visitDistinct(planner::LogicalOperator* op) override; + void visitUnwind(planner::LogicalOperator* op) override; + void visitUnion(planner::LogicalOperator* op) override; + void visitFilter(planner::LogicalOperator* op) override; + void visitSetProperty(planner::LogicalOperator* op) override; + void visitDelete(planner::LogicalOperator* op) override; + void visitInsert(planner::LogicalOperator* op) override; + void visitMerge(planner::LogicalOperator* op) override; + void visitCopyTo(planner::LogicalOperator* op) override; + + std::shared_ptr appendFlattens( + std::shared_ptr op, + const std::unordered_set& groupsPos); + std::shared_ptr appendFlattenIfNecessary( + std::shared_ptr op, planner::f_group_pos groupPos); +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/filter_push_down_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/filter_push_down_optimizer.h new file mode 100644 index 0000000000..98b69f4b87 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/filter_push_down_optimizer.h @@ -0,0 +1,93 @@ +#pragma once + +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace optimizer { + +struct PredicateSet { + binder::expression_vector equalityPredicates; + binder::expression_vector nonEqualityPredicates; + + PredicateSet() = default; + EXPLICIT_COPY_DEFAULT_MOVE(PredicateSet); + + bool isEmpty() const { return equalityPredicates.empty() && nonEqualityPredicates.empty(); } + void clear() { + equalityPredicates.clear(); + nonEqualityPredicates.clear(); + } + + void addPredicate(std::shared_ptr predicate); + std::shared_ptr popNodePKEqualityComparison( + const binder::Expression& nodeID); + binder::expression_vector getAllPredicates(); + +private: + PredicateSet(const PredicateSet& other) + : equalityPredicates{other.equalityPredicates}, + nonEqualityPredicates{other.nonEqualityPredicates} {} +}; + +class FilterPushDownOptimizer { +public: + explicit FilterPushDownOptimizer(main::ClientContext* context) : context{context} { + predicateSet = PredicateSet(); + } + explicit FilterPushDownOptimizer(main::ClientContext* context, PredicateSet predicateSet) + : predicateSet{std::move(predicateSet)}, context{context} {} + + void rewrite(planner::LogicalPlan* plan); + +private: + std::shared_ptr visitOperator( + const std::shared_ptr& op); + // Collect predicates in FILTER + std::shared_ptr visitFilterReplace( + const std::shared_ptr& op); + // Push primary key lookup into CROSS_PRODUCT + // E.g. + // Filter(a.ID=b.ID) + // CrossProduct to HashJoin + // S(a) S(b) S(a) S(b) + std::shared_ptr visitCrossProductReplace( + const std::shared_ptr& op); + + // Push FILTER into SCAN_NODE_TABLE, and turn index lookup into INDEX_SCAN. + std::shared_ptr visitScanNodeTableReplace( + const std::shared_ptr& op); + // Push Filter into EXTEND. + std::shared_ptr visitExtendReplace( + const std::shared_ptr& op); + // Push Filter into TABLE_FUNCTION_CALL + std::shared_ptr visitTableFunctionCallReplace( + const std::shared_ptr& op); + + // Finish the current push down optimization by apply remaining predicates as a single filter. + // And heuristically reorder equality predicates first in the filter. + std::shared_ptr finishPushDown( + std::shared_ptr op); + std::shared_ptr appendFilters( + const binder::expression_vector& predicates, + std::shared_ptr child); + + std::shared_ptr appendScanNodeTable( + std::shared_ptr nodeID, std::vector nodeTableIDs, + binder::expression_vector properties, std::shared_ptr child); + std::shared_ptr appendFilter( + std::shared_ptr predicate, + std::shared_ptr child); + + std::shared_ptr visitChildren( + const std::shared_ptr& op); + +private: + PredicateSet predicateSet; + main::ClientContext* context; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/limit_push_down_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/limit_push_down_optimizer.h new file mode 100644 index 0000000000..64d7b9e7be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/limit_push_down_optimizer.h @@ -0,0 +1,23 @@ +#pragma once + +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +class LimitPushDownOptimizer { +public: + LimitPushDownOptimizer() : skipNumber{0}, limitNumber{common::INVALID_LIMIT} {} + + void rewrite(planner::LogicalPlan* plan); + +private: + void visitOperator(planner::LogicalOperator* op); + +private: + common::offset_t skipNumber; + common::offset_t limitNumber; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/logical_operator_collector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/logical_operator_collector.h new file mode 100644 index 0000000000..610a74a041 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/logical_operator_collector.h @@ -0,0 +1,48 @@ +#pragma once + +#include "logical_operator_visitor.h" + +namespace lbug { +namespace optimizer { + +class LogicalOperatorCollector : public LogicalOperatorVisitor { +public: + ~LogicalOperatorCollector() override = default; + + void collect(planner::LogicalOperator* op); + + bool hasOperators() const { return !ops.empty(); } + const std::vector& getOperators() const { return ops; } + +protected: + std::vector ops; +}; + +class LogicalFlattenCollector final : public LogicalOperatorCollector { +protected: + void visitFlatten(planner::LogicalOperator* op) override { ops.push_back(op); } +}; + +class LogicalFilterCollector final : public LogicalOperatorCollector { +protected: + void visitFilter(planner::LogicalOperator* op) override { ops.push_back(op); } +}; + +class LogicalScanNodeTableCollector final : public LogicalOperatorCollector { +protected: + void visitScanNodeTable(planner::LogicalOperator* op) override { ops.push_back(op); } +}; + +// TODO(Xiyang): Rename me. +class LogicalIndexScanNodeCollector final : public LogicalOperatorCollector { +protected: + void visitScanNodeTable(planner::LogicalOperator* op) override; +}; + +class LogicalRecursiveExtendCollector final : public LogicalOperatorCollector { +protected: + void visitRecursiveExtend(planner::LogicalOperator* op) override { ops.push_back(op); } +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/logical_operator_visitor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/logical_operator_visitor.h new file mode 100644 index 0000000000..355f147c6b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/logical_operator_visitor.h @@ -0,0 +1,182 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace optimizer { + +class LogicalOperatorVisitor { +public: + LogicalOperatorVisitor() = default; + virtual ~LogicalOperatorVisitor() = default; + +protected: + void visitOperatorSwitch(planner::LogicalOperator* op); + std::shared_ptr visitOperatorReplaceSwitch( + std::shared_ptr op); + + virtual void visitAccumulate(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitAccumulateReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitAggregate(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitAggregateReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitCopyFrom(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitCopyFromReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitCopyTo(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitCopyToReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitDelete(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitDeleteReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitDistinct(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitDistinctReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitEmptyResult(planner::LogicalOperator*) {} + virtual std::shared_ptr visitEmptyResultReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitExpressionsScan(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitExpressionsScanReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitExtend(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitExtendReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitFilter(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitFilterReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitFlatten(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitFlattenReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitHashJoin(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitHashJoinReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitIntersect(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitIntersectReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitInsert(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitInsertReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitLimit(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitLimitReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitMerge(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitMergeReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitNodeLabelFilter(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitNodeLabelFilterReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitOrderBy(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitOrderByReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitPathPropertyProbe(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitPathPropertyProbeReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitProjection(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitProjectionReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitRecursiveExtend(planner::LogicalOperator*) {} + virtual std::shared_ptr visitRecursiveExtendReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitScanNodeTable(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitScanNodeTableReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitSetProperty(planner::LogicalOperator*) {} + virtual std::shared_ptr visitSetPropertyReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitTableFunctionCall(planner::LogicalOperator*) {} + virtual std::shared_ptr visitTableFunctionCallReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitUnion(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitUnionReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitUnwind(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitUnwindReplace( + std::shared_ptr op) { + return op; + } + + virtual void visitCrossProduct(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitCrossProductReplace( + std::shared_ptr op) { + return op; + } +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/optimizer.h new file mode 100644 index 0000000000..967f32102f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/optimizer.h @@ -0,0 +1,23 @@ +#pragma once + +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace planner { +class CardinalityEstimator; +} + +namespace optimizer { + +class Optimizer { +public: + static void optimize(planner::LogicalPlan* plan, main::ClientContext* context, + const planner::CardinalityEstimator& cardinalityEstimator); +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/projection_push_down_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/projection_push_down_optimizer.h new file mode 100644 index 0000000000..4a33d9ebe6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/projection_push_down_optimizer.h @@ -0,0 +1,68 @@ +#pragma once + +#include "common/enums/path_semantic.h" +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace binder { +struct BoundSetPropertyInfo; +} +namespace planner { +struct LogicalInsertInfo; +} +namespace optimizer { + +// ProjectionPushDownOptimizer implements the logic to avoid materializing unnecessary properties +// for hash join build. +// Note the optimization is for properties & variables only but not for general expressions. This is +// because it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be +// either the whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or +// only a.age is evaluate. For simplicity, we only consider the push down for property. +class ProjectionPushDownOptimizer : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + explicit ProjectionPushDownOptimizer(common::PathSemantic semantic) : semantic(semantic){}; + +private: + void visitOperator(planner::LogicalOperator* op); + + void visitPathPropertyProbe(planner::LogicalOperator* op) override; + void visitExtend(planner::LogicalOperator* op) override; + void visitAccumulate(planner::LogicalOperator* op) override; + void visitFilter(planner::LogicalOperator* op) override; + void visitNodeLabelFilter(planner::LogicalOperator* op) override; + void visitHashJoin(planner::LogicalOperator* op) override; + void visitIntersect(planner::LogicalOperator* op) override; + void visitProjection(planner::LogicalOperator* op) override; + void visitOrderBy(planner::LogicalOperator* op) override; + void visitUnwind(planner::LogicalOperator* op) override; + void visitSetProperty(planner::LogicalOperator* op) override; + void visitInsert(planner::LogicalOperator* op) override; + void visitDelete(planner::LogicalOperator* op) override; + void visitMerge(planner::LogicalOperator* op) override; + void visitCopyFrom(planner::LogicalOperator* op) override; + void visitTableFunctionCall(planner::LogicalOperator*) override; + + void visitSetInfo(const binder::BoundSetPropertyInfo& info); + void visitInsertInfo(const planner::LogicalInsertInfo& info); + + void collectExpressionsInUse(std::shared_ptr expression); + + binder::expression_vector pruneExpressions(const binder::expression_vector& expressions); + + void preAppendProjection(planner::LogicalOperator* op, common::idx_t childIdx, + binder::expression_vector expressions); + +private: + binder::expression_set propertiesInUse; + binder::expression_set variablesInUse; + binder::expression_set nodeOrRelInUse; + common::PathSemantic semantic; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/remove_factorization_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/remove_factorization_rewriter.h new file mode 100644 index 0000000000..2bd51e8426 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/remove_factorization_rewriter.h @@ -0,0 +1,22 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +class RemoveFactorizationRewriter : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + + std::shared_ptr visitOperator( + const std::shared_ptr& op); + +private: + std::shared_ptr visitFlattenReplace( + std::shared_ptr op) override; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/remove_unnecessary_join_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/remove_unnecessary_join_optimizer.h new file mode 100644 index 0000000000..25509ef882 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/remove_unnecessary_join_optimizer.h @@ -0,0 +1,33 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +/* Due to the nature of graph pattern, a (node)-[rel]-(node) is always interpreted as two joins. + * However, in many cases, a single join is sufficient. + * E.g. MATCH (a)-[e]->(b) RETURN e.date + * Our planner will generate a plan where the HJ is redundant. + * HJ + * / \ + * E(e) S(b) + * | + * S(a) + * This optimizer prunes such redundant joins. + */ +class RemoveUnnecessaryJoinOptimizer : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + +private: + std::shared_ptr visitOperator( + const std::shared_ptr& op); + + std::shared_ptr visitHashJoinReplace( + std::shared_ptr op) override; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/schema_populator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/schema_populator.h new file mode 100644 index 0000000000..c2e5347cf8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/schema_populator.h @@ -0,0 +1,12 @@ +#pragma once + +#include "optimizer/logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" +namespace lbug { +namespace optimizer { +class SchemaPopulator : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); +}; +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/top_k_optimizer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/top_k_optimizer.h new file mode 100644 index 0000000000..4a549f29f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/optimizer/top_k_optimizer.h @@ -0,0 +1,22 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace optimizer { + +class TopKOptimizer : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + + std::shared_ptr visitOperator( + const std::shared_ptr& op); + +private: + std::shared_ptr visitLimitReplace( + std::shared_ptr op) override; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/lbug_cypher_parser.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/lbug_cypher_parser.h new file mode 100644 index 0000000000..754d53f358 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/lbug_cypher_parser.h @@ -0,0 +1,32 @@ +#pragma once + +// ANTLR4 generates code with unused parameters. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "cypher_parser.h" +#pragma GCC diagnostic pop + +namespace lbug { +namespace parser { + +class LbugCypherParser : public CypherParser { + +public: + explicit LbugCypherParser(antlr4::TokenStream* input) : CypherParser(input) {} + + void notifyQueryNotConcludeWithReturn(antlr4::Token* startToken) override; + + void notifyNodePatternWithoutParentheses(std::string nodeName, + antlr4::Token* startToken) override; + + void notifyInvalidNotEqualOperator(antlr4::Token* startToken) override; + + void notifyEmptyToken(antlr4::Token* startToken) override; + + void notifyReturnNotAtEnd(antlr4::Token* startToken) override; + + void notifyNonBinaryComparison(antlr4::Token* startToken) override; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/parser_error_listener.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/parser_error_listener.h new file mode 100644 index 0000000000..f9738384a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/parser_error_listener.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "antlr4-runtime.h" // IWYU pragma: keep. This is the public header. + +namespace lbug { +namespace parser { + +class ParserErrorListener : public antlr4::BaseErrorListener { + +public: + void syntaxError(antlr4::Recognizer* recognizer, antlr4::Token* offendingSymbol, size_t line, + size_t charPositionInLine, const std::string& msg, std::exception_ptr e) override; + +private: + std::string formatUnderLineError(antlr4::Recognizer& recognizer, + const antlr4::Token& offendingToken, size_t line, size_t charPositionInLine); +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/parser_error_strategy.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/parser_error_strategy.h new file mode 100644 index 0000000000..2e219c95f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/antlr_parser/parser_error_strategy.h @@ -0,0 +1,16 @@ +#pragma once + +#include "antlr4-runtime.h" // IWYU pragma: keep; this is the public header. + +namespace lbug { +namespace parser { + +class ParserErrorStrategy : public antlr4::DefaultErrorStrategy { + +protected: + void reportNoViableAlternative(antlr4::Parser* recognizer, + const antlr4::NoViableAltException& e) override; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/attach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/attach_database.h new file mode 100644 index 0000000000..641b9d7474 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/attach_database.h @@ -0,0 +1,21 @@ +#pragma once + +#include "parsed_data/attach_info.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class AttachDatabase final : public Statement { +public: + explicit AttachDatabase(AttachInfo attachInfo) + : Statement{common::StatementType::ATTACH_DATABASE}, attachInfo{std::move(attachInfo)} {} + + const AttachInfo& getAttachInfo() const { return attachInfo; } + +private: + AttachInfo attachInfo; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/copy.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/copy.h new file mode 100644 index 0000000000..cdfb6dac64 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/copy.h @@ -0,0 +1,70 @@ +#pragma once + +#include + +#include "parser/expression/parsed_expression.h" +#include "parser/scan_source.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class Copy : public Statement { +public: + explicit Copy(common::StatementType type) : Statement{type} {} + + void setParsingOption(options_t options) { parsingOptions = std::move(options); } + const options_t& getParsingOptions() const { return parsingOptions; } + +protected: + options_t parsingOptions; +}; + +struct CopyFromColumnInfo { + bool inputColumnOrder = false; + std::vector columnNames; + + CopyFromColumnInfo() = default; + CopyFromColumnInfo(bool inputColumnOrder, std::vector columnNames) + : inputColumnOrder{inputColumnOrder}, columnNames{std::move(columnNames)} {} +}; + +class CopyFrom : public Copy { +public: + CopyFrom(std::unique_ptr source, std::string tableName) + : Copy{common::StatementType::COPY_FROM}, byColumn_{false}, source{std::move(source)}, + tableName{std::move(tableName)} {} + + void setByColumn() { byColumn_ = true; } + bool byColumn() const { return byColumn_; } + + BaseScanSource* getSource() const { return source.get(); } + + std::string getTableName() const { return tableName; } + + void setColumnInfo(CopyFromColumnInfo columnInfo_) { columnInfo = std::move(columnInfo_); } + CopyFromColumnInfo getCopyColumnInfo() const { return columnInfo; } + +private: + bool byColumn_; + std::unique_ptr source; + std::string tableName; + CopyFromColumnInfo columnInfo; +}; + +class CopyTo : public Copy { +public: + CopyTo(std::string filePath, std::unique_ptr statement) + : Copy{common::StatementType::COPY_TO}, filePath{std::move(filePath)}, + statement{std::move(statement)} {} + + std::string getFilePath() const { return filePath; } + const Statement* getStatement() const { return statement.get(); } + +private: + std::string filePath; + std::unique_ptr statement; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/create_macro.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/create_macro.h new file mode 100644 index 0000000000..d55a76bfc4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/create_macro.h @@ -0,0 +1,37 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +using default_macro_args = std::vector>>; + +class CreateMacro final : public Statement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_MACRO; + +public: + CreateMacro(std::string macroName, std::unique_ptr macroExpression, + std::vector positionalArgs, default_macro_args defaultArgs) + : Statement{type_}, macroName{std::move(macroName)}, + macroExpression{std::move(macroExpression)}, positionalArgs{std::move(positionalArgs)}, + defaultArgs{std::move(defaultArgs)} {} + + std::string getMacroName() const { return macroName; } + + ParsedExpression* getMacroExpression() const { return macroExpression.get(); } + + std::vector getPositionalArgs() const { return positionalArgs; } + + std::vector> getDefaultArgs() const; + +public: + std::string macroName; + std::unique_ptr macroExpression; + std::vector positionalArgs; + default_macro_args defaultArgs; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/database_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/database_statement.h new file mode 100644 index 0000000000..0d199c837b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/database_statement.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class DatabaseStatement : public Statement { +public: + explicit DatabaseStatement(common::StatementType type, std::string dbName) + : Statement{type}, dbName{std::move(dbName)} {} + + std::string getDBName() const { return dbName; } + +private: + std::string dbName; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/alter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/alter.h new file mode 100644 index 0000000000..e6003dcdc4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/alter.h @@ -0,0 +1,22 @@ +#pragma once + +#include "alter_info.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class Alter : public Statement { + static constexpr common::StatementType type_ = common::StatementType::ALTER; + +public: + explicit Alter(AlterInfo info) : Statement{type_}, info{std::move(info)} {} + + const AlterInfo* getInfo() const { return &info; } + +private: + AlterInfo info; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/alter_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/alter_info.h new file mode 100644 index 0000000000..c4c9fd7809 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/alter_info.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include "common/copy_constructors.h" +#include "common/enums/alter_type.h" +#include "common/enums/conflict_action.h" +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace parser { + +struct ExtraAlterInfo { + virtual ~ExtraAlterInfo() = default; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +struct AlterInfo { + common::AlterType type; + std::string tableName; + std::unique_ptr extraInfo; + common::ConflictAction onConflict; + + AlterInfo(common::AlterType type, std::string tableName, + std::unique_ptr extraInfo, + common::ConflictAction onConflict = common::ConflictAction::ON_CONFLICT_THROW) + : type{type}, tableName{std::move(tableName)}, extraInfo{std::move(extraInfo)}, + onConflict{onConflict} {} + DELETE_COPY_DEFAULT_MOVE(AlterInfo); +}; + +struct ExtraRenameTableInfo : public ExtraAlterInfo { + std::string newName; + + explicit ExtraRenameTableInfo(std::string newName) : newName{std::move(newName)} {} +}; + +struct ExtraAddFromToConnection : public ExtraAlterInfo { + std::string srcTableName; + std::string dstTableName; + + explicit ExtraAddFromToConnection(std::string srcTableName, std::string dstTableName) + : srcTableName{std::move(srcTableName)}, dstTableName{std::move(dstTableName)} {} +}; + +struct ExtraAddPropertyInfo : public ExtraAlterInfo { + std::string propertyName; + std::string dataType; + std::unique_ptr defaultValue; + + ExtraAddPropertyInfo(std::string propertyName, std::string dataType, + std::unique_ptr defaultValue) + : propertyName{std::move(propertyName)}, dataType{std::move(dataType)}, + defaultValue{std::move(defaultValue)} {} +}; + +struct ExtraDropPropertyInfo : public ExtraAlterInfo { + std::string propertyName; + + explicit ExtraDropPropertyInfo(std::string propertyName) + : propertyName{std::move(propertyName)} {} +}; + +struct ExtraRenamePropertyInfo : public ExtraAlterInfo { + std::string propertyName; + std::string newName; + + ExtraRenamePropertyInfo(std::string propertyName, std::string newName) + : propertyName{std::move(propertyName)}, newName{std::move(newName)} {} +}; + +struct ExtraCommentInfo : public ExtraAlterInfo { + std::string comment; + + explicit ExtraCommentInfo(std::string comment) : comment{std::move(comment)} {} +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_sequence.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_sequence.h new file mode 100644 index 0000000000..1710001de2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_sequence.h @@ -0,0 +1,22 @@ +#pragma once + +#include "create_sequence_info.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class CreateSequence final : public Statement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_SEQUENCE; + +public: + explicit CreateSequence(CreateSequenceInfo info) : Statement{type_}, info{std::move(info)} {} + + CreateSequenceInfo getInfo() const { return info.copy(); } + +private: + CreateSequenceInfo info; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_sequence_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_sequence_info.h new file mode 100644 index 0000000000..ac8c4c9b90 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_sequence_info.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "common/copy_constructors.h" +#include "common/enums/conflict_action.h" + +namespace lbug { +namespace parser { + +enum class SequenceInfoType { + START, + INCREMENT, + MINVALUE, + MAXVALUE, + CYCLE, + INVALID, +}; + +struct CreateSequenceInfo { + std::string sequenceName; + std::string startWith = ""; + std::string increment = "1"; + std::string minValue = ""; + std::string maxValue = ""; + bool cycle = false; + common::ConflictAction onConflict; + + explicit CreateSequenceInfo(std::string sequenceName, common::ConflictAction onConflict) + : sequenceName{std::move(sequenceName)}, onConflict{onConflict} {} + EXPLICIT_COPY_DEFAULT_MOVE(CreateSequenceInfo); + +private: + CreateSequenceInfo(const CreateSequenceInfo& other) + : sequenceName{other.sequenceName}, startWith{other.startWith}, increment{other.increment}, + minValue{other.minValue}, maxValue{other.maxValue}, cycle{other.cycle}, + onConflict{other.onConflict} {} +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_table.h new file mode 100644 index 0000000000..a1c160c36a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_table.h @@ -0,0 +1,28 @@ +#pragma once + +#include "create_table_info.h" +#include "parser/scan_source.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class CreateTable final : public Statement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_TABLE; + +public: + explicit CreateTable(CreateTableInfo info) : Statement{type_}, info{std::move(info)} {} + + CreateTable(CreateTableInfo info, std::unique_ptr&& source) + : Statement{type_}, info{std::move(info)}, source{std::move(source)} {} + + const CreateTableInfo* getInfo() const { return &info; } + const QueryScanSource* getSource() const { return source.get(); } + +private: + CreateTableInfo info; + std::unique_ptr source; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_table_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_table_info.h new file mode 100644 index 0000000000..b9fa5265f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_table_info.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/enums/conflict_action.h" +#include "common/enums/table_type.h" +#include "parsed_property_definition.h" + +namespace lbug { +namespace parser { + +struct ExtraCreateTableInfo { + virtual ~ExtraCreateTableInfo() = default; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +struct CreateTableInfo { + common::TableType type; + std::string tableName; + std::vector propertyDefinitions; + std::unique_ptr extraInfo; + common::ConflictAction onConflict; + + CreateTableInfo(common::TableType type, std::string tableName, + common::ConflictAction onConflict) + : type{type}, tableName{std::move(tableName)}, extraInfo{nullptr}, onConflict{onConflict} {} + DELETE_COPY_DEFAULT_MOVE(CreateTableInfo); +}; + +struct ExtraCreateNodeTableInfo final : ExtraCreateTableInfo { + std::string pKName; + + explicit ExtraCreateNodeTableInfo(std::string pKName) : pKName{std::move(pKName)} {} +}; + +struct ExtraCreateRelTableGroupInfo final : ExtraCreateTableInfo { + std::string relMultiplicity; + std::vector> srcDstTablePairs; + options_t options; + + ExtraCreateRelTableGroupInfo(std::string relMultiplicity, + std::vector> srcDstTablePairs, options_t options) + : relMultiplicity{std::move(relMultiplicity)}, + srcDstTablePairs{std::move(srcDstTablePairs)}, options{std::move(options)} {} +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_type.h new file mode 100644 index 0000000000..806831355a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/create_type.h @@ -0,0 +1,25 @@ +#pragma once + +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class CreateType final : public Statement { + static constexpr common::StatementType type_ = common::StatementType::CREATE_TYPE; + +public: + CreateType(std::string name, std::string dataType) + : Statement{type_}, name{std::move(name)}, dataType{std::move(dataType)} {} + + std::string getName() const { return name; } + + std::string getDataType() const { return dataType; } + +private: + std::string name; + std::string dataType; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/drop.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/drop.h new file mode 100644 index 0000000000..89d7656ba0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/drop.h @@ -0,0 +1,22 @@ +#pragma once + +#include "drop_info.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class Drop : public Statement { + static constexpr common::StatementType type_ = common::StatementType::DROP; + +public: + explicit Drop(DropInfo dropInfo) : Statement{type_}, dropInfo{std::move(dropInfo)} {} + + const DropInfo& getDropInfo() const { return dropInfo; } + +private: + DropInfo dropInfo; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/drop_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/drop_info.h new file mode 100644 index 0000000000..87bf438fd2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/drop_info.h @@ -0,0 +1,17 @@ +#pragma once +#include + +#include "common/enums/conflict_action.h" +#include "common/enums/drop_type.h" + +namespace lbug { +namespace parser { + +struct DropInfo { + std::string name; + common::DropType dropType; + common::ConflictAction conflictAction; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/parsed_property_definition.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/parsed_property_definition.h new file mode 100644 index 0000000000..510197952e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/ddl/parsed_property_definition.h @@ -0,0 +1,43 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace parser { + +struct ParsedColumnDefinition { + std::string name; + std::string type; + + ParsedColumnDefinition(std::string name, std::string type) + : name{std::move(name)}, type{std::move(type)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ParsedColumnDefinition); + +private: + ParsedColumnDefinition(const ParsedColumnDefinition& other) + : name{other.name}, type{other.type} {} +}; + +struct ParsedPropertyDefinition { + ParsedColumnDefinition columnDefinition; + std::unique_ptr defaultExpr; + + ParsedPropertyDefinition(ParsedColumnDefinition columnDefinition, + std::unique_ptr defaultExpr) + : columnDefinition{std::move(columnDefinition)}, defaultExpr{std::move(defaultExpr)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ParsedPropertyDefinition); + + std::string getName() const { return columnDefinition.name; } + std::string getType() const { return columnDefinition.type; } + +private: + ParsedPropertyDefinition(const ParsedPropertyDefinition& other) + : columnDefinition{other.columnDefinition.copy()} { + if (other.defaultExpr) { + defaultExpr = other.defaultExpr->copy(); + } + } +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/detach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/detach_database.h new file mode 100644 index 0000000000..dd5a917d2f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/detach_database.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +#include "parser/database_statement.h" + +namespace lbug { +namespace parser { + +class DetachDatabase final : public DatabaseStatement { + static constexpr common::StatementType type_ = common::StatementType::DETACH_DATABASE; + +public: + explicit DetachDatabase(std::string dbName) : DatabaseStatement{type_, std::move(dbName)} {} +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/explain_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/explain_statement.h new file mode 100644 index 0000000000..939b870694 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/explain_statement.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "common/enums/explain_type.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class ExplainStatement : public Statement { +public: + ExplainStatement(std::unique_ptr statementToExplain, common::ExplainType explainType) + : Statement{common::StatementType::EXPLAIN}, + statementToExplain{std::move(statementToExplain)}, explainType{explainType} {} + + inline Statement* getStatementToExplain() const { return statementToExplain.get(); } + + inline common::ExplainType getExplainType() const { return explainType; } + +private: + std::unique_ptr statementToExplain; + common::ExplainType explainType; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_case_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_case_expression.h new file mode 100644 index 0000000000..ba2ca09dd6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_case_expression.h @@ -0,0 +1,93 @@ +#pragma once + +#include "common/copy_constructors.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +struct ParsedCaseAlternative { + std::unique_ptr whenExpression; + std::unique_ptr thenExpression; + + ParsedCaseAlternative() = default; + ParsedCaseAlternative(std::unique_ptr whenExpression, + std::unique_ptr thenExpression) + : whenExpression{std::move(whenExpression)}, thenExpression{std::move(thenExpression)} {} + ParsedCaseAlternative(const ParsedCaseAlternative& other) + : whenExpression{other.whenExpression->copy()}, + thenExpression{other.thenExpression->copy()} {} + DEFAULT_BOTH_MOVE(ParsedCaseAlternative); + + void serialize(common::Serializer& serializer) const; + static ParsedCaseAlternative deserialize(common::Deserializer& deserializer); +}; + +// Cypher supports 2 types of CaseExpression +// 1. CASE a.age +// WHEN 20 THEN ... +// 2. CASE +// WHEN a.age = 20 THEN ... +class ParsedCaseExpression final : public ParsedExpression { + friend class ParsedExpressionChildrenVisitor; + +public: + explicit ParsedCaseExpression(std::string raw) + : ParsedExpression{common::ExpressionType::CASE_ELSE, std::move(raw)} {}; + + ParsedCaseExpression(std::string alias, std::string rawName, parsed_expr_vector children, + std::unique_ptr caseExpression, + std::vector caseAlternatives, + std::unique_ptr elseExpression) + : ParsedExpression{common::ExpressionType::CASE_ELSE, std::move(alias), std::move(rawName), + std::move(children)}, + caseExpression{std::move(caseExpression)}, caseAlternatives{std::move(caseAlternatives)}, + elseExpression{std::move(elseExpression)} {} + + ParsedCaseExpression(std::unique_ptr caseExpression, + std::vector caseAlternatives, + std::unique_ptr elseExpression) + : ParsedExpression{common::ExpressionType::CASE_ELSE}, + caseExpression{std::move(caseExpression)}, caseAlternatives{std::move(caseAlternatives)}, + elseExpression{std::move(elseExpression)} {} + + inline void setCaseExpression(std::unique_ptr expression) { + caseExpression = std::move(expression); + } + inline bool hasCaseExpression() const { return caseExpression != nullptr; } + inline ParsedExpression* getCaseExpression() const { return caseExpression.get(); } + + inline void addCaseAlternative(ParsedCaseAlternative caseAlternative) { + caseAlternatives.push_back(std::move(caseAlternative)); + } + inline uint32_t getNumCaseAlternative() const { return caseAlternatives.size(); } + inline ParsedCaseAlternative* getCaseAlternativeUnsafe(uint32_t idx) { + return &caseAlternatives[idx]; + } + inline const ParsedCaseAlternative* getCaseAlternative(uint32_t idx) const { + return &caseAlternatives[idx]; + } + + inline void setElseExpression(std::unique_ptr expression) { + elseExpression = std::move(expression); + } + inline bool hasElseExpression() const { return elseExpression != nullptr; } + inline ParsedExpression* getElseExpression() const { return elseExpression.get(); } + + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + std::unique_ptr copy() const override; + +private: + void serializeInternal(common::Serializer& serializer) const override; + +private: + // Optional. If not specified, directly check next whenExpression + std::unique_ptr caseExpression; + std::vector caseAlternatives; + // Optional. If not specified, evaluate as null + std::unique_ptr elseExpression; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_expression.h new file mode 100644 index 0000000000..94db81061f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_expression.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/cast.h" +#include "common/copy_constructors.h" +#include "common/enums/expression_type.h" +#include "common/types/types.h" + +namespace lbug { + +namespace common { +struct FileInfo; +class Serializer; +class Deserializer; +} // namespace common + +namespace parser { + +class ParsedExpression; +class ParsedExpressionChildrenVisitor; +using parsed_expr_vector = std::vector>; +using parsed_expr_pair = + std::pair, std::unique_ptr>; +using s_parsed_expr_pair = std::pair>; + +class LBUG_API ParsedExpression { + friend class ParsedExpressionChildrenVisitor; + +public: + ParsedExpression(common::ExpressionType type, std::unique_ptr child, + std::string rawName); + ParsedExpression(common::ExpressionType type, std::unique_ptr left, + std::unique_ptr right, std::string rawName); + ParsedExpression(common::ExpressionType type, std::string rawName) + : type{type}, rawName{std::move(rawName)} {} + explicit ParsedExpression(common::ExpressionType type) : type{type} {} + + ParsedExpression(common::ExpressionType type, std::string alias, std::string rawName, + parsed_expr_vector children) + : type{type}, alias{std::move(alias)}, rawName{std::move(rawName)}, + children{std::move(children)} {} + DELETE_COPY_DEFAULT_MOVE(ParsedExpression); + virtual ~ParsedExpression() = default; + + common::ExpressionType getExpressionType() const { return type; } + + void setAlias(std::string name) { alias = std::move(name); } + bool hasAlias() const { return !alias.empty(); } + std::string getAlias() const { return alias; } + + std::string getRawName() const { return rawName; } + + common::idx_t getNumChildren() const { return children.size(); } + ParsedExpression* getChild(common::idx_t idx) const { return children[idx].get(); } + void setChild(common::idx_t idx, std::unique_ptr child) { + KU_ASSERT(idx < children.size()); + children[idx] = std::move(child); + } + + std::string toString() const { return rawName; } + + virtual std::unique_ptr copy() const { + return std::make_unique(type, alias, rawName, copyVector(children)); + } + + void serialize(common::Serializer& serializer) const; + + static std::unique_ptr deserialize(common::Deserializer& deserializer); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + +private: + virtual void serializeInternal(common::Serializer&) const {} + +protected: + common::ExpressionType type; + std::string alias; + std::string rawName; + parsed_expr_vector children; +}; + +using options_t = std::unordered_map>; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_expression_visitor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_expression_visitor.h new file mode 100644 index 0000000000..7aa5edfe75 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_expression_visitor.h @@ -0,0 +1,80 @@ +#pragma once + +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedExpressionVisitor { +public: + virtual ~ParsedExpressionVisitor() = default; + + void visit(const ParsedExpression* expr); + void visitUnsafe(ParsedExpression* expr); + + virtual void visitSwitch(const ParsedExpression* expr); + virtual void visitFunctionExpr(const ParsedExpression*) {} + virtual void visitAggFunctionExpr(const ParsedExpression*) {} + virtual void visitPropertyExpr(const ParsedExpression*) {} + virtual void visitLiteralExpr(const ParsedExpression*) {} + virtual void visitVariableExpr(const ParsedExpression*) {} + virtual void visitPathExpr(const ParsedExpression*) {} + virtual void visitNodeRelExpr(const ParsedExpression*) {} + virtual void visitParamExpr(const ParsedExpression*) {} + virtual void visitSubqueryExpr(const ParsedExpression*) {} + virtual void visitCaseExpr(const ParsedExpression*) {} + virtual void visitGraphExpr(const ParsedExpression*) {} + virtual void visitLambdaExpr(const ParsedExpression*) {} + virtual void visitStar(const ParsedExpression*) {} + + void visitChildren(const ParsedExpression& expr); + void visitCaseChildren(const ParsedExpression& expr); + + virtual void visitSwitchUnsafe(ParsedExpression*) {} + virtual void visitChildrenUnsafe(ParsedExpression& expr); + virtual void visitCaseChildrenUnsafe(ParsedExpression& expr); +}; + +class ParsedParamExprCollector : public ParsedExpressionVisitor { +public: + std::vector getParamExprs() const { return paramExprs; } + bool hasParamExprs() const { return !paramExprs.empty(); } + + void visitParamExpr(const ParsedExpression* expr) override { paramExprs.push_back(expr); } + +private: + std::vector paramExprs; +}; + +class ReadWriteExprAnalyzer : public ParsedExpressionVisitor { +public: + explicit ReadWriteExprAnalyzer(main::ClientContext* context) + : ParsedExpressionVisitor{}, context{context} {} + + bool isReadOnly() const { return readOnly; } + void visitFunctionExpr(const ParsedExpression* expr) override; + +private: + main::ClientContext* context; + bool readOnly = true; +}; + +class MacroParameterReplacer : public ParsedExpressionVisitor { +public: + explicit MacroParameterReplacer( + const std::unordered_map& nameToExpr) + : nameToExpr{nameToExpr} {} + + std::unique_ptr replace(std::unique_ptr input); + +private: + void visitSwitchUnsafe(ParsedExpression* expr) override; + + std::unique_ptr getReplace(const std::string& name); + +private: + const std::unordered_map& nameToExpr; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_function_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_function_expression.h new file mode 100644 index 0000000000..8d5efbae36 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_function_expression.h @@ -0,0 +1,77 @@ +#pragma once + +#include "common/string_utils.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedFunctionExpression : public ParsedExpression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::FUNCTION; + +public: + ParsedFunctionExpression(std::string functionName, std::string rawName, bool isDistinct = false) + : ParsedExpression{expressionType_, std::move(rawName)}, isDistinct{isDistinct}, + functionName{std::move(functionName)} {} + + ParsedFunctionExpression(std::string functionName, std::unique_ptr child, + std::string rawName, bool isDistinct = false) + : ParsedExpression{expressionType_, std::move(child), std::move(rawName)}, + isDistinct{isDistinct}, functionName{std::move(functionName)} {} + + ParsedFunctionExpression(std::string functionName, std::unique_ptr left, + std::unique_ptr right, std::string rawName, bool isDistinct = false) + : ParsedExpression{expressionType_, std::move(left), std::move(right), std::move(rawName)}, + isDistinct{isDistinct}, functionName{std::move(functionName)} {} + + ParsedFunctionExpression(std::string alias, std::string rawName, parsed_expr_vector children, + std::string functionName, bool isDistinct, std::vector optionalArguments) + : ParsedExpression{expressionType_, std::move(alias), std::move(rawName), + std::move(children)}, + isDistinct{isDistinct}, functionName{std::move(functionName)}, + optionalArguments{std::move(optionalArguments)} {} + + ParsedFunctionExpression(std::string functionName, bool isDistinct) + : ParsedExpression{expressionType_}, isDistinct{isDistinct}, + functionName{std::move(functionName)} {} + + bool getIsDistinct() const { return isDistinct; } + + std::string getFunctionName() const { return functionName; } + std::string getNormalizedFunctionName() const { + return common::StringUtils::getUpper(functionName); + } + + void addChild(std::unique_ptr child) { children.push_back(std::move(child)); } + + void setOptionalArguments(std::vector optionalArguments) { + this->optionalArguments = std::move(optionalArguments); + } + void addOptionalParams(std::string name, std::unique_ptr child) { + optionalArguments.push_back(std::move(name)); + children.push_back(std::move(child)); + } + + const std::vector& getOptionalArguments() const { return optionalArguments; } + + static std::unique_ptr deserialize( + common::Deserializer& deserializer); + + std::unique_ptr copy() const override { + return std::make_unique(alias, rawName, copyVector(children), + functionName, isDistinct, optionalArguments); + } + +private: + void serializeInternal(common::Serializer& serializer) const override; + +private: + bool isDistinct; + std::string functionName; + // In Lbug, function arguments must be either all required or all optional - mixing required and + // optional parameters in the same function is not allowed. + std::vector optionalArguments; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_lambda_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_lambda_expression.h new file mode 100644 index 0000000000..cf24ce77f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_lambda_expression.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/enums/expression_type.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedLambdaExpression : public ParsedExpression { + static constexpr const common::ExpressionType type_ = common::ExpressionType::LAMBDA; + +public: + ParsedLambdaExpression(std::vector varNames, + std::unique_ptr expr, std::string rawName) + : ParsedExpression{type_, rawName}, varNames{std::move(varNames)}, + functionExpr{std::move(expr)} {} + + std::vector getVarNames() const { return varNames; } + + ParsedExpression* getFunctionExpr() const { return functionExpr.get(); } + + std::unique_ptr copy() const override { + return std::make_unique(varNames, functionExpr->copy(), rawName); + } + +private: + std::vector varNames; + std::unique_ptr functionExpr; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_literal_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_literal_expression.h new file mode 100644 index 0000000000..ce364fafa4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_literal_expression.h @@ -0,0 +1,47 @@ +#pragma once + +#include "common/types/value/value.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedLiteralExpression : public ParsedExpression { + static constexpr common::ExpressionType expressionType = common::ExpressionType::LITERAL; + +public: + ParsedLiteralExpression(common::Value value, std::string raw) + : ParsedExpression{expressionType, std::move(raw)}, value{std::move(value)} {} + + ParsedLiteralExpression(std::string alias, std::string rawName, parsed_expr_vector children, + common::Value value) + : ParsedExpression{expressionType, std::move(alias), std::move(rawName), + std::move(children)}, + value{std::move(value)} {} + + explicit ParsedLiteralExpression(common::Value value) + : ParsedExpression{expressionType}, value{std::move(value)} {} + + common::Value getValue() const { return value; } + + static std::unique_ptr deserialize( + common::Deserializer& deserializer) { + return std::make_unique(*common::Value::deserialize(deserializer)); + } + + std::unique_ptr copy() const override { + return std::make_unique(alias, rawName, copyVector(children), + value); + } + +private: + void serializeInternal(common::Serializer& serializer) const override { + value.serialize(serializer); + } + +private: + common::Value value; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_parameter_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_parameter_expression.h new file mode 100644 index 0000000000..03372957bd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_parameter_expression.h @@ -0,0 +1,31 @@ +#pragma once + +#include "common/assert.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedParameterExpression : public ParsedExpression { +public: + explicit ParsedParameterExpression(std::string parameterName, std::string raw) + : ParsedExpression{common::ExpressionType::PARAMETER, std::move(raw)}, + parameterName{std::move(parameterName)} {} + + inline std::string getParameterName() const { return parameterName; } + + static std::unique_ptr deserialize(common::Deserializer&) { + KU_UNREACHABLE; + } + + inline std::unique_ptr copy() const override { KU_UNREACHABLE; } + +private: + void serializeInternal(common::Serializer&) const override { KU_UNREACHABLE; } + +private: + std::string parameterName; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_property_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_property_expression.h new file mode 100644 index 0000000000..bb9c6e552d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_property_expression.h @@ -0,0 +1,49 @@ +#pragma once + +#include "common/constants.h" +#include "common/serializer/serializer.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedPropertyExpression : public ParsedExpression { + static constexpr common::ExpressionType expressionType_ = common::ExpressionType::PROPERTY; + +public: + ParsedPropertyExpression(std::string propertyName, std::unique_ptr child, + std::string raw) + : ParsedExpression{expressionType_, std::move(child), std::move(raw)}, + propertyName{std::move(propertyName)} {} + + ParsedPropertyExpression(std::string alias, std::string rawName, parsed_expr_vector children, + std::string propertyName) + : ParsedExpression{expressionType_, std::move(alias), std::move(rawName), + std::move(children)}, + propertyName{std::move(propertyName)} {} + + explicit ParsedPropertyExpression(std::string propertyName) + : ParsedExpression{expressionType_}, propertyName{std::move(propertyName)} {} + + std::string getPropertyName() const { return propertyName; } + bool isStar() const { return propertyName == common::InternalKeyword::STAR; } + + static std::unique_ptr deserialize( + common::Deserializer& deserializer); + + std::unique_ptr copy() const override { + return std::make_unique(alias, rawName, copyVector(children), + propertyName); + } + +private: + void serializeInternal(common::Serializer& serializer) const override { + serializer.serializeValue(propertyName); + } + +private: + std::string propertyName; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_subquery_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_subquery_expression.h new file mode 100644 index 0000000000..53a32bbed8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_subquery_expression.h @@ -0,0 +1,56 @@ +#pragma once + +#include "common/assert.h" +#include "common/enums/subquery_type.h" +#include "parsed_expression.h" +#include "parser/query/graph_pattern/pattern_element.h" +#include "parser/query/reading_clause/join_hint.h" + +namespace lbug { +namespace parser { + +class ParsedSubqueryExpression : public ParsedExpression { + static constexpr common::ExpressionType type_ = common::ExpressionType::SUBQUERY; + +public: + ParsedSubqueryExpression(common::SubqueryType subqueryType, std::string rawName) + : ParsedExpression{type_, std::move(rawName)}, subqueryType{subqueryType} {} + + common::SubqueryType getSubqueryType() const { return subqueryType; } + + void addPatternElement(PatternElement element) { + patternElements.push_back(std::move(element)); + } + void setPatternElements(std::vector elements) { + patternElements = std::move(elements); + } + const std::vector& getPatternElements() const { return patternElements; } + + void setWhereClause(std::unique_ptr expression) { + whereClause = std::move(expression); + } + bool hasWhereClause() const { return whereClause != nullptr; } + const ParsedExpression* getWhereClause() const { return whereClause.get(); } + + void setHint(std::shared_ptr root) { hintRoot = std::move(root); } + bool hasHint() const { return hintRoot != nullptr; } + std::shared_ptr getHint() const { return hintRoot; } + + static std::unique_ptr deserialize(common::Deserializer&) { + KU_UNREACHABLE; + } + + std::unique_ptr copy() const override { KU_UNREACHABLE; } + +private: + void serializeInternal(common::Serializer&) const override { KU_UNREACHABLE; } + +private: + common::SubqueryType subqueryType; + std::vector patternElements; + std::unique_ptr whereClause; + std::shared_ptr hintRoot = nullptr; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_variable_expression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_variable_expression.h new file mode 100644 index 0000000000..9fad8f019b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/expression/parsed_variable_expression.h @@ -0,0 +1,46 @@ +#pragma once + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "parsed_expression.h" + +namespace lbug { +namespace parser { + +class ParsedVariableExpression : public ParsedExpression { +public: + ParsedVariableExpression(std::string variableName, std::string raw) + : ParsedExpression{common::ExpressionType::VARIABLE, std::move(raw)}, + variableName{std::move(variableName)} {} + + ParsedVariableExpression(std::string alias, std::string rawName, parsed_expr_vector children, + std::string variableName) + : ParsedExpression{common::ExpressionType::VARIABLE, std::move(alias), std::move(rawName), + std::move(children)}, + variableName{std::move(variableName)} {} + + explicit ParsedVariableExpression(std::string variableName) + : ParsedExpression{common::ExpressionType::VARIABLE}, + variableName{std::move(variableName)} {} + + inline std::string getVariableName() const { return variableName; } + + static std::unique_ptr deserialize( + common::Deserializer& deserializer); + + inline std::unique_ptr copy() const override { + return std::make_unique(alias, rawName, copyVector(children), + variableName); + } + +private: + inline void serializeInternal(common::Serializer& serializer) const override { + serializer.serializeValue(variableName); + } + +private: + std::string variableName; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/extension_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/extension_statement.h new file mode 100644 index 0000000000..6b99c8c1bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/extension_statement.h @@ -0,0 +1,23 @@ +#pragma once + +#include "extension/extension_action.h" +#include "statement.h" + +namespace lbug { +namespace parser { + +using namespace lbug::extension; + +class ExtensionStatement final : public Statement { +public: + explicit ExtensionStatement(std::unique_ptr info) + : Statement{common::StatementType::EXTENSION}, info{std::move(info)} {} + + std::unique_ptr getAuxInfo() const { return info->copy(); } + +private: + std::unique_ptr info; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parsed_data/attach_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parsed_data/attach_info.h new file mode 100644 index 0000000000..305b84b393 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parsed_data/attach_info.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace parser { + +struct AttachInfo { + AttachInfo(std::string dbPath, std::string dbAlias, std::string dbType, options_t options) + : dbPath{std::move(dbPath)}, dbAlias{std::move(dbAlias)}, dbType{std::move(dbType)}, + options{std::move(options)} {} + + std::string dbPath, dbAlias, dbType; + options_t options; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parsed_statement_visitor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parsed_statement_visitor.h new file mode 100644 index 0000000000..701e5635a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parsed_statement_visitor.h @@ -0,0 +1,66 @@ +#pragma once + +#include "statement.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace parser { + +class SingleQuery; +class QueryPart; +class ReadingClause; +class UpdatingClause; +class WithClause; +class ReturnClause; + +class StatementVisitor { +public: + StatementVisitor() = default; + virtual ~StatementVisitor() = default; + + void visit(const Statement& statement); + +private: + // LCOV_EXCL_START + virtual void visitQuery(const Statement& statement); + virtual void visitSingleQuery(const SingleQuery* singleQuery); + virtual void visitQueryPart(const QueryPart* queryPart); + virtual void visitReadingClause(const ReadingClause* readingClause); + virtual void visitMatch(const ReadingClause* /*readingClause*/) {} + virtual void visitUnwind(const ReadingClause* /*readingClause*/) {} + virtual void visitInQueryCall(const ReadingClause* /*readingClause*/) {} + virtual void visitLoadFrom(const ReadingClause* /*readingClause*/) {} + virtual void visitUpdatingClause(const UpdatingClause* /*updatingClause*/); + virtual void visitSet(const UpdatingClause* /*updatingClause*/) {} + virtual void visitDelete(const UpdatingClause* /*updatingClause*/) {} + virtual void visitInsert(const UpdatingClause* /*updatingClause*/) {} + virtual void visitMerge(const UpdatingClause* /*updatingClause*/) {} + virtual void visitWithClause(const WithClause* /*withClause*/) {} + virtual void visitReturnClause(const ReturnClause* /*returnClause*/) {} + + virtual void visitCreateSequence(const Statement& /*statement*/) {} + virtual void visitDrop(const Statement& /*statement*/) {} + virtual void visitCreateTable(const Statement& /*statement*/) {} + virtual void visitCreateType(const Statement& /*statement*/) {} + virtual void visitAlter(const Statement& /*statement*/) {} + virtual void visitCopyFrom(const Statement& /*statement*/) {} + virtual void visitCopyTo(const Statement& /*statement*/) {} + virtual void visitStandaloneCall(const Statement& /*statement*/) {} + virtual void visitExplain(const Statement& /*statement*/); + virtual void visitCreateMacro(const Statement& /*statement*/) {} + virtual void visitTransaction(const Statement& /*statement*/) {} + virtual void visitExtension(const Statement& /*statement*/) {} + virtual void visitExportDatabase(const Statement& /*statement*/) {} + virtual void visitImportDatabase(const Statement& /*statement*/) {} + virtual void visitAttachDatabase(const Statement& /*statement*/) {} + virtual void visitDetachDatabase(const Statement& /*statement*/) {} + virtual void visitUseDatabase(const Statement& /*statement*/) {} + virtual void visitStandaloneCallFunction(const Statement& /*statement*/) {} + virtual void visitExtensionClause(const Statement& /*statement*/) {} + // LCOV_EXCL_STOP +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parser.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parser.h new file mode 100644 index 0000000000..1fc75d2085 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/parser.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +#include "extension/transformer_extension.h" +#include "statement.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace parser { + +class Parser { + +public: + LBUG_API static std::vector> parseQuery(std::string_view query, + std::vector transformerExtensions = {}); +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/port_db.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/port_db.h new file mode 100644 index 0000000000..5a4f13f32a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/port_db.h @@ -0,0 +1,35 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class ExportDB : public Statement { +public: + explicit ExportDB(std::string filePath) + : Statement{common::StatementType::EXPORT_DATABASE}, filePath{std::move(filePath)} {} + + inline void setParsingOption(options_t options) { parsingOptions = std::move(options); } + inline const options_t& getParsingOptionsRef() const { return parsingOptions; } + inline std::string getFilePath() const { return filePath; } + +private: + options_t parsingOptions; + std::string filePath; +}; + +class ImportDB : public Statement { +public: + explicit ImportDB(std::string filePath) + : Statement{common::StatementType::IMPORT_DATABASE}, filePath{std::move(filePath)} {} + + inline std::string getFilePath() const { return filePath; } + +private: + std::string filePath; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/node_pattern.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/node_pattern.h new file mode 100644 index 0000000000..e88a85d6f1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/node_pattern.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace parser { + +class NodePattern { +public: + NodePattern(std::string name, std::vector tableNames, + std::vector propertyKeyVals) + : variableName{std::move(name)}, tableNames{std::move(tableNames)}, + propertyKeyVals{std::move(propertyKeyVals)} {} + DELETE_COPY_DEFAULT_MOVE(NodePattern); + + virtual ~NodePattern() = default; + + inline std::string getVariableName() const { return variableName; } + + inline std::vector getTableNames() const { return tableNames; } + + inline const std::vector& getPropertyKeyVals() const { + return propertyKeyVals; + } + +protected: + std::string variableName; + std::vector tableNames; + std::vector propertyKeyVals; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/pattern_element.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/pattern_element.h new file mode 100644 index 0000000000..8006a5ed37 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/pattern_element.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include "pattern_element_chain.h" + +namespace lbug { +namespace parser { + +class PatternElement { +public: + explicit PatternElement(NodePattern nodePattern) : nodePattern{std::move(nodePattern)} {} + DELETE_COPY_DEFAULT_MOVE(PatternElement); + + inline void setPathName(std::string name) { pathName = std::move(name); } + inline bool hasPathName() const { return !pathName.empty(); } + inline std::string getPathName() const { return pathName; } + + inline const NodePattern* getFirstNodePattern() const { return &nodePattern; } + + inline void addPatternElementChain(PatternElementChain chain) { + patternElementChains.push_back(std::move(chain)); + } + inline uint32_t getNumPatternElementChains() const { return patternElementChains.size(); } + inline const PatternElementChain* getPatternElementChain(uint32_t idx) const { + return &patternElementChains[idx]; + } + +private: + std::string pathName; + NodePattern nodePattern; + std::vector patternElementChains; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/pattern_element_chain.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/pattern_element_chain.h new file mode 100644 index 0000000000..2815a62399 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/pattern_element_chain.h @@ -0,0 +1,24 @@ +#pragma once + +#include "rel_pattern.h" + +namespace lbug { +namespace parser { + +class PatternElementChain { +public: + PatternElementChain(RelPattern relPattern, NodePattern nodePattern) + : relPattern{std::move(relPattern)}, nodePattern{std::move(nodePattern)} {} + DELETE_COPY_DEFAULT_MOVE(PatternElementChain); + + inline const RelPattern* getRelPattern() const { return &relPattern; } + + inline const NodePattern* getNodePattern() const { return &nodePattern; } + +private: + RelPattern relPattern; + NodePattern nodePattern; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/rel_pattern.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/rel_pattern.h new file mode 100644 index 0000000000..17151fab33 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/graph_pattern/rel_pattern.h @@ -0,0 +1,50 @@ +#pragma once + +#include "common/copy_constructors.h" +#include "common/enums/query_rel_type.h" +#include "node_pattern.h" + +namespace lbug { +namespace parser { + +enum class ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 }; + +struct RecursiveRelPatternInfo { + std::string lowerBound; + std::string upperBound; + std::string weightPropertyName; + std::string relName; + std::string nodeName; + std::unique_ptr whereExpression = nullptr; + bool hasProjection = false; + parsed_expr_vector relProjectionList; + parsed_expr_vector nodeProjectionList; + + RecursiveRelPatternInfo() = default; + DELETE_COPY_DEFAULT_MOVE(RecursiveRelPatternInfo); +}; + +class RelPattern : public NodePattern { +public: + RelPattern(std::string name, std::vector tableNames, common::QueryRelType relType, + ArrowDirection arrowDirection, std::vector propertyKeyValPairs, + RecursiveRelPatternInfo recursiveInfo) + : NodePattern{std::move(name), std::move(tableNames), std::move(propertyKeyValPairs)}, + relType{relType}, arrowDirection{arrowDirection}, + recursiveInfo{std::move(recursiveInfo)} {} + DELETE_COPY_DEFAULT_MOVE(RelPattern); + + common::QueryRelType getRelType() const { return relType; } + + ArrowDirection getDirection() const { return arrowDirection; } + + const RecursiveRelPatternInfo* getRecursiveInfo() const { return &recursiveInfo; } + +private: + common::QueryRelType relType; + ArrowDirection arrowDirection; + RecursiveRelPatternInfo recursiveInfo; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/query_part.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/query_part.h new file mode 100644 index 0000000000..9c4404d473 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/query_part.h @@ -0,0 +1,37 @@ +#pragma once + +#include "parser/query/reading_clause/reading_clause.h" +#include "parser/query/return_with_clause/with_clause.h" +#include "parser/query/updating_clause/updating_clause.h" + +namespace lbug { +namespace parser { + +class QueryPart { +public: + explicit QueryPart(WithClause withClause) : withClause{std::move(withClause)} {} + + inline uint32_t getNumUpdatingClauses() const { return updatingClauses.size(); } + inline UpdatingClause* getUpdatingClause(uint32_t idx) const { + return updatingClauses[idx].get(); + } + inline void addUpdatingClause(std::unique_ptr updatingClause) { + updatingClauses.push_back(std::move(updatingClause)); + } + + inline uint32_t getNumReadingClauses() const { return readingClauses.size(); } + inline ReadingClause* getReadingClause(uint32_t idx) const { return readingClauses[idx].get(); } + inline void addReadingClause(std::unique_ptr readingClause) { + readingClauses.push_back(std::move(readingClause)); + } + + inline const WithClause* getWithClause() const { return &withClause; } + +private: + std::vector> readingClauses; + std::vector> updatingClauses; + WithClause withClause; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/in_query_call_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/in_query_call_clause.h new file mode 100644 index 0000000000..1f1236cbad --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/in_query_call_clause.h @@ -0,0 +1,28 @@ +#pragma once + +#include "parser/query/reading_clause/reading_clause.h" +#include "yield_variable.h" + +namespace lbug { +namespace parser { + +class InQueryCallClause final : public ReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::IN_QUERY_CALL; + +public: + InQueryCallClause(std::unique_ptr functionExpression, + std::vector yieldClause) + : ReadingClause{clauseType_}, functionExpression{std::move(functionExpression)}, + yieldVariables{std::move(yieldClause)} {} + + const ParsedExpression* getFunctionExpression() const { return functionExpression.get(); } + + const std::vector& getYieldVariables() const { return yieldVariables; } + +private: + std::unique_ptr functionExpression; + std::vector yieldVariables; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/join_hint.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/join_hint.h new file mode 100644 index 0000000000..06909ae62e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/join_hint.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace parser { + +struct JoinHintNode { + std::string variableName; + std::vector> children; + + JoinHintNode() = default; + explicit JoinHintNode(std::string name) : variableName{std::move(name)} {} + void addChild(std::shared_ptr child) { children.push_back(std::move(child)); } + bool isLeaf() const { return children.empty(); } +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/load_from.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/load_from.h new file mode 100644 index 0000000000..1716d83517 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/load_from.h @@ -0,0 +1,37 @@ +#pragma once + +#include "parser/ddl/parsed_property_definition.h" +#include "parser/expression/parsed_expression.h" +#include "parser/scan_source.h" +#include "reading_clause.h" + +namespace lbug { +namespace parser { + +class LoadFrom : public ReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::LOAD_FROM; + +public: + explicit LoadFrom(std::unique_ptr source) + : ReadingClause{clauseType_}, source{std::move(source)} {} + + BaseScanSource* getSource() const { return source.get(); } + + void setParingOptions(options_t options) { parsingOptions = std::move(options); } + const options_t& getParsingOptions() const { return parsingOptions; } + + void setPropertyDefinitions(std::vector definitions) { + columnDefinitions = std::move(definitions); + } + const std::vector& getColumnDefinitions() const { + return columnDefinitions; + } + +private: + std::unique_ptr source; + std::vector columnDefinitions; + options_t parsingOptions; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/match_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/match_clause.h new file mode 100644 index 0000000000..e86558665e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/match_clause.h @@ -0,0 +1,34 @@ +#pragma once + +#include "join_hint.h" +#include "parser/query/graph_pattern/pattern_element.h" +#include "reading_clause.h" + +namespace lbug { +namespace parser { + +class MatchClause : public ReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::MATCH; + +public: + MatchClause(std::vector patternElements, + common::MatchClauseType matchClauseType) + : ReadingClause{clauseType_}, patternElements{std::move(patternElements)}, + matchClauseType{matchClauseType} {} + + const std::vector& getPatternElementsRef() const { return patternElements; } + + common::MatchClauseType getMatchClauseType() const { return matchClauseType; } + + void setHint(std::shared_ptr root) { hintRoot = std::move(root); } + bool hasHint() const { return hintRoot != nullptr; } + std::shared_ptr getHint() const { return hintRoot; } + +private: + std::vector patternElements; + common::MatchClauseType matchClauseType; + std::shared_ptr hintRoot = nullptr; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/reading_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/reading_clause.h new file mode 100644 index 0000000000..ad050115e3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/reading_clause.h @@ -0,0 +1,33 @@ +#pragma once + +#include "common/cast.h" +#include "common/enums/clause_type.h" +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace parser { + +class ReadingClause { +public: + explicit ReadingClause(common::ClauseType clauseType) : clauseType{clauseType} {}; + virtual ~ReadingClause() = default; + + common::ClauseType getClauseType() const { return clauseType; } + + void setWherePredicate(std::unique_ptr expression) { + wherePredicate = std::move(expression); + } + bool hasWherePredicate() const { return wherePredicate != nullptr; } + const ParsedExpression* getWherePredicate() const { return wherePredicate.get(); } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +private: + common::ClauseType clauseType; + std::unique_ptr wherePredicate; +}; +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/unwind_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/unwind_clause.h new file mode 100644 index 0000000000..e731d693a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/unwind_clause.h @@ -0,0 +1,27 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "reading_clause.h" + +namespace lbug { +namespace parser { + +class UnwindClause : public ReadingClause { + static constexpr common::ClauseType clauseType_ = common::ClauseType::UNWIND; + +public: + UnwindClause(std::unique_ptr expression, std::string listAlias) + : ReadingClause{clauseType_}, expression{std::move(expression)}, + alias{std::move(listAlias)} {} + + const ParsedExpression* getExpression() const { return expression.get(); } + + std::string getAlias() const { return alias; } + +private: + std::unique_ptr expression; + std::string alias; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/yield_variable.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/yield_variable.h new file mode 100644 index 0000000000..9e49fa1082 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/reading_clause/yield_variable.h @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace lbug { +namespace parser { + +struct YieldVariable { + std::string name; + std::string alias; + + YieldVariable(std::string name, std::string alias) + : name{std::move(name)}, alias{std::move(alias)} {} + bool hasAlias() const { return alias != ""; } +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/regular_query.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/regular_query.h new file mode 100644 index 0000000000..f76922dbe1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/regular_query.h @@ -0,0 +1,37 @@ +#pragma once + +#include "common/types/types.h" +#include "parser/statement.h" +#include "single_query.h" + +namespace lbug { +namespace parser { + +class RegularQuery : public Statement { + static constexpr common::StatementType type_ = common::StatementType::QUERY; + +public: + explicit RegularQuery(SingleQuery singleQuery) : Statement{type_} { + singleQueries.push_back(std::move(singleQuery)); + } + + void addSingleQuery(SingleQuery singleQuery, bool isUnionAllQuery) { + singleQueries.push_back(std::move(singleQuery)); + isUnionAll.push_back(isUnionAllQuery); + } + + common::idx_t getNumSingleQueries() const { return singleQueries.size(); } + + const SingleQuery* getSingleQuery(common::idx_t singleQueryIdx) const { + return &singleQueries[singleQueryIdx]; + } + + std::vector getIsUnionAll() const { return isUnionAll; } + +private: + std::vector singleQueries; + std::vector isUnionAll; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/projection_body.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/projection_body.h new file mode 100644 index 0000000000..66d573c25f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/projection_body.h @@ -0,0 +1,56 @@ +#pragma once + +#include "common/copy_constructors.h" +#include "parser/expression/parsed_expression.h" + +namespace lbug { +namespace parser { + +class ProjectionBody { +public: + ProjectionBody(bool isDistinct, + std::vector> projectionExpressions) + : isDistinct{isDistinct}, projectionExpressions{std::move(projectionExpressions)} {} + DELETE_COPY_DEFAULT_MOVE(ProjectionBody); + + inline bool getIsDistinct() const { return isDistinct; } + + inline const std::vector>& getProjectionExpressions() const { + return projectionExpressions; + } + + inline void setOrderByExpressions(std::vector> expressions, + std::vector sortOrders) { + orderByExpressions = std::move(expressions); + isAscOrders = std::move(sortOrders); + } + inline bool hasOrderByExpressions() const { return !orderByExpressions.empty(); } + inline const std::vector>& getOrderByExpressions() const { + return orderByExpressions; + } + + inline std::vector getSortOrders() const { return isAscOrders; } + + inline void setSkipExpression(std::unique_ptr expression) { + skipExpression = std::move(expression); + } + inline bool hasSkipExpression() const { return skipExpression != nullptr; } + inline ParsedExpression* getSkipExpression() const { return skipExpression.get(); } + + inline void setLimitExpression(std::unique_ptr expression) { + limitExpression = std::move(expression); + } + inline bool hasLimitExpression() const { return limitExpression != nullptr; } + inline ParsedExpression* getLimitExpression() const { return limitExpression.get(); } + +private: + bool isDistinct; + std::vector> projectionExpressions; + std::vector> orderByExpressions; + std::vector isAscOrders; + std::unique_ptr skipExpression; + std::unique_ptr limitExpression; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/return_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/return_clause.h new file mode 100644 index 0000000000..3e05fae542 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/return_clause.h @@ -0,0 +1,23 @@ +#pragma once + +#include "projection_body.h" + +namespace lbug { +namespace parser { + +class ReturnClause { +public: + explicit ReturnClause(ProjectionBody projectionBody) + : projectionBody{std::move(projectionBody)} {} + DELETE_COPY_DEFAULT_MOVE(ReturnClause); + + virtual ~ReturnClause() = default; + + inline const ProjectionBody* getProjectionBody() const { return &projectionBody; } + +private: + ProjectionBody projectionBody; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/with_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/with_clause.h new file mode 100644 index 0000000000..0d611aefca --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/return_with_clause/with_clause.h @@ -0,0 +1,26 @@ +#pragma once + +#include "return_clause.h" + +namespace lbug { +namespace parser { + +class WithClause : public ReturnClause { +public: + explicit WithClause(ProjectionBody projectionBody) : ReturnClause{std::move(projectionBody)} {} + DELETE_COPY_DEFAULT_MOVE(WithClause); + + inline void setWhereExpression(std::unique_ptr expression) { + whereExpression = std::move(expression); + } + + inline bool hasWhereExpression() const { return whereExpression != nullptr; } + + inline ParsedExpression* getWhereExpression() const { return whereExpression.get(); } + +private: + std::unique_ptr whereExpression; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/single_query.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/single_query.h new file mode 100644 index 0000000000..a16225c8cc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/single_query.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include "common/assert.h" +#include "common/copy_constructors.h" +#include "query_part.h" + +namespace lbug { +namespace parser { + +class SingleQuery { +public: + SingleQuery() = default; + DELETE_COPY_DEFAULT_MOVE(SingleQuery); + + inline void addQueryPart(QueryPart queryPart) { queryParts.push_back(std::move(queryPart)); } + inline uint32_t getNumQueryParts() const { return queryParts.size(); } + inline const QueryPart* getQueryPart(uint32_t idx) const { return &queryParts[idx]; } + + inline uint32_t getNumUpdatingClauses() const { return updatingClauses.size(); } + inline UpdatingClause* getUpdatingClause(uint32_t idx) const { + return updatingClauses[idx].get(); + } + inline void addUpdatingClause(std::unique_ptr updatingClause) { + updatingClauses.push_back(std::move(updatingClause)); + } + + inline uint32_t getNumReadingClauses() const { return readingClauses.size(); } + inline ReadingClause* getReadingClause(uint32_t idx) const { return readingClauses[idx].get(); } + inline void addReadingClause(std::unique_ptr readingClause) { + readingClauses.push_back(std::move(readingClause)); + } + + inline void setReturnClause(ReturnClause clause) { returnClause = std::move(clause); } + inline bool hasReturnClause() const { return returnClause.has_value(); } + inline const ReturnClause* getReturnClause() const { + KU_ASSERT(returnClause.has_value()); + return &returnClause.value(); + } + +private: + std::vector queryParts; + std::vector> readingClauses; + std::vector> updatingClauses; + std::optional returnClause; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/delete_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/delete_clause.h new file mode 100644 index 0000000000..7fca67dffa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/delete_clause.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common/enums/delete_type.h" +#include "parser/expression/parsed_expression.h" +#include "updating_clause.h" + +namespace lbug { +namespace parser { + +class DeleteClause final : public UpdatingClause { +public: + explicit DeleteClause(common::DeleteNodeType deleteType) + : UpdatingClause{common::ClauseType::DELETE_}, deleteType{deleteType} {}; + + void addExpression(std::unique_ptr expression) { + expressions.push_back(std::move(expression)); + } + common::DeleteNodeType getDeleteClauseType() const { return deleteType; } + uint32_t getNumExpressions() const { return expressions.size(); } + ParsedExpression* getExpression(uint32_t idx) const { return expressions[idx].get(); } + +private: + common::DeleteNodeType deleteType; + parsed_expr_vector expressions; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/insert_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/insert_clause.h new file mode 100644 index 0000000000..0ed44ed946 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/insert_clause.h @@ -0,0 +1,24 @@ +#pragma once + +#include "parser/query/graph_pattern/pattern_element.h" +#include "updating_clause.h" + +namespace lbug { +namespace parser { + +class InsertClause final : public UpdatingClause { +public: + explicit InsertClause(std::vector patternElements) + : UpdatingClause{common::ClauseType::INSERT}, + patternElements{std::move(patternElements)} {}; + + inline const std::vector& getPatternElementsRef() const { + return patternElements; + } + +private: + std::vector patternElements; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/merge_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/merge_clause.h new file mode 100644 index 0000000000..fd76143e8d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/merge_clause.h @@ -0,0 +1,40 @@ +#pragma once + +#include "parser/query/graph_pattern/pattern_element.h" +#include "updating_clause.h" + +namespace lbug { +namespace parser { + +class MergeClause final : public UpdatingClause { +public: + explicit MergeClause(std::vector patternElements) + : UpdatingClause{common::ClauseType::MERGE}, patternElements{std::move(patternElements)} {} + + inline const std::vector& getPatternElementsRef() const { + return patternElements; + } + inline void addOnMatchSetItems(parsed_expr_pair setItem) { + onMatchSetItems.push_back(std::move(setItem)); + } + inline bool hasOnMatchSetItems() const { return !onMatchSetItems.empty(); } + inline const std::vector& getOnMatchSetItemsRef() const { + return onMatchSetItems; + } + + inline void addOnCreateSetItems(parsed_expr_pair setItem) { + onCreateSetItems.push_back(std::move(setItem)); + } + inline bool hasOnCreateSetItems() const { return !onCreateSetItems.empty(); } + inline const std::vector& getOnCreateSetItemsRef() const { + return onCreateSetItems; + } + +private: + std::vector patternElements; + std::vector onMatchSetItems; + std::vector onCreateSetItems; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/set_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/set_clause.h new file mode 100644 index 0000000000..97abdad3aa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/set_clause.h @@ -0,0 +1,21 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "updating_clause.h" + +namespace lbug { +namespace parser { + +class SetClause final : public UpdatingClause { +public: + SetClause() : UpdatingClause{common::ClauseType::SET} {}; + + inline void addSetItem(parsed_expr_pair setItem) { setItems.push_back(std::move(setItem)); } + inline const std::vector& getSetItemsRef() const { return setItems; } + +private: + std::vector setItems; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/updating_clause.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/updating_clause.h new file mode 100644 index 0000000000..67b25f90af --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/query/updating_clause/updating_clause.h @@ -0,0 +1,26 @@ +#pragma once + +#include "common/cast.h" +#include "common/enums/clause_type.h" + +namespace lbug { +namespace parser { + +class UpdatingClause { +public: + explicit UpdatingClause(common::ClauseType clauseType) : clauseType{clauseType} {}; + virtual ~UpdatingClause() = default; + + common::ClauseType getClauseType() const { return clauseType; } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +private: + common::ClauseType clauseType; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/scan_source.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/scan_source.h new file mode 100644 index 0000000000..cd3c688380 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/scan_source.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include + +#include "common/copy_constructors.h" +#include "common/enums/scan_source_type.h" +#include "expression/parsed_expression.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +struct BaseScanSource { + common::ScanSourceType type; + + explicit BaseScanSource(common::ScanSourceType type) : type{type} {} + virtual ~BaseScanSource() = default; + DELETE_COPY_AND_MOVE(BaseScanSource); + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } +}; + +struct ParameterScanSource : public BaseScanSource { + std::unique_ptr paramExpression; + + explicit ParameterScanSource(std::unique_ptr paramExpression) + : BaseScanSource{common::ScanSourceType::PARAM}, + paramExpression{std::move(paramExpression)} {} +}; + +struct FileScanSource : public BaseScanSource { + std::vector filePaths; + + explicit FileScanSource(std::vector paths) + : BaseScanSource{common::ScanSourceType::FILE}, filePaths{std::move(paths)} {} +}; + +struct ObjectScanSource : public BaseScanSource { + // If multiple object presents, assuming they have a nested structure. + // E.g. for postgres.person, objectNames should be [postgres, person] + std::vector objectNames; + + explicit ObjectScanSource(std::vector objectNames) + : BaseScanSource{common::ScanSourceType::OBJECT}, objectNames{std::move(objectNames)} {} +}; + +struct QueryScanSource : public BaseScanSource { + std::unique_ptr statement; + + explicit QueryScanSource(std::unique_ptr statement) + : BaseScanSource{common::ScanSourceType::QUERY}, statement{std::move(statement)} {} +}; + +struct TableFuncScanSource : public BaseScanSource { + std::unique_ptr functionExpression = nullptr; + + explicit TableFuncScanSource(std::unique_ptr functionExpression) + : BaseScanSource{common::ScanSourceType::TABLE_FUNC}, + functionExpression{std::move(functionExpression)} {} +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/standalone_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/standalone_call.h new file mode 100644 index 0000000000..03e635716f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/standalone_call.h @@ -0,0 +1,25 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class StandaloneCall : public Statement { +public: + explicit StandaloneCall(std::string optionName, std::unique_ptr optionValue) + : Statement{common::StatementType::STANDALONE_CALL}, optionName{std::move(optionName)}, + optionValue{std::move(optionValue)} {} + + std::string getOptionName() const { return optionName; } + + ParsedExpression* getOptionValue() const { return optionValue.get(); } + +private: + std::string optionName; + std::unique_ptr optionValue; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/standalone_call_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/standalone_call_function.h new file mode 100644 index 0000000000..372dd7c79a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/standalone_call_function.h @@ -0,0 +1,22 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "parser/statement.h" + +namespace lbug { +namespace parser { + +class StandaloneCallFunction : public Statement { +public: + explicit StandaloneCallFunction(std::unique_ptr functionExpression) + : Statement{common::StatementType::STANDALONE_CALL_FUNCTION}, + functionExpression{std::move(functionExpression)} {} + + const ParsedExpression* getFunctionExpression() const { return functionExpression.get(); } + +private: + std::unique_ptr functionExpression; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/statement.h new file mode 100644 index 0000000000..8e63760a4e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/statement.h @@ -0,0 +1,55 @@ +#pragma once + +#include "common/cast.h" +#include "common/enums/statement_type.h" + +namespace lbug { +namespace parser { + +class Statement { +public: + explicit Statement(common::StatementType statementType) + : parsingTime{0}, statementType{statementType}, internal{false} {} + + virtual ~Statement() = default; + + common::StatementType getStatementType() const { return statementType; } + void setToInternal() { internal = true; } + bool isInternal() const { return internal; } + void setParsingTime(double time) { parsingTime = time; } + double getParsingTime() const { return parsingTime; } + + bool requireTransaction() const { + switch (statementType) { + case common::StatementType::TRANSACTION: + return false; + default: + return true; + } + } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + +private: + double parsingTime; + common::StatementType statementType; + // By setting the statement to internal, we still execute the statement, but will not return the + // executio result as part of the query result returned to users. + // The use case for this is when a query internally generates other queries to finish first, + // e.g., `TableFunction::rewriteFunc`. + bool internal; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/transaction_statement.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/transaction_statement.h new file mode 100644 index 0000000000..f12c69c8ef --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/transaction_statement.h @@ -0,0 +1,23 @@ +#pragma once + +#include "statement.h" +#include "transaction/transaction_action.h" + +namespace lbug { +namespace parser { + +class TransactionStatement : public Statement { + static constexpr common::StatementType statementType_ = common::StatementType::TRANSACTION; + +public: + explicit TransactionStatement(transaction::TransactionAction transactionAction) + : Statement{statementType_}, transactionAction{transactionAction} {} + + transaction::TransactionAction getTransactionAction() const { return transactionAction; } + +private: + transaction::TransactionAction transactionAction; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/transformer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/transformer.h new file mode 100644 index 0000000000..2919148b8a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/transformer.h @@ -0,0 +1,261 @@ +#pragma once + +// ANTLR4 generates code with unused parameters. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "cypher_parser.h" +#pragma GCC diagnostic pop + +#include "common/enums/conflict_action.h" +#include "extension/transformer_extension.h" +#include "parser/ddl/parsed_property_definition.h" +#include "statement.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace parser { + +class RegularQuery; +class SingleQuery; +class QueryPart; +class UpdatingClause; +class ReadingClause; +class WithClause; +class ReturnClause; +class ProjectionBody; +class PatternElement; +class NodePattern; +class PatternElementChain; +class RelPattern; +struct ParsedCaseAlternative; +struct BaseScanSource; +struct JoinHintNode; +struct YieldVariable; + +class Transformer { +public: + Transformer(CypherParser::Ku_StatementsContext& root, + std::vector transformerExtensions) + : root{root}, transformerExtensions{std::move(transformerExtensions)} {} + + std::vector> transform(); + + void registerTransformExtension( + std::unique_ptr transformerExtension); + + std::unique_ptr transformStatement(CypherParser::OC_StatementContext& ctx); + + std::unique_ptr transformWhere(CypherParser::OC_WhereContext& ctx); + + static std::string transformVariable(CypherParser::OC_VariableContext& ctx); + std::string transformSchemaName(CypherParser::OC_SchemaNameContext& ctx); + static std::string transformSymbolicName(CypherParser::OC_SymbolicNameContext& ctx); + static std::string transformStringLiteral(antlr4::tree::TerminalNode& stringLiteral); + static common::ConflictAction transformConflictAction(CypherParser::KU_IfNotExistsContext* ctx); + + // Transform copy statement. + std::unique_ptr transformCopyTo(CypherParser::KU_CopyTOContext& ctx); + std::unique_ptr transformCopyFrom(CypherParser::KU_CopyFromContext& ctx); + std::unique_ptr transformCopyFromByColumn( + CypherParser::KU_CopyFromByColumnContext& ctx); + std::vector transformColumnNames(CypherParser::KU_ColumnNamesContext& ctx); + std::vector transformFilePaths( + const std::vector& stringLiteral); + std::unique_ptr transformScanSource(CypherParser::KU_ScanSourceContext& ctx); + options_t transformOptions(CypherParser::KU_OptionsContext& ctx); + + std::unique_ptr transformExportDatabase(CypherParser::KU_ExportDatabaseContext& ctx); + std::unique_ptr transformImportDatabase(CypherParser::KU_ImportDatabaseContext& ctx); + + // Transform query statement. + std::unique_ptr transformQuery(CypherParser::OC_QueryContext& ctx); + std::unique_ptr transformRegularQuery(CypherParser::OC_RegularQueryContext& ctx); + SingleQuery transformSingleQuery(CypherParser::OC_SingleQueryContext& ctx); + SingleQuery transformSinglePartQuery(CypherParser::OC_SinglePartQueryContext& ctx); + QueryPart transformQueryPart(CypherParser::KU_QueryPartContext& ctx); + + // Transform updating. + std::unique_ptr transformUpdatingClause( + CypherParser::OC_UpdatingClauseContext& ctx); + std::unique_ptr transformCreate(CypherParser::OC_CreateContext& ctx); + std::unique_ptr transformMerge(CypherParser::OC_MergeContext& ctx); + std::unique_ptr transformSet(CypherParser::OC_SetContext& ctx); + parsed_expr_pair transformSetItem(CypherParser::OC_SetItemContext& ctx); + std::unique_ptr transformDelete(CypherParser::OC_DeleteContext& ctx); + + // Transform reading. + std::unique_ptr transformReadingClause( + CypherParser::OC_ReadingClauseContext& ctx); + std::unique_ptr transformMatch(CypherParser::OC_MatchContext& ctx); + std::unique_ptr transformUnwind(CypherParser::OC_UnwindContext& ctx); + std::vector transformYieldVariables(CypherParser::OC_YieldItemsContext& ctx); + std::unique_ptr transformInQueryCall(CypherParser::KU_InQueryCallContext& ctx); + std::unique_ptr transformLoadFrom(CypherParser::KU_LoadFromContext& ctx); + std::shared_ptr transformJoinHint(CypherParser::KU_JoinNodeContext& ctx); + + // Transform projection. + WithClause transformWith(CypherParser::OC_WithContext& ctx); + ReturnClause transformReturn(CypherParser::OC_ReturnContext& ctx); + ProjectionBody transformProjectionBody(CypherParser::OC_ProjectionBodyContext& ctx); + std::vector> transformProjectionItems( + CypherParser::OC_ProjectionItemsContext& ctx); + std::unique_ptr transformProjectionItem( + CypherParser::OC_ProjectionItemContext& ctx); + + // Transform graph pattern. + std::vector transformPattern(CypherParser::OC_PatternContext& ctx); + PatternElement transformPatternPart(CypherParser::OC_PatternPartContext& ctx); + PatternElement transformAnonymousPatternPart(CypherParser::OC_AnonymousPatternPartContext& ctx); + PatternElement transformPatternElement(CypherParser::OC_PatternElementContext& ctx); + NodePattern transformNodePattern(CypherParser::OC_NodePatternContext& ctx); + PatternElementChain transformPatternElementChain( + CypherParser::OC_PatternElementChainContext& ctx); + RelPattern transformRelationshipPattern(CypherParser::OC_RelationshipPatternContext& ctx); + std::vector transformProperties(CypherParser::KU_PropertiesContext& ctx); + std::vector transformRelTypes(CypherParser::OC_RelationshipTypesContext& ctx); + std::vector transformNodeLabels(CypherParser::OC_NodeLabelsContext& ctx); + std::string transformLabelName(CypherParser::OC_LabelNameContext& ctx); + std::string transformRelTypeName(CypherParser::OC_RelTypeNameContext& ctx); + + // Transform expression. + std::unique_ptr transformExpression(CypherParser::OC_ExpressionContext& ctx); + std::unique_ptr transformOrExpression( + CypherParser::OC_OrExpressionContext& ctx); + std::unique_ptr transformXorExpression( + CypherParser::OC_XorExpressionContext& ctx); + std::unique_ptr transformAndExpression( + CypherParser::OC_AndExpressionContext& ctx); + std::unique_ptr transformNotExpression( + CypherParser::OC_NotExpressionContext& ctx); + std::unique_ptr transformComparisonExpression( + CypherParser::OC_ComparisonExpressionContext& ctx); + std::unique_ptr transformBitwiseOrOperatorExpression( + CypherParser::KU_BitwiseOrOperatorExpressionContext& ctx); + std::unique_ptr transformBitwiseAndOperatorExpression( + CypherParser::KU_BitwiseAndOperatorExpressionContext& ctx); + std::unique_ptr transformBitShiftOperatorExpression( + CypherParser::KU_BitShiftOperatorExpressionContext& ctx); + std::unique_ptr transformAddOrSubtractExpression( + CypherParser::OC_AddOrSubtractExpressionContext& ctx); + std::unique_ptr transformMultiplyDivideModuloExpression( + CypherParser::OC_MultiplyDivideModuloExpressionContext& ctx); + std::unique_ptr transformPowerOfExpression( + CypherParser::OC_PowerOfExpressionContext& ctx); + std::unique_ptr transformUnaryAddSubtractOrFactorialExpression( + CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext& ctx); + std::unique_ptr transformStringListNullOperatorExpression( + CypherParser::OC_StringListNullOperatorExpressionContext& ctx); + std::unique_ptr transformStringOperatorExpression( + CypherParser::OC_StringOperatorExpressionContext& ctx, + std::unique_ptr propertyExpression); + std::unique_ptr transformListOperatorExpression( + CypherParser::OC_ListOperatorExpressionContext& ctx, + std::unique_ptr childExpression); + std::unique_ptr transformNullOperatorExpression( + CypherParser::OC_NullOperatorExpressionContext& ctx, + std::unique_ptr propertyExpression); + std::unique_ptr transformPropertyOrLabelsExpression( + CypherParser::OC_PropertyOrLabelsExpressionContext& ctx); + std::unique_ptr transformAtom(CypherParser::OC_AtomContext& ctx); + std::unique_ptr transformLiteral(CypherParser::OC_LiteralContext& ctx); + std::unique_ptr transformBooleanLiteral( + CypherParser::OC_BooleanLiteralContext& ctx); + std::unique_ptr transformListLiteral( + CypherParser::OC_ListLiteralContext& ctx); + std::unique_ptr transformStructLiteral( + CypherParser::KU_StructLiteralContext& ctx); + std::unique_ptr transformParameterExpression( + CypherParser::OC_ParameterContext& ctx); + std::unique_ptr transformParenthesizedExpression( + CypherParser::OC_ParenthesizedExpressionContext& ctx); + std::unique_ptr transformFunctionInvocation( + CypherParser::OC_FunctionInvocationContext& ctx); + std::string transformFunctionName(CypherParser::OC_FunctionNameContext& ctx); + std::vector transformLambdaVariables(CypherParser::KU_LambdaVarsContext& ctx); + std::unique_ptr transformLambdaParameter( + CypherParser::KU_LambdaParameterContext& ctx); + std::unique_ptr transformFunctionParameterExpression( + CypherParser::KU_FunctionParameterContext& ctx); + std::unique_ptr transformPathPattern( + CypherParser::OC_PathPatternsContext& ctx); + std::unique_ptr transformExistCountSubquery( + CypherParser::OC_ExistCountSubqueryContext& ctx); + std::unique_ptr transformOcQuantifier( + CypherParser::OC_QuantifierContext& ctx); + std::unique_ptr createPropertyExpression( + CypherParser::OC_PropertyKeyNameContext& ctx, std::unique_ptr child); + std::unique_ptr createPropertyExpression( + CypherParser::OC_PropertyLookupContext& ctx, std::unique_ptr child); + std::unique_ptr transformCaseExpression( + CypherParser::OC_CaseExpressionContext& ctx); + ParsedCaseAlternative transformCaseAlternative(CypherParser::OC_CaseAlternativeContext& ctx); + std::unique_ptr transformNumberLiteral( + CypherParser::OC_NumberLiteralContext& ctx, bool negative); + std::unique_ptr transformProperty( + CypherParser::OC_PropertyExpressionContext& ctx); + std::string transformPropertyKeyName(CypherParser::OC_PropertyKeyNameContext& ctx); + std::unique_ptr transformIntegerLiteral( + CypherParser::OC_IntegerLiteralContext& ctx, bool negative); + std::unique_ptr transformDoubleLiteral( + CypherParser::OC_DoubleLiteralContext& ctx, bool negative); + + // Transform ddl. + std::unique_ptr transformAlterTable(CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformCreateNodeTable( + CypherParser::KU_CreateNodeTableContext& ctx); + std::unique_ptr transformCreateRelGroup(CypherParser::KU_CreateRelTableContext& ctx); + std::unique_ptr transformCreateSequence(CypherParser::KU_CreateSequenceContext& ctx); + std::unique_ptr transformCreateType(CypherParser::KU_CreateTypeContext& ctx); + std::unique_ptr transformDrop(CypherParser::KU_DropContext& ctx); + std::unique_ptr transformRenameTable(CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformAddFromToConnection( + CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformDropFromToConnection( + CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformAddProperty(CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformDropProperty(CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformRenameProperty(CypherParser::KU_AlterTableContext& ctx); + std::unique_ptr transformCommentOn(CypherParser::KU_CommentOnContext& ctx); + std::string transformUnionType(CypherParser::KU_UnionTypeContext& ctx); + std::string transformStructType(CypherParser::KU_StructTypeContext& ctx); + std::string transformMapType(CypherParser::KU_MapTypeContext& ctx); + std::string transformDecimalType(CypherParser::KU_DecimalTypeContext& ctx); + std::string transformDataType(CypherParser::KU_DataTypeContext& ctx); + std::string getPKName(CypherParser::KU_CreateNodeTableContext& ctx); + std::string transformPrimaryKey(CypherParser::KU_CreateNodeConstraintContext& ctx); + std::string transformPrimaryKey(CypherParser::KU_ColumnDefinitionContext& ctx); + std::vector transformColumnDefinitions( + CypherParser::KU_ColumnDefinitionsContext& ctx); + ParsedColumnDefinition transformColumnDefinition(CypherParser::KU_ColumnDefinitionContext& ctx); + std::vector transformPropertyDefinitions( + CypherParser::KU_PropertyDefinitionsContext& ctx); + + // Transform standalone call. + std::unique_ptr transformStandaloneCall(CypherParser::KU_StandaloneCallContext& ctx); + + // Transform create macro. + std::unique_ptr transformCreateMacro(CypherParser::KU_CreateMacroContext& ctx); + std::vector transformPositionalArgs(CypherParser::KU_PositionalArgsContext& ctx); + + // Transform transaction. + std::unique_ptr transformTransaction(CypherParser::KU_TransactionContext& ctx); + + // Transform extension. + std::unique_ptr transformExtension(CypherParser::KU_ExtensionContext& ctx); + + // Transform attach/detach/use database. + std::unique_ptr transformAttachDatabase(CypherParser::KU_AttachDatabaseContext& ctx); + std::unique_ptr transformDetachDatabase(CypherParser::KU_DetachDatabaseContext& ctx); + std::unique_ptr transformUseDatabase(CypherParser::KU_UseDatabaseContext& ctx); + + std::unique_ptr transformExtensionStatement(antlr4::ParserRuleContext* ctx); + +private: + CypherParser::Ku_StatementsContext& root; + std::vector transformerExtensions; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/use_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/use_database.h new file mode 100644 index 0000000000..01c751c706 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/use_database.h @@ -0,0 +1,15 @@ +#pragma once + +#include "parser/database_statement.h" + +namespace lbug { +namespace parser { + +class UseDatabase final : public DatabaseStatement { +public: + explicit UseDatabase(std::string dbName) + : DatabaseStatement{common::StatementType::USE_DATABASE, std::move(dbName)} {} +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/visitor/standalone_call_rewriter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/visitor/standalone_call_rewriter.h new file mode 100644 index 0000000000..138ce15212 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/visitor/standalone_call_rewriter.h @@ -0,0 +1,25 @@ +#pragma once + +#include "parser/parsed_statement_visitor.h" + +namespace lbug { +namespace parser { + +class StandaloneCallRewriter final : public StatementVisitor { +public: + explicit StandaloneCallRewriter(main::ClientContext* context, bool allowRewrite) + : StatementVisitor{}, rewriteQuery{}, context{context}, singleStatement{allowRewrite} {} + + std::string getRewriteQuery(const Statement& statement); + +private: + void visitStandaloneCallFunction(const Statement& statement) override; + +private: + std::string rewriteQuery; + main::ClientContext* context; + bool singleStatement; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/visitor/statement_read_write_analyzer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/visitor/statement_read_write_analyzer.h new file mode 100644 index 0000000000..f7f14e2f11 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/parser/visitor/statement_read_write_analyzer.h @@ -0,0 +1,46 @@ +#pragma once + +#include "parser/expression/parsed_expression.h" +#include "parser/parsed_statement_visitor.h" + +namespace lbug { +namespace parser { + +class StatementReadWriteAnalyzer final : public StatementVisitor { +public: + explicit StatementReadWriteAnalyzer(main::ClientContext* context) + : StatementVisitor{}, readOnly{true}, context{context} {} + + bool isReadOnly() const { return readOnly; } + +private: + void visitCreateSequence(const Statement& /*statement*/) override { readOnly = false; } + void visitDrop(const Statement& /*statement*/) override { readOnly = false; } + void visitCreateTable(const Statement& /*statement*/) override { readOnly = false; } + void visitCreateType(const Statement& /*statement*/) override { readOnly = false; } + void visitAlter(const Statement& /*statement*/) override { readOnly = false; } + void visitCopyFrom(const Statement& /*statement*/) override { readOnly = false; } + void visitStandaloneCall(const Statement& /*statement*/) override { readOnly = true; } + void visitStandaloneCallFunction(const Statement& /*statement*/) override { readOnly = false; } + void visitCreateMacro(const Statement& /*statement*/) override { readOnly = false; } + void visitExtension(const Statement& /*statement*/) override; + + void visitReadingClause(const ReadingClause* readingClause) override; + void visitWithClause(const WithClause* withClause) override; + void visitReturnClause(const ReturnClause* returnClause) override; + + void visitUpdatingClause(const UpdatingClause* /*updatingClause*/) override { + readOnly = false; + } + + void visitExtensionClause(const Statement& /* statement*/) override { readOnly = false; } + + bool isExprReadOnly(const ParsedExpression* expr); + +private: + bool readOnly; + main::ClientContext* context; +}; + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/cardinality_estimator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/cardinality_estimator.h new file mode 100644 index 0000000000..b7e7729afc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/cardinality_estimator.h @@ -0,0 +1,63 @@ +#pragma once + +#include "binder/query/query_graph.h" +#include "planner/operator/logical_plan.h" +#include "storage/stats/table_stats.h" + +namespace lbug { +namespace main { +class ClientContext; +} // namespace main + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace planner { + +class LogicalAggregate; + +class CardinalityEstimator { +public: + explicit CardinalityEstimator(main::ClientContext* context) : context{context} {} + DELETE_COPY_DEFAULT_MOVE(CardinalityEstimator); + + void init(const binder::QueryGraph& queryGraph); + LBUG_API void init(const binder::NodeExpression& node); + + void rectifyCardinality(const binder::Expression& nodeID, cardinality_t card); + + cardinality_t estimateScanNode(const LogicalOperator& op) const; + cardinality_t estimateHashJoin(const std::vector& joinConditions, + const LogicalOperator& probeOp, const LogicalOperator& buildOp) const; + cardinality_t estimateCrossProduct(const LogicalOperator& probeOp, + const LogicalOperator& buildOp) const; + cardinality_t estimateIntersect(const binder::expression_vector& joinNodeIDs, + const LogicalOperator& probeOp, const std::vector& buildOps) const; + cardinality_t estimateFlatten(const LogicalOperator& childOp, + f_group_pos groupPosToFlatten) const; + cardinality_t estimateFilter(const LogicalOperator& childOp, + const binder::Expression& predicate) const; + cardinality_t estimateAggregate(const LogicalAggregate& op) const; + + double getExtensionRate(const binder::RelExpression& rel, + const binder::NodeExpression& boundNode, const transaction::Transaction* transaction) const; + cardinality_t multiply(double extensionRate, cardinality_t card) const; + +private: + cardinality_t getNodeIDDom(const std::string& nodeIDName) const; + cardinality_t getNumNodes(const transaction::Transaction* transaction, + const std::vector& tableIDs) const; + cardinality_t getNumRels(const transaction::Transaction* transaction, + const std::vector& tableIDs) const; + +private: + main::ClientContext* context; + // TODO(Guodong): Extend this to cover rel tables. + std::unordered_map nodeTableStats; + // The domain of nodeID is defined as the number of unique value of nodeID, i.e. num nodes. + std::unordered_map nodeIDName2dom; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/cost_model.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/cost_model.h new file mode 100644 index 0000000000..493d566452 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/cost_model.h @@ -0,0 +1,24 @@ +#pragma once + +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace planner { + +class CostModel { +public: + static uint64_t computeExtendCost(const LogicalPlan& childPlan); + static uint64_t computeHashJoinCost(const std::vector& joinConditions, + const LogicalPlan& probe, const LogicalPlan& build); + static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs, + const LogicalPlan& probe, const LogicalPlan& build); + static uint64_t computeMarkJoinCost(const std::vector& joinConditions, + const LogicalPlan& probe, const LogicalPlan& build); + static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs, + const LogicalPlan& probe, const LogicalPlan& build); + static uint64_t computeIntersectCost(const LogicalPlan& probePlan, + const std::vector& buildPlans); +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_order_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_order_util.h new file mode 100644 index 0000000000..f7e8cd3340 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_order_util.h @@ -0,0 +1,16 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct JoinOrderUtil { + // Although we do not flatten join key in Build operator computation. We still need to perform + // cardinality and cost estimation based on their flat cardinality. + static uint64_t getJoinKeysFlatCardinality(const binder::expression_vector& joinNodeIDs, + const LogicalOperator& buildOp); +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_plan_solver.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_plan_solver.h new file mode 100644 index 0000000000..99fc9b36e3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_plan_solver.h @@ -0,0 +1,28 @@ +#pragma once + +#include "join_tree.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +class JoinPlanSolver { +public: + explicit JoinPlanSolver(Planner* planner) : planner{planner} {} + + LogicalPlan solve(const JoinTree& joinTree); + +private: + LogicalPlan solveTreeNode(const JoinTreeNode& current, const JoinTreeNode* parent); + + LogicalPlan solveNodeScanTreeNode(const JoinTreeNode& treeNode); + LogicalPlan solveRelScanTreeNode(const JoinTreeNode& treeNode, const JoinTreeNode& parent); + LogicalPlan solveBinaryJoinTreeNode(const JoinTreeNode& treeNode); + LogicalPlan solveMultiwayJoinTreeNode(const JoinTreeNode& treeNode); + +private: + Planner* planner; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_tree.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_tree.h new file mode 100644 index 0000000000..a2d1efaa4e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_tree.h @@ -0,0 +1,99 @@ +#pragma once + +#include "binder/expression/node_expression.h" + +namespace lbug { +namespace planner { + +enum class TreeNodeType : uint8_t { + NODE_SCAN = 0, + REL_SCAN = 1, + BINARY_JOIN = 5, + MULTIWAY_JOIN = 6, +}; + +struct TreeNodeTypeUtils { + static std::string toString(TreeNodeType type); +}; + +struct ExtraTreeNodeInfo { + virtual ~ExtraTreeNodeInfo() = default; + + virtual std::unique_ptr copy() const = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct ExtraJoinTreeNodeInfo : ExtraTreeNodeInfo { + std::vector> joinNodes; + binder::expression_vector predicates; + + explicit ExtraJoinTreeNodeInfo(std::shared_ptr joinNode) { + joinNodes.push_back(std::move(joinNode)); + } + explicit ExtraJoinTreeNodeInfo(std::vector> joinNodes) + : joinNodes{std::move(joinNodes)} {} + ExtraJoinTreeNodeInfo(const ExtraJoinTreeNodeInfo& other) + : joinNodes{other.joinNodes}, predicates{other.predicates} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct NodeRelScanInfo { + std::shared_ptr nodeOrRel; + binder::expression_vector properties; + binder::expression_vector predicates; + + NodeRelScanInfo(std::shared_ptr nodeOrRel, + binder::expression_vector properties) + : nodeOrRel{std::move(nodeOrRel)}, properties{std::move(properties)} {} +}; + +struct ExtraScanTreeNodeInfo : ExtraTreeNodeInfo { + std::unique_ptr nodeInfo; + std::vector relInfos; + binder::expression_vector predicates; + + ExtraScanTreeNodeInfo() = default; + ExtraScanTreeNodeInfo(const ExtraScanTreeNodeInfo& other) + : nodeInfo{std::make_unique(*other.nodeInfo)}, relInfos{other.relInfos} {} + + void merge(const ExtraScanTreeNodeInfo& other); + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct JoinTreeNode { + TreeNodeType type; + std::unique_ptr extraInfo; + std::vector> children; + + JoinTreeNode(TreeNodeType type, std::unique_ptr extraInfo) + : type{type}, extraInfo{std::move(extraInfo)} {} + DELETE_COPY_DEFAULT_MOVE(JoinTreeNode); + + std::string toString() const; + + void addChild(std::shared_ptr child) { children.push_back(std::move(child)); } +}; + +struct JoinTree { + std::shared_ptr root; + explicit JoinTree(std::shared_ptr root) : root{std::move(root)} {} + + JoinTree(const JoinTree& other) : root{other.root} {} +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_tree_constructor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_tree_constructor.h new file mode 100644 index 0000000000..6f7f09b714 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order/join_tree_constructor.h @@ -0,0 +1,44 @@ +#pragma once + +#include "join_tree.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +class PropertyExprCollection; + +class JoinTreeConstructor { +public: + JoinTreeConstructor(const binder::QueryGraph& queryGraph, + const PropertyExprCollection& propertyCollection, binder::expression_vector predicates, + const QueryGraphPlanningInfo& planningInfo) + : queryGraph{queryGraph}, propertyCollection{propertyCollection}, + queryGraphPredicates{std::move(predicates)}, planningInfo{planningInfo} {} + + JoinTree construct(std::shared_ptr root); + +private: + struct IntermediateResult { + std::shared_ptr treeNode = nullptr; + binder::SubqueryGraph subqueryGraph; + }; + + IntermediateResult constructTreeNode(std::shared_ptr hintNode); + IntermediateResult constructNodeScan(std::shared_ptr expr); + IntermediateResult constructRelScan(std::shared_ptr expr); + + std::shared_ptr tryConstructNestedLoopJoin( + std::vector> joinNodes, + const JoinTreeNode& leftRoot, const JoinTreeNode& rightRoot, + const binder::expression_vector& predicates); + +private: + const binder::QueryGraph& queryGraph; + const PropertyExprCollection& propertyCollection; + binder::expression_vector queryGraphPredicates; + const QueryGraphPlanningInfo& planningInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order_enumerator_context.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order_enumerator_context.h new file mode 100644 index 0000000000..77988383d7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/join_order_enumerator_context.h @@ -0,0 +1,52 @@ +#pragma once + +#include "planner/operator/logical_plan.h" +#include "planner/subplans_table.h" + +namespace lbug { +namespace planner { + +class JoinOrderEnumeratorContext { + friend class Planner; + +public: + JoinOrderEnumeratorContext() + : currentLevel{0}, maxLevel{0}, subPlansTable{std::make_unique()}, + queryGraph{nullptr} {} + DELETE_COPY_DEFAULT_MOVE(JoinOrderEnumeratorContext); + + void init(const binder::QueryGraph* queryGraph, const binder::expression_vector& predicates); + + binder::expression_vector getWhereExpressions() { return whereExpressionsSplitOnAND; } + + bool containPlans(const binder::SubqueryGraph& subqueryGraph) const { + return subPlansTable->containSubgraphPlans(subqueryGraph); + } + const std::vector& getPlans(const binder::SubqueryGraph& subqueryGraph) const { + return subPlansTable->getSubgraphPlans(subqueryGraph); + } + void addPlan(const binder::SubqueryGraph& subqueryGraph, LogicalPlan plan) { + subPlansTable->addPlan(subqueryGraph, std::move(plan)); + } + + binder::SubqueryGraph getEmptySubqueryGraph() const { + return binder::SubqueryGraph(*queryGraph); + } + binder::SubqueryGraph getFullyMatchedSubqueryGraph() const; + + const binder::QueryGraph* getQueryGraph() { return queryGraph; } + + void resetState(); + +private: + binder::expression_vector whereExpressionsSplitOnAND; + + uint32_t currentLevel; + uint32_t maxLevel; + + std::unique_ptr subPlansTable; + const binder::QueryGraph* queryGraph; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_alter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_alter.h new file mode 100644 index 0000000000..19e3d41063 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_alter.h @@ -0,0 +1,47 @@ +#pragma once + +#include "binder/ddl/bound_alter_info.h" +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace planner { + +struct LogicalAlterPrintInfo final : OPPrintInfo { + binder::BoundAlterInfo info; + + explicit LogicalAlterPrintInfo(binder::BoundAlterInfo info) : info{std::move(info)} {} + + std::string toString() const override { return info.toString(); } + + std::unique_ptr copy() const override { + return std::unique_ptr(new LogicalAlterPrintInfo(*this)); + } + + LogicalAlterPrintInfo(const LogicalAlterPrintInfo& other) : info{other.info.copy()} {} +}; + +class LogicalAlter final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::ALTER; + +public: + explicit LogicalAlter(binder::BoundAlterInfo info) + : LogicalSimple{type_}, info{std::move(info)} {} + + std::string getExpressionsForPrinting() const override { return info.tableName; } + + const binder::BoundAlterInfo* getInfo() const { return &info; } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(info.copy()); + } + + std::unique_ptr copy() override { + return std::make_unique(info.copy()); + } + +private: + binder::BoundAlterInfo info; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_sequence.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_sequence.h new file mode 100644 index 0000000000..04859ed899 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_sequence.h @@ -0,0 +1,51 @@ +#pragma once + +#include "binder/ddl/bound_create_sequence_info.h" +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace planner { + +struct LogicalCreateSequencePrintInfo final : OPPrintInfo { + std::string sequenceName; + + explicit LogicalCreateSequencePrintInfo(std::string sequenceName) + : sequenceName(std::move(sequenceName)) {} + + std::string toString() const override { return "Sequence: " + sequenceName; }; + + std::unique_ptr copy() const override { + return std::unique_ptr( + new LogicalCreateSequencePrintInfo(*this)); + } + +private: + LogicalCreateSequencePrintInfo(const LogicalCreateSequencePrintInfo& other) + : OPPrintInfo(other), sequenceName(other.sequenceName) {} +}; + +class LogicalCreateSequence : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::CREATE_SEQUENCE; + +public: + explicit LogicalCreateSequence(binder::BoundCreateSequenceInfo info) + : LogicalSimple{type_}, info{std::move(info)} {} + + std::string getExpressionsForPrinting() const override { return info.sequenceName; } + + binder::BoundCreateSequenceInfo getInfo() const { return info.copy(); } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(info.sequenceName); + } + + std::unique_ptr copy() final { + return std::make_unique(info.copy()); + } + +private: + binder::BoundCreateSequenceInfo info; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_table.h new file mode 100644 index 0000000000..8a93c9a4f0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_table.h @@ -0,0 +1,49 @@ +#pragma once + +#include "binder/ddl/bound_create_table_info.h" +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace planner { + +struct LogicalCreateTablePrintInfo final : OPPrintInfo { + binder::BoundCreateTableInfo info; + + explicit LogicalCreateTablePrintInfo(binder::BoundCreateTableInfo info) + : info{std::move(info)} {} + + std::string toString() const override { return info.toString(); } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + + LogicalCreateTablePrintInfo(const LogicalCreateTablePrintInfo& other) + : info{other.info.copy()} {} +}; + +class LogicalCreateTable final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::CREATE_TABLE; + +public: + explicit LogicalCreateTable(binder::BoundCreateTableInfo info) + : LogicalSimple{type_}, info{std::move(info)} {} + + std::string getExpressionsForPrinting() const override { return info.tableName; } + + const binder::BoundCreateTableInfo* getInfo() const { return &info; } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(info.copy()); + } + + std::unique_ptr copy() override { + return std::make_unique(info.copy()); + } + +private: + binder::BoundCreateTableInfo info; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_type.h new file mode 100644 index 0000000000..a2a1f0c7e9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_create_type.h @@ -0,0 +1,51 @@ +#pragma once + +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace planner { + +struct LogicalCreateTypePrintInfo final : OPPrintInfo { + std::string typeName; + std::string type; + + LogicalCreateTypePrintInfo(std::string typeName, std::string type) + : typeName(std::move(typeName)), type(std::move(type)) {} + + std::string toString() const override { return typeName + " As " + type; }; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LogicalCreateTypePrintInfo(*this)); + } + +private: + LogicalCreateTypePrintInfo(const LogicalCreateTypePrintInfo& other) + : OPPrintInfo(other), typeName(other.typeName), type(other.type) {} +}; + +class LogicalCreateType : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::CREATE_TYPE; + +public: + LogicalCreateType(std::string typeName, common::LogicalType type) + : LogicalSimple{type_}, typeName{std::move(typeName)}, type{std::move(type)} {} + + std::string getExpressionsForPrinting() const override { return typeName; } + + const common::LogicalType& getType() const { return type; } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(typeName, type.toString()); + } + + std::unique_ptr copy() final { + return std::make_unique(typeName, type.copy()); + } + +private: + std::string typeName; + common::LogicalType type; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_drop.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_drop.h new file mode 100644 index 0000000000..0fdf1224c7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/ddl/logical_drop.h @@ -0,0 +1,41 @@ +#pragma once + +#include "parser/ddl/drop_info.h" +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace planner { + +struct LogicalDropPrintInfo : OPPrintInfo { + std::string name; + + explicit LogicalDropPrintInfo(std::string name) : name{std::move(name)} {} + + std::string toString() const override { return name; } +}; + +class LogicalDrop : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::DROP; + +public: + explicit LogicalDrop(parser::DropInfo dropInfo) + : LogicalSimple{type_}, dropInfo{std::move(dropInfo)} {} + + std::string getExpressionsForPrinting() const override { return dropInfo.name; } + + const parser::DropInfo& getDropInfo() const { return dropInfo; } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(dropInfo.name); + } + + std::unique_ptr copy() override { + return std::make_unique(dropInfo); + } + +private: + parser::DropInfo dropInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/base_logical_extend.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/base_logical_extend.h new file mode 100644 index 0000000000..f2ae3cb874 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/base_logical_extend.h @@ -0,0 +1,85 @@ +#pragma once + +#include "binder/expression/rel_expression.h" +#include "common/enums/extend_direction.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct BaseLogicalExtendPrintInfo : OPPrintInfo { + // Start node of extension. + std::shared_ptr boundNode; + // End node of extension. + std::shared_ptr nbrNode; + std::shared_ptr rel; + common::ExtendDirection direction; + + BaseLogicalExtendPrintInfo(std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + common::ExtendDirection direction) + : boundNode{std::move(boundNode)}, nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, + direction{direction} {} + + std::string toString() const override { + switch (direction) { + case common::ExtendDirection::FWD: { + return "(" + boundNode->toString() + ")-[" + rel->toString() + "]->(" + + nbrNode->toString() + ")"; + } + case common::ExtendDirection::BWD: { + return "(" + nbrNode->toString() + ")-[" + rel->toString() + "]->(" + + boundNode->toString() + ")"; + } + case common::ExtendDirection::BOTH: { + return "(" + boundNode->toString() + ")-[" + rel->toString() + "]-(" + + nbrNode->toString() + ")"; + } + default: { + KU_UNREACHABLE; + } + } + } +}; + +class BaseLogicalExtend : public LogicalOperator { +public: + BaseLogicalExtend(LogicalOperatorType operatorType, + std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + common::ExtendDirection direction, bool extendFromSource_, + std::shared_ptr child) + : LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)}, + nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction}, + extendFromSource_{extendFromSource_} {} + + std::shared_ptr getBoundNode() const { return boundNode; } + std::shared_ptr getNbrNode() const { return nbrNode; } + std::shared_ptr getRel() const { return rel; } + bool isRecursive() const { return rel->isRecursive(); } + common::ExtendDirection getDirection() const { return direction; } + + bool extendFromSourceNode() const { return extendFromSource_; } + + virtual f_group_pos_set getGroupsPosToFlatten() = 0; + + std::string getExpressionsForPrinting() const override; + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(boundNode, nbrNode, rel, direction); + } + +protected: + // Start node of extension. + std::shared_ptr boundNode; + // End node of extension. + std::shared_ptr nbrNode; + std::shared_ptr rel; + common::ExtendDirection direction; + // Ideally we should check this by *boundNode == *rel->getSrcNode() + // This is currently not doable due to recursive plan not setting src node correctly. + bool extendFromSource_; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/logical_extend.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/logical_extend.h new file mode 100644 index 0000000000..fcb995d7ad --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/logical_extend.h @@ -0,0 +1,47 @@ +#pragma once + +#include "planner/operator/extend/base_logical_extend.h" +#include "storage/predicate/column_predicate.h" + +namespace lbug { +namespace planner { + +class LogicalExtend final : public BaseLogicalExtend { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::EXTEND; + +public: + LogicalExtend(std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + common::ExtendDirection direction, bool extendFromSource, + binder::expression_vector properties, std::shared_ptr child, + common::cardinality_t cardinality = 0) + : BaseLogicalExtend{type_, std::move(boundNode), std::move(nbrNode), std::move(rel), + direction, extendFromSource, std::move(child)}, + scanNbrID{true}, properties{std::move(properties)} { + this->cardinality = cardinality; + } + + f_group_pos_set getGroupsPosToFlatten() override { return f_group_pos_set{}; } + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + binder::expression_vector getProperties() const { return properties; } + void setPropertyPredicates(std::vector predicates) { + propertyPredicates = std::move(predicates); + } + const std::vector& getPropertyPredicates() const { + return propertyPredicates; + } + void setScanNbrID(bool scanNbrID_) { scanNbrID = scanNbrID_; } + bool shouldScanNbrID() const { return scanNbrID; } + + std::unique_ptr copy() override; + +private: + bool scanNbrID; + binder::expression_vector properties; + std::vector propertyPredicates; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/logical_recursive_extend.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/logical_recursive_extend.h new file mode 100644 index 0000000000..7669e4fa04 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/logical_recursive_extend.h @@ -0,0 +1,64 @@ +#pragma once + +#include "function/gds/rec_joins.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalRecursiveExtend final : public LogicalOperator { + static constexpr LogicalOperatorType operatorType_ = LogicalOperatorType::RECURSIVE_EXTEND; + +public: + LogicalRecursiveExtend(std::unique_ptr function, + const function::RJBindData& bindData, binder::expression_vector resultColumns) + : LogicalOperator{operatorType_}, function{std::move(function)}, bindData{bindData}, + resultColumns{std::move(resultColumns)}, limitNum{common::INVALID_LIMIT} {} + + void computeFlatSchema() override; + void computeFactorizedSchema() override; + + void setFunction(std::unique_ptr func) { function = std::move(func); } + const function::RJAlgorithm& getFunction() const { return *function; } + + const function::RJBindData& getBindData() const { return bindData; } + function::RJBindData& getBindDataUnsafe() { return bindData; } + + void setResultColumns(binder::expression_vector exprs) { resultColumns = std::move(exprs); } + binder::expression_vector getResultColumns() const { return resultColumns; } + + void setLimitNum(common::offset_t num) { limitNum = num; } + common::offset_t getLimitNum() const { return limitNum; } + + bool hasInputNodeMask() const { return hasInputNodeMask_; } + void setInputNodeMask() { hasInputNodeMask_ = true; } + + bool hasOutputNodeMask() const { return hasOutputNodeMask_; } + void setOutputNodeMask() { hasOutputNodeMask_ = true; } + + bool hasNodePredicate() const { return !children.empty(); } + + std::string getExpressionsForPrinting() const override { return function->getFunctionName(); } + + std::unique_ptr copy() override { + auto result = + std::make_unique(function->copy(), bindData, resultColumns); + result->limitNum = limitNum; + result->hasInputNodeMask_ = hasInputNodeMask_; + result->hasOutputNodeMask_ = hasOutputNodeMask_; + return result; + } + +private: + std::unique_ptr function; + function::RJBindData bindData; + binder::expression_vector resultColumns; + + common::offset_t limitNum; // TODO: remove this once recursive extend is pipelined. + + bool hasInputNodeMask_ = false; + bool hasOutputNodeMask_ = false; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/recursive_join_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/recursive_join_type.h new file mode 100644 index 0000000000..83c7c46172 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/extend/recursive_join_type.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace lbug { +namespace planner { + +enum class RecursiveJoinType : uint8_t { + TRACK_NONE = 0, + TRACK_PATH = 1, +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/factorization/flatten_resolver.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/factorization/flatten_resolver.h new file mode 100644 index 0000000000..f7d301a9ac --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/factorization/flatten_resolver.h @@ -0,0 +1,60 @@ +#pragma once + +#include "planner/operator/schema.h" + +namespace lbug { +namespace planner { + +class GroupDependencyAnalyzer; + +struct FlattenAllButOne { + static std::pair getGroupsPosToFlatten( + const binder::expression_vector& exprs, const Schema& schema); + static f_group_pos_set getGroupsPosToFlatten(std::shared_ptr expr, + const Schema& schema); + // Assume no requiredFlatGroups + static f_group_pos_set getGroupsPosToFlatten( + const std::unordered_set& dependentGroups, const Schema& schema); +}; + +struct FlattenAll { + static f_group_pos_set getGroupsPosToFlatten(const binder::expression_vector& exprs, + const Schema& schema); + static f_group_pos_set getGroupsPosToFlatten(std::shared_ptr expr, + const Schema& schema); + static f_group_pos_set getGroupsPosToFlatten( + const std::unordered_set& dependentGroups, const Schema& schema); +}; + +class GroupDependencyAnalyzer { +public: + GroupDependencyAnalyzer(bool collectDependentExpr, const Schema& schema) + : collectDependentExpr{collectDependentExpr}, schema{schema} {} + + binder::expression_vector getDependentExprs() const { + return binder::expression_vector{dependentExprs.begin(), dependentExprs.end()}; + } + std::unordered_set getDependentGroups() const { return dependentGroups; } + std::unordered_set getRequiredFlatGroups() const { return requiredFlatGroups; } + + void visit(std::shared_ptr expr); + +private: + void visitFunction(std::shared_ptr expr); + + void visitCase(std::shared_ptr expr); + + void visitNodeOrRel(std::shared_ptr expr); + + void visitSubquery(std::shared_ptr expr); + +private: + bool collectDependentExpr; + const Schema& schema; + std::unordered_set dependentGroups; + std::unordered_set requiredFlatGroups; + binder::expression_set dependentExprs; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/factorization/sink_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/factorization/sink_util.h new file mode 100644 index 0000000000..f8b121fdda --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/factorization/sink_util.h @@ -0,0 +1,28 @@ +#pragma once + +#include "planner/operator/schema.h" + +namespace lbug { +namespace planner { + +// This class contains the logic for re-computing factorization structure after sinking +class SinkOperatorUtil { +public: + static void mergeSchema(const Schema& inputSchema, + const binder::expression_vector& expressionsToMerge, Schema& resultSchema); + + static void recomputeSchema(const Schema& inputSchema, + const binder::expression_vector& expressionsToMerge, Schema& resultSchema); + +private: + static std::unordered_map getUnFlatPayloadsPerGroup( + const Schema& schema, const binder::expression_vector& payloads); + + static binder::expression_vector getFlatPayloads(const Schema& schema, + const binder::expression_vector& payloads); + + static uint32_t appendPayloadsToNewGroup(Schema& schema, binder::expression_vector& payloads); +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_accumulate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_accumulate.h new file mode 100644 index 0000000000..60c2e25adc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_accumulate.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common/enums/accumulate_type.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalAccumulate final : public LogicalOperator { + static constexpr LogicalOperatorType type = LogicalOperatorType::ACCUMULATE; + +public: + LogicalAccumulate(common::AccumulateType accumulateType, binder::expression_vector flatExprs, + std::shared_ptr mark, std::shared_ptr child) + : LogicalOperator{type, std::move(child)}, accumulateType{accumulateType}, + flatExprs{std::move(flatExprs)}, mark{std::move(mark)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + f_group_pos_set getGroupPositionsToFlatten() const; + + std::string getExpressionsForPrinting() const override { return {}; } + + common::AccumulateType getAccumulateType() const { return accumulateType; } + binder::expression_vector getPayloads() const { + return children[0]->getSchema()->getExpressionsInScope(); + } + bool hasMark() const { return mark != nullptr; } + std::shared_ptr getMark() const { return mark; } + + std::unique_ptr copy() override { + return make_unique(accumulateType, flatExprs, mark, children[0]->copy()); + } + +private: + common::AccumulateType accumulateType; + binder::expression_vector flatExprs; + // Accumulate may be used for optional match, e.g. OPTIONAL MATCH (a). In such case, we use + // mark to determine if at least one pattern is found. + std::shared_ptr mark; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_aggregate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_aggregate.h new file mode 100644 index 0000000000..c8c45571b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_aggregate.h @@ -0,0 +1,79 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalAggregatePrintInfo final : OPPrintInfo { + binder::expression_vector keys; + binder::expression_vector aggregates; + + LogicalAggregatePrintInfo(binder::expression_vector keys, binder::expression_vector aggregates) + : keys(std::move(keys)), aggregates(std::move(aggregates)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LogicalAggregatePrintInfo(*this)); + } + +private: + LogicalAggregatePrintInfo(const LogicalAggregatePrintInfo& other) + : OPPrintInfo(other), keys(other.keys), aggregates(other.aggregates) {} +}; + +class LogicalAggregate final : public LogicalOperator { + static constexpr LogicalOperatorType operatorType_ = LogicalOperatorType::AGGREGATE; + +public: + LogicalAggregate(binder::expression_vector keys, binder::expression_vector aggregates, + std::shared_ptr child) + : LogicalOperator{operatorType_, std::move(child)}, keys{std::move(keys)}, + aggregates{std::move(aggregates)} {} + LogicalAggregate(binder::expression_vector keys, binder::expression_vector dependentKeys, + binder::expression_vector aggregates, std::shared_ptr child, + common::cardinality_t cardinality) + : LogicalOperator{operatorType_, std::move(child), cardinality}, keys{std::move(keys)}, + dependentKeys{std::move(dependentKeys)}, aggregates{std::move(aggregates)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + f_group_pos_set getGroupsPosToFlatten(); + + std::string getExpressionsForPrinting() const override; + + bool hasKeys() const { return !keys.empty(); } + binder::expression_vector getKeys() const { return keys; } + void setKeys(binder::expression_vector expressions) { keys = std::move(expressions); } + binder::expression_vector getDependentKeys() const { return dependentKeys; } + void setDependentKeys(binder::expression_vector expressions) { + dependentKeys = std::move(expressions); + } + binder::expression_vector getAllKeys() const { + binder::expression_vector result; + result.insert(result.end(), keys.begin(), keys.end()); + result.insert(result.end(), dependentKeys.begin(), dependentKeys.end()); + return result; + } + binder::expression_vector getAggregates() const { return aggregates; } + + std::unique_ptr copy() override { + return make_unique(keys, dependentKeys, aggregates, children[0]->copy(), + cardinality); + } + +private: + void insertAllExpressionsToGroupAndScope(f_group_pos groupPos); + +private: + binder::expression_vector keys; + // A dependentKeyExpression depend on a keyExpression (e.g. a.age depends on a.ID) and will not + // be treated as a hash key during hash aggregation. + binder::expression_vector dependentKeys; + binder::expression_vector aggregates; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_create_macro.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_create_macro.h new file mode 100644 index 0000000000..3d4e0a715e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_create_macro.h @@ -0,0 +1,51 @@ +#pragma once + +#include "function/scalar_macro_function.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalCreateMacroPrintInfo final : OPPrintInfo { + std::string macroName; + + explicit LogicalCreateMacroPrintInfo(std::string macroName) : macroName(std::move(macroName)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LogicalCreateMacroPrintInfo(*this)); + } + +private: + LogicalCreateMacroPrintInfo(const LogicalCreateMacroPrintInfo& other) + : OPPrintInfo(other), macroName(other.macroName) {} +}; + +class LogicalCreateMacro final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::CREATE_MACRO; + +public: + LogicalCreateMacro(std::string macroName, std::unique_ptr macro) + : LogicalOperator{type_}, macroName{std::move(macroName)}, macro{std::move(macro)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getMacroName() const { return macroName; } + + std::unique_ptr getMacro() const { return macro->copy(); } + + std::string getExpressionsForPrinting() const override { return macroName; } + + std::unique_ptr copy() override { + return std::make_unique(macroName, macro->copy()); + } + +private: + std::string macroName; + std::shared_ptr macro; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_cross_product.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_cross_product.h new file mode 100644 index 0000000000..383ce1a7a8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_cross_product.h @@ -0,0 +1,48 @@ +#pragma once + +#include "common/enums/accumulate_type.h" +#include "planner/operator/logical_operator.h" +#include "planner/operator/sip/side_way_info_passing.h" + +namespace lbug { +namespace planner { + +class LogicalCrossProduct final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::CROSS_PRODUCT; + +public: + LogicalCrossProduct(common::AccumulateType accumulateType, + std::shared_ptr mark, std::shared_ptr probeChild, + std::shared_ptr buildChild, common::cardinality_t cardinality) + : LogicalOperator{type_, std::move(probeChild), std::move(buildChild)}, + accumulateType{accumulateType}, mark{std::move(mark)} { + this->cardinality = cardinality; + } + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return std::string(); } + + common::AccumulateType getAccumulateType() const { return accumulateType; } + bool hasMark() const { return mark != nullptr; } + std::shared_ptr getMark() const { return mark; } + + SIPInfo& getSIPInfoUnsafe() { return sipInfo; } + SIPInfo getSIPInfo() const { return sipInfo; } + + std::unique_ptr copy() override { + auto op = make_unique(accumulateType, mark, children[0]->copy(), + children[1]->copy(), cardinality); + op->sipInfo = sipInfo; + return op; + } + +private: + common::AccumulateType accumulateType; + std::shared_ptr mark; + SIPInfo sipInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_distinct.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_distinct.h new file mode 100644 index 0000000000..5760a35415 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_distinct.h @@ -0,0 +1,55 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalDistinct final : public LogicalOperator { +public: + LogicalDistinct(binder::expression_vector keys, std::shared_ptr child) + : LogicalDistinct{LogicalOperatorType::DISTINCT, keys, binder::expression_vector{}, + std::move(child)} {} + LogicalDistinct(LogicalOperatorType type, binder::expression_vector keys, + binder::expression_vector payloads, std::shared_ptr child) + : LogicalOperator{type, std::move(child)}, keys{std::move(keys)}, + payloads{std::move(payloads)}, skipNum{UINT64_MAX}, limitNum{UINT64_MAX} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + virtual f_group_pos_set getGroupsPosToFlatten(); + + std::string getExpressionsForPrinting() const override; + + binder::expression_vector getKeys() const { return keys; } + void setKeys(binder::expression_vector expressions) { keys = std::move(expressions); } + binder::expression_vector getPayloads() const { return payloads; } + void setPayloads(binder::expression_vector expressions) { payloads = std::move(expressions); } + + void setSkipNum(common::offset_t num) { skipNum = num; } + bool hasSkipNum() const { return skipNum != common::INVALID_LIMIT; } + common::offset_t getSkipNum() const { return skipNum; } + void setLimitNum(common::offset_t num) { limitNum = num; } + bool hasLimitNum() const { return limitNum != common::INVALID_LIMIT; } + common::offset_t getLimitNum() const { return limitNum; } + + std::unique_ptr copy() override { + return make_unique(operatorType, keys, payloads, children[0]->copy()); + } + +protected: + binder::expression_vector getKeysAndPayloads() const; + +protected: + binder::expression_vector keys; + // Payloads meaning additional keys that are functional dependent on the keys above. + binder::expression_vector payloads; + +private: + common::offset_t skipNum; + common::offset_t limitNum; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_dummy_sink.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_dummy_sink.h new file mode 100644 index 0000000000..5b1d292a72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_dummy_sink.h @@ -0,0 +1,25 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalDummySink final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::DUMMY_SINK; + +public: + explicit LogicalDummySink(std::shared_ptr child) + : LogicalOperator{type_, {std::move(child)}} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return ""; } + std::unique_ptr copy() override { + return std::make_unique(children[0]->copy()); + } +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_empty_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_empty_result.h new file mode 100644 index 0000000000..8916c0d00d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_empty_result.h @@ -0,0 +1,36 @@ +#pragma once + +#include "logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalEmptyResult final : public LogicalOperator { +public: + explicit LogicalEmptyResult(const Schema& schema) + : LogicalOperator{LogicalOperatorType::EMPTY_RESULT}, originalSchema{schema.copy()} { + this->schema = schema.copy(); + } + + void computeFactorizedSchema() override { schema = originalSchema->copy(); } + void computeFlatSchema() override { + createEmptySchema(); + schema->createGroup(); + for (auto& e : originalSchema->getExpressionsInScope()) { + schema->insertToGroupAndScope(e, 0); + } + } + + std::string getExpressionsForPrinting() const override { return std::string{}; } + + std::unique_ptr copy() override { + return std::make_unique(*originalSchema); + } + +private: + // The original schema of the plan that generates empty result. + std::unique_ptr originalSchema; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_explain.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_explain.h new file mode 100644 index 0000000000..b9a146844a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_explain.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include "common/enums/explain_type.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalExplain final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::EXPLAIN; + +public: + LogicalExplain(std::shared_ptr child, common::ExplainType explainType, + binder::expression_vector innerResultColumns) + : LogicalOperator{type_, std::move(child)}, explainType{explainType}, + innerResultColumns{std::move(innerResultColumns)} {} + + void computeSchema(); + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return ""; } + + common::ExplainType getExplainType() const { return explainType; } + + binder::expression_vector getInnerResultColumns() const { return innerResultColumns; } + + std::unique_ptr copy() override { + return std::make_unique(children[0], explainType, innerResultColumns); + } + +private: + common::ExplainType explainType; + binder::expression_vector innerResultColumns; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_filter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_filter.h new file mode 100644 index 0000000000..3a31c7af31 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_filter.h @@ -0,0 +1,47 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalFilterPrintInfo final : OPPrintInfo { + std::shared_ptr expression; + explicit LogicalFilterPrintInfo(std::shared_ptr expression) + : expression{std::move(expression)} {} + std::string toString() const override { return expression->toString(); } +}; + +class LogicalFilter final : public LogicalOperator { +public: + LogicalFilter(std::shared_ptr expression, + std::shared_ptr child, common::cardinality_t cardinality = 0) + : LogicalOperator{LogicalOperatorType::FILTER, std::move(child), cardinality}, + expression{std::move(expression)} {} + + inline void computeFactorizedSchema() override { copyChildSchema(0); } + inline void computeFlatSchema() override { copyChildSchema(0); } + + f_group_pos_set getGroupsPosToFlatten(); + + inline std::string getExpressionsForPrinting() const override { return expression->toString(); } + + inline std::shared_ptr getPredicate() const { return expression; } + + f_group_pos getGroupPosToSelect() const; + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(expression); + } + + inline std::unique_ptr copy() override { + return make_unique(expression, children[0]->copy(), cardinality); + } + +private: + std::shared_ptr expression; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_flatten.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_flatten.h new file mode 100644 index 0000000000..65d27233bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_flatten.h @@ -0,0 +1,31 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalFlatten final : public LogicalOperator { +public: + LogicalFlatten(f_group_pos groupPos, std::shared_ptr child, + common::cardinality_t cardinality) + : LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child), cardinality}, + groupPos{groupPos} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + inline std::string getExpressionsForPrinting() const override { return std::string{}; } + + inline f_group_pos getGroupPos() const { return groupPos; } + + inline std::unique_ptr copy() override { + return make_unique(groupPos, children[0]->copy(), cardinality); + } + +private: + f_group_pos groupPos; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_hash_join.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_hash_join.h new file mode 100644 index 0000000000..f6274073f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_hash_join.h @@ -0,0 +1,69 @@ +#pragma once + +#include "common/enums/join_type.h" +#include "logical_operator.h" +#include "planner/operator/sip/side_way_info_passing.h" + +namespace lbug { +namespace planner { + +// We only support equality comparison as join condition +using join_condition_t = binder::expression_pair; + +// Probe side on left, i.e. children[0]. Build side on right, i.e. children[1]. +class LBUG_API LogicalHashJoin final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::HASH_JOIN; + +public: + LogicalHashJoin(std::vector joinConditions, common::JoinType joinType, + std::shared_ptr mark, std::shared_ptr probeChild, + std::shared_ptr buildChild, common::cardinality_t cardinality = 0) + : LogicalOperator{type_, std::move(probeChild), std::move(buildChild)}, + joinConditions(std::move(joinConditions)), joinType{joinType}, mark{std::move(mark)} { + this->cardinality = cardinality; + } + + f_group_pos_set getGroupsPosToFlattenOnProbeSide(); + f_group_pos_set getGroupsPosToFlattenOnBuildSide(); + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override; + + binder::expression_vector getExpressionsToMaterialize() const; + + binder::expression_vector getJoinNodeIDs() const; + static binder::expression_vector getJoinNodeIDs( + const std::vector& joinConditions); + + std::vector getJoinConditions() const { return joinConditions; } + common::JoinType getJoinType() const { return joinType; } + bool hasMark() const { return mark != nullptr; } + std::shared_ptr getMark() const { return mark; } + + SIPInfo& getSIPInfoUnsafe() { return sipInfo; } + SIPInfo getSIPInfo() const { return sipInfo; } + + std::unique_ptr copy() override; + + // Flat probe side key group in either of the following two cases: + // 1. there are multiple join nodes; + // 2. if the build side contains more than one group or the build side has projected out data + // chunks, which may increase the multiplicity of data chunks in the build side. The key is to + // keep probe side key unflat only when we know that there is only 0 or 1 match for each key. + // TODO(Guodong): when the build side has only flat payloads, we should consider getting rid of + // flattening probe key, instead duplicating keys as in vectorized processing if necessary. + bool requireFlatProbeKeys() const; + + static bool isNodeIDOnlyJoin(const std::vector& joinConditions); + +private: + std::vector joinConditions; + common::JoinType joinType; + std::shared_ptr mark; // when joinType is Mark or Left + SIPInfo sipInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_intersect.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_intersect.h new file mode 100644 index 0000000000..b3264a0796 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_intersect.h @@ -0,0 +1,50 @@ +#pragma once + +#include "planner/operator/logical_operator.h" +#include "planner/operator/sip/side_way_info_passing.h" + +namespace lbug { +namespace planner { + +class LogicalIntersect final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::INTERSECT; + +public: + LogicalIntersect(std::shared_ptr intersectNodeID, + binder::expression_vector keyNodeIDs, std::shared_ptr probeChild, + std::vector> buildChildren, + common::cardinality_t cardinality = 0) + : LogicalOperator{type_, std::move(probeChild)}, + intersectNodeID{std::move(intersectNodeID)}, keyNodeIDs{std::move(keyNodeIDs)} { + for (auto& child : buildChildren) { + children.push_back(std::move(child)); + } + this->cardinality = cardinality; + } + + f_group_pos_set getGroupsPosToFlattenOnProbeSide(); + f_group_pos_set getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx); + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return intersectNodeID->toString(); } + + std::shared_ptr getIntersectNodeID() const { return intersectNodeID; } + uint32_t getNumBuilds() const { return keyNodeIDs.size(); } + binder::expression_vector getKeyNodeIDs() const { return keyNodeIDs; } + std::shared_ptr getKeyNodeID(uint32_t idx) const { return keyNodeIDs[idx]; } + + SIPInfo& getSIPInfoUnsafe() { return sipInfo; } + SIPInfo getSIPInfo() const { return sipInfo; } + + std::unique_ptr copy() override; + +private: + std::shared_ptr intersectNodeID; + binder::expression_vector keyNodeIDs; + SIPInfo sipInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_limit.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_limit.h new file mode 100644 index 0000000000..7169fa3cf6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_limit.h @@ -0,0 +1,46 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalLimit final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::LIMIT; + +public: + LogicalLimit(std::shared_ptr skipNum, + std::shared_ptr limitNum, std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, skipNum{std::move(skipNum)}, + limitNum{std::move(limitNum)} {} + + f_group_pos_set getGroupsPosToFlatten(); + + void computeFactorizedSchema() override { copyChildSchema(0); } + void computeFlatSchema() override { copyChildSchema(0); } + + std::string getExpressionsForPrinting() const override; + + bool hasSkipNum() const { return skipNum != nullptr; } + std::shared_ptr getSkipNum() const { return skipNum; } + + bool hasLimitNum() const { return limitNum != nullptr; } + std::shared_ptr getLimitNum() const { return limitNum; } + + f_group_pos getGroupPosToSelect() const; + + std::unordered_set getGroupsPosToLimit() const { + return schema->getGroupsPosInScope(); + } + + std::unique_ptr copy() override { + return make_unique(skipNum, limitNum, children[0]->copy()); + } + +private: + std::shared_ptr skipNum; + std::shared_ptr limitNum; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_multiplcity_reducer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_multiplcity_reducer.h new file mode 100644 index 0000000000..4af60eaa56 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_multiplcity_reducer.h @@ -0,0 +1,24 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalMultiplicityReducer final : public LogicalOperator { +public: + explicit LogicalMultiplicityReducer(std::shared_ptr child) + : LogicalOperator(LogicalOperatorType::MULTIPLICITY_REDUCER, std::move(child)) {} + + inline void computeFactorizedSchema() override { copyChildSchema(0); } + inline void computeFlatSchema() override { copyChildSchema(0); } + + inline std::string getExpressionsForPrinting() const override { return std::string(); } + + inline std::unique_ptr copy() override { + return make_unique(children[0]->copy()); + } +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_node_label_filter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_node_label_filter.h new file mode 100644 index 0000000000..58440a1802 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_node_label_filter.h @@ -0,0 +1,33 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalNodeLabelFilter final : public LogicalOperator { +public: + LogicalNodeLabelFilter(std::shared_ptr nodeID, + std::unordered_set tableIDSet, std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::NODE_LABEL_FILTER, std::move(child)}, + nodeID{std::move(nodeID)}, tableIDSet{std::move(tableIDSet)} {} + + inline void computeFactorizedSchema() override { copyChildSchema(0); } + inline void computeFlatSchema() override { copyChildSchema(0); } + + inline std::string getExpressionsForPrinting() const override { return nodeID->toString(); } + + inline std::shared_ptr getNodeID() const { return nodeID; } + inline std::unordered_set getTableIDSet() const { return tableIDSet; } + + std::unique_ptr copy() override { + return std::make_unique(nodeID, tableIDSet, children[0]->copy()); + } + +private: + std::shared_ptr nodeID; + std::unordered_set tableIDSet; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_noop.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_noop.h new file mode 100644 index 0000000000..fbc67396bc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_noop.h @@ -0,0 +1,35 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +// Serve as a dummy parent (usually root) for a set of children that doesn't have a well-defined +// parent. E.g. CREATE TABLE AS, create table & copy. +class LogicalNoop : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::NOOP; + +public: + explicit LogicalNoop(common::idx_t messageChildIdx, + std::vector> children) + : LogicalOperator{type_, {std::move(children)}}, messageChildIdx{messageChildIdx} {} + + void computeFactorizedSchema() override { createEmptySchema(); } + void computeFlatSchema() override { createEmptySchema(); } + + common::idx_t getMessageChildIdx() const { return messageChildIdx; } + + std::string getExpressionsForPrinting() const override { return ""; } + + std::unique_ptr copy() override { + return std::make_unique(messageChildIdx, copyVector(children)); + } + +private: + // For create table as. Dummy sink is the last operator and should propagate return message. + common::idx_t messageChildIdx; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_operator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_operator.h new file mode 100644 index 0000000000..3be8831209 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_operator.h @@ -0,0 +1,149 @@ +#pragma once + +#include + +#include "common/cast.h" +#include "planner/operator/operator_print_info.h" +#include "planner/operator/schema.h" + +namespace lbug { +namespace planner { + +// This ENUM is sorted by alphabetical order. +enum class LogicalOperatorType : uint8_t { + ACCUMULATE, + AGGREGATE, + ALTER, + ATTACH_DATABASE, + COPY_FROM, + COPY_TO, + CREATE_MACRO, + CREATE_SEQUENCE, + CREATE_TABLE, + CREATE_TYPE, + CROSS_PRODUCT, + DELETE, + DETACH_DATABASE, + DISTINCT, + DROP, + DUMMY_SCAN, + DUMMY_SINK, + EMPTY_RESULT, + EXPLAIN, + EXPRESSIONS_SCAN, + EXTEND, + EXTENSION, + EXPORT_DATABASE, + FILTER, + FLATTEN, + HASH_JOIN, + IMPORT_DATABASE, + INDEX_LOOK_UP, + INTERSECT, + INSERT, + LIMIT, + MERGE, + MULTIPLICITY_REDUCER, + NODE_LABEL_FILTER, + NOOP, + ORDER_BY, + PARTITIONER, + PATH_PROPERTY_PROBE, + PROJECTION, + RECURSIVE_EXTEND, + SCAN_NODE_TABLE, + SEMI_MASKER, + SET_PROPERTY, + STANDALONE_CALL, + TABLE_FUNCTION_CALL, + TRANSACTION, + UNION_ALL, + UNWIND, + USE_DATABASE, + EXTENSION_CLAUSE, +}; + +class LogicalOperator; +using logical_op_vector_t = std::vector>; + +struct LogicalOperatorUtils { + static std::string logicalOperatorTypeToString(LogicalOperatorType type); + static bool isUpdate(LogicalOperatorType type); + static bool isAccHashJoin(const LogicalOperator& op); +}; + +class LBUG_API LogicalOperator { +public: + explicit LogicalOperator(LogicalOperatorType operatorType) + : operatorType{operatorType}, cardinality{1} {} + explicit LogicalOperator(LogicalOperatorType operatorType, + std::shared_ptr child, + std::optional cardinality = {}); + explicit LogicalOperator(LogicalOperatorType operatorType, + std::shared_ptr left, std::shared_ptr right); + explicit LogicalOperator(LogicalOperatorType operatorType, const logical_op_vector_t& children); + + virtual ~LogicalOperator() = default; + + uint32_t getNumChildren() const { return children.size(); } + std::shared_ptr getChild(uint64_t idx) const { return children[idx]; } + std::vector> getChildren() const { return children; } + void setChild(uint64_t idx, std::shared_ptr child) { + children[idx] = std::move(child); + } + void addChild(std::shared_ptr child) { children.push_back(std::move(child)); } + void setCardinality(common::cardinality_t cardinality_) { this->cardinality = cardinality_; } + + // Operator type. + LogicalOperatorType getOperatorType() const { return operatorType; } + bool hasUpdateRecursive(); + + // Schema + Schema* getSchema() const { return schema.get(); } + virtual void computeFactorizedSchema() = 0; + virtual void computeFlatSchema() = 0; + + // Printing. + virtual std::string getExpressionsForPrinting() const = 0; + // Print the sub-plan rooted at this operator. + virtual std::string toString(uint64_t depth = 0) const; + + virtual std::unique_ptr getPrintInfo() const { + return std::make_unique(); + } + common::cardinality_t getCardinality() const { return cardinality; } + + // TODO: remove this function once planner do not share operator across plans + virtual std::unique_ptr copy() = 0; + static logical_op_vector_t copy(const logical_op_vector_t& ops); + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + +protected: + void createEmptySchema() { schema = std::make_unique(); } + void copyChildSchema(uint32_t idx) { schema = children[idx]->getSchema()->copy(); } + +protected: + LogicalOperatorType operatorType; + std::unique_ptr schema; + logical_op_vector_t children; + common::cardinality_t cardinality; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_order_by.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_order_by.h new file mode 100644 index 0000000000..1513205dd9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_order_by.h @@ -0,0 +1,48 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalOrderBy final : public LogicalOperator { +public: + LogicalOrderBy(binder::expression_vector expressionsToOrderBy, std::vector sortOrders, + std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::ORDER_BY, std::move(child)}, + expressionsToOrderBy{std::move(expressionsToOrderBy)}, + isAscOrders{std::move(sortOrders)} {} + + f_group_pos_set getGroupsPosToFlatten(); + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override; + + binder::expression_vector getExpressionsToOrderBy() const { return expressionsToOrderBy; } + std::vector getIsAscOrders() const { return isAscOrders; } + + bool isTopK() const { return hasLimitNum(); } + + void setSkipNum(std::shared_ptr num) { skipNum = std::move(num); } + bool hasSkipNum() const { return skipNum != nullptr; } + std::shared_ptr getSkipNum() const { return skipNum; } + + void setLimitNum(std::shared_ptr num) { limitNum = std::move(num); } + bool hasLimitNum() const { return limitNum != nullptr; } + std::shared_ptr getLimitNum() const { return limitNum; } + + std::unique_ptr copy() override { + return make_unique(expressionsToOrderBy, isAscOrders, children[0]->copy()); + } + +private: + binder::expression_vector expressionsToOrderBy; + std::vector isAscOrders; + std::shared_ptr skipNum = nullptr; + std::shared_ptr limitNum = nullptr; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_partitioner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_partitioner.h new file mode 100644 index 0000000000..9dd5a88ae7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_partitioner.h @@ -0,0 +1,73 @@ +#pragma once + +#include "binder/copy/bound_copy_from.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalPartitioningInfo { + common::idx_t keyIdx; + + explicit LogicalPartitioningInfo(common::idx_t keyIdx) : keyIdx{keyIdx} {} + LogicalPartitioningInfo(const LogicalPartitioningInfo& other) : keyIdx{other.keyIdx} {} + + EXPLICIT_COPY_DEFAULT_MOVE(LogicalPartitioningInfo); +}; + +struct LogicalPartitionerInfo { + std::shared_ptr offset; + std::vector partitioningInfos; + + explicit LogicalPartitionerInfo(std::shared_ptr offset) + : offset{std::move(offset)} {} + LogicalPartitionerInfo(const LogicalPartitionerInfo& other) : offset{other.offset} { + for (auto& partitioningInfo : other.partitioningInfos) { + partitioningInfos.push_back(partitioningInfo.copy()); + } + } + + EXPLICIT_COPY_DEFAULT_MOVE(LogicalPartitionerInfo); + + common::idx_t getNumInfos() const { return partitioningInfos.size(); } + LogicalPartitioningInfo& getInfo(common::idx_t idx) { + KU_ASSERT(idx < partitioningInfos.size()); + return partitioningInfos[idx]; + } + const LogicalPartitioningInfo& getInfo(common::idx_t idx) const { + KU_ASSERT(idx < partitioningInfos.size()); + return partitioningInfos[idx]; + } +}; + +class LogicalPartitioner final : public LogicalOperator { + static constexpr LogicalOperatorType type = LogicalOperatorType::PARTITIONER; + +public: + LogicalPartitioner(LogicalPartitionerInfo info, binder::BoundCopyFromInfo copyFromInfo, + std::shared_ptr child) + : LogicalOperator{type, std::move(child)}, info{std::move(info)}, + copyFromInfo{std::move(copyFromInfo)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override; + + LogicalPartitionerInfo& getInfo() { return info; } + const LogicalPartitionerInfo& getInfo() const { return info; } + + std::unique_ptr copy() override { + return make_unique(info.copy(), copyFromInfo.copy(), + children[0]->copy()); + } + +private: + LogicalPartitionerInfo info; + +public: + binder::BoundCopyFromInfo copyFromInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_path_property_probe.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_path_property_probe.h new file mode 100644 index 0000000000..ee109987a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_path_property_probe.h @@ -0,0 +1,56 @@ +#pragma once + +#include "binder/expression/rel_expression.h" +#include "planner/operator/extend/recursive_join_type.h" +#include "planner/operator/logical_operator.h" +#include "planner/operator/sip/side_way_info_passing.h" + +namespace lbug { +namespace planner { + +class LogicalPathPropertyProbe : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::PATH_PROPERTY_PROBE; + +public: + LogicalPathPropertyProbe(std::shared_ptr rel, + std::shared_ptr probeChild, std::shared_ptr nodeChild, + std::shared_ptr relChild, RecursiveJoinType joinType) + : LogicalOperator{type_, std::move(probeChild)}, recursiveRel{std::move(rel)}, + nodeChild{std::move(nodeChild)}, relChild{std::move(relChild)}, joinType{joinType} {} + + void computeFactorizedSchema() final; + void computeFlatSchema() final; + + std::string getExpressionsForPrinting() const override { return recursiveRel->toString(); } + + std::shared_ptr getRel() const { return recursiveRel; } + std::shared_ptr getPathNodeIDs() const { return pathNodeIDs; } + std::shared_ptr getPathEdgeIDs() const { return pathEdgeIDs; } + + void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; } + RecursiveJoinType getJoinType() const { return joinType; } + + std::shared_ptr getNodeChild() const { return nodeChild; } + std::shared_ptr getRelChild() const { return relChild; } + + SIPInfo& getSIPInfoUnsafe() { return sipInfo; } + SIPInfo getSIPInfo() const { return sipInfo; } + + std::unique_ptr copy() override; + +private: + std::shared_ptr recursiveRel; + std::shared_ptr nodeChild; + std::shared_ptr relChild; + RecursiveJoinType joinType; + SIPInfo sipInfo; + +public: + common::ExtendDirection direction = common::ExtendDirection::FWD; + bool extendFromLeft = true; + std::shared_ptr pathNodeIDs; + std::shared_ptr pathEdgeIDs; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_plan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_plan.h new file mode 100644 index 0000000000..80df03118f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_plan.h @@ -0,0 +1,49 @@ +#pragma once + +#include "logical_operator.h" + +namespace lbug { +namespace planner { + +using cardinality_t = uint64_t; + +class LBUG_API LogicalPlan { + friend class CardinalityEstimator; + friend class CostModel; + +public: + LogicalPlan() : cost{0} {} + LogicalPlan(const LogicalPlan& other) : lastOperator{other.lastOperator}, cost{other.cost} {} + EXPLICIT_COPY_DEFAULT_MOVE(LogicalPlan); + + void setLastOperator(std::shared_ptr op) { lastOperator = std::move(op); } + + bool isEmpty() const { return lastOperator == nullptr; } + + std::shared_ptr getLastOperator() const { return lastOperator; } + LogicalOperator& getLastOperatorRef() const { + KU_ASSERT(lastOperator); + return *lastOperator; + } + Schema* getSchema() const { return lastOperator->getSchema(); } + + cardinality_t getCardinality() const { + KU_ASSERT(lastOperator); + return lastOperator->getCardinality(); + } + + void setCost(uint64_t cost_) { cost = cost_; } + uint64_t getCost() const { return cost; } + + std::string toString() const { return lastOperator->toString(); } + + bool isProfile() const; + bool hasUpdate() const; + +private: + std::shared_ptr lastOperator; + uint64_t cost; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_plan_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_plan_util.h new file mode 100644 index 0000000000..eab7ec5d19 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_plan_util.h @@ -0,0 +1,26 @@ +#pragma once + +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace planner { + +class LogicalPlanUtil { +public: + static std::string encodeJoin(LogicalPlan& logicalPlan); + +private: + static std::string encode(LogicalOperator* logicalOperator); + static void encodeRecursive(LogicalOperator* logicalOperator, std::string& encodeString); + // Encode joins + static void encodeCrossProduct(LogicalOperator* logicalOperator, std::string& encodeString); + static void encodeIntersect(LogicalOperator* logicalOperator, std::string& encodeString); + static void encodeHashJoin(LogicalOperator* logicalOperator, std::string& encodeString); + static void encodeExtend(LogicalOperator* logicalOperator, std::string& encodeString); + static void encodeScanNodeTable(LogicalOperator* logicalOperator, std::string& encodeString); + // Encode filter + static void encodeFilter(LogicalOperator* logicalOperator, std::string& encodedString); +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_projection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_projection.h new file mode 100644 index 0000000000..f2a52c7cee --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_projection.h @@ -0,0 +1,40 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "binder/expression/expression_util.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LBUG_API LogicalProjection : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::PROJECTION; + +public: + LogicalProjection(binder::expression_vector expressions, std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, expressions{std::move(expressions)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { + return binder::ExpressionUtil::toString(expressions); + } + + binder::expression_vector getExpressionsToProject() const { return expressions; } + void setExpressionsToProject(const binder::expression_vector& expressions) { + this->expressions = expressions; + } + + std::unordered_set getDiscardedGroupsPos() const; + + std::unique_ptr copy() override { + return make_unique(expressions, children[0]->copy()); + } + +private: + binder::expression_vector expressions; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_standalone_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_standalone_call.h new file mode 100644 index 0000000000..0913309c52 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_standalone_call.h @@ -0,0 +1,38 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace main { +struct Option; +} +namespace planner { + +class LogicalStandaloneCall final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::STANDALONE_CALL; + +public: + LogicalStandaloneCall(const main::Option* option, + std::shared_ptr optionValue) + : LogicalOperator{type_}, option{option}, optionValue{std::move(optionValue)} {} + + const main::Option* getOption() const { return option; } + std::shared_ptr getOptionValue() const { return optionValue; } + + std::string getExpressionsForPrinting() const override; + + void computeFlatSchema() override { createEmptySchema(); } + + void computeFactorizedSchema() override { createEmptySchema(); } + + std::unique_ptr copy() override { + return make_unique(option, optionValue); + } + +protected: + const main::Option* option; + std::shared_ptr optionValue; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_table_function_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_table_function_call.h new file mode 100644 index 0000000000..93dd979352 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_table_function_call.h @@ -0,0 +1,46 @@ +#pragma once + +#include "function/table/bind_data.h" +#include "function/table/table_function.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LBUG_API LogicalTableFunctionCall final : public LogicalOperator { + static constexpr LogicalOperatorType operatorType_ = LogicalOperatorType::TABLE_FUNCTION_CALL; + +public: + LogicalTableFunctionCall(function::TableFunction tableFunc, + std::unique_ptr bindData) + : LogicalOperator{operatorType_}, tableFunc{std::move(tableFunc)}, + bindData{std::move(bindData)} { + setCardinality(this->bindData->numRows); + } + + const function::TableFunction& getTableFunc() const { return tableFunc; } + const function::TableFuncBindData* getBindData() const { return bindData.get(); } + + void setColumnSkips(std::vector columnSkips) { + bindData->setColumnSkips(std::move(columnSkips)); + } + void setColumnPredicates(std::vector predicates) { + bindData->setColumnPredicates(std::move(predicates)); + } + + void computeFlatSchema() override; + void computeFactorizedSchema() override; + + std::string getExpressionsForPrinting() const override { return tableFunc.name; } + + std::unique_ptr copy() override { + return std::make_unique(tableFunc, bindData->copy()); + } + +private: + function::TableFunction tableFunc; + std::unique_ptr bindData; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_transaction.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_transaction.h new file mode 100644 index 0000000000..7c32201d84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_transaction.h @@ -0,0 +1,32 @@ +#pragma once + +#include "planner/operator/logical_operator.h" +#include "transaction/transaction_action.h" + +namespace lbug { +namespace planner { + +class LogicalTransaction : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::TRANSACTION; + +public: + explicit LogicalTransaction(transaction::TransactionAction transactionAction) + : LogicalOperator{type_}, transactionAction{transactionAction} {} + + std::string getExpressionsForPrinting() const final { return std::string(); } + + void computeFlatSchema() final { createEmptySchema(); } + void computeFactorizedSchema() final { createEmptySchema(); } + + transaction::TransactionAction getTransactionAction() const { return transactionAction; } + + std::unique_ptr copy() final { + return std::make_unique(transactionAction); + } + +private: + transaction::TransactionAction transactionAction; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_union.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_union.h new file mode 100644 index 0000000000..914c4428d0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_union.h @@ -0,0 +1,38 @@ +#pragma once + +#include "logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalUnion : public LogicalOperator { +public: + LogicalUnion(binder::expression_vector expressions, + const std::vector>& children) + : LogicalOperator{LogicalOperatorType::UNION_ALL, children}, + expressionsToUnion{std::move(expressions)} {} + + f_group_pos_set getGroupsPosToFlatten(uint32_t childIdx); + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return std::string{}; } + + binder::expression_vector getExpressionsToUnion() const { return expressionsToUnion; } + + Schema* getSchemaBeforeUnion(uint32_t idx) const { return children[idx]->getSchema(); } + + std::unique_ptr copy() override; + +private: + // If an expression to union has different flat/unflat state in different child, we + // need to flatten that expression in all the single queries. + bool requireFlatExpression(uint32_t expressionIdx); + +private: + binder::expression_vector expressionsToUnion; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_unwind.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_unwind.h new file mode 100644 index 0000000000..3d510c65f6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/logical_unwind.h @@ -0,0 +1,39 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalUnwind : public LogicalOperator { +public: + LogicalUnwind(std::shared_ptr inExpr, + std::shared_ptr outExpr, std::shared_ptr idExpr, + std::shared_ptr childOperator) + : LogicalOperator{LogicalOperatorType::UNWIND, std::move(childOperator)}, + inExpr{std::move(inExpr)}, outExpr{std::move(outExpr)}, idExpr{std::move(idExpr)} {} + + f_group_pos_set getGroupsPosToFlatten(); + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::shared_ptr getInExpr() const { return inExpr; } + std::shared_ptr getOutExpr() const { return outExpr; } + bool hasIDExpr() const { return idExpr != nullptr; } + std::shared_ptr getIDExpr() const { return idExpr; } + + std::string getExpressionsForPrinting() const override { return inExpr->toString(); } + + std::unique_ptr copy() override { + return make_unique(inExpr, outExpr, idExpr, children[0]->copy()); + } + +private: + std::shared_ptr inExpr; + std::shared_ptr outExpr; + std::shared_ptr idExpr; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/operator_print_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/operator_print_info.h new file mode 100644 index 0000000000..253cbf83f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/operator_print_info.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace lbug { + +struct OPPrintInfo { + OPPrintInfo() {} + virtual ~OPPrintInfo() = default; + + virtual std::string toString() const { return std::string(); } + + virtual std::unique_ptr copy() const { return std::make_unique(); } + + static std::unique_ptr EmptyInfo() { return std::make_unique(); } +}; + +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_copy_from.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_copy_from.h new file mode 100644 index 0000000000..7530c6e687 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_copy_from.h @@ -0,0 +1,55 @@ +#pragma once + +#include "binder/copy/bound_copy_from.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalCopyFromPrintInfo final : OPPrintInfo { + std::string tableName; + + explicit LogicalCopyFromPrintInfo(std::string tableName) : tableName(std::move(tableName)) {} + + std::string toString() const override { return "Table name: " + tableName; }; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LogicalCopyFromPrintInfo(*this)); + } + +private: + LogicalCopyFromPrintInfo(const LogicalCopyFromPrintInfo& other) + : OPPrintInfo(other), tableName(other.tableName) {} +}; + +class LogicalCopyFrom final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::COPY_FROM; + +public: + LogicalCopyFrom(binder::BoundCopyFromInfo info, std::shared_ptr child) + : LogicalOperator{type_, std::move(child), std::optional(0)}, + info{std::move(info)} {} + LogicalCopyFrom(binder::BoundCopyFromInfo info, const logical_op_vector_t& children) + : LogicalOperator{type_, children}, info{std::move(info)} {} + + std::string getExpressionsForPrinting() const override { return info.tableName; } + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + const binder::BoundCopyFromInfo* getInfo() const { return &info; } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(info.tableName); + } + + std::unique_ptr copy() override { + return make_unique(info.copy(), LogicalOperator::copy(children)); + } + +private: + binder::BoundCopyFromInfo info; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_copy_to.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_copy_to.h new file mode 100644 index 0000000000..3ccff5a878 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_copy_to.h @@ -0,0 +1,59 @@ +#pragma once + +#include "function/export/export_function.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalCopyToPrintInfo final : OPPrintInfo { + std::vector columnNames; + std::string fileName; + + LogicalCopyToPrintInfo(std::vector columnNames, std::string fileName) + : columnNames(std::move(columnNames)), fileName(std::move(fileName)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LogicalCopyToPrintInfo(*this)); + } + +private: + LogicalCopyToPrintInfo(const LogicalCopyToPrintInfo& other) + : OPPrintInfo(other), columnNames(other.columnNames), fileName(other.fileName) {} +}; + +class LogicalCopyTo final : public LogicalOperator { +public: + LogicalCopyTo(std::unique_ptr bindData, + function::ExportFunction exportFunc, std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::COPY_TO, std::move(child), + std::optional(0)}, + bindData{std::move(bindData)}, exportFunc{std::move(exportFunc)} {} + + f_group_pos_set getGroupsPosToFlatten(); + + std::string getExpressionsForPrinting() const override { return std::string{}; } + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::unique_ptr getBindData() const { return bindData->copy(); } + function::ExportFunction getExportFunc() const { return exportFunc; }; + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(bindData->columnNames, bindData->fileName); + } + + std::unique_ptr copy() override { + return make_unique(bindData->copy(), exportFunc, children[0]->copy()); + } + +private: + std::unique_ptr bindData; + function::ExportFunction exportFunc; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_delete.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_delete.h new file mode 100644 index 0000000000..c7b9e42606 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_delete.h @@ -0,0 +1,58 @@ +#pragma once + +#include "binder/query/updating_clause/bound_delete_info.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalDeletePrintInfo final : OPPrintInfo { + std::vector infos; + + explicit LogicalDeletePrintInfo(std::vector infos) + : infos{std::move(infos)} {} + + std::string toString() const override { + std::string result = ""; + for (auto& info : infos) { + result += info.toString(); + } + return result; + } +}; + +class LogicalDelete final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::DELETE; + +public: + LogicalDelete(std::vector infos, + std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, infos{std::move(infos)} {} + + common::TableType getTableType() const { + KU_ASSERT(!infos.empty()); + return infos[0].tableType; + } + const std::vector& getInfos() const { return infos; } + + void computeFactorizedSchema() override { copyChildSchema(0); } + void computeFlatSchema() override { copyChildSchema(0); } + + std::string getExpressionsForPrinting() const override; + + f_group_pos_set getGroupsPosToFlatten() const; + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(copyVector(infos)); + } + + std::unique_ptr copy() override { + return std::make_unique(copyVector(infos), children[0]->copy()); + } + +private: + std::vector infos; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_insert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_insert.h new file mode 100644 index 0000000000..3a195cda46 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_insert.h @@ -0,0 +1,57 @@ +#pragma once + +#include "common/enums/conflict_action.h" +#include "common/enums/table_type.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalInsertInfo { + common::TableType tableType; + std::shared_ptr pattern; + binder::expression_vector columnExprs; + binder::expression_vector columnDataExprs; + std::vector isReturnColumnExprs; + common::ConflictAction conflictAction; + + LogicalInsertInfo(common::TableType tableType, std::shared_ptr pattern, + binder::expression_vector columnExprs, binder::expression_vector columnDataExprs, + common::ConflictAction conflictAction) + : tableType{tableType}, pattern{std::move(pattern)}, columnExprs{std::move(columnExprs)}, + columnDataExprs{std::move(columnDataExprs)}, conflictAction{conflictAction} {} + EXPLICIT_COPY_DEFAULT_MOVE(LogicalInsertInfo); + +private: + LogicalInsertInfo(const LogicalInsertInfo& other) + : tableType{other.tableType}, pattern{other.pattern}, columnExprs{other.columnExprs}, + columnDataExprs{other.columnDataExprs}, isReturnColumnExprs{other.isReturnColumnExprs}, + conflictAction{other.conflictAction} {} +}; + +class LogicalInsert final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::INSERT; + +public: + LogicalInsert(std::vector infos, std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, infos{std::move(infos)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const final; + + f_group_pos_set getGroupsPosToFlatten(); + + const std::vector& getInfos() const { return infos; } + + std::unique_ptr copy() override { + return std::make_unique(copyVector(infos), children[0]->copy()); + } + +private: + std::vector infos; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_merge.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_merge.h new file mode 100644 index 0000000000..bd9c2d1856 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_merge.h @@ -0,0 +1,85 @@ +#pragma once + +#include "binder/query/updating_clause/bound_set_info.h" +#include "planner/operator/logical_operator.h" +#include "planner/operator/persistent/logical_insert.h" + +namespace lbug { +namespace planner { + +class LogicalMerge final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::MERGE; + +public: + LogicalMerge(std::shared_ptr existenceMark, binder::expression_vector keys, + std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, existenceMark{std::move(existenceMark)}, + keys{std::move(keys)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return {}; } + + f_group_pos_set getGroupsPosToFlatten(); + + std::shared_ptr getExistenceMark() const { return existenceMark; } + + void addInsertNodeInfo(LogicalInsertInfo info) { insertNodeInfos.push_back(std::move(info)); } + const std::vector& getInsertNodeInfos() const { return insertNodeInfos; } + + void addInsertRelInfo(LogicalInsertInfo info) { insertRelInfos.push_back(std::move(info)); } + const std::vector& getInsertRelInfos() const { return insertRelInfos; } + + void addOnCreateSetNodeInfo(binder::BoundSetPropertyInfo info) { + onCreateSetNodeInfos.push_back(std::move(info)); + } + const std::vector& getOnCreateSetNodeInfos() const { + return onCreateSetNodeInfos; + } + + void addOnCreateSetRelInfo(binder::BoundSetPropertyInfo info) { + onCreateSetRelInfos.push_back(std::move(info)); + } + const std::vector& getOnCreateSetRelInfos() const { + return onCreateSetRelInfos; + } + + void addOnMatchSetNodeInfo(binder::BoundSetPropertyInfo info) { + onMatchSetNodeInfos.push_back(std::move(info)); + } + const std::vector& getOnMatchSetNodeInfos() const { + return onMatchSetNodeInfos; + } + + void addOnMatchSetRelInfo(binder::BoundSetPropertyInfo info) { + onMatchSetRelInfos.push_back(std::move(info)); + } + const std::vector& getOnMatchSetRelInfos() const { + return onMatchSetRelInfos; + } + const binder::expression_vector& getKeys() const { return keys; } + + std::unique_ptr copy() override; + +private: + std::shared_ptr existenceMark; + // Create infos + std::vector insertNodeInfos; + std::vector insertRelInfos; + // On Create infos + std::vector onCreateSetNodeInfos; + std::vector onCreateSetRelInfos; + // On Match infos + std::vector onMatchSetNodeInfos; + std::vector onMatchSetRelInfos; + // Key expressions used in merge hash table. + // If a merge clause is taking input from previous query parts + // E.g. UNWIND [1,1,3] AS x MERGE (n:N{id:x}) + // Since we don't re-evaluate the existence of n for each x, we need to create n only for + // distinct x, i.e. 1 & 3. So there is a notion of key in MERGE. + binder::expression_vector keys; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_set.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_set.h new file mode 100644 index 0000000000..560a30ea5f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/persistent/logical_set.h @@ -0,0 +1,37 @@ +#pragma once + +#include "binder/query/updating_clause/bound_set_info.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalSetProperty final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::SET_PROPERTY; + +public: + LogicalSetProperty(std::vector infos, + std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, infos{std::move(infos)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + f_group_pos_set getGroupsPosToFlatten(uint32_t idx) const; + + std::string getExpressionsForPrinting() const override; + + common::TableType getTableType() const; + const std::vector& getInfos() const { return infos; } + const binder::BoundSetPropertyInfo& getInfo(uint32_t idx) const { return infos[idx]; } + + std::unique_ptr copy() override { + return std::make_unique(copyVector(infos), children[0]->copy()); + } + +private: + std::vector infos; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_dummy_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_dummy_scan.h new file mode 100644 index 0000000000..2bea164f3e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_dummy_scan.h @@ -0,0 +1,25 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalDummyScan final : public LogicalOperator { +public: + explicit LogicalDummyScan() : LogicalOperator{LogicalOperatorType::DUMMY_SCAN} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + inline std::string getExpressionsForPrinting() const override { return std::string(); } + + static std::shared_ptr getDummyExpression(); + + inline std::unique_ptr copy() override { + return std::make_unique(); + } +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_expressions_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_expressions_scan.h new file mode 100644 index 0000000000..a0f1ba7681 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_expressions_scan.h @@ -0,0 +1,40 @@ +#pragma once + +#include "binder/expression/expression_util.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +// LogicalExpressionsScan scans from an outer factorize table +class LogicalExpressionsScan final : public LogicalOperator { +public: + explicit LogicalExpressionsScan(binder::expression_vector expressions) + : LogicalOperator{LogicalOperatorType::EXPRESSIONS_SCAN}, + expressions{std::move(expressions)}, outerAccumulate{nullptr} {} + + inline void computeFactorizedSchema() override { computeSchema(); } + inline void computeFlatSchema() override { computeSchema(); } + + inline std::string getExpressionsForPrinting() const override { + return binder::ExpressionUtil::toString(expressions); + } + + inline binder::expression_vector getExpressions() const { return expressions; } + inline void setOuterAccumulate(LogicalOperator* op) { outerAccumulate = op; } + inline LogicalOperator* getOuterAccumulate() const { return outerAccumulate; } + + inline std::unique_ptr copy() override { + return std::make_unique(expressions); + } + +private: + void computeSchema(); + +private: + binder::expression_vector expressions; + LogicalOperator* outerAccumulate; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_index_look_up.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_index_look_up.h new file mode 100644 index 0000000000..e0485e2100 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_index_look_up.h @@ -0,0 +1,37 @@ +#pragma once + +#include "binder/copy/index_look_up_info.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +// This operator is specifically used to transform primary key to offset during relationship copy. +// So it is not a source operator. I would suggest move this logic into rel copy instead of +// maintaining an operator. +class LogicalPrimaryKeyLookup final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::INDEX_LOOK_UP; + +public: + LogicalPrimaryKeyLookup(std::vector infos, + std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, infos{std::move(infos)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override; + + uint32_t getNumInfos() const { return infos.size(); } + const binder::IndexLookupInfo& getInfo(uint32_t idx) const { return infos[idx]; } + + std::unique_ptr copy() override { + return make_unique(infos, children[0]->copy()); + } + +private: + std::vector infos; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_scan_node_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_scan_node_table.h new file mode 100644 index 0000000000..e363c69934 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/scan/logical_scan_node_table.h @@ -0,0 +1,111 @@ +#pragma once + +#include "binder/expression/expression_util.h" +#include "planner/operator/logical_operator.h" +#include "storage/predicate/column_predicate.h" + +namespace lbug { +namespace planner { + +enum class LogicalScanNodeTableType : uint8_t { + SCAN = 0, + PRIMARY_KEY_SCAN = 1, +}; + +struct ExtraScanNodeTableInfo { + virtual ~ExtraScanNodeTableInfo() = default; + virtual std::unique_ptr copy() const = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +struct PrimaryKeyScanInfo final : ExtraScanNodeTableInfo { + std::shared_ptr key; + + explicit PrimaryKeyScanInfo(std::shared_ptr key) : key{std::move(key)} {} + + std::unique_ptr copy() const override { + return std::make_unique(key); + } +}; + +struct LogicalScanNodeTablePrintInfo final : OPPrintInfo { + std::shared_ptr nodeID; + binder::expression_vector properties; + + LogicalScanNodeTablePrintInfo(std::shared_ptr nodeID, + binder::expression_vector properties) + : nodeID{std::move(nodeID)}, properties{std::move(properties)} {} + + std::string toString() const override { + auto result = "Tables: " + nodeID->toString(); + if (nodeID->hasAlias()) { + result += "Alias: " + nodeID->getAlias(); + } + result += ",Properties :" + binder::ExpressionUtil::toString(properties); + return result; + } +}; + +class LogicalScanNodeTable final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::SCAN_NODE_TABLE; + static constexpr LogicalScanNodeTableType defaultScanType = LogicalScanNodeTableType::SCAN; + +public: + LogicalScanNodeTable(std::shared_ptr nodeID, + std::vector nodeTableIDs, binder::expression_vector properties, + common::cardinality_t cardinality = 0) + : LogicalOperator{type_}, scanType{defaultScanType}, nodeID{std::move(nodeID)}, + nodeTableIDs{std::move(nodeTableIDs)}, properties{std::move(properties)} { + this->cardinality = cardinality; + } + LogicalScanNodeTable(const LogicalScanNodeTable& other); + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { + return nodeID->toString() + " " + binder::ExpressionUtil::toString(properties); + } + + LogicalScanNodeTableType getScanType() const { return scanType; } + void setScanType(LogicalScanNodeTableType scanType_) { scanType = scanType_; } + + std::shared_ptr getNodeID() const { return nodeID; } + std::vector getTableIDs() const { return nodeTableIDs; } + + binder::expression_vector getProperties() const { return properties; } + void addProperty(std::shared_ptr expr) { + properties.push_back(std::move(expr)); + } + void setPropertyPredicates(std::vector predicates) { + propertyPredicates = std::move(predicates); + } + const std::vector& getPropertyPredicates() const { + return propertyPredicates; + } + + void setExtraInfo(std::unique_ptr info) { extraInfo = std::move(info); } + + ExtraScanNodeTableInfo* getExtraInfo() const { return extraInfo.get(); } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(nodeID, properties); + } + + std::unique_ptr copy() override; + +private: + LogicalScanNodeTableType scanType; + std::shared_ptr nodeID; + std::vector nodeTableIDs; + binder::expression_vector properties; + std::vector propertyPredicates; + std::unique_ptr extraInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/schema.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/schema.h new file mode 100644 index 0000000000..e82a6d63b9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/schema.h @@ -0,0 +1,147 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" + +namespace lbug { +namespace planner { + +using f_group_pos = uint32_t; +using f_group_pos_set = std::unordered_set; +constexpr f_group_pos INVALID_F_GROUP_POS = UINT32_MAX; + +class FactorizationGroup { + friend class Schema; + friend class CardinalityEstimator; + +public: + FactorizationGroup() : flat{false}, singleState{false}, cardinalityMultiplier{1} {} + FactorizationGroup(const FactorizationGroup& other) + : flat{other.flat}, singleState{other.singleState}, + cardinalityMultiplier{other.cardinalityMultiplier}, expressions{other.expressions}, + expressionNameToPos{other.expressionNameToPos} {} + + void setFlat() { + KU_ASSERT(!flat); + flat = true; + } + bool isFlat() const { return flat; } + void setSingleState() { + KU_ASSERT(!singleState); + singleState = true; + setFlat(); + } + bool isSingleState() const { return singleState; } + + void setMultiplier(double multiplier) { cardinalityMultiplier = multiplier; } + double getMultiplier() const { return cardinalityMultiplier; } + + void insertExpression(const std::shared_ptr& expression) { + KU_ASSERT(!expressionNameToPos.contains(expression->getUniqueName())); + expressionNameToPos.insert({expression->getUniqueName(), expressions.size()}); + expressions.push_back(expression); + } + binder::expression_vector getExpressions() const { return expressions; } + uint32_t getExpressionPos(const binder::Expression& expression) const { + KU_ASSERT(expressionNameToPos.contains(expression.getUniqueName())); + return expressionNameToPos.at(expression.getUniqueName()); + } + +private: + bool flat; + bool singleState; + double cardinalityMultiplier; + binder::expression_vector expressions; + std::unordered_map expressionNameToPos; +}; + +class Schema { +public: + common::idx_t getNumGroups() const { return groups.size(); } + + FactorizationGroup* getGroup(const std::shared_ptr& expression) const { + return getGroup(getGroupPos(expression->getUniqueName())); + } + + FactorizationGroup* getGroup(const std::string& expressionName) const { + return getGroup(getGroupPos(expressionName)); + } + + FactorizationGroup* getGroup(uint32_t pos) const { return groups[pos].get(); } + + f_group_pos createGroup(); + + void insertToScope(const std::shared_ptr& expression, uint32_t groupPos); + void insertToGroupAndScope(const std::shared_ptr& expression, + uint32_t groupPos); + // Use these unsafe insert functions only if the operator may work with duplicate expressions. + // E.g. group by a.age, a.age + void insertToScopeMayRepeat(const std::shared_ptr& expression, + uint32_t groupPos); + void insertToGroupAndScopeMayRepeat(const std::shared_ptr& expression, + uint32_t groupPos); + + void insertToGroupAndScope(const binder::expression_vector& expressions, uint32_t groupPos); + + f_group_pos getGroupPos(const binder::Expression& expression) const { + return getGroupPos(expression.getUniqueName()); + } + + f_group_pos getGroupPos(const std::string& expressionName) const; + + std::pair getExpressionPos(const binder::Expression& expression) const { + auto groupPos = getGroupPos(expression); + return std::make_pair(groupPos, groups[groupPos]->getExpressionPos(expression)); + } + + void flattenGroup(f_group_pos pos) { groups[pos]->setFlat(); } + void setGroupAsSingleState(f_group_pos pos) { groups[pos]->setSingleState(); } + + bool isExpressionInScope(const binder::Expression& expression) const; + + binder::expression_vector getExpressionsInScope() const { return expressionsInScope; } + + binder::expression_vector getExpressionsInScope(f_group_pos pos) const; + + bool evaluable(const binder::Expression& expression) const; + + void clearExpressionsInScope() { + expressionNameToGroupPos.clear(); + expressionsInScope.clear(); + } + + // Get the group positions containing at least one expression in scope. + f_group_pos_set getGroupsPosInScope() const; + + LBUG_API std::unique_ptr copy() const; + + void clear(); + +private: + size_t getNumGroups(bool isFlat) const; + +private: + std::vector> groups; + std::unordered_map expressionNameToGroupPos; + // Our projection doesn't explicitly remove expressions. Instead, we keep track of what + // expressions are in scope (i.e. being projected). + binder::expression_vector expressionsInScope; +}; + +class SchemaUtils { +public: + // Given a set of factorization group, a leading group is selected as the unFlat group (caller + // should ensure at most one unFlat group which is our general assumption of factorization). If + // all groups are flat, we select any (the first) group as leading group. + static f_group_pos getLeadingGroupPos(const std::unordered_set& groupPositions, + const Schema& schema); + + static void validateAtMostOneUnFlatGroup(const std::unordered_set& groupPositions, + const Schema& schema); + static void validateNoUnFlatGroup(const std::unordered_set& groupPositions, + const Schema& schema); +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_attach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_attach_database.h new file mode 100644 index 0000000000..ee88ac9ce9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_attach_database.h @@ -0,0 +1,46 @@ +#pragma once + +#include "binder/bound_attach_info.h" +#include "logical_simple.h" + +namespace lbug { +namespace planner { + +struct LogicalAttachDatabasePrintInfo final : OPPrintInfo { + std::string dbName; + + explicit LogicalAttachDatabasePrintInfo(std::string dbName) : dbName(std::move(dbName)) {} + + std::string toString() const override { return "Database: " + dbName; }; + + std::unique_ptr copy() const override { + return std::unique_ptr( + new LogicalAttachDatabasePrintInfo(*this)); + } + +private: + LogicalAttachDatabasePrintInfo(const LogicalAttachDatabasePrintInfo& other) + : OPPrintInfo(other), dbName(other.dbName) {} +}; + +class LogicalAttachDatabase final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::ATTACH_DATABASE; + +public: + explicit LogicalAttachDatabase(binder::AttachInfo attachInfo) + : LogicalSimple{type_}, attachInfo{std::move(attachInfo)} {} + + binder::AttachInfo getAttachInfo() const { return attachInfo; } + + std::string getExpressionsForPrinting() const override { return attachInfo.dbPath; } + + std::unique_ptr copy() override { + return std::make_unique(attachInfo); + } + +private: + binder::AttachInfo attachInfo; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_detach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_detach_database.h new file mode 100644 index 0000000000..776e4c1fb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_detach_database.h @@ -0,0 +1,28 @@ +#pragma once + +#include "logical_simple.h" + +namespace lbug { +namespace planner { + +class LogicalDetachDatabase final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::DETACH_DATABASE; + +public: + explicit LogicalDetachDatabase(std::string dbName) + : LogicalSimple{type_}, dbName{std::move(dbName)} {} + + std::string getDBName() const { return dbName; } + + std::string getExpressionsForPrinting() const override { return dbName; } + + std::unique_ptr copy() override { + return std::make_unique(dbName); + } + +private: + std::string dbName; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_export_db.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_export_db.h new file mode 100644 index 0000000000..56806da756 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_export_db.h @@ -0,0 +1,41 @@ +#pragma once + +#include "common/copier_config/csv_reader_config.h" +#include "common/copier_config/file_scan_info.h" +#include "logical_simple.h" + +namespace lbug { +namespace planner { + +class LogicalExportDatabase final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::EXPORT_DATABASE; + +public: + LogicalExportDatabase(common::FileScanInfo boundFileInfo, + const std::vector>& plans, bool exportSchemaOnly) + : LogicalSimple{type_, plans}, boundFileInfo{std::move(boundFileInfo)}, + schemaOnly{exportSchemaOnly} {} + + std::string getFilePath() const { return boundFileInfo.filePaths[0]; } + common::FileType getFileType() const { return boundFileInfo.fileTypeInfo.fileType; } + common::CSVOption getCopyOption() const { + auto csvConfig = common::CSVReaderConfig::construct(boundFileInfo.options); + return csvConfig.option.copy(); + } + const common::FileScanInfo* getBoundFileInfo() const { return &boundFileInfo; } + std::string getExpressionsForPrinting() const override { return std::string{}; } + + bool isSchemaOnly() const { return schemaOnly; } + + std::unique_ptr copy() override { + return make_unique(boundFileInfo.copy(), copyVector(children), + schemaOnly); + } + +private: + common::FileScanInfo boundFileInfo; + bool schemaOnly; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_extension.h new file mode 100644 index 0000000000..6852ee82e2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_extension.h @@ -0,0 +1,30 @@ +#pragma once + +#include "extension/extension_action.h" +#include "logical_simple.h" + +namespace lbug { +namespace planner { + +class LogicalExtension final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::EXTENSION; + +public: + explicit LogicalExtension(std::unique_ptr auxInfo) + : LogicalSimple{type_}, auxInfo{std::move(auxInfo)} {} + + std::string getExpressionsForPrinting() const override { return path; } + + const extension::ExtensionAuxInfo& getAuxInfo() const { return *auxInfo; } + + std::unique_ptr copy() override { + return std::make_unique(auxInfo->copy()); + } + +private: + std::unique_ptr auxInfo; + std::string path; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_import_db.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_import_db.h new file mode 100644 index 0000000000..b5cd5241aa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_import_db.h @@ -0,0 +1,32 @@ +#pragma once + +#include "logical_simple.h" + +namespace lbug { +namespace planner { + +class LogicalImportDatabase : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::IMPORT_DATABASE; + +public: + LogicalImportDatabase(std::string query, std::string indexQuery) + : LogicalSimple{type_}, query{std::move(query)}, indexQuery{std::move(indexQuery)} {} + + std::string getQuery() const { return query; } + + std::string getIndexQuery() const { return indexQuery; } + + std::string getExpressionsForPrinting() const override { return std::string{}; } + + std::unique_ptr copy() override { + return make_unique(query, indexQuery); + } + +private: + // see comment in BoundImportDatabase + std::string query; + std::string indexQuery; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_simple.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_simple.h new file mode 100644 index 0000000000..ada03ee2c0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_simple.h @@ -0,0 +1,21 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +class LogicalSimple : public LogicalOperator { +public: + explicit LogicalSimple(LogicalOperatorType operatorType) : LogicalOperator{operatorType} {} + LogicalSimple(LogicalOperatorType operatorType, + const std::vector>& plans) + : LogicalOperator{operatorType, plans} {} + + void computeFactorizedSchema() override; + + void computeFlatSchema() override; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_use_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_use_database.h new file mode 100644 index 0000000000..d7914e34dd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/simple/logical_use_database.h @@ -0,0 +1,28 @@ +#pragma once + +#include "logical_simple.h" + +namespace lbug { +namespace planner { + +class LogicalUseDatabase final : public LogicalSimple { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::USE_DATABASE; + +public: + explicit LogicalUseDatabase(std::string dbName) + : LogicalSimple{type_}, dbName{std::move(dbName)} {} + + std::string getDBName() const { return dbName; } + + std::string getExpressionsForPrinting() const override { return dbName; } + + std::unique_ptr copy() override { + return std::make_unique(dbName); + } + +private: + std::string dbName; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/logical_semi_masker.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/logical_semi_masker.h new file mode 100644 index 0000000000..32f4c40d45 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/logical_semi_masker.h @@ -0,0 +1,117 @@ +#pragma once + +#include "common/enums/extend_direction.h" +#include "common/exception/runtime.h" +#include "planner/operator/logical_operator.h" +#include "semi_mask_target_type.h" + +namespace lbug { +namespace planner { + +/* + * NODE + * offsets are collected from value vector of type NODE_ID (INTERNAL_ID) + * + * PATH + * offsets are collected from value vector of type PATH (LIST[INTERNAL_ID]). This is a fast-path + * code used when scanning properties along the path. + * + * */ +enum class SemiMaskKeyType : uint8_t { + NODE = 0, + PATH = 1, + NODE_ID_LIST = 2, +}; + +struct ExtraKeyInfo { + virtual ~ExtraKeyInfo() = default; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + + virtual std::unique_ptr copy() const = 0; +}; + +struct ExtraPathKeyInfo final : ExtraKeyInfo { + common::ExtendDirection direction; + + explicit ExtraPathKeyInfo(common::ExtendDirection direction) : direction{direction} {} + + std::unique_ptr copy() const override { + return std::make_unique(direction); + } +}; + +struct ExtraNodeIDListKeyInfo final : ExtraKeyInfo { + std::shared_ptr srcNodeID; + std::shared_ptr dstNodeID; + + ExtraNodeIDListKeyInfo(std::shared_ptr srcNodeID, + std::shared_ptr dstNodeID) + : srcNodeID{std::move(srcNodeID)}, dstNodeID{std::move(dstNodeID)} {} + + std::unique_ptr copy() const override { + return std::make_unique(srcNodeID, dstNodeID); + } +}; + +class LBUG_API LogicalSemiMasker final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::SEMI_MASKER; + +public: + LogicalSemiMasker(SemiMaskKeyType keyType, SemiMaskTargetType targetType, + std::shared_ptr key, std::vector nodeTableIDs, + std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, keyType{keyType}, targetType{targetType}, + key{std::move(key)}, nodeTableIDs{std::move(nodeTableIDs)} {} + + ~LogicalSemiMasker() override; + + void computeFactorizedSchema() override { copyChildSchema(0); } + void computeFlatSchema() override { copyChildSchema(0); } + + std::string getExpressionsForPrinting() const override { return key->toString(); } + + SemiMaskKeyType getKeyType() const { return keyType; } + + SemiMaskTargetType getTargetType() const { return targetType; } + + std::shared_ptr getKey() const { return key; } + void setExtraKeyInfo(std::unique_ptr extraInfo) { + extraKeyInfo = std::move(extraInfo); + } + ExtraKeyInfo* getExtraKeyInfo() const { return extraKeyInfo.get(); } + + std::vector getNodeTableIDs() const { return nodeTableIDs; } + + void addTarget(const LogicalOperator* op) { targetOps.push_back(op); } + std::vector getTargetOperators() const { return targetOps; } + + std::unique_ptr copy() override { + if (!targetOps.empty()) { + throw common::RuntimeException( + "LogicalSemiMasker::copy() should not be called when ops " + "is not empty. Raw pointers will be point to corrupted object after copy."); + } + auto result = std::make_unique(keyType, targetType, key, nodeTableIDs, + children[0]->copy()); + if (extraKeyInfo != nullptr) { + result->setExtraKeyInfo(extraKeyInfo->copy()); + } + return result; + } + +private: + SemiMaskKeyType keyType; + SemiMaskTargetType targetType; + std::shared_ptr key; + std::unique_ptr extraKeyInfo = nullptr; + std::vector nodeTableIDs; + // Operators accepting semi masker + std::vector targetOps; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/semi_mask_target_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/semi_mask_target_type.h new file mode 100644 index 0000000000..7602278d7f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/semi_mask_target_type.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace lbug { +namespace planner { + +enum class SemiMaskTargetType : uint8_t { + SCAN_NODE = 0, + RECURSIVE_EXTEND_INPUT_NODE = 2, + RECURSIVE_EXTEND_OUTPUT_NODE = 3, + RECURSIVE_EXTEND_PATH_NODE = 4, + GDS_GRAPH_NODE = 5, +}; + +} +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/side_way_info_passing.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/side_way_info_passing.h new file mode 100644 index 0000000000..6d73d726b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/operator/sip/side_way_info_passing.h @@ -0,0 +1,79 @@ +#pragma once + +#include + +namespace lbug { +namespace planner { + +enum class SemiMaskPosition : uint8_t { + NONE = 0, + ON_BUILD = 1, + ON_PROBE = 2, + PROHIBIT_PROBE_TO_BUILD = 3, + PROHIBIT = 4, +}; + +/* + * If semi mask is present, scan pipelines needs to be constructed before semi mask pipelines. + * SIPDependency instructs whether probe or build should be constructed first. + */ +enum class SIPDependency : uint8_t { + NONE = 0, + PROBE_DEPENDS_ON_BUILD = 1, + BUILD_DEPENDS_ON_PROBE = 2, +}; + +/* + * Direction of side way information passing. If direction is probe to build, then probe side + * must have been materialized, and we need to construct a new pipeline to scan materialized result. + * We use SIPDirection in the mapper to create pipelines. + * */ +enum class SIPDirection { + NONE = 0, + PROBE_TO_BUILD = 1, + BUILD_TO_PROBE = 2, + // TODO(Xiyang/Guodong): Temp hack to allow vector index search to pass semi mask. + FORCE_BUILD_TO_PROBE = 3, +}; + +/* + * + * We perform side way information passing in the following cases + * 1. Inner hash join + * + * If we add semi mask on build side, position is ON_BUILD, direction is BUILD_TO_PROBE and + * dependency is PROBE_DEPENDS_ON_BUILD (because scan need to be generated before semi masker). + * + * If we add semi mask on probe side, position is ON_PROBE, direction is PROBE_TO_BUILD + * and dependency is BUILD_DEPENDS_ON_PROBE. + * + * 2. Unnesting correlated subquery + * + * When unnesting a correlated subquery, we first accumulate the probe plan and pass information to + * the build side. Semi mask position is NONE, direction is PROBE_TO_BUILD and dependency is + * PROBE_DEPENDS_ON_BUILD. + * + * 3. Optional match after update + * + * When performing optional match update, since we only have left join operator, update pipeline + * is placed on the probe side. However, by semantic, optional match should scan updated result. + * So we accumulate the probe side and make it the right-most pipeline to make sure update pipeline + * is executed before optional match pipeline. + * + * TODO(Xiyang): it worth thinking if we should simply put outer plan always on the build side. + * + * We disable semi-mask-based sip in the following cases + * + * During join order enumeration, we might disable semi mask if probe side cardinality is large to + * avoid large materialization probe side intermediate result. + * + * During filter push down, we disable semi mask if join condition is not id-based + * */ +struct SIPInfo { + SemiMaskPosition position = SemiMaskPosition::NONE; + SIPDependency dependency = SIPDependency::NONE; + SIPDirection direction = SIPDirection::NONE; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/planner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/planner.h new file mode 100644 index 0000000000..fd3e5fea16 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/planner.h @@ -0,0 +1,352 @@ +#pragma once + +#include "binder/bound_statement.h" +#include "binder/query/query_graph.h" +#include "common/enums/accumulate_type.h" +#include "common/enums/extend_direction.h" +#include "common/enums/join_type.h" +#include "planner/join_order/cardinality_estimator.h" +#include "planner/join_order_enumerator_context.h" +#include "planner/operator/logical_plan.h" +#include "planner/operator/sip/semi_mask_target_type.h" + +namespace lbug { +namespace extension { +class PlannerExtension; +} +namespace binder { +struct BoundTableScanInfo; +struct BoundCopyFromInfo; +struct BoundInsertInfo; +struct BoundSetPropertyInfo; +struct BoundDeleteInfo; +struct BoundJoinHintNode; +class NormalizedSingleQuery; +class NormalizedQueryPart; +class BoundReadingClause; +class BoundUpdatingClause; +class BoundProjectionBody; +} // namespace binder +namespace planner { + +struct LogicalInsertInfo; + +enum class SubqueryPlanningType : uint8_t { + NONE = 0, + UNNEST_CORRELATED = 1, + CORRELATED = 2, +}; + +struct QueryGraphPlanningInfo { + // Predicate info. + binder::expression_vector predicates; + // Subquery info. + SubqueryPlanningType subqueryType = SubqueryPlanningType::NONE; + binder::expression_vector corrExprs; + cardinality_t corrExprsCard = 0; + // Join hint info. + std::shared_ptr hint = nullptr; + + bool containsCorrExpr(const binder::Expression& expr) const; +}; + +// Group property expressions based on node/relationship. +class PropertyExprCollection { +public: + void addProperties(const std::string& patternName, + std::shared_ptr property); + binder::expression_vector getProperties(const binder::Expression& pattern) const; + binder::expression_vector getProperties() const; + + void clear(); + +private: + std::unordered_map patternNameToProperties; +}; + +class LBUG_API Planner { +public: + explicit Planner(main::ClientContext* clientContext); + DELETE_COPY_AND_MOVE(Planner); + + LogicalPlan planStatement(const binder::BoundStatement& statement); + + // Plan simple statement. + LogicalPlan planCreateTable(const binder::BoundStatement& statement); + LogicalPlan planCreateType(const binder::BoundStatement& statement); + LogicalPlan planCreateSequence(const binder::BoundStatement& statement); + LogicalPlan planCreateMacro(const binder::BoundStatement& statement); + LogicalPlan planDrop(const binder::BoundStatement& statement); + LogicalPlan planAlter(const binder::BoundStatement& statement); + LogicalPlan planStandaloneCall(const binder::BoundStatement& statement); + LogicalPlan planStandaloneCallFunction(const binder::BoundStatement& statement); + LogicalPlan planExplain(const binder::BoundStatement& statement); + LogicalPlan planTransaction(const binder::BoundStatement& statement); + LogicalPlan planExtension(const binder::BoundStatement& statement); + LogicalPlan planAttachDatabase(const binder::BoundStatement& statement); + LogicalPlan planDetachDatabase(const binder::BoundStatement& statement); + LogicalPlan planUseDatabase(const binder::BoundStatement& statement); + LogicalPlan planExtensionClause(const binder::BoundStatement& statement); + + // Plan copy. + LogicalPlan planCopyTo(const binder::BoundStatement& statement); + LogicalPlan planCopyFrom(const binder::BoundStatement& statement); + LogicalPlan planCopyNodeFrom(const binder::BoundCopyFromInfo* info); + LogicalPlan planCopyRelFrom(const binder::BoundCopyFromInfo* info); + + // Plan export/import database + std::vector> planExportTableData( + const binder::BoundStatement& boundExportDatabase); + LogicalPlan planExportDatabase(const binder::BoundStatement& statement); + LogicalPlan planImportDatabase(const binder::BoundStatement& statement); + + // Plan query. + LogicalPlan planQuery(const binder::BoundStatement& boundStatement); + LogicalPlan planSingleQuery(const binder::NormalizedSingleQuery& singleQuery); + void planQueryPart(const binder::NormalizedQueryPart& queryPart, LogicalPlan& prevPlan); + + // Plan read. + void planReadingClause(const binder::BoundReadingClause& readingClause, LogicalPlan& plan); + void planMatchClause(const binder::BoundReadingClause& readingClause, LogicalPlan& plan); + void planUnwindClause(const binder::BoundReadingClause& readingClause, LogicalPlan& plan); + void planTableFunctionCall(const binder::BoundReadingClause& readingClause, LogicalPlan& plan); + + void planReadOp(std::shared_ptr op, + const binder::expression_vector& predicates, LogicalPlan& plan); + void planLoadFrom(const binder::BoundReadingClause& readingClause, LogicalPlan& plan); + + // Plan updating + void planUpdatingClause(const binder::BoundUpdatingClause& updatingClause, LogicalPlan& plan); + void planInsertClause(const binder::BoundUpdatingClause& updatingClause, LogicalPlan& plan); + void planMergeClause(const binder::BoundUpdatingClause& updatingClause, LogicalPlan& plan); + void planSetClause(const binder::BoundUpdatingClause& updatingClause, LogicalPlan& plan); + void planDeleteClause(const binder::BoundUpdatingClause& updatingClause, LogicalPlan& plan); + + // Plan projection + void planProjectionBody(const binder::BoundProjectionBody* projectionBody, LogicalPlan& plan); + void planAggregate(const binder::expression_vector& expressionsToAggregate, + const binder::expression_vector& expressionsToGroupBy, LogicalPlan& plan); + void planOrderBy(const binder::expression_vector& expressionsToProject, + const binder::expression_vector& expressionsToOrderBy, const std::vector& isAscOrders, + LogicalPlan& plan); + + // Plan subquery + void planOptionalMatch(const binder::QueryGraphCollection& queryGraphCollection, + const binder::expression_vector& predicates, LogicalPlan& leftPlan, + std::shared_ptr hint); + // Write whether optional match succeed or not to mark. + void planOptionalMatch(const binder::QueryGraphCollection& queryGraphCollection, + const binder::expression_vector& predicates, std::shared_ptr mark, + LogicalPlan& leftPlan, std::shared_ptr hint); + void planRegularMatch(const binder::QueryGraphCollection& queryGraphCollection, + const binder::expression_vector& predicates, LogicalPlan& leftPlan, + std::shared_ptr hint); + void planSubquery(const std::shared_ptr& subquery, LogicalPlan& outerPlan); + void planSubqueryIfNecessary(std::shared_ptr expression, LogicalPlan& plan); + + static binder::expression_vector getCorrelatedExprs( + const binder::QueryGraphCollection& collection, const binder::expression_vector& predicates, + Schema* outerSchema); + + // Plan query graphs + LogicalPlan planQueryGraphCollectionInNewContext( + const binder::QueryGraphCollection& queryGraphCollection, + const QueryGraphPlanningInfo& info); + LogicalPlan planQueryGraphCollection(const binder::QueryGraphCollection& queryGraphCollection, + const QueryGraphPlanningInfo& info); + LogicalPlan planQueryGraph(const binder::QueryGraph& queryGraph, + const QueryGraphPlanningInfo& info); + + // Plan node/rel table scan + void planBaseTableScans(const QueryGraphPlanningInfo& info); + void planCorrelatedExpressionsScan(const QueryGraphPlanningInfo& info); + void planNodeScan(uint32_t nodePos); + void planNodeIDScan(uint32_t nodePos); + void planRelScan(uint32_t relPos); + void appendExtend(std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + common::ExtendDirection direction, const binder::expression_vector& properties, + LogicalPlan& plan); + + // Plan dp level + void planLevel(uint32_t level); + void planLevelExactly(uint32_t level); + void planLevelApproximately(uint32_t level); + + // Plan worst case optimal join + void planWCOJoin(uint32_t leftLevel, uint32_t rightLevel); + void planWCOJoin(const binder::SubqueryGraph& subgraph, + const std::vector>& rels, + const std::shared_ptr& intersectNode); + + // Plan index-nested-loop join / hash join + void planInnerJoin(uint32_t leftLevel, uint32_t rightLevel); + bool tryPlanINLJoin(const binder::SubqueryGraph& subgraph, + const binder::SubqueryGraph& otherSubgraph, + const std::vector>& joinNodes); + void planInnerHashJoin(const binder::SubqueryGraph& subgraph, + const binder::SubqueryGraph& otherSubgraph, + const std::vector>& joinNodes, bool flipPlan); + + // Plan semi mask + void appendNodeSemiMask(SemiMaskTargetType targetType, const binder::NodeExpression& node, + LogicalPlan& plan); + LogicalPlan getNodeSemiMaskPlan(SemiMaskTargetType targetType, + const binder::NodeExpression& node, std::shared_ptr nodePredicate); + + // This is mostly used when we try to reinterpret function output as node and read its + // properties, e.g. query_vector_index, gds algorithms ... + LogicalPlan getNodePropertyScanPlan(const binder::NodeExpression& node); + + // Append dummy sink + void appendDummySink(LogicalPlan& plan); + + // Append empty result + void appendEmptyResult(LogicalPlan& plan); + + // Append updating operators + void appendInsertNode(const std::vector& boundInsertInfos, + LogicalPlan& plan); + void appendInsertRel(const std::vector& boundInsertInfos, + LogicalPlan& plan); + + void appendSetProperty(const std::vector& infos, + LogicalPlan& plan); + void appendDelete(const std::vector& infos, LogicalPlan& plan); + std::unique_ptr createLogicalInsertInfo( + const binder::BoundInsertInfo* info) const; + + // Append projection operators + void appendProjection(const binder::expression_vector& expressionsToProject, LogicalPlan& plan); + void appendAggregate(const binder::expression_vector& expressionsToGroupBy, + const binder::expression_vector& expressionsToAggregate, LogicalPlan& plan); + void appendOrderBy(const binder::expression_vector& expressions, + const std::vector& isAscOrders, LogicalPlan& plan); + void appendMultiplicityReducer(LogicalPlan& plan); + void appendLimit(std::shared_ptr skipNum, + std::shared_ptr limitNum, LogicalPlan& plan); + + // Append scan operators + void appendExpressionsScan(const binder::expression_vector& expressions, LogicalPlan& plan); + void appendScanNodeTable(std::shared_ptr nodeID, + std::vector tableIDs, const binder::expression_vector& properties, + LogicalPlan& plan); + + // Append extend operators + void appendNonRecursiveExtend(const std::shared_ptr& boundNode, + const std::shared_ptr& nbrNode, + const std::shared_ptr& rel, common::ExtendDirection direction, + bool extendFromSource, const binder::expression_vector& properties, LogicalPlan& plan); + void appendRecursiveExtend(const std::shared_ptr& boundNode, + const std::shared_ptr& nbrNode, + const std::shared_ptr& rel, common::ExtendDirection direction, + LogicalPlan& plan); + void createPathNodePropertyScanPlan(const std::shared_ptr& node, + const binder::expression_vector& properties, LogicalPlan& plan); + void createPathRelPropertyScanPlan(const std::shared_ptr& boundNode, + const std::shared_ptr& nbrNode, + const std::shared_ptr& recursiveRel, + common::ExtendDirection direction, bool extendFromSource, + const binder::expression_vector& properties, LogicalPlan& plan); + void appendNodeLabelFilter(std::shared_ptr nodeID, + std::unordered_set tableIDSet, LogicalPlan& plan); + + // Append Join operators + void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType, + LogicalPlan& probePlan, LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType, + std::shared_ptr mark, LogicalPlan& probePlan, LogicalPlan& buildPlan, + LogicalPlan& resultPlan); + void appendHashJoin(const std::vector& joinConditions, + common::JoinType joinType, std::shared_ptr mark, LogicalPlan& probePlan, + LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendAccHashJoin(const std::vector& joinConditions, + common::JoinType joinType, std::shared_ptr mark, LogicalPlan& probePlan, + LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendMarkJoin(const binder::expression_vector& joinNodeIDs, + const std::shared_ptr& mark, LogicalPlan& probePlan, + LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendMarkJoin(const std::vector& joinConditions, + const std::shared_ptr& mark, LogicalPlan& probePlan, + LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendIntersect(const std::shared_ptr& intersectNodeID, + binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan, + std::vector& buildPlans); + + void appendCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan, + LogicalPlan& resultPlan); + // Optional cross product produce at least one tuple for each probe tuple + void appendOptionalCrossProduct(std::shared_ptr mark, + const LogicalPlan& probePlan, const LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendAccOptionalCrossProduct(std::shared_ptr mark, + LogicalPlan& probePlan, const LogicalPlan& buildPlan, LogicalPlan& resultPlan); + void appendCrossProduct(common::AccumulateType accumulateType, + std::shared_ptr mark, const LogicalPlan& probePlan, + const LogicalPlan& buildPlan, LogicalPlan& resultPlan); + + // Append accumulate operators + // Skip if plan has been accumulated. + void tryAppendAccumulate(LogicalPlan& plan); + // Accumulate everything. + void appendAccumulate(LogicalPlan& plan); + // Accumulate everything. Append mark. + void appendOptionalAccumulate(std::shared_ptr mark, LogicalPlan& plan); + // Append accumulate with a set of expressions being flattened first. + void appendAccumulate(const binder::expression_vector& flatExprs, LogicalPlan& plan); + // Append accumulate with a set of expressions being flattened first. Append mark. + void appendAccumulate(common::AccumulateType accumulateType, + const binder::expression_vector& flatExprs, std::shared_ptr mark, + LogicalPlan& plan); + + void appendDummyScan(LogicalPlan& plan); + + void appendUnwind(const binder::BoundReadingClause& boundReadingClause, LogicalPlan& plan); + + void appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan); + void appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan); + + void appendFilters(const binder::expression_vector& predicates, LogicalPlan& plan); + void appendFilter(const std::shared_ptr& predicate, LogicalPlan& plan); + + void appendTableFunctionCall(const binder::BoundTableScanInfo& info, LogicalPlan& plan); + + void appendDistinct(const binder::expression_vector& keys, LogicalPlan& plan); + + const CardinalityEstimator& getCardinalityEstimator() const { return cardinalityEstimator; } + CardinalityEstimator& getCardinliatyEstimatorUnsafe() { return cardinalityEstimator; } + + // Get operators + static std::shared_ptr getTableFunctionCall( + const binder::BoundTableScanInfo& info); + static std::shared_ptr getTableFunctionCall( + const binder::BoundReadingClause& readingClause); + + LogicalPlan createUnionPlan(std::vector& childrenPlans, + const binder::expression_vector& expressions, bool isUnionAll); + + binder::expression_vector getProperties(const binder::Expression& pattern) const; + + JoinOrderEnumeratorContext enterNewContext(); + void exitContext(JoinOrderEnumeratorContext prevContext); + PropertyExprCollection enterNewPropertyExprCollection(); + void exitPropertyExprCollection(PropertyExprCollection collection); + + static binder::expression_vector getNewlyMatchedExprs( + const std::vector& prevs, const binder::SubqueryGraph& new_, + const binder::expression_vector& exprs); + static binder::expression_vector getNewlyMatchedExprs(const binder::SubqueryGraph& prev, + const binder::SubqueryGraph& new_, const binder::expression_vector& exprs); + static binder::expression_vector getNewlyMatchedExprs(const binder::SubqueryGraph& leftPrev, + const binder::SubqueryGraph& rightPrev, const binder::SubqueryGraph& new_, + const binder::expression_vector& exprs); + +private: + main::ClientContext* clientContext; + PropertyExprCollection propertyExprCollection; + CardinalityEstimator cardinalityEstimator; + JoinOrderEnumeratorContext context; + std::vector plannerExtensions; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/subplans_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/subplans_table.h new file mode 100644 index 0000000000..fc2b3219ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/planner/subplans_table.h @@ -0,0 +1,105 @@ +#pragma once + +#include + +#include "binder/query/query_graph.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace planner { + +const uint64_t MAX_LEVEL_TO_PLAN_EXACTLY = 7; + +// Different from vanilla dp algorithm where one optimal plan is kept per subgraph, we keep multiple +// plans each with a different factorization structure. The following example will explain our +// rationale. +// Given a triangle with an outgoing edge +// MATCH (a)->(b)->(c), (a)->(c), (c)->(d) +// At level 3 (assume level is based on num of nodes) for subgraph "abc", if we ignore factorization +// structure, the 3 plans that intersects on "a", "b", or "c" are considered homogenous and one of +// them will be picked. +// Then at level 4 for subgraph "abcd", we know the plan that intersect on "c" will be worse because +// we need to further flatten it and extend to "d". +// Therefore, we try to be factorization aware when keeping optimal plans. +class SubgraphPlans { +public: + explicit SubgraphPlans(const binder::SubqueryGraph& subqueryGraph); + + uint64_t getMaxCost() const { return maxCost; } + + void addPlan(LogicalPlan plan); + + const std::vector& getPlans() const { return plans; } + +private: + // To balance computation time, we encode plan by only considering the flat information of the + // nodes that are involved in current subgraph. + std::bitset encodePlan(const LogicalPlan& plan); + +private: + constexpr static uint32_t MAX_NUM_PLANS = 10; + +private: + uint64_t maxCost = UINT64_MAX; + binder::expression_vector nodeIDsToEncode; + std::vector plans; + std::unordered_map, common::idx_t> + encodedPlan2PlanIdx; +}; + +// A DPLevel is a collection of plans per subgraph. All subgraph should have the same number of +// variables. +class DPLevel { +public: + bool contains(const binder::SubqueryGraph& subqueryGraph) const { + return subgraph2Plans.contains(subqueryGraph); + } + + const SubgraphPlans& getSubgraphPlans(const binder::SubqueryGraph& subqueryGraph) const { + return subgraph2Plans.at(subqueryGraph); + } + + std::vector getSubqueryGraphs(); + + void addPlan(const binder::SubqueryGraph& subqueryGraph, LogicalPlan plan); + + void clear() { subgraph2Plans.clear(); } + +private: + constexpr static uint32_t MAX_NUM_SUBGRAPH = 50; + +private: + binder::subquery_graph_V_map_t subgraph2Plans; +}; + +class SubPlansTable { +public: + void resize(uint32_t newSize); + + uint64_t getMaxCost(const binder::SubqueryGraph& subqueryGraph) const; + + bool containSubgraphPlans(const binder::SubqueryGraph& subqueryGraph) const; + + const std::vector& getSubgraphPlans( + const binder::SubqueryGraph& subqueryGraph) const; + + std::vector getSubqueryGraphs(uint32_t level); + + void addPlan(const binder::SubqueryGraph& subqueryGraph, LogicalPlan plan); + + void clear(); + +private: + const DPLevel& getDPLevel(const binder::SubqueryGraph& subqueryGraph) const { + return dpLevels[subqueryGraph.getTotalNumVariables()]; + } + DPLevel& getDPLevelUnsafe(const binder::SubqueryGraph& subqueryGraph) { + return dpLevels[subqueryGraph.getTotalNumVariables()]; + } + +private: + std::vector dpLevels; +}; + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/data_pos.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/data_pos.h new file mode 100644 index 0000000000..5b44d8bf25 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/data_pos.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include "common/types/types.h" + +namespace lbug { +namespace processor { + +using data_chunk_pos_t = common::idx_t; +constexpr data_chunk_pos_t INVALID_DATA_CHUNK_POS = common::INVALID_IDX; +using value_vector_pos_t = common::idx_t; +constexpr value_vector_pos_t INVALID_VALUE_VECTOR_POS = common::INVALID_IDX; + +struct DataPos { + data_chunk_pos_t dataChunkPos; + value_vector_pos_t valueVectorPos; + + DataPos() : dataChunkPos{INVALID_DATA_CHUNK_POS}, valueVectorPos{INVALID_VALUE_VECTOR_POS} {} + explicit DataPos(data_chunk_pos_t dataChunkPos, value_vector_pos_t valueVectorPos) + : dataChunkPos{dataChunkPos}, valueVectorPos{valueVectorPos} {} + explicit DataPos(std::pair pos) + : dataChunkPos{pos.first}, valueVectorPos{pos.second} {} + + static DataPos getInvalidPos() { return DataPos(); } + bool isValid() const { + return dataChunkPos != INVALID_DATA_CHUNK_POS && valueVectorPos != INVALID_VALUE_VECTOR_POS; + } + + inline bool operator==(const DataPos& rhs) const { + return (dataChunkPos == rhs.dataChunkPos) && (valueVectorPos == rhs.valueVectorPos); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/execution_context.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/execution_context.h new file mode 100644 index 0000000000..5197b8e26d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/execution_context.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common/profiler.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace processor { + +struct LBUG_API ExecutionContext { + uint64_t queryID; + common::Profiler* profiler; + main::ClientContext* clientContext; + + ExecutionContext(common::Profiler* profiler, main::ClientContext* clientContext, + uint64_t queryID) + : queryID{queryID}, profiler{profiler}, clientContext{clientContext} {} +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/expression_mapper.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/expression_mapper.h new file mode 100644 index 0000000000..1fec111536 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/expression_mapper.h @@ -0,0 +1,60 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/result/result_set_descriptor.h" + +namespace lbug { +namespace processor { + +class ExpressionMapper { +public: + ExpressionMapper() = default; + explicit ExpressionMapper(const planner::Schema* schema) : schema{schema} {} + ExpressionMapper(const planner::Schema* schema, evaluator::ExpressionEvaluator* parent) + : schema{schema}, parentEvaluator{parent} {} + + std::unique_ptr getEvaluator( + std::shared_ptr expression); + std::unique_ptr getConstantEvaluator( + std::shared_ptr expression); + +private: + static std::unique_ptr getLiteralEvaluator( + std::shared_ptr expression); + + static std::unique_ptr getParameterEvaluator( + std::shared_ptr expression); + + std::unique_ptr getReferenceEvaluator( + std::shared_ptr expression) const; + + static std::unique_ptr getLambdaParamEvaluator( + std::shared_ptr expression); + + std::unique_ptr getCaseEvaluator( + std::shared_ptr expression); + + std::unique_ptr getFunctionEvaluator( + std::shared_ptr expression); + + std::unique_ptr getNodeEvaluator( + std::shared_ptr expression); + + std::unique_ptr getRelEvaluator( + std::shared_ptr expression); + + std::unique_ptr getPathEvaluator( + std::shared_ptr expression); + + std::vector> getEvaluators( + const binder::expression_vector& expressions); + +private: + const planner::Schema* schema = nullptr; + // TODO: comment + evaluator::ExpressionEvaluator* parentEvaluator = nullptr; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/aggregate_hash_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/aggregate_hash_table.h new file mode 100644 index 0000000000..0d7b5a232e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/aggregate_hash_table.h @@ -0,0 +1,342 @@ +#pragma once + +#include + +#include "aggregate_input.h" +#include "common/copy_constructors.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/null_mask.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "function/aggregate_function.h" +#include "processor/result/base_hash_table.h" +#include "processor/result/factorized_table.h" +#include "processor/result/factorized_table_schema.h" + +namespace lbug { +namespace common { +class InMemOverflowBuffer; +} +namespace storage { +class MemoryManager; +} +namespace processor { + +class HashSlot { + // upper 7 bits are for the fingerprint, the remaining 57 bits are for the pointer. + // The largest pointer size seems to be 57 bytes for intel's 5-level paging + static constexpr size_t FINGERPRINT_BITS = 7; + static constexpr size_t POINTER_BITS = 57; + +public: + HashSlot(common::hash_t hash, const uint8_t* entry) + : entry(reinterpret_cast(entry) | + (hash & common::NULL_HIGH_MASKS[FINGERPRINT_BITS])) {} + + bool checkFingerprint(common::hash_t hash) const { + return (entry >> POINTER_BITS) == (hash >> POINTER_BITS); + } + + // pointer to the factorizedTable entry which stores [groupKey1, ... + // groupKeyN, aggregateState1, ..., aggregateStateN, hashValue]. + uint8_t* getEntry() const { + return reinterpret_cast(entry & common::NULL_LOWER_MASKS[POINTER_BITS]); + } + +private: + uint64_t entry; +}; + +enum class HashTableType : uint8_t { AGGREGATE_HASH_TABLE = 0, MARK_HASH_TABLE = 1 }; + +/** + * AggregateHashTable Design + * + * 1. Payload + * Entry layout: [groupKey1, ... groupKeyN, aggregateState1, ..., aggregateStateN, hashValue] + * Payload is stored in the factorizedTable. + * + * 2. Hash slot + * Layout : see HashSlot struct + * If the entry is a nullptr, then the current hashSlot is unused. + * + * 3. Collision handling + * Linear probing. When collision happens, we find the next hash slot whose entry is a + * nullptr. + * + */ +class AggregateHashTable; +using update_agg_function_t = std::function&, const std::vector&, + function::AggregateFunction&, common::ValueVector*, uint64_t, uint32_t, uint32_t)>; + +class AggregateHashTable : public BaseHashTable { +public: + AggregateHashTable(storage::MemoryManager& memoryManager, + const std::vector& keyTypes, + const std::vector& payloadTypes, uint64_t numEntriesToAllocate, + FactorizedTableSchema tableSchema) + : AggregateHashTable(memoryManager, common::LogicalType::copy(keyTypes), + common::LogicalType::copy(payloadTypes), + std::vector{} /* empty aggregates */, + std::vector{} /* empty distinct agg key*/, numEntriesToAllocate, + std::move(tableSchema)) {} + + AggregateHashTable(storage::MemoryManager& memoryManager, + std::vector keyTypes, std::vector payloadTypes, + const std::vector& aggregateFunctions, + const std::vector& distinctAggKeyTypes, uint64_t numEntriesToAllocate, + FactorizedTableSchema tableSchema); + + //! merge aggregate hash table by combining aggregate states under the same key + void merge(FactorizedTable&& other); + void merge(AggregateHashTable&& other) { merge(std::move(*other.factorizedTable)); } + // Must be called after merging hash tables with distinct functions, but only when the + // merged distinct tuples match the merged non-distinct tuples + void mergeDistinctAggregateInfo(); + + void finalizeAggregateStates(); + + void resize(uint64_t newSize); + void clear(); + void resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys); + + AggregateHashTable createEmptyCopy() const { return AggregateHashTable(*this); } + + DEFAULT_BOTH_MOVE(AggregateHashTable); + AggregateHashTable* getDistinctHashTable(uint64_t aggregateFunctionIdx) const { + return distinctHashTables[aggregateFunctionIdx].get(); + } + + void appendDistinct(const std::vector& keyVectors, + common::ValueVector* aggregateVector, const common::DataChunkState* leadingState); + +protected: + virtual uint64_t append(const std::vector& keyVectors, + const common::DataChunkState* leadingState, + const std::vector& aggregateInputs, uint64_t resultSetMultiplicity) { + return append(keyVectors, std::vector{} /*dependentKeyVectors*/, + leadingState, aggregateInputs, resultSetMultiplicity); + } + + virtual uint64_t append(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, + const common::DataChunkState* leadingState, + const std::vector& aggregateInputs, uint64_t resultSetMultiplicity); + + virtual uint64_t matchFTEntries(std::span keyVectors, + uint64_t numMayMatches, uint64_t numNoMatches); + + uint64_t matchFTEntries(const FactorizedTable& srcTable, uint64_t startOffset, + uint64_t numMayMatches, uint64_t numNoMatches); + + void initializeFTEntries(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, + uint64_t numFTEntriesToInitialize); + void initializeFTEntries(const FactorizedTable& sourceTable, uint64_t sourceStartOffset, + uint64_t numFTEntriesToInitialize); + + uint64_t matchUnFlatVecWithFTColumn(const common::ValueVector* vector, uint64_t numMayMatches, + uint64_t& numNoMatches, uint32_t colIdx); + + uint64_t matchFlatVecWithFTColumn(const common::ValueVector* vector, uint64_t numMayMatches, + uint64_t& numNoMatches, uint32_t colIdx); + + void findHashSlots(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, + const common::DataChunkState* leadingState); + + void findHashSlots(const FactorizedTable& data, uint64_t startOffset, uint64_t numTuples); + +protected: + void initializeFT(const std::vector& aggregateFunctions, + FactorizedTableSchema&& tableSchema); + + void initializeHashTable(uint64_t numEntriesToAllocate); + + void initializeTmpVectors(); + + // ! This function will only be used by distinct aggregate, which assumes that all + // groupByKeys are flat. + uint8_t* findEntryInDistinctHT(const std::vector& groupByKeyVectors, + common::hash_t hash); + + void initializeFTEntryWithFlatVec(common::ValueVector* flatVector, + uint64_t numEntriesToInitialize, uint32_t colIdx); + + void initializeFTEntryWithUnFlatVec(common::ValueVector* unFlatVector, + uint64_t numEntriesToInitialize, uint32_t colIdx); + + uint8_t* createEntryInDistinctHT(const std::vector& groupByHashKeyVectors, + common::hash_t hash); + + void increaseSlotIdx(uint64_t& slotIdx) const; + + void initTmpHashSlotsAndIdxes(); + void initTmpHashSlotsAndIdxes(const FactorizedTable& sourceTable, uint64_t startOffset, + uint64_t numTuples); + + void increaseHashSlotIdxes(uint64_t numNoMatches); + + void updateAggState(const std::vector& keyVectors, + function::AggregateFunction& aggregateFunction, common::ValueVector* aggVector, + uint64_t multiplicity, uint32_t aggStateOffset, + const common::DataChunkState* firstUnFlatState); + + void updateAggStates(const std::vector& keyVectors, + const std::vector& aggregateInputs, uint64_t resultSetMultiplicity, + const common::DataChunkState* firstUnFlatState); + + void fillEntryWithInitialNullAggregateState(FactorizedTable& table, uint8_t* entry); + + //! find an uninitialized hash slot for given hash and fill hash slot with block id and + //! offset + void fillHashSlot(common::hash_t hash, uint8_t* groupByKeysAndAggregateStateBuffer); + + inline HashSlot* getHashSlot(uint64_t slotIdx) { + KU_ASSERT(slotIdx < maxNumHashSlots); + // If the slotIdx is smaller than the numHashSlotsPerBlock, then the hashSlot must be + // in the first hashSlotsBlock. We don't need to compute the blockIdx and blockOffset. + return slotIdx < ((uint64_t)1 << numSlotsPerBlockLog2) ? + (HashSlot*)(hashSlotsBlocks[0]->getData() + slotIdx * sizeof(HashSlot)) : + (HashSlot*)(hashSlotsBlocks[slotIdx >> numSlotsPerBlockLog2]->getData() + + (slotIdx & slotIdxInBlockMask) * sizeof(HashSlot)); + } + + void addDataBlocksIfNecessary(uint64_t maxNumHashSlots); + + void updateNullAggVectorState(const common::DataChunkState& keyState, + function::AggregateFunction& aggregateFunction, uint64_t multiplicity, + uint32_t aggStateOffset); + + void updateBothFlatAggVectorState(function::AggregateFunction& aggregateFunction, + common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset); + + void updateFlatUnFlatKeyFlatAggVectorState(const common::DataChunkState& unFlatKeyState, + function::AggregateFunction& aggregateFunction, common::ValueVector* aggVector, + uint64_t multiplicity, uint32_t aggStateOffset); + + void updateFlatKeyUnFlatAggVectorState(const std::vector& flatKeyVectors, + function::AggregateFunction& aggregateFunction, common::ValueVector* aggVector, + uint64_t multiplicity, uint32_t aggStateOffset); + + void updateBothUnFlatSameDCAggVectorState(function::AggregateFunction& aggregateFunction, + common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset); + + void updateBothUnFlatDifferentDCAggVectorState(const common::DataChunkState& unFlatKeyState, + function::AggregateFunction& aggregateFunction, common::ValueVector* aggVector, + uint64_t multiplicity, uint32_t aggStateOffset); + + static std::vector getDistinctAggKeyTypes( + const AggregateHashTable& hashTable) { + std::vector distinctAggKeyTypes(hashTable.distinctHashTables.size()); + std::transform(hashTable.distinctHashTables.begin(), hashTable.distinctHashTables.end(), + distinctAggKeyTypes.begin(), [&](const auto& distinctHashTable) { + if (distinctHashTable) { + return distinctHashTable->keyTypes.back().copy(); + } else { + return common::LogicalType(); + } + }); + return distinctAggKeyTypes; + } + + template + uint8_t* findEntry(common::hash_t hash, Func compareKeys) { + auto slotIdx = getSlotIdxForHash(hash); + while (true) { + auto slot = (HashSlot*)getHashSlot(slotIdx); + if (slot->getEntry() == nullptr) { + return nullptr; + } else if (slot->checkFingerprint(hash) && compareKeys(slot->getEntry())) { + return slot->getEntry(); + } + increaseSlotIdx(slotIdx); + } + } + +private: + // Does not copy the contents of the hash table and is provided as a convenient way of + // constructing more hash tables without having to hold on to or expose the construction + // arguments via createEmptyCopy + AggregateHashTable(const AggregateHashTable& other) + : AggregateHashTable(*other.memoryManager, common::LogicalType::copy(other.keyTypes), + common::LogicalType::copy(other.payloadTypes), other.aggregateFunctions, + getDistinctAggKeyTypes(other), 0, other.getTableSchema()->copy()) {} + +protected: + uint32_t hashColIdxInFT{}; + std::unique_ptr mayMatchIdxes; + std::unique_ptr noMatchIdxes; + std::unique_ptr entryIdxesToInitialize; + std::unique_ptr hashSlotsToUpdateAggState; + + std::vector payloadTypes; + std::vector aggregateFunctions; + + //! special handling of distinct aggregate + std::vector> distinctHashTables; + std::vector distinctHashEntriesProcessed; + uint32_t hashColOffsetInFT{}; + uint32_t aggStateColOffsetInFT{}; + uint32_t aggStateColIdxInFT{}; + uint32_t numBytesForKeys = 0; + uint32_t numBytesForDependentKeys = 0; + std::vector updateAggFuncs; + // Temporary arrays to hold intermediate results. + std::unique_ptr tmpValueIdxes; + std::unique_ptr tmpSlotIdxes; +}; + +struct AggregateHashTableUtils { + static std::unique_ptr createDistinctHashTable( + storage::MemoryManager& memoryManager, + const std::vector& groupByKeyTypes, + const common::LogicalType& distinctKeyType); + + static FactorizedTableSchema getTableSchemaForKeys( + const std::vector& groupByKeyTypes, + const common::LogicalType& distinctKeyType); +}; + +// Separate class since the SimpleAggregate has multiple different top-level destinations for +// partitioning +class AggregatePartitioningData { +public: + virtual ~AggregatePartitioningData() = default; + virtual void appendTuples(const FactorizedTable& table, ft_col_offset_t hashOffset) = 0; + virtual void appendDistinctTuple(size_t /*distinctFuncIndex*/, std::span /*tuple*/, + common::hash_t /*hash*/) = 0; + virtual void appendOverflow(common::InMemOverflowBuffer&& overflowBuffer) = 0; +}; + +// Fixed-sized Aggregate hash table that flushes tuples into partitions in the +// HashAggregateSharedState when full +class PartitioningAggregateHashTable final : public AggregateHashTable { +public: + PartitioningAggregateHashTable(AggregatePartitioningData* partitioningData, + storage::MemoryManager& memoryManager, std::vector keyTypes, + std::vector payloadTypes, + const std::vector& aggregateFunctions, + const std::vector& distinctAggKeyTypes, + FactorizedTableSchema tableSchema) + : AggregateHashTable(memoryManager, std::move(keyTypes), std::move(payloadTypes), + aggregateFunctions, distinctAggKeyTypes, + common::DEFAULT_VECTOR_CAPACITY /*minimum size*/, tableSchema.copy()), + tableSchema{std::move(tableSchema)}, partitioningData{partitioningData} {} + + uint64_t append(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, + const common::DataChunkState* leadingState, + const std::vector& aggregateInputs, + uint64_t resultSetMultiplicity) override; + + void mergeIfFull(uint64_t tuplesToAdd, bool mergeAll = false); + +private: + FactorizedTableSchema tableSchema; + AggregatePartitioningData* partitioningData; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/aggregate_input.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/aggregate_input.h new file mode 100644 index 0000000000..f2d1db3ff4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/aggregate_input.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/data_chunk/data_chunk.h" +#include "processor/data_pos.h" + +namespace lbug { +namespace processor { + +struct AggregateInfo { + DataPos aggVectorPos; + std::vector multiplicityChunksPos; + common::LogicalType distinctAggKeyType; + + AggregateInfo(const DataPos& aggVectorPos, std::vector multiplicityChunksPos, + common::LogicalType distinctAggKeyType) + : aggVectorPos{aggVectorPos}, multiplicityChunksPos{std::move(multiplicityChunksPos)}, + distinctAggKeyType{std::move(distinctAggKeyType)} {} + EXPLICIT_COPY_DEFAULT_MOVE(AggregateInfo); + +private: + AggregateInfo(const AggregateInfo& other) + : aggVectorPos{other.aggVectorPos}, multiplicityChunksPos{other.multiplicityChunksPos}, + distinctAggKeyType{other.distinctAggKeyType.copy()} {} +}; + +struct AggregateInput { + common::ValueVector* aggregateVector; + std::vector multiplicityChunks; + + AggregateInput() : aggregateVector{nullptr} {} + EXPLICIT_COPY_DEFAULT_MOVE(AggregateInput); + +private: + AggregateInput(const AggregateInput& other) + : aggregateVector{other.aggregateVector}, multiplicityChunks{other.multiplicityChunks} {} +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/base_aggregate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/base_aggregate.h new file mode 100644 index 0000000000..d64c58a991 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/base_aggregate.h @@ -0,0 +1,147 @@ +#pragma once + +#include + +#include "aggregate_input.h" +#include "common/mpsc_queue.h" +#include "function/aggregate_function.h" +#include "processor/operator/sink.h" +#include "processor/result/factorized_table.h" +#include "processor/result/factorized_table_schema.h" + +namespace lbug { +namespace main { +class ClientContext; +} +namespace processor { +class AggregateHashTable; + +size_t getNumPartitionsForParallelism(main::ClientContext* context); + +class BaseAggregateSharedState { + friend class BaseAggregate; + +public: + template + void finalizePartitions(std::vector& globalPartitions, Func finalizeFunc) { + for (auto& partition : globalPartitions) { + if (!partition.finalized && partition.mtx.try_lock()) { + if (partition.finalized) { + // If there was a data race in the above && a thread may get through after + // another thread has finalized this partition Ignore coverage since we can't + // reliably test this data race + // LCOV_EXCL_START + partition.mtx.unlock(); + continue; + // LCOV_EXCL_END + } + finalizeFunc(partition); + partition.finalized = true; + partition.mtx.unlock(); + } + } + } + + bool isReadyForFinalization() const { return readyForFinalization; } + +protected: + explicit BaseAggregateSharedState( + const std::vector& aggregateFunctions, size_t numPartitions); + + virtual std::pair getNextRangeToRead() = 0; + + ~BaseAggregateSharedState() = default; + + void finalizeAggregateHashTable(const AggregateHashTable& localHashTable); + + class HashTableQueue { + public: + HashTableQueue(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema); + + std::unique_ptr copy() const { + return std::make_unique(headBlock.load()->table.getMemoryManager(), + headBlock.load()->table.getTableSchema()->copy()); + } + ~HashTableQueue(); + + void appendTuple(std::span tuple); + + void mergeInto(AggregateHashTable& hashTable); + + bool empty() const { + auto headBlock = this->headBlock.load(); + return (headBlock == nullptr || headBlock->numTuplesReserved == 0) && + queuedTuples.approxSize() == 0; + } + + struct TupleBlock { + TupleBlock(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema) + : numTuplesReserved{0}, numTuplesWritten{0}, + table{memoryManager, std::move(tableSchema)} { + // Start at a fixed capacity of one full block (so that concurrent writes are safe). + // If it is not filled, we resize it to the actual capacity before writing it to the + // hashTable + table.resize(table.getNumTuplesPerBlock()); + } + // numTuplesReserved may be greater than the capacity of the factorizedTable + // if threads try to write to it while a new block is being allocated + // So it should not be relied on for anything other than reserving tuples + std::atomic numTuplesReserved; + // Set after the tuple has been written to the block. + // Once numTuplesWritten == factorizedTable.getNumTuplesPerBlock() all writes have + // finished + std::atomic numTuplesWritten; + FactorizedTable table; + }; + common::MPSCQueue queuedTuples; + // When queueing tuples, they are always added to the headBlock until the headBlock is full + // (numTuplesReserved >= factorizedTable.getNumTuplesPerBlock()), then pushed into the + // queuedTuples (at which point, the numTuplesReserved may not be equal to the + // numTuplesWritten) + std::atomic headBlock; + uint64_t numTuplesPerBlock; + }; + +protected: + std::mutex mtx; + std::atomic currentOffset; + std::vector aggregateFunctions; + std::atomic numThreadsFinishedProducing; + std::atomic numThreads; + common::MPSCQueue> overflow; + uint8_t shiftForPartitioning; + bool readyForFinalization; +}; + +class BaseAggregate : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::AGGREGATE; + +protected: + BaseAggregate(std::shared_ptr sharedState, + std::vector aggregateFunctions, + std::vector aggInfos, std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, + aggregateFunctions{std::move(aggregateFunctions)}, aggInfos{std::move(aggInfos)}, + sharedState{std::move(sharedState)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool containDistinctAggregate() const; + + void finalizeInternal(ExecutionContext* /*context*/) override { + // Delegated to HashAggregateFinalize so it can be parallelized + sharedState->readyForFinalization = true; + } + + std::unique_ptr copy() override = 0; + +protected: + std::vector aggregateFunctions; + std::vector aggInfos; + std::vector aggInputs; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/base_aggregate_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/base_aggregate_scan.h new file mode 100644 index 0000000000..e3edbbf63c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/base_aggregate_scan.h @@ -0,0 +1,44 @@ +#pragma once + +#include "function/aggregate_function.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +using move_agg_result_to_vector_func = std::function; + +struct AggregateScanInfo { + std::vector aggregatesPos; + std::vector moveAggResultToVectorFuncs; +}; + +class BaseAggregateScan : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::AGGREGATE_SCAN; + +public: + BaseAggregateScan(AggregateScanInfo scanInfo, std::unique_ptr child, + uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + scanInfo{std::move(scanInfo)} {} + + BaseAggregateScan(AggregateScanInfo scanInfo, physical_op_id id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, scanInfo{std::move(scanInfo)} {} + + bool isSource() const override { return true; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override = 0; + + std::unique_ptr copy() override = 0; + +protected: + AggregateScanInfo scanInfo; + std::vector> aggregateVectors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/hash_aggregate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/hash_aggregate.h new file mode 100644 index 0000000000..a5204c2180 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/hash_aggregate.h @@ -0,0 +1,205 @@ +#pragma once + +#include + +#include +#include +#include + +#include "aggregate_hash_table.h" +#include "common/cast.h" +#include "common/copy_constructors.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/in_mem_overflow_buffer.h" +#include "common/mpsc_queue.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "processor/operator/aggregate/aggregate_input.h" +#include "processor/operator/aggregate/base_aggregate.h" +#include "processor/operator/physical_operator.h" +#include "processor/result/factorized_table.h" +#include "processor/result/factorized_table_schema.h" + +namespace lbug { +namespace processor { + +struct HashAggregateInfo { + std::vector flatKeysPos; + std::vector unFlatKeysPos; + std::vector dependentKeysPos; + FactorizedTableSchema tableSchema; + + HashAggregateInfo(std::vector flatKeysPos, std::vector unFlatKeysPos, + std::vector dependentKeysPos, FactorizedTableSchema tableSchema); + EXPLICIT_COPY_DEFAULT_MOVE(HashAggregateInfo); + +private: + HashAggregateInfo(const HashAggregateInfo& other); +}; + +// NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor): This is a final class. +class HashAggregateSharedState final : public BaseAggregateSharedState, + public AggregatePartitioningData { + +public: + explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashAggInfo, + const std::vector& aggregateFunctions, + std::span aggregateInfos, std::vector keyTypes, + std::vector payloadTypes); + + void appendTuples(const FactorizedTable& factorizedTable, ft_col_offset_t hashOffset) override { + auto numBytesPerTuple = factorizedTable.getTableSchema()->getNumBytesPerTuple(); + for (ft_tuple_idx_t tupleIdx = 0; tupleIdx < factorizedTable.getNumTuples(); tupleIdx++) { + auto tuple = factorizedTable.getTuple(tupleIdx); + auto hash = *reinterpret_cast(tuple + hashOffset); + auto& partition = + globalPartitions[(hash >> shiftForPartitioning) % globalPartitions.size()]; + partition.queue->appendTuple(std::span(tuple, numBytesPerTuple)); + } + } + + void appendDistinctTuple(size_t distinctFuncIndex, std::span tuple, + common::hash_t hash) override { + auto& partition = + globalPartitions[(hash >> shiftForPartitioning) % globalPartitions.size()]; + partition.distinctTableQueues[distinctFuncIndex]->appendTuple(tuple); + } + + void appendOverflow(common::InMemOverflowBuffer&& overflowBuffer) override { + overflow.push(std::make_unique(std::move(overflowBuffer))); + } + + void finalizePartitions(); + + std::pair getNextRangeToRead() override; + + void scan(std::span entries, std::vector& keyVectors, + common::offset_t startOffset, common::offset_t numRowsToScan, + std::vector& columnIndices); + + uint64_t getNumTuples() const; + + uint64_t getCurrentOffset() const { return currentOffset; } + + void setLimitNumber(uint64_t num) { limitNumber = num; } + uint64_t getLimitNumber() const { return limitNumber; } + + const FactorizedTableSchema* getTableSchema() const { + return globalPartitions[0].hashTable->getTableSchema(); + } + + const HashAggregateInfo& getAggregateInfo() const { return aggInfo; } + + void assertFinalized() const; + +protected: + std::tuple getPartitionForOffset( + common::offset_t offset) const; + + struct Partition { + std::unique_ptr hashTable; + std::mutex mtx; + std::unique_ptr queue; + // The tables storing the distinct values for distinct aggregate functions all get merged in + // the same way as the main table + std::vector> distinctTableQueues; + std::atomic finalized = false; + }; + +public: + HashAggregateInfo aggInfo; + uint64_t limitNumber; + storage::MemoryManager* memoryManager; + std::vector globalPartitions; +}; + +struct HashAggregateLocalState { + std::vector keyVectors; + std::vector dependentKeyVectors; + common::DataChunkState* leadingState = nullptr; + std::unique_ptr aggregateHashTable; + + void init(HashAggregateSharedState* sharedState, ResultSet& resultSet, + main::ClientContext* context, std::vector& aggregateFunctions, + std::vector types); + uint64_t append(const std::vector& aggregateInputs, + uint64_t multiplicity) const; +}; + +struct HashAggregatePrintInfo final : OPPrintInfo { + binder::expression_vector keys; + binder::expression_vector aggregates; + uint64_t limitNum; + + HashAggregatePrintInfo(binder::expression_vector keys, binder::expression_vector aggregates) + : keys{std::move(keys)}, aggregates{std::move(aggregates)}, limitNum{UINT64_MAX} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new HashAggregatePrintInfo(*this)); + } + +private: + HashAggregatePrintInfo(const HashAggregatePrintInfo& other) + : OPPrintInfo{other}, keys{other.keys}, aggregates{other.aggregates}, + limitNum{other.limitNum} {} +}; + +class HashAggregate final : public BaseAggregate { +public: + HashAggregate(std::shared_ptr sharedState, + std::vector aggregateFunctions, + std::vector aggInfos, std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : BaseAggregate{std::move(sharedState), std::move(aggregateFunctions), std::move(aggInfos), + std::move(child), id, std::move(printInfo)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(sharedState, copyVector(aggregateFunctions), + copyVector(aggInfos), children[0]->copy(), id, printInfo->copy()); + } + + const HashAggregateSharedState& getSharedStateReference() const { + return common::ku_dynamic_cast(*sharedState); + } + std::shared_ptr getSharedState() const { + return std::reinterpret_pointer_cast(sharedState); + } + +private: + HashAggregateLocalState localState; +}; + +class HashAggregateFinalize final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::AGGREGATE_FINALIZE; + +public: + HashAggregateFinalize(std::shared_ptr sharedState, physical_op_id id, + std::unique_ptr printInfo) + : Sink{type_, id, std::move(printInfo)}, sharedState{std::move(sharedState)} {} + + bool isSource() const override { return true; } + + void executeInternal(ExecutionContext* /*context*/) override { + KU_ASSERT(sharedState->isReadyForFinalization()); + sharedState->finalizePartitions(); + } + void finalizeInternal(ExecutionContext* /*context*/) override { + sharedState->assertFinalized(); + } + + std::unique_ptr copy() override { + return make_unique(sharedState, id, printInfo->copy()); + } + +private: + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/hash_aggregate_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/hash_aggregate_scan.h new file mode 100644 index 0000000000..140d0c9acc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/hash_aggregate_scan.h @@ -0,0 +1,40 @@ +#pragma once + +#include "processor/operator/aggregate/base_aggregate_scan.h" +#include "processor/operator/aggregate/hash_aggregate.h" + +namespace lbug { +namespace processor { + +class HashAggregateScan final : public BaseAggregateScan { +public: + HashAggregateScan(std::shared_ptr sharedState, + std::vector groupByKeyVectorsPos, AggregateScanInfo scanInfo, uint32_t id, + std::unique_ptr printInfo) + : BaseAggregateScan{std::move(scanInfo), id, std::move(printInfo)}, + groupByKeyVectorsPos{std::move(groupByKeyVectorsPos)}, + sharedState{std::move(sharedState)} {} + + std::shared_ptr getSharedState() const { return sharedState; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(sharedState, groupByKeyVectorsPos, scanInfo, id, + printInfo->copy()); + } + + double getProgress(ExecutionContext* context) const override; + +private: + std::vector groupByKeyVectorsPos; + std::vector groupByKeyVectors; + std::shared_ptr sharedState; + std::vector groupByKeyVectorsColIdxes; + std::vector entries; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/simple_aggregate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/simple_aggregate.h new file mode 100644 index 0000000000..a07cc55a0b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/simple_aggregate.h @@ -0,0 +1,156 @@ +#pragma once + +#include + +#include "common/cast.h" +#include "common/copy_constructors.h" +#include "common/in_mem_overflow_buffer.h" +#include "processor/operator/aggregate/aggregate_hash_table.h" +#include "processor/operator/aggregate/base_aggregate.h" + +namespace lbug { +namespace processor { + +// NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor): This is a final class. +class SimpleAggregateSharedState final : public BaseAggregateSharedState { + friend class SimpleAggregate; + +public: + explicit SimpleAggregateSharedState(main::ClientContext* clientContext, + const std::vector& aggregateFunctions, + const std::vector& aggInfos); + + // The partitioningData objects need a stable pointer to this shared state + DELETE_COPY_AND_MOVE(SimpleAggregateSharedState); + + void combineAggregateStates( + const std::vector>& localAggregateStates, + common::InMemOverflowBuffer&& localOverflowBuffer); + + void finalizeAggregateStates(); + + std::pair getNextRangeToRead() override; + + function::AggregateState* getAggregateState(uint64_t idx) { + return globalAggregateStates[idx].get(); + } + + // Merges data from the queues into the distinct hash tables + // Can be run concurrently (but only after all data has been written into the queues) + void finalizePartitions(storage::MemoryManager* memoryManager, + const std::vector& aggInfos); + + bool isReadyForFinalization() const { return readyForFinalization; } + +protected: + struct Partition { + struct DistinctData { + std::unique_ptr hashTable; + std::unique_ptr queue; + std::unique_ptr state; + }; + std::mutex mtx; + std::vector distinctTables; + std::atomic finalized = false; + }; + + class SimpleAggregatePartitioningData : public AggregatePartitioningData { + public: + SimpleAggregatePartitioningData(SimpleAggregateSharedState* sharedState, size_t functionIdx) + : sharedState{sharedState}, functionIdx{functionIdx} {} + + void appendTuples(const FactorizedTable& factorizedTable, + ft_col_offset_t hashOffset) override; + void appendDistinctTuple(size_t, std::span, common::hash_t) override; + void appendOverflow(common::InMemOverflowBuffer&& overflowBuffer) override; + + private: + SimpleAggregateSharedState* sharedState; + size_t functionIdx; + }; + +private: + bool hasDistinct; + std::vector globalPartitions; + std::vector partitioningData; + common::InMemOverflowBuffer aggregateOverflowBuffer; + std::vector> globalAggregateStates; +}; + +struct SimpleAggregatePrintInfo final : OPPrintInfo { + binder::expression_vector aggregates; + + explicit SimpleAggregatePrintInfo(binder::expression_vector aggregates) + : aggregates{std::move(aggregates)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new SimpleAggregatePrintInfo(*this)); + } + +private: + SimpleAggregatePrintInfo(const SimpleAggregatePrintInfo& other) + : OPPrintInfo{other}, aggregates{other.aggregates} {} +}; + +class SimpleAggregate final : public BaseAggregate { +public: + SimpleAggregate(std::shared_ptr sharedState, + std::vector aggregateFunctions, + std::vector aggInfos, std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : BaseAggregate{std::move(sharedState), std::move(aggregateFunctions), std::move(aggInfos), + std::move(child), id, std::move(printInfo)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(sharedState, copyVector(aggregateFunctions), + copyVector(aggInfos), children[0]->copy(), id, printInfo->copy()); + } + +private: + void computeAggregate(function::AggregateFunction* function, AggregateInput* input, + function::AggregateState* state, common::InMemOverflowBuffer& overflowBuffer); + + SimpleAggregateSharedState& getSharedState() { + return common::ku_dynamic_cast(*sharedState.get()); + } + +private: + std::vector> localAggregateStates; + std::vector> distinctHashTables; +}; + +class SimpleAggregateFinalize final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::AGGREGATE_FINALIZE; + +public: + SimpleAggregateFinalize(std::shared_ptr sharedState, + std::vector aggInfos, physical_op_id id, + std::unique_ptr printInfo) + : Sink{type_, id, std::move(printInfo)}, sharedState{std::move(sharedState)}, + aggInfos{std::move(aggInfos)} {} + + bool isSource() const override { return true; } + + void executeInternal(ExecutionContext* context) override; + + void finalizeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(sharedState, copyVector(aggInfos), id, + printInfo->copy()); + } + +private: + std::shared_ptr sharedState; + std::vector aggInfos; + std::vector> globalAggregateStates; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/simple_aggregate_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/simple_aggregate_scan.h new file mode 100644 index 0000000000..e0ca717fbb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/aggregate/simple_aggregate_scan.h @@ -0,0 +1,30 @@ +#pragma once + +#include "processor/operator/aggregate/base_aggregate_scan.h" +#include "processor/operator/aggregate/simple_aggregate.h" + +namespace lbug { +namespace processor { + +class SimpleAggregateScan final : public BaseAggregateScan { +public: + SimpleAggregateScan(std::shared_ptr sharedState, + AggregateScanInfo scanInfo, uint32_t id, std::unique_ptr printInfo) + : BaseAggregateScan{std::move(scanInfo), id, std::move(printInfo)}, + sharedState{std::move(sharedState)}, outDataChunk{nullptr} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(sharedState, scanInfo, id, printInfo->copy()); + } + +private: + std::shared_ptr sharedState; + common::DataChunk* outDataChunk; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/arrow_result_collector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/arrow_result_collector.h new file mode 100644 index 0000000000..83482c5539 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/arrow_result_collector.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include "common/arrow/arrow.h" +#include "processor/operator/sink.h" +#include "processor/result/flat_tuple.h" + +namespace lbug { +namespace processor { + +class ArrowResultCollectorSharedState { +public: + std::vector arrays; + + void merge(const std::vector& localArrays); + +private: + std::mutex mutex; +}; + +struct ArrowResultCollectorLocalState { + std::vector arrays; + std::vector vectors; + std::vector> vectorsSelPos; + std::vector chunks; + std::vector chunkCursors; + std::unique_ptr tuple; + + // Advance cursor. + bool advance(); + // Scan from vector to tuple based on cursor. + void fillTuple(); + + void resetCursor(); +}; + +struct ArrowResultCollectorInfo { + int64_t chunkSize; + std::vector payloadPositions; + std::vector columnTypes; + + ArrowResultCollectorInfo(int64_t chunkSize, std::vector payloadPositions, + std::vector columnTypes) + : chunkSize{chunkSize}, payloadPositions{std::move(payloadPositions)}, + columnTypes{std::move(columnTypes)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ArrowResultCollectorInfo); + +private: + ArrowResultCollectorInfo(const ArrowResultCollectorInfo& other) + : chunkSize{other.chunkSize}, payloadPositions{other.payloadPositions}, + columnTypes{copyVector(other.columnTypes)} {} +}; + +class ArrowResultCollector final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::RESULT_COLLECTOR; + +public: + ArrowResultCollector(std::shared_ptr sharedState, + ArrowResultCollectorInfo info, std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, + sharedState{std::move(sharedState)}, info{std::move(info)} {} + + std::unique_ptr getQueryResult() const override; + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(sharedState, info.copy(), children[0]->copy(), + id, printInfo->copy()); + } + +private: + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext*) override; + + void iterateResultSet(common::ArrowRowBatch* inputBatch); + bool fillRowBatch(common::ArrowRowBatch& rowBatch); + +private: + std::shared_ptr sharedState; + ArrowResultCollectorInfo info; + ArrowResultCollectorLocalState localState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/base_partitioner_shared_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/base_partitioner_shared_state.h new file mode 100644 index 0000000000..d63837f60a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/base_partitioner_shared_state.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "common/types/types.h" + +namespace lbug { +namespace storage { +class NodeTable; +class RelTable; +} // namespace storage +namespace main { +class ClientContext; +} +namespace processor { + +struct LBUG_API PartitionerSharedState { + storage::NodeTable* srcNodeTable; + storage::NodeTable* dstNodeTable; + storage::RelTable* relTable; + + static constexpr size_t DIRECTIONS = 2; + std::array numNodes; + std::array + numPartitions; // num of partitions in each direction. + std::atomic nextPartitionIdx; + + PartitionerSharedState() + : srcNodeTable{nullptr}, dstNodeTable{nullptr}, relTable(nullptr), numNodes{0, 0}, + numPartitions{0, 0}, nextPartitionIdx{0} {} + virtual ~PartitionerSharedState() = default; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + + virtual void initialize(const common::logical_type_vec_t& columnTypes, + common::idx_t numPartitioners, const main::ClientContext* clientContext); + + common::partition_idx_t getNextPartition(common::idx_t partitioningIdx); + + common::partition_idx_t getNumPartitions(common::idx_t partitioningIdx) const { + return numPartitions[partitioningIdx]; + } + common::offset_t getNumNodes(common::idx_t partitioningIdx) const { + return numNodes[partitioningIdx]; + } + + virtual void resetState(common::idx_t partitioningIdx); + + static common::partition_idx_t getNumPartitionsFromRows(common::offset_t numRows); +}; +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/cross_product.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/cross_product.h new file mode 100644 index 0000000000..2c82af1555 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/cross_product.h @@ -0,0 +1,64 @@ +#pragma once + +#include "processor/operator/physical_operator.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +struct CrossProductLocalState { + std::shared_ptr table; + uint64_t maxMorselSize; + uint64_t startIdx = 0u; + + CrossProductLocalState(std::shared_ptr table, uint64_t maxMorselSize) + : table{std::move(table)}, maxMorselSize{maxMorselSize}, startIdx{0} {} + EXPLICIT_COPY_DEFAULT_MOVE(CrossProductLocalState); + + void init() { startIdx = table->getNumTuples(); } + +private: + CrossProductLocalState(const CrossProductLocalState& other) + : table{other.table}, maxMorselSize{other.maxMorselSize}, startIdx{other.startIdx} {} +}; + +struct CrossProductInfo { + std::vector outVecPos; + std::vector colIndicesToScan; + + CrossProductInfo(std::vector outVecPos, std::vector colIndicesToScan) + : outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)} {} + EXPLICIT_COPY_DEFAULT_MOVE(CrossProductInfo); + +private: + CrossProductInfo(const CrossProductInfo& other) + : outVecPos{other.outVecPos}, colIndicesToScan{other.colIndicesToScan} {} +}; + +class CrossProduct final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::CROSS_PRODUCT; + +public: + CrossProduct(CrossProductInfo info, CrossProductLocalState localState, + std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + info{std::move(info)}, localState{std::move(localState)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), localState.copy(), children[0]->copy(), + id, printInfo->copy()); + } + +private: + CrossProductInfo info; + CrossProductLocalState localState; + std::vector vectorsToScan; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/alter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/alter.h new file mode 100644 index 0000000000..f01a6b3ca5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/alter.h @@ -0,0 +1,65 @@ +#pragma once + +#include "binder/ddl/bound_alter_info.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace catalog { +class TableCatalogEntry; +class RelGroupCatalogEntry; +} // namespace catalog + +namespace processor { + +struct AlterPrintInfo final : OPPrintInfo { + common::AlterType alterType; + std::string tableName; + binder::BoundAlterInfo info; + + AlterPrintInfo(common::AlterType alterType, std::string tableName, binder::BoundAlterInfo info) + : alterType{alterType}, tableName{std::move(tableName)}, info{std::move(info)} {} + + std::string toString() const override { return info.toString(); } + + std::unique_ptr copy() const override { + return std::unique_ptr(new AlterPrintInfo(*this)); + } + +private: + AlterPrintInfo(const AlterPrintInfo& other) + : OPPrintInfo{other}, alterType{other.alterType}, tableName{other.tableName}, + info{other.info.copy()} {} +}; + +class Alter final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::ALTER; + +public: + Alter(binder::BoundAlterInfo info, + std::unique_ptr defaultValueEvaluator, + std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + info{std::move(info)}, defaultValueEvaluator{std::move(defaultValueEvaluator)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), + defaultValueEvaluator == nullptr ? nullptr : defaultValueEvaluator->copy(), + messageTable, id, printInfo->copy()); + } + +private: + void alterTable(main::ClientContext* clientContext, const catalog::TableCatalogEntry& entry, + const binder::BoundAlterInfo& alterInfo); + + binder::BoundAlterInfo info; + std::unique_ptr defaultValueEvaluator; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_sequence.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_sequence.h new file mode 100644 index 0000000000..fcd8d47810 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_sequence.h @@ -0,0 +1,46 @@ +#pragma once + +#include "binder/ddl/bound_create_sequence_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct CreateSequencePrintInfo final : OPPrintInfo { + std::string seqName; + + explicit CreateSequencePrintInfo(std::string seqName) : seqName{std::move(seqName)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new CreateSequencePrintInfo(*this)); + } + +private: + CreateSequencePrintInfo(const CreateSequencePrintInfo& other) + : OPPrintInfo{other}, seqName{other.seqName} {} +}; + +class CreateSequence final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::CREATE_SEQUENCE; + +public: + CreateSequence(binder::BoundCreateSequenceInfo info, + std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + info{std::move(info)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), messageTable, id, printInfo->copy()); + } + +private: + binder::BoundCreateSequenceInfo info; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_table.h new file mode 100644 index 0000000000..27ad081678 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_table.h @@ -0,0 +1,42 @@ +#pragma once + +#include "binder/ddl/bound_create_table_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct CreateTableSharedState { + bool tableCreated = false; +}; + +class CreateTable final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::CREATE_TABLE; + +public: + CreateTable(binder::BoundCreateTableInfo info, std::shared_ptr messageTable, + std::shared_ptr sharedState, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + info{std::move(info)}, sharedState{std::move(sharedState)} {} + + void executeInternal(ExecutionContext* context) override; + + bool terminate() const override { + // If table is not created, meaning table already exists. Then subsequent copy tasks should + // not be executed. + return !sharedState->tableCreated; + } + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), messageTable, sharedState, id, + printInfo->copy()); + } + +private: + binder::BoundCreateTableInfo info; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_type.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_type.h new file mode 100644 index 0000000000..2d935e7422 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/create_type.h @@ -0,0 +1,48 @@ +#pragma once + +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct CreateTypePrintInfo final : OPPrintInfo { + std::string typeName; + std::string type; + + CreateTypePrintInfo(std::string typeName, std::string type) + : typeName{std::move(typeName)}, type{std::move(type)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new CreateTypePrintInfo(*this)); + } + +private: + CreateTypePrintInfo(const CreateTypePrintInfo& other) + : OPPrintInfo{other}, typeName{other.typeName}, type{other.type} {} +}; + +class CreateType final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::CREATE_TYPE; + +public: + CreateType(std::string name, common::LogicalType type, + std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + name{std::move(name)}, type{std::move(type)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(name, type.copy(), messageTable, id, printInfo->copy()); + } + +private: + std::string name; + common::LogicalType type; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/drop.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/drop.h new file mode 100644 index 0000000000..b11e8cdfc3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/ddl/drop.h @@ -0,0 +1,51 @@ +#pragma once + +#include "parser/ddl/drop_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct DropPrintInfo final : OPPrintInfo { + std::string name; + + explicit DropPrintInfo(std::string name) : name{std::move(name)} {} + + std::string toString() const override { return name; } + + std::unique_ptr copy() const override { + return std::unique_ptr(new DropPrintInfo(*this)); + } + +private: + DropPrintInfo(const DropPrintInfo& other) : OPPrintInfo{other}, name{other.name} {} +}; + +class Drop final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DROP; + +public: + Drop(parser::DropInfo dropInfo, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + dropInfo{std::move(dropInfo)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(dropInfo, messageTable, id, printInfo->copy()); + } + +private: + void dropSequence(const main::ClientContext* context); + void dropTable(const main::ClientContext* context); + void dropMacro(const main::ClientContext* context); + void handleMacroExistence(const main::ClientContext* context); + void dropRelGroup(const main::ClientContext* context); + +private: + parser::DropInfo dropInfo; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/empty_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/empty_result.h new file mode 100644 index 0000000000..3c5731b83e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/empty_result.h @@ -0,0 +1,25 @@ +#pragma once + +#include "physical_operator.h" + +namespace lbug { +namespace processor { + +class EmptyResult final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::EMPTY_RESULT; + +public: + EmptyResult(physical_op_id id, std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)} {} + + bool isSource() const override { return true; } + + bool getNextTuplesInternal(ExecutionContext*) override; + + std::unique_ptr copy() override { + return std::make_unique(id, printInfo->copy()); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/filter.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/filter.h new file mode 100644 index 0000000000..15e1a24288 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/filter.h @@ -0,0 +1,93 @@ +#pragma once + +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/filtering_operator.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct FilterPrintInfo final : OPPrintInfo { + std::shared_ptr expression; + + explicit FilterPrintInfo(std::shared_ptr expression) + : expression{std::move(expression)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new FilterPrintInfo(*this)); + } + +private: + FilterPrintInfo(const FilterPrintInfo& other) + : OPPrintInfo{other}, expression{other.expression} {} +}; + +class Filter final : public PhysicalOperator, public SelVectorOverWriter { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::FILTER; + +public: + Filter(std::unique_ptr expressionEvaluator, + uint32_t dataChunkToSelectPos, std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + expressionEvaluator{std::move(expressionEvaluator)}, + dataChunkToSelectPos(dataChunkToSelectPos) {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(expressionEvaluator->copy(), dataChunkToSelectPos, + children[0]->copy(), id, printInfo->copy()); + } + +private: + std::unique_ptr expressionEvaluator; + uint32_t dataChunkToSelectPos; + std::shared_ptr state; +}; + +struct NodeLabelFilterInfo { + DataPos nodeVectorPos; + std::unordered_set nodeLabelSet; + + NodeLabelFilterInfo(const DataPos& nodeVectorPos, + std::unordered_set nodeLabelSet) + : nodeVectorPos{nodeVectorPos}, nodeLabelSet{std::move(nodeLabelSet)} {} + NodeLabelFilterInfo(const NodeLabelFilterInfo& other) + : nodeVectorPos{other.nodeVectorPos}, nodeLabelSet{other.nodeLabelSet} {} + + std::unique_ptr copy() const { + return std::make_unique(*this); + } +}; + +class NodeLabelFiler final : public PhysicalOperator, public SelVectorOverWriter { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::FILTER; + +public: + NodeLabelFiler(std::unique_ptr info, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + info{std::move(info)}, nodeIDVector{nullptr} {} + + void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() final { + return std::make_unique(info->copy(), children[0]->copy(), id, + printInfo->copy()); + } + +private: + std::unique_ptr info; + common::ValueVector* nodeIDVector; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/filtering_operator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/filtering_operator.h new file mode 100644 index 0000000000..a27b9a8075 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/filtering_operator.h @@ -0,0 +1,30 @@ +#pragma once + +#include "common/data_chunk/sel_vector.h" + +namespace lbug { +namespace common { +class DataChunkState; +} // namespace common + +namespace processor { + +class SelVectorOverWriter { +public: + SelVectorOverWriter(); + virtual ~SelVectorOverWriter() = default; + +protected: + void restoreSelVector(common::DataChunkState& dataChunkState) const; + + void saveSelVector(common::DataChunkState& dataChunkState); + +private: + virtual void resetCurrentSelVector(const common::SelectionVector& selVector); + +protected: + std::shared_ptr prevSelVector; + std::shared_ptr currentSelVector; +}; +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/flatten.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/flatten.h new file mode 100644 index 0000000000..37842ef2a9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/flatten.h @@ -0,0 +1,42 @@ +#pragma once + +#include "processor/operator/filtering_operator.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct FlattenLocalState { + uint64_t currentIdx = 0; + uint64_t sizeToFlatten = 0; +}; + +class Flatten final : public PhysicalOperator, SelVectorOverWriter { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::FLATTEN; + +public: + Flatten(data_chunk_pos_t dataChunkToFlattenPos, std::unique_ptr child, + uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + dataChunkToFlattenPos{dataChunkToFlattenPos}, dataChunkState{nullptr} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(dataChunkToFlattenPos, children[0]->copy(), id, + printInfo->copy()); + } + +private: + void resetCurrentSelVector(const common::SelectionVector& selVector) override; + +private: + data_chunk_pos_t dataChunkToFlattenPos; + common::DataChunkState* dataChunkState; + std::unique_ptr localState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/hash_join_build.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/hash_join_build.h new file mode 100644 index 0000000000..41747b79ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/hash_join_build.h @@ -0,0 +1,118 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" +#include "join_hash_table.h" +#include "processor/operator/physical_operator.h" +#include "processor/operator/sink.h" +#include "processor/result/factorized_table.h" +#include "processor/result/result_set.h" + +namespace lbug { +namespace processor { + +struct HashJoinBuildPrintInfo final : OPPrintInfo { + binder::expression_vector keys; + binder::expression_vector payloads; + + HashJoinBuildPrintInfo(binder::expression_vector keys, binder::expression_vector payloads) + : keys{std::move(keys)}, payloads(std::move(payloads)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new HashJoinBuildPrintInfo(*this)); + } + +private: + HashJoinBuildPrintInfo(const HashJoinBuildPrintInfo& other) + : OPPrintInfo{other}, keys{other.keys}, payloads{other.payloads} {} +}; + +class HashJoinBuild; + +// This is a shared state between HashJoinBuild and HashJoinProbe operators. +// Each clone of these two operators will share the same state. +// Inside the state, we keep the materialized tuples in factorizedTable, which are merged by each +// HashJoinBuild thread when they finished materializing thread-local tuples. Also, the state holds +// a global htDirectory, which will be updated by the last thread in the hash join build side +// task/pipeline, and probed by the HashJoinProbe operators. +class HashJoinSharedState { +public: + explicit HashJoinSharedState(std::unique_ptr hashTable) + : hashTable{std::move(hashTable)} {}; + + void mergeLocalHashTable(JoinHashTable& localHashTable); + + JoinHashTable* getHashTable() { return hashTable.get(); } + +protected: + std::mutex mtx; + std::unique_ptr hashTable; +}; + +struct HashJoinBuildInfo { + std::vector keysPos; + std::vector fStateTypes; + std::vector payloadsPos; + FactorizedTableSchema tableSchema; + + HashJoinBuildInfo(std::vector keysPos, std::vector fStateTypes, + std::vector payloadsPos, FactorizedTableSchema tableSchema) + : keysPos{std::move(keysPos)}, fStateTypes{std::move(fStateTypes)}, + payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(tableSchema)} {} + EXPLICIT_COPY_DEFAULT_MOVE(HashJoinBuildInfo); + + common::idx_t getNumKeys() const { return keysPos.size(); } + +private: + HashJoinBuildInfo(const HashJoinBuildInfo& other) + : keysPos{other.keysPos}, fStateTypes{other.fStateTypes}, payloadsPos{other.payloadsPos}, + tableSchema{other.tableSchema.copy()} {} +}; + +class HashJoinBuild : public Sink { +public: + HashJoinBuild(PhysicalOperatorType operatorType, + std::shared_ptr sharedState, HashJoinBuildInfo info, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{operatorType, std::move(child), id, std::move(printInfo)}, + sharedState{std::move(sharedState)}, info{std::move(info)} {} + + std::shared_ptr getSharedState() const { return sharedState; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + void finalizeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(operatorType, sharedState, info.copy(), + children[0]->copy(), id, printInfo->copy()); + } + +protected: + virtual uint64_t appendVectors() { + return hashTable->appendVectors(keyVectors, payloadVectors, keyState); + } + +private: + void setKeyState(common::DataChunkState* state); + +protected: + std::shared_ptr sharedState; + HashJoinBuildInfo info; + + std::vector keyVectors; + // State of unFlat key(s). If all keys are flat, it points to any flat key state. + common::DataChunkState* keyState = nullptr; + std::vector payloadVectors; + + std::unique_ptr hashTable; // local state +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/hash_join_probe.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/hash_join_probe.h new file mode 100644 index 0000000000..78ca3d78eb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/hash_join_probe.h @@ -0,0 +1,123 @@ +#pragma once + +#include "common/enums/join_type.h" +#include "processor/operator/filtering_operator.h" +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/physical_operator.h" +#include "processor/result/result_set.h" + +namespace lbug { +namespace processor { + +struct ProbeState { + explicit ProbeState() + : matchedSelVector{common::DEFAULT_VECTOR_CAPACITY}, nextMatchedTupleIdx{0} { + matchedTuples = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + probedTuples = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + matchedSelVector.setToFiltered(); + } + + // Each key corresponds to a pointer with the same hash value from the ht directory. + std::unique_ptr probedTuples; + // Pointers to tuples in ht that actually matched. + std::unique_ptr matchedTuples; + // Selective index mapping each probed tuple to its probe side key vector. + common::SelectionVector matchedSelVector; + common::sel_t nextMatchedTupleIdx; +}; + +struct ProbeDataInfo { +public: + ProbeDataInfo(std::vector keysDataPos, std::vector payloadsOutPos) + : keysDataPos{std::move(keysDataPos)}, payloadsOutPos{std::move(payloadsOutPos)}, + markDataPos{UINT32_MAX, UINT32_MAX} {} + + ProbeDataInfo(const ProbeDataInfo& other) + : ProbeDataInfo{other.keysDataPos, other.payloadsOutPos} { + markDataPos = other.markDataPos; + } + + inline uint32_t getNumPayloads() const { return payloadsOutPos.size(); } + +public: + std::vector keysDataPos; + std::vector payloadsOutPos; + DataPos markDataPos; +}; + +struct HashJoinProbePrintInfo final : OPPrintInfo { + binder::expression_vector keys; + + explicit HashJoinProbePrintInfo(binder::expression_vector keys) : keys{std::move(keys)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new HashJoinProbePrintInfo(*this)); + } + +private: + HashJoinProbePrintInfo(const HashJoinProbePrintInfo& other) + : OPPrintInfo{other}, keys{other.keys} {} +}; + +// Probe side on left, i.e. children[0] and build side on right, i.e. children[1] +class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::HASH_JOIN_PROBE; + +public: + HashJoinProbe(std::shared_ptr sharedState, common::JoinType joinType, + bool flatProbe, const ProbeDataInfo& probeDataInfo, + std::unique_ptr probeChild, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(probeChild), id, std::move(printInfo)}, + sharedState{std::move(sharedState)}, joinType{joinType}, flatProbe{flatProbe}, + probeDataInfo{probeDataInfo}, markVector(nullptr) {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(sharedState, joinType, flatProbe, probeDataInfo, + children[0]->copy(), id, printInfo->copy()); + } + +private: + bool getMatchedTuples(ExecutionContext* context) { + return flatProbe ? getMatchedTuplesForFlatKey(context) : + getMatchedTuplesForUnFlatKey(context); + } + bool getMatchedTuplesForFlatKey(ExecutionContext* context); + // We can probe a batch of input tuples if we know they have at most one match. + bool getMatchedTuplesForUnFlatKey(ExecutionContext* context); + + uint64_t getInnerJoinResult() { + return flatProbe ? getInnerJoinResultForFlatKey() : getInnerJoinResultForUnFlatKey(); + } + uint64_t getInnerJoinResultForFlatKey(); + uint64_t getInnerJoinResultForUnFlatKey(); + uint64_t getLeftJoinResult(); + uint64_t getMarkJoinResult(); + uint64_t getCountJoinResult(); + uint64_t getJoinResult(); + +private: + std::shared_ptr sharedState; + common::JoinType joinType; + bool flatProbe; + + ProbeDataInfo probeDataInfo; + std::vector vectorsToReadInto; + std::vector columnIdxsToReadFrom; + std::vector keyVectors; + common::ValueVector* markVector; + std::unique_ptr probeState; + + std::unique_ptr hashVector; + std::unique_ptr tmpHashVector; + common::SelectionVector hashSelVec; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/join_hash_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/join_hash_table.h new file mode 100644 index 0000000000..427ed60b55 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/hash_join/join_hash_table.h @@ -0,0 +1,73 @@ +#pragma once + +#include "processor/result/base_hash_table.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} +namespace processor { + +class JoinHashTable : public BaseHashTable { +public: + JoinHashTable(storage::MemoryManager& memoryManager, common::logical_type_vec_t keyTypes, + FactorizedTableSchema tableSchema); + + uint64_t appendVectors(const std::vector& keyVectors, + const std::vector& payloadVectors, common::DataChunkState* keyState); + void appendVector(common::ValueVector* vector, + const std::vector& appendInfos, ft_col_idx_t colIdx); + + // Used in worst-case optimal join + uint64_t appendVectorWithSorting(common::ValueVector* keyVector, + std::vector payloadVectors); + + void allocateHashSlots(uint64_t numTuples); + void buildHashSlots(); + + // The tmpHashResultVector may be null if there is only one keyVector + void probe(const std::vector& keyVectors, common::ValueVector& hashVector, + common::SelectionVector& hashSelVec, common::ValueVector* tmpHashResultVector, + uint8_t** probedTuples); + // All key vectors must be flat. Thus input is a tuple, multiple matches can be found for the + // given key tuple. + common::sel_t matchFlatKeys(const std::vector& keyVectors, + uint8_t** probedTuples, uint8_t** matchedTuples); + // Input is multiple tuples, at most one match exist for each key. + common::sel_t matchUnFlatKey(common::ValueVector* keyVector, uint8_t** probedTuples, + uint8_t** matchedTuples, common::SelectionVector& matchedTuplesSelVector); + + void lookup(std::vector& vectors, std::vector& colIdxesToScan, + uint8_t** tuplesToRead, uint64_t startPos, uint64_t numTuplesToRead) { + factorizedTable->lookup(vectors, colIdxesToScan, tuplesToRead, startPos, numTuplesToRead); + } + void merge(JoinHashTable& other) { factorizedTable->merge(*other.factorizedTable); } + uint8_t** getPrevTuple(const uint8_t* tuple) const { + return (uint8_t**)(tuple + prevPtrColOffset); + } + uint8_t* getTupleForHash(common::hash_t hash) { + auto slotIdx = getSlotIdxForHash(hash); + KU_ASSERT(slotIdx < maxNumHashSlots); + return ((uint8_t**)(hashSlotsBlocks[slotIdx >> numSlotsPerBlockLog2] + ->getData()))[slotIdx & slotIdxInBlockMask]; + } + +private: + uint8_t** findHashSlot(const uint8_t* tuple) const; + // This function returns the pointer that previously stored in the same slot. + uint8_t* insertEntry(uint8_t* tuple) const; + + // Join hash table assumes all keys to be flat. + void computeVectorHashes(std::vector keyVectors); + + common::offset_t getHashValueColOffset() const; + +private: + static constexpr uint64_t PREV_PTR_COL_IDX = 1; + static constexpr uint64_t HASH_COL_IDX = 2; + uint64_t prevPtrColOffset; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/index_lookup.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/index_lookup.h new file mode 100644 index 0000000000..a2a67b6ff0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/index_lookup.h @@ -0,0 +1,89 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/persistent/batch_insert_error_handler.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace transaction { +class Transaction; +} +namespace storage { +class NodeTable; +} // namespace storage +namespace processor { + +struct BatchInsertSharedState; +struct IndexLookupInfo { + storage::NodeTable* nodeTable; + std::unique_ptr keyEvaluator; + DataPos resultVectorPos; + + IndexLookupInfo(storage::NodeTable* nodeTable, + std::unique_ptr keyEvaluator, + const DataPos& resultVectorPos) + : nodeTable{nodeTable}, keyEvaluator{std::move(keyEvaluator)}, + resultVectorPos{resultVectorPos} {} + EXPLICIT_COPY_DEFAULT_MOVE(IndexLookupInfo); + +private: + IndexLookupInfo(const IndexLookupInfo& other) + : nodeTable{other.nodeTable}, keyEvaluator{other.keyEvaluator->copy()}, + resultVectorPos{other.resultVectorPos} {} +}; + +struct IndexLookupPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + explicit IndexLookupPrintInfo(binder::expression_vector expressions) + : expressions{std::move(expressions)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new IndexLookupPrintInfo(*this)); + } + +private: + IndexLookupPrintInfo(const IndexLookupPrintInfo& other) + : OPPrintInfo{other}, expressions{other.expressions} {} +}; + +struct IndexLookupLocalState { + explicit IndexLookupLocalState(std::unique_ptr errorHandler) + : errorHandler(std::move(errorHandler)) {} + + std::unique_ptr errorHandler; + std::vector warningDataVectors; +}; + +class IndexLookup final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::INDEX_LOOKUP; + +public: + IndexLookup(std::vector infos, std::vector warningDataVectorPos, + std::unique_ptr child, common::idx_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + infos{std::move(infos)}, warningDataVectorPos{std::move(warningDataVectorPos)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) final; + + std::unique_ptr copy() final { + return std::make_unique(copyVector(infos), warningDataVectorPos, + children[0]->copy(), getOperatorID(), printInfo->copy()); + } + +private: + void lookup(transaction::Transaction* transaction, const IndexLookupInfo& info); + +private: + std::vector infos; + std::vector warningDataVectorPos; + std::unique_ptr localState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/intersect/intersect.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/intersect/intersect.h new file mode 100644 index 0000000000..7fed640670 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/intersect/intersect.h @@ -0,0 +1,86 @@ +#pragma once + +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct IntersectDataInfo { + DataPos keyDataPos; + // TODO(Xiyang): payload is not an accurate name for intersect. + std::vector payloadsDataPos; +}; + +struct IntersectPrintInfo final : OPPrintInfo { + std::shared_ptr key; + + explicit IntersectPrintInfo(std::shared_ptr key) : key{std::move(key)} {} + + std::string toString() const override; + std::unique_ptr copy() const override { + return std::unique_ptr(new IntersectPrintInfo(*this)); + } + +private: + IntersectPrintInfo(const IntersectPrintInfo& other) : OPPrintInfo{other}, key{other.key} {} +}; + +class Intersect : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::INTERSECT; + +public: + Intersect(const DataPos& outputDataPos, std::vector intersectDataInfos, + std::vector> sharedHTs, + std::unique_ptr probeChild, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(probeChild), id, std::move(printInfo)}, + outputDataPos{outputDataPos}, intersectDataInfos{std::move(intersectDataInfos)}, + sharedHTs{std::move(sharedHTs)} { + tupleIdxPerBuildSide.resize(this->sharedHTs.size(), 0); + carryBuildSideIdx = -1u; + probedFlatTuples.resize(this->sharedHTs.size()); + } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(outputDataPos, intersectDataInfos, sharedHTs, + children[0]->copy(), id, printInfo->copy()); + } + +private: + // For each build side, probe its HT and return a vector of matched flat tuples. + void probeHTs(); + // Left is always the one with less num of values. + static void twoWayIntersect(common::nodeID_t* leftNodeIDs, common::SelectionVector& lSelVector, + common::nodeID_t* rightNodeIDs, common::SelectionVector& rSelVector); + void intersectLists(const std::vector& listsToIntersect); + void populatePayloads(const std::vector& tuples, + const std::vector& listIdxes); + bool hasNextTuplesToIntersect(); + + uint32_t getNumBuilds() { return sharedHTs.size(); } + +private: + DataPos outputDataPos; + std::vector intersectDataInfos; + // payloadColumnIdxesToScanFrom and payloadVectorsToScanInto are organized by each build child. + std::vector> payloadColumnIdxesToScanFrom; + std::vector> payloadVectorsToScanInto; + std::shared_ptr outKeyVector; + std::vector> probeKeyVectors; + std::vector> intersectSelVectors; + std::vector> sharedHTs; + std::vector isIntersectListAFlatValue; + std::vector> probedFlatTuples; + // Keep track of the tuple to intersect for each build side. + std::vector tupleIdxPerBuildSide; + // This is used to indicate which build side to increment the tuple idx for. + uint32_t carryBuildSideIdx; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/intersect/intersect_build.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/intersect/intersect_build.h new file mode 100644 index 0000000000..382c2a707b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/intersect/intersect_build.h @@ -0,0 +1,57 @@ +#pragma once + +#include "binder/expression/expression_util.h" +#include "processor/operator/hash_join/hash_join_build.h" + +namespace lbug { +namespace processor { + +struct IntersectBuildPrintInfo final : OPPrintInfo { + binder::expression_vector keys; + binder::expression_vector payloads; + + IntersectBuildPrintInfo(binder::expression_vector keys, binder::expression_vector payloads) + : keys{std::move(keys)}, payloads(std::move(payloads)) {} + + std::string toString() const override { + std::string result = "Keys: "; + result += binder::ExpressionUtil::toString(keys); + if (!payloads.empty()) { + result += ", Payloads: "; + result += binder::ExpressionUtil::toString(payloads); + } + return result; + } + + std::unique_ptr copy() const override { + return std::unique_ptr(new IntersectBuildPrintInfo(*this)); + } + +private: + IntersectBuildPrintInfo(const IntersectBuildPrintInfo& other) + : OPPrintInfo{other}, keys{other.keys}, payloads{other.payloads} {} +}; + +class IntersectBuild final : public HashJoinBuild { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::INTERSECT_BUILD; + +public: + IntersectBuild(std::shared_ptr sharedState, HashJoinBuildInfo info, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : HashJoinBuild{type_, std::move(sharedState), std::move(info), std::move(child), id, + std::move(printInfo)} {} + + uint64_t appendVectors() final { + KU_ASSERT(keyVectors.size() == 1); + return hashTable->appendVectorWithSorting(keyVectors[0], payloadVectors); + } + + std::unique_ptr copy() override { + return make_unique(sharedState, info.copy(), children[0]->copy(), id, + printInfo->copy()); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/limit.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/limit.h new file mode 100644 index 0000000000..38077d7147 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/limit.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct LimitPrintInfo final : OPPrintInfo { + uint64_t limitNum; + + explicit LimitPrintInfo(uint64_t limitNum) : limitNum{limitNum} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LimitPrintInfo(*this)); + } + +private: + LimitPrintInfo(const LimitPrintInfo& other) : OPPrintInfo{other}, limitNum{other.limitNum} {} +}; + +class Limit final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::LIMIT; + +public: + Limit(uint64_t limitNumber, std::shared_ptr counter, + uint32_t dataChunkToSelectPos, std::unordered_set dataChunksPosInScope, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + limitNumber{limitNumber}, counter{std::move(counter)}, + dataChunkToSelectPos{dataChunkToSelectPos}, + dataChunksPosInScope(std::move(dataChunksPosInScope)) {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(limitNumber, counter, dataChunkToSelectPos, dataChunksPosInScope, + children[0]->copy(), id, printInfo->copy()); + } + +private: + uint64_t limitNumber; + std::shared_ptr counter; + uint32_t dataChunkToSelectPos; + std::unordered_set dataChunksPosInScope; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/macro/create_macro.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/macro/create_macro.h new file mode 100644 index 0000000000..2dc19c09db --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/macro/create_macro.h @@ -0,0 +1,59 @@ +#pragma once + +#include "catalog/catalog.h" +#include "function/scalar_macro_function.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct CreateMacroInfo { + std::string macroName; + std::unique_ptr macro; + + CreateMacroInfo(std::string macroName, std::unique_ptr macro) + : macroName{std::move(macroName)}, macro{std::move(macro)} {} + EXPLICIT_COPY_DEFAULT_MOVE(CreateMacroInfo); + +private: + CreateMacroInfo(const CreateMacroInfo& other) + : macroName{other.macroName}, macro{other.macro->copy()} {} +}; + +struct CreateMacroPrintInfo final : OPPrintInfo { + std::string macroName; + + explicit CreateMacroPrintInfo(std::string macroName) : macroName{std::move(macroName)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new CreateMacroPrintInfo(*this)); + } + +private: + CreateMacroPrintInfo(const CreateMacroPrintInfo& other) + : OPPrintInfo{other}, macroName{other.macroName} {} +}; + +class CreateMacro final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::CREATE_MACRO; + +public: + CreateMacro(CreateMacroInfo info, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + info{std::move(info)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), messageTable, id, printInfo->copy()); + } + +private: + CreateMacroInfo info; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/multiplicity_reducer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/multiplicity_reducer.h new file mode 100644 index 0000000000..cd1c0c0e72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/multiplicity_reducer.h @@ -0,0 +1,34 @@ +#pragma once + +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +class MultiplicityReducer final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::MULTIPLICITY_REDUCER; + +public: + MultiplicityReducer(std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, prevMultiplicity{1}, + numRepeat{0} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(children[0]->copy(), id, printInfo->copy()); + } + +private: + void restoreMultiplicity() { resultSet->multiplicity = prevMultiplicity; } + + void saveMultiplicity() { prevMultiplicity = resultSet->multiplicity; } + +private: + uint64_t prevMultiplicity; + uint64_t numRepeat; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/key_block_merger.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/key_block_merger.h new file mode 100644 index 0000000000..b46619ff5c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/key_block_merger.h @@ -0,0 +1,203 @@ +#pragma once + +#include +#include + +#include "processor/operator/order_by/order_by_key_encoder.h" + +namespace lbug { +namespace processor { + +struct KeyBlockMergeMorsel; + +// This struct stores the string key column information. We can utilize the +// pre-computed indexes and offsets to expedite the tuple comparison in merge sort. +struct StrKeyColInfo { + StrKeyColInfo(uint32_t colOffsetInFT, uint32_t colOffsetInEncodedKeyBlock, bool isAscOrder) + : colOffsetInFT{colOffsetInFT}, colOffsetInEncodedKeyBlock{colOffsetInEncodedKeyBlock}, + isAscOrder{isAscOrder} {} + + inline uint32_t getEncodingSize() const { + return OrderByKeyEncoder::getEncodingSize( + common::LogicalType(common::LogicalTypeID::STRING)); + } + + uint32_t colOffsetInFT; + uint32_t colOffsetInEncodedKeyBlock; + bool isAscOrder; +}; + +class MergedKeyBlocks { +public: + MergedKeyBlocks(uint32_t numBytesPerTuple, uint64_t numTuples, + storage::MemoryManager* memoryManager); + + // This constructor is used to convert a dataBlock to a MergedKeyBlocks. + MergedKeyBlocks(uint32_t numBytesPerTuple, std::shared_ptr keyBlock); + + inline uint8_t* getTuple(uint64_t tupleIdx) const { + KU_ASSERT(tupleIdx < numTuples); + return keyBlocks[tupleIdx / numTuplesPerBlock]->getData() + + numBytesPerTuple * (tupleIdx % numTuplesPerBlock); + } + + inline uint64_t getNumTuples() const { return numTuples; } + + inline uint32_t getNumBytesPerTuple() const { return numBytesPerTuple; } + + inline uint32_t getNumTuplesPerBlock() const { return numTuplesPerBlock; } + + inline uint8_t* getKeyBlockBuffer(uint32_t idx) const { + KU_ASSERT(idx < keyBlocks.size()); + return keyBlocks[idx]->getData(); + } + + uint8_t* getBlockEndTuplePtr(uint32_t blockIdx, uint64_t endTupleIdx, + uint32_t endTupleBlockIdx) const; + +private: + uint32_t numBytesPerTuple; + uint32_t numTuplesPerBlock; + uint64_t numTuples; + std::vector> keyBlocks; + uint32_t endTupleOffset; +}; + +struct BlockPtrInfo { + BlockPtrInfo(uint64_t startTupleIdx, uint64_t endTupleIdx, MergedKeyBlocks* keyBlocks); + + inline bool hasMoreTuplesToRead() const { return curTuplePtr != endTuplePtr; } + + inline uint64_t getNumBytesLeftInCurBlock() const { return curBlockEndTuplePtr - curTuplePtr; } + + inline uint64_t getNumTuplesLeftInCurBlock() const { + return getNumBytesLeftInCurBlock() / keyBlocks->getNumBytesPerTuple(); + } + + void updateTuplePtrIfNecessary(); + + MergedKeyBlocks* keyBlocks; + uint8_t* curTuplePtr; + uint64_t curBlockIdx; + uint64_t endBlockIdx; + uint8_t* curBlockEndTuplePtr; + uint8_t* endTuplePtr; + uint64_t endTupleIdx; +}; + +class KeyBlockMerger { +public: + explicit KeyBlockMerger(std::vector factorizedTables, + std::vector& strKeyColsInfo, uint32_t numBytesPerTuple) + : factorizedTables{std::move(factorizedTables)}, strKeyColsInfo{strKeyColsInfo}, + numBytesPerTuple{numBytesPerTuple}, numBytesToCompare{numBytesPerTuple - 8}, + hasStringCol{!strKeyColsInfo.empty()} {} + + void mergeKeyBlocks(KeyBlockMergeMorsel& keyBlockMergeMorsel) const; + + inline bool compareTuplePtr(uint8_t* leftTuplePtr, uint8_t* rightTuplePtr) const { + return hasStringCol ? compareTuplePtrWithStringCol(leftTuplePtr, rightTuplePtr) : + memcmp(leftTuplePtr, rightTuplePtr, numBytesToCompare) > 0; + } + + bool compareTuplePtrWithStringCol(uint8_t* leftTuplePtr, uint8_t* rightTuplePtr) const; + +private: + void copyRemainingBlockDataToResult(BlockPtrInfo& blockToCopy, BlockPtrInfo& resultBlock) const; + +private: + // FactorizedTables[i] stores all order_by columns encoded and sorted by the ith thread. + // MergeSort uses factorizedTable to access the full contents of the string key columns + // when resolving ties. + std::vector factorizedTables; + // We also store the colIdxInFactorizedTable, colOffsetInEncodedKeyBlock, isAscOrder, isStrCol + // for each string column. So, we don't need to compute them again during merge sort. + std::vector& strKeyColsInfo; + uint32_t numBytesPerTuple; + uint32_t numBytesToCompare; + bool hasStringCol; +}; + +class KeyBlockMergeTask { +public: + KeyBlockMergeTask(std::shared_ptr leftKeyBlock, + std::shared_ptr rightKeyBlock, + std::shared_ptr resultKeyBlock, KeyBlockMerger& keyBlockMerger) + : leftKeyBlock{std::move(leftKeyBlock)}, rightKeyBlock{std::move(rightKeyBlock)}, + resultKeyBlock{std::move(resultKeyBlock)}, leftKeyBlockNextIdx{0}, + rightKeyBlockNextIdx{0}, activeMorsels{0}, keyBlockMerger{keyBlockMerger} {} + + std::unique_ptr getMorsel(); + + inline bool hasMorselLeft() const { + // Returns true if there are still morsels left in the current task. + return leftKeyBlockNextIdx < leftKeyBlock->getNumTuples() || + rightKeyBlockNextIdx < rightKeyBlock->getNumTuples(); + } + +private: + uint64_t findRightKeyBlockIdx(uint8_t* leftEndTuplePtr) const; + +public: + static const uint32_t batch_size = 10000; + + std::shared_ptr leftKeyBlock; + std::shared_ptr rightKeyBlock; + std::shared_ptr resultKeyBlock; + uint64_t leftKeyBlockNextIdx; + uint64_t rightKeyBlockNextIdx; + // The counter is used to keep track of the number of morsels given to thread. + // If the counter is 0 and there is no morsel left in the current task, we can + // put the resultKeyBlock back to the keyBlock list. + uint64_t activeMorsels; + // KeyBlockMerger is used to compare the values of two tuples during the binary search. + KeyBlockMerger& keyBlockMerger; +}; + +struct KeyBlockMergeMorsel { + explicit KeyBlockMergeMorsel(uint64_t leftKeyBlockStartIdx, uint64_t leftKeyBlockEndIdx, + uint64_t rightKeyBlockStartIdx, uint64_t rightKeyBlockEndIdx) + : leftKeyBlockStartIdx{leftKeyBlockStartIdx}, leftKeyBlockEndIdx{leftKeyBlockEndIdx}, + rightKeyBlockStartIdx{rightKeyBlockStartIdx}, rightKeyBlockEndIdx{rightKeyBlockEndIdx} {} + + std::shared_ptr keyBlockMergeTask; + uint64_t leftKeyBlockStartIdx; + uint64_t leftKeyBlockEndIdx; + uint64_t rightKeyBlockStartIdx; + uint64_t rightKeyBlockEndIdx; +}; + +// A dispatcher class used to assign KeyBlockMergeMorsel to threads. +// All functions are guaranteed to be thread-safe, so callers don't need to +// acquire a lock before calling these functions. +class KeyBlockMergeTaskDispatcher { +public: + inline bool isDoneMerge() { + std::lock_guard keyBlockMergeDispatcherLock{mtx}; + // Returns true if there are no more merge task to do or the sortedKeyBlocks is empty + // (meaning that the resultSet is empty). + return sortedKeyBlocks->size() <= 1 && activeKeyBlockMergeTasks.empty(); + } + + std::unique_ptr getMorsel(); + + void doneMorsel(std::unique_ptr morsel); + + // This function is used to initialize the columns of keyBlockMergeTaskDispatcher based on + // sharedFactorizedTablesAndSortedKeyBlocks. + void init(storage::MemoryManager* memoryManager, + std::queue>* sortedKeyBlocks, + std::vector factorizedTables, std::vector& strKeyColsInfo, + uint64_t numBytesPerTuple); + +private: + std::mutex mtx; + + storage::MemoryManager* memoryManager = nullptr; + std::queue>* sortedKeyBlocks = nullptr; + std::vector> activeKeyBlockMergeTasks; + std::unique_ptr keyBlockMerger; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by.h new file mode 100644 index 0000000000..88f95ab6e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by.h @@ -0,0 +1,68 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "processor/operator/sink.h" +#include "processor/result/result_set.h" +#include "sort_state.h" + +namespace lbug { +namespace processor { + +struct OrderByPrintInfo final : OPPrintInfo { + binder::expression_vector keys; + binder::expression_vector payloads; + + OrderByPrintInfo(binder::expression_vector keys, binder::expression_vector payloads) + : keys(std::move(keys)), payloads(std::move(payloads)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new OrderByPrintInfo(*this)); + } + +private: + OrderByPrintInfo(const OrderByPrintInfo& other) + : OPPrintInfo(other), keys(other.keys), payloads(other.payloads) {} +}; + +class OrderBy final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::ORDER_BY; + +public: + OrderBy(OrderByDataInfo info, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, info{std::move(info)}, + sharedState{std::move(sharedState)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final; + + void executeInternal(ExecutionContext* context) override; + + void finalize(ExecutionContext* /*context*/) override { + // TODO(Ziyi): we always call lookup function on the first factorizedTable in sharedState + // and that lookup function may read tuples in other factorizedTable, So we need to combine + // hasNoNullGuarantee with other factorizedTables. This is not a good way to solve this + // problem, and should be changed later. + sharedState->combineFTHasNoNullGuarantee(); + } + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), sharedState, children[0]->copy(), id, + printInfo->copy()); + } + +private: + void initGlobalStateInternal(ExecutionContext* context) override; + +private: + OrderByDataInfo info; + SortLocalState localState; + std::shared_ptr sharedState; + std::vector orderByVectors; + std::vector payloadVectors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_data_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_data_info.h new file mode 100644 index 0000000000..424a0a3b31 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_data_info.h @@ -0,0 +1,38 @@ +#pragma once + +#include "processor/data_pos.h" +#include "processor/result/factorized_table_schema.h" + +namespace lbug { +namespace processor { + +struct OrderByDataInfo { + std::vector keysPos; + std::vector payloadsPos; + std::vector keyTypes; + std::vector payloadTypes; + std::vector isAscOrder; + FactorizedTableSchema payloadTableSchema; + std::vector keyInPayloadPos; + + OrderByDataInfo(std::vector keysPos, std::vector payloadsPos, + std::vector keyTypes, std::vector payloadTypes, + std::vector isAscOrder, FactorizedTableSchema payloadTableSchema, + std::vector keyInPayloadPos) + : keysPos{std::move(keysPos)}, payloadsPos{std::move(payloadsPos)}, + keyTypes{std::move(keyTypes)}, payloadTypes{std::move(payloadTypes)}, + isAscOrder{std::move(isAscOrder)}, payloadTableSchema{std::move(payloadTableSchema)}, + keyInPayloadPos{std::move(keyInPayloadPos)} {} + EXPLICIT_COPY_DEFAULT_MOVE(OrderByDataInfo); + +private: + OrderByDataInfo(const OrderByDataInfo& other) + : keysPos{other.keysPos}, payloadsPos{other.payloadsPos}, + keyTypes{common::LogicalType::copy(other.keyTypes)}, + payloadTypes{common::LogicalType::copy(other.payloadTypes)}, isAscOrder{other.isAscOrder}, + payloadTableSchema{other.payloadTableSchema.copy()}, + keyInPayloadPos{other.keyInPayloadPos} {} +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_key_encoder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_key_encoder.h new file mode 100644 index 0000000000..ef70d5092b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_key_encoder.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include + +#include "common/vector/value_vector.h" +#include "order_by_data_info.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +#define BSWAP64(x) \ + ((uint64_t)((((uint64_t)(x) & 0xff00000000000000ull) >> 56) | \ + (((uint64_t)(x) & 0x00ff000000000000ull) >> 40) | \ + (((uint64_t)(x) & 0x0000ff0000000000ull) >> 24) | \ + (((uint64_t)(x) & 0x000000ff00000000ull) >> 8) | \ + (((uint64_t)(x) & 0x00000000ff000000ull) << 8) | \ + (((uint64_t)(x) & 0x0000000000ff0000ull) << 24) | \ + (((uint64_t)(x) & 0x000000000000ff00ull) << 40) | \ + (((uint64_t)(x) & 0x00000000000000ffull) << 56))) + +#define BSWAP32(x) \ + ((uint32_t)((((uint32_t)(x) & 0xff000000) >> 24) | (((uint32_t)(x) & 0x00ff0000) >> 8) | \ + (((uint32_t)(x) & 0x0000ff00) << 8) | (((uint32_t)(x) & 0x000000ff) << 24))) + +#define BSWAP16(x) ((uint16_t)((((uint16_t)(x) & 0xff00) >> 8) | (((uint16_t)(x) & 0x00ff) << 8))) + +// The OrderByKeyEncoder encodes all columns in the ORDER BY clause into a single binary sequence +// that, when compared using memcmp will yield the correct overall sorting order. On little-endian +// hardware, the least-significant byte is stored at the smallest address. To encode the sorting +// order, we need the big-endian representation for values. For example: we want to encode 73(INT64) +// and 38(INT64) as an 8-byte binary string The encoding in little-endian hardware is: +// 73=0x4900000000000000 38=0x2600000000000000, which doesn't preserve the order. The encoding in +// big-endian hardware is: 73=0x0000000000000049 38=0x0000000000000026, which can easily be compared +// using memcmp. In addition, The first bit is also flipped to preserve ordering between positive +// and negative numbers. So the final encoding for 73(INT64) and 38(INT64) as an 8-byte binary +// string is: 73=0x8000000000000049 38=0x8000000000000026. To handle the null in comparison, we +// add an extra byte(called the NULL flag) to represent whether this value is null or not. + +using encode_function_t = std::function; + +class OrderByKeyEncoder { +public: + OrderByKeyEncoder(const OrderByDataInfo& orderByDataInfo, storage::MemoryManager* memoryManager, + uint8_t ftIdx, uint32_t numTuplesPerBlockInFT, uint32_t numBytesPerTuple); + + inline std::vector>& getKeyBlocks() { return keyBlocks; } + + inline uint32_t getNumBytesPerTuple() const { return numBytesPerTuple; } + + inline uint32_t getNumTuplesInCurBlock() const { return keyBlocks.back()->numTuples; } + + static uint32_t getNumBytesPerTuple(const std::vector& keyVectors); + + static inline uint32_t getEncodedFTBlockIdx(const uint8_t* tupleInfoPtr) { + return *(uint32_t*)tupleInfoPtr; + } + + // Note: We only encode 3 bytes for ftBlockOffset, but we are reading 4 bytes from tupleInfoPtr. + // We need to do a bit mask to set the most significant byte to 0x00. + static inline uint32_t getEncodedFTBlockOffset(const uint8_t* tupleInfoPtr) { + return (*(uint32_t*)(tupleInfoPtr + 4) & 0x00FFFFFF); + } + + static inline uint8_t getEncodedFTIdx(const uint8_t* tupleInfoPtr) { + return *(tupleInfoPtr + 7); + } + + static inline bool isNullVal(const uint8_t* nullBytePtr, bool isAscOrder) { + return *(nullBytePtr) == (isAscOrder ? UINT8_MAX : 0); + } + + static inline bool isLongStr(const uint8_t* strBuffer, bool isAsc) { + return *(strBuffer + 13) == (isAsc ? UINT8_MAX : 0); + } + + static uint32_t getEncodingSize(const common::LogicalType& dataType); + + void encodeKeys(const std::vector& orderByKeys); + + inline void clear() { keyBlocks.clear(); } + +private: + template + static inline void encodeTemplate(const uint8_t* data, uint8_t* resultPtr, bool swapBytes) { + OrderByKeyEncoder::encodeData(*(type*)data, resultPtr, swapBytes); + } + + template + static void encodeData(type /*data*/, uint8_t* /*resultPtr*/, bool /*swapBytes*/) { + KU_UNREACHABLE; + } + + static inline uint8_t flipSign(uint8_t key_byte) { return key_byte ^ 128; } + + void flipBytesIfNecessary(uint32_t keyColIdx, uint8_t* tuplePtr, uint32_t numEntriesToEncode, + common::LogicalType& type); + + void encodeFlatVector(common::ValueVector* vector, uint8_t* tuplePtr, uint32_t keyColIdx); + + void encodeUnflatVector(common ::ValueVector* vector, uint8_t* tuplePtr, uint32_t encodedTuples, + uint32_t numEntriesToEncode, uint32_t keyColIdx); + + void encodeVector(common::ValueVector* vector, uint8_t* tuplePtr, uint32_t encodedTuples, + uint32_t numEntriesToEncode, uint32_t keyColIdx); + + void encodeFTIdx(uint32_t numEntriesToEncode, uint8_t* tupleInfoPtr); + + void allocateMemoryIfFull(); + + static void getEncodingFunction(common::PhysicalTypeID physicalType, encode_function_t& func); + +private: + storage::MemoryManager* memoryManager; + std::vector> keyBlocks; + std::vector isAscOrder; + uint32_t numBytesPerTuple; + uint32_t maxNumTuplesPerBlock; + uint32_t ftBlockIdx = 0; + // Since we encode 3 bytes for ftBlockOffset, the maxFTBlockOffset is 2^24 - 1. + static const uint32_t MAX_FT_BLOCK_OFFSET = (1ul << 24) - 1; + uint32_t ftBlockOffset = 0; + // We only encode 1 byte for ftIndex, this limits the maximum number of threads of our system to + // 256. + uint8_t ftIdx; + uint32_t numTuplesPerBlockInFT; + // We need to swap the encoded binary strings if we are using little endian hardware. + bool swapBytes; + std::vector encodeFunctions; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_merge.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_merge.h new file mode 100644 index 0000000000..1c8acf174f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_merge.h @@ -0,0 +1,41 @@ +#pragma once + +#include "processor/operator/order_by/sort_state.h" +#include "processor/operator/physical_operator.h" +#include "processor/operator/sink.h" +#include "processor/result/result_set.h" + +namespace lbug { +namespace processor { + +class OrderByMerge final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::ORDER_BY_MERGE; + +public: + OrderByMerge(std::shared_ptr sharedState, + std::shared_ptr sharedDispatcher, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, id, printInfo->copy()}, sharedState{std::move(sharedState)}, + sharedDispatcher{std::move(sharedDispatcher)} {} + + bool isSource() const override { return true; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(sharedState, sharedDispatcher, id, printInfo->copy()); + } + +private: + void initGlobalStateInternal(ExecutionContext* context) override; + +private: + std::shared_ptr sharedState; + std::unique_ptr localMerger; + std::shared_ptr sharedDispatcher; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_scan.h new file mode 100644 index 0000000000..5ed037ff9c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/order_by_scan.h @@ -0,0 +1,59 @@ +#pragma once + +#include "processor/operator/order_by/sort_state.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct OrderByScanLocalState { + std::vector vectorsToRead; + std::unique_ptr payloadScanner; + uint64_t numTuples = 0; + uint64_t numTuplesRead = 0; + + void init(std::vector& outVectorPos, SortSharedState& sharedState, + ResultSet& resultSet); + + // NOLINTNEXTLINE(readability-make-member-function-const): Updates vectorsToRead. + uint64_t scan() { + uint64_t tuplesRead = payloadScanner->scan(vectorsToRead); + numTuplesRead += tuplesRead; + return tuplesRead; + } +}; + +// To preserve the ordering of tuples, the orderByScan operator will only +// be executed in single-thread mode. +class OrderByScan final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::ORDER_BY_SCAN; + +public: + OrderByScan(std::vector outVectorPos, std::shared_ptr sharedState, + uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, outVectorPos{std::move(outVectorPos)}, + localState{std::make_unique()}, + sharedState{std::move(sharedState)} {} + + bool isSource() const override { return true; } + // Ordered table should be scanned in single-thread mode. + bool isParallel() const override { return false; } + + bool getNextTuplesInternal(ExecutionContext* context) override; + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(outVectorPos, sharedState, id, printInfo->copy()); + } + + double getProgress(ExecutionContext* context) const override; + +private: + std::vector outVectorPos; + std::unique_ptr localState; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/radix_sort.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/radix_sort.h new file mode 100644 index 0000000000..3ca6cf6a76 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/radix_sort.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include "processor/operator/order_by/key_block_merger.h" +#include "processor/operator/order_by/order_by_key_encoder.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +struct TieRange { +public: + uint32_t startingTupleIdx; + uint32_t endingTupleIdx; + inline uint32_t getNumTuples() const { return endingTupleIdx - startingTupleIdx + 1; } + explicit TieRange(uint32_t startingTupleIdx, uint32_t endingTupleIdx) + : startingTupleIdx{startingTupleIdx}, endingTupleIdx{endingTupleIdx} {} +}; + +// RadixSort sorts a block of binary strings using the radixSort and quickSort (only for comparing +// string overflow pointers). The algorithm loops through each column of the orderByVectors. If it +// sees a column with string, which is variable length, it will call radixSort to sort the columns +// seen so far. If there are tie tuples, it will compare the overflow ptr of strings. For subsequent +// columns, the algorithm only calls radixSort on tie tuples. +class RadixSort { +public: + RadixSort(storage::MemoryManager* memoryManager, FactorizedTable& factorizedTable, + OrderByKeyEncoder& orderByKeyEncoder, std::vector strKeyColsInfo); + + void sortSingleKeyBlock(const DataBlock& keyBlock); + +private: + void radixSort(uint8_t* keyBlockPtr, uint32_t numTuplesToSort, uint32_t numBytesSorted, + uint32_t numBytesToSort); + + std::vector findTies(uint8_t* keyBlockPtr, uint32_t numTuplesToFindTies, + uint32_t numBytesToSort, uint32_t baseTupleIdx) const; + + void fillTmpTuplePtrSortingBlock(TieRange& keyBlockTie, uint8_t* keyBlockPtr); + + void reOrderKeyBlock(TieRange& keyBlockTie, uint8_t* keyBlockPtr); + + // Some ties can't be solved in quicksort, just add them to ties. + template + void findStringTies(TieRange& keyBlockTie, uint8_t* keyBlockPtr, std::queue& ties, + StrKeyColInfo& keyColInfo); + + void solveStringTies(TieRange& keyBlockTie, uint8_t* keyBlockPtr, std::queue& ties, + StrKeyColInfo& keyColInfo); + +private: + std::unique_ptr tmpSortingResultBlock; + // Since we do radix sort on each dataBlock at a time, the maxNumber of tuples in the dataBlock + // is: LARGE_PAGE_SIZE / numBytesPerTuple. + // The size of tmpTuplePtrSortingBlock should be larger than: + // sizeof(uint8_t*) * MaxNumOfTuplePointers=(LARGE_PAGE_SIZE / numBytesPerTuple). + // Since we know: numBytesPerTuple >= sizeof(uint8_t*) (note: we put the + // tupleIdx/FactorizedTableIdx at the end of each row in dataBlock), sizeof(uint8_t*) * + // MaxNumOfTuplePointers=(LARGE_PAGE_SIZE / numBytesPerTuple) <= LARGE_PAGE_SIZE. As a result, + // we only need one dataBlock to store the tuplePointers while solving the string ties. + std::unique_ptr tmpTuplePtrSortingBlock; + // FactorizedTable stores all columns in the tuples that will be sorted, including the order by + // key columns. RadixSort uses factorizedTable to access the full contents of the string columns + // when resolving ties. + FactorizedTable& factorizedTable; + std::vector strKeyColsInfo; + uint32_t numBytesPerTuple; + uint32_t numBytesToRadixSort; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/sort_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/sort_state.h new file mode 100644 index 0000000000..aa86a28b60 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/sort_state.h @@ -0,0 +1,92 @@ +#pragma once + +#include + +#include "processor/operator/order_by/radix_sort.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +class SortSharedState { +public: + SortSharedState() : nextTableIdx{0}, numBytesPerTuple{0} { + sortedKeyBlocks = std::make_unique>>(); + } + + inline uint64_t getNumBytesPerTuple() const { return numBytesPerTuple; } + + inline std::vector& getStrKeyColInfo() { return strKeyColsInfo; } + + inline std::queue>* getSortedKeyBlocks() { + return sortedKeyBlocks.get(); + } + + void init(const OrderByDataInfo& orderByDataInfo); + + std::pair getLocalPayloadTable( + storage::MemoryManager& memoryManager, const FactorizedTableSchema& payloadTableSchema); + + void appendLocalSortedKeyBlock(const std::shared_ptr& mergedDataBlocks); + + void combineFTHasNoNullGuarantee(); + + std::vector getPayloadTables() const; + + inline MergedKeyBlocks* getMergedKeyBlock() const { + return sortedKeyBlocks->empty() ? nullptr : sortedKeyBlocks->front().get(); + } + +private: + std::mutex mtx; + std::vector> payloadTables; + uint8_t nextTableIdx; + std::unique_ptr>> sortedKeyBlocks; + uint32_t numBytesPerTuple; + std::vector strKeyColsInfo; +}; + +class SortLocalState { +public: + void init(const OrderByDataInfo& orderByDataInfo, SortSharedState& sharedState, + storage::MemoryManager* memoryManager); + + void append(const std::vector& keyVectors, + const std::vector& payloadVectors); + + void finalize(SortSharedState& sharedState); + +private: + std::unique_ptr orderByKeyEncoder; + std::unique_ptr radixSorter; + uint64_t globalIdx = UINT64_MAX; + FactorizedTable* payloadTable = nullptr; +}; + +class PayloadScanner { +public: + PayloadScanner(MergedKeyBlocks* keyBlockToScan, std::vector payloadTables, + uint64_t skipNumber = UINT64_MAX, uint64_t limitNumber = UINT64_MAX); + + uint64_t scan(std::vector vectorsToRead); + +private: + bool scanSingleTuple(std::vector vectorsToRead) const; + + void applyLimitOnResultVectors(std::vector vectorsToRead); + +private: + bool hasUnflatColInPayload; + uint32_t payloadIdxOffset; + std::vector colsToScan; + std::unique_ptr tuplesToRead; + std::unique_ptr blockPtrInfo; + MergedKeyBlocks* keyBlockToScan; + uint32_t nextTupleIdxToReadInMergedKeyBlock; + uint64_t endTuplesIdxToReadInMergedKeyBlock; + std::vector payloadTables; + uint64_t limitNumber; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/top_k.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/top_k.h new file mode 100644 index 0000000000..f5485e0778 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/top_k.h @@ -0,0 +1,203 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" +#include "processor/operator/sink.h" +#include "sort_state.h" + +namespace lbug { +namespace processor { + +struct TopKPrintInfo final : OPPrintInfo { + binder::expression_vector keys; + binder::expression_vector payloads; + uint64_t skipNum; + uint64_t limitNum; + + TopKPrintInfo(binder::expression_vector keys, binder::expression_vector payloads, + uint64_t skipNum, uint64_t limitNum) + : keys(std::move(keys)), payloads(std::move(payloads)), skipNum(skipNum), + limitNum(limitNum) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new TopKPrintInfo(*this)); + } + +private: + TopKPrintInfo(const TopKPrintInfo& other) + : OPPrintInfo(other), keys(other.keys), payloads(other.payloads), skipNum(other.skipNum), + limitNum(other.limitNum) {} +}; + +class TopKSortState { +public: + TopKSortState(); + + void init(const OrderByDataInfo& orderByDataInfo, storage::MemoryManager* memoryManager); + + void append(const std::vector& keyVectors, + const std::vector& payloadVectors); + + void finalize(); + + inline uint64_t getNumTuples() const { return numTuples; } + + inline SortSharedState* getSharedState() { return orderBySharedState.get(); } + + std::unique_ptr getScanner(uint64_t skip, uint64_t limit) const { + return std::make_unique(orderBySharedState->getMergedKeyBlock(), + orderBySharedState->getPayloadTables(), skip, limit); + } + +private: + std::unique_ptr orderByLocalState; + std::unique_ptr orderBySharedState; + + uint64_t numTuples; + storage::MemoryManager* memoryManager; +}; + +class TopKBuffer { + using vector_select_comparison_func = std::function; + +public: + explicit TopKBuffer(const OrderByDataInfo& orderByDataInfo) + : orderByDataInfo{&orderByDataInfo}, skip{0}, limit{0}, memoryManager{nullptr}, + hasBoundaryValue{false} { + sortState = std::make_unique(); + } + + void init(storage::MemoryManager* memoryManager, uint64_t skipNumber, uint64_t limitNumber); + + void append(const std::vector& keyVectors, + const std::vector& payloadVectors); + + void reduce(); + + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + inline void finalize() { sortState->finalize(); } + + void merge(TopKBuffer* other); + + inline std::unique_ptr getScanner() const { + return sortState->getScanner(skip, limit); + } + +private: + void initVectors(); + + template + void getSelectComparisonFunction(common::PhysicalTypeID typeID, + vector_select_comparison_func& selectFunc); + + void initCompareFuncs(); + + void setBoundaryValue(); + + bool compareBoundaryValue(const std::vector& keyVectors); + + bool compareFlatKeys(common::idx_t vectorIdxToCompare, + const std::vector keyVectors); + + void compareUnflatKeys(common::idx_t vectorIdxToCompare, + const std::vector keyVectors); + + static void appendSelState(common::SelectionVector* selVector, + common::SelectionVector* selVectorToAppend); + +public: + const OrderByDataInfo* orderByDataInfo; + std::unique_ptr sortState; + uint64_t skip; + uint64_t limit; + storage::MemoryManager* memoryManager; + std::vector compareFuncs; + std::vector equalsFuncs; + bool hasBoundaryValue; + +private: + // Holds the ownership of all temp vectors. + std::vector> tmpVectors; + std::vector> boundaryVecs; + + std::vector payloadVecsToScan; + std::vector keyVecsToScan; + std::vector lastPayloadVecsToScan; + std::vector lastKeyVecsToScan; +}; + +class TopKLocalState { +public: + void init(const OrderByDataInfo& orderByDataInfo, storage::MemoryManager* memoryManager, + ResultSet& resultSet, uint64_t skipNumber, uint64_t limitNumber); + + void append(const std::vector& keyVectors, + const std::vector& payloadVectors); + + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + inline void finalize() { buffer->finalize(); } + + std::unique_ptr buffer; +}; + +class TopKSharedState { +public: + void init(const OrderByDataInfo& orderByDataInfo, storage::MemoryManager* memoryManager, + uint64_t skipNumber, uint64_t limitNumber) { + buffer = std::make_unique(orderByDataInfo); + buffer->init(memoryManager, skipNumber, limitNumber); + } + + void mergeLocalState(TopKLocalState* localState) { + std::unique_lock lck{mtx}; + buffer->merge(localState->buffer.get()); + } + + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + inline void finalize() { buffer->finalize(); } + + std::unique_ptr buffer; + +private: + std::mutex mtx; +}; + +class TopK final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::TOP_K; + +public: + TopK(OrderByDataInfo info, std::shared_ptr sharedState, uint64_t skipNumber, + uint64_t limitNumber, std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, info(std::move(info)), + sharedState{std::move(sharedState)}, skipNumber{skipNumber}, limitNumber{limitNumber} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void initGlobalStateInternal(ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + void finalize(ExecutionContext* /*context*/) override { sharedState->finalize(); } + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), sharedState, skipNumber, limitNumber, + children[0]->copy(), id, printInfo->copy()); + } + +private: + OrderByDataInfo info; + TopKLocalState localState; + std::shared_ptr sharedState; + uint64_t skipNumber; + uint64_t limitNumber; + std::vector orderByVectors; + std::vector payloadVectors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/top_k_scanner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/top_k_scanner.h new file mode 100644 index 0000000000..7e70f7bfbd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/order_by/top_k_scanner.h @@ -0,0 +1,47 @@ +#pragma once + +#include "top_k.h" + +namespace lbug { +namespace processor { + +struct TopKLocalScanState { + std::vector vectorsToScan; + std::unique_ptr payloadScanner; + + void init(std::vector& outVectorPos, TopKSharedState& sharedState, + ResultSet& resultSet); + + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + inline uint64_t scan() { return payloadScanner->scan(vectorsToScan); } +}; + +class TopKScan final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::TOP_K_SCAN; + +public: + TopKScan(std::vector outVectorPos, std::shared_ptr sharedState, + physical_op_id id, std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, outVectorPos{std::move(outVectorPos)}, + localState{std::make_unique()}, sharedState{std::move(sharedState)} {} + + bool isSource() const override { return true; } + // Ordered table should be scanned in single-thread mode. + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(outVectorPos, sharedState, id, printInfo->copy()); + } + +private: + std::vector outVectorPos; + std::unique_ptr localState; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/partitioner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/partitioner.h new file mode 100644 index 0000000000..0c138f29b2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/partitioner.h @@ -0,0 +1,200 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/enums/column_evaluate_type.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/base_partitioner_shared_state.h" +#include "processor/operator/sink.h" +#include "storage/table/in_mem_chunked_node_group_collection.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} // namespace storage +namespace transaction { +class Transaction; +} +namespace processor { + +using partitioner_func_t = + std::function; + +struct PartitionerFunctions { + static void partitionRelData(common::ValueVector* key, common::ValueVector* partitionIdxes); +}; + +// Partitioner operator can duplicate and partition the same data chunk from child with multiple +// partitioning methods. For example, copy of rel tables require partitioning on both FWD and BWD +// direction. Each partitioning method corresponds to a PartitioningState. +struct PartitioningBuffer { + std::vector> partitions; + + void merge(const PartitioningBuffer& localPartitioningState) const; +}; + +// NOTE: Currently, Partitioner is tightly coupled with RelBatchInsert. We should generalize it +// later when necessary. Here, each partition is essentially a node group. +struct BatchInsertSharedState; +struct PartitioningInfo; +struct PartitionerDataInfo; +struct PartitionerInfo; +struct RelBatchInsertProgressSharedState; + +struct CopyPartitionerSharedState : public PartitionerSharedState { + std::mutex mtx; + storage::MemoryManager& mm; + + explicit CopyPartitionerSharedState(storage::MemoryManager& mm) : mm{mm} {} + + std::vector> partitioningBuffers; + + void initialize(const common::logical_type_vec_t& columnTypes, common::idx_t numPartitioners, + const main::ClientContext* clientContext) override; + + void resetState(common::idx_t partitioningIdx) override; + void merge(const std::vector>& localPartitioningStates); + + // Must only be called once for any given parameters. + // The data gets moved out of the shared state since some of it may be spilled to disk and will + // need to be freed after its processed. + std::unique_ptr getPartitionBuffer( + common::idx_t partitioningIdx, common::partition_idx_t partitionIdx) const { + KU_ASSERT(partitioningIdx < partitioningBuffers.size()); + KU_ASSERT(partitionIdx < partitioningBuffers[partitioningIdx]->partitions.size()); + + KU_ASSERT(partitioningBuffers[partitioningIdx]->partitions[partitionIdx].get()); + auto partitioningBuffer = + std::move(partitioningBuffers[partitioningIdx]->partitions[partitionIdx]); + // This may still run out of memory if there isn't enough space for one partitioningBuffer + // per thread + partitioningBuffer->loadFromDisk(mm); + return partitioningBuffer; + } +}; + +struct PartitionerLocalState { + std::vector> partitioningBuffers; + + PartitioningBuffer* getPartitioningBuffer(common::partition_idx_t partitioningIdx) const { + KU_ASSERT(partitioningIdx < partitioningBuffers.size()); + return partitioningBuffers[partitioningIdx].get(); + } +}; + +struct PartitioningInfo { + common::idx_t keyIdx; + partitioner_func_t partitionerFunc; + + PartitioningInfo(common::idx_t keyIdx, partitioner_func_t partitionerFunc) + : keyIdx{keyIdx}, partitionerFunc{std::move(partitionerFunc)} {} + EXPLICIT_COPY_DEFAULT_MOVE(PartitioningInfo); + +private: + PartitioningInfo(const PartitioningInfo& other) + : keyIdx{other.keyIdx}, partitionerFunc{other.partitionerFunc} {} +}; + +struct PartitionerDataInfo { + std::string tableName; + std::string fromTableName; + std::string toTableName; + std::vector columnTypes; + evaluator::evaluator_vector_t columnEvaluators; + std::vector evaluateTypes; + + PartitionerDataInfo(std::string tableName, std::string fromTableName, std::string toTableName, + std::vector columnTypes, + std::vector> columnEvaluators, + std::vector evaluateTypes) + : tableName{std::move(tableName)}, fromTableName{std::move(fromTableName)}, + toTableName{std::move(toTableName)}, columnTypes{std::move(columnTypes)}, + columnEvaluators{std::move(columnEvaluators)}, evaluateTypes{std::move(evaluateTypes)} {} + EXPLICIT_COPY_DEFAULT_MOVE(PartitionerDataInfo); + +private: + PartitionerDataInfo(const PartitionerDataInfo& other) + : tableName{other.tableName}, fromTableName{other.fromTableName}, + toTableName{other.toTableName}, columnTypes{common::LogicalType::copy(other.columnTypes)}, + columnEvaluators{copyVector(other.columnEvaluators)}, evaluateTypes{other.evaluateTypes} { + } +}; + +struct PartitionerInfo { + DataPos relOffsetDataPos; + std::vector infos; + + PartitionerInfo() {} + PartitionerInfo(const PartitionerInfo& other) : relOffsetDataPos{other.relOffsetDataPos} { + infos.reserve(other.infos.size()); + for (auto& otherInfo : other.infos) { + infos.push_back(otherInfo.copy()); + } + } + + EXPLICIT_COPY_DEFAULT_MOVE(PartitionerInfo); +}; + +struct PartitionerPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + + explicit PartitionerPrintInfo(binder::expression_vector expressions) + : expressions{std::move(expressions)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new PartitionerPrintInfo(*this)); + } + +private: + PartitionerPrintInfo(const PartitionerPrintInfo& other) + : OPPrintInfo{other}, expressions{other.expressions} {} +}; + +class Partitioner final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::PARTITIONER; + +public: + Partitioner(PartitionerInfo info, PartitionerDataInfo dataInfo, + std::shared_ptr sharedState, + std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo); + + void initGlobalStateInternal(ExecutionContext* context) override; + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + void executeInternal(ExecutionContext* context) override; + + std::shared_ptr getSharedState() { return sharedState; } + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), dataInfo.copy(), sharedState, + children[0]->copy(), id, printInfo->copy()); + } + + static void initializePartitioningStates(const common::logical_type_vec_t& columnTypes, + std::vector>& partitioningBuffers, + const std::array& + numPartitions, + common::idx_t numPartitioners); + +private: + void evaluateExpressions(uint64_t numRels) const; + common::DataChunk constructDataChunk( + const std::shared_ptr& state) const; + // TODO: For now, RelBatchInsert will guarantee all data are inside one data chunk. Should be + // generalized to resultSet later if needed. + void copyDataToPartitions(storage::MemoryManager& memoryManager, + common::partition_idx_t partitioningIdx, const common::DataChunk& chunkToCopyFrom) const; + +private: + PartitionerDataInfo dataInfo; + PartitionerInfo info; + std::shared_ptr sharedState; + std::unique_ptr localState; + + // Intermediate temp value vector. + std::unique_ptr partitionIdxes; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/path_property_probe.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/path_property_probe.h new file mode 100644 index 0000000000..5d8dba65c6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/path_property_probe.h @@ -0,0 +1,122 @@ +#pragma once + +#include "common/enums/extend_direction.h" +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct PathPropertyProbeSharedState { + std::shared_ptr nodeHashTableState; + std::shared_ptr relHashTableState; + + PathPropertyProbeSharedState(std::shared_ptr nodeHashTableState, + std::shared_ptr relHashTableState) + : nodeHashTableState{std::move(nodeHashTableState)}, + relHashTableState{std::move(relHashTableState)} {} +}; + +struct PathPropertyProbeLocalState { + std::unique_ptr hashes; + std::unique_ptr probedTuples; + std::unique_ptr matchedTuples; + + PathPropertyProbeLocalState() { + hashes = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + probedTuples = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + matchedTuples = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + } +}; + +struct PathPropertyProbeInfo { + DataPos pathPos = DataPos(); + + DataPos leftNodeIDPos = DataPos(); + DataPos rightNodeIDPos = DataPos(); + DataPos inputNodeIDsPos = DataPos(); + DataPos inputEdgeIDsPos = DataPos(); + DataPos directionPos = DataPos(); + + std::unordered_map tableIDToName; + + std::vector nodeFieldIndices; + std::vector relFieldIndices; + std::vector nodeTableColumnIndices; + std::vector relTableColumnIndices; + + common::ExtendDirection extendDirection = common::ExtendDirection::FWD; + bool extendFromLeft = false; + + PathPropertyProbeInfo() = default; + EXPLICIT_COPY_DEFAULT_MOVE(PathPropertyProbeInfo); + +private: + PathPropertyProbeInfo(const PathPropertyProbeInfo& other) { + pathPos = other.pathPos; + leftNodeIDPos = other.leftNodeIDPos; + rightNodeIDPos = other.rightNodeIDPos; + inputNodeIDsPos = other.inputNodeIDsPos; + inputEdgeIDsPos = other.inputEdgeIDsPos; + directionPos = other.directionPos; + tableIDToName = other.tableIDToName; + nodeFieldIndices = other.nodeFieldIndices; + relFieldIndices = other.relFieldIndices; + nodeTableColumnIndices = other.nodeTableColumnIndices; + relTableColumnIndices = other.relTableColumnIndices; + extendDirection = other.extendDirection; + extendFromLeft = other.extendFromLeft; + } +}; + +class PathPropertyProbe : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::PATH_PROPERTY_PROBE; + +public: + PathPropertyProbe(PathPropertyProbeInfo info, + std::shared_ptr sharedState, + std::unique_ptr probeChild, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(probeChild), id, std::move(printInfo)}, + info{std::move(info)}, sharedState{std::move(sharedState)} {} + + void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final; + + bool getNextTuplesInternal(ExecutionContext* context) final; + + std::unique_ptr copy() final { + return std::make_unique(info.copy(), sharedState, children[0]->copy(), + id, printInfo->copy()); + } + +private: + void probe(JoinHashTable* hashTable, uint64_t sizeProbed, uint64_t sizeToProbe, + common::ValueVector* idVector, const std::vector& propertyVectors, + const std::vector& colIndicesToScan) const; + +private: + PathPropertyProbeInfo info; + std::shared_ptr sharedState; + PathPropertyProbeLocalState localState; + + common::ValueVector* pathNodesVector = nullptr; + common::ValueVector* pathRelsVector = nullptr; + common::ValueVector* pathNodeIDsDataVector = nullptr; + common::ValueVector* pathNodeLabelsDataVector = nullptr; + common::ValueVector* pathRelIDsDataVector = nullptr; + common::ValueVector* pathRelLabelsDataVector = nullptr; + common::ValueVector* pathSrcNodeIDsDataVector = nullptr; + common::ValueVector* pathDstNodeIDsDataVector = nullptr; + + std::vector pathNodesPropertyDataVectors; + std::vector pathRelsPropertyDataVectors; + + common::ValueVector* inputLeftNodeIDVector = nullptr; + common::ValueVector* inputRightNodeIDVector = nullptr; + common::ValueVector* inputNodeIDsVector = nullptr; + common::ValueVector* inputRelIDsVector = nullptr; + common::ValueVector* inputDirectionVector = nullptr; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/batch_insert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/batch_insert.h new file mode 100644 index 0000000000..9d14fa3eaf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/batch_insert.h @@ -0,0 +1,121 @@ +#pragma once + +#include "main/query_result/materialized_query_result.h" +#include "processor/operator/sink.h" +#include "processor/result/factorized_table.h" +#include "storage/page_allocator.h" +#include "storage/table/table.h" + +namespace lbug { +namespace storage { +class MemoryManager; +class ChunkedNodeGroup; +} // namespace storage +namespace processor { + +struct BatchInsertInfo { + std::string tableName; + bool compressionEnabled = true; + + std::vector warningColumnTypes; + // column types include property and warning + std::vector columnTypes; + std::vector insertColumnIDs; + std::vector outputDataColumns; + std::vector warningDataColumns; + + BatchInsertInfo(std::string tableName, std::vector warningColumnTypes) + : tableName{std::move(tableName)}, warningColumnTypes{std::move(warningColumnTypes)} {} + BatchInsertInfo(const BatchInsertInfo& other) + : tableName{other.tableName}, compressionEnabled{other.compressionEnabled}, + warningColumnTypes{copyVector(other.warningColumnTypes)}, + columnTypes{copyVector(other.columnTypes)}, insertColumnIDs{other.insertColumnIDs}, + outputDataColumns{other.outputDataColumns}, warningDataColumns{other.warningDataColumns} { + } + DELETE_COPY_ASSN(BatchInsertInfo); + virtual ~BatchInsertInfo() = default; + + virtual std::unique_ptr copy() const = 0; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +struct LBUG_API BatchInsertSharedState { + std::mutex mtx; + std::atomic numRows; + + // Use a separate mutex for numErroredRows to avoid double-locking in local error handlers + // As access to numErroredRows is independent of access to other shared state + std::mutex erroredRowMutex; + std::shared_ptr numErroredRows; + + storage::Table* table; + std::shared_ptr fTable; + + explicit BatchInsertSharedState(std::shared_ptr fTable) + : numRows{0}, numErroredRows(std::make_shared(0)), table{nullptr}, + fTable{std::move(fTable)} {}; + BatchInsertSharedState(const BatchInsertSharedState& other) = delete; + + virtual ~BatchInsertSharedState() = default; + + void incrementNumRows(common::row_idx_t numRowsToIncrement) { + numRows.fetch_add(numRowsToIncrement); + } + common::row_idx_t getNumRows() const { return numRows.load(); } + common::row_idx_t getNumErroredRows() { + common::UniqLock lockGuard{erroredRowMutex}; + return *numErroredRows; + } + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +struct BatchInsertLocalState { + std::unique_ptr chunkedGroup; + storage::PageAllocator* optimisticAllocator = nullptr; + + virtual ~BatchInsertLocalState() = default; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } +}; + +class LBUG_API BatchInsert : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::BATCH_INSERT; + +public: + BatchInsert(std::unique_ptr info, + std::shared_ptr sharedState, physical_op_id id, + std::unique_ptr printInfo) + : Sink{type_, id, std::move(printInfo)}, info{std::move(info)}, + sharedState{std::move(sharedState)} {} + + ~BatchInsert() override = default; + + std::shared_ptr getResultFTable() const override { + return sharedState->fTable; + } + + std::unique_ptr getQueryResult() const override { + return std::make_unique(sharedState->fTable); + } + + std::unique_ptr copy() override = 0; + +protected: + std::unique_ptr info; + std::shared_ptr sharedState; + std::unique_ptr localState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/batch_insert_error_handler.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/batch_insert_error_handler.h new file mode 100644 index 0000000000..90dfef69d6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/batch_insert_error_handler.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include "common/types/types.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/reader/copy_from_error.h" + +namespace lbug { +namespace processor { +struct BatchInsertCachedError { + explicit BatchInsertCachedError(std::string message, + const std::optional& warningData = {}); + BatchInsertCachedError() = default; + + std::string message; + + // CSV Reader data + std::optional warningData; +}; + +class BatchInsertErrorHandler { +public: + BatchInsertErrorHandler(ExecutionContext* context, bool ignoreErrors, + std::shared_ptr sharedErrorCounter = nullptr, + std::mutex* sharedErrorCounterMtx = nullptr); + + void handleError(std::string message, const std::optional& warningData = {}); + + void handleError(BatchInsertCachedError error); + + void flushStoredErrors(); + bool getIgnoreErrors() const; + +private: + common::row_idx_t getNumErrors() const; + void addNewVectorsIfNeeded(); + void clearErrors(); + + static constexpr uint64_t LOCAL_WARNING_LIMIT = 1024; + + bool ignoreErrors; + uint64_t warningLimit; + ExecutionContext* context; + uint64_t currentInsertIdx; + + std::mutex* sharedErrorCounterMtx; + std::shared_ptr sharedErrorCounter; + + std::vector cachedErrors; +}; +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/copy_rel_batch_insert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/copy_rel_batch_insert.h new file mode 100644 index 0000000000..d931cdd229 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/copy_rel_batch_insert.h @@ -0,0 +1,48 @@ +#pragma once + +#include "processor/operator/persistent/rel_batch_insert.h" + +namespace lbug { +namespace storage { +class CSRNodeGroup; +struct InMemChunkedCSRHeader; +} // namespace storage + +namespace processor { + +struct CopyRelBatchInsertExecutionState : RelBatchInsertExecutionState { + std::unique_ptr partitioningBuffer; +}; + +class CopyRelBatchInsert final : public RelBatchInsertImpl { +public: + std::unique_ptr copy() override { + return std::make_unique(*this); + } + + std::unique_ptr initExecutionState( + const PartitionerSharedState& partitionerSharedState, const RelBatchInsertInfo& relInfo, + common::node_group_idx_t nodeGroupIdx) override; + + void populateCSRLengths(RelBatchInsertExecutionState& executionState, + storage::InMemChunkedCSRHeader& csrHeader, common::offset_t numNodes, + const RelBatchInsertInfo& relInfo) override; + + void finalizeStartCSROffsets(RelBatchInsertExecutionState& executionState, + storage::InMemChunkedCSRHeader& csrHeader, const RelBatchInsertInfo& relInfo) override; + + void writeToTable(RelBatchInsertExecutionState& executionState, + const storage::InMemChunkedCSRHeader& csrHeader, const RelBatchInsertLocalState& localState, + BatchInsertSharedState& sharedState, const RelBatchInsertInfo& relInfo) override; + +private: + static void setRowIdxFromCSROffsets(storage::ColumnChunkData& rowIdxChunk, + storage::ColumnChunkData& csrOffsetChunk); + + static void populateCSRLengthsInternal(const storage::InMemChunkedCSRHeader& csrHeader, + common::offset_t numNodes, storage::InMemChunkedNodeGroupCollection& partition, + common::column_id_t boundNodeOffsetColumn); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/copy_to.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/copy_to.h new file mode 100644 index 0000000000..2619147067 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/copy_to.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include "function/export/export_function.h" +#include "processor/operator/sink.h" +#include "processor/result/result_set.h" + +namespace lbug { +namespace processor { + +struct CopyToInfo { + function::ExportFunction exportFunc; + std::unique_ptr bindData; + std::vector inputVectorPoses; + std::vector isFlatVec; + + CopyToInfo(function::ExportFunction exportFunc, + std::unique_ptr bindData, + std::vector inputVectorPoses, std::vector isFlatVec) + : exportFunc{std::move(exportFunc)}, bindData{std::move(bindData)}, + inputVectorPoses{std::move(inputVectorPoses)}, isFlatVec{std::move(isFlatVec)} {} + + CopyToInfo copy() const { + return CopyToInfo{exportFunc, bindData->copy(), inputVectorPoses, isFlatVec}; + } +}; + +struct CopyToLocalState { + std::unique_ptr exportFuncLocalState; + std::vector> inputVectors; +}; + +struct CopyToPrintInfo final : OPPrintInfo { + std::vector columnNames; + std::string fileName; + + CopyToPrintInfo(std::vector columnNames, std::string fileName) + : columnNames{std::move(columnNames)}, fileName{std::move(fileName)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new CopyToPrintInfo(*this)); + } + +private: + CopyToPrintInfo(const CopyToPrintInfo& other) + : OPPrintInfo{other}, columnNames{other.columnNames}, fileName{other.fileName} {} +}; + +class CopyTo final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::COPY_TO; + +public: + CopyTo(CopyToInfo info, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, info{std::move(info)}, + sharedState{std::move(sharedState)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void initGlobalStateInternal(ExecutionContext* context) override; + + void finalize(ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + std::pair&> getParallelFlag() { + return {std::filesystem::path(info.bindData->fileName).filename().string(), + sharedState->parallelFlag}; + } + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), sharedState, children[0]->copy(), id, + printInfo->copy()); + } + +private: + CopyToInfo info; + CopyToLocalState localState; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/delete.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/delete.h new file mode 100644 index 0000000000..664a250008 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/delete.h @@ -0,0 +1,95 @@ +#pragma once + +#include "delete_executor.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct DeleteNodePrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + common::DeleteNodeType deleteType; + + DeleteNodePrintInfo(binder::expression_vector expressions, common::DeleteNodeType deleteType) + : expressions{std::move(expressions)}, deleteType{deleteType} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new DeleteNodePrintInfo(*this)); + } + +private: + DeleteNodePrintInfo(const DeleteNodePrintInfo& other) + : OPPrintInfo{other}, expressions{other.expressions}, deleteType{other.deleteType} {} +}; + +class DeleteNode final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DELETE_; + +public: + DeleteNode(std::vector> executors, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + executors{std::move(executors)} {} + + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(copyVector(executors), children[0]->copy(), id, + printInfo->copy()); + } + +private: + std::vector> executors; +}; + +struct DeleteRelPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + + explicit DeleteRelPrintInfo(binder::expression_vector expressions) + : expressions{std::move(expressions)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new DeleteRelPrintInfo(*this)); + } + +private: + DeleteRelPrintInfo(const DeleteRelPrintInfo& other) + : OPPrintInfo{other}, expressions{other.expressions} {} +}; + +class DeleteRel final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DELETE_; + +public: + DeleteRel(std::vector> executors, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + executors{std::move(executors)} {} + + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(copyVector(executors), children[0]->copy(), id, + printInfo->copy()); + } + +private: + std::vector> executors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/delete_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/delete_executor.h new file mode 100644 index 0000000000..4319c53f43 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/delete_executor.h @@ -0,0 +1,220 @@ +#pragma once + +#include +#include + +#include "common/enums/delete_type.h" +#include "common/vector/value_vector.h" +#include "processor/result/result_set.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +namespace lbug { +namespace processor { + +struct NodeDeleteInfo { + common::DeleteNodeType deleteType; + DataPos nodeIDPos; + + common::ValueVector* nodeIDVector = nullptr; + + NodeDeleteInfo(common::DeleteNodeType deleteType, const DataPos& nodeIDPos) + : deleteType{deleteType}, nodeIDPos{nodeIDPos} {}; + EXPLICIT_COPY_DEFAULT_MOVE(NodeDeleteInfo); + + void init(const ResultSet& resultSet); + +private: + NodeDeleteInfo(const NodeDeleteInfo& other) + : deleteType{other.deleteType}, nodeIDPos{other.nodeIDPos} {} +}; + +struct NodeTableDeleteInfo { + storage::NodeTable* table; + std::unordered_set fwdRelTables; + std::unordered_set bwdRelTables; + DataPos pkPos; + + common::ValueVector* pkVector; + + NodeTableDeleteInfo(storage::NodeTable* table, + std::unordered_set fwdRelTables, + std::unordered_set bwdRelTables, const DataPos& pkPos) + : table{table}, fwdRelTables{std::move(fwdRelTables)}, + bwdRelTables{std::move(bwdRelTables)}, pkPos{pkPos}, pkVector{nullptr} {}; + EXPLICIT_COPY_DEFAULT_MOVE(NodeTableDeleteInfo); + + void init(const ResultSet& resultSet); + + void deleteFromRelTable(transaction::Transaction* transaction, + common::ValueVector* nodeIDVector) const; + void detachDeleteFromRelTable(transaction::Transaction* transaction, + storage::RelTableDeleteState* detachDeleteState) const; + +private: + NodeTableDeleteInfo(const NodeTableDeleteInfo& other) + : table{other.table}, fwdRelTables{other.fwdRelTables}, bwdRelTables{other.bwdRelTables}, + pkPos{other.pkPos}, pkVector{nullptr} {} +}; + +class NodeDeleteExecutor { +public: + explicit NodeDeleteExecutor(NodeDeleteInfo info) : info{std::move(info)} {} + NodeDeleteExecutor(const NodeDeleteExecutor& other) : info{other.info.copy()} {} + virtual ~NodeDeleteExecutor() = default; + + virtual void init(ResultSet* resultSet, ExecutionContext* context); + + virtual void delete_(ExecutionContext* context) = 0; + + virtual std::unique_ptr copy() const = 0; + +protected: + NodeDeleteInfo info; + std::unique_ptr dstNodeIDVector; + std::unique_ptr relIDVector; + std::unique_ptr detachDeleteState; +}; + +// Handle MATCH (n) (DETACH)? DELETE n +class EmptyNodeDeleteExecutor final : public NodeDeleteExecutor { +public: + explicit EmptyNodeDeleteExecutor(NodeDeleteInfo info) : NodeDeleteExecutor{std::move(info)} {} + EmptyNodeDeleteExecutor(const EmptyNodeDeleteExecutor& other) : NodeDeleteExecutor{other} {} + + void delete_(ExecutionContext*) override {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +class SingleLabelNodeDeleteExecutor final : public NodeDeleteExecutor { +public: + SingleLabelNodeDeleteExecutor(NodeDeleteInfo info, NodeTableDeleteInfo tableInfo) + : NodeDeleteExecutor(std::move(info)), tableInfo{std::move(tableInfo)} {} + SingleLabelNodeDeleteExecutor(const SingleLabelNodeDeleteExecutor& other) + : NodeDeleteExecutor(other), tableInfo{other.tableInfo.copy()} {} + + void init(ResultSet* resultSet, ExecutionContext*) override; + void delete_(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + NodeTableDeleteInfo tableInfo; +}; + +class MultiLabelNodeDeleteExecutor final : public NodeDeleteExecutor { +public: + MultiLabelNodeDeleteExecutor(NodeDeleteInfo info, + common::table_id_map_t tableInfos) + : NodeDeleteExecutor(std::move(info)), tableInfos{std::move(tableInfos)} {} + MultiLabelNodeDeleteExecutor(const MultiLabelNodeDeleteExecutor& other) + : NodeDeleteExecutor(other), tableInfos{copyUnorderedMap(other.tableInfos)} {} + + void init(ResultSet* resultSet, ExecutionContext*) override; + void delete_(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + common::table_id_map_t tableInfos; +}; + +struct RelDeleteInfo { + DataPos srcNodeIDPos; + DataPos dstNodeIDPos; + DataPos relIDPos; + + common::ValueVector* srcNodeIDVector = nullptr; + common::ValueVector* dstNodeIDVector = nullptr; + common::ValueVector* relIDVector = nullptr; + + RelDeleteInfo() + : srcNodeIDPos{INVALID_DATA_CHUNK_POS, INVALID_VALUE_VECTOR_POS}, + dstNodeIDPos{INVALID_DATA_CHUNK_POS, INVALID_VALUE_VECTOR_POS}, + relIDPos{INVALID_DATA_CHUNK_POS, INVALID_VALUE_VECTOR_POS} {} + RelDeleteInfo(DataPos srcNodeIDPos, DataPos dstNodeIDPos, DataPos relIDPos) + : srcNodeIDPos{srcNodeIDPos}, dstNodeIDPos{dstNodeIDPos}, relIDPos{relIDPos} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelDeleteInfo); + + void init(const ResultSet& resultSet); + +private: + RelDeleteInfo(const RelDeleteInfo& other) + : srcNodeIDPos{other.srcNodeIDPos}, dstNodeIDPos{other.dstNodeIDPos}, + relIDPos{other.relIDPos} {} +}; + +class RelDeleteExecutor { +public: + explicit RelDeleteExecutor(RelDeleteInfo info) : info{std::move(info)} {} + RelDeleteExecutor(const RelDeleteExecutor& other) : info{other.info.copy()} {} + virtual ~RelDeleteExecutor() = default; + + virtual void init(ResultSet* resultSet, ExecutionContext* context); + + virtual void delete_(ExecutionContext* context) = 0; + + virtual std::unique_ptr copy() const = 0; + +protected: + RelDeleteInfo info; +}; + +class EmptyRelDeleteExecutor final : public RelDeleteExecutor { +public: + explicit EmptyRelDeleteExecutor() : RelDeleteExecutor{RelDeleteInfo{}} {} + EmptyRelDeleteExecutor(const EmptyRelDeleteExecutor& other) : RelDeleteExecutor{other} {} + + void init(ResultSet*, ExecutionContext*) override {} + + void delete_(ExecutionContext*) override {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +class SingleLabelRelDeleteExecutor final : public RelDeleteExecutor { +public: + SingleLabelRelDeleteExecutor(storage::RelTable* table, RelDeleteInfo info) + : RelDeleteExecutor(std::move(info)), table{table} {} + SingleLabelRelDeleteExecutor(const SingleLabelRelDeleteExecutor& other) + : RelDeleteExecutor{other}, table{other.table} {} + + void delete_(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + storage::RelTable* table; +}; + +class MultiLabelRelDeleteExecutor final : public RelDeleteExecutor { +public: + MultiLabelRelDeleteExecutor(common::table_id_map_t tableIDToTableMap, + RelDeleteInfo info) + : RelDeleteExecutor(std::move(info)), tableIDToTableMap{std::move(tableIDToTableMap)} {} + MultiLabelRelDeleteExecutor(const MultiLabelRelDeleteExecutor& other) + : RelDeleteExecutor{other}, tableIDToTableMap{other.tableIDToTableMap} {} + + void delete_(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + common::table_id_map_t tableIDToTableMap; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/index_builder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/index_builder.h new file mode 100644 index 0000000000..b1286dc5c0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/index_builder.h @@ -0,0 +1,210 @@ +#pragma once + +#include +#include + +#include "common/copy_constructors.h" +#include "common/mpsc_queue.h" +#include "common/static_vector.h" +#include "common/types/int128_t.h" +#include "common/types/types.h" +#include "common/types/uint128_t.h" +#include "processor/operator/persistent/node_batch_insert_error_handler.h" +#include "storage/index/hash_index.h" +#include "storage/index/hash_index_utils.h" +#include "storage/table/column_chunk_data.h" + +namespace lbug { +namespace transaction { +class Transaction; +}; +namespace storage { +class NodeTable; +}; +namespace processor { + +constexpr size_t SHOULD_FLUSH_QUEUE_SIZE = 32; + +constexpr size_t WARNING_DATA_BUFFER_SIZE = 64; +using OptionalWarningDataBuffer = + std::unique_ptr>; + +using OptionalWarningSourceData = std::optional; + +template +struct IndexBufferWithWarningData { + storage::IndexBuffer indexBuffer; + OptionalWarningDataBuffer warningDataBuffer; + + bool full() const; + void append(T key, common::offset_t value, OptionalWarningSourceData&& warningData); +}; + +class IndexBuilderGlobalQueues { +public: + explicit IndexBuilderGlobalQueues(transaction::Transaction* transaction, + storage::NodeTable* nodeTable); + + template + void insert(size_t index, IndexBufferWithWarningData elem, + NodeBatchInsertErrorHandler& errorHandler) { + auto& typedQueues = std::get>(queues).array; + typedQueues[index].push(std::move(elem)); + if (typedQueues[index].approxSize() < SHOULD_FLUSH_QUEUE_SIZE) { + return; + } + maybeConsumeIndex(index, errorHandler); + } + + void consume(NodeBatchInsertErrorHandler& errorHandler); + + common::PhysicalTypeID pkTypeID() const; + +private: + void maybeConsumeIndex(size_t index, NodeBatchInsertErrorHandler& errorHandler); + + storage::NodeTable* nodeTable; + + template + // NOLINTNEXTLINE (cppcoreguidelines-pro-type-member-init) + struct Queue { + std::array>, storage::NUM_HASH_INDEXES> + array; + // Type information to help std::visit. Value is not used + T type; + }; + + // Queues for distributing primary keys. + std::variant, Queue, Queue, Queue, Queue, + Queue, Queue, Queue, Queue, Queue, + Queue, Queue, Queue> + queues; + transaction::Transaction* transaction; +}; + +class IndexBuilderLocalBuffers { +public: + explicit IndexBuilderLocalBuffers(IndexBuilderGlobalQueues& globalQueues); + + void insert(std::string key, common::offset_t value, OptionalWarningSourceData&& warningData, + NodeBatchInsertErrorHandler& errorHandler) { + auto indexPos = storage::HashIndexUtils::getHashIndexPosition(std::string_view(key)); + auto& stringBuffer = (*std::get>(buffers))[indexPos]; + + if (stringBuffer.full()) { + // StaticVector's move constructor leaves the original vector valid and empty + globalQueues->insert(indexPos, std::move(stringBuffer), errorHandler); + } + + // moving the buffer clears it which is the expected behaviour + // NOLINTNEXTLINE (bugprone-use-after-move) + stringBuffer.append(std::move(key), value, std::move(warningData)); + } + + template + void insert(T key, common::offset_t value, OptionalWarningSourceData&& warningData, + NodeBatchInsertErrorHandler& errorHandler) { + auto indexPos = storage::HashIndexUtils::getHashIndexPosition(key); + auto& buffer = (*std::get>(buffers))[indexPos]; + + if (buffer.full()) { + globalQueues->insert(indexPos, std::move(buffer), errorHandler); + } + + // moving the buffer clears it which is the expected behaviour + // NOLINTNEXTLINE (bugprone-use-after-move) + buffer.append(key, value, std::move(warningData)); + } + + void flush(NodeBatchInsertErrorHandler& errorHandler); + +private: + IndexBuilderGlobalQueues* globalQueues; + + // These arrays are much too large to be inline. + template + using Buffers = std::array, storage::NUM_HASH_INDEXES>; + template + using UniqueBuffers = std::unique_ptr>; + std::variant, UniqueBuffers, UniqueBuffers, + UniqueBuffers, UniqueBuffers, UniqueBuffers, + UniqueBuffers, UniqueBuffers, UniqueBuffers, + UniqueBuffers, UniqueBuffers, UniqueBuffers, + UniqueBuffers> + buffers; +}; + +class IndexBuilderSharedState { + friend class IndexBuilder; + +public: + explicit IndexBuilderSharedState(transaction::Transaction* transaction, + storage::NodeTable* nodeTable) + : globalQueues{transaction, nodeTable}, nodeTable(nodeTable) {} + void consume(NodeBatchInsertErrorHandler& errorHandler) { + return globalQueues.consume(errorHandler); + } + + void addProducer() { producers.fetch_add(1, std::memory_order_relaxed); } + void quitProducer(); + bool isDone() const { return done.load(std::memory_order_relaxed); } + +private: + IndexBuilderGlobalQueues globalQueues; + storage::NodeTable* nodeTable; + + std::atomic producers; + std::atomic done; +}; + +// RAII for producer counting. +class ProducerToken { +public: + explicit ProducerToken(std::shared_ptr sharedState) + : sharedState(std::move(sharedState)) { + this->sharedState->addProducer(); + } + DELETE_COPY_DEFAULT_MOVE(ProducerToken); + + void quit() { + sharedState->quitProducer(); + sharedState.reset(); + } + ~ProducerToken() { + if (sharedState) { + quit(); + } + } + +private: + std::shared_ptr sharedState; +}; + +class IndexBuilder { + +public: + DELETE_COPY_DEFAULT_MOVE(IndexBuilder); + explicit IndexBuilder(std::shared_ptr sharedState); + + IndexBuilder clone() { return IndexBuilder(sharedState); } + + void insert(const storage::ColumnChunkData& chunk, + const std::vector& warningData, common::offset_t nodeOffset, + common::offset_t numNodes, NodeBatchInsertErrorHandler& errorHandler); + + ProducerToken getProducerToken() const { return ProducerToken(sharedState); } + + void finishedProducing(NodeBatchInsertErrorHandler& errorHandler); + void finalize(ExecutionContext* context, NodeBatchInsertErrorHandler& errorHandler); + +private: + bool checkNonNullConstraint(const storage::ColumnChunkData& chunk, + const std::vector& warningData, common::offset_t nodeOffset, + common::offset_t chunkOffset, NodeBatchInsertErrorHandler& errorHandler); + std::shared_ptr sharedState; + + IndexBuilderLocalBuffers localBuffers; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/insert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/insert.h new file mode 100644 index 0000000000..475eac12bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/insert.h @@ -0,0 +1,53 @@ +#pragma once + +#include "insert_executor.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct InsertPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + common::ConflictAction action; + + InsertPrintInfo(binder::expression_vector expressions, common::ConflictAction action) + : expressions(std::move(expressions)), action(action) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new InsertPrintInfo(*this)); + } + +private: + InsertPrintInfo(const InsertPrintInfo& other) + : OPPrintInfo(other), expressions(other.expressions), action(other.action) {} +}; + +class Insert final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::INSERT; + +public: + Insert(std::vector nodeExecutors, + std::vector relExecutors, std::unique_ptr child, + uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + nodeExecutors{std::move(nodeExecutors)}, relExecutors{std::move(relExecutors)} {} + + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(copyVector(nodeExecutors), copyVector(relExecutors), + children[0]->copy(), id, printInfo->copy()); + } + +private: + std::vector nodeExecutors; + std::vector relExecutors; +}; +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/insert_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/insert_executor.h new file mode 100644 index 0000000000..3b351ba1a0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/insert_executor.h @@ -0,0 +1,148 @@ +#pragma once + +#include "common/enums/conflict_action.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/execution_context.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +namespace lbug { +namespace processor { + +// Operator level info +struct NodeInsertInfo { + DataPos nodeIDPos; + // Column vector pos is invalid if it doesn't need to be projected. + std::vector columnsPos; + common::ConflictAction conflictAction; + + common::ValueVector* nodeIDVector = nullptr; + std::vector columnVectors; + + NodeInsertInfo(DataPos nodeIDPos, std::vector columnsPos, + common::ConflictAction conflictAction) + : nodeIDPos{nodeIDPos}, columnsPos{std::move(columnsPos)}, conflictAction{conflictAction} {} + EXPLICIT_COPY_DEFAULT_MOVE(NodeInsertInfo); + + void init(const ResultSet& resultSet); + + void updateNodeID(common::nodeID_t nodeID) const; + common::nodeID_t getNodeID() const; + +private: + NodeInsertInfo(const NodeInsertInfo& other) + : nodeIDPos{other.nodeIDPos}, columnsPos{other.columnsPos}, + conflictAction{other.conflictAction} {} +}; + +// Table level info +struct NodeTableInsertInfo { + storage::NodeTable* table; + evaluator::evaluator_vector_t columnDataEvaluators; + + common::ValueVector* pkVector; + std::vector columnDataVectors; + + NodeTableInsertInfo(storage::NodeTable* table, + evaluator::evaluator_vector_t columnDataEvaluators) + : table{table}, columnDataEvaluators{std::move(columnDataEvaluators)}, pkVector{nullptr} {} + EXPLICIT_COPY_DEFAULT_MOVE(NodeTableInsertInfo); + + void init(const ResultSet& resultSet, main::ClientContext* context); + +private: + NodeTableInsertInfo(const NodeTableInsertInfo& other) + : table{other.table}, columnDataEvaluators{copyVector(other.columnDataEvaluators)}, + pkVector{nullptr} {} +}; + +class NodeInsertExecutor { +public: + NodeInsertExecutor(NodeInsertInfo info, NodeTableInsertInfo tableInfo) + : info{std::move(info)}, tableInfo{std::move(tableInfo)} {} + EXPLICIT_COPY_DEFAULT_MOVE(NodeInsertExecutor); + + void init(ResultSet* resultSet, const ExecutionContext* context); + + void setNodeIDVectorToNonNull() const; + common::nodeID_t insert(main::ClientContext* context); + + // For MERGE, we might need to skip the insert for duplicate input. But still, we need to write + // the output vector for later usage. + void skipInsert() const; + +private: + NodeInsertExecutor(const NodeInsertExecutor& other) + : info{other.info.copy()}, tableInfo{other.tableInfo.copy()} {} + + bool checkConflict(const transaction::Transaction* transaction) const; + +private: + NodeInsertInfo info; + NodeTableInsertInfo tableInfo; +}; + +struct RelInsertInfo { + DataPos srcNodeIDPos; + DataPos dstNodeIDPos; + std::vector columnsPos; + + common::ValueVector* srcNodeIDVector; + common::ValueVector* dstNodeIDVector; + std::vector columnVectors; + + RelInsertInfo(DataPos srcNodeIDPos, DataPos dstNodeIDPos, std::vector columnsPos) + : srcNodeIDPos{srcNodeIDPos}, dstNodeIDPos{dstNodeIDPos}, columnsPos{std::move(columnsPos)}, + srcNodeIDVector{nullptr}, dstNodeIDVector{nullptr} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelInsertInfo); + + void init(const ResultSet& resultSet); + +private: + RelInsertInfo(const RelInsertInfo& other) + : srcNodeIDPos{other.srcNodeIDPos}, dstNodeIDPos{other.dstNodeIDPos}, + columnsPos{other.columnsPos}, srcNodeIDVector{nullptr}, dstNodeIDVector{nullptr} {} +}; + +struct RelTableInsertInfo { + storage::RelTable* table; + evaluator::evaluator_vector_t columnDataEvaluators; + + std::vector columnDataVectors; + + RelTableInsertInfo(storage::RelTable* table, evaluator::evaluator_vector_t evaluators) + : table{table}, columnDataEvaluators{std::move(evaluators)} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelTableInsertInfo); + + void init(const ResultSet& resultSet, main::ClientContext* context); + common::internalID_t getRelID() const; + +private: + RelTableInsertInfo(const RelTableInsertInfo& other) + : table{other.table}, columnDataEvaluators(copyVector(other.columnDataEvaluators)) {} +}; + +class RelInsertExecutor { +public: + RelInsertExecutor(RelInsertInfo info, RelTableInsertInfo tableInfo) + : info{std::move(info)}, tableInfo{std::move(tableInfo)} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelInsertExecutor); + + void init(ResultSet* resultSet, const ExecutionContext* context); + + common::internalID_t insert(main::ClientContext* context); + + // See comment in NodeInsertExecutor. + void skipInsert() const; + +private: + RelInsertExecutor(const RelInsertExecutor& other) + : info{other.info.copy()}, tableInfo{other.tableInfo.copy()} {} + +private: + RelInsertInfo info; + RelTableInsertInfo tableInfo; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/merge.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/merge.h new file mode 100644 index 0000000000..7ff8d5d2fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/merge.h @@ -0,0 +1,123 @@ +#pragma once + +#include "insert_executor.h" +#include "processor/operator/physical_operator.h" +#include "processor/result/pattern_creation_info_table.h" +#include "set_executor.h" + +namespace lbug { +namespace processor { + +struct MergeInfo { + std::vector> keyEvaluators; + FactorizedTableSchema tableSchema; + common::executor_info executorInfo; + DataPos existenceMark; + + MergeInfo(std::vector> keyEvaluators, + FactorizedTableSchema tableSchema, common::executor_info executorInfo, + DataPos existenceMark) + : keyEvaluators{std::move(keyEvaluators)}, tableSchema{std::move(tableSchema)}, + executorInfo{std::move(executorInfo)}, existenceMark{existenceMark} {} + EXPLICIT_COPY_DEFAULT_MOVE(MergeInfo); + +private: + MergeInfo(const MergeInfo& other) + : keyEvaluators{copyVector(other.keyEvaluators)}, tableSchema{other.tableSchema.copy()}, + executorInfo{other.executorInfo}, existenceMark{other.existenceMark} {} +}; + +struct MergePrintInfo final : OPPrintInfo { + binder::expression_vector pattern; + std::vector onCreate; + std::vector onMatch; + + MergePrintInfo(binder::expression_vector pattern, std::vector onCreate, + std::vector onMatch) + : pattern(std::move(pattern)), onCreate(std::move(onCreate)), onMatch(std::move(onMatch)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new MergePrintInfo(*this)); + } + +private: + MergePrintInfo(const MergePrintInfo& other) + : OPPrintInfo(other), pattern(other.pattern), onCreate(other.onCreate), + onMatch(other.onMatch) {} +}; + +struct MergeLocalState { + std::vector keyVectors; + std::unique_ptr hashTable; + common::ValueVector* existenceVector = nullptr; + + void init(ResultSet& resultSet, main::ClientContext* context, MergeInfo& info); + + bool patternExists() const; + + PatternCreationInfo getPatternCreationInfo() const { + return hashTable->getPatternCreationInfo(keyVectors); + } +}; + +class Merge final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::MERGE; + +public: + Merge(std::vector nodeInsertExecutors, + std::vector relInsertExecutors, + std::vector> onCreateNodeSetExecutors, + std::vector> onCreateRelSetExecutors, + std::vector> onMatchNodeSetExecutors, + std::vector> onMatchRelSetExecutors, MergeInfo info, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + nodeInsertExecutors{std::move(nodeInsertExecutors)}, + relInsertExecutors{std::move(relInsertExecutors)}, + onCreateNodeSetExecutors{std::move(onCreateNodeSetExecutors)}, + onCreateRelSetExecutors{std::move(onCreateRelSetExecutors)}, + onMatchNodeSetExecutors{std::move(onMatchNodeSetExecutors)}, + onMatchRelSetExecutors{std::move(onMatchRelSetExecutors)}, info{std::move(info)} {} + + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(copyVector(nodeInsertExecutors), + copyVector(relInsertExecutors), copyVector(onCreateNodeSetExecutors), + copyVector(onCreateRelSetExecutors), copyVector(onMatchNodeSetExecutors), + copyVector(onMatchRelSetExecutors), info.copy(), children[0]->copy(), id, + printInfo->copy()); + } + +private: + void executeOnMatch(ExecutionContext* context); + + void executeOnCreatedPattern(PatternCreationInfo& info, ExecutionContext* context); + + void executeOnNewPattern(PatternCreationInfo& info, ExecutionContext* context); + + void executeNoMatch(ExecutionContext* context); + +private: + std::vector nodeInsertExecutors; + std::vector relInsertExecutors; + + std::vector> onCreateNodeSetExecutors; + std::vector> onCreateRelSetExecutors; + + std::vector> onMatchNodeSetExecutors; + std::vector> onMatchRelSetExecutors; + + MergeInfo info; + MergeLocalState localState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/node_batch_insert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/node_batch_insert.h new file mode 100644 index 0000000000..45ad93c37c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/node_batch_insert.h @@ -0,0 +1,144 @@ +#pragma once + +#include "common/enums/column_evaluate_type.h" +#include "common/types/types.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/persistent/batch_insert.h" +#include "processor/operator/persistent/index_builder.h" +#include "storage/stats/table_stats.h" +#include "storage/table/chunked_node_group.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} +namespace transaction { +class Transaction; +} // namespace transaction + +namespace processor { +struct ExecutionContext; + +struct NodeBatchInsertPrintInfo final : OPPrintInfo { + std::string tableName; + + explicit NodeBatchInsertPrintInfo(std::string tableName) : tableName(std::move(tableName)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new NodeBatchInsertPrintInfo(*this)); + } + +private: + NodeBatchInsertPrintInfo(const NodeBatchInsertPrintInfo& other) + : OPPrintInfo(other), tableName(other.tableName) {} +}; + +struct NodeBatchInsertInfo final : BatchInsertInfo { + evaluator::evaluator_vector_t columnEvaluators; + std::vector evaluateTypes; + + NodeBatchInsertInfo(std::string tableName, std::vector warningColumnTypes, + std::vector> columnEvaluators, + std::vector evaluateTypes) + : BatchInsertInfo{std::move(tableName), std::move(warningColumnTypes)}, + columnEvaluators{std::move(columnEvaluators)}, evaluateTypes{std::move(evaluateTypes)} {} + + NodeBatchInsertInfo(const NodeBatchInsertInfo& other) + : BatchInsertInfo{other}, columnEvaluators{copyVector(other.columnEvaluators)}, + evaluateTypes{other.evaluateTypes} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct NodeBatchInsertSharedState final : BatchInsertSharedState { + // Primary key info + common::column_id_t pkColumnID; + common::LogicalType pkType; + std::optional globalIndexBuilder; + + function::TableFuncSharedState* tableFuncSharedState; + + std::vector mainDataColumns; + + // The sharedNodeGroup is to accumulate left data within local node groups in NodeBatchInsert + // ops. + std::unique_ptr sharedNodeGroup; + + explicit NodeBatchInsertSharedState(std::shared_ptr fTable) + : BatchInsertSharedState{std::move(fTable)}, pkColumnID{0}, + globalIndexBuilder(std::nullopt), tableFuncSharedState{nullptr}, + sharedNodeGroup{nullptr} {} + + void initPKIndex(const ExecutionContext* context); +}; + +struct NodeBatchInsertLocalState final : BatchInsertLocalState { + std::optional errorHandler; + + std::optional localIndexBuilder; + + std::shared_ptr columnState; + std::vector columnVectors; + + storage::TableStats stats; + + explicit NodeBatchInsertLocalState(std::span outputDataTypes) + : stats{outputDataTypes} {} +}; + +class NodeBatchInsert final : public BatchInsert { +public: + NodeBatchInsert(std::unique_ptr info, + std::shared_ptr sharedState, + std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : BatchInsert{std::move(info), std::move(sharedState), id, std::move(printInfo)} { + children.push_back(std::move(child)); + } + + void initGlobalStateInternal(ExecutionContext* context) override; + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + + void finalize(ExecutionContext* context) override; + void finalizeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info->copy(), sharedState, children[0]->copy(), id, + printInfo->copy()); + } + + // The node group will be reset so that the only values remaining are the ones which were + // not written + void writeAndResetNodeGroup(transaction::Transaction* transaction, + std::unique_ptr& nodeGroup, + std::optional& indexBuilder, storage::MemoryManager* mm, + storage::PageAllocator& pageAllocator) const; + +private: + void evaluateExpressions(uint64_t numTuples) const; + void appendIncompleteNodeGroup(transaction::Transaction* transaction, + std::unique_ptr localNodeGroup, + std::optional& indexBuilder, storage::MemoryManager* mm) const; + void clearToIndex(storage::MemoryManager* mm, + std::unique_ptr& nodeGroup, + common::offset_t startIndexInGroup) const; + + void copyToNodeGroup(transaction::Transaction* transaction, storage::MemoryManager* mm) const; + + NodeBatchInsertErrorHandler createErrorHandler(ExecutionContext* context) const; + + void writeAndResetNodeGroup(transaction::Transaction* transaction, + std::unique_ptr& nodeGroup, + std::optional& indexBuilder, storage::MemoryManager* mm, + NodeBatchInsertErrorHandler& errorHandler, storage::PageAllocator& pageAllocator) const; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/node_batch_insert_error_handler.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/node_batch_insert_error_handler.h new file mode 100644 index 0000000000..d54a411e3b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/node_batch_insert_error_handler.h @@ -0,0 +1,61 @@ +#pragma once + +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/batch_insert_error_handler.h" + +namespace lbug { +namespace storage { +class NodeTable; +} + +namespace processor { +template +struct IndexBuilderError { + std::string message; + T key; + common::nodeID_t nodeID; + + // CSV Reader data + std::optional warningData; +}; + +class NodeBatchInsertErrorHandler { +public: + NodeBatchInsertErrorHandler(ExecutionContext* context, common::LogicalTypeID pkType, + storage::NodeTable* nodeTable, bool ignoreErrors, + std::shared_ptr sharedErrorCounter, std::mutex* sharedErrorCounterMtx); + + template + void handleError(IndexBuilderError error) { + baseErrorHandler.handleError(std::move(error.message), std::move(error.warningData)); + + setCurrentErroneousRow(error.key, error.nodeID); + deleteCurrentErroneousRow(); + } + + void flushStoredErrors(); + +private: + template + void setCurrentErroneousRow(const T& key, common::nodeID_t nodeID) { + keyVector->setValue(0, key); + offsetVector->setValue(0, nodeID); + } + + void deleteCurrentErroneousRow(); + + static constexpr common::idx_t DELETE_VECTOR_SIZE = 1; + + storage::NodeTable* nodeTable; + ExecutionContext* context; + + // vectors that are reused by each deletion + std::shared_ptr keyVector; + std::shared_ptr offsetVector; + + BatchInsertErrorHandler baseErrorHandler; +}; +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/copy_from_error.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/copy_from_error.h new file mode 100644 index 0000000000..0ea3a51369 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/copy_from_error.h @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" +#include "common/constants.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace common { +class ValueVector; +} +namespace storage { +class ColumnChunkData; +} + +namespace processor { + +template +concept DataSource = + std::same_as || std::same_as; + +struct LBUG_API WarningSourceData { + // we should stick to integral types here as each value essentially adds a column to the output + // when reading from a file + using DataType = std::variant; + + static constexpr size_t BLOCK_IDX_IDX = 0; + static constexpr size_t OFFSET_IN_BLOCK_IDX = 1; + static constexpr size_t NUM_BLOCK_VALUES = 2; + + WarningSourceData() : WarningSourceData(0) {} + explicit WarningSourceData(uint64_t numSourceSpecificValues); + + template + void dumpTo(uint64_t& blockIdx, uint32_t& offsetInBlock, Types&... vars) const; + + template + static WarningSourceData constructFrom(uint64_t blockIdx, uint32_t offsetInBlock, + Types... newValues); + + uint64_t getBlockIdx() const; + uint32_t getOffsetInBlock() const; + + template + static WarningSourceData constructFromData(const std::vector& chunks, common::idx_t pos); + + std::array values; + uint64_t numValues; +}; + +struct LineContext { + uint64_t startByteOffset; + uint64_t endByteOffset; + + bool isCompleteLine; + + void setNewLine(uint64_t start); + void setEndOfLine(uint64_t end); +}; + +// If parsing in parallel during parsing we may not be able to determine line numbers +// Thus we have additional fields that can be used to determine line numbers + reconstruct lines +// After parsing this will be used to populate a PopulatedCopyFromError instance +struct LBUG_API CopyFromFileError { + CopyFromFileError(std::string message, WarningSourceData warningData, bool completedLine = true, + bool mustThrow = false); + + std::string message; + bool completedLine; + WarningSourceData warningData; + + bool mustThrow; + + bool operator<(const CopyFromFileError& o) const; +}; + +struct PopulatedCopyFromError { + std::string message; + std::string filePath; + std::string skippedLineOrRecord; + uint64_t lineNumber; +}; + +template +void WarningSourceData::dumpTo(uint64_t& blockIdx, uint32_t& offsetInBlock, Types&... vars) const { + static_assert(sizeof...(Types) + NUM_BLOCK_VALUES <= std::tuple_size_v); + KU_ASSERT(sizeof...(Types) + NUM_BLOCK_VALUES == numValues); + common::TypeUtils::paramPackForEach( + [this](auto idx, auto& value) { + value = std::get>(values[idx]); + }, + blockIdx, offsetInBlock, vars...); +} + +template +WarningSourceData WarningSourceData::constructFrom(uint64_t blockIdx, uint32_t offsetInBlock, + Types... newValues) { + static_assert(sizeof...(Types) + NUM_BLOCK_VALUES <= std::tuple_size_v, + "For performance reasons the number of warning metadata columns has a " + "statically-defined limit, modify " + "'common::CopyConstants::WARNING_DATA_MAX_NUM_COLUMNS' if you wish to increase it."); + + WarningSourceData ret{sizeof...(Types) + NUM_BLOCK_VALUES}; + common::TypeUtils::paramPackForEach([&ret](auto idx, auto value) { ret.values[idx] = value; }, + blockIdx, offsetInBlock, newValues...); + return ret; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h new file mode 100644 index 0000000000..57e63753c5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h @@ -0,0 +1,154 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/copier_config/csv_reader_config.h" +#include "common/data_chunk/data_chunk.h" +#include "common/file_system/file_info.h" +#include "common/types/types.h" +#include "processor/operator/persistent/reader/copy_from_error.h" + +namespace lbug { +namespace common { +struct FileScanInfo; +} +namespace main { +class ClientContext; +} + +namespace processor { +class LocalFileErrorHandler; +class SharedFileErrorHandler; + +struct CSVColumnInfo { + uint64_t numColumns; + std::vector columnSkips; + common::column_id_t numWarningDataColumns; + + CSVColumnInfo(uint64_t numColumns, std::vector columnSkips, + common::column_id_t numWarningDataColumns) + : numColumns{numColumns}, columnSkips{std::move(columnSkips)}, + numWarningDataColumns(numWarningDataColumns) {} + EXPLICIT_COPY_DEFAULT_MOVE(CSVColumnInfo); + +private: + CSVColumnInfo(const CSVColumnInfo& other) + : numColumns{other.numColumns}, columnSkips{other.columnSkips}, + numWarningDataColumns(other.numWarningDataColumns) {} +}; + +class BaseCSVReader { + friend class ParsingDriver; + friend class SniffCSVNameAndTypeDriver; + +public: + // 1st element is number of successfully parsed rows + // 2nd element is number of failed to parse rows + using parse_result_t = std::pair; + + BaseCSVReader(const std::string& filePath, common::idx_t fileIdx, common::CSVOption option, + CSVColumnInfo columnInfo, main::ClientContext* context, + LocalFileErrorHandler* errorHandler); + + virtual ~BaseCSVReader() = default; + + virtual uint64_t parseBlock(common::block_idx_t blockIdx, common::DataChunk& resultChunk) = 0; + + main::ClientContext* getClientContext() const { return context; } + const common::CSVOption& getCSVOption() const { return option; } + + uint64_t getNumColumns() const { return columnInfo.numColumns; } + bool skipColumn(common::idx_t idx) const { + KU_ASSERT(idx < columnInfo.columnSkips.size()); + return columnInfo.columnSkips[idx]; + } + bool isEOF() const; + uint64_t getFileSize(); + // Get the file offset of the current buffer position. + uint64_t getFileOffset() const; + + std::string reconstructLine(uint64_t startPosition, uint64_t endPosition, bool completedLine); + + static common::column_id_t appendWarningDataColumns(std::vector& resultColumnNames, + std::vector& resultColumnTypes, + const common::FileScanInfo& fileScanInfo); + + static PopulatedCopyFromError basePopulateErrorFunc(CopyFromFileError error, + const SharedFileErrorHandler* sharedErrorHandler, BaseCSVReader* reader, + std::string filePath); + + static common::idx_t getFileIdxFunc(const CopyFromFileError& error); + +protected: + template + bool addValue(Driver&, uint64_t rowNum, common::column_id_t columnIdx, std::string_view strVal, + std::vector& escapePositions); + + //! Read BOM and header. + parse_result_t handleFirstBlock(); + + //! If this finds a BOM, it advances `position`. + void readBOM(); + parse_result_t readHeader(); + //! Reads a new buffer from the CSV file. + //! Uses the start value to ensure the current value stays within the buffer. + //! Modifies the start value to point to the new start of the current value. + //! If start is NULL, none of the buffer is kept. + //! Returns false if the file has been exhausted. + bool readBuffer(uint64_t* start); + + //! Like ReadBuffer, but only reads if position >= bufferSize. + //! If this returns true, buffer[position] is a valid character that we can read. + inline bool maybeReadBuffer(uint64_t* start) { + return position < bufferSize || readBuffer(start); + } + + void handleCopyException(const std::string& message, bool mustThrow = false); + + template + parse_result_t parseCSV(Driver&); + + inline bool isNewLine(char c) { return c == '\n' || c == '\r'; } + +protected: + virtual bool handleQuotedNewline() = 0; + + void skipCurrentLine(); + + void resetNumRowsInCurrentBlock(); + void increaseNumRowsInCurrentBlock(uint64_t numRows, uint64_t numErrors); + uint64_t getNumRowsInCurrentBlock() const; + uint32_t getRowOffsetInCurrentBlock() const; + + WarningSourceData getWarningSourceData() const; + +protected: + main::ClientContext* context; + common::CSVOption option; + CSVColumnInfo columnInfo; + std::unique_ptr fileInfo; + + common::block_idx_t currentBlockIdx; + uint64_t numRowsInCurrentBlock; + + uint64_t curRowIdx; + uint64_t numErrors; + + std::unique_ptr buffer; + uint64_t bufferIdx; + std::atomic bufferSize; + std::atomic position; + LineContext lineContext; + std::atomic osFileOffset; + common::idx_t fileIdx; + + LocalFileErrorHandler* errorHandler; + + bool rowEmpty = false; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/dialect_detection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/dialect_detection.h new file mode 100644 index 0000000000..bf325b5c97 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/dialect_detection.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "common/copier_config/csv_reader_config.h" + +namespace lbug { +namespace processor { + +struct DialectOption { + char delimiter = ','; + char quoteChar = '"'; + char escapeChar = '"'; + bool everQuoted = false; + bool everEscaped = false; + bool doDialectDetection = true; + + DialectOption() = default; + DialectOption(char delim, char quote, char escape) + : delimiter(delim), quoteChar(quote), escapeChar(escape), everQuoted(false), + everEscaped(false), doDialectDetection(true) {} +}; + +std::vector generateDialectOptions(const common::CSVOption& option); + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/driver.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/driver.h new file mode 100644 index 0000000000..133eb02afb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/driver.h @@ -0,0 +1,157 @@ +#pragma once + +#include +#include + +#include "common/data_chunk/data_chunk.h" +#include "function/table/bind_input.h" +#include "processor/operator/persistent/reader/copy_from_error.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace processor { + +// TODO(Keenan): Split up this file. +class BaseCSVReader; + +// Driver type identifications. +enum class DriverType { + PARSING, + PARALLEL, + SERIAL, + SNIFF_CSV_DIALECT, + SNIFF_CSV_NAME_AND_TYPE, + SNIFF_CSV_HEADER, + HEADER, + SKIP_ROW +}; + +struct WarningDataWithColumnInfo { + WarningDataWithColumnInfo(const WarningSourceData& warningSourceData, + uint64_t warningDataStartColumnIdx) + : warningDataStartColumnIdx(warningDataStartColumnIdx), data(warningSourceData) {} + + uint64_t warningDataStartColumnIdx; + WarningSourceData data; +}; + +class ParsingDriver { +public: + explicit ParsingDriver(common::DataChunk& chunk, DriverType type = DriverType::PARSING); + virtual ~ParsingDriver() = default; + + bool done(uint64_t rowNum); + virtual bool addValue(uint64_t rowNum, common::column_id_t columnIdx, std::string_view value); + virtual bool addRow(uint64_t rowNum, common::column_id_t columnCount, + std::optional warningData); + +public: + const DriverType driverType; + +private: + virtual bool doneEarly() = 0; + virtual BaseCSVReader* getReader() = 0; + +private: + common::DataChunk& chunk; + +protected: + bool rowEmpty; +}; + +class ParallelCSVReader; + +class ParallelParsingDriver : public ParsingDriver { +public: + ParallelParsingDriver(common::DataChunk& chunk, ParallelCSVReader* reader); + bool doneEarly() override; + +private: + BaseCSVReader* getReader() override; + +private: + ParallelCSVReader* reader; +}; + +class SerialCSVReader; + +class SerialParsingDriver : public ParsingDriver { +public: + SerialParsingDriver(common::DataChunk& chunk, SerialCSVReader* reader, + DriverType type = DriverType::SERIAL); + bool doneEarly() override; + +private: + BaseCSVReader* getReader() override; + +protected: + SerialCSVReader* reader; +}; + +class SniffCSVDialectDriver : public SerialParsingDriver { +public: + explicit SniffCSVDialectDriver(SerialCSVReader* reader); + + bool done(uint64_t rowNum) const; + bool addValue(uint64_t rowNum, common::column_id_t columnIdx, std::string_view value) override; + bool addRow(uint64_t rowNum, common::column_id_t columnCount, + std::optional warningData) override; + void reset(); + + void setEverQuoted() { everQuoted = true; } + void setEverEscaped() { everEscaped = true; } + void setError() { error = true; } + + bool getEverQuoted() const { return everQuoted; } + bool getEverEscaped() const { return everEscaped; } + bool getError() const { return error; } + common::idx_t getResultPosition() const { return resultPosition; } + common::idx_t getColumnCount(common::idx_t index) const { return columnCounts[index]; } + +private: + std::vector columnCounts; + common::idx_t currentColumnCount = 0; + bool error = false; + common::idx_t resultPosition = 0; + bool everQuoted = false; + bool everEscaped = false; +}; + +class SniffCSVNameAndTypeDriver : public SerialParsingDriver { +public: + SniffCSVNameAndTypeDriver(SerialCSVReader* reader, + const function::ExtraScanTableFuncBindInput* bindInput); + + bool done(uint64_t rowNum); + bool addValue(uint64_t rowNum, common::column_id_t columnIdx, std::string_view value) override; + +public: + std::vector firstRow; + std::vector> columns; + std::vector sniffType; + // if the type isn't declared in the header, sniff it +}; + +class SniffCSVHeaderDriver : public SerialParsingDriver { +public: + SniffCSVHeaderDriver(SerialCSVReader* reader, + const std::vector>& TypeDetected); + + bool done(uint64_t rowNum) const { + // Only read the first line. + return (0 < rowNum); + }; + + bool addValue(uint64_t rowNum, common::column_id_t columnIdx, std::string_view value) override; + +public: + std::vector> columns; + std::vector> header; + bool detectedHeader = false; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h new file mode 100644 index 0000000000..0fceb09ab4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h @@ -0,0 +1,64 @@ +#pragma once + +#include "base_csv_reader.h" +#include "common/types/types.h" +#include "function/function.h" +#include "function/table/bind_input.h" +#include "function/table/scan_file_function.h" +#include "function/table/table_function.h" +#include "processor/operator/persistent/reader/file_error_handler.h" + +namespace lbug { +namespace processor { + +//! ParallelCSVReader is a class that reads values from a stream in parallel. +class ParallelCSVReader final : public BaseCSVReader { + friend class ParallelParsingDriver; + +public: + ParallelCSVReader(const std::string& filePath, common::idx_t fileIdx, common::CSVOption option, + CSVColumnInfo columnInfo, main::ClientContext* context, + LocalFileErrorHandler* errorHandler); + + bool hasMoreToRead() const; + uint64_t parseBlock(common::block_idx_t blockIdx, common::DataChunk& resultChunk) override; + uint64_t continueBlock(common::DataChunk& resultChunk); + + void reportFinishedBlock(); + +protected: + bool handleQuotedNewline() override; + +private: + bool finishedBlock() const; + void seekToBlockStart(); +}; + +struct ParallelCSVLocalState final : public function::TableFuncLocalState { + std::unique_ptr reader; + std::unique_ptr errorHandler; + common::idx_t fileIdx = common::INVALID_IDX; +}; + +struct ParallelCSVScanSharedState final : public function::ScanFileWithProgressSharedState { + common::CSVOption csvOption; + CSVColumnInfo columnInfo; + std::atomic numBlocksReadByFiles = 0; + std::vector errorHandlers; + populate_func_t populateErrorFunc; + + ParallelCSVScanSharedState(common::FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context, common::CSVOption csvOption, CSVColumnInfo columnInfo); + + void setFileComplete(uint64_t completedFileIdx); + populate_func_t constructPopulateFunc(); +}; + +struct ParallelCSVScan { + static constexpr const char* name = "READ_CSV_PARALLEL"; + + static function::function_set getFunctionSet(); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h new file mode 100644 index 0000000000..9ee1a9b25a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h @@ -0,0 +1,67 @@ +#pragma once + +#include "base_csv_reader.h" +#include "function/function.h" +#include "function/table/bind_input.h" +#include "function/table/scan_file_function.h" +#include "processor/operator/persistent/reader/csv/dialect_detection.h" +#include "processor/operator/persistent/reader/file_error_handler.h" + +namespace lbug { +namespace processor { + +//! Serial CSV reader is a class that reads values from a stream in a single thread. +class SerialCSVReader final : public BaseCSVReader { +public: + SerialCSVReader(const std::string& filePath, common::idx_t fileIdx, common::CSVOption option, + CSVColumnInfo columnInfo, main::ClientContext* context, LocalFileErrorHandler* errorHandler, + const function::ExtraScanTableFuncBindInput* bindInput = nullptr); + + //! Sniffs CSV dialect and determines skip rows, header row, column types and column names + std::vector> sniffCSV( + DialectOption& detectedDialect, bool& detectedHeader); + uint64_t parseBlock(common::block_idx_t blockIdx, common::DataChunk& resultChunk) override; + +protected: + bool handleQuotedNewline() override { return true; } + +private: + const function::ExtraScanTableFuncBindInput* bindInput; + void resetReaderState(); + DialectOption detectDialect(); + bool detectHeader(std::vector>& detectedTypes); +}; + +struct SerialCSVScanSharedState final : public function::ScanFileWithProgressSharedState { + std::unique_ptr reader; + common::CSVOption csvOption; + CSVColumnInfo columnInfo; + uint64_t totalReadSizeByFile; + std::unique_ptr sharedErrorHandler; + std::unique_ptr localErrorHandler; + uint64_t queryID; + populate_func_t populateErrorFunc; + + SerialCSVScanSharedState(common::FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context, common::CSVOption csvOption, CSVColumnInfo columnInfo, + uint64_t queryID); + + void read(common::DataChunk& outputChunk); + + void initReader(main::ClientContext* context); + void finalizeReader(main::ClientContext* context) const; + + populate_func_t constructPopulateFunc() const; +}; + +struct SerialCSVScan { + static constexpr const char* name = "READ_CSV_SERIAL"; + + static function::function_set getFunctionSet(); + static void bindColumns(const function::ExtraScanTableFuncBindInput* bindInput, + std::vector& columnNames, std::vector& columnTypes, + DialectOption& detectedDialect, bool& detectedHeader, main::ClientContext* context); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/file_error_handler.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/file_error_handler.h new file mode 100644 index 0000000000..4673b6c984 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/file_error_handler.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include + +#include "common/uniq_lock.h" +#include "processor/operator/persistent/reader/copy_from_error.h" +#include "processor/warning_context.h" + +namespace lbug::processor { + +class BaseCSVReader; +class SerialCSVReader; + +struct LinesPerBlock { + uint64_t numLines; + bool doneParsingBlock; +}; + +class SharedFileErrorHandler; + +class LBUG_API SharedFileErrorHandler { +public: + explicit SharedFileErrorHandler(common::idx_t fileIdx, std::mutex* sharedMtx, + populate_func_t populateErrorFunc = {}); + + void handleError(CopyFromFileError error); + void throwCachedErrorsIfNeeded(); + + void setHeaderNumRows(uint64_t numRows); + + void updateLineNumberInfo(const std::map& linesPerBlock, + bool canThrowCachedError); + uint64_t getNumCachedErrors(); + uint64_t getLineNumber(uint64_t blockIdx, uint64_t numRowsReadInBlock) const; + + void setPopulateErrorFunc(populate_func_t newPopulateErrorFunc); + +private: + // this number can be small as we only cache errors if we wish to throw them later + static constexpr uint64_t MAX_CACHED_ERROR_COUNT = 64; + + common::UniqLock lock(); + void tryThrowFirstCachedError(); + + std::string getErrorMessage(PopulatedCopyFromError populatedError) const; + void throwError(CopyFromFileError error) const; + bool canGetLineNumber(uint64_t blockIdx) const; + void tryCacheError(CopyFromFileError error, const common::UniqLock&); + + std::mutex* mtx; // can be nullptr, in which case mutual exclusion is guaranteed by the caller + common::idx_t fileIdx; + std::vector linesPerBlock; + std::vector cachedErrors; + populate_func_t populateErrorFunc; + + uint64_t headerNumRows; +}; + +class LBUG_API LocalFileErrorHandler { +public: + ~LocalFileErrorHandler(); + + LocalFileErrorHandler(SharedFileErrorHandler* sharedErrorHandler, bool ignoreErrors, + main::ClientContext* context, bool cacheIgnoredErrors = true); + + void handleError(CopyFromFileError error); + void reportFinishedBlock(uint64_t blockIdx, uint64_t numRowsRead); + void setHeaderNumRows(uint64_t numRows); + void finalize(bool canThrowCachedError = true); + bool getIgnoreErrorsOption() const { return ignoreErrors; } + +private: + static constexpr uint64_t LOCAL_WARNING_LIMIT = 256; + void flushCachedErrors(bool canThrowCachedError = true); + + std::map linesPerBlock; + std::vector cachedErrors; + SharedFileErrorHandler* sharedErrorHandler; + main::ClientContext* context; + + uint64_t maxCachedErrorCount; + bool ignoreErrors; + bool cacheIgnoredErrors; +}; + +} // namespace lbug::processor diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/npy/npy_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/npy/npy_reader.h new file mode 100644 index 0000000000..19d1685c3a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/npy/npy_reader.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include + +#include "common/data_chunk/data_chunk.h" +#include "common/types/types.h" +#include "function/function.h" +#include "function/table/scan_file_function.h" + +namespace lbug { +namespace processor { + +class NpyReader { +public: + explicit NpyReader(const std::string& filePath); + + ~NpyReader(); + + size_t getNumElementsPerRow() const; + + uint8_t* getPointerToRow(size_t row) const; + + inline size_t getNumRows() const { return shape[0]; } + + void readBlock(common::block_idx_t blockIdx, common::ValueVector* vectorToRead) const; + + // Used in tests only. + inline common::LogicalTypeID getType() const { return type; } + inline std::vector getShape() const { return shape; } + + void validate(const common::LogicalType& type_, common::offset_t numRows); + +private: + void parseHeader(); + void parseType(std::string descr); + +private: + std::string filePath; + int fd; + size_t fileSize; + void* mmapRegion; + size_t dataOffset; + std::vector shape; + common::LogicalTypeID type; +}; + +class NpyMultiFileReader { +public: + explicit NpyMultiFileReader(const std::vector& filePaths); + + void readBlock(common::block_idx_t blockIdx, common::DataChunk& dataChunkToRead) const; + +private: + std::vector> fileReaders; +}; + +struct NpyScanSharedState final : public function::ScanFileSharedState { + explicit NpyScanSharedState(const common::FileScanInfo fileScanInfo, uint64_t numRows); + + std::unique_ptr npyMultiFileReader; +}; + +struct NpyScanFunction { + static constexpr const char* name = "READ_NPY"; + + static function::function_set getFunctionSet(); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/boolean_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/boolean_column_reader.h new file mode 100644 index 0000000000..0bee024172 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/boolean_column_reader.h @@ -0,0 +1,45 @@ +#pragma once + +#include "column_reader.h" +#include "templated_column_reader.h" + +namespace lbug { +namespace processor { + +struct BooleanParquetValueConversion; + +class BooleanColumnReader : public TemplatedColumnReader { +public: + static constexpr const common::PhysicalTypeID TYPE = common::PhysicalTypeID::BOOL; + +public: + BooleanColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat) + : TemplatedColumnReader(reader, std::move(type), + schema, schemaIdx, maxDefine, maxRepeat), + bytePos(0){}; + + uint8_t bytePos; + + void initializeRead(uint64_t rowGroupIdx, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol) override; + + inline void resetPage() override { bytePos = 0; } +}; + +struct BooleanParquetValueConversion { + static bool dictRead(ByteBuffer& /*dict*/, uint32_t& /*offset*/, ColumnReader& /*reader*/) { + throw common::CopyException{"Dicts for booleans make no sense"}; + } + + static bool plainRead(ByteBuffer& plainData, ColumnReader& reader); + + static inline void plainSkip(ByteBuffer& plainData, ColumnReader& reader) { + plainRead(plainData, reader); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/callback_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/callback_column_reader.h new file mode 100644 index 0000000000..29e034c088 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/callback_column_reader.h @@ -0,0 +1,41 @@ +#pragma once + +#include "column_reader.h" +#include "parquet_reader.h" +#include "templated_column_reader.h" + +namespace lbug { +namespace processor { + +template +class CallbackColumnReader + : public TemplatedColumnReader> { + using BaseType = TemplatedColumnReader>; + +public: + static constexpr const common::PhysicalTypeID TYPE = common::PhysicalTypeID::ANY; + +public: + CallbackColumnReader(ParquetReader& reader, common::LogicalType type_p, + const lbug_parquet::format::SchemaElement& schema_p, uint64_t file_idx_p, + uint64_t max_define_p, uint64_t max_repeat_p) + : TemplatedColumnReader>(reader, + std::move(type_p), schema_p, file_idx_p, max_define_p, max_repeat_p) {} + +protected: + void dictionary(const std::shared_ptr& dictionaryData, + uint64_t numEntries) override { + BaseType::allocateDict(numEntries * sizeof(KU_PHYSICAL_TYPE)); + auto dictPtr = (KU_PHYSICAL_TYPE*)this->dict->ptr; + for (auto i = 0u; i < numEntries; i++) { + dictPtr[i] = FUNC(dictionaryData->read()); + } + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/column_reader.h new file mode 100644 index 0000000000..2934ddfb43 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/column_reader.h @@ -0,0 +1,126 @@ +#pragma once + +#include + +#include "common/system_config.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "parquet_dbp_decoder.h" +#include "parquet_rle_bp_decoder.h" +#include "parquet_types.h" +#include "resizable_buffer.h" +#include "thrift_tools.h" + +namespace lbug { +namespace processor { +class ParquetReader; + +typedef std::bitset parquet_filter_t; + +class ColumnReader { +public: + ColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, common::idx_t fileIdx, + uint64_t maxDefinition, uint64_t maxRepeat); + virtual ~ColumnReader() = default; + const common::LogicalType& getDataType() const { return type; } + bool hasDefines() const { return maxDefine > 0; } + bool hasRepeats() const { return maxRepeat > 0; } + virtual void skip(uint64_t numValues) { pendingSkips += numValues; } + virtual void dictionary(const std::shared_ptr& /*data*/, + uint64_t /*num_entries*/) { + KU_UNREACHABLE; + } + virtual void offsets(uint32_t* /*offsets*/, uint8_t* /*defines*/, uint64_t /*numValues*/, + parquet_filter_t& /*filter*/, uint64_t /*resultOffset*/, common::ValueVector* /*result*/) { + KU_UNREACHABLE; + } + virtual void plain(const std::shared_ptr& /*plainData*/, uint8_t* /*defines*/, + uint64_t /*numValues*/, parquet_filter_t& /*filter*/, uint64_t /*resultOffset*/, + common::ValueVector* /*result*/) { + KU_UNREACHABLE; + } + virtual void resetPage() {} + virtual uint64_t getGroupRowsAvailable() { return groupRowsAvailable; } + virtual void initializeRead(uint64_t rowGroupIdx, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol); + virtual uint64_t getTotalCompressedSize(); + virtual void registerPrefetch(ThriftFileTransport& transport, bool allowMerge); + virtual uint64_t fileOffset() const; + virtual void applyPendingSkips(uint64_t numValues); + virtual uint64_t read(uint64_t numValues, parquet_filter_t& filter, uint8_t* defineOut, + uint8_t* repeatOut, common::ValueVector* resultOut); + static std::unique_ptr createReader(ParquetReader& reader, + common::LogicalType type, const lbug_parquet::format::SchemaElement& schema, + uint64_t fileIdx, uint64_t maxDefine, uint64_t maxRepeat); + void prepareRead(parquet_filter_t& filter); + void allocateBlock(uint64_t size); + void allocateCompressed(uint64_t size); + void decompressInternal(lbug_parquet::format::CompressionCodec::type codec, const uint8_t* src, + uint64_t srcSize, uint8_t* dst, uint64_t dstSize); + void preparePageV2(lbug_parquet::format::PageHeader& pageHdr); + void preparePage(lbug_parquet::format::PageHeader& pageHdr); + void prepareDataPage(lbug_parquet::format::PageHeader& pageHdr); + template + void plainTemplated(const std::shared_ptr& plainData, const uint8_t* defines, + uint64_t numValues, parquet_filter_t& filter, uint64_t resultOffset, + common::ValueVector* result) { + for (auto i = 0u; i < numValues; i++) { + if (hasDefines() && defines[i + resultOffset] != maxDefine) { + result->setNull(i + resultOffset, true); + continue; + } + result->setNull(i + resultOffset, false); + if (filter[i + resultOffset]) { + VALUE_TYPE val = CONVERSION::plainRead(*plainData, *this); + result->setValue(i + resultOffset, val); + } else { // there is still some data there that we have to skip over + CONVERSION::plainSkip(*plainData, *this); + } + } + } + +private: + static std::unique_ptr createTimestampReader(ParquetReader& reader, + common::LogicalType type, const lbug_parquet::format::SchemaElement& schema, + uint64_t fileIdx, uint64_t maxDefine, uint64_t maxRepeat); + +protected: + const lbug_parquet::format::SchemaElement& schema; + + uint64_t fileIdx; + uint64_t maxDefine; + uint64_t maxRepeat; + + ParquetReader& reader; + common::LogicalType type; + + uint64_t pendingSkips = 0; + + const lbug_parquet::format::ColumnChunk* chunk = nullptr; + + lbug_apache::thrift::protocol::TProtocol* protocol; + uint64_t pageRowsAvailable; + uint64_t groupRowsAvailable; + uint64_t chunkReadOffset; + + std::shared_ptr block; + + ResizeableBuffer compressedBuffer; + ResizeableBuffer offsetBuffer; + + std::unique_ptr dictDecoder; + std::unique_ptr defineDecoder; + std::unique_ptr repeatedDecoder; + std::unique_ptr dbpDecoder; + std::unique_ptr rleDecoder; + + // dummies for Skip() + parquet_filter_t noneFilter; + ResizeableBuffer dummyDefine; + ResizeableBuffer dummyRepeat; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/decode_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/decode_utils.h new file mode 100644 index 0000000000..a0cc653611 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/decode_utils.h @@ -0,0 +1,65 @@ +#pragma once + +#include "common/exception/copy.h" +#include "common/string_format.h" +#include "resizable_buffer.h" + +namespace lbug { +namespace processor { +class ParquetDecodeUtils { + +public: + template + static T ZigzagToInt(const T n) { + return (n >> 1) ^ -(n & 1); + } + + static const uint64_t BITPACK_MASKS[]; + static const uint64_t BITPACK_MASKS_SIZE; + static const uint8_t BITPACK_DLEN; + + template + static uint32_t BitUnpack(ByteBuffer& buffer, uint8_t& bitpack_pos, T* dest, uint32_t count, + uint8_t width) { + if (width >= ParquetDecodeUtils::BITPACK_MASKS_SIZE) { + throw common::CopyException(common::stringFormat( + "The width ({}) of the bitpacked data exceeds the supported max width ({}), " + "the file might be corrupted.", + width, ParquetDecodeUtils::BITPACK_MASKS_SIZE)); + } + auto mask = BITPACK_MASKS[width]; + + for (uint32_t i = 0; i < count; i++) { + T val = (buffer.get() >> bitpack_pos) & mask; + bitpack_pos += width; + while (bitpack_pos > BITPACK_DLEN) { + buffer.inc(1); + val |= (T(buffer.get()) << T(BITPACK_DLEN - (bitpack_pos - width))) & mask; + bitpack_pos -= BITPACK_DLEN; + } + dest[i] = val; + } + return count; + } + + template + static T VarintDecode(ByteBuffer& buf) { + T result = 0; + uint8_t shift = 0; + while (true) { + auto byte = buf.read(); + result |= T(byte & 127) << shift; + if ((byte & 128) == 0) { + break; + } + shift += 7; + if (shift > sizeof(T) * 8) { + throw std::runtime_error("Varint-decoding found too large number"); + } + } + return result; + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/interval_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/interval_column_reader.h new file mode 100644 index 0000000000..d1841212b5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/interval_column_reader.h @@ -0,0 +1,41 @@ +#pragma once + +#include "common/constants.h" +#include "common/types/interval_t.h" +#include "templated_column_reader.h" + +namespace lbug { +namespace processor { + +struct IntervalValueConversion { + static inline common::interval_t dictRead(ByteBuffer& dict, uint32_t& offset, + ColumnReader& /*reader*/) { + return (reinterpret_cast(dict.ptr))[offset]; + } + + static common::interval_t readParquetInterval(const char* input); + + static common::interval_t plainRead(ByteBuffer& plainData, ColumnReader& reader); + + static inline void plainSkip(ByteBuffer& plain_data, ColumnReader& /*reader*/) { + plain_data.inc(common::ParquetConstants::PARQUET_INTERVAL_SIZE); + } +}; + +class IntervalColumnReader + : public TemplatedColumnReader { + +public: + IntervalColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t fileIdx, uint64_t maxDefine, + uint64_t maxRepeat) + : TemplatedColumnReader(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat){}; + +protected: + void dictionary(const std::shared_ptr& dictionaryData, + uint64_t numEntries) override; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/list_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/list_column_reader.h new file mode 100644 index 0000000000..e03bb59864 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/list_column_reader.h @@ -0,0 +1,54 @@ +#pragma once + +#include "column_reader.h" + +namespace lbug { +namespace processor { + +class ListColumnReader : public ColumnReader { +public: + static constexpr const common::PhysicalTypeID TYPE = common::PhysicalTypeID::LIST; + +public: + ListColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat, std::unique_ptr childColumnReader, + storage::MemoryManager* memoryManager); + + inline void initializeRead(uint64_t rowGroupIdx, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol) override { + childColumnReader->initializeRead(rowGroupIdx, columns, protocol); + } + + uint64_t read(uint64_t numValues, parquet_filter_t& filter, uint8_t* defineOut, + uint8_t* repeatOut, common::ValueVector* resultOut) override; + + void applyPendingSkips(uint64_t numValues) override; + +private: + inline uint64_t getGroupRowsAvailable() override { + return childColumnReader->getGroupRowsAvailable() + overflowChildCount; + } + + inline uint64_t getTotalCompressedSize() override { + return childColumnReader->getTotalCompressedSize(); + } + + inline void registerPrefetch(ThriftFileTransport& transport, bool allow_merge) override { + childColumnReader->registerPrefetch(transport, allow_merge); + } + +private: + std::unique_ptr childColumnReader; + ResizeableBuffer childDefines; + ResizeableBuffer childRepeats; + uint8_t* childDefinesPtr; + uint8_t* childRepeatsPtr; + parquet_filter_t childFilter; + uint64_t overflowChildCount; + std::unique_ptr vectorToRead; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_dbp_decoder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_dbp_decoder.h new file mode 100644 index 0000000000..c35506c961 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_dbp_decoder.h @@ -0,0 +1,115 @@ +#pragma once + +#include "decode_utils.h" + +namespace lbug { +namespace processor { + +class DbpDecoder { +public: + DbpDecoder(uint8_t* buffer, uint32_t buffer_len) : buffer_(buffer, buffer_len) { + // + // overall header + block_value_count = ParquetDecodeUtils::VarintDecode(buffer_); + miniblocks_per_block = ParquetDecodeUtils::VarintDecode(buffer_); + total_value_count = ParquetDecodeUtils::VarintDecode(buffer_); + start_value = + ParquetDecodeUtils::ZigzagToInt(ParquetDecodeUtils::VarintDecode(buffer_)); + + // some derivatives + KU_ASSERT(miniblocks_per_block > 0); + values_per_miniblock = block_value_count / miniblocks_per_block; + miniblock_bit_widths = std::unique_ptr(new uint8_t[miniblocks_per_block]); + + // init state to something sane + values_left_in_block = 0; + values_left_in_miniblock = 0; + miniblock_offset = 0; + min_delta = 0; + bitpack_pos = 0; + is_first_value = true; + }; + + template + void GetBatch(uint8_t* values_target_ptr, uint32_t batch_size) { + auto* values = reinterpret_cast(values_target_ptr); + + if (batch_size == 0) { + return; + } + uint64_t value_offset = 0; + + if (is_first_value) { + values[0] = start_value; + value_offset++; + is_first_value = false; + } + + if (total_value_count == 1) { // I guess it's a special case + if (batch_size > 1) { + throw std::runtime_error("DBP decode did not find enough values (have 1)"); + } + return; + } + + while (value_offset < batch_size) { + if (values_left_in_block == 0) { // need to open new block + if (bitpack_pos > 0) { // have to eat the leftovers if any + buffer_.inc(1); + } + min_delta = ParquetDecodeUtils::ZigzagToInt( + ParquetDecodeUtils::VarintDecode(buffer_)); + for (auto miniblock_idx = 0u; miniblock_idx < miniblocks_per_block; + miniblock_idx++) { + miniblock_bit_widths[miniblock_idx] = buffer_.read(); + // TODO what happens if width is 0? + } + values_left_in_block = block_value_count; + miniblock_offset = 0; + bitpack_pos = 0; + values_left_in_miniblock = values_per_miniblock; + } + if (values_left_in_miniblock == 0) { + miniblock_offset++; + values_left_in_miniblock = values_per_miniblock; + } + + auto read_now = + std::min(values_left_in_miniblock, (uint64_t)batch_size - value_offset); + ParquetDecodeUtils::BitUnpack(buffer_, bitpack_pos, &values[value_offset], read_now, + miniblock_bit_widths[miniblock_offset]); + for (auto i = value_offset; i < value_offset + read_now; i++) { + values[i] = ((i == 0) ? start_value : values[i - 1]) + min_delta + values[i]; + } + value_offset += read_now; + values_left_in_miniblock -= read_now; + values_left_in_block -= read_now; + } + + if (value_offset != batch_size) { + throw std::runtime_error("DBP decode did not find enough values"); + } + start_value = values[batch_size - 1]; + } + +private: + ByteBuffer buffer_; + uint64_t block_value_count; + uint64_t miniblocks_per_block; + uint64_t total_value_count; + int64_t start_value; + uint64_t values_per_miniblock; + + std::unique_ptr miniblock_bit_widths; + uint64_t values_left_in_block; + uint64_t values_left_in_miniblock; + uint64_t miniblock_offset; + int64_t min_delta; + + bool is_first_value; + + uint8_t bitpack_pos; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h new file mode 100644 index 0000000000..f674c31a50 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h @@ -0,0 +1,120 @@ +#pragma once + +#include "column_reader.h" +#include "common/data_chunk/data_chunk.h" +#include "common/file_system/virtual_file_system.h" +#include "common/types/types.h" +#include "function/function.h" +#include "function/table/scan_file_function.h" +#include "parquet_types.h" +#include "protocol/TCompactProtocol.h" +#include "resizable_buffer.h" + +namespace lbug { +namespace processor { + +struct ParquetReaderPrefetchConfig { + // Percentage of data in a row group span that should be scanned for enabling whole group + // prefetch + static constexpr double WHOLE_GROUP_PREFETCH_MINIMUM_SCAN = 0.95; +}; + +struct ParquetReaderScanState { + std::vector groupIdxList; + int64_t currentGroup = -1; + uint64_t groupOffset = UINT64_MAX; + std::unique_ptr fileInfo; + std::unique_ptr rootReader; + std::unique_ptr thriftFileProto; + + bool finished = false; + + ResizeableBuffer defineBuf; + ResizeableBuffer repeatBuf; + + // TODO(Ziyi): We currently only support reading from local file system, thus the prefetch + // mode is disabled by default. Add this back when we support remote file system. + bool prefetchMode = false; + bool currentGroupPrefetched = false; +}; + +class ParquetReader { +public: + ParquetReader(std::string filePath, std::vector columnSkips, + main::ClientContext* context); + ~ParquetReader() = default; + + void initializeScan(ParquetReaderScanState& state, std::vector groups_to_read, + common::VirtualFileSystem* vfs); + bool scanInternal(ParquetReaderScanState& state, common::DataChunk& result); + void scan(ParquetReaderScanState& state, common::DataChunk& result); + uint64_t getNumRowsGroups() { return metadata->row_groups.size(); } + + uint32_t getNumColumns() const { return columnNames.size(); } + std::string getColumnName(uint32_t idx) const { return columnNames[idx]; } + const common::LogicalType& getColumnType(uint32_t idx) const { return columnTypes[idx]; } + + lbug_parquet::format::FileMetaData* getMetadata() const { return metadata.get(); } + +private: + std::unique_ptr createThriftProtocol( + common::FileInfo* fileInfo_, bool prefetch_mode) { + return std::make_unique< + lbug_apache::thrift::protocol::TCompactProtocolT>( + std::make_shared(fileInfo_, prefetch_mode)); + } + const lbug_parquet::format::RowGroup& getGroup(ParquetReaderScanState& state) { + KU_ASSERT( + state.currentGroup >= 0 && (uint64_t)state.currentGroup < state.groupIdxList.size()); + KU_ASSERT(state.groupIdxList[state.currentGroup] < metadata->row_groups.size()); + return metadata->row_groups[state.groupIdxList[state.currentGroup]]; + } + static common::LogicalType deriveLogicalType(const lbug_parquet::format::SchemaElement& s_ele); + void initMetadata(); + std::unique_ptr createReader(); + std::unique_ptr createReaderRecursive(uint64_t depth, uint64_t maxDefine, + uint64_t maxRepeat, uint64_t& nextSchemaIdx, uint64_t& nextFileIdx); + void prepareRowGroupBuffer(ParquetReaderScanState& state, uint64_t colIdx); + // Group span is the distance between the min page offset and the max page offset plus the max + // page compressed size + uint64_t getGroupSpan(ParquetReaderScanState& state); + uint64_t getGroupCompressedSize(ParquetReaderScanState& state); + uint64_t getGroupOffset(ParquetReaderScanState& state); + +private: + std::string filePath; + std::vector columnSkips; + std::vector columnNames; + std::vector columnTypes; + + std::unique_ptr metadata; + main::ClientContext* context; +}; + +struct ParquetScanSharedState final : function::ScanFileWithProgressSharedState { + explicit ParquetScanSharedState(common::FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context, std::vector columnSkips); + + std::vector> readers; + std::vector columnSkips; + uint64_t totalRowsGroups; + std::atomic numBlocksReadByFiles; +}; + +struct ParquetScanLocalState final : function::TableFuncLocalState { + ParquetScanLocalState() : reader(nullptr) { + state = std::make_unique(); + } + + ParquetReader* reader; + std::unique_ptr state; +}; + +struct ParquetScanFunction { + static constexpr const char* name = "READ_PARQUET"; + + static function::function_set getFunctionSet(); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_rle_bp_decoder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_rle_bp_decoder.h new file mode 100644 index 0000000000..759a9e5712 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_rle_bp_decoder.h @@ -0,0 +1,118 @@ +#pragma once + +#include "decode_utils.h" +#include "resizable_buffer.h" + +namespace lbug { +namespace processor { + +class RleBpDecoder { +public: + /// Create a decoder object. buffer/buffer_len is the decoded data. + /// bit_width is the width of each value (before encoding). + RleBpDecoder(uint8_t* buffer, uint32_t buffer_len, uint32_t bit_width) + : buffer_(buffer, buffer_len), bit_width_(bit_width), current_value_(0), repeat_count_(0), + literal_count_(0) { + if (bit_width >= 64) { + throw std::runtime_error("Decode bit width too large"); + } + byte_encoded_len = ((bit_width_ + 7) / 8); + max_val = (uint64_t(1) << bit_width_) - 1; + } + + template + void GetBatch(uint8_t* values_target_ptr, uint32_t batch_size) { + auto* values = reinterpret_cast(values_target_ptr); + uint32_t values_read = 0; + + while (values_read < batch_size) { + if (repeat_count_ > 0) { + int repeat_batch = std::min(batch_size - values_read, + static_cast(repeat_count_)); + std::fill(values + values_read, values + values_read + repeat_batch, + static_cast(current_value_)); + repeat_count_ -= repeat_batch; + values_read += repeat_batch; + } else if (literal_count_ > 0) { + uint32_t literal_batch = std::min(batch_size - values_read, + static_cast(literal_count_)); + uint32_t actual_read = ParquetDecodeUtils::BitUnpack(buffer_, bitpack_pos, + values + values_read, literal_batch, bit_width_); + if (literal_batch != actual_read) { + throw std::runtime_error("Did not find enough values"); + } + literal_count_ -= literal_batch; + values_read += literal_batch; + } else { + if (!NextCounts()) { + if (values_read != batch_size) { + throw std::runtime_error("RLE decode did not find enough values"); + } + return; + } + } + } + if (values_read != batch_size) { + throw std::runtime_error("RLE decode did not find enough values"); + } + } + + static uint8_t ComputeBitWidth(uint64_t val) { + if (val == 0) { + return 0; + } + uint8_t ret = 1; + while (((uint64_t)(1u << ret) - 1) < val) { + ret++; + } + return ret; + } + +private: + ByteBuffer buffer_; + + /// Number of bits needed to encode the value. Must be between 0 and 64. + uint32_t bit_width_; + uint64_t current_value_; + uint32_t repeat_count_; + uint32_t literal_count_; + uint8_t byte_encoded_len; + uint64_t max_val; + + uint8_t bitpack_pos = 0; + + /// Fills literal_count_ and repeat_count_ with next values. Returns false if there + /// are no more. + template + bool NextCounts() { + // Read the next run's indicator int, it could be a literal or repeated run. + // The int is encoded as a vlq-encoded value. + if (bitpack_pos != 0) { + buffer_.inc(1); + bitpack_pos = 0; + } + auto indicator_value = ParquetDecodeUtils::VarintDecode(buffer_); + + // lsb indicates if it is a literal run or repeated run + bool is_literal = indicator_value & 1; + if (is_literal) { + literal_count_ = (indicator_value >> 1) * 8; + } else { + repeat_count_ = indicator_value >> 1; + // (ARROW-4018) this is not big-endian compatible, lol + current_value_ = 0; + for (auto i = 0; i < byte_encoded_len; i++) { + current_value_ |= (buffer_.read() << (i * 8)); + } + // sanity check + if (repeat_count_ > 0 && current_value_ > max_val) { + throw std::runtime_error("Payload value bigger than allowed. Corrupted file?"); + } + } + // TODO complain if we run out of buffer + return true; + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_timestamp.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_timestamp.h new file mode 100644 index 0000000000..c168b776e0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/parquet_timestamp.h @@ -0,0 +1,27 @@ +#pragma once + +#include "common/types/timestamp_t.h" + +namespace lbug { +namespace processor { + +struct Int96 { + uint32_t value[3]; +}; + +struct ParquetTimeStampUtils { + static constexpr int64_t JULIAN_TO_UNIX_EPOCH_DAYS = 2440588LL; + static constexpr int64_t MILLISECONDS_PER_DAY = 86400000LL; + static constexpr int64_t MICROSECONDS_PER_DAY = MILLISECONDS_PER_DAY * 1000LL; + static constexpr int64_t NANOSECONDS_PER_MICRO = 1000LL; + + static common::timestamp_t impalaTimestampToTimestamp(const Int96& rawTS); + static common::timestamp_t parquetTimestampMicrosToTimestamp(const int64_t& rawTS); + static common::timestamp_t parquetTimestampMsToTimestamp(const int64_t& rawTS); + static common::timestamp_t parquetTimestampNsToTimestamp(const int64_t& rawTS); + static int64_t impalaTimestampToMicroseconds(const Int96& impalaTimestamp); + static common::date_t parquetIntToDate(const int32_t& raw_date); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/resizable_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/resizable_buffer.h new file mode 100644 index 0000000000..2086b1e0f0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/resizable_buffer.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include + +#include "common/utils.h" + +namespace lbug { +namespace processor { + +class ByteBuffer { // on to the 10 thousandth impl +public: + ByteBuffer() = default; + ByteBuffer(uint8_t* ptr, uint64_t len) : ptr{ptr}, len{len} {}; + + uint8_t* ptr = nullptr; + uint64_t len = 0; + +public: + void inc(uint64_t increment) { + available(increment); + len -= increment; + ptr += increment; + } + + template + T read() { + T val = get(); + inc(sizeof(T)); + return val; + } + + template + T Load(const uint8_t* ptr) { + T ret{}; + memcpy(&ret, ptr, sizeof(ret)); + return ret; + } + + template + T get() { + available(sizeof(T)); + T val = Load(ptr); + return val; + } + + void copyTo(char* dest, uint64_t len) const { + available(len); + std::memcpy(dest, ptr, len); + } + + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + void zero() { std::memset(ptr, 0, len); } + + void available(uint64_t req_len) const { + if (req_len > len) { + throw std::runtime_error("Out of buffer"); + } + } +}; + +class ResizeableBuffer : public ByteBuffer { +public: + ResizeableBuffer() = default; + explicit ResizeableBuffer(uint64_t new_size) { resize(new_size); } + void resize(uint64_t new_size) { + len = new_size; + if (new_size == 0) { + return; + } + if (new_size > allocLen) { + allocLen = common::nextPowerOfTwo(new_size); + allocatedData = std::make_unique(allocLen); + ptr = allocatedData.get(); + } + } + +private: + std::unique_ptr allocatedData; + uint64_t allocLen = 0; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/string_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/string_column_reader.h new file mode 100644 index 0000000000..f8036fb6f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/string_column_reader.h @@ -0,0 +1,38 @@ +#pragma once + +#include "column_reader.h" +#include "processor/operator/persistent/reader/parquet/templated_column_reader.h" + +namespace lbug { +namespace processor { + +struct StringParquetValueConversion { + static common::ku_string_t dictRead(ByteBuffer& dict, uint32_t& offset, ColumnReader& reader); + + static common::ku_string_t plainRead(ByteBuffer& plainData, ColumnReader& reader); + + static void plainSkip(ByteBuffer& plainData, ColumnReader& reader); +}; + +class StringColumnReader + : public TemplatedColumnReader { +public: + static constexpr const common::PhysicalTypeID TYPE = common::PhysicalTypeID::STRING; + +public: + StringColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat); + + std::unique_ptr dictStrs; + uint64_t fixedWidthStringLength; + +public: + void dictionary(const std::shared_ptr& dictionary_data, + uint64_t numEntries) override; + static uint32_t verifyString(const char* strData, uint32_t strLen, const bool isVarchar); + uint32_t verifyString(const char* strData, uint32_t strLen); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/struct_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/struct_column_reader.h new file mode 100644 index 0000000000..e15c4b526e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/struct_column_reader.h @@ -0,0 +1,35 @@ +#pragma once + +#include "column_reader.h" + +namespace lbug { +namespace processor { + +class StructColumnReader : public ColumnReader { +public: + static constexpr const common::PhysicalTypeID TYPE = common::PhysicalTypeID::STRUCT; + +public: + StructColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat, std::vector> childReaders); + + void initializeRead(uint64_t rowGroupIdx, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol) override; + uint64_t read(uint64_t num_values, parquet_filter_t& filter, uint8_t* define_out, + uint8_t* repeat_out, common::ValueVector* result) override; + ColumnReader* getChildReader(uint64_t childIdx); + +private: + uint64_t getTotalCompressedSize() override; + void registerPrefetch(ThriftFileTransport& transport, bool allow_merge) override; + void skip(uint64_t num_values) override; + uint64_t getGroupRowsAvailable() override; + +private: + std::vector> childReaders; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/templated_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/templated_column_reader.h new file mode 100644 index 0000000000..ed9ea1231c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/templated_column_reader.h @@ -0,0 +1,95 @@ +#pragma once + +#include "column_reader.h" +#include "resizable_buffer.h" + +namespace lbug { +namespace processor { + +template +struct TemplatedParquetValueConversion { + static VALUE_TYPE dictRead(ByteBuffer& dict, uint32_t& offset, ColumnReader& /*reader*/) { + KU_ASSERT(offset < dict.len / sizeof(VALUE_TYPE)); + return ((VALUE_TYPE*)dict.ptr)[offset]; + } + + static VALUE_TYPE plainRead(ByteBuffer& plainData, ColumnReader& /*reader*/) { + return plainData.read(); + } + + static void plainSkip(ByteBuffer& plainData, ColumnReader& /*reader*/) { + plainData.inc(sizeof(VALUE_TYPE)); + } +}; + +template +class TemplatedColumnReader : public ColumnReader { +public: + static constexpr const common::PhysicalTypeID TYPE = common::PhysicalTypeID::ANY; + +public: + TemplatedColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat) + : ColumnReader(reader, std::move(type), schema, schemaIdx, maxDefine, maxRepeat){}; + + std::shared_ptr dict; + +public: + void allocateDict(uint64_t size) { + if (!dict) { + dict = std::make_shared(size); + } else { + dict->resize(size); + } + } + + void offsets(uint32_t* offsets, uint8_t* defines, uint64_t numValues, parquet_filter_t& filter, + uint64_t resultOffset, common::ValueVector* result) override { + uint64_t offsetIdx = 0; + for (auto rowIdx = 0u; rowIdx < numValues; rowIdx++) { + if (hasDefines() && defines[rowIdx + resultOffset] != maxDefine) { + result->setNull(rowIdx + resultOffset, true); + continue; + } + result->setNull(rowIdx + resultOffset, false); + if (filter[rowIdx + resultOffset]) { + VALUE_TYPE val = VALUE_CONVERSION::dictRead(*dict, offsets[offsetIdx++], *this); + result->setValue(rowIdx + resultOffset, val); + } else { + offsetIdx++; + } + } + } + + void plain(const std::shared_ptr& plainData, uint8_t* defines, uint64_t numValues, + parquet_filter_t& filter, uint64_t resultOffset, common::ValueVector* result) override { + plainTemplated(plainData, defines, numValues, filter, + resultOffset, result); + } + + void dictionary(const std::shared_ptr& data, + uint64_t /*num_entries*/) override { + dict = data; + } +}; + +template +struct CallbackParquetValueConversion { + static DUCKDB_PHYSICAL_TYPE dictRead(ByteBuffer& dict, uint32_t& offset, ColumnReader& reader) { + return TemplatedParquetValueConversion::dictRead(dict, offset, + reader); + } + + static DUCKDB_PHYSICAL_TYPE plainRead(ByteBuffer& plainData, ColumnReader& /*reader*/) { + return FUNC(plainData.read()); + } + + static void plainSkip(ByteBuffer& plainData, ColumnReader& /*reader*/) { + plainData.inc(sizeof(PARQUET_PHYSICAL_TYPE)); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/thrift_tools.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/thrift_tools.h new file mode 100644 index 0000000000..f4b7b38035 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/thrift_tools.h @@ -0,0 +1,198 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "common/file_system/file_info.h" +#include "transport/TVirtualTransport.h" + +namespace lbug { +namespace processor { + +// A ReadHead for prefetching data in a specific range +struct ReadHead { + ReadHead(uint64_t location, uint64_t size) : location(location), size(size){}; + // Hint info + uint64_t location; + uint64_t size; + + // Current info + std::unique_ptr data; + bool data_isset = false; + + uint64_t GetEnd() const { return size + location; } + + void Allocate() { data = std::make_unique(size); } +}; + +// Comparator for ReadHeads that are either overlapping, adjacent, or within ALLOW_GAP bytes from +// each other +struct ReadHeadComparator { + static constexpr uint64_t ALLOW_GAP = 1 << 14; // 16 KiB + bool operator()(const ReadHead* a, const ReadHead* b) const { + auto a_start = a->location; + auto a_end = a->location + a->size; + auto b_start = b->location; + + if (a_end <= UINT64_MAX - ALLOW_GAP) { + a_end += ALLOW_GAP; + } + + return a_start < b_start && a_end < b_start; + } +}; + +// Two-step read ahead buffer +// 1: register all ranges that will be read, merging ranges that are consecutive +// 2: prefetch all registered ranges +struct ReadAheadBuffer { + explicit ReadAheadBuffer(common::FileInfo* handle) : handle(handle) {} + + // The list of read heads + std::list read_heads; + // Set for merging consecutive ranges + std::set merge_set; + + common::FileInfo* handle; + + uint64_t total_size = 0; + + // Add a read head to the prefetching list + void AddReadHead(uint64_t pos, uint64_t len, bool merge_buffers = true) { + // Attempt to merge with existing + if (merge_buffers) { + ReadHead new_read_head{pos, len}; + auto lookup_set = merge_set.find(&new_read_head); + if (lookup_set != merge_set.end()) { + auto existing_head = *lookup_set; + auto new_start = + std::min(existing_head->location, new_read_head.location); + auto new_length = + std::min(existing_head->GetEnd(), new_read_head.GetEnd()) - new_start; + existing_head->location = new_start; + existing_head->size = new_length; + return; + } + } + + read_heads.emplace_front(ReadHead(pos, len)); + total_size += len; + auto& read_head = read_heads.front(); + + if (merge_buffers) { + merge_set.insert(&read_head); + } + + if (read_head.GetEnd() > handle->getFileSize()) { + throw std::runtime_error("Prefetch registered for bytes outside file"); + } + } + + // Returns the relevant read head + ReadHead* GetReadHead(uint64_t pos) { + for (auto& read_head : read_heads) { + if (pos >= read_head.location && pos < read_head.GetEnd()) { + return &read_head; + } + } + return nullptr; + } + + // Prefetch all read heads + void Prefetch() { + for (auto& read_head : read_heads) { + read_head.Allocate(); + + if (read_head.GetEnd() > handle->getFileSize()) { + throw std::runtime_error("Prefetch registered requested for bytes outside file"); + } + handle->readFromFile(read_head.data.get(), read_head.size, read_head.location); + read_head.data_isset = true; + } + } +}; + +class ThriftFileTransport + : public lbug_apache::thrift::transport::TVirtualTransport { +public: + static constexpr uint64_t PREFETCH_FALLBACK_BUFFERSIZE = 1000000; + + ThriftFileTransport(common::FileInfo* handle_p, bool prefetch_mode_p) + : handle(handle_p), location(0), ra_buffer(ReadAheadBuffer(handle_p)), + prefetch_mode(prefetch_mode_p) {} + + uint32_t read(uint8_t* buf, uint32_t len) { + auto prefetch_buffer = ra_buffer.GetReadHead(location); + if (prefetch_buffer != nullptr && + location - prefetch_buffer->location + len <= prefetch_buffer->size) { + KU_ASSERT(location - prefetch_buffer->location + len <= prefetch_buffer->size); + + if (!prefetch_buffer->data_isset) { + prefetch_buffer->Allocate(); + handle->readFromFile(prefetch_buffer->data.get(), prefetch_buffer->size, + prefetch_buffer->location); + prefetch_buffer->data_isset = true; + } + memcpy(buf, prefetch_buffer->data.get() + location - prefetch_buffer->location, len); + } else { + if (prefetch_mode && len < PREFETCH_FALLBACK_BUFFERSIZE && len > 0) { + Prefetch(location, std::min(PREFETCH_FALLBACK_BUFFERSIZE, + handle->getFileSize() - location)); + auto prefetch_buffer_fallback = ra_buffer.GetReadHead(location); + KU_ASSERT(location - prefetch_buffer_fallback->location + len <= + prefetch_buffer_fallback->size); + memcpy(buf, + prefetch_buffer_fallback->data.get() + location - + prefetch_buffer_fallback->location, + len); + } else { + handle->readFromFile(buf, len, location); + } + } + location += len; + return len; + } + + // Prefetch a single buffer + void Prefetch(uint64_t pos, uint64_t len) { + RegisterPrefetch(pos, len, false); + FinalizeRegistration(); + PrefetchRegistered(); + } + + // Register a buffer for prefixing + void RegisterPrefetch(uint64_t pos, uint64_t len, bool can_merge = true) { + ra_buffer.AddReadHead(pos, len, can_merge); + } + + // Prevents any further merges, should be called before PrefetchRegistered + void FinalizeRegistration() { ra_buffer.merge_set.clear(); } + + // Prefetch all previously registered ranges + void PrefetchRegistered() { ra_buffer.Prefetch(); } + + void ClearPrefetch() { + ra_buffer.read_heads.clear(); + ra_buffer.merge_set.clear(); + } + + void SetLocation(uint64_t location_p) { location = location_p; } + + uint64_t GetLocation() const { return location; } + uint64_t GetSize() { return handle->getFileSize(); } + +private: + common::FileInfo* handle; + uint64_t location; + + // Multi-buffer prefetch + ReadAheadBuffer ra_buffer; + + // Whether the prefetch mode is enabled. In this mode the DirectIO flag of the handle will be + // set and the parquet reader will manage the read buffering. + bool prefetch_mode; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/uuid_column_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/uuid_column_reader.h new file mode 100644 index 0000000000..c9e517784e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/parquet/uuid_column_reader.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/types/uuid.h" +#include "processor/operator/persistent/reader/parquet/resizable_buffer.h" +#include "templated_column_reader.h" + +namespace lbug { +namespace processor { + +struct UUIDValueConversion { + static common::ku_uuid_t dictRead(ByteBuffer& dict, uint32_t& offset, + ColumnReader& /*reader*/) { + return reinterpret_cast(dict.ptr)[offset]; + } + + static common::ku_uuid_t ReadParquetUUID(const uint8_t* input); + + static common::ku_uuid_t plainRead(ByteBuffer& bufferData, ColumnReader& /*reader*/); + + static void plainSkip(ByteBuffer& plain_data, ColumnReader& /*reader*/) { + plain_data.inc(sizeof(common::ku_uuid_t)); + } +}; + +class UUIDColumnReader : public TemplatedColumnReader { +public: + UUIDColumnReader(ParquetReader& reader, common::LogicalType dataType, + const lbug_parquet::format::SchemaElement& schema_p, uint64_t file_idx_p, + uint64_t maxDefine, uint64_t maxRepeat) + : TemplatedColumnReader(reader, std::move(dataType), + schema_p, file_idx_p, maxDefine, maxRepeat){}; + +protected: + void dictionary(const std::shared_ptr& dictionaryData, + uint64_t numEntries) override; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/reader_bind_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/reader_bind_utils.h new file mode 100644 index 0000000000..c337baf435 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/reader/reader_bind_utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common/types/types.h" + +namespace lbug { +namespace processor { + +struct ReaderBindUtils { + static void validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber); + static void validateColumnTypes(const std::vector& columnNames, + const std::vector& expectedColumnTypes, + const std::vector& detectedColumnTypes); + static void resolveColumns(const std::vector& expectedColumnNames, + const std::vector& detectedColumnNames, + std::vector& resultColumnNames, + const std::vector& expectedColumnTypes, + const std::vector& detectedColumnTypes, + std::vector& resultColumnTypes); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/rel_batch_insert.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/rel_batch_insert.h new file mode 100644 index 0000000000..3893cd3b42 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/rel_batch_insert.h @@ -0,0 +1,171 @@ +#pragma once + +#include "common/enums/rel_direction.h" +#include "processor/operator/partitioner.h" +#include "processor/operator/persistent/batch_insert.h" + +namespace lbug { +namespace catalog { +class RelGroupCatalogEntry; +} // namespace catalog + +namespace storage { +class CSRNodeGroup; +struct InMemChunkedCSRHeader; +} // namespace storage + +namespace processor { + +struct LBUG_API RelBatchInsertPrintInfo final : OPPrintInfo { + std::string tableName; + + explicit RelBatchInsertPrintInfo(std::string tableName) : tableName(std::move(tableName)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new RelBatchInsertPrintInfo(*this)); + } + +private: + RelBatchInsertPrintInfo(const RelBatchInsertPrintInfo& other) + : OPPrintInfo(other), tableName(other.tableName) {} +}; + +struct LBUG_API RelBatchInsertProgressSharedState { + std::atomic partitionsDone; + uint64_t partitionsTotal; + + RelBatchInsertProgressSharedState() : partitionsDone{0}, partitionsTotal{0} {}; +}; + +struct LBUG_API RelBatchInsertInfo final : BatchInsertInfo { + common::RelDataDirection direction; + common::table_id_t fromTableID, toTableID; + uint64_t partitioningIdx = UINT64_MAX; + common::column_id_t boundNodeOffsetColumnID = common::INVALID_COLUMN_ID; + + RelBatchInsertInfo(std::string tableName, std::vector warningColumnTypes, + common::table_id_t fromTableID, common::table_id_t toTableID, + common::RelDataDirection direction) + : BatchInsertInfo{std::move(tableName), std::move(warningColumnTypes)}, + direction{direction}, fromTableID{fromTableID}, toTableID{toTableID} {} + RelBatchInsertInfo(const RelBatchInsertInfo& other) + : BatchInsertInfo{other}, direction{other.direction}, fromTableID{other.fromTableID}, + toTableID{other.toTableID}, partitioningIdx{other.partitioningIdx}, + boundNodeOffsetColumnID{other.boundNodeOffsetColumnID} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct LBUG_API RelBatchInsertLocalState final : BatchInsertLocalState { + common::partition_idx_t nodeGroupIdx = common::INVALID_NODE_GROUP_IDX; + std::unique_ptr dummyAllNullDataChunk; +}; + +struct LBUG_API RelBatchInsertExecutionState { + virtual ~RelBatchInsertExecutionState() = default; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +/** + * Abstract RelBatchInsert class + * When performing rel batch insert, we typically take some source data and use it to construct the + * CSR header as well as a chunked node group that actually contains the rel properties + * Child classes can customize how data is copied from the source into the CSR chunked node group + * (which is the format in which the rels are actually stored) + * + * The following interfaces can be overriden: + * - initExecutionState(): The execution state contains any extra local state needed during the + * insertion of a single CSR node group. This reset by calling initExecutionState() before the + * insertion of each node group so make sure that the lifetime of any stored state doesn't exceed + * this. + * - populateCSRLengths(): Populates the length chunk in the CSR header. The offsets are directly + * calculated from the lengths and thus calculating the offsets doesn't need to be customized + * - finalizeStartCSROffsets(): The CSR offsets are initially calculated as start offsets. The + * default behaviour of this function is to convert the start offsets to end offsets. However, if + * any extra logic is required during this conversion, this function can be overriden. + * - writeToTable(): Writes property data to the local chunked node group. This function is also + * responsible for ensuring that the data is written in a way such that the copied data is in + * agreement with the CSR header + * + * Generally, the source data to be copied from should be contained in the partitionerSharedState, + * which can also be overridden. + */ +class LBUG_API RelBatchInsertImpl { +public: + virtual ~RelBatchInsertImpl() = default; + virtual std::unique_ptr copy() = 0; + virtual std::unique_ptr initExecutionState( + const PartitionerSharedState& partitionerSharedState, const RelBatchInsertInfo& relInfo, + common::node_group_idx_t nodeGroupIdx) = 0; + virtual void populateCSRLengths(RelBatchInsertExecutionState& executionState, + storage::InMemChunkedCSRHeader& csrHeader, common::offset_t numNodes, + const RelBatchInsertInfo& relInfo) = 0; + virtual void finalizeStartCSROffsets(RelBatchInsertExecutionState& executionState, + storage::InMemChunkedCSRHeader& csrHeader, const RelBatchInsertInfo& relInfo); + virtual void writeToTable(RelBatchInsertExecutionState& executionState, + const storage::InMemChunkedCSRHeader& csrHeader, const RelBatchInsertLocalState& localState, + BatchInsertSharedState& sharedState, const RelBatchInsertInfo& relInfo) = 0; +}; + +class LBUG_API RelBatchInsert : public BatchInsert { +public: + RelBatchInsert(std::unique_ptr info, + std::shared_ptr partitionerSharedState, + std::shared_ptr sharedState, physical_op_id id, + std::unique_ptr printInfo, + std::shared_ptr progressSharedState, + std::unique_ptr impl) + : BatchInsert{std::move(info), std::move(sharedState), id, std::move(printInfo)}, + partitionerSharedState{std::move(partitionerSharedState)}, + progressSharedState{std::move(progressSharedState)}, impl(std::move(impl)) {} + + bool isSource() const override { return true; } + + void initGlobalStateInternal(ExecutionContext* context) override; + void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override; + + void executeInternal(ExecutionContext* context) override; + void finalizeInternal(ExecutionContext* context) override; + + void updateProgress(const ExecutionContext* context) const; + + std::unique_ptr copy() override { + return std::make_unique(info->copy(), partitionerSharedState, sharedState, + id, printInfo->copy(), progressSharedState, impl->copy()); + } + +private: + void appendNodeGroup(const catalog::RelGroupCatalogEntry& relGroupEntry, + storage::MemoryManager& mm, transaction::Transaction* transaction, + storage::CSRNodeGroup& nodeGroup, const RelBatchInsertInfo& relInfo, + const RelBatchInsertLocalState& localState); + + void populateCSRHeader(const catalog::RelGroupCatalogEntry& relGroupEntry, + RelBatchInsertExecutionState& executionState, common::offset_t startNodeOffset, + const RelBatchInsertInfo& relInfo, const RelBatchInsertLocalState& localState, + common::offset_t numNodes, bool leaveGaps); + + static void checkRelMultiplicityConstraint(const catalog::RelGroupCatalogEntry& relGroupEntry, + const storage::InMemChunkedCSRHeader& csrHeader, common::offset_t startNodeOffset, + const RelBatchInsertInfo& relInfo); + +protected: + std::shared_ptr partitionerSharedState; + std::shared_ptr progressSharedState; + std::unique_ptr impl; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/set.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/set.h new file mode 100644 index 0000000000..c6dca2a70a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/set.h @@ -0,0 +1,77 @@ +#pragma once + +#include "processor/operator/physical_operator.h" +#include "set_executor.h" + +namespace lbug { +namespace processor { + +struct SetPropertyPrintInfo final : OPPrintInfo { + std::vector expressions; + + explicit SetPropertyPrintInfo(std::vector expressions) + : expressions(std::move(expressions)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new SetPropertyPrintInfo(*this)); + } + +private: + SetPropertyPrintInfo(const SetPropertyPrintInfo& other) + : OPPrintInfo(other), expressions(other.expressions) {} +}; + +class SetNodeProperty final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SET_PROPERTY; + +public: + SetNodeProperty(std::vector> executors, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + executors{std::move(executors)} {} + + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(copyVector(executors), children[0]->copy(), id, + printInfo->copy()); + } + +private: + std::vector> executors; +}; + +class SetRelProperty final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SET_PROPERTY; + +public: + SetRelProperty(std::vector> executors, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + executors{std::move(executors)} {} + + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(copyVector(executors), children[0]->copy(), id, + printInfo->copy()); + } + +private: + std::vector> executors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/set_executor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/set_executor.h new file mode 100644 index 0000000000..b4c0323d89 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/set_executor.h @@ -0,0 +1,193 @@ +#pragma once + +#include "expression_evaluator/expression_evaluator.h" +#include "processor/execution_context.h" +#include "processor/result/result_set.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +namespace lbug { +namespace processor { + +struct NodeSetInfo { + DataPos nodeIDPos; + DataPos columnVectorPos; + + std::unique_ptr evaluator; + + common::ValueVector* nodeIDVector = nullptr; + common::ValueVector* columnVector = nullptr; + common::ValueVector* columnDataVector = nullptr; + + NodeSetInfo(DataPos nodeIDPos, DataPos columnVectorPos, + std::unique_ptr evaluator) + : nodeIDPos{nodeIDPos}, columnVectorPos{columnVectorPos}, evaluator{std::move(evaluator)} {} + EXPLICIT_COPY_DEFAULT_MOVE(NodeSetInfo); + + void init(const ResultSet& resultSet, main::ClientContext* context); + +private: + NodeSetInfo(const NodeSetInfo& other) + : nodeIDPos{other.nodeIDPos}, columnVectorPos{other.columnVectorPos}, + evaluator{other.evaluator->copy()} {} +}; + +struct NodeTableSetInfo { + storage::NodeTable* table; + common::column_id_t columnID; + + NodeTableSetInfo(storage::NodeTable* table, common::column_id_t columnID) + : table{table}, columnID{columnID} {} + EXPLICIT_COPY_DEFAULT_MOVE(NodeTableSetInfo); + +private: + NodeTableSetInfo(const NodeTableSetInfo& other) + : table{other.table}, columnID{other.columnID} {} +}; + +class NodeSetExecutor { +public: + explicit NodeSetExecutor(NodeSetInfo info) : info{std::move(info)} {} + NodeSetExecutor(const NodeSetExecutor& other) : info{other.info.copy()} {} + virtual ~NodeSetExecutor() = default; + + virtual void init(ResultSet* resultSet, ExecutionContext* context); + + void setNodeID(common::nodeID_t nodeID) const; + + virtual void set(ExecutionContext* context) = 0; + + virtual std::unique_ptr copy() const = 0; + +protected: + NodeSetInfo info; +}; + +class SingleLabelNodeSetExecutor final : public NodeSetExecutor { +public: + SingleLabelNodeSetExecutor(NodeSetInfo info, NodeTableSetInfo tableInfo) + : NodeSetExecutor{std::move(info)}, tableInfo{std::move(tableInfo)} {} + SingleLabelNodeSetExecutor(const SingleLabelNodeSetExecutor& other) + : NodeSetExecutor{other}, tableInfo(other.tableInfo.copy()) {} + + void set(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + NodeTableSetInfo tableInfo; +}; + +class MultiLabelNodeSetExecutor final : public NodeSetExecutor { +public: + MultiLabelNodeSetExecutor(NodeSetInfo info, common::table_id_map_t tableInfos) + : NodeSetExecutor{std::move(info)}, tableInfos{std::move(tableInfos)} {} + MultiLabelNodeSetExecutor(const MultiLabelNodeSetExecutor& other) + : NodeSetExecutor{other}, tableInfos{copyUnorderedMap(other.tableInfos)} {} + + void set(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + common::table_id_map_t tableInfos; +}; + +struct RelSetInfo { + DataPos srcNodeIDPos; + DataPos dstNodeIDPos; + DataPos relIDPos; + DataPos columnVectorPos; + std::unique_ptr evaluator; + + common::ValueVector* srcNodeIDVector = nullptr; + common::ValueVector* dstNodeIDVector = nullptr; + common::ValueVector* relIDVector = nullptr; + common::ValueVector* columnVector = nullptr; + common::ValueVector* columnDataVector = nullptr; + + RelSetInfo(DataPos srcNodeIDPos, DataPos dstNodeIDPos, DataPos relIDPos, + DataPos columnVectorPos, std::unique_ptr evaluator) + : srcNodeIDPos{srcNodeIDPos}, dstNodeIDPos{dstNodeIDPos}, relIDPos{relIDPos}, + columnVectorPos{columnVectorPos}, evaluator{std::move(evaluator)} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelSetInfo); + + void init(const ResultSet& resultSet, main::ClientContext* context); + +private: + RelSetInfo(const RelSetInfo& other) + : srcNodeIDPos{other.srcNodeIDPos}, dstNodeIDPos{other.dstNodeIDPos}, + relIDPos{other.relIDPos}, columnVectorPos{other.columnVectorPos}, + evaluator{other.evaluator->copy()} {} +}; + +struct RelTableSetInfo { + storage::RelTable* table; + common::column_id_t columnID; + + RelTableSetInfo(storage::RelTable* table, common::column_id_t columnID) + : table{table}, columnID{columnID} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelTableSetInfo); + +private: + RelTableSetInfo(const RelTableSetInfo& other) : table{other.table}, columnID{other.columnID} {} +}; + +class RelSetExecutor { +public: + explicit RelSetExecutor(RelSetInfo info) : info{std::move(info)} {} + RelSetExecutor(const RelSetExecutor& other) : info{other.info.copy()} {} + virtual ~RelSetExecutor() = default; + + void init(ResultSet* resultSet, ExecutionContext* context); + + void setRelID(common::nodeID_t relID) const; + + virtual void set(ExecutionContext* context) = 0; + + virtual std::unique_ptr copy() const = 0; + +protected: + RelSetInfo info; +}; + +class SingleLabelRelSetExecutor final : public RelSetExecutor { +public: + SingleLabelRelSetExecutor(RelSetInfo info, RelTableSetInfo tableInfo) + : RelSetExecutor{std::move(info)}, tableInfo{std::move(tableInfo)} {} + SingleLabelRelSetExecutor(const SingleLabelRelSetExecutor& other) + : RelSetExecutor{other}, tableInfo{other.tableInfo.copy()} {} + + void set(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + RelTableSetInfo tableInfo; +}; + +class MultiLabelRelSetExecutor final : public RelSetExecutor { +public: + MultiLabelRelSetExecutor(RelSetInfo info, common::table_id_map_t tableInfos) + : RelSetExecutor{std::move(info)}, tableInfos{std::move(tableInfos)} {} + MultiLabelRelSetExecutor(const MultiLabelRelSetExecutor& other) + : RelSetExecutor{other}, tableInfos{copyUnorderedMap(other.tableInfos)} {} + + void set(ExecutionContext* context) override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + +private: + common::table_id_map_t tableInfos; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/basic_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/basic_column_writer.h new file mode 100644 index 0000000000..8ff7018bac --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/basic_column_writer.h @@ -0,0 +1,93 @@ +#pragma once + +#include "parquet_types.h" +#include "processor/operator/persistent/writer/parquet/column_writer.h" + +namespace lbug { +namespace processor { + +class BasicColumnWriterState : public ColumnWriterState { +public: + BasicColumnWriterState(lbug_parquet::format::RowGroup& rowGroup, uint64_t colIdx) + : rowGroup{rowGroup}, colIdx{colIdx} { + pageInfo.emplace_back(); + } + + lbug_parquet::format::RowGroup& rowGroup; + uint64_t colIdx; + std::vector pageInfo; + std::vector writeInfo; + std::unique_ptr statsState; + uint64_t currentPage = 0; +}; + +class BasicColumnWriter : public ColumnWriter { +public: + BasicColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, + std::vector schemaPath, uint64_t maxRepeat, uint64_t maxDefine, + bool canHaveNulls) + : ColumnWriter(writer, schemaIdx, std::move(schemaPath), maxRepeat, maxDefine, + canHaveNulls) {} + +public: + std::unique_ptr initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) override; + void prepare(ColumnWriterState& state, ColumnWriterState* parent, common::ValueVector* vector, + uint64_t count) override; + void beginWrite(ColumnWriterState& state) override; + void write(ColumnWriterState& state, common::ValueVector* vector, uint64_t count) override; + void finalizeWrite(ColumnWriterState& state) override; + +protected: + void writeLevels(common::Serializer& bufferedSerializer, const std::vector& levels, + uint64_t maxValue, uint64_t startOffset, uint64_t count); + + virtual lbug_parquet::format::Encoding::type getEncoding(BasicColumnWriterState& /*state*/) { + return lbug_parquet::format::Encoding::PLAIN; + } + + void nextPage(BasicColumnWriterState& state); + void flushPage(BasicColumnWriterState& state); + + // Initializes the state used to track statistics during writing. Only used for scalar types. + virtual std::unique_ptr initializeStatsState() { + return std::make_unique(); + } + + // Initialize the writer for a specific page. Only used for scalar types. + virtual std::unique_ptr initializePageState( + BasicColumnWriterState& /*state*/) { + return nullptr; + } + + // Flushes the writer for a specific page. Only used for scalar types. + virtual void flushPageState(common::Serializer& /*bufferedSerializer*/, + ColumnWriterPageState* /*state*/) {} + + // Retrieves the row size of a vector at the specified location. Only used for scalar types. + virtual uint64_t getRowSize(common::ValueVector* /*vector*/, uint64_t /*index*/, + BasicColumnWriterState& /*state*/) { + KU_UNREACHABLE; + } + // Writes a (subset of a) vector to the specified serializer. Only used for scalar types. + virtual void writeVector(common::Serializer& bufferedSerializer, ColumnWriterStatistics* stats, + ColumnWriterPageState* pageState, common::ValueVector* vector, uint64_t chunkStart, + uint64_t chunkEnd) = 0; + + virtual bool hasDictionary(BasicColumnWriterState& /*writerState*/) { return false; } + // The number of elements in the dictionary. + virtual uint64_t dictionarySize(BasicColumnWriterState& /*writerState*/) { KU_UNREACHABLE; } + void writeDictionary(BasicColumnWriterState& state, + std::unique_ptr bufferedSerializer, uint64_t rowCount); + virtual void flushDictionary(BasicColumnWriterState& /*state*/, + ColumnWriterStatistics* /*stats*/) { + KU_UNREACHABLE; + } + + void setParquetStatistics(BasicColumnWriterState& state, + lbug_parquet::format::ColumnChunk& column); + void registerToRowGroup(lbug_parquet::format::RowGroup& rowGroup); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/boolean_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/boolean_column_writer.h new file mode 100644 index 0000000000..43d7a6534f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/boolean_column_writer.h @@ -0,0 +1,67 @@ +#pragma once + +#include "processor/operator/persistent/writer/parquet/basic_column_writer.h" + +namespace lbug { +namespace processor { + +class BooleanStatisticsState : public ColumnWriterStatistics { +public: + BooleanStatisticsState() : min{true}, max{false} {} + + bool min; + bool max; + +public: + bool hasStats() const { return !(min && !max); } + + std::string getMin() override { return getMinValue(); } + std::string getMax() override { return getMaxValue(); } + std::string getMinValue() override { + return hasStats() ? std::string(reinterpret_cast(&min), sizeof(bool)) : + std::string(); + } + std::string getMaxValue() override { + return hasStats() ? std::string(reinterpret_cast(&max), sizeof(bool)) : + std::string(); + } +}; + +class BooleanWriterPageState : public ColumnWriterPageState { +public: + uint8_t byte = 0; + uint8_t bytePos = 0; +}; + +class BooleanColumnWriter : public BasicColumnWriter { +public: + BooleanColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, + std::vector schemaPath, uint64_t maxRepeat, uint64_t maxDefine, + bool canHaveNulls) + : BasicColumnWriter(writer, schemaIdx, std::move(schemaPath), maxRepeat, maxDefine, + canHaveNulls) {} + + inline std::unique_ptr initializeStatsState() override { + return std::make_unique(); + } + + inline uint64_t getRowSize(common::ValueVector* /*vector*/, uint64_t /*index*/, + BasicColumnWriterState& /*state*/) override { + return sizeof(bool); + } + + inline std::unique_ptr initializePageState( + BasicColumnWriterState& /*state*/) override { + return std::make_unique(); + } + + void writeVector(common::Serializer& bufferedSerializer, + ColumnWriterStatistics* writerStatistics, ColumnWriterPageState* writerPageState, + common::ValueVector* vector, uint64_t chunkStart, uint64_t chunkEnd) override; + + void flushPageState(common::Serializer& temp_writer, + ColumnWriterPageState* writerPageState) override; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/column_writer.h new file mode 100644 index 0000000000..b9983d3ad2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/column_writer.h @@ -0,0 +1,109 @@ +#pragma once + +#include "common/serializer/buffer_writer.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "parquet_types.h" + +namespace lbug { +namespace processor { +class ParquetWriter; + +struct PageInformation { + uint64_t offset = 0; + uint64_t rowCount = 0; + uint64_t emptyCount = 0; + uint64_t estimatedPageSize = 0; +}; + +class ColumnWriterPageState { +public: + virtual ~ColumnWriterPageState() = default; +}; + +struct PageWriteInformation { + lbug_parquet::format::PageHeader pageHeader; + std::shared_ptr bufferWriter; + std::unique_ptr writer; + std::unique_ptr pageState; + uint64_t writePageIdx = 0; + uint64_t writeCount = 0; + uint64_t maxWriteCount = 0; + size_t compressedSize = 0; + uint8_t* compressedData = nullptr; + std::unique_ptr compressedBuf; +}; + +class ColumnWriterState { +public: + virtual ~ColumnWriterState() = default; + + std::vector definitionLevels; + std::vector repetitionLevels; + std::vector isEmpty; +}; + +class ColumnWriterStatistics { +public: + virtual ~ColumnWriterStatistics() = default; + + virtual std::string getMin() { return {}; } + virtual std::string getMax() { return {}; } + virtual std::string getMinValue() { return {}; } + virtual std::string getMaxValue() { return {}; } +}; + +class ColumnWriter { +public: + ColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, std::vector schemaPath, + uint64_t maxRepeat, uint64_t maxDefine, bool canHaveNulls); + virtual ~ColumnWriter() = default; + + // Create the column writer for a specific type recursively. + // TODO(Ziyi): We currently don't have statistics to indicate whether a column + // has null value or not. So canHaveNullsToCreate is always true. + static std::unique_ptr createWriterRecursive( + std::vector& schemas, ParquetWriter& writer, + const common::LogicalType& type, const std::string& name, + std::vector schemaPathToCreate, storage::MemoryManager* mm, + uint64_t maxRepeatToCreate = 0, uint64_t maxDefineToCreate = 1, + bool canHaveNullsToCreate = true); + + virtual std::unique_ptr initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) = 0; + // Indicates whether the write need to analyse the data before preparing it. + virtual bool hasAnalyze() { return false; } + virtual void analyze(ColumnWriterState& /*state*/, ColumnWriterState* /*parent*/, + common::ValueVector* /*vector*/, uint64_t /*count*/) { + KU_UNREACHABLE; + } + // Called after all data has been passed to Analyze. + virtual void finalizeAnalyze(ColumnWriterState& /*state*/) { KU_UNREACHABLE; } + virtual void prepare(ColumnWriterState& state, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count) = 0; + virtual void beginWrite(ColumnWriterState& state) = 0; + virtual void write(ColumnWriterState& state, common::ValueVector* vector, uint64_t count) = 0; + virtual void finalizeWrite(ColumnWriterState& state) = 0; + inline uint64_t getVectorPos(common::ValueVector* vector, uint64_t idx) { + return (vector->state == nullptr || !vector->state->isFlat()) ? idx : 0; + } + + ParquetWriter& writer; + uint64_t schemaIdx; + std::vector schemaPath; + uint64_t maxRepeat; + uint64_t maxDefine; + bool canHaveNulls; + // collected stats + uint64_t nullCount; + +protected: + void handleDefineLevels(ColumnWriterState& state, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count, uint16_t defineValue, uint16_t nullValue); + void handleRepeatLevels(ColumnWriterState& stateToHandle, ColumnWriterState* parent); + void compressPage(common::BufferWriter& bufferedSerializer, size_t& compressedSize, + uint8_t*& compressedData, std::unique_ptr& compressedBuf); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/interval_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/interval_column_writer.h new file mode 100644 index 0000000000..9e041850fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/interval_column_writer.h @@ -0,0 +1,33 @@ +#pragma once + +#include "basic_column_writer.h" +#include "common/constants.h" +#include "common/types/interval_t.h" + +namespace lbug { +namespace processor { + +class IntervalColumnWriter : public BasicColumnWriter { + +public: + IntervalColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, + std::vector schemaPath, uint64_t maxRepeat, uint64_t maxDefine, + bool canHaveNulls) + : BasicColumnWriter(writer, schemaIdx, std::move(schemaPath), maxRepeat, maxDefine, + canHaveNulls) {} + +public: + static void writeParquetInterval(common::interval_t input, uint8_t* result); + + void writeVector(common::Serializer& bufferedSerializer, ColumnWriterStatistics* state, + ColumnWriterPageState* pageState, common::ValueVector* vector, uint64_t chunkStart, + uint64_t chunkEnd) override; + + uint64_t getRowSize(common::ValueVector* /*vector*/, uint64_t /*index*/, + BasicColumnWriterState& /*state*/) override { + return common::ParquetConstants::PARQUET_INTERVAL_SIZE; + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/list_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/list_column_writer.h new file mode 100644 index 0000000000..5d9dc38b2a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/list_column_writer.h @@ -0,0 +1,45 @@ +#pragma once + +#include "processor/operator/persistent/writer/parquet/column_writer.h" + +namespace lbug { +namespace processor { + +class ListColumnWriter : public ColumnWriter { +public: + ListColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, std::vector schema, + uint64_t maxRepeat, uint64_t maxDefine, std::unique_ptr childWriter, + bool canHaveNulls) + : ColumnWriter(writer, schemaIdx, std::move(schema), maxRepeat, maxDefine, canHaveNulls), + childWriter(std::move(childWriter)) {} + + std::unique_ptr initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) override; + bool hasAnalyze() override; + void analyze(ColumnWriterState& writerState, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count) override; + void finalizeAnalyze(ColumnWriterState& writerState) override; + void prepare(ColumnWriterState& writerState, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count) override; + void beginWrite(ColumnWriterState& state) override; + void write(ColumnWriterState& writerState, common::ValueVector* vector, + uint64_t count) override; + void finalizeWrite(ColumnWriterState& writerState) override; + +private: + std::unique_ptr childWriter; +}; + +class ListColumnWriterState : public ColumnWriterState { +public: + ListColumnWriterState(lbug_parquet::format::RowGroup& rowGroup, uint64_t colIdx) + : rowGroup{rowGroup}, colIdx{colIdx} {} + + lbug_parquet::format::RowGroup& rowGroup; + uint64_t colIdx; + std::unique_ptr childState; + uint64_t parentIdx = 0; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.h new file mode 100644 index 0000000000..bae2a35098 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.h @@ -0,0 +1,43 @@ +#pragma once + +#include "common/serializer/serializer.h" + +namespace lbug { +namespace processor { + +class RleBpEncoder { +public: + explicit RleBpEncoder(uint32_t bitWidth); + +public: + // NOTE: Prepare is only required if a byte count is required BEFORE writing + // This is the case with e.g. writing repetition/definition levels + // If GetByteCount() is not required, prepare can be safely skipped. + void beginPrepare(uint32_t firstValue); + void prepareValue(uint32_t value); + void finishPrepare(); + + void beginWrite(uint32_t first_value); + void writeValue(common::Serializer& writer, uint32_t value); + void finishWrite(common::Serializer& writer); + + uint64_t getByteCount() const; + + static uint8_t getVarintSize(uint32_t val); + +private: + //! meta information + uint32_t byteWidth; + //! RLE run information + uint64_t byteCount; + uint64_t runCount; + uint64_t currentRunCount; + uint32_t lastValue; + +private: + void finishRun(); + void writeRun(common::Serializer& bufferedSerializer); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/parquet_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/parquet_writer.h new file mode 100644 index 0000000000..172e2d781f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/parquet_writer.h @@ -0,0 +1,94 @@ +#pragma once + +#include + +#include "common/data_chunk/data_chunk.h" +#include "common/file_system/file_info.h" +#include "common/types/types.h" +#include "parquet_types.h" +#include "processor/operator/persistent/writer/parquet/column_writer.h" +#include "processor/result/factorized_table.h" +#include "protocol/TProtocol.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace processor { + +class ParquetWriterTransport : public lbug_apache::thrift::protocol::TTransport { +public: + explicit ParquetWriterTransport(common::FileInfo* fileInfo, common::offset_t& offset) + : fileInfo{fileInfo}, offset{offset} {} + + inline bool isOpen() const override { return true; } + + void open() override {} + + void close() override {} + + inline void write_virt(const uint8_t* buf, uint32_t len) override { + fileInfo->writeFile(buf, len, offset); + offset += len; + } + +private: + common::FileInfo* fileInfo; + common::offset_t& offset; +}; + +struct PreparedRowGroup { + lbug_parquet::format::RowGroup rowGroup; + std::vector> states; +}; + +class ParquetWriter { +public: + ParquetWriter(std::string fileName, std::vector types, + std::vector names, lbug_parquet::format::CompressionCodec::type codec, + main::ClientContext* context); + + inline common::offset_t getOffset() const { return fileOffset; } + inline void write(const uint8_t* buf, uint32_t len) { + fileInfo->writeFile(buf, len, fileOffset); + fileOffset += len; + } + inline lbug_parquet::format::CompressionCodec::type getCodec() { return codec; } + inline lbug_apache::thrift::protocol::TProtocol* getProtocol() { return protocol.get(); } + inline lbug_parquet::format::Type::type getParquetType(uint64_t schemaIdx) { + return fileMetaData.schema[schemaIdx].type; + } + void flush(FactorizedTable& ft); + void finalize(); + static lbug_parquet::format::Type::type convertToParquetType(const common::LogicalType& type); + static void setSchemaProperties(const common::LogicalType& type, + lbug_parquet::format::SchemaElement& schemaElement); + +private: + void prepareRowGroup(FactorizedTable& ft, PreparedRowGroup& result); + void flushRowGroup(PreparedRowGroup& rowGroup); + void readFromFT(FactorizedTable& ft, std::vector vectorsToRead, + uint64_t& numTuplesRead); + inline uint64_t getNumTuples(common::DataChunk* unflatChunk) { + return unflatChunk->getNumValueVectors() != 0 ? + unflatChunk->state->getSelVector().getSelSize() : + 1; + } + +private: + std::string fileName; + std::vector types; + std::vector columnNames; + lbug_parquet::format::CompressionCodec::type codec; + std::unique_ptr fileInfo; + std::shared_ptr protocol; + lbug_parquet::format::FileMetaData fileMetaData; + std::mutex lock; + std::vector> columnWriters; + common::offset_t fileOffset; + storage::MemoryManager* mm; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h new file mode 100644 index 0000000000..53372dcac8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h @@ -0,0 +1,105 @@ +#pragma once + +#include "basic_column_writer.h" +#include "common/serializer/serializer.h" +#include "function/cast/functions/numeric_limits.h" +#include "function/comparison/comparison_functions.h" + +namespace lbug { +namespace processor { + +template +class NumericStatisticsState : public ColumnWriterStatistics { +public: + NumericStatisticsState() + : min(function::NumericLimits::maximum()), max(function::NumericLimits::minimum()) {} + + T min; + T max; + +public: + bool hasStats() { return min <= max; } + + std::string getMin() override { + return function::NumericLimits::isSigned() ? getMinValue() : std::string(); + } + std::string getMax() override { + return function::NumericLimits::isSigned() ? getMaxValue() : std::string(); + } + std::string getMinValue() override { + return hasStats() ? std::string((char*)&min, sizeof(T)) : std::string(); + } + std::string getMaxValue() override { + return hasStats() ? std::string((char*)&max, sizeof(T)) : std::string(); + } +}; + +struct BaseParquetOperator { + template + inline static std::unique_ptr initializeStats() { + return std::make_unique>(); + } + + template + static void handleStats(ColumnWriterStatistics* stats, SRC /*sourceValue*/, TGT targetValue) { + auto& numericStats = (NumericStatisticsState&)*stats; + uint8_t result = 0; + function::LessThan::operation(targetValue, numericStats.min, result, + nullptr /* leftVector */, nullptr /* rightVector */); + if (result != 0) { + numericStats.min = targetValue; + } + function::GreaterThan::operation(targetValue, numericStats.max, result, + nullptr /* leftVector */, nullptr /* rightVector */); + if (result != 0) { + numericStats.max = targetValue; + } + } +}; + +struct ParquetCastOperator : public BaseParquetOperator { + template + static TGT Operation(SRC input) { + return TGT(input); + } +}; + +template +class StandardColumnWriter : public BasicColumnWriter { +public: + StandardColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, + std::vector schemaPath, uint64_t maxRepeat, uint64_t maxDefine, + bool canHaveNulls) + : BasicColumnWriter(writer, schemaIdx, std::move(schemaPath), maxRepeat, maxDefine, + canHaveNulls) {} + + std::unique_ptr initializeStatsState() override { + return OP::template initializeStats(); + } + + void templatedWritePlain(common::ValueVector* vector, ColumnWriterStatistics* stats, + uint64_t chunkStart, uint64_t chunkEnd, common::Serializer& ser) { + for (auto r = chunkStart; r < chunkEnd; r++) { + auto pos = getVectorPos(vector, r); + if (!vector->isNull(pos)) { + TGT targetValue = OP::template Operation(vector->getValue(pos)); + OP::template handleStats(stats, vector->getValue(pos), targetValue); + ser.write(targetValue); + } + } + } + + void writeVector(common::Serializer& bufferedSerializer, ColumnWriterStatistics* stats, + ColumnWriterPageState* /*pageState*/, common::ValueVector* vector, uint64_t chunkStart, + uint64_t chunkEnd) override { + templatedWritePlain(vector, stats, chunkStart, chunkEnd, bufferedSerializer); + } + + uint64_t getRowSize(common::ValueVector* /*vector*/, uint64_t /*index*/, + BasicColumnWriterState& /*state*/) override { + return sizeof(TGT); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/string_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/string_column_writer.h new file mode 100644 index 0000000000..be001181f3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/string_column_writer.h @@ -0,0 +1,142 @@ +#pragma once + +#include "parquet_rle_bp_encoder.h" +#include "parquet_types.h" +#include "processor/operator/persistent/writer/parquet/basic_column_writer.h" + +namespace lbug { +namespace processor { + +struct StringHash { + std::size_t operator()(const common::ku_string_t& k) const; +}; + +struct StringEquality { + bool operator()(const common::ku_string_t& a, const common::ku_string_t& b) const; +}; + +template +using string_map_t = std::unordered_map; + +class StringStatisticsState : public ColumnWriterStatistics { +public: + bool hasStats = false; + bool valuesTooBig = false; + std::string min; + std::string max; + +public: + bool hasValidStats() const { return hasStats; } + + void update(const common::ku_string_t& val); + + std::string getMin() override { return getMinValue(); } + std::string getMax() override { return getMaxValue(); } + std::string getMinValue() override { return hasValidStats() ? min : std::string(); } + std::string getMaxValue() override { return hasValidStats() ? max : std::string(); } +}; + +class StringColumnWriterState : public BasicColumnWriterState { +public: + StringColumnWriterState(lbug_parquet::format::RowGroup& rowGroup, uint64_t colIdx, + storage::MemoryManager* mm) + : BasicColumnWriterState{rowGroup, colIdx}, estimatedDictPageSize{0}, + estimatedRlePagesSize{0}, estimatedPlainSize{0}, + overflowBuffer{std::make_unique(mm)}, keyBitWidth{0} {} + + // Analysis state. + uint64_t estimatedDictPageSize; + uint64_t estimatedRlePagesSize; + uint64_t estimatedPlainSize; + + // Dictionary and accompanying string heap. + string_map_t dictionary; + std::unique_ptr overflowBuffer; + // key_bit_width== 0 signifies the chunk is written in plain encoding + uint32_t keyBitWidth; + + bool isDictionaryEncoded() const { return keyBitWidth != 0; } +}; + +class StringWriterPageState : public ColumnWriterPageState { +public: + explicit StringWriterPageState(uint32_t bitWidth, const string_map_t& values) + : bitWidth(bitWidth), dictionary(values), encoder(bitWidth), writtenValue(false) { + KU_ASSERT(isDictionaryEncoded() || (bitWidth == 0 && dictionary.empty())); + } + + inline bool isDictionaryEncoded() const { return bitWidth != 0; } + // If 0, we're writing a plain page. + uint32_t bitWidth; + const string_map_t& dictionary; + RleBpEncoder encoder; + bool writtenValue; +}; + +class StringColumnWriter : public BasicColumnWriter { +public: + StringColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, + std::vector schemaPath, uint64_t maxRepeat, uint64_t maxDefine, + bool canHaveNulls, storage::MemoryManager* mm) + : BasicColumnWriter(writer, schemaIdx, std::move(schemaPath), maxRepeat, maxDefine, + canHaveNulls), + mm{mm} {} + +public: + inline std::unique_ptr initializeStatsState() override { + return std::make_unique(); + } + + std::unique_ptr initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) override; + + inline bool hasAnalyze() override { return true; } + + inline std::unique_ptr initializePageState( + BasicColumnWriterState& state_p) override { + auto& state = reinterpret_cast(state_p); + return std::make_unique(state.keyBitWidth, state.dictionary); + } + + inline lbug_parquet::format::Encoding::type getEncoding( + BasicColumnWriterState& writerState) override { + auto& state = reinterpret_cast(writerState); + return state.isDictionaryEncoded() ? lbug_parquet::format::Encoding::RLE_DICTIONARY : + lbug_parquet::format::Encoding::PLAIN; + } + + inline bool hasDictionary(BasicColumnWriterState& writerState) override { + auto& state = reinterpret_cast(writerState); + return state.isDictionaryEncoded(); + } + + inline uint64_t dictionarySize(BasicColumnWriterState& writerState) override { + auto& state = reinterpret_cast(writerState); + KU_ASSERT(state.isDictionaryEncoded()); + return state.dictionary.size(); + } + + void analyze(ColumnWriterState& writerState, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count) override; + + void finalizeAnalyze(ColumnWriterState& writerState) override; + + void writeVector(common::Serializer& bufferedSerializer, ColumnWriterStatistics* statsToWrite, + ColumnWriterPageState* writerPageState, common::ValueVector* vector, uint64_t chunkStart, + uint64_t chunkEnd) override; + + void flushPageState(common::Serializer& bufferedSerializer, + ColumnWriterPageState* writerPageState) override; + + void flushDictionary(BasicColumnWriterState& writerState, + ColumnWriterStatistics* writerStats) override; + + uint64_t getRowSize(common::ValueVector* vector, uint64_t index, + BasicColumnWriterState& writerState) override; + +private: + storage::MemoryManager* mm; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/struct_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/struct_column_writer.h new file mode 100644 index 0000000000..97bafd0737 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/struct_column_writer.h @@ -0,0 +1,44 @@ +#pragma once + +#include "processor/operator/persistent/writer/parquet/column_writer.h" + +namespace lbug { +namespace processor { + +class StructColumnWriter : public ColumnWriter { +public: + StructColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, std::vector schema, + uint64_t maxRepeat, uint64_t maxDefine, + std::vector> childWriter, bool canHaveNull) + : ColumnWriter{writer, schemaIdx, std::move(schema), maxRepeat, maxDefine, canHaveNull}, + childWriters{std::move(childWriter)} {} + + std::vector> childWriters; + +public: + std::unique_ptr initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) override; + bool hasAnalyze() override; + void analyze(ColumnWriterState& state, ColumnWriterState* parent, common::ValueVector* vector, + uint64_t count) override; + void finalizeAnalyze(ColumnWriterState& state) override; + void prepare(ColumnWriterState& state, ColumnWriterState* parent, common::ValueVector* vector, + uint64_t count) override; + + void beginWrite(ColumnWriterState& state) override; + void write(ColumnWriterState& state, common::ValueVector* vector, uint64_t count) override; + void finalizeWrite(ColumnWriterState& state) override; +}; + +class StructColumnWriterState : public ColumnWriterState { +public: + StructColumnWriterState(lbug_parquet::format::RowGroup& rowGroup, uint64_t colIdx) + : rowGroup(rowGroup), colIdx(colIdx) {} + + lbug_parquet::format::RowGroup& rowGroup; + uint64_t colIdx; + std::vector> childStates; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/uuid_column_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/uuid_column_writer.h new file mode 100644 index 0000000000..87ea2406b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/persistent/writer/parquet/uuid_column_writer.h @@ -0,0 +1,28 @@ +#pragma once + +#include "basic_column_writer.h" +#include "common/constants.h" + +namespace lbug { +namespace processor { + +class UUIDColumnWriter : public BasicColumnWriter { +public: + UUIDColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, std::vector schemaPath, + uint64_t maxRepeat, uint64_t maxDefine, bool canHaveNulls) + : BasicColumnWriter(writer, schemaIdx, std::move(schemaPath), maxRepeat, maxDefine, + canHaveNulls) {} + +public: + void writeVector(common::Serializer& bufferedSerializer, ColumnWriterStatistics* state, + ColumnWriterPageState* pageState, common::ValueVector* vector, uint64_t chunkStart, + uint64_t chunkEnd) override; + + uint64_t getRowSize(common::ValueVector* /*vector*/, uint64_t /*index*/, + BasicColumnWriterState& /*state*/) override { + return common::ParquetConstants::PARQUET_UUID_SIZE; + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/physical_operator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/physical_operator.h new file mode 100644 index 0000000000..795693987b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/physical_operator.h @@ -0,0 +1,180 @@ +#pragma once + +#include "planner/operator/operator_print_info.h" +#include "processor/result/result_set.h" + +namespace lbug::common { +class Profiler; +class NumericMetric; +class TimeMetric; +} // namespace lbug::common +namespace lbug { +namespace processor { +struct ExecutionContext; + +using physical_op_id = uint32_t; + +enum class PhysicalOperatorType : uint8_t { + ALTER, + AGGREGATE, + AGGREGATE_FINALIZE, + AGGREGATE_SCAN, + ATTACH_DATABASE, + BATCH_INSERT, + COPY_TO, + CREATE_MACRO, + CREATE_SEQUENCE, + CREATE_TABLE, + CREATE_TYPE, + CROSS_PRODUCT, + DETACH_DATABASE, + DELETE_, + DROP, + DUMMY_SINK, + DUMMY_SIMPLE_SINK, + EMPTY_RESULT, + EXPORT_DATABASE, + EXTENSION_CLAUSE, + FILTER, + FLATTEN, + HASH_JOIN_BUILD, + HASH_JOIN_PROBE, + IMPORT_DATABASE, + INDEX_LOOKUP, + INSERT, + INTERSECT_BUILD, + INTERSECT, + INSTALL_EXTENSION, + LIMIT, + LOAD_EXTENSION, + MERGE, + MULTIPLICITY_REDUCER, + PARTITIONER, + PATH_PROPERTY_PROBE, + PRIMARY_KEY_SCAN_NODE_TABLE, + PROJECTION, + PROFILE, + RECURSIVE_EXTEND, + RESULT_COLLECTOR, + SCAN_NODE_TABLE, + SCAN_REL_TABLE, + SEMI_MASKER, + SET_PROPERTY, + SKIP, + STANDALONE_CALL, + TABLE_FUNCTION_CALL, + TOP_K, + TOP_K_SCAN, + TRANSACTION, + ORDER_BY, + ORDER_BY_MERGE, + ORDER_BY_SCAN, + UNION_ALL_SCAN, + UNWIND, + USE_DATABASE, + UNINSTALL_EXTENSION, +}; + +class PhysicalOperator; +struct PhysicalOperatorUtils { + static std::string operatorToString(const PhysicalOperator* physicalOp); + LBUG_API static std::string operatorTypeToString(PhysicalOperatorType operatorType); +}; + +struct OperatorMetrics { + common::TimeMetric& executionTime; + common::NumericMetric& numOutputTuple; + + OperatorMetrics(common::TimeMetric& executionTime, common::NumericMetric& numOutputTuple) + : executionTime{executionTime}, numOutputTuple{numOutputTuple} {} +}; + +using physical_op_vector_t = std::vector>; + +class LBUG_API PhysicalOperator { +public: + // Leaf operator + PhysicalOperator(PhysicalOperatorType operatorType, physical_op_id id, + std::unique_ptr printInfo) + : id{id}, operatorType{operatorType}, resultSet(nullptr), printInfo{std::move(printInfo)} {} + // Unary operator + PhysicalOperator(PhysicalOperatorType operatorType, std::unique_ptr child, + physical_op_id id, std::unique_ptr printInfo); + // Binary operator + PhysicalOperator(PhysicalOperatorType operatorType, std::unique_ptr left, + std::unique_ptr right, physical_op_id id, + std::unique_ptr printInfo); + PhysicalOperator(PhysicalOperatorType operatorType, physical_op_vector_t children, + physical_op_id id, std::unique_ptr printInfo); + + virtual ~PhysicalOperator() = default; + + physical_op_id getOperatorID() const { return id; } + + PhysicalOperatorType getOperatorType() const { return operatorType; } + + virtual bool isSource() const { return false; } + virtual bool isSink() const { return false; } + virtual bool isParallel() const { return true; } + + void addChild(std::unique_ptr op) { children.push_back(std::move(op)); } + PhysicalOperator* getChild(common::idx_t idx) const { return children[idx].get(); } + common::idx_t getNumChildren() const { return children.size(); } + std::unique_ptr moveUnaryChild(); + + // Global state is initialized once. + void initGlobalState(ExecutionContext* context); + // Local state is initialized for each thread. + void initLocalState(ResultSet* resultSet, ExecutionContext* context); + + bool getNextTuple(ExecutionContext* context); + + virtual void finalize(ExecutionContext* context); + + std::unordered_map getProfilerKeyValAttributes( + common::Profiler& profiler) const; + std::vector getProfilerAttributes(common::Profiler& profiler) const; + + const OPPrintInfo* getPrintInfo() const { return printInfo.get(); } + + virtual std::unique_ptr copy() = 0; + + virtual double getProgress(ExecutionContext* context) const; + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + template + const TARGET& constCast() { + return common::ku_dynamic_cast(*this); + } + +protected: + virtual void initGlobalStateInternal(ExecutionContext* /*context*/) {} + virtual void initLocalStateInternal(ResultSet* /*resultSet_*/, ExecutionContext* /*context*/) {} + // Return false if no more tuples to pull, otherwise return true + virtual bool getNextTuplesInternal(ExecutionContext* context) = 0; + + std::string getTimeMetricKey() const { return "time-" + std::to_string(id); } + std::string getNumTupleMetricKey() const { return "numTuple-" + std::to_string(id); } + + void registerProfilingMetrics(common::Profiler* profiler); + + double getExecutionTime(common::Profiler& profiler) const; + uint64_t getNumOutputTuples(common::Profiler& profiler) const; + + virtual void finalizeInternal(ExecutionContext* /*context*/) {} + +protected: + physical_op_id id; + std::unique_ptr metrics; + PhysicalOperatorType operatorType; + + physical_op_vector_t children; + ResultSet* resultSet; + std::unique_ptr printInfo; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/profile.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/profile.h new file mode 100644 index 0000000000..0151387d84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/profile.h @@ -0,0 +1,34 @@ +#pragma once + +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { +class PhysicalPlan; + +struct ProfileInfo { + PhysicalPlan* physicalPlan = nullptr; +}; + +class Profile final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::PROFILE; + +public: + Profile(ProfileInfo info, std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, info{info} {} + + void setPhysicalPlan(PhysicalPlan* physicalPlan) { info.physicalPlan = physicalPlan; } + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info, messageTable, id, printInfo->copy()); + } + +private: + ProfileInfo info; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/projection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/projection.h new file mode 100644 index 0000000000..7f23a5219f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/projection.h @@ -0,0 +1,79 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct ProjectionPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + + explicit ProjectionPrintInfo(binder::expression_vector expressions) + : expressions{std::move(expressions)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new ProjectionPrintInfo(*this)); + } + +private: + ProjectionPrintInfo(const ProjectionPrintInfo& other) + : OPPrintInfo{other}, expressions{other.expressions} {} +}; + +struct ProjectionInfo { + std::vector> evaluators; + std::vector exprsOutputPos; + std::unordered_set activeChunkIndices; + std::unordered_set discardedChunkIndices; + + ProjectionInfo() = default; + EXPLICIT_COPY_DEFAULT_MOVE(ProjectionInfo); + + void addEvaluator(std::unique_ptr evaluator, + const DataPos& outputPos) { + evaluators.push_back(std::move(evaluator)); + exprsOutputPos.push_back(outputPos); + activeChunkIndices.insert(outputPos.dataChunkPos); + } + +private: + ProjectionInfo(const ProjectionInfo& other) + : evaluators{copyVector(other.evaluators)}, exprsOutputPos{other.exprsOutputPos}, + activeChunkIndices{other.activeChunkIndices}, + discardedChunkIndices{other.discardedChunkIndices} {} +}; + +class Projection final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::PROJECTION; + +public: + Projection(ProjectionInfo info, std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : PhysicalOperator(type_, std::move(child), id, std::move(printInfo)), + info{std::move(info)}, prevMultiplicity{1} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), children[0]->copy(), id, + printInfo->copy()); + } + +private: + void saveMultiplicity() { prevMultiplicity = resultSet->multiplicity; } + + void restoreMultiplicity() { resultSet->multiplicity = prevMultiplicity; } + +private: + ProjectionInfo info; + uint64_t prevMultiplicity; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/recursive_extend.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/recursive_extend.h new file mode 100644 index 0000000000..b877c0df7d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/recursive_extend.h @@ -0,0 +1,55 @@ +#pragma once + +#include "function/gds/rec_joins.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct RecursiveExtendPrintInfo final : OPPrintInfo { + std::string funcName; + + explicit RecursiveExtendPrintInfo(std::string funcName) : funcName{std::move(funcName)} {} + + std::string toString() const override { return funcName; } + + std::unique_ptr copy() const override { + return std::unique_ptr(new RecursiveExtendPrintInfo(*this)); + } + +private: + RecursiveExtendPrintInfo(const RecursiveExtendPrintInfo& other) + : OPPrintInfo{other}, funcName{other.funcName} {} +}; + +class RecursiveExtend : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::RECURSIVE_EXTEND; + +public: + RecursiveExtend(std::unique_ptr function, function::RJBindData bindData, + std::shared_ptr sharedState, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, id, std::move(printInfo)}, function{std::move(function)}, bindData{bindData}, + sharedState{std::move(sharedState)} {} + + std::shared_ptr getSharedState() const { return sharedState; } + + bool isSource() const override { return true; } + + bool isParallel() const override { return false; } + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(function->copy(), bindData, sharedState, id, + printInfo->copy()); + } + +private: + std::unique_ptr function; + function::RJBindData bindData; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/recursive_extend_shared_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/recursive_extend_shared_state.h new file mode 100644 index 0000000000..bd2a200ac2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/recursive_extend_shared_state.h @@ -0,0 +1,50 @@ +#pragma once + +#include "common/counter.h" +#include "common/mask.h" +#include "graph/graph.h" +#include "processor/result/factorized_table_pool.h" + +namespace lbug { +namespace processor { + +struct RecursiveExtendSharedState { + std::unique_ptr graph; + std::unique_ptr counter = nullptr; + + RecursiveExtendSharedState(std::shared_ptr fTable, + std::unique_ptr graph, common::offset_t limitNumber) + : graph{std::move(graph)}, factorizedTablePool{std::move(fTable)} { + if (limitNumber != common::INVALID_LIMIT) { + counter = std::make_unique(limitNumber); + } + } + + void setInputNodeMask(std::unique_ptr maskMap) { + inputNodeMask = std::move(maskMap); + } + common::NodeOffsetMaskMap* getInputNodeMaskMap() const { return inputNodeMask.get(); } + + void setOutputNodeMask(std::unique_ptr maskMap) { + outputNodeMask = std::move(maskMap); + } + common::NodeOffsetMaskMap* getOutputNodeMaskMap() const { return outputNodeMask.get(); } + + void setPathNodeMask(std::unique_ptr maskMap) { + pathNodeMask = std::move(maskMap); + } + common::NodeOffsetMaskMap* getPathNodeMaskMap() const { return pathNodeMask.get(); } + + bool exceedLimit() const { return !(counter == nullptr) && counter->exceedLimit(); } + +public: + FactorizedTablePool factorizedTablePool; + +private: + std::unique_ptr inputNodeMask = nullptr; + std::unique_ptr outputNodeMask = nullptr; + std::unique_ptr pathNodeMask = nullptr; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/result_collector.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/result_collector.h new file mode 100644 index 0000000000..4c91b83655 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/result_collector.h @@ -0,0 +1,107 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" +#include "common/enums/accumulate_type.h" +#include "processor/operator/sink.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +class ResultCollectorSharedState { +public: + explicit ResultCollectorSharedState(std::shared_ptr table) + : table{std::move(table)} {} + + void mergeLocalTable(FactorizedTable& localTable) { + std::unique_lock lck{mtx}; + table->merge(localTable); + } + + std::shared_ptr getTable() { return table; } + +private: + std::mutex mtx; + std::shared_ptr table; +}; + +struct ResultCollectorInfo { + common::AccumulateType accumulateType; + FactorizedTableSchema tableSchema; + std::vector payloadPositions; + + ResultCollectorInfo(common::AccumulateType accumulateType, FactorizedTableSchema tableSchema, + std::vector payloadPositions) + : accumulateType{accumulateType}, tableSchema{std::move(tableSchema)}, + payloadPositions{std::move(payloadPositions)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ResultCollectorInfo); + +private: + ResultCollectorInfo(const ResultCollectorInfo& other) + : accumulateType{other.accumulateType}, tableSchema{other.tableSchema.copy()}, + payloadPositions{other.payloadPositions} {} +}; + +struct ResultCollectorPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + common::AccumulateType accumulateType; + + ResultCollectorPrintInfo(binder::expression_vector expressions, + common::AccumulateType accumulateType) + : expressions{std::move(expressions)}, accumulateType{accumulateType} {} + ResultCollectorPrintInfo(const ResultCollectorPrintInfo& other) + : OPPrintInfo{other}, expressions{other.expressions}, accumulateType{other.accumulateType} { + } + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +class ResultCollector final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::RESULT_COLLECTOR; + +public: + ResultCollector(ResultCollectorInfo info, + std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, info{std::move(info)}, + sharedState{std::move(sharedState)} {} + + void executeInternal(ExecutionContext* context) override; + + void finalizeInternal(ExecutionContext* context) override; + + std::shared_ptr getResultFTable() const override { + return sharedState->getTable(); + } + + std::unique_ptr getQueryResult() const override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), sharedState, children[0]->copy(), id, + printInfo->copy()); + } + +private: + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void initNecessaryLocalState(ResultSet* resultSet, ExecutionContext* context); + +private: + ResultCollectorInfo info; + std::shared_ptr sharedState; + std::vector payloadVectors; + std::vector payloadAndMarkVectors; + + std::unique_ptr markVector; + std::unique_ptr localTable; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/primary_key_scan_node_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/primary_key_scan_node_table.h new file mode 100644 index 0000000000..a27d1c48f6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/primary_key_scan_node_table.h @@ -0,0 +1,73 @@ +#pragma once + +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/scan/scan_node_table.h" + +namespace lbug { +namespace processor { + +struct PrimaryKeyScanPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + std::string key; + std::string alias; + + PrimaryKeyScanPrintInfo(binder::expression_vector expressions, std::string key, + std::string alias) + : expressions(std::move(expressions)), key(std::move(key)), alias{std::move(alias)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new PrimaryKeyScanPrintInfo(*this)); + } + +private: + PrimaryKeyScanPrintInfo(const PrimaryKeyScanPrintInfo& other) + : OPPrintInfo(other), expressions(other.expressions), alias(other.alias) {} +}; + +struct PrimaryKeyScanSharedState { + std::mutex mtx; + + common::idx_t numTables; + common::idx_t cursor; + + explicit PrimaryKeyScanSharedState(common::idx_t numTables) : numTables{numTables}, cursor{0} {} + + common::idx_t getTableIdx(); +}; + +class PrimaryKeyScanNodeTable : public ScanTable { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::PRIMARY_KEY_SCAN_NODE_TABLE; + +public: + PrimaryKeyScanNodeTable(ScanOpInfo opInfo, std::vector tableInfos, + std::unique_ptr indexEvaluator, + std::shared_ptr sharedState, physical_op_id id, + std::unique_ptr printInfo) + : ScanTable{type_, std::move(opInfo), id, std::move(printInfo)}, scanState{nullptr}, + tableInfos{std::move(tableInfos)}, indexEvaluator{std::move(indexEvaluator)}, + sharedState{std::move(sharedState)} {} + + bool isSource() const override { return true; } + + void initLocalStateInternal(ResultSet*, ExecutionContext*) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + bool isParallel() const override { return false; } + + std::unique_ptr copy() override { + return std::make_unique(opInfo.copy(), copyVector(tableInfos), + indexEvaluator->copy(), sharedState, id, printInfo->copy()); + } + +private: + std::unique_ptr scanState; + std::vector tableInfos; + std::unique_ptr indexEvaluator; + std::shared_ptr sharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_multi_rel_tables.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_multi_rel_tables.h new file mode 100644 index 0000000000..4cf70685e7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_multi_rel_tables.h @@ -0,0 +1,93 @@ +#pragma once + +#include "processor/operator/scan/scan_rel_table.h" + +namespace lbug { +namespace processor { + +struct DirectionInfo { + bool extendFromSource; + DataPos directionPos; + + DirectionInfo() : extendFromSource{false}, directionPos{DataPos::getInvalidPos()} {} + EXPLICIT_COPY_DEFAULT_MOVE(DirectionInfo); + + bool needFlip(common::RelDataDirection relDataDirection) const; + +private: + DirectionInfo(const DirectionInfo& other) + : extendFromSource{other.extendFromSource}, directionPos{other.directionPos} {} +}; + +class RelTableCollectionScanner { + friend class ScanMultiRelTable; + +public: + explicit RelTableCollectionScanner(std::vector relInfos) + : relInfos{std::move(relInfos)} {} + EXPLICIT_COPY_DEFAULT_MOVE(RelTableCollectionScanner); + + bool empty() const { return relInfos.empty(); } + + void resetState() { + currentTableIdx = 0; + nextTableIdx = 0; + } + + void addRelInfos(std::vector relInfos_) { + for (auto& relInfo : relInfos_) { + relInfos.push_back(std::move(relInfo)); + } + } + + bool scan(main::ClientContext* context, storage::RelTableScanState& scanState, + const std::vector& outVectors); + +private: + RelTableCollectionScanner(const RelTableCollectionScanner& other) + : relInfos{copyVector(other.relInfos)} {} + +private: + std::vector relInfos; + std::vector directionValues; + common::ValueVector* directionVector = nullptr; + common::idx_t currentTableIdx = common::INVALID_IDX; + uint32_t nextTableIdx = 0; +}; + +class ScanMultiRelTable final : public ScanTable { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SCAN_REL_TABLE; + +public: + ScanMultiRelTable(ScanOpInfo info, DirectionInfo directionInfo, + common::table_id_map_t scanners, + std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : ScanTable{type_, std::move(info), std::move(child), id, std::move(printInfo)}, + directionInfo{std::move(directionInfo)}, scanState{nullptr}, boundNodeIDVector{nullptr}, + scanners{std::move(scanners)}, currentScanner{nullptr} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(opInfo.copy(), directionInfo.copy(), + copyUnorderedMap(scanners), children[0]->copy(), id, printInfo->copy()); + } + +private: + void resetState(); + void initCurrentScanner(const common::nodeID_t& nodeID); + +private: + DirectionInfo directionInfo; + std::unique_ptr scanState; + + common::ValueVector* boundNodeIDVector; + common::table_id_map_t scanners; + RelTableCollectionScanner* currentScanner; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_node_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_node_table.h new file mode 100644 index 0000000000..0ca38bd6c8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_node_table.h @@ -0,0 +1,126 @@ +#pragma once + +#include "processor/operator/scan/scan_table.h" +#include "storage/predicate/column_predicate.h" +#include "storage/table/node_table.h" + +namespace lbug { +namespace processor { + +struct ScanNodeTableProgressSharedState { + std::atomic numGroupsScanned; + common::node_group_idx_t numGroups; + + ScanNodeTableProgressSharedState() : numGroupsScanned{0}, numGroups{0} {}; +}; + +class ScanNodeTableSharedState { +public: + explicit ScanNodeTableSharedState(std::unique_ptr semiMask) + : table{nullptr}, currentCommittedGroupIdx{common::INVALID_NODE_GROUP_IDX}, + currentUnCommittedGroupIdx{common::INVALID_NODE_GROUP_IDX}, numCommittedNodeGroups{0}, + numUnCommittedNodeGroups{0}, semiMask{std::move(semiMask)} {}; + + void initialize(const transaction::Transaction* transaction, storage::NodeTable* table, + ScanNodeTableProgressSharedState& progressSharedState); + + void nextMorsel(storage::NodeTableScanState& scanState, + ScanNodeTableProgressSharedState& progressSharedState); + + common::SemiMask* getSemiMask() const { return semiMask.get(); } + +private: + std::mutex mtx; + storage::NodeTable* table; + common::node_group_idx_t currentCommittedGroupIdx; + common::node_group_idx_t currentUnCommittedGroupIdx; + common::node_group_idx_t numCommittedNodeGroups; + common::node_group_idx_t numUnCommittedNodeGroups; + std::unique_ptr semiMask; +}; + +struct ScanNodeTablePrintInfo final : OPPrintInfo { + std::vector tableNames; + std::string alias; + binder::expression_vector properties; + + ScanNodeTablePrintInfo(std::vector tableNames, std::string alias, + binder::expression_vector properties) + : tableNames{std::move(tableNames)}, alias{std::move(alias)}, + properties{std::move(properties)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new ScanNodeTablePrintInfo(*this)); + } + +private: + ScanNodeTablePrintInfo(const ScanNodeTablePrintInfo& other) + : OPPrintInfo{other}, tableNames{other.tableNames}, alias{other.alias}, + properties{other.properties} {} +}; + +struct ScanNodeTableInfo : ScanTableInfo { + ScanNodeTableInfo(storage::Table* table, + std::vector columnPredicates) + : ScanTableInfo{table, std::move(columnPredicates)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ScanNodeTableInfo); + + void initScanState(storage::TableScanState& scanState, + const std::vector& outVectors, main::ClientContext* context) override; + +private: + ScanNodeTableInfo(const ScanNodeTableInfo& other) : ScanTableInfo{other} {} +}; + +class ScanNodeTable final : public ScanTable { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SCAN_NODE_TABLE; + +public: + ScanNodeTable(ScanOpInfo opInfo, std::vector tableInfos, + std::vector> sharedStates, uint32_t id, + std::unique_ptr printInfo, + std::shared_ptr progressSharedState) + : ScanTable{type_, std::move(opInfo), id, std::move(printInfo)}, currentTableIdx{0}, + scanState{nullptr}, tableInfos{std::move(tableInfos)}, + sharedStates{std::move(sharedStates)}, + progressSharedState{std::move(progressSharedState)} { + KU_ASSERT(this->tableInfos.size() == this->sharedStates.size()); + } + + common::table_id_map_t getSemiMasks() const; + + bool isSource() const override { return true; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + const ScanNodeTableSharedState& getSharedState(common::idx_t idx) const { + KU_ASSERT(idx < sharedStates.size()); + return *sharedStates[idx]; + } + + std::unique_ptr copy() override { + return std::make_unique(opInfo.copy(), copyVector(tableInfos), sharedStates, + id, printInfo->copy(), progressSharedState); + } + + double getProgress(ExecutionContext* context) const override; + +private: + void initGlobalStateInternal(ExecutionContext* context) override; + + void initCurrentTable(ExecutionContext* context); + +private: + common::idx_t currentTableIdx; + std::unique_ptr scanState; + std::vector tableInfos; + std::vector> sharedStates; + std::shared_ptr progressSharedState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_rel_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_rel_table.h new file mode 100644 index 0000000000..004d6db60e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_rel_table.h @@ -0,0 +1,87 @@ +#pragma once + +#include "binder/expression/rel_expression.h" +#include "common/enums/extend_direction.h" +#include "processor/operator/scan/scan_table.h" +#include "storage/predicate/column_predicate.h" +#include "storage/table/rel_table.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} +namespace processor { + +struct ScanRelTableInfo : ScanTableInfo { + common::RelDataDirection direction; + + ScanRelTableInfo(storage::Table* table, + std::vector columnPredicates, + common::RelDataDirection direction) + : ScanTableInfo{table, std::move(columnPredicates)}, direction{direction} {} + EXPLICIT_COPY_DEFAULT_MOVE(ScanRelTableInfo); + + void initScanState(storage::TableScanState& scanState, + const std::vector& outVectors, main::ClientContext* context) override; + +private: + ScanRelTableInfo(const ScanRelTableInfo& other) + : ScanTableInfo{other}, direction{other.direction} {} +}; + +struct ScanRelTablePrintInfo final : OPPrintInfo { + std::vector tableNames; + binder::expression_vector properties; + std::shared_ptr boundNode; + std::shared_ptr rel; + std::shared_ptr nbrNode; + common::ExtendDirection direction; + std::string alias; + + ScanRelTablePrintInfo(std::vector tableNames, binder::expression_vector properties, + std::shared_ptr boundNode, + std::shared_ptr rel, std::shared_ptr nbrNode, + common::ExtendDirection direction, std::string alias) + : tableNames{std::move(tableNames)}, properties{std::move(properties)}, + boundNode{std::move(boundNode)}, rel{std::move(rel)}, nbrNode{std::move(nbrNode)}, + direction{direction}, alias{std::move(alias)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new ScanRelTablePrintInfo(*this)); + } + +private: + ScanRelTablePrintInfo(const ScanRelTablePrintInfo& other) + : OPPrintInfo{other}, tableNames{other.tableNames}, properties{other.properties}, + boundNode{other.boundNode}, rel{other.rel}, nbrNode{other.nbrNode}, + direction{other.direction}, alias{other.alias} {} +}; + +class ScanRelTable final : public ScanTable { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SCAN_REL_TABLE; + +public: + ScanRelTable(ScanOpInfo info, ScanRelTableInfo tableInfo, + std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : ScanTable{type_, std::move(info), std::move(child), id, std::move(printInfo)}, + tableInfo{std::move(tableInfo)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(opInfo.copy(), tableInfo.copy(), children[0]->copy(), + id, printInfo->copy()); + } + +protected: + ScanRelTableInfo tableInfo; + std::unique_ptr scanState; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_table.h new file mode 100644 index 0000000000..28e7c696da --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/scan/scan_table.h @@ -0,0 +1,112 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "processor/operator/physical_operator.h" +#include "storage/table/table.h" + +namespace lbug { +namespace processor { + +struct ScanOpInfo { + // Node ID vector position. + DataPos nodeIDPos; + // Output vector (properties or CSRs) positions + std::vector outVectorsPos; + + ScanOpInfo(DataPos nodeIDPos, std::vector outVectorsPos) + : nodeIDPos{nodeIDPos}, outVectorsPos{std::move(outVectorsPos)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ScanOpInfo); + +private: + ScanOpInfo(const ScanOpInfo& other) + : nodeIDPos{other.nodeIDPos}, outVectorsPos{other.outVectorsPos} {} +}; + +// For multi-table scan, a column with the same name could be of different types. In such case, +// we scan the original type from storage and then cast at operator level +class ColumnCaster { +public: + explicit ColumnCaster(common::LogicalType columnType) : columnType{std::move(columnType)} {} + EXPLICIT_COPY_DEFAULT_MOVE(ColumnCaster); + + void setCastExpr(std::shared_ptr expr) { castExpr = std::move(expr); } + bool hasCast() const { return castExpr != nullptr; } + + // Generate temporary vectors for scanning + void init(common::ValueVector* vectorAfterCasting, storage::MemoryManager* memoryManager); + // Get temporary vector for scanning. This vector has the same data type as column. + common::ValueVector* getVectorBeforeCasting() const { return vectorBeforeCasting.get(); } + + void cast(); + +private: + ColumnCaster(const ColumnCaster& other) + : columnType{other.columnType.copy()}, castExpr{other.castExpr} {} + + common::LogicalType columnType; + std::shared_ptr castExpr; + + // vector for scanning; same data type as column + std::shared_ptr vectorBeforeCasting = nullptr; + // vector after casting. This should be the vector in result set so we don't manage its life + // cycle + common::ValueVector* vectorAfterCasting = nullptr; + + std::vector> funcInputVectors; + std::vector funcInputSelVectors; +}; + +struct ScanTableInfo { + storage::Table* table; + + ScanTableInfo(storage::Table* table, std::vector columnPredicates) + : table{table}, columnPredicates{std::move(columnPredicates)} {} + virtual ~ScanTableInfo() = default; + + void addColumnInfo(common::column_id_t columnID, ColumnCaster caster); + + virtual void initScanState(storage::TableScanState& scanState, + const std::vector& outVectors, main::ClientContext* context) = 0; + + void castColumns(); + +protected: + ScanTableInfo(const ScanTableInfo& other) + : table{other.table}, columnIDs{other.columnIDs}, + columnPredicates{copyVector(other.columnPredicates)}, + columnCasters{copyVector(other.columnCasters)}, hasColumnCaster{other.hasColumnCaster} {} + + void initScanStateVectors(storage::TableScanState& scanState, + const std::vector& outVectors, storage::MemoryManager* memoryManager); + + // Column ids to scan + std::vector columnIDs; + // Column predicates for zone map + std::vector columnPredicates; + // Column cast handler for multi table scan of the same column name but different type + std::vector columnCasters; + bool hasColumnCaster = false; +}; + +class ScanTable : public PhysicalOperator { +public: + ScanTable(PhysicalOperatorType operatorType, ScanOpInfo info, + std::unique_ptr child, physical_op_id id, + std::unique_ptr printInfo) + : PhysicalOperator{operatorType, std::move(child), id, std::move(printInfo)}, + opInfo{std::move(info)} {} + + ScanTable(PhysicalOperatorType operatorType, ScanOpInfo info, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{operatorType, id, std::move(printInfo)}, opInfo{std::move(info)} {} + +protected: + void initLocalStateInternal(ResultSet*, ExecutionContext*) override; + +protected: + ScanOpInfo opInfo; + std::vector outVectors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/semi_masker.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/semi_masker.h new file mode 100644 index 0000000000..f008ae5556 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/semi_masker.h @@ -0,0 +1,213 @@ +#pragma once + +#include + +#include "common/enums/extend_direction.h" +#include "common/mask.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +class BaseSemiMasker; + +struct SemiMaskerLocalState { + common::table_id_map_t> localMasksPerTable; + common::SemiMask* singleTableRef = nullptr; + + void maskSingleTable(common::offset_t offset) const { singleTableRef->mask(offset); } + void maskMultiTable(common::nodeID_t nodeID) const { + KU_ASSERT(localMasksPerTable.contains(nodeID.tableID)); + localMasksPerTable.at(nodeID.tableID)->mask(nodeID.offset); + } +}; + +class SemiMaskerSharedState { +public: + explicit SemiMaskerSharedState( + common::table_id_map_t> masksPerTable) + : masksPerTable{std::move(masksPerTable)} {} + + SemiMaskerLocalState* appendLocalState(); + + void mergeToGlobal(); + +private: + common::table_id_map_t> masksPerTable; + std::vector> localInfos; + std::mutex mtx; +}; + +struct SemiMaskerPrintInfo final : OPPrintInfo { + std::vector operatorNames; + + explicit SemiMaskerPrintInfo(std::vector operatorNames) + : operatorNames{std::move(operatorNames)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new SemiMaskerPrintInfo(*this)); + } + +private: + SemiMaskerPrintInfo(const SemiMaskerPrintInfo& other) + : OPPrintInfo{other}, operatorNames{other.operatorNames} {} +}; + +class BaseSemiMasker : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SEMI_MASKER; + +protected: + BaseSemiMasker(DataPos keyPos, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, keyPos{keyPos}, + keyVector{nullptr}, sharedState{std::move(sharedState)}, localState{nullptr} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + void finalizeInternal(ExecutionContext* context) final; + +protected: + DataPos keyPos; + common::ValueVector* keyVector; + std::shared_ptr sharedState; + SemiMaskerLocalState* localState; +}; + +class SingleTableSemiMasker final : public BaseSemiMasker { +public: + SingleTableSemiMasker(DataPos keyPos, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : BaseSemiMasker{keyPos, std::move(sharedState), std::move(child), id, + std::move(printInfo)} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(keyPos, sharedState, children[0]->copy(), id, + printInfo->copy()); + } +}; + +class MultiTableSemiMasker final : public BaseSemiMasker { +public: + MultiTableSemiMasker(DataPos keyPos, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : BaseSemiMasker{keyPos, std::move(sharedState), std::move(child), id, + std::move(printInfo)} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(keyPos, sharedState, children[0]->copy(), id, + printInfo->copy()); + } +}; + +class NodeIDsSemiMask : public BaseSemiMasker { +protected: + NodeIDsSemiMask(DataPos keyPos, DataPos srcNodeIDPos, DataPos dstNodeIDPos, + std::shared_ptr sharedState, std::unique_ptr child, + uint32_t id, std::unique_ptr printInfo) + : BaseSemiMasker{keyPos, std::move(sharedState), std::move(child), id, + std::move(printInfo)}, + srcNodeIDPos{srcNodeIDPos}, dstNodeIDPos{dstNodeIDPos} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final; + +protected: + DataPos srcNodeIDPos; + DataPos dstNodeIDPos; + + common::ValueVector* srcNodeIDVector = nullptr; + common::ValueVector* dstNodeIDVector = nullptr; +}; + +class NodeIDsSingleTableSemiMasker final : public NodeIDsSemiMask { +public: + NodeIDsSingleTableSemiMasker(DataPos keyPos, DataPos srcNodeIDPos, DataPos dstNodeIDPos, + std::shared_ptr sharedState, std::unique_ptr child, + uint32_t id, std::unique_ptr printInfo) + : NodeIDsSemiMask{keyPos, srcNodeIDPos, dstNodeIDPos, std::move(sharedState), + std::move(child), id, std::move(printInfo)} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(keyPos, srcNodeIDPos, dstNodeIDPos, + sharedState, children[0]->copy(), id, printInfo->copy()); + } +}; + +class NodeIDsMultipleTableSemiMasker final : public NodeIDsSemiMask { +public: + NodeIDsMultipleTableSemiMasker(DataPos keyPos, DataPos srcNodeIDPos, DataPos dstNodeIDPos, + std::shared_ptr sharedState, std::unique_ptr child, + uint32_t id, std::unique_ptr printInfo) + : NodeIDsSemiMask{keyPos, srcNodeIDPos, dstNodeIDPos, std::move(sharedState), + std::move(child), id, std::move(printInfo)} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(keyPos, srcNodeIDPos, dstNodeIDPos, + sharedState, children[0]->copy(), id, printInfo->copy()); + } +}; + +class PathSemiMasker : public BaseSemiMasker { +protected: + PathSemiMasker(DataPos keyPos, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo, common::ExtendDirection direction) + : BaseSemiMasker{keyPos, std::move(sharedState), std::move(child), id, + std::move(printInfo)}, + direction{direction} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final; + +protected: + common::ValueVector* pathRelsVector = nullptr; + common::ValueVector* pathRelsSrcIDDataVector = nullptr; + common::ValueVector* pathRelsDstIDDataVector = nullptr; + common::ExtendDirection direction; +}; + +class PathSingleTableSemiMasker final : public PathSemiMasker { +public: + PathSingleTableSemiMasker(DataPos keyPos, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo, common::ExtendDirection direction) + : PathSemiMasker{keyPos, std::move(sharedState), std::move(child), id, std::move(printInfo), + direction} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(keyPos, sharedState, children[0]->copy(), + id, printInfo->copy(), direction); + } +}; + +class PathMultipleTableSemiMasker final : public PathSemiMasker { +public: + PathMultipleTableSemiMasker(DataPos keyPos, std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo, common::ExtendDirection direction) + : PathSemiMasker{keyPos, std::move(sharedState), std::move(child), id, std::move(printInfo), + direction} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(keyPos, sharedState, + children[0]->copy(), id, printInfo->copy(), direction); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/attach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/attach_database.h new file mode 100644 index 0000000000..e510bd3e64 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/attach_database.h @@ -0,0 +1,47 @@ +#pragma once + +#include "binder/bound_attach_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct AttachDatabasePrintInfo final : OPPrintInfo { + std::string dbName; + std::string dbPath; + + AttachDatabasePrintInfo(std::string dbName, std::string dbPath) + : dbName{std::move(dbName)}, dbPath{std::move(dbPath)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new AttachDatabasePrintInfo(*this)); + } + +private: + AttachDatabasePrintInfo(const AttachDatabasePrintInfo& other) + : OPPrintInfo{other}, dbName{other.dbName}, dbPath{other.dbPath} {} +}; + +class AttachDatabase final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::ATTACH_DATABASE; + +public: + AttachDatabase(binder::AttachInfo attachInfo, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + attachInfo{std::move(attachInfo)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(attachInfo, messageTable, id, printInfo->copy()); + } + +private: + binder::AttachInfo attachInfo; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/detach_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/detach_database.h new file mode 100644 index 0000000000..d145b884a0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/detach_database.h @@ -0,0 +1,44 @@ +#pragma once + +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct DetatchDatabasePrintInfo final : OPPrintInfo { + std::string name; + + explicit DetatchDatabasePrintInfo(std::string name) : name{std::move(name)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new DetatchDatabasePrintInfo(*this)); + } + +private: + DetatchDatabasePrintInfo(const DetatchDatabasePrintInfo& other) + : OPPrintInfo{other}, name{other.name} {} +}; + +class DetachDatabase final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DETACH_DATABASE; + +public: + DetachDatabase(std::string dbName, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + dbName{std::move(dbName)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(dbName, messageTable, id, printInfo->copy()); + } + +private: + std::string dbName; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/export_db.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/export_db.h new file mode 100644 index 0000000000..d6d1a4fa9b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/export_db.h @@ -0,0 +1,60 @@ +#pragma once + +#include "common/copier_config/file_scan_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct ExportDBSharedState final { + std::unordered_map*> canUseParallelReader; +}; + +struct ExportDBPrintInfo final : OPPrintInfo { + std::string filePath; + common::case_insensitive_map_t options; + + ExportDBPrintInfo(std::string filePath, common::case_insensitive_map_t options) + : filePath{std::move(filePath)}, options{std::move(options)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new ExportDBPrintInfo(*this)); + } + +private: + ExportDBPrintInfo(const ExportDBPrintInfo& other) + : OPPrintInfo{other}, filePath{other.filePath}, options{other.options} {} +}; + +class ExportDB final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::EXPORT_DATABASE; + +public: + ExportDB(common::FileScanInfo boundFileInfo, bool schemaOnly, + std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo, + std::shared_ptr sharedState = std::make_shared()) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + boundFileInfo{std::move(boundFileInfo)}, schemaOnly{schemaOnly}, + sharedState{std::move(sharedState)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(boundFileInfo.copy(), schemaOnly, messageTable, id, + printInfo->copy(), sharedState); + } + + void addToParallelReaderMap(const std::string& file, const std::atomic& parallelFlag) { + sharedState->canUseParallelReader.insert({file, ¶llelFlag}); + } + +private: + common::FileScanInfo boundFileInfo; + bool schemaOnly; + std::shared_ptr sharedState = nullptr; +}; +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/extension_print_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/extension_print_info.h new file mode 100644 index 0000000000..00cf3bf4bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/extension_print_info.h @@ -0,0 +1,16 @@ +#pragma once + +#include "planner/operator/operator_print_info.h" + +namespace lbug { +namespace processor { + +struct ExtensionPrintInfo : OPPrintInfo { + std::string extensionName; + + explicit ExtensionPrintInfo(std::string extensionName) + : extensionName{std::move(extensionName)} {} +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/import_db.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/import_db.h new file mode 100644 index 0000000000..0cf5785c17 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/import_db.h @@ -0,0 +1,30 @@ +#pragma once + +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +class ImportDB final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::IMPORT_DATABASE; + +public: + ImportDB(std::string query, std::string indexQuery, + std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + query{std::move(query)}, indexQuery{std::move(indexQuery)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(query, indexQuery, messageTable, id, printInfo->copy()); + } + +private: + std::string query; + std::string indexQuery; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/install_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/install_extension.h new file mode 100644 index 0000000000..19f319868d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/install_extension.h @@ -0,0 +1,45 @@ +#pragma once + +#include "extension/extension_installer.h" +#include "extension_print_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct InstallExtensionPrintInfo final : public ExtensionPrintInfo { + explicit InstallExtensionPrintInfo(std::string extensionName) + : ExtensionPrintInfo{std::move(extensionName)} {} + + std::string toString() const override { return "Install " + extensionName; } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +class InstallExtension final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::INSTALL_EXTENSION; + +public: + InstallExtension(extension::InstallExtensionInfo info, + std::shared_ptr messageTable, physical_op_id id, + std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + info{std::move(info)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(info, messageTable, id, printInfo->copy()); + } + +private: + void setOutputMessage(bool installed, storage::MemoryManager* memoryManager); + +private: + extension::InstallExtensionInfo info; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/load_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/load_extension.h new file mode 100644 index 0000000000..899bf01461 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/load_extension.h @@ -0,0 +1,45 @@ +#pragma once + +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct LoadExtensionPrintInfo final : OPPrintInfo { + std::string extensionName; + + explicit LoadExtensionPrintInfo(std::string extensionName) + : extensionName{std::move(extensionName)} {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new LoadExtensionPrintInfo(*this)); + } + +private: + LoadExtensionPrintInfo(const LoadExtensionPrintInfo& other) + : OPPrintInfo{other}, extensionName{other.extensionName} {} +}; + +class LoadExtension final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::LOAD_EXTENSION; + +public: + LoadExtension(std::string path, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + path{std::move(path)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(path, messageTable, id, printInfo->copy()); + } + +private: + std::string path; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/uninstall_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/uninstall_extension.h new file mode 100644 index 0000000000..3283d2bd4b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/uninstall_extension.h @@ -0,0 +1,40 @@ +#pragma once + +#include "extension_print_info.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct UninstallExtensionPrintInfo final : public ExtensionPrintInfo { + explicit UninstallExtensionPrintInfo(std::string extensionName) + : ExtensionPrintInfo{std::move(extensionName)} {} + + std::string toString() const override { return "Uninstall " + extensionName; } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +class UninstallExtension final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::UNINSTALL_EXTENSION; + +public: + UninstallExtension(std::string path, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + path{std::move(path)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(path, messageTable, id, printInfo->copy()); + } + +private: + std::string path; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/use_database.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/use_database.h new file mode 100644 index 0000000000..773bd72bb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/simple/use_database.h @@ -0,0 +1,44 @@ +#pragma once + +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +struct UseDatabasePrintInfo final : OPPrintInfo { + std::string dbName; + + explicit UseDatabasePrintInfo(std::string dbName) : dbName(std::move(dbName)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new UseDatabasePrintInfo(*this)); + } + +private: + UseDatabasePrintInfo(const UseDatabasePrintInfo& other) + : OPPrintInfo(other), dbName(other.dbName) {} +}; + +class UseDatabase final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::USE_DATABASE; + +public: + UseDatabase(std::string dbName, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : SimpleSink{type_, std::move(messageTable), id, std::move(printInfo)}, + dbName{std::move(dbName)} {} + + void executeInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(dbName, messageTable, id, printInfo->copy()); + } + +private: + std::string dbName; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/sink.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/sink.h new file mode 100644 index 0000000000..bd9d9182bb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/sink.h @@ -0,0 +1,123 @@ +#pragma once + +#include "common/exception/internal.h" +#include "common/metric.h" +#include "processor/operator/physical_operator.h" +#include "processor/result/factorized_table.h" +#include "processor/result/result_set_descriptor.h" + +namespace lbug { +namespace main { +class QueryResult; +} +namespace processor { + +class LBUG_API Sink : public PhysicalOperator { +public: + Sink(PhysicalOperatorType operatorType, physical_op_id id, + std::unique_ptr printInfo) + : PhysicalOperator{operatorType, id, std::move(printInfo)} {} + Sink(PhysicalOperatorType operatorType, std::unique_ptr child, + physical_op_id id, std::unique_ptr printInfo) + : PhysicalOperator{operatorType, std::move(child), id, std::move(printInfo)} {} + + bool isSink() const override { return true; } + + void setDescriptor(std::unique_ptr descriptor) { + KU_ASSERT(resultSetDescriptor == nullptr); + resultSetDescriptor = std::move(descriptor); + } + std::unique_ptr getResultSet(storage::MemoryManager* memoryManager); + + void execute(ResultSet* resultSet, ExecutionContext* context) { + initLocalState(resultSet, context); + metrics->executionTime.start(); + executeInternal(context); + metrics->executionTime.stop(); + } + + virtual std::unique_ptr getQueryResult() const { + throw common::InternalException( + common::stringFormat("{} operator does not implement getQueryResult.", + PhysicalOperatorUtils::operatorTypeToString(operatorType))); + } + + virtual std::shared_ptr getResultFTable() const { + throw common::InternalException(common::stringFormat( + "Trying to get result table from {} operator which doesn't have one.", + PhysicalOperatorUtils::operatorTypeToString(operatorType))); + } + + virtual bool terminate() const { return false; } + + std::unique_ptr copy() override = 0; + +protected: + virtual void executeInternal(ExecutionContext* context) = 0; + + bool getNextTuplesInternal(ExecutionContext* /*context*/) final { + throw common::InternalException( + "getNextTupleInternal() should not be called on sink operator."); + } + +protected: + std::unique_ptr resultSetDescriptor; +}; + +class LBUG_API DummySink final : public Sink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DUMMY_SINK; + +public: + DummySink(std::unique_ptr child, uint32_t id) + : Sink{type_, std::move(child), id, OPPrintInfo::EmptyInfo()} {} + + std::unique_ptr copy() override { + return std::make_unique(children[0]->copy(), id); + } + +protected: + void executeInternal(ExecutionContext* context) override { + while (children[0]->getNextTuple(context)) { + // DO NOTHING. + } + } +}; + +class SimpleSink : public Sink { +public: + SimpleSink(PhysicalOperatorType operatorType, std::shared_ptr messageTable, + physical_op_id id, std::unique_ptr printInfo) + : Sink{operatorType, id, std::move(printInfo)}, messageTable{std::move(messageTable)} {} + + bool isSource() const final { return true; } + bool isParallel() const final { return false; } + + std::unique_ptr getQueryResult() const override; + + std::shared_ptr getResultFTable() const override { return messageTable; } + +protected: + void appendMessage(const std::string& msg, storage::MemoryManager* memoryManager); + +protected: + std::shared_ptr messageTable; +}; + +// For cases like Export. We need a parent for ExportDB and multiple CopyTo. This parent does not +// have any logic other than propagating the result fTable. +class DummySimpleSink final : public SimpleSink { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DUMMY_SIMPLE_SINK; + +public: + DummySimpleSink(std::shared_ptr messageTable, physical_op_id id) + : SimpleSink{type_, std::move(messageTable), id, OPPrintInfo::EmptyInfo()} {} + + void executeInternal(ExecutionContext*) override {} + + std::unique_ptr copy() override { + return std::make_unique(messageTable, id); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/skip.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/skip.h new file mode 100644 index 0000000000..a83c0e3471 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/skip.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include "processor/operator/filtering_operator.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct SkipPrintInfo final : OPPrintInfo { + uint64_t number; + + explicit SkipPrintInfo(std::int64_t number) : number(number) {} + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new SkipPrintInfo(*this)); + } + +private: + SkipPrintInfo(const SkipPrintInfo& other) : OPPrintInfo(other), number(other.number) {} +}; + +class Skip final : public PhysicalOperator, public SelVectorOverWriter { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SKIP; + +public: + Skip(uint64_t skipNumber, std::shared_ptr counter, + uint32_t dataChunkToSelectPos, std::unordered_set dataChunksPosInScope, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + skipNumber{skipNumber}, counter{std::move(counter)}, + dataChunkToSelectPos{dataChunkToSelectPos}, + dataChunksPosInScope{std::move(dataChunksPosInScope)} {} + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(skipNumber, counter, dataChunkToSelectPos, dataChunksPosInScope, + children[0]->copy(), id, printInfo->copy()); + } + +private: + uint64_t skipNumber; + std::shared_ptr counter; + uint32_t dataChunkToSelectPos; + std::shared_ptr dataChunkToSelect; + std::unordered_set dataChunksPosInScope; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/standalone_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/standalone_call.h new file mode 100644 index 0000000000..7495f0a845 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/standalone_call.h @@ -0,0 +1,65 @@ +#pragma once + +#include "common/types/value/value.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace main { +struct Option; +} +namespace processor { + +struct StandaloneCallPrintInfo final : OPPrintInfo { + std::string functionName; + + explicit StandaloneCallPrintInfo(std::string functionName) + : functionName(std::move(functionName)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new StandaloneCallPrintInfo(*this)); + } + +private: + StandaloneCallPrintInfo(const StandaloneCallPrintInfo& other) + : OPPrintInfo(other), functionName(other.functionName) {} +}; + +struct StandaloneCallInfo { + const main::Option* option; + common::Value optionValue; + // TODO: we should remove this. + bool hasExecuted = false; + + StandaloneCallInfo(const main::Option* option, common::Value optionValue) + : option{option}, optionValue{std::move(optionValue)} {} + EXPLICIT_COPY_DEFAULT_MOVE(StandaloneCallInfo); + +private: + StandaloneCallInfo(const StandaloneCallInfo& other) + : option{other.option}, optionValue{other.optionValue} {} +}; + +class StandaloneCall final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::STANDALONE_CALL; + +public: + StandaloneCall(StandaloneCallInfo info, uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, standaloneCallInfo{std::move(info)} {} + + bool isSource() const override { return true; } + bool isParallel() const override { return false; } + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(standaloneCallInfo.copy(), id, printInfo->copy()); + } + +private: + StandaloneCallInfo standaloneCallInfo; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_function_call.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_function_call.h new file mode 100644 index 0000000000..cffdb62c82 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_function_call.h @@ -0,0 +1,82 @@ +#pragma once + +#include "function/table/bind_data.h" +#include "function/table/table_function.h" +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +struct TableFunctionCallInfo { + function::TableFunction function{}; + std::unique_ptr bindData; + std::vector outPosV; + + TableFunctionCallInfo() = default; + EXPLICIT_COPY_DEFAULT_MOVE(TableFunctionCallInfo); + +private: + TableFunctionCallInfo(const TableFunctionCallInfo& other) { + function = other.function; + bindData = other.bindData->copy(); + outPosV = other.outPosV; + } +}; + +struct TableFunctionCallPrintInfo final : OPPrintInfo { + std::string funcName; + binder::expression_vector exprs; + + explicit TableFunctionCallPrintInfo(std::string funcName, binder::expression_vector exprs) + : funcName(std::move(funcName)), exprs(std::move(exprs)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new TableFunctionCallPrintInfo(*this)); + } + +private: + TableFunctionCallPrintInfo(const TableFunctionCallPrintInfo& other) + : OPPrintInfo(other), funcName(other.funcName), exprs(other.exprs) {} +}; + +class LBUG_API TableFunctionCall final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::TABLE_FUNCTION_CALL; + +public: + TableFunctionCall(TableFunctionCallInfo info, + std::shared_ptr sharedState, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, info{std::move(info)}, + sharedState{std::move(sharedState)} {} + + const TableFunctionCallInfo& getInfo() const { return info; } + std::shared_ptr getSharedState() const { return sharedState; } + + bool isSource() const override { return true; } + + bool isParallel() const override { return info.function.canParallelFunc(); } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + void finalizeInternal(ExecutionContext* context) override; + + double getProgress(ExecutionContext* context) const override; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), sharedState, id, printInfo->copy()); + } + +private: + TableFunctionCallInfo info; + std::shared_ptr sharedState; + std::unique_ptr localState = nullptr; + std::unique_ptr funcInput = nullptr; + std::unique_ptr funcOutput = nullptr; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_scan/ftable_scan_function.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_scan/ftable_scan_function.h new file mode 100644 index 0000000000..d8ed0ae8c6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_scan/ftable_scan_function.h @@ -0,0 +1,35 @@ +#pragma once + +#include "function/table/bind_data.h" +#include "function/table/table_function.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +struct FTableScanBindData : public function::TableFuncBindData { + std::shared_ptr table; + std::vector columnIndices; + uint64_t morselSize; + + FTableScanBindData(std::shared_ptr table, + std::vector columnIndices, uint64_t morselSize) + : table{std::move(table)}, columnIndices{std::move(columnIndices)}, morselSize{morselSize} { + } + FTableScanBindData(const FTableScanBindData& other) + : function::TableFuncBindData{other}, table{other.table}, + columnIndices{other.columnIndices}, morselSize{other.morselSize} {} + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct FTableScan { + static constexpr const char* name = "READ_FTABLE"; + + static std::unique_ptr getFunction(); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_scan/union_all_scan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_scan/union_all_scan.h new file mode 100644 index 0000000000..518954d5bd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/table_scan/union_all_scan.h @@ -0,0 +1,97 @@ +#pragma once + +#include + +#include "binder/expression/expression.h" +#include "processor/operator/physical_operator.h" +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +struct UnionAllScanPrintInfo final : OPPrintInfo { + binder::expression_vector expressions; + + explicit UnionAllScanPrintInfo(binder::expression_vector expressions) + : expressions(std::move(expressions)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new UnionAllScanPrintInfo(*this)); + } + +private: + UnionAllScanPrintInfo(const UnionAllScanPrintInfo& other) + : OPPrintInfo(other), expressions(other.expressions) {} +}; + +struct UnionAllScanInfo { + std::vector outputPositions; + std::vector columnIndices; + + UnionAllScanInfo(std::vector outputPositions, std::vector columnIndices) + : outputPositions{std::move(outputPositions)}, columnIndices{std::move(columnIndices)} {} + EXPLICIT_COPY_DEFAULT_MOVE(UnionAllScanInfo); + +private: + UnionAllScanInfo(const UnionAllScanInfo& other) + : outputPositions{other.outputPositions}, columnIndices{other.columnIndices} {} +}; + +struct UnionAllScanMorsel { + FactorizedTable* table; + uint64_t startTupleIdx; + uint64_t numTuples; + + UnionAllScanMorsel(FactorizedTable* table, uint64_t startTupleIdx, uint64_t numTuples) + : table{table}, startTupleIdx{startTupleIdx}, numTuples{numTuples} {} +}; + +class UnionAllScanSharedState { +public: + UnionAllScanSharedState(std::vector> tables, + uint64_t maxMorselSize) + : tables{std::move(tables)}, maxMorselSize{maxMorselSize}, tableIdx{0}, + nextTupleIdxToScan{0} {} + + std::unique_ptr getMorsel(); + +private: + std::unique_ptr getMorselNoLock(FactorizedTable* table); + +private: + std::mutex mtx; + std::vector> tables; + uint64_t maxMorselSize; + uint64_t tableIdx; + uint64_t nextTupleIdxToScan; +}; + +class UnionAllScan : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::UNION_ALL_SCAN; + +public: + UnionAllScan(UnionAllScanInfo info, std::shared_ptr sharedState, + physical_op_id id, std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, info{std::move(info)}, + sharedState{std::move(sharedState)} {} + + bool isSource() const final { return true; } + + void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final; + + bool getNextTuplesInternal(ExecutionContext* context) final; + + std::unique_ptr copy() override { + return std::make_unique(info.copy(), sharedState, id, printInfo->copy()); + } + +private: + UnionAllScanInfo info; + std::shared_ptr sharedState; + std::vector vectors; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/transaction.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/transaction.h new file mode 100644 index 0000000000..0ca31f2e70 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/transaction.h @@ -0,0 +1,60 @@ +#pragma once + +#include "processor/operator/physical_operator.h" +#include "transaction/transaction_action.h" + +namespace lbug { +namespace transaction { +class TransactionContext; +} // namespace transaction + +namespace processor { + +struct TransactionPrintInfo final : OPPrintInfo { + transaction::TransactionAction action; + + explicit TransactionPrintInfo(transaction::TransactionAction action) : action(action) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new TransactionPrintInfo(*this)); + } + +private: + TransactionPrintInfo(const TransactionPrintInfo& other) + : OPPrintInfo(other), action(other.action) {} +}; + +class Transaction final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::TRANSACTION; + +public: + Transaction(transaction::TransactionAction transactionAction, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, transactionAction{transactionAction}, + hasExecuted{false} {} + + bool isSource() const final { return true; } + bool isParallel() const final { return false; } + + void initLocalStateInternal(ResultSet* /*resultSet_*/, ExecutionContext* /*context*/) final { + hasExecuted = false; + } + + bool getNextTuplesInternal(ExecutionContext* context) final; + + std::unique_ptr copy() override { + return std::make_unique(transactionAction, id, printInfo->copy()); + } + +private: + void validateActiveTransaction(const transaction::TransactionContext& context) const; + +private: + transaction::TransactionAction transactionAction; + bool hasExecuted; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/unwind.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/unwind.h new file mode 100644 index 0000000000..72b12db443 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/operator/unwind.h @@ -0,0 +1,66 @@ +#pragma once + +#include "expression_evaluator/expression_evaluator.h" +#include "processor/operator/physical_operator.h" +#include "processor/result/result_set.h" + +namespace lbug { +namespace processor { + +struct UnwindPrintInfo final : OPPrintInfo { + std::shared_ptr inExpression; + std::shared_ptr outExpression; + + UnwindPrintInfo(std::shared_ptr inExpression, + std::shared_ptr outExpression) + : inExpression(std::move(inExpression)), outExpression(std::move(outExpression)) {} + + std::string toString() const override; + + std::unique_ptr copy() const override { + return std::unique_ptr(new UnwindPrintInfo(*this)); + } + +private: + UnwindPrintInfo(const UnwindPrintInfo& other) + : OPPrintInfo(other), inExpression(other.inExpression), outExpression(other.outExpression) { + } +}; + +class Unwind : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::UNWIND; + +public: + Unwind(DataPos outDataPos, DataPos idPos, + std::unique_ptr expressionEvaluator, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{type_, std::move(child), id, std::move(printInfo)}, + outDataPos{outDataPos}, idPos(idPos), expressionEvaluator{std::move(expressionEvaluator)}, + startIndex{0u} {} + + bool getNextTuplesInternal(ExecutionContext* context) override; + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + std::unique_ptr copy() override { + return make_unique(outDataPos, idPos, expressionEvaluator->copy(), + children[0]->copy(), id, printInfo->copy()); + } + +private: + bool hasMoreToRead() const; + void copyTuplesToOutVector(uint64_t startPos, uint64_t endPos) const; + + DataPos outDataPos; + DataPos idPos; + + std::unique_ptr expressionEvaluator; + std::shared_ptr outValueVector; + common::ValueVector* idVector = nullptr; + uint32_t startIndex; + common::list_entry_t listEntry; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/physical_plan.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/physical_plan.h new file mode 100644 index 0000000000..c4faa87997 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/physical_plan.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "processor/operator/physical_operator.h" + +namespace lbug { +namespace processor { + +class PhysicalPlan { +public: + explicit PhysicalPlan(std::unique_ptr lastOperator) + : lastOperator{std::move(lastOperator)} {} + +public: + std::unique_ptr lastOperator; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/plan_mapper.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/plan_mapper.h new file mode 100644 index 0000000000..3f907facf7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/plan_mapper.h @@ -0,0 +1,258 @@ +#pragma once + +#include "common/arrow/arrow_result_config.h" +#include "main/query_result.h" +#include "planner/operator/logical_operator.h" +#include "processor/execution_context.h" +#include "processor/operator/result_collector.h" +#include "processor/physical_plan.h" + +namespace lbug { +namespace common { +enum class RelDataDirection : uint8_t; +class SemiMask; +class NodeOffsetMaskMap; +class SemiMask; +} // namespace common +namespace main { +class ClientContext; +} +namespace extension { +class MapperExtension; +} + +namespace binder { +struct BoundCopyFromInfo; +struct BoundDeleteInfo; +struct BoundSetPropertyInfo; +} // namespace binder + +namespace catalog { +class TableCatalogEntry; +} + +namespace planner { +class LogicalSemiMasker; +struct LogicalInsertInfo; +class LogicalCopyFrom; +class LogicalPlan; +} // namespace planner + +namespace processor { + +struct HashJoinBuildInfo; +struct AggregateInfo; +class NodeInsertExecutor; +class RelInsertExecutor; +class NodeSetExecutor; +class RelSetExecutor; +class NodeDeleteExecutor; +class RelDeleteExecutor; +struct NodeTableDeleteInfo; +struct NodeTableSetInfo; +struct RelTableSetInfo; +struct BatchInsertSharedState; +struct PartitionerSharedState; +class RelBatchInsertImpl; +class ArrowResultCollector; + +class PlanMapper { +public: + explicit PlanMapper(ExecutionContext* executionContext); + + std::unique_ptr getPhysicalPlan(const planner::LogicalPlan* logicalPlan, + const binder::expression_vector& expressions, main::QueryResultType resultType, + common::ArrowResultConfig arrowConfig); + + uint32_t getOperatorID() { return physicalOperatorID++; } + + static DataPos getDataPos(const binder::Expression& expression, const planner::Schema& schema) { + return DataPos(schema.getExpressionPos(expression)); + } + + // Assume scans all columns of table in the same order as given expressions. + LBUG_API std::unique_ptr createFTableScanAligned( + const binder::expression_vector& exprs, const planner::Schema* schema, + std::shared_ptr table, uint64_t maxMorselSize, + physical_op_vector_t children); + + LBUG_API std::unique_ptr mapOperator( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapAccumulate( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapAggregate(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapAlter(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapAttachDatabase( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCopyFrom(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCopyNodeFrom( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCopyRelFrom( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCopyTo(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCreateMacro( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCreateSequence( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCreateTable( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCreateType( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCrossProduct( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDelete(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDeleteNode( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDeleteRel(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDetachDatabase( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDistinct(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDrop(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDummyScan(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapDummySink(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapEmptyResult( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapExplain(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapExpressionsScan( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapExtend(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapExtension(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapExportDatabase( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapFilter(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapFlatten(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapHashJoin(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapImportDatabase( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapIndexLookup( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapIntersect(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapInsert(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapLimit(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapMerge(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapMultiplicityReducer( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapNodeLabelFilter( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapNoop(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapOrderBy(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapPartitioner( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapPathPropertyProbe( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapProjection( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapRecursiveExtend( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapScanNodeTable( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapSemiMasker( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapSetProperty( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapSetNodeProperty( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapSetRelProperty( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapStandaloneCall( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapTableFunctionCall( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapTransaction( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapUnionAll(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapUnwind(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapUseDatabase( + const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapExtensionClause( + const planner::LogicalOperator* logicalOperator); + + std::unique_ptr createResultCollector(common::AccumulateType accumulateType, + const binder::expression_vector& expressions, planner::Schema* schema, + std::unique_ptr prevOperator); + std::unique_ptr createArrowResultCollector( + common::ArrowResultConfig arrowConfig, const binder::expression_vector& expressions, + planner::Schema* schema, std::unique_ptr prevOperator); + + // Scan fTable + std::unique_ptr createFTableScan(const binder::expression_vector& exprs, + std::vector colIndices, const planner::Schema* schema, + std::shared_ptr table, uint64_t maxMorselSize, + physical_op_vector_t children); + // Scan is the leaf operator of physical plan. + std::unique_ptr createFTableScan(const binder::expression_vector& exprs, + const std::vector& colIndices, const planner::Schema* schema, + std::shared_ptr table, uint64_t maxMorselSize); + // Do not scan anything from table. Serves as a control logic of pull model. + std::unique_ptr createEmptyFTableScan(std::shared_ptr table, + uint64_t maxMorselSize, physical_op_vector_t children); + std::unique_ptr createEmptyFTableScan(std::shared_ptr table, + uint64_t maxMorselSize, std::unique_ptr child); + // Do not scan anything from table. Serves as a control logic of pull model. + // Scan is the leaf operator of physical plan. + std::unique_ptr createEmptyFTableScan(std::shared_ptr table, + uint64_t maxMorselSize); + // Assume scans all columns of table in the same order as given expressions. + // Scan fTable without row offset. + // Scan is the leaf operator of physical plan. + std::unique_ptr createFTableScanAligned( + const binder::expression_vector& exprs, const planner::Schema* schema, + std::shared_ptr table, uint64_t maxMorselSize); + + static HashJoinBuildInfo createHashBuildInfo(const planner::Schema& buildSideSchema, + const binder::expression_vector& keys, const binder::expression_vector& payloads); + + std::unique_ptr createDistinctHashAggregate( + const binder::expression_vector& keys, const binder::expression_vector& payloads, + planner::Schema* inSchema, planner::Schema* outSchema, + std::unique_ptr prevOperator); + std::unique_ptr createHashAggregate(const binder::expression_vector& keys, + const binder::expression_vector& payloads, const binder::expression_vector& aggregates, + planner::Schema* inSchema, planner::Schema* outSchema, + std::unique_ptr prevOperator); + + NodeInsertExecutor getNodeInsertExecutor(const planner::LogicalInsertInfo* boundInfo, + const planner::Schema& inSchema, const planner::Schema& outSchema) const; + RelInsertExecutor getRelInsertExecutor(const planner::LogicalInsertInfo* boundInfo, + const planner::Schema& inSchema, const planner::Schema& outSchema) const; + std::unique_ptr getNodeSetExecutor( + const binder::BoundSetPropertyInfo& boundInfo, const planner::Schema& schema) const; + std::unique_ptr getRelSetExecutor(const binder::BoundSetPropertyInfo& boundInfo, + const planner::Schema& schema) const; + std::unique_ptr getNodeDeleteExecutor( + const binder::BoundDeleteInfo& boundInfo, const planner::Schema& schema) const; + std::unique_ptr getRelDeleteExecutor( + const binder::BoundDeleteInfo& boundInfo, const planner::Schema& schema) const; + NodeTableDeleteInfo getNodeTableDeleteInfo(const catalog::TableCatalogEntry& entry, + DataPos pkPos) const; + + static void mapSIPJoin(PhysicalOperator* joinRoot); + + static std::vector getDataPos(const binder::expression_vector& expressions, + const planner::Schema& schema); + static FactorizedTableSchema createFlatFTableSchema( + const binder::expression_vector& expressions, const planner::Schema& schema); + std::unique_ptr createSemiMask(common::table_id_t tableID) const; + + void addOperatorMapping(const planner::LogicalOperator* logicalOp, + PhysicalOperator* physicalOp) { + KU_ASSERT(!logicalOpToPhysicalOpMap.contains(logicalOp)); + logicalOpToPhysicalOpMap.insert({logicalOp, physicalOp}); + } + void eraseOperatorMapping(const planner::LogicalOperator* logicalOp) { + KU_ASSERT(logicalOpToPhysicalOpMap.contains(logicalOp)); + logicalOpToPhysicalOpMap.erase(logicalOp); + } + +public: + ExecutionContext* executionContext; + main::ClientContext* clientContext; + +private: + std::unordered_map logicalOpToPhysicalOpMap; + physical_op_id physicalOperatorID; + std::vector mapperExtensions; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/processor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/processor.h new file mode 100644 index 0000000000..6db96b29d6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/processor.h @@ -0,0 +1,37 @@ +#pragma once + +#include "common/task_system/task_scheduler.h" + +namespace lbug { +namespace main { +class QueryResult; +} +namespace processor { +class FactorizedTable; +class PhysicalPlan; +class PhysicalOperator; +class QueryProcessor { + +public: +#if defined(__APPLE__) + explicit QueryProcessor(uint64_t numThreads, uint32_t threadQos); +#else + explicit QueryProcessor(uint64_t numThreads); +#endif + + common::TaskScheduler* getTaskScheduler() { return taskScheduler.get(); } + + std::unique_ptr execute(PhysicalPlan* physicalPlan, + ExecutionContext* context); + +private: + void decomposePlanIntoTask(PhysicalOperator* op, common::Task* task, ExecutionContext* context); + + void initTask(common::Task* task); + +private: + std::unique_ptr taskScheduler; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/processor_task.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/processor_task.h new file mode 100644 index 0000000000..5d2247df62 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/processor_task.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common/task_system/task.h" +#include "processor/operator/sink.h" + +namespace lbug { +namespace processor { + +class ProcessorTask : public common::Task { + friend class QueryProcessor; + +public: + ProcessorTask(Sink* sink, ExecutionContext* executionContext); + + void run() override; + + void finalize() override; + + bool terminate() override; + +private: + bool sharedStateInitialized; + Sink* sink; + ExecutionContext* executionContext; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/base_hash_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/base_hash_table.h new file mode 100644 index 0000000000..a6f6812993 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/base_hash_table.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include "common/copy_constructors.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "processor/result/factorized_table.h" +#include "processor/result/factorized_table_schema.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} +namespace processor { + +using compare_function_t = + std::function; +using ft_compare_function_t = + std::function; + +class BaseHashTable { +public: + BaseHashTable(storage::MemoryManager& memoryManager, common::logical_type_vec_t keyTypes); + + virtual ~BaseHashTable() = default; + + DELETE_COPY_DEFAULT_MOVE(BaseHashTable); + + const FactorizedTableSchema* getTableSchema() const { + return factorizedTable->getTableSchema(); + } + uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); } + uint64_t getCapacity() const { return maxNumHashSlots; } + const FactorizedTable* getFactorizedTable() const { return factorizedTable.get(); } + +protected: + static constexpr uint64_t HASH_BLOCK_SIZE = common::TEMP_PAGE_SIZE; + + uint64_t getSlotIdxForHash(common::hash_t hash) const { return hash % maxNumHashSlots; } + void setMaxNumHashSlots(uint64_t newSize); + + void computeVectorHashes(const std::vector& flatKeyVectors) { + computeVectorHashes(constSpan(flatKeyVectors)); + } + void computeVectorHashes(std::span flatKeyVectors); + void initSlotConstant(uint64_t numSlotsPerBlock); + bool matchFlatVecWithEntry(const std::vector& keyVectors, + const uint8_t* entry); + + template + std::span constSpan(const std::vector& vector) { + return std::span(const_cast(vector.data()), vector.size()); + } + +private: + void initCompareFuncs(); + void initTmpHashVector(); + +protected: + uint64_t maxNumHashSlots; + uint64_t numSlotsPerBlockLog2; + uint64_t slotIdxInBlockMask; + std::vector> hashSlotsBlocks; + storage::MemoryManager* memoryManager; + std::unique_ptr factorizedTable; + std::vector compareEntryFuncs; + std::vector ftCompareEntryFuncs; + std::vector keyTypes; + // Temporary arrays to hold intermediate results for appending. + std::shared_ptr hashState; + std::unique_ptr hashVector; + common::SelectionVector hashSelVec; + std::unique_ptr tmpHashResultVector; + std::unique_ptr tmpHashCombineResultVector; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table.h new file mode 100644 index 0000000000..d39a2e9418 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table.h @@ -0,0 +1,322 @@ +#pragma once + +#include +#include + +#include "common/in_mem_overflow_buffer.h" +#include "common/types/value/value.h" +#include "common/vector/value_vector.h" +#include "factorized_table_schema.h" +#include "flat_tuple.h" + +namespace lbug { +namespace storage { +class MemoryManager; +} +namespace processor { + +struct BlockAppendingInfo { + BlockAppendingInfo(uint8_t* data, uint64_t numTuplesToAppend) + : data{data}, numTuplesToAppend{numTuplesToAppend} {} + + uint8_t* data; + uint64_t numTuplesToAppend; +}; + +// This struct allocates and holds one bmBackedBlock when constructed. The bmBackedBlock will be +// released when this struct goes out of scope. +class DataBlock { +public: + DataBlock(storage::MemoryManager* mm, uint64_t size); + ~DataBlock(); + + uint8_t* getData() const; + std::span getSizedData() const; + uint8_t* getWritableData() const; + void resetNumTuplesAndFreeSize(); + void resetToZero(); + + // Manually set the underlying memory buffer to evicted to avoid double free + void preventDestruction(); + + static void copyTuples(DataBlock* blockToCopyFrom, ft_tuple_idx_t tupleIdxToCopyFrom, + DataBlock* blockToCopyInto, ft_tuple_idx_t tupleIdxToCopyTo, uint32_t numTuplesToCopy, + uint32_t numBytesPerTuple); + +public: + uint32_t numTuples; + uint64_t freeSize; + +private: + std::unique_ptr block; +}; + +class DataBlockCollection { +public: + // This interface is used for unFlat tuple blocks, for which numBytesPerTuple and + // numTuplesPerBlock are useless. + DataBlockCollection() : numBytesPerTuple{UINT32_MAX}, numTuplesPerBlock{UINT32_MAX} {} + DataBlockCollection(uint32_t numBytesPerTuple, uint32_t numTuplesPerBlock) + : numBytesPerTuple{numBytesPerTuple}, numTuplesPerBlock{numTuplesPerBlock} {} + + void append(std::unique_ptr otherBlock) { blocks.push_back(std::move(otherBlock)); } + void append(std::vector> otherBlocks) { + std::move(begin(otherBlocks), end(otherBlocks), back_inserter(blocks)); + } + void append(std::unique_ptr other) { append(std::move(other->blocks)); } + bool needAllocation(uint64_t size) const { return isEmpty() || blocks.back()->freeSize < size; } + + bool isEmpty() const { return blocks.empty(); } + const std::vector>& getBlocks() const { return blocks; } + DataBlock* getBlock(ft_block_idx_t blockIdx) { return blocks[blockIdx].get(); } + DataBlock* getLastBlock() { return blocks.back().get(); } + + void merge(DataBlockCollection& other); + void preventDestruction() const { + for (auto& block : blocks) { + block->preventDestruction(); + } + } + +private: + uint32_t numBytesPerTuple; + uint32_t numTuplesPerBlock; + std::vector> blocks; +}; + +class FlatTupleIterator; + +class LBUG_API FactorizedTable { + friend FlatTupleIterator; + friend class JoinHashTable; + friend class PathPropertyProbe; + +public: + FactorizedTable(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema); + ~FactorizedTable(); + void append(const std::vector& vectors); + + //! This function appends an empty tuple to the factorizedTable and returns a pointer to that + //! tuple. + uint8_t* appendEmptyTuple(); + + // This function scans numTuplesToScan of rows to vectors starting at tupleIdx. Callers are + // responsible for making sure all the parameters are valid. + void scan(std::span vectors, ft_tuple_idx_t tupleIdx, + uint64_t numTuplesToScan) const { + std::vector colIdxes(tableSchema.getNumColumns()); + iota(colIdxes.begin(), colIdxes.end(), 0); + scan(vectors, tupleIdx, numTuplesToScan, colIdxes); + } + bool isEmpty() const { return getNumTuples() == 0; } + void scan(std::span vectors, ft_tuple_idx_t tupleIdx, + uint64_t numTuplesToScan, std::span colIdxToScan) const; + // TODO(Guodong): Unify these two interfaces along with `readUnflatCol`. + // startPos is the starting position in the tuplesToRead, not the starting position in the + // factorizedTable + void lookup(std::span vectors, std::span colIdxesToScan, + uint8_t** tuplesToRead, uint64_t startPos, uint64_t numTuplesToRead) const; + void lookup(std::vector& vectors, + const common::SelectionVector* selVector, std::vector& colIdxesToScan, + uint8_t* tupleToRead) const; + void lookup(std::vector& vectors, std::vector& colIdxesToScan, + std::vector& tupleIdxesToRead, uint64_t startPos, + uint64_t numTuplesToRead) const; + + // When we merge two factorizedTables, we need to update the hasNoNullGuarantee based on + // other factorizedTable. + void mergeMayContainNulls(FactorizedTable& other); + void merge(FactorizedTable& other); + + common::InMemOverflowBuffer* getInMemOverflowBuffer() const { + return inMemOverflowBuffer.get(); + } + + bool hasUnflatCol() const; + bool hasUnflatCol(std::vector& colIdxes) const { + return std::any_of(colIdxes.begin(), colIdxes.end(), + [this](ft_col_idx_t colIdx) { return !tableSchema.getColumn(colIdx)->isFlat(); }); + } + + uint64_t getNumTuples() const { return numTuples; } + uint64_t getTotalNumFlatTuples() const; + uint64_t getNumFlatTuples(ft_tuple_idx_t tupleIdx) const; + + const std::vector>& getTupleDataBlocks() { + return flatTupleBlockCollection->getBlocks(); + } + const FactorizedTableSchema* getTableSchema() const { return &tableSchema; } + + template + TYPE getData(ft_block_idx_t blockIdx, ft_block_offset_t blockOffset, + ft_col_offset_t colOffset) const { + return *((TYPE*)getCell(blockIdx, blockOffset, colOffset)); + } + + uint8_t* getTuple(ft_tuple_idx_t tupleIdx) const; + + void updateFlatCell(uint8_t* tuplePtr, ft_col_idx_t colIdx, common::ValueVector* valueVector, + uint32_t pos); + void updateFlatCellNoNull(uint8_t* ftTuplePtr, ft_col_idx_t colIdx, void* dataBuf) { + memcpy(ftTuplePtr + tableSchema.getColOffset(colIdx), dataBuf, + tableSchema.getColumn(colIdx)->getNumBytes()); + } + + uint64_t getNumTuplesPerBlock() const { return numFlatTuplesPerBlock; } + + bool hasNoNullGuarantee(ft_col_idx_t colIdx) const { + return tableSchema.getColumn(colIdx)->hasNoNullGuarantee(); + } + + bool isOverflowColNull(const uint8_t* nullBuffer, ft_tuple_idx_t tupleIdx, + ft_col_idx_t colIdx) const; + bool isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx_t colIdx) const; + bool isNonOverflowColNull(ft_tuple_idx_t tupleIdx, ft_col_idx_t colIdx) const; + void setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx); + void clear(); + + storage::MemoryManager* getMemoryManager() { return memoryManager; } + + void resize(uint64_t numTuples); + + template + void forEach(Func func) { + for (auto& tupleBlock : flatTupleBlockCollection->getBlocks()) { + uint8_t* tuple = tupleBlock->getData(); + for (auto i = 0u; i < tupleBlock->numTuples; i++) { + func(tuple); + tuple += getTableSchema()->getNumBytesPerTuple(); + } + } + } + + static std::shared_ptr EmptyTable(storage::MemoryManager* mm) { + return std::make_shared(mm, FactorizedTableSchema()); + } + + void setPreventDestruction(bool preventDestruction) { + this->preventDestruction = preventDestruction; + } + +private: + void setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx); + + uint64_t computeNumTuplesToAppend( + const std::vector& vectorsToAppend) const; + + uint8_t* getCell(ft_block_idx_t blockIdx, ft_block_offset_t blockOffset, + ft_col_offset_t colOffset) const { + return flatTupleBlockCollection->getBlock(blockIdx)->getData() + + blockOffset * tableSchema.getNumBytesPerTuple() + colOffset; + } + std::pair getBlockIdxAndTupleIdxInBlock( + uint64_t tupleIdx) const { + return std::make_pair(tupleIdx / numFlatTuplesPerBlock, tupleIdx % numFlatTuplesPerBlock); + } + + std::vector allocateFlatTupleBlocks(uint64_t numTuplesToAppend); + uint8_t* allocateUnflatTupleBlock(uint32_t numBytes); + void copyFlatVectorToFlatColumn(const common::ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, ft_col_idx_t colIdx); + void copyUnflatVectorToFlatColumn(const common::ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, uint64_t numAppendedTuples, ft_col_idx_t colIdx); + void copyVectorToFlatColumn(const common::ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, uint64_t numAppendedTuples, + ft_col_idx_t colIdx) { + vector.state->isFlat() ? + copyFlatVectorToFlatColumn(vector, blockAppendInfo, colIdx) : + copyUnflatVectorToFlatColumn(vector, blockAppendInfo, numAppendedTuples, colIdx); + } + void copyVectorToUnflatColumn(const common::ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, ft_col_idx_t colIdx); + void copyVectorToColumn(const common::ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, uint64_t numAppendedTuples, ft_col_idx_t colIdx); + common::overflow_value_t appendVectorToUnflatTupleBlocks(const common::ValueVector& vector, + ft_col_idx_t colIdx); + + // TODO(Guodong): Unify these two `readUnflatCol()` with a (possibly templated) copy executor. + void readUnflatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, + common::ValueVector& vector) const; + void readUnflatCol(const uint8_t* tupleToRead, const common::SelectionVector& selVector, + ft_col_idx_t colIdx, common::ValueVector& vector) const; + void readFlatColToFlatVector(uint8_t* tupleToRead, ft_col_idx_t colIdx, + common::ValueVector& vector, common::sel_t pos) const; + void readFlatColToUnflatVector(uint8_t** tuplesToRead, ft_col_idx_t colIdx, + common::ValueVector& vector, uint64_t numTuplesToRead) const; + void readFlatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, common::ValueVector& vector, + uint64_t numTuplesToRead) const; + +private: + storage::MemoryManager* memoryManager; + // Table Schema. Keeping track of factorization structure. + FactorizedTableSchema tableSchema; + // Number of rows in table. + uint64_t numTuples; + // Radix sort requires there is a fixed number of tuple in a block. + uint64_t flatTupleBlockSize; + uint32_t numFlatTuplesPerBlock; + // Data blocks for flat tuples. + std::unique_ptr flatTupleBlockCollection; + // Data blocks for unFlat tuples. + std::unique_ptr unFlatTupleBlockCollection; + // Overflow buffer storing variable size part of an entry. + std::unique_ptr inMemOverflowBuffer; + // Prevent destruction of the underlying data structures when the factorized table is + // destructed. If the parent database is closed, the underlying data structures is already + // destructed, so destruction will cause double free. + bool preventDestruction = false; +}; + +class FactorizedTableIterator { +public: + explicit FactorizedTableIterator(FactorizedTable& factorizedTable); + + bool hasNext() { + return nextTupleIdx < factorizedTable.getNumTuples() || nextFlatTupleIdx < numFlatTuples; + } + + void getNext(FlatTuple& tuple); + + void resetState(); + +private: + // The dataChunkPos may be not consecutive, which means some entries in the + // flatTuplePositionsInDataChunk is invalid. We put pair(UINT64_MAX, UINT64_MAX) in the + // invalid entries. + bool isValidDataChunkPos(uint32_t dataChunkPos) const { + return flatTuplePositionsInDataChunk[dataChunkPos].first != UINT64_MAX; + } + + void readUnflatColToFlatTuple(ft_col_idx_t colIdx, uint8_t* valueBuffer, FlatTuple& tuple); + + void readFlatColToFlatTuple(ft_col_idx_t colIdx, uint8_t* valueBuffer, FlatTuple& tuple); + + // We put pair(UINT64_MAX, UINT64_MAX) in all invalid entries in + // FlatTuplePositionsInDataChunk. + void updateInvalidEntriesInFlatTuplePositionsInDataChunk(); + + // This function is used to update the number of elements in the dataChunk when we want + // to iterate a new tuple. + void updateNumElementsInDataChunk(); + + // This function updates the flatTuplePositionsInDataChunk, so that getNextFlatTuple() can + // correctly outputs the next flat tuple in the current tuple. For example, we want to read + // two unFlat columns, which are on different dataChunks A,B and both have 100 columns. The + // flatTuplePositionsInDataChunk after the first call to getNextFlatTuple() looks like: + // {dataChunkA : [0, 100]}, {dataChunkB : [0, 100]} This function updates the + // flatTuplePositionsInDataChunk to: {dataChunkA: [1, 100]}, {dataChunkB: [0, 100]}. Meaning + // that the getNextFlatTuple() should read the second element in the first unflat column and + // the first element in the second unflat column. + void updateFlatTuplePositionsInDataChunk(); + + const FactorizedTable& factorizedTable; + uint8_t* currentTupleBuffer; + uint64_t numFlatTuples; + ft_tuple_idx_t nextFlatTupleIdx; + ft_tuple_idx_t nextTupleIdx; + // This field stores the (nextIdxToReadInDataChunk, numElementsInDataChunk) of each dataChunk. + std::vector> flatTuplePositionsInDataChunk; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_pool.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_pool.h new file mode 100644 index 0000000000..3e05db6dff --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_pool.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include "processor/result/factorized_table.h" + +namespace lbug { +namespace processor { + +// We implement a local ftable pool to avoid generate many small ftables when running GDS. +// Alternative solutions are directly writing to global ftable with partition so conflict is +// minimized. Or we optimize ftable to be more memory efficient when number of tuples is small. +class LBUG_API FactorizedTablePool { +public: + explicit FactorizedTablePool(std::shared_ptr globalTable) + : globalTable{std::move(globalTable)} {} + DELETE_COPY_AND_MOVE(FactorizedTablePool); + + FactorizedTable* claimLocalTable(storage::MemoryManager* mm); + + void returnLocalTable(FactorizedTable* table); + + void mergeLocalTables(); + + std::shared_ptr getGlobalTable() const { return globalTable; } + +private: + std::mutex mtx; + std::shared_ptr globalTable; + std::stack availableLocalTables; + std::vector> localTables; +}; + +} // namespace processor +} // namespace lbug \ No newline at end of file diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_schema.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_schema.h new file mode 100644 index 0000000000..8909da5ed5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_schema.h @@ -0,0 +1,97 @@ +#pragma once + +#include "common/assert.h" +#include "common/copy_constructors.h" +#include "common/types/types.h" + +namespace lbug { +namespace processor { + +// TODO(Guodong/Ziyi): Move these typedef to common and unify them with the ones without `ft_`. +typedef uint64_t ft_tuple_idx_t; +typedef uint32_t ft_col_idx_t; +typedef uint32_t ft_col_offset_t; +typedef uint32_t ft_block_idx_t; +typedef uint32_t ft_block_offset_t; + +class ColumnSchema { +public: + ColumnSchema(bool isUnFlat, common::idx_t groupID, uint32_t numBytes) + : isUnFlat{isUnFlat}, groupID{groupID}, numBytes{numBytes}, mayContainNulls{false} {} + EXPLICIT_COPY_DEFAULT_MOVE(ColumnSchema); + + bool isFlat() const { return !isUnFlat; } + + common::idx_t getGroupID() const { return groupID; } + + uint32_t getNumBytes() const { return numBytes; } + + bool operator==(const ColumnSchema& other) const { + return isUnFlat == other.isUnFlat && groupID == other.groupID && numBytes == other.numBytes; + } + bool operator!=(const ColumnSchema& other) const { return !(*this == other); } + + void setMayContainsNullsToTrue() { mayContainNulls = true; } + + bool hasNoNullGuarantee() const { return !mayContainNulls; } + +private: + ColumnSchema(const ColumnSchema& other); + +private: + // This following two information can alternatively be maintained at table schema + // level as a column group information. + // Whether column is unFlat. + bool isUnFlat; + // Group id. + common::idx_t groupID; + // Num bytes of the column. + uint32_t numBytes; + // Whether column may contain nulls. + // If this field is true, the column can still be all non-null. + bool mayContainNulls; +}; + +class LBUG_API FactorizedTableSchema { +public: + FactorizedTableSchema() = default; + EXPLICIT_COPY_DEFAULT_MOVE(FactorizedTableSchema); + + void appendColumn(ColumnSchema column); + + const ColumnSchema* getColumn(ft_col_idx_t idx) const { return &columns[idx]; } + + uint32_t getNumColumns() const { return columns.size(); } + + ft_col_offset_t getNullMapOffset() const { return numBytesForDataPerTuple; } + + uint32_t getNumBytesPerTuple() const { return numBytesPerTuple; } + + ft_col_offset_t getColOffset(ft_col_idx_t idx) const { return colOffsets[idx]; } + + void setMayContainsNullsToTrue(ft_col_idx_t idx) { + KU_ASSERT(idx < columns.size()); + columns[idx].setMayContainsNullsToTrue(); + } + + bool isEmpty() const { return columns.empty(); } + + bool operator==(const FactorizedTableSchema& other) const; + bool operator!=(const FactorizedTableSchema& other) const { return !(*this == other); } + + uint64_t getNumFlatColumns() const; + uint64_t getNumUnFlatColumns() const; + +private: + FactorizedTableSchema(const FactorizedTableSchema& other); + +private: + std::vector columns; + uint32_t numBytesForDataPerTuple = 0; + uint32_t numBytesForNullMapPerTuple = 0; + uint32_t numBytesPerTuple = 0; + std::vector colOffsets; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_util.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_util.h new file mode 100644 index 0000000000..e6e5b02b57 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/factorized_table_util.h @@ -0,0 +1,29 @@ +#pragma once + +#include "factorized_table.h" +#include "planner/operator/schema.h" + +namespace lbug { +namespace processor { + +class FactorizedTableUtils { +public: + static FactorizedTableSchema createFTableSchema(const binder::expression_vector& exprs, + const planner::Schema& schema); + static FactorizedTableSchema createFlatTableSchema( + std::vector columnTypes); + + // TODO(Ziyi): These two functions are used to store the copy message in a factorizedTable + // because the current QueryProcessor::execute requires the last operator in the physical plan + // must be ResultCollector. We should remove this class after we remove the assumption that the + // last operator in the pipeline must be resultCollector. + static void appendStringToTable(FactorizedTable* factorizedTable, const std::string& outputMsg, + storage::MemoryManager* memoryManager); + static std::shared_ptr getFactorizedTableForOutputMsg( + const std::string& outputMsg, storage::MemoryManager* memoryManager); + static LBUG_API std::shared_ptr getSingleStringColumnFTable( + storage::MemoryManager* mm); +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/flat_tuple.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/flat_tuple.h new file mode 100644 index 0000000000..a3e6a378a0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/flat_tuple.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace processor { + +/** + * @brief Stores a vector of Values. + */ +class FlatTuple { +public: + explicit FlatTuple(const std::vector& types); + + DELETE_COPY_AND_MOVE(FlatTuple); + + /** + * @return number of values in the FlatTuple. + */ + LBUG_API common::idx_t len() const; + /** + * @brief Get a pointer to the value at the specified index. + * @param idx The index of the value to retrieve. + * @return A pointer to the Value at the specified index. + */ + LBUG_API common::Value* getValue(common::idx_t idx); + + /** + * @brief Access the value at the specified index by reference. + * @param idx The index of the value to access. + * @return A reference to the Value at the specified index. + */ + LBUG_API common::Value& operator[](common::idx_t idx); + + /** + * @brief Access the value at the specified index by const reference. + * @param idx The index of the value to access. + * @return A const reference to the Value at the specified index. + */ + LBUG_API const common::Value& operator[](common::idx_t idx) const; + + /** + * @brief Convert the FlatTuple to a string representation. + * @return A string representation of all values in the FlatTuple. + */ + LBUG_API std::string toString() const; + + /** + * @param colsWidth The length of each column + * @param delimiter The delimiter to separate each value. + * @param maxWidth The maximum length of each column. Only the first maxWidth number of + * characters of each column will be displayed. + * @return all values in string format. + */ + LBUG_API std::string toString(const std::vector& colsWidth, + const std::string& delimiter = "|", uint32_t maxWidth = -1); + +private: + std::vector values; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/pattern_creation_info_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/pattern_creation_info_table.h new file mode 100644 index 0000000000..90259bc6c0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/pattern_creation_info_table.h @@ -0,0 +1,37 @@ +#pragma once + +#include "processor/operator/aggregate/aggregate_hash_table.h" + +namespace lbug { +namespace processor { + +struct PatternCreationInfo { + uint8_t* tuple; + bool hasCreated; + + common::nodeID_t getPatternID(common::executor_id_t matchExecutorID) const { + auto ftColIndex = matchExecutorID; + return *(common::nodeID_t*)(tuple + ftColIndex * sizeof(common::nodeID_t)); + } + + void updateID(common::executor_id_t executorID, common::executor_info executorInfo, + common::nodeID_t nodeID) const; +}; + +class PatternCreationInfoTable : public AggregateHashTable { +public: + PatternCreationInfoTable(storage::MemoryManager& memoryManager, + std::vector keyTypes, FactorizedTableSchema tableSchema); + + PatternCreationInfo getPatternCreationInfo(const std::vector& keyVectors); + + uint64_t matchFTEntries(std::span keyVectors, + uint64_t numMayMatches, uint64_t numNoMatches) override; + +private: + uint8_t* tuple; + ft_col_offset_t idColOffset; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/result_set.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/result_set.h new file mode 100644 index 0000000000..b89512bfc8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/result_set.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include "common/data_chunk/data_chunk.h" +#include "processor/data_pos.h" +#include "result_set_descriptor.h" + +namespace lbug { +namespace processor { + +class ResultSet { +public: + ResultSet() : ResultSet(0) {} + explicit ResultSet(common::idx_t numDataChunks) : multiplicity{1}, dataChunks(numDataChunks) {} + ResultSet(ResultSetDescriptor* resultSetDescriptor, storage::MemoryManager* memoryManager); + + void insert(common::idx_t pos, std::shared_ptr dataChunk) { + KU_ASSERT(dataChunks.size() > pos); + dataChunks[pos] = std::move(dataChunk); + } + + std::shared_ptr getDataChunk(data_chunk_pos_t dataChunkPos) { + return dataChunks[dataChunkPos]; + } + std::shared_ptr getValueVector(const DataPos& dataPos) const { + return dataChunks[dataPos.dataChunkPos]->valueVectors[dataPos.valueVectorPos]; + } + + // Our projection does NOT explicitly remove dataChunk from resultSet. Therefore, caller should + // always provide a set of positions when reading from multiple dataChunks. + uint64_t getNumTuples(const std::unordered_set& dataChunksPosInScope) { + return getNumTuplesWithoutMultiplicity(dataChunksPosInScope) * multiplicity; + } + + uint64_t getNumTuplesWithoutMultiplicity( + const std::unordered_set& dataChunksPosInScope); + +public: + uint64_t multiplicity; + std::vector> dataChunks; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/result_set_descriptor.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/result_set_descriptor.h new file mode 100644 index 0000000000..0256327ba4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/result/result_set_descriptor.h @@ -0,0 +1,44 @@ +#pragma once + +#include "common/types/types.h" + +namespace lbug { +namespace planner { +class Schema; +} // namespace planner + +namespace processor { + +struct DataChunkDescriptor { + bool isSingleState; + std::vector logicalTypes; + + explicit DataChunkDescriptor(bool isSingleState) : isSingleState{isSingleState} {} + DataChunkDescriptor(const DataChunkDescriptor& other) + : isSingleState{other.isSingleState}, + logicalTypes(common::LogicalType::copy(other.logicalTypes)) {} + + inline std::unique_ptr copy() const { + return std::make_unique(*this); + } +}; + +struct LBUG_API ResultSetDescriptor { + std::vector> dataChunkDescriptors; + + ResultSetDescriptor() = default; + explicit ResultSetDescriptor( + std::vector> dataChunkDescriptors) + : dataChunkDescriptors{std::move(dataChunkDescriptors)} {} + explicit ResultSetDescriptor(planner::Schema* schema); + DELETE_BOTH_COPY(ResultSetDescriptor); + + std::unique_ptr copy() const; + + static std::unique_ptr EmptyDescriptor() { + return std::make_unique(); + } +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/warning_context.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/warning_context.h new file mode 100644 index 0000000000..dbac96ab51 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/processor/warning_context.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" +#include "common/types/types.h" +#include "processor/operator/persistent/reader/copy_from_error.h" + +namespace lbug { +namespace common { +class ValueVector; +} +namespace storage { +class ColumnChunkData; +} +namespace main { +struct ClientConfig; +} +namespace processor { + +class SerialCSVReader; + +struct WarningInfo { + uint64_t queryID; + PopulatedCopyFromError warning; + + WarningInfo(PopulatedCopyFromError warning, uint64_t queryID) + : queryID(queryID), warning(std::move(warning)) {} +}; + +using populate_func_t = std::function; +using get_file_idx_func_t = std::function; + +class LBUG_API WarningContext { +public: + explicit WarningContext(main::ClientConfig* clientConfig); + + void appendWarningMessages(const std::vector& messages); + + void populateWarnings(uint64_t queryID, populate_func_t populateFunc = {}, + get_file_idx_func_t getFileIdxFunc = {}); + void defaultPopulateAllWarnings(uint64_t queryID); + + const std::vector& getPopulatedWarnings() const; + uint64_t getWarningCount(uint64_t queryID); + void clearPopulatedWarnings(); + + void setIgnoreErrorsForCurrentQuery(bool ignoreErrors); + // NOTE: this function only works if the logical operator is COPY FROM + // for other operators setIgnoreErrorsForCurrentQuery() is not called + bool getIgnoreErrorsOption() const; + + static WarningContext* Get(const main::ClientContext& context); + +private: + std::mutex mtx; + main::ClientConfig* clientConfig; + std::vector unpopulatedWarnings; + std::vector populatedWarnings; + uint64_t queryWarningCount; + uint64_t numStoredWarnings; + bool ignoreErrorsOption; +}; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/buffer_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/buffer_manager.h new file mode 100644 index 0000000000..5d51ab0491 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/buffer_manager.h @@ -0,0 +1,298 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/buffer_manager/page_state.h" +#include "storage/enums/page_read_policy.h" +#include "storage/file_handle.h" + +namespace lbug { +namespace main { +struct DBConfig; +}; +namespace common { +class VirtualFileSystem; +}; +namespace testing { +class FlakyBufferManager; +class BufferManagerTest; +class CopyTestHelper; +}; // namespace testing +namespace storage { +class ChunkedNodeGroup; +class Spiller; + +// This class keeps state info for pages potentially can be evicted. +// The page state of a candidate is set to be MARKED when it is first enqueued. After enqueued, if +// the candidate was recently accessed, it is no longer immediately evictable. See the state +// transition diagram above `BufferManager` class declaration for more details. +struct EvictionCandidate { + friend class EvictionQueue; + + // If the candidate is MARKED, it is evictable. + static bool isEvictable(uint64_t currPageStateAndVersion) { + return PageState::getState(currPageStateAndVersion) == PageState::MARKED; + } + // If the candidate was recently read optimistically, it is second chance evictable. + static bool isSecondChanceEvictable(uint64_t currPageStateAndVersion) { + return PageState::getState(currPageStateAndVersion) == PageState::UNLOCKED; + } + + bool operator==(const EvictionCandidate& other) const { + return fileIdx == other.fileIdx && pageIdx == other.pageIdx; + } + + // Returns false if the candidate was not empty, or if another thread set the value first + bool set(const EvictionCandidate& newCandidate); + + uint32_t fileIdx = UINT32_MAX; + common::page_idx_t pageIdx = common::INVALID_PAGE_IDX; +}; + +// A circular buffer queue storing eviction candidates +// One candidate should be stored for each page currently in memory +class EvictionQueue { +public: + static constexpr auto EMPTY = EvictionCandidate{UINT32_MAX, common::INVALID_PAGE_IDX}; + static constexpr size_t BATCH_SIZE = 64; + explicit EvictionQueue(uint64_t capacity) + // Capacity needs to be a multiple of the batch size + : insertCursor{0}, evictionCursor{0}, size{0}, + capacity{capacity + (BATCH_SIZE - capacity % BATCH_SIZE)}, + data{std::make_unique[]>(this->capacity)} {} + + bool insert(uint32_t fileIndex, common::page_idx_t pageIndex); + + // Produces the next non-empty candidate to be tried for eviction. + // Note that it is still possible (though unlikely) for another thread to evict this candidate, + // so it is not guaranteed to be empty. + // The PageState should be locked, and then the atomic checked against the version used when + // locking the page state to make sure there wasn't a data race + std::span, BATCH_SIZE> next(); + void clear(std::atomic& candidate); + + uint64_t getSize() const { return size; } + uint64_t getEvictionCursor() const { return evictionCursor; } + uint64_t getCapacity() const { return capacity; } + +private: + std::atomic insertCursor; + std::atomic evictionCursor; + std::atomic size; + const uint64_t capacity; + std::unique_ptr[]> data; +}; + +/** + * The Buffer Manager (BM) is a centralized manager of database memory resources. + * It provides two main functionalities: + * 1) it provides the high-level functionality to pin() and unpin() the pages of the database files + * used by storage structures, such as the Column, Lists, or HashIndex in the storage layer, and + * operates via their FileHandle to read/write the page data into/out of one of the frames. + * 2) it provides optimistic read of pages, which optimistically read unlocked or marked pages + * without acquiring locks. + * 3) it supports the MemoryManager (MM) to allocate memory buffers that are not + * backed by any disk files. Similar to disk files, MM provides in-mem file handles to the BM to + * pin/unpin pages. Pin happens when MM tries to allocate a new memory buffer, and unpin happens + * when MM tries to reclaim a memory buffer. + * + * Specifically, in BM's context, page is the basic management unit of data in a file. The file can + * be a disk file, such as a column file, or an in-mem file, such as an temp in-memory file kept by + * the MM. Frame is the basic management unit of data resides in a VMRegion, namely in a virtual + * memory space. Each page is uniquely mapped to a frame, and it can be cached into or evicted from + * the frame. See `VMRegion` for more details. + * + * When users unpin their pages, the BM might spill them to disk. The behavior of what is guaranteed + * to be kept in frame and what can be spilled to disk is directly determined by the pin/unpin + * calls of the users. + * + * Also, BM provides some specialized functionalities for WAL files: + * 1) it supports the caller to set pinned pages as dirty, which will be safely written back to disk + * when the pages are evicted; + * 2) it supports the caller to flush or remove pages from the BM; + * 3) it supports the caller to directly update the content of a frame. + * + * All accesses to the BM are through a FileHandle. This design is to decentralize the management + * of page states from the BM to each file handle itself. Thus, each on-disk file should have a + * unique FileHandle, and MM also holds a unique FileHandle, which is backed by a temp in-mem + * file, for all memory buffer allocations + * + * To start a Database, users need to specify the max size of the memory usage (`maxSize`) in BM. + * If users don't specify the value, the system will set maxSize to available physical mem * + * DEFAULT_PHY_MEM_SIZE_RATIO_FOR_BM (defined in constants.h). + * The BM relies on virtual memory regions mapped through `mmap` to anonymous address spaces. + * 1) For disk pages, BM allocates a virtual memory region of DEFAULT_VM_REGION_MAX_SIZE (defined in + * constants.h), which is usually much larger than `maxSize`, and is expected to be large enough to + * contain all disk pages. Each disk page in database files is directly mapped to a unique + * PAGE_SIZE frame in the region. + * 2) For each FileHandle backed by a temp in-mem file in MM, BM allocates a virtual memory region + * of `maxSize` for it. Each memory buffer is mapped to a unique TEMP_PAGE_SIZE frame in that + * region. Both disk pages and memory buffers are all managed by the BM to make sure that actually + * used physical memory doesn't go beyond max size specified by users. Currently, the BM uses a + * queue based replacement policy and the MADV_DONTNEED hint to explicitly control evictions. See + * comments above `claimAFrame()` for more details. + * + * Page states in BM: + * A page can be in one of the four states: a) LOCKED, b) UNLOCKED, c) MARKED, d) EVICTED. + * Every page is initialized as in the EVICTED state. + * The state transition diagram of page X is as follows (oRead refers to optimisticRead): + * Note: optimistic reads on UNLOCKED pages don't make any changes to pages' states. oRead on + * UNLOCKED is omitted in the diagram. + * + * 7.2. pin(pY): evict pX. 7.1. pin(pY): tryLock(pX) + * |<-------------------------|<------------------------------------------------------------| + * | | 4. pin(pX) | + * | |<------------------------------------------------------------| + * | 1. pin(pX) | 5. pin(pX) 6. pin(pY): 2nd chance eviction | + * EVICTED ------------------> LOCKED <-------------UNLOCKED ------------------------------> MARKED + * | | 3. oRead(pX) | + * | <--------------------------------------| + * | 2. unpin(pX): enqueue pX & increment version | + * -------------------------------------------------------------> + * + * 1. When page pX at EVICTED state, and it is pinned, it transits to the Locked state. `pin` will + * first acquire the exclusive lock on the page, then read the page from the disk into its frame. + * The caller can safely make changes to the page. + * 2. When the caller finishes changes to the page, it calls `unpin`, which releases the lock on the + * page, puts the page into the eviction queue, and increments its version. The page now transits to + * the MARKED state. Note that currently the page is still cached, but it is ready to be evicted. + * The page version number is used to identify any potential writing on the page. Each time a page + * transits from LOCKED to MARKED state, we will increment its version. This happens when a page is + * pinned, then unpinned. During the pin and unpin, we assume the page's content in its + * corresponding frame might have changed, thus, we increment the version number to forbid stale + * reads on it; + * 3. The MARKED page can be optimistically read by the caller, setting the page's state to + * UNLOCKED. For evicted pages, optimistic reads will trigger pin and unpin to read pages from disk + * into frames. + * 4. The MARKED page can be pinned again by the caller, setting the page's state to LOCKED. + * 5. The UNLOCKED page can also be pinned again by the caller, setting the page's state to LOCKED. + * 6. During eviction, UNLOCKED pages will be checked if they are second chance evictable. If so, + * they will be set to MARKED, and their eviction candidates will be moved back to the eviction + * queue. + * 7. During eviction, if the page is in the MARKED state, it will be LOCKED first (7.1), then + * removed from its frame, and set to EVICTED (7.2). + * + * The design is inspired by vmcache in the paper "Virtual-Memory Assisted Buffer Management" + * (https://www.cs.cit.tum.de/fileadmin/w00cfj/dis/_my_direct_uploads/vmcache.pdf). + * We would also like to thank Fadhil Abubaker for doing the initial research and prototyping of + * Umbra's design in his CS 848 course project: + * https://github.com/fabubaker/lbug/blob/umbra-bm/final_project_report.pdf. + */ +class BufferManager { + friend class testing::FlakyBufferManager; + friend class testing::BufferManagerTest; + friend class testing::CopyTestHelper; + + friend class FileHandle; + friend class MemoryManager; + +public: + BufferManager(const std::string& databasePath, const std::string& spillToDiskPath, + uint64_t bufferPoolSize, uint64_t maxDBSize, common::VirtualFileSystem* vfs, bool readOnly); + virtual ~BufferManager(); + + // Currently, these functions are specifically used only for WAL files. + void removeFilePagesFromFrames(FileHandle& fileHandle); + void updateFrameIfPageIsInFrameWithoutLock(common::file_idx_t fileIdx, const uint8_t* newPage, + common::page_idx_t pageIdx); + + // For files that are managed by BM, their FileHandles should be created through this function. + FileHandle* getFileHandle(const std::string& filePath, uint8_t flags, + common::VirtualFileSystem* vfs, main::ClientContext* context) { + fileHandles.emplace_back( + std::make_unique(filePath, flags, this, fileHandles.size(), vfs, context)); + return fileHandles.back().get(); + } + + uint64_t getMemoryLimit() const { return bufferPoolSize; } + uint64_t getUsedMemory() const { return usedMemory; } + + void getSpillerOrSkip(std::function func) { + if (spiller) { + return func(*spiller); + } + } + + void resetSpiller(std::string spillPath); + + // This function only works when run in a single-threaded context + // Iterates through the eviction queue and removes any elements that have already been evicted + // (due to some external intervention) + void removeEvictedCandidates(); + +protected: + // Reclaims used memory until the given size to reserve is available. + // The specified amount of memory will be recorded as being used + virtual bool reserve(uint64_t sizeToReserve); + +private: + uint8_t* pin(FileHandle& fileHandle, common::page_idx_t pageIdx, + PageReadPolicy pageReadPolicy = PageReadPolicy::READ_PAGE); + void optimisticRead(FileHandle& fileHandle, common::page_idx_t pageIdx, + const std::function& func); + // The function assumes that the requested page is already pinned. + void unpin(FileHandle& fileHandle, common::page_idx_t pageIdx); + uint8_t* getFrame(FileHandle& fileHandle, common::page_idx_t pageIdx) const { +#if BM_MALLOC + return fileHandle.getPageState(pageIdx)->getPage(); +#else + return vmRegions[fileHandle.getPageSizeClass()]->getFrame(fileHandle.getFrameIdx(pageIdx)); +#endif + } + common::frame_group_idx_t addNewFrameGroup( + common::PageSizeClass pageSizeClass [[maybe_unused]]) { +#if BM_MALLOC + return 0; +#else + return vmRegions[pageSizeClass]->addNewFrameGroup(); +#endif + } + void removePageFromFrameIfNecessary(FileHandle& fileHandle, common::page_idx_t pageIdx); + + static void verifySizeParams(uint64_t bufferPoolSize, uint64_t maxDBSize); + + bool claimAFrame(FileHandle& fileHandle, common::page_idx_t pageIdx, + PageReadPolicy pageReadPolicy); + // Return number of bytes freed. + uint64_t tryEvictPage(std::atomic& candidate); + + void cachePageIntoFrame(FileHandle& fileHandle, common::page_idx_t pageIdx, + PageReadPolicy pageReadPolicy); + void removePageFromFrame(FileHandle& fileHandle, common::page_idx_t pageIdx, bool shouldFlush); + + uint64_t freeUsedMemory(uint64_t size); + + void releaseFrameForPage(FileHandle& fileHandle [[maybe_unused]], + common::page_idx_t pageIdx [[maybe_unused]]) { +#if BM_MALLOC + // Page is freed instead in PageState::resetToEvicted +#else + vmRegions[fileHandle.getPageSizeClass()]->releaseFrame(fileHandle.getFrameIdx(pageIdx)); +#endif + } + + uint64_t evictPages(); + +private: + std::atomic bufferPoolSize; + EvictionQueue evictionQueue; + // Total memory used + std::atomic usedMemory; + // Amount of memory used, which cannot be evicted + std::atomic nonEvictableMemory; + // Each VMRegion corresponds to a virtual memory region of a specific page size. Currently, we + // hold two sizes of REGULAR_PAGE and TEMP_PAGE. + std::array, 2> vmRegions; + std::vector> fileHandles; + std::unique_ptr spiller; + common::VirtualFileSystem* vfs; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/memory_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/memory_manager.h new file mode 100644 index 0000000000..3915e899d4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/memory_manager.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/system_config.h" +#include "common/types/types.h" +#include "storage/buffer_manager/spill_result.h" +#include + +namespace lbug { + +namespace common { +class VirtualFileSystem; +} + +namespace storage { + +class MemoryManager; +class FileHandle; +class BufferManager; +class ChunkedNodeGroup; +template +class MmAllocator; + +class MemoryBuffer { + friend class Spiller; + +public: + LBUG_API MemoryBuffer(MemoryManager* mm, common::page_idx_t blockIdx, uint8_t* buffer, + uint64_t size = common::TEMP_PAGE_SIZE); + LBUG_API ~MemoryBuffer(); + DELETE_COPY_AND_MOVE(MemoryBuffer); + + std::span getBuffer() const { + KU_ASSERT(!evicted); + return buffer; + } + uint8_t* getData() const { return getBuffer().data(); } + + MemoryManager* getMemoryManager() const { return mm; } + + // Manually set the evicted state of the buffer to avoid double free. + void preventDestruction() { evicted = true; } + +private: + // Can be called multiple times safely + void prepareLoadFromDisk(); + + // Must only be called once before loading from disk + SpillResult setSpilledToDisk(uint64_t filePosition); + +private: + std::span buffer; + uint64_t filePosition = UINT64_MAX; + MemoryManager* mm; + common::page_idx_t pageIdx; + bool evicted; +}; + +/* + * The Memory Manager (MM) is used for allocating/reclaiming intermediate memory blocks. + * It can allocate a memory buffer of size PAGE_256KB from the buffer manager backed by a + * BMFileHandle with temp in-mem file. + * + * The MemoryManager holds a BMFileHandle backed by + * a temp in-mem file, and is responsible for allocating/reclaiming memory buffers of its size class + * from the buffer manager. The MemoryManager keeps track of free pages in the BMFileHandle, so + * that it can reuse those freed pages without allocating new pages. The MemoryManager is + * thread-safe, so that multiple threads can allocate/reclaim memory blocks with the same size class + * at the same time. + * + * MM will return a MemoryBuffer to the caller, which is a wrapper of the allocated memory block, + * and it will automatically call its allocator to reclaim the memory block when it is destroyed. + */ +class LBUG_API MemoryManager { + friend class MemoryBuffer; + template + friend class MmAllocator; + +public: + MemoryManager(BufferManager* bm, common::VirtualFileSystem* vfs); + + ~MemoryManager() = default; + + std::unique_ptr allocateBuffer(bool initializeToZero = false, + uint64_t size = common::TEMP_PAGE_SIZE); + common::page_offset_t getPageSize() const { return pageSize; } + + BufferManager* getBufferManager() const { return bm; } + + static MemoryManager* Get(const main::ClientContext& context); + +private: + void freeBlock(common::page_idx_t pageIdx, std::span buffer); + void updateUsedMemoryForFreedBlock(common::page_idx_t pageIdx, std::span buffer); + std::span mallocBuffer(bool initializeToZero, uint64_t size); + +private: + FileHandle* fh; + BufferManager* bm; + common::page_offset_t pageSize; + std::stack freePages; + std::mutex allocatorLock; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/mm_allocator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/mm_allocator.h new file mode 100644 index 0000000000..570c2f7579 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/mm_allocator.h @@ -0,0 +1,54 @@ +#pragma once + +#include "memory_manager.h" + +namespace lbug { +namespace storage { + +template +class MmAllocator { +public: + using value_type = T; + + explicit MmAllocator(MemoryManager* mm) : mm{mm} {} + + MmAllocator(const MmAllocator& other) : mm{other.mm} {} + MmAllocator& operator=(const MmAllocator& other) = default; + DELETE_BOTH_MOVE(MmAllocator); + + [[nodiscard]] T* allocate(const std::size_t size) { + KU_ASSERT_UNCONDITIONAL(mm != nullptr); + KU_ASSERT_UNCONDITIONAL(size > 0); + KU_ASSERT_UNCONDITIONAL(size <= std::numeric_limits::max() / sizeof(T)); + + auto buffer = mm->mallocBuffer(false, size * sizeof(T)); + auto p = reinterpret_cast(buffer.data()); + + // Ensure proper alignment + KU_ASSERT_UNCONDITIONAL(reinterpret_cast(p) % alignof(T) == 0); + + return p; + } + + void deallocate(T* p, const std::size_t size) noexcept { + KU_ASSERT_UNCONDITIONAL(mm != nullptr); + KU_ASSERT_UNCONDITIONAL(p != nullptr); + KU_ASSERT_UNCONDITIONAL(size > 0); + + const auto buffer = std::span(reinterpret_cast(p), size * sizeof(T)); + if (buffer.data() != nullptr) { + mm->freeBlock(common::INVALID_PAGE_IDX, buffer); + } + } + +private: + MemoryManager* mm; +}; + +template +bool operator==(const MmAllocator& a, const MmAllocator& b) { + return a.mm == b.mm; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/page_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/page_state.h new file mode 100644 index 0000000000..8d504b0eb0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/page_state.h @@ -0,0 +1,117 @@ +#pragma once + +#include + +#include "common/assert.h" + +// Alternative variant of the buffer manager which doesn't rely on MADV_DONTNEED (on Unix) for +// evicting pages (which is unavailable in Webassembly runtimes) +#if BM_MALLOC +#include +#endif + +namespace lbug { +namespace storage { + +// Keeps the state information of a page in a file. +class PageState { + static constexpr uint64_t DIRTY_MASK = 0x0080000000000000; + static constexpr uint64_t STATE_MASK = 0xFF00000000000000; + static constexpr uint64_t VERSION_MASK = 0x00FFFFFFFFFFFFFF; + static constexpr uint64_t NUM_BITS_TO_SHIFT_FOR_STATE = 56; + +public: + static constexpr uint64_t UNLOCKED = 0; + static constexpr uint64_t LOCKED = 1; + static constexpr uint64_t MARKED = 2; + static constexpr uint64_t EVICTED = 3; + + PageState() { stateAndVersion.store(EVICTED << NUM_BITS_TO_SHIFT_FOR_STATE); } + + uint64_t getState() const { return getState(stateAndVersion.load()); } + static uint64_t getState(uint64_t stateAndVersion) { + return (stateAndVersion & STATE_MASK) >> NUM_BITS_TO_SHIFT_FOR_STATE; + } + static uint64_t getVersion(uint64_t stateAndVersion) { return stateAndVersion & VERSION_MASK; } + static uint64_t updateStateWithSameVersion(uint64_t oldStateAndVersion, uint64_t newState) { + return ((oldStateAndVersion << 8) >> 8) | (newState << NUM_BITS_TO_SHIFT_FOR_STATE); + } + static uint64_t updateStateAndIncrementVersion(uint64_t oldStateAndVersion, uint64_t newState) { + return (((oldStateAndVersion << 8) >> 8) + 1) | (newState << NUM_BITS_TO_SHIFT_FOR_STATE); + } + void spinLock(uint64_t oldStateAndVersion) { + while (true) { + if (tryLock(oldStateAndVersion)) { + return; + } + } + } + bool tryLock(uint64_t oldStateAndVersion) { + return stateAndVersion.compare_exchange_strong(oldStateAndVersion, + updateStateWithSameVersion(oldStateAndVersion, LOCKED)); + } + void unlock() { + // TODO(Keenan / Guodong): Track down this rare bug and re-enable the assert. Ref #2289. + // KU_ASSERT(getState(stateAndVersion.load()) == LOCKED); + stateAndVersion.store(updateStateAndIncrementVersion(stateAndVersion.load(), UNLOCKED)); + } + void unlockUnchanged() { + // TODO(Keenan / Guodong): Track down this rare bug and re-enable the assert. Ref #2289. + // KU_ASSERT(getState(stateAndVersion.load()) == LOCKED); + stateAndVersion.store(updateStateWithSameVersion(stateAndVersion.load(), UNLOCKED)); + } + // Change page state from Mark to Unlocked. + bool tryClearMark(uint64_t oldStateAndVersion) { + KU_ASSERT(getState(oldStateAndVersion) == MARKED); + return stateAndVersion.compare_exchange_strong(oldStateAndVersion, + updateStateWithSameVersion(oldStateAndVersion, UNLOCKED)); + } + bool tryMark(uint64_t oldStateAndVersion) { + return stateAndVersion.compare_exchange_strong(oldStateAndVersion, + updateStateWithSameVersion(oldStateAndVersion, MARKED)); + } + + void setDirty() { + KU_ASSERT(getState(stateAndVersion.load()) == LOCKED); + stateAndVersion |= DIRTY_MASK; + } + void clearDirty() { + KU_ASSERT(getState(stateAndVersion.load()) == LOCKED); + stateAndVersion &= ~DIRTY_MASK; + } + // Meant to be used when flushing in a single thread. + // Should not be used if other threads are modifying the page state + void clearDirtyWithoutLock() { stateAndVersion &= ~DIRTY_MASK; } + bool isDirty() const { return stateAndVersion & DIRTY_MASK; } + uint64_t getStateAndVersion() const { return stateAndVersion.load(); } + + void resetToEvicted() { + stateAndVersion.store(EVICTED << NUM_BITS_TO_SHIFT_FOR_STATE); +#if BM_MALLOC + page.reset(); +#endif + } + +#if BM_MALLOC + uint8_t* getPage() const { return page.get(); } + uint8_t* allocatePage(uint64_t pageSize) { + page = std::make_unique(pageSize); + return page.get(); + } + uint16_t getReaderCount() const { return readerCount; } + void addReader() { readerCount++; } + void removeReader() { readerCount--; } +#endif + +private: + // Highest 1 bit is dirty bit, and the rest are page state and version bits. + // In the rest bits, the lowest 1 byte is state, and the rest are version. + std::atomic stateAndVersion; +#if BM_MALLOC + std::unique_ptr page; + std::atomic readerCount; +#endif +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/spill_result.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/spill_result.h new file mode 100644 index 0000000000..45907925cf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/spill_result.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace lbug { +namespace storage { + +struct SpillResult { + uint64_t memoryFreed = 0; + uint64_t memoryNowEvictable = 0; + + SpillResult& operator+=(const SpillResult& other) { + memoryFreed += other.memoryFreed; + memoryNowEvictable += other.memoryNowEvictable; + return *this; + } +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/spiller.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/spiller.h new file mode 100644 index 0000000000..fe04968546 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/spiller.h @@ -0,0 +1,47 @@ +#pragma once + +#include "storage/buffer_manager/memory_manager.h" +#include "storage/file_handle.h" + +namespace lbug { +namespace common { +class VirtualFileSystem; +}; +namespace storage { +class InMemChunkedNodeGroup; + +class BufferManager; +class ColumnChunkData; + +// This should only be used with a LocalFileSystem +class Spiller { +public: + Spiller(std::string tmpFilePath, BufferManager& bufferManager, common::VirtualFileSystem* vfs); + void addUnusedChunk(InMemChunkedNodeGroup* nodeGroup); + void clearUnusedChunk(InMemChunkedNodeGroup* nodeGroup); + SpillResult spillToDisk(ColumnChunkData& chunk) const; + void loadFromDisk(ColumnChunkData& chunk) const; + // reclaims memory from the next full partitioner group in the set + // and returns the amount of memory reclaimed + // If the set is empty, returns zero + SpillResult claimNextGroup(); + // Must only be used once all chunks have been loaded from disk. + void clearFile(); + ~Spiller(); + +private: + FileHandle* getOrCreateDataFH() const; + FileHandle* getDataFH() const; + +private: + std::string tmpFilePath; + BufferManager& bufferManager; + common::VirtualFileSystem* vfs; + std::unordered_set fullPartitionerGroups; + std::atomic dataFH; + std::mutex partitionerGroupsMtx; + mutable std::mutex fileCreationMutex; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/vm_region.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/vm_region.h new file mode 100644 index 0000000000..1bf70837a7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/buffer_manager/vm_region.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include "common/constants.h" +#include "common/types/types.h" + +namespace lbug { +namespace storage { + +// A VMRegion holds a virtual memory region of a certain size allocated through mmap. +// The region is divided into frame groups, each of which is a group of frames of the same size. +// Each FileHandle should grab a frame group each time when they add a new file page group (see +// `FileHandle::addNewPageGroupWithoutLock`). In this way, each file page group uniquely +// corresponds to a frame group, thus, a page also uniquely corresponds to a frame in a VMRegion. +class VMRegion { + friend class BufferManager; + +public: + explicit VMRegion(common::PageSizeClass pageSizeClass, uint64_t maxRegionSize); + ~VMRegion(); + + common::frame_group_idx_t addNewFrameGroup(); + + // Use `MADV_DONTNEED` to release physical memory associated with this frame. + void releaseFrame(common::frame_idx_t frameIdx) const; + + // Returns true if the memory address is within the reserved virtual memory region + bool contains(const uint8_t* address) const { + return address >= region && address < region + getMaxRegionSize(); + } + inline uint8_t* getFrame(common::frame_idx_t frameIdx) const { + return region + (static_cast(frameIdx) * frameSize); + } + +private: + inline uint64_t getMaxRegionSize() const { + return maxNumFrameGroups * frameSize * common::StorageConstants::PAGE_GROUP_SIZE; + } + +private: + std::mutex mtx; + uint8_t* region; + uint32_t frameSize; + uint64_t numFrameGroups; + uint64_t maxNumFrameGroups; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/checkpointer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/checkpointer.h new file mode 100644 index 0000000000..1497ccb07b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/checkpointer.h @@ -0,0 +1,62 @@ +#pragma once + +#include "storage/database_header.h" +#include "storage/page_range.h" + +namespace lbug { +namespace transaction { +class Transaction; +} +namespace catalog { +class Catalog; +} +namespace common { +class VirtualFileSystem; +} // namespace common +namespace testing { +struct FSMLeakChecker; +} +namespace main { +class AttachedLbugDatabase; +} // namespace main + +namespace storage { +class StorageManager; + +class Checkpointer { + friend class main::AttachedLbugDatabase; + friend struct testing::FSMLeakChecker; + +public: + explicit Checkpointer(main::ClientContext& clientContext); + virtual ~Checkpointer(); + + void writeCheckpoint(); + void rollback(); + + void readCheckpoint(); + + static bool canAutoCheckpoint(const main::ClientContext& clientContext, + const transaction::Transaction& transaction); + +protected: + virtual bool checkpointStorage(); + virtual void serializeCatalogAndMetadata(DatabaseHeader& databaseHeader, + bool hasStorageChanges); + virtual void writeDatabaseHeader(const DatabaseHeader& header); + virtual void logCheckpointAndApplyShadowPages(); + +private: + static void readCheckpoint(main::ClientContext* context, catalog::Catalog* catalog, + StorageManager* storageManager); + + PageRange serializeCatalog(const catalog::Catalog& catalog, StorageManager& storageManager); + PageRange serializeMetadata(const catalog::Catalog& catalog, StorageManager& storageManager); + +protected: + main::ClientContext& clientContext; + bool isInMemory; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/bitpacking_int128.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/bitpacking_int128.h new file mode 100644 index 0000000000..a21f8b7b1f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/bitpacking_int128.h @@ -0,0 +1,17 @@ +// Adapted from +// https://github.com/duckdb/duckdb/blob/main/src/include/duckdb/common/bitpacking.hpp + +#pragma once + +#include "common/types/int128_t.h" + +namespace lbug::storage { + +struct Int128Packer { + static void pack(const common::int128_t* __restrict in, uint32_t* __restrict out, + uint8_t width); + static void unpack(const uint32_t* __restrict in, common::int128_t* __restrict out, + uint8_t width); +}; + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/bitpacking_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/bitpacking_utils.h new file mode 100644 index 0000000000..f7671de291 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/bitpacking_utils.h @@ -0,0 +1,23 @@ +// Adapted from +// https://github.com/duckdb/duckdb/blob/main/src/include/duckdb/common/bitpacking.hpp + +#pragma once + +#include "storage/compression/compression.h" + +namespace lbug::storage { + +template +struct BitpackingUtils { + using CompressedType = + std::conditional_t= sizeof(uint32_t), uint32_t, uint8_t>; + static constexpr size_t sizeOfCompressedTypeBits = sizeof(CompressedType) * 8; + + static void unpackSingle(const uint8_t* __restrict src, UncompressedType* __restrict dst, + uint16_t bitWidth, size_t srcOffset); + + static void packSingle(const UncompressedType src, uint8_t* __restrict dstCursor, + uint16_t bitWidth, size_t dstOffset); +}; + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/compression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/compression.h new file mode 100644 index 0000000000..0319afc228 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/compression.h @@ -0,0 +1,498 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "alp/state.hpp" +#include "common/assert.h" +#include "common/null_mask.h" +#include "common/numeric_utils.h" +#include "common/types/types.h" +#include + +namespace lbug { +namespace common { +class ValueVector; +class NullMask; +} // namespace common + +namespace storage { +class ColumnChunkData; + +struct PageCursor; + +template +concept StorageValueType = (common::numeric_utils::IsIntegral || std::floating_point); +// Type storing values in the column chunk statistics +// Only supports integers (up to 128bit), floats and bools +union StorageValue { + int64_t signedInt; + uint64_t unsignedInt; + double floatVal; + common::int128_t signedInt128; + + StorageValue() = default; + template + requires std::same_as, common::int128_t> + explicit StorageValue(T value) : signedInt128(value) {} + + template + requires std::integral && std::numeric_limits::is_signed + // zero-initialize union padding + explicit StorageValue(T value) : StorageValue(common::int128_t(0)) { + signedInt = value; + } + + template + requires std::integral && (!std::numeric_limits::is_signed) + explicit StorageValue(T value) : StorageValue(common::int128_t(0)) { + unsignedInt = value; + } + + template + requires std::is_floating_point_v + explicit StorageValue(T value) : StorageValue(common::int128_t(0)) { + floatVal = value; + } + + bool operator==(const StorageValue& other) const { + // We zero-initialize any padding bits, so we can compare values to check equality + return this->signedInt128 == other.signedInt128; + } + + template + StorageValue& operator=(const T& val) { + return *this = StorageValue(val); + } + + template + T get() const { + if constexpr (std::same_as, common::int128_t>) { + return signedInt128; + } else if constexpr (std::integral) { + if constexpr (std::numeric_limits::is_signed) { + return static_cast(signedInt); + } else { + return static_cast(unsignedInt); + } + } else if constexpr (std::is_floating_point()) { + return floatVal; + } else { + KU_UNREACHABLE; + } + } + + bool gt(const StorageValue& other, common::PhysicalTypeID type) const; + + // If the type cannot be stored in the statistics, readFromVector will return nullopt + static std::optional readFromVector(const common::ValueVector& vector, + common::offset_t posInVector); +}; +static_assert(std::is_trivial_v); + +std::pair, std::optional> getMinMaxStorageValue( + const ColumnChunkData& data, uint64_t offset, uint64_t numValues, + common::PhysicalTypeID physicalType, bool valueRequiredIfUnsupported = false); + +// Expects bools to be one bool per bit (like ColumnChunkData, not like ValueVector) +std::pair, std::optional> getMinMaxStorageValue( + const uint8_t* data, uint64_t offset, uint64_t numValues, common::PhysicalTypeID physicalType, + const common::NullMask* nullMask, bool valueRequiredIfUnsupported = false); + +std::pair, std::optional> getMinMaxStorageValue( + const common::ValueVector& data, uint64_t offset, uint64_t numValues, + common::PhysicalTypeID physicalType, bool valueRequiredIfUnsupported = false); + +// Returns the size of the data type in bytes +uint32_t getDataTypeSizeInChunk(const common::LogicalType& dataType); +uint32_t getDataTypeSizeInChunk(const common::PhysicalTypeID& dataType); + +// Compression type is written to the data header both so we can usually catch issues when we +// decompress uncompressed data by mistake, and to allow for runtime-configurable compression. +enum class CompressionType : uint8_t { + UNCOMPRESSED = 0, + INTEGER_BITPACKING = 1, + BOOLEAN_BITPACKING = 2, + CONSTANT = 3, + ALP = 4, +}; + +struct ExtraMetadata { + virtual ~ExtraMetadata() = default; + virtual std::unique_ptr copy() = 0; +}; + +// used only for compressing floats/doubles +struct ALPMetadata : ExtraMetadata { + ALPMetadata() : exp(0), fac(0), exceptionCount(0), exceptionCapacity(0) {} + explicit ALPMetadata(const alp::state& alpState, common::PhysicalTypeID physicalType); + + uint8_t exp; + uint8_t fac; + uint32_t exceptionCount; + uint32_t exceptionCapacity; + + void serialize(common::Serializer& serializer) const; + static ALPMetadata deserialize(common::Deserializer& deserializer); + + std::unique_ptr copy() override; +}; + +struct InPlaceUpdateLocalState { + struct FloatState { + size_t newExceptionCount; + } floatState; +}; + +// Data statistics used for determining how to handle compressed data +struct LBUG_API CompressionMetadata { + + // Minimum and maximum are upper and lower bounds for the data. + // Updates and deletions may cause them to no longer be the exact minimums and maximums, + // but no value will be larger than the maximum or smaller than the minimum + StorageValue min; + StorageValue max; + + CompressionType compression; + + std::optional> extraMetadata; + + std::vector children; + + CompressionMetadata(StorageValue min, StorageValue max, CompressionType compression) + : min(min), max(max), compression(compression), extraMetadata() {} + + // constructor for float metadata + CompressionMetadata(StorageValue min, StorageValue max, CompressionType compression, + const alp::state& state, StorageValue minEncoded, StorageValue maxEncoded, + common::PhysicalTypeID physicalType); + + CompressionMetadata(const CompressionMetadata&); + CompressionMetadata& operator=(const CompressionMetadata&); + + static size_t getChildCount(CompressionType compressionType); + + inline bool isConstant() const { return compression == CompressionType::CONSTANT; } + const CompressionMetadata& getChild(common::offset_t idx) const; + + // accessors for additionalMetadata + inline const ExtraMetadata* getExtraMetadata() const { + KU_ASSERT(extraMetadata.has_value()); + return extraMetadata.value().get(); + } + inline ExtraMetadata* getExtraMetadata() { + KU_ASSERT(extraMetadata.has_value()); + return extraMetadata.value().get(); + } + inline const ALPMetadata* floatMetadata() const { + return common::ku_dynamic_cast(getExtraMetadata()); + } + inline ALPMetadata* floatMetadata() { + return common::ku_dynamic_cast(getExtraMetadata()); + } + + void serialize(common::Serializer& serializer) const; + static CompressionMetadata deserialize(common::Deserializer& deserializer); + + // Returns the number of values which will be stored in the given data size + // This must be consistent with the compression implementation for the given size + uint64_t numValues(uint64_t dataSize, common::PhysicalTypeID dataType) const; + uint64_t numValues(uint64_t dataSize, const common::LogicalType& dataType) const; + // Returns true if and only if the provided value within the vector can be updated + // in this chunk in-place. + bool canUpdateInPlace(const uint8_t* data, uint32_t pos, uint64_t numValues, + common::PhysicalTypeID physicalType, InPlaceUpdateLocalState& localUpdateState, + const std::optional& nullMask = std::nullopt) const; + bool canAlwaysUpdateInPlace() const; + + std::string toString(const common::PhysicalTypeID physicalType) const; +}; + +class CompressionAlg { +public: + virtual ~CompressionAlg() = default; + + // Takes a single uncompressed value from the srcBuffer and compresses it into the dstBuffer + // Offsets refer to value offsets, not byte offsets + // + // nullMask may be null if no mask is available (all values are non-null) + // Storage of null values is handled by the implementation and decompression of null values + // does not have to produce the original value passed to this function. + virtual void setValuesFromUncompressed(const uint8_t* srcBuffer, common::offset_t srcOffset, + uint8_t* dstBuffer, common::offset_t dstOffset, common::offset_t numValues, + const CompressionMetadata& metadata, const common::NullMask* nullMask) const = 0; + + // Takes uncompressed data from the srcBuffer and compresses it into the dstBuffer + // + // stores only as much data in dstBuffer as will fit, and advances the srcBuffer pointer + // to the beginning of the next value to store. + // (This means that we can't start the next page on an unaligned value. + // Maybe instead we could use value offsets, but the compression algorithms + // usually work on aligned chunks anyway) + // + // dstBufferSize is the size in bytes + // numValuesRemaining is the number of values remaining in the srcBuffer to be compressed. + // compressNextPage must store the least of either the number of values per page + // (as calculated by CompressionMetadata::numValues), or the remaining number of values. + // + // returns the size in bytes of the compressed data within the page (rounded up to the nearest + // byte) + virtual uint64_t compressNextPage(const uint8_t*& srcBuffer, uint64_t numValuesRemaining, + uint8_t* dstBuffer, uint64_t dstBufferSize, + const struct CompressionMetadata& metadata) const = 0; + + // Takes compressed data from the srcBuffer and decompresses it into the dstBuffer + // Offsets refer to value offsets, not byte offsets + // srcBuffer points to the beginning of a page + virtual void decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& metadata) const = 0; + + virtual CompressionType getCompressionType() const = 0; +}; + +class ConstantCompression final : public CompressionAlg { +public: + explicit ConstantCompression(const common::LogicalType& logicalType) + : numBytesPerValue{static_cast(getDataTypeSizeInChunk(logicalType))}, + dataType{logicalType.getPhysicalType()} {} + static std::optional analyze(const ColumnChunkData& chunk); + + // Shouldn't be used, there's a special case when compressing which ends early for constant + // compression + uint64_t compressNextPage(const uint8_t*&, uint64_t, uint8_t*, uint64_t, + const struct CompressionMetadata&) const override { + return 0; + }; + + static void decompressValues(uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + common::PhysicalTypeID physicalType, uint32_t numBytesPerValue, + const CompressionMetadata& metadata); + + void decompressFromPage(const uint8_t* /*srcBuffer*/, uint64_t /*srcOffset*/, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& metadata) const override; + + void copyFromPage(const uint8_t* /*srcBuffer*/, uint64_t /*srcOffset*/, uint8_t* dstBuffer, + uint64_t dstOffset, uint64_t numValues, const CompressionMetadata& metadata) const; + + // Nothing to do; constant compressed data is only updated if the update is to the same value + void setValuesFromUncompressed(const uint8_t*, common::offset_t, uint8_t*, common::offset_t, + common::offset_t, const CompressionMetadata&, + const common::NullMask* /*nullMask*/) const override {} + + CompressionType getCompressionType() const override { return CompressionType::CONSTANT; } + +private: + uint8_t numBytesPerValue; + common::PhysicalTypeID dataType; +}; + +// Compression alg which does not compress values and instead just copies them. +class Uncompressed : public CompressionAlg { +public: + explicit Uncompressed(common::PhysicalTypeID physicalType) + : numBytesPerValue{getDataTypeSizeInChunk(physicalType)} {} + explicit Uncompressed(const common::LogicalType& logicalType) + : Uncompressed(logicalType.getPhysicalType()) {} + explicit Uncompressed(uint8_t numBytesPerValue) : numBytesPerValue{numBytesPerValue} {} + + Uncompressed(const Uncompressed&) = default; + + inline void setValuesFromUncompressed(const uint8_t* srcBuffer, common::offset_t srcOffset, + uint8_t* dstBuffer, common::offset_t dstOffset, common::offset_t numValues, + const CompressionMetadata& /*metadata*/, const common::NullMask* /*nullMask*/) const final { + memcpy(dstBuffer + dstOffset * numBytesPerValue, srcBuffer + srcOffset * numBytesPerValue, + numBytesPerValue * numValues); + } + + static uint64_t numValues(uint64_t dataSize, common::PhysicalTypeID physicalType); + static uint64_t numValues(uint64_t dataSize, const common::LogicalType& logicalType); + + inline uint64_t compressNextPage(const uint8_t*& srcBuffer, uint64_t numValuesRemaining, + uint8_t* dstBuffer, uint64_t dstBufferSize, + const struct CompressionMetadata& /*metadata*/) const override { + if (numBytesPerValue == 0) { + return 0; + } + uint64_t numValues = std::min(numValuesRemaining, dstBufferSize / numBytesPerValue); + uint64_t sizeToCopy = numValues * numBytesPerValue; + KU_ASSERT(sizeToCopy <= dstBufferSize); + std::memcpy(dstBuffer, srcBuffer, sizeToCopy); + srcBuffer += sizeToCopy; + return sizeToCopy; + } + + inline void decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, uint8_t* dstBuffer, + uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& /*metadata*/) const override { + std::memcpy(dstBuffer + dstOffset * numBytesPerValue, + srcBuffer + srcOffset * numBytesPerValue, numValues * numBytesPerValue); + } + + CompressionType getCompressionType() const override { return CompressionType::UNCOMPRESSED; } + +protected: + const uint32_t numBytesPerValue; +}; + +template +struct BitpackInfo { + uint8_t bitWidth; + bool hasNegative; + T offset; +}; + +template +concept IntegerBitpackingType = (common::numeric_utils::IsIntegral && !std::same_as); + +// Augmented with Frame of Reference encoding using an offset stored in the compression metadata +template +class IntegerBitpacking : public CompressionAlg { + using U = common::numeric_utils::MakeUnSignedT; + +public: + // This is an implementation detail of the fastpfor bitpacking algorithm + static constexpr uint64_t CHUNK_SIZE = 32; + +public: + IntegerBitpacking() = default; + IntegerBitpacking(const IntegerBitpacking&) = default; + + void setValuesFromUncompressed(const uint8_t* srcBuffer, common::offset_t srcOffset, + uint8_t* dstBuffer, common::offset_t dstOffset, common::offset_t numValues, + const CompressionMetadata& metadata, const common::NullMask* nullMask) const final; + + static BitpackInfo getPackingInfo(const CompressionMetadata& metadata); + + static inline uint64_t numValues(uint64_t dataSize, const BitpackInfo& info) { + if (info.bitWidth == 0) { + return UINT64_MAX; + } + auto numValues = dataSize * 8 / info.bitWidth; + return numValues; + } + + static inline uint64_t numValues(uint64_t dataSize, const CompressionMetadata& metadata) { + auto info = getPackingInfo(metadata); + return numValues(dataSize, info); + } + + uint64_t compressNextPage(const uint8_t*& srcBuffer, uint64_t numValuesRemaining, + uint8_t* dstBuffer, uint64_t dstBufferSize, + const struct CompressionMetadata& metadata) const final; + + void decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, uint8_t* dstBuffer, + uint64_t dstOffset, uint64_t numValues, + const struct CompressionMetadata& metadata) const final; + + static bool canUpdateInPlace(std::span value, const CompressionMetadata& metadata, + const std::optional& nullMask = std::nullopt, + uint64_t nullMaskOffset = 0); + + CompressionType getCompressionType() const override { + return CompressionType::INTEGER_BITPACKING; + } + +protected: + // Read multiple values from within a chunk. Cannot span multiple chunks. + void getValues(const uint8_t* chunkStart, uint8_t pos, uint8_t* dst, uint8_t numValuesToRead, + const BitpackInfo& header) const; + + inline const uint8_t* getChunkStart(const uint8_t* buffer, uint64_t pos, + uint8_t bitWidth) const { + // Order of operations is important so that pos is rounded down to a multiple of + // CHUNK_SIZE + return buffer + (pos / CHUNK_SIZE) * bitWidth * CHUNK_SIZE / 8; + } + + void packPartialChunk(const U* srcBuffer, uint8_t* dstBuffer, size_t posInDst, + BitpackInfo info, size_t remainingValues) const; + + void copyValuesToTempChunkWithOffset(const U* srcBuffer, U* tmpBuffer, BitpackInfo info, + size_t numValuesToCopy) const; + + void setPartialChunkInPlace(const uint8_t* srcBuffer, common::offset_t posInSrc, + uint8_t* dstBuffer, common::offset_t posInDst, common::offset_t numValues, + const BitpackInfo& header) const; +}; + +class BooleanBitpacking : public CompressionAlg { +public: + BooleanBitpacking() = default; + BooleanBitpacking(const BooleanBitpacking&) = default; + + void setValuesFromUncompressed(const uint8_t* srcBuffer, common::offset_t srcOffset, + uint8_t* dstBuffer, common::offset_t dstOffset, common::offset_t numValues, + const CompressionMetadata& metadata, const common::NullMask* nullMask) const final; + + static inline uint64_t numValues(uint64_t dataSize) { return dataSize * 8; } + + uint64_t compressNextPage(const uint8_t*& srcBuffer, uint64_t numValuesRemaining, + uint8_t* dstBuffer, uint64_t dstBufferSize, + const struct CompressionMetadata& metadata) const final; + + void decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, uint8_t* dstBuffer, + uint64_t dstOffset, uint64_t numValues, const CompressionMetadata& metadata) const final; + + void copyFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, uint8_t* dstBuffer, + uint64_t dstOffset, uint64_t numValues, const CompressionMetadata& metadata) const; + + CompressionType getCompressionType() const override { + return CompressionType::BOOLEAN_BITPACKING; + } +}; + +class CompressedFunctor { +public: + CompressedFunctor(const CompressedFunctor&) = default; + +protected: + explicit CompressedFunctor(const common::LogicalType& logicalType) + : constant{logicalType}, uncompressed{logicalType}, + physicalType{logicalType.getPhysicalType()} {} + const ConstantCompression constant; + const Uncompressed uncompressed; + const BooleanBitpacking booleanBitpacking; + const common::PhysicalTypeID physicalType; +}; + +class ReadCompressedValuesFromPageToVector : public CompressedFunctor { +public: + explicit ReadCompressedValuesFromPageToVector(const common::LogicalType& logicalType) + : CompressedFunctor(logicalType) {} + ReadCompressedValuesFromPageToVector(const ReadCompressedValuesFromPageToVector&) = default; + + void operator()(const uint8_t* frame, PageCursor& pageCursor, common::ValueVector* resultVector, + uint32_t posInVector, uint64_t numValuesToRead, const CompressionMetadata& metadata); +}; + +class ReadCompressedValuesFromPage : public CompressedFunctor { +public: + explicit ReadCompressedValuesFromPage(const common::LogicalType& logicalType) + : CompressedFunctor(logicalType) {} + ReadCompressedValuesFromPage(const ReadCompressedValuesFromPage&) = default; + + void operator()(const uint8_t* frame, PageCursor& pageCursor, uint8_t* result, + uint32_t startPosInResult, uint64_t numValuesToRead, const CompressionMetadata& metadata); +}; + +class WriteCompressedValuesToPage : public CompressedFunctor { +public: + explicit WriteCompressedValuesToPage(const common::LogicalType& logicalType) + : CompressedFunctor(logicalType) {} + WriteCompressedValuesToPage(const WriteCompressedValuesToPage&) = default; + + void operator()(uint8_t* frame, uint16_t posInFrame, const uint8_t* data, + common::offset_t dataOffset, common::offset_t numValues, + const CompressionMetadata& metadata, const common::NullMask* nullMask = nullptr); + + void operator()(uint8_t* frame, uint16_t posInFrame, common::ValueVector* vector, + uint32_t posInVector, common::offset_t numValues, const CompressionMetadata& metadata); +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/float_compression.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/float_compression.h new file mode 100644 index 0000000000..3f8e3f7c3e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/float_compression.h @@ -0,0 +1,95 @@ +#pragma once + +#include + +#include "storage/compression/compression.h" +#include + +namespace lbug { +namespace common { +class ValueVector; +class NullMask; +} // namespace common + +namespace storage { +class ColumnChunkData; + +struct PageCursor; + +template +struct EncodeException { + T value; + uint32_t posInChunk; + + static constexpr size_t sizeInBytes() { return sizeof(value) + sizeof(posInChunk); } + + static size_t numPagesFromExceptions(size_t exceptionCount); + + static size_t exceptionBytesPerPage(); + + bool operator<(const EncodeException& o) const; +}; + +template +struct EncodeExceptionView { + // Used to access ALP exceptions that are stored in buffers + // We don't use the EncodeException struct directly since we don't want to copy struct padding + explicit EncodeExceptionView(std::byte* val) { bytes = val; } + + EncodeException getValue(common::offset_t elementOffset = 0) const; + void setValue(EncodeException exception, common::offset_t elementOffset = 0); + std::byte* bytes; +}; + +template +class FloatCompression final : public CompressionAlg { +public: + using EncodedType = std::conditional_t, int64_t, int32_t>; + static constexpr size_t MAX_EXCEPTION_FACTOR = 4; + +public: + FloatCompression(); + + void setValuesFromUncompressed(const uint8_t* srcBuffer, common::offset_t srcOffset, + uint8_t* dstBuffer, common::offset_t dstOffset, common::offset_t numValues, + const CompressionMetadata& metadata, const common::NullMask* nullMask) const override; + + static uint64_t numValues(uint64_t dataSize, const CompressionMetadata& metadata); + + // this is included to satisfy the CompressionAlg interface but we don't actually use it + uint64_t compressNextPage(const uint8_t*& srcBuffer, uint64_t numValuesRemaining, + uint8_t* dstBuffer, uint64_t dstBufferSize, + const CompressionMetadata& metadata) const override; + + uint64_t compressNextPageWithExceptions(const uint8_t*& srcBuffer, uint64_t srcOffset, + uint64_t numValuesRemaining, uint8_t* dstBuffer, uint64_t dstBufferSize, + EncodeExceptionView exceptionBuffer, uint64_t exceptionBufferSize, + uint64_t& exceptionCount, const CompressionMetadata& metadata) const; + + // does not patch exceptions (this is handled by the column reader) + void decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, uint8_t* dstBuffer, + uint64_t dstOffset, uint64_t numValues, const CompressionMetadata& metadata) const override; + + static bool canUpdateInPlace(std::span value, const CompressionMetadata& metadata, + InPlaceUpdateLocalState& localUpdateState, + const std::optional& nullMask = std::nullopt, + uint64_t nullMaskOffset = 0); + + CompressionType getCompressionType() const override { return CompressionType::ALP; } + + static BitpackInfo getBitpackInfo(const CompressionMetadata& metadata); + + // Returns number of pages for storing bitpacked ALP values (excluding pages reserved for + // exceptions) + static common::page_idx_t getNumDataPages(common::page_idx_t numTotalPages, + const CompressionMetadata& compMeta); + +private: + const CompressionAlg& getEncodedFloatBitpacker(const CompressionMetadata& metadata) const; + + ConstantCompression constantEncodedFloatBitpacker; + IntegerBitpacking encodedFloatBitpacker; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/sign_extend.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/sign_extend.h new file mode 100644 index 0000000000..e86eb1dc13 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/compression/sign_extend.h @@ -0,0 +1,56 @@ +#pragma once + +/* Adapted from + * https://github.com/duckdb/duckdb/blob/312b9954507386305544a42c4f43c2bd410a64cb/src/include/duckdb/common/bitpacking.hpp#L190-L199 + * + * Copyright 2018-2023 Stichting DuckDB Foundation + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and + * associated documentation files (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, publish, distribute, + * sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or + * substantial portions of the Software. + */ + +#include + +#include + +#include "common/assert.h" +#include "common/numeric_utils.h" +#include "common/utils.h" + +namespace lbug { +namespace storage { + +template +void Store(const T& val, uint8_t* ptr) { + memcpy(ptr, (void*)&val, sizeof(val)); +} + +template +T Load(const uint8_t* ptr) { + T ret{}; + memcpy(&ret, ptr, sizeof(ret)); + return ret; +} + +// Sign bit extension +template, uint64_t CHUNK_SIZE> +static void SignExtend(uint8_t* dst, uint8_t width) { + KU_ASSERT(width < sizeof(T) * 8); + T const mask = T_U(1) << (width - 1); + for (uint64_t i = 0; i < CHUNK_SIZE; ++i) { + T value = Load(dst + i * sizeof(T)); + const T_U andMask = common::BitmaskUtils::all1sMaskForLeastSignificantBits(width); + value = value & andMask; + T result = (value ^ mask) - mask; + Store(result, dst + i * sizeof(T)); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/database_header.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/database_header.h new file mode 100644 index 0000000000..2df60b0f86 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/database_header.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include "common/types/uuid.h" +#include "storage/page_range.h" + +namespace lbug { +namespace storage { +class PageManager; + +struct DatabaseHeader { + PageRange catalogPageRange; + PageRange metadataPageRange; + + // An ID that is unique between lbug databases + // Used to ensure that files such as the WAL match the current database + common::ku_uuid_t databaseID{0}; + + void updateCatalogPageRange(PageManager& pageManager, PageRange newPageRange); + void freeMetadataPageRange(PageManager& pageManager) const; + void serialize(common::Serializer& ser) const; + + static DatabaseHeader deserialize(common::Deserializer& deSer); + static DatabaseHeader createInitialHeader(common::RandomEngine* randomEngine); + + // If we haven't written a header to the database file yet (e.g. if the file is empty) this + // function will return a nullopt + static std::optional readDatabaseHeader(common::FileInfo& dataFileInfo); +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/disk_array.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/disk_array.h new file mode 100644 index 0000000000..ece4e68025 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/disk_array.h @@ -0,0 +1,411 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "common/copy_constructors.h" +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/shadow_utils.h" +#include "storage/storage_utils.h" +#include "transaction/transaction.h" +#include +#include + +namespace lbug { +namespace storage { +class PageAllocator; +class FileHandle; +class BufferManager; + +static constexpr uint64_t NUM_PAGE_IDXS_PER_PIP = + (common::LBUG_PAGE_SIZE - sizeof(common::page_idx_t)) / sizeof(common::page_idx_t); + +struct DiskArrayHeader { + DiskArrayHeader() : numElements{0}, firstPIPPageIdx{common::INVALID_PAGE_IDX} {} + bool operator==(const DiskArrayHeader& other) const = default; + + uint64_t numElements; + common::page_idx_t firstPIPPageIdx; + uint32_t _padding{}; +}; +static_assert(std::has_unique_object_representations_v); + +/** + * Data for page-based storage helper functions + */ +struct PageStorageInfo { + explicit PageStorageInfo(uint64_t elementSize); + + uint64_t alignedElementSize; + uint64_t numElementsPerPage; +}; + +// TODO(bmwinger): this should use the memoryManager +struct PIP { + PIP() : nextPipPageIdx{ShadowUtils::NULL_PAGE_IDX}, pageIdxs{} { + for (auto& pageIdx : pageIdxs) { + pageIdx = ShadowUtils::NULL_PAGE_IDX; + } + } + + common::page_idx_t nextPipPageIdx; + common::page_idx_t pageIdxs[NUM_PAGE_IDXS_PER_PIP]; +}; +static_assert(sizeof(PIP) == common::LBUG_PAGE_SIZE); + +struct PIPWrapper { + PIPWrapper(const FileHandle& fileHandle, common::page_idx_t pipPageIdx); + + explicit PIPWrapper(common::page_idx_t pipPageIdx) : pipPageIdx(pipPageIdx) {} + + common::page_idx_t pipPageIdx; + PIP pipContents; +}; + +struct PIPUpdates { + // Since PIPs are only appended to, the only existing PIP which may be modified is the last one + // This gets tracked separately to make indexing into newPIPs simpler. + std::optional updatedLastPIP; + std::vector newPIPs; + + void clear() { + updatedLastPIP.reset(); + newPIPs.clear(); + } +}; + +/** + * DiskArray stores a disk-based array in a file. The array is broken down into a predefined and + * stable header page, i.e., the header page of the array is always in a pre-allocated page in the + * file. The header page contains the pointer to the first ``page indices page'' (pip). Each pip + * stores a list of page indices that store the ``array pages''. Each PIP also stores the pageIdx of + * the next PIP if one exists (or we use StorageConstants::NULL_PAGE_IDX as null). Array pages store + * the actual data in the array. + * + * Storage structures can use multiple disk arrays in a single file by giving each one a different + * pre-allocated stable header pageIdxs. + * + * We clarify the following abbreviations and conventions in the variables used in these files: + *
    + *
  • pip: Page Indices Page + *
  • pipIdx: logical index of a PIP in DiskArray. For example a variable pipIdx we use with + * value 5 indicates the 5th PIP, not the physical disk pageIdx of where that PIP is stored. + *
  • pipPageIdx: the physical disk pageIdx of some PIP + *
  • AP: Array Page + *
  • apIdx: logical index of the array page in DiskArray. For example a variable apIdx with + * value 5 indicates the 5th array page of the Disk Array (i.e., the physical offset of this would + * correspond to the 5 element in the first PIP) not the physical disk pageIdx of an array page.
  • apPageIdx: the physical disk pageIdx of some PIP. + *
+ */ +class DiskArrayInternal { +public: + // Used when loading from file + DiskArrayInternal(FileHandle& fileHandle, const DiskArrayHeader& headerForReadTrx, + DiskArrayHeader& headerForWriteTrx, ShadowFile* shadowFile, uint64_t elementSize, + bool bypassShadowing = false); + + uint64_t getNumElements( + transaction::TransactionType trxType = transaction::TransactionType::READ_ONLY); + + void get(uint64_t idx, const transaction::Transaction* transaction, std::span val); + + // Note: This function is to be used only by the WRITE trx. + void update(const transaction::Transaction* transaction, uint64_t idx, + std::span val); + + // Note: Currently, this function doesn't support shrinking the size of the array. + uint64_t resize(PageAllocator& pageAllocator, const transaction::Transaction* transaction, + uint64_t newNumElements, std::span defaultVal); + + void checkpointInMemoryIfNecessary() { + std::unique_lock xlock{this->diskArraySharedMtx}; + checkpointOrRollbackInMemoryIfNecessaryNoLock(true /* is checkpoint */); + } + void rollbackInMemoryIfNecessary() { + std::unique_lock xlock{this->diskArraySharedMtx}; + checkpointOrRollbackInMemoryIfNecessaryNoLock(false /* is rollback */); + } + + void checkpoint(); + + void reclaimStorage(PageAllocator& pageAllocator) const; + + // Write WriteIterator for making fast bulk changes to the disk array + // The pages are cached while the elements are stored on the same page + // Designed for sequential writes, but supports random writes too (at the cost that the page + // caching is only beneficial when seeking from one element to another on the same page) + // + // The iterator is not locked, allowing multiple to be used at the same time, but access to + // individual pages is locked through the FileHandle. It will hang if you seek/pushback on the + // same page as another iterator in an overlapping scope. + struct WriteIterator { + DiskArrayInternal& diskArray; + PageCursor apCursor; + uint32_t valueSize; + // TODO(bmwinger): Instead of pinning the page and updating in-place, it might be better to + // read and cache the page, then write the page to the WAL if it's ever modified. However + // when doing bulk hashindex inserts, there's a high likelihood that every page accessed + // will be modified, so it may be faster this way. + ShadowPageAndFrame shadowPageAndFrame; + static const transaction::TransactionType TRX_TYPE = + transaction::TransactionType::CHECKPOINT; + uint64_t idx; + DEFAULT_MOVE_CONSTRUCT(WriteIterator); + + // Constructs WriteIterator in an invalid state. Seek must be called before accessing data + WriteIterator(uint32_t valueSize, DiskArrayInternal& diskArray) + : diskArray(diskArray), apCursor(), valueSize(valueSize), + shadowPageAndFrame{common::INVALID_PAGE_IDX, common::INVALID_PAGE_IDX, nullptr}, + idx(0) { + diskArray.hasTransactionalUpdates = true; + } + + WriteIterator& seek(size_t newIdx); + // Adds a new element to the disk array and seeks to the new element + void pushBack(PageAllocator& pageAllocator, const transaction::Transaction* transaction, + std::span val); + + inline WriteIterator& operator+=(size_t increment) { return seek(idx + increment); } + + ~WriteIterator() { unpin(); } + + std::span operator*() const { + KU_ASSERT(idx < diskArray.headerForWriteTrx.numElements); + KU_ASSERT(shadowPageAndFrame.originalPage != common::INVALID_PAGE_IDX); + return std::span(shadowPageAndFrame.frame + apCursor.elemPosInPage, valueSize); + } + + uint64_t size() const { return diskArray.headerForWriteTrx.numElements; } + + private: + void unpin(); + void getPage(common::page_idx_t newPageIdx, bool isNewlyAdded); + }; + + WriteIterator iter_mut(uint64_t valueSize); + + inline common::page_idx_t getAPIdx(uint64_t idx) const; + +protected: + // Updates to new pages (new to this transaction) bypass the wal file. + void updatePage(uint64_t pageIdx, bool isNewPage, std::function updateOp); + + void updateLastPageOnDisk(); + + uint64_t getNumElementsNoLock(transaction::TransactionType trxType) const { + return getDiskArrayHeader(trxType).numElements; + } + + uint64_t getNumAPs(const DiskArrayHeader& header) const { + return (header.numElements + storageInfo.numElementsPerPage - 1) / + storageInfo.numElementsPerPage; + } + + void setNextPIPPageIDxOfPIPNoLock(uint64_t pipIdxOfPreviousPIP, + common::page_idx_t nextPIPPageIdx); + + // This function does division and mod and should not be used in performance critical code. + common::page_idx_t getAPPageIdxNoLock(common::page_idx_t apIdx, + transaction::TransactionType trxType = transaction::TransactionType::READ_ONLY); + + // pipIdx is the idx of the PIP, and not the physical pageIdx. This function assumes + // that the caller has called hasPIPUpdatesNoLock and received true. + common::page_idx_t getUpdatedPageIdxOfPipNoLock(uint64_t pipIdx); + + void clearWALPageVersionAndRemovePageFromFrameIfNecessary(common::page_idx_t pageIdx); + + void checkpointOrRollbackInMemoryIfNecessaryNoLock(bool isCheckpoint); + +private: + bool checkOutOfBoundAccess(transaction::TransactionType trxType, uint64_t idx) const; + bool hasPIPUpdatesNoLock(uint64_t pipIdx) const; + + const DiskArrayHeader& getDiskArrayHeader(transaction::TransactionType trxType) const { + if (trxType == transaction::TransactionType::CHECKPOINT) { + return headerForWriteTrx; + } + return header; + } + + // Returns the apPageIdx of the AP with idx apIdx and a bool indicating whether the apPageIdx is + // a newly inserted page. + std::pair getAPPageIdxAndAddAPToPIPIfNecessaryForWriteTrxNoLock( + PageAllocator& pageAllocator, const transaction::Transaction* transaction, + common::page_idx_t apIdx); + +protected: + PageStorageInfo storageInfo; + FileHandle& fileHandle; + const DiskArrayHeader& header; + DiskArrayHeader& headerForWriteTrx; + bool hasTransactionalUpdates; + ShadowFile* shadowFile; + std::vector pips; + PIPUpdates pipUpdates; + std::shared_mutex diskArraySharedMtx; + // For write transactions only + common::page_idx_t lastAPPageIdx; + common::page_idx_t lastPageOnDisk; +}; + +template +inline std::span getSpan(U& val) { + return std::span(reinterpret_cast(&val), sizeof(U)); +} + +template +class DiskArray { + static_assert(sizeof(U) <= common::LBUG_PAGE_SIZE); + +public: + // If bypassWAL is set, the buffer manager is used to pages new to this transaction to the + // original file, but does not handle flushing them. BufferManager::flushAllDirtyPagesInFrames + // should be called on this file handle exactly once during prepare commit. + DiskArray(FileHandle& fileHandle, const DiskArrayHeader& headerForReadTrx, + DiskArrayHeader& headerForWriteTrx, ShadowFile* shadowFile, bool bypassWAL = false) + : diskArray(fileHandle, headerForReadTrx, headerForWriteTrx, shadowFile, sizeof(U), + bypassWAL) {} + + // Note: This function is to be used only by the WRITE trx. + inline void update(const transaction::Transaction* transaction, uint64_t idx, U val) { + diskArray.update(transaction, idx, getSpan(val)); + } + + inline U get(uint64_t idx, const transaction::Transaction* transaction) { + U val; + diskArray.get(idx, transaction, getSpan(val)); + return val; + } + + // Note: Currently, this function doesn't support shrinking the size of the array. + inline uint64_t resize(PageAllocator& pageAllocator, + const transaction::Transaction* transaction, uint64_t newNumElements) { + U defaultVal; + return diskArray.resize(pageAllocator, transaction, newNumElements, getSpan(defaultVal)); + } + + inline uint64_t getNumElements( + transaction::TransactionType trxType = transaction::TransactionType::READ_ONLY) { + return diskArray.getNumElements(trxType); + } + + inline void checkpointInMemoryIfNecessary() { diskArray.checkpointInMemoryIfNecessary(); } + inline void rollbackInMemoryIfNecessary() { diskArray.rollbackInMemoryIfNecessary(); } + inline void checkpoint() { diskArray.checkpoint(); } + inline void reclaimStorage(PageAllocator& pageAllocator) const { + diskArray.reclaimStorage(pageAllocator); + } + + class WriteIterator { + public: + explicit WriteIterator(DiskArrayInternal::WriteIterator&& iter) : iter(std::move(iter)) {} + inline U& operator*() { return *reinterpret_cast((*iter).data()); } + DELETE_COPY_DEFAULT_MOVE(WriteIterator); + + inline WriteIterator& operator+=(size_t dist) { + iter += dist; + return *this; + } + + inline WriteIterator& seek(size_t idx) { + iter.seek(idx); + return *this; + } + + inline uint64_t idx() const { return iter.idx; } + inline uint64_t getAPIdx() const { return iter.apCursor.pageIdx; } + + inline WriteIterator& pushBack(PageAllocator& pageAllocator, + const transaction::Transaction* transaction, U val) { + iter.pushBack(pageAllocator, transaction, getSpan(val)); + return *this; + } + + inline uint64_t size() const { return iter.size(); } + + private: + DiskArrayInternal::WriteIterator iter; + }; + + inline WriteIterator iter_mut() { return WriteIterator{diskArray.iter_mut(sizeof(U))}; } + inline uint64_t getAPIdx(uint64_t idx) const { return diskArray.getAPIdx(idx); } + static constexpr uint32_t getAlignedElementSize() { return std::bit_ceil(sizeof(U)); } + +private: + DiskArrayInternal diskArray; +}; + +class BlockVectorInternal { +public: + using element_construct_func_t = std::function; + + explicit BlockVectorInternal(MemoryManager& memoryManager, size_t elementSize) + : storageInfo{elementSize}, numElements{0}, memoryManager{memoryManager} {} + + // This function is designed to be used during building of a disk array, i.e., during loading. + // In particular, it changes the needed capacity non-transactionally, i.e., without writing + // anything to the wal. + void resize(uint64_t newNumElements, const element_construct_func_t& defaultConstructor); + + inline uint64_t size() const { return numElements; } + + // [] operator can be used to update elements, e.g., diskArray[5] = 4, when building an + // InMemDiskArrayBuilder without transactional updates. This changes the contents directly in + // memory and not on disk (nor on the wal). + uint8_t* operator[](uint64_t idx) const; + +private: + inline uint64_t getNumArrayPagesNeededForElements(uint64_t numElements) const { + return (numElements + this->storageInfo.numElementsPerPage - 1) / + this->storageInfo.numElementsPerPage; + } + +protected: + std::vector> inMemArrayPages; + PageStorageInfo storageInfo; + uint64_t numElements; + MemoryManager& memoryManager; +}; + +template +class BlockVector { +public: + explicit BlockVector(MemoryManager& memoryManager, uint64_t numElements = 0) + : vector(memoryManager, sizeof(U)) { + resize(numElements); + } + + ~BlockVector() { + for (uint64_t i = 0; i < size(); ++i) { + operator[](i).~U(); + } + } + + inline U& operator[](uint64_t idx) { return *(U*)vector[idx]; } + + inline void resize(uint64_t newNumElements) { + // NOLINTNEXTLINE(readability-non-const-parameter) placement-new requires non-const ptr + static constexpr auto defaultConstructor = [](uint8_t* data) { + [[maybe_unused]] auto* p = new (data) U(); + KU_ASSERT(p); + }; + vector.resize(newNumElements, defaultConstructor); + } + + inline uint64_t size() const { return vector.size(); } + + static constexpr uint32_t getAlignedElementSize() { + return DiskArray::getAlignedElementSize(); + } + +private: + BlockVectorInternal vector; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/disk_array_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/disk_array_collection.h new file mode 100644 index 0000000000..80f429c4cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/disk_array_collection.h @@ -0,0 +1,78 @@ +#pragma once + +#include + +#include "common/types/types.h" +#include "disk_array.h" + +namespace lbug { +namespace storage { + +class FileHandle; + +class DiskArrayCollection { + struct HeaderPage { + explicit HeaderPage(uint32_t numHeaders = 0) + : nextHeaderPage{common::INVALID_PAGE_IDX}, numHeaders{numHeaders} {} + static constexpr size_t NUM_HEADERS_PER_PAGE = + (common::LBUG_PAGE_SIZE - sizeof(common::page_idx_t) - sizeof(uint32_t)) / + sizeof(DiskArrayHeader); + + bool operator==(const HeaderPage&) const = default; + + std::array headers; + common::page_idx_t nextHeaderPage; + uint32_t numHeaders; + }; + static_assert(std::has_unique_object_representations_v); + +public: + DiskArrayCollection(FileHandle& fileHandle, ShadowFile& shadowFile, + bool bypassShadowing = false); + DiskArrayCollection(FileHandle& fileHandle, ShadowFile& shadowFile, + common::page_idx_t firstHeaderPage, bool bypassShadowing = false); + + void checkpoint(common::page_idx_t firstHeaderPage, PageAllocator& pageAllocator); + + void checkpointInMemory() { + for (size_t i = 0; i < headersForWriteTrx.size(); i++) { + *headersForReadTrx[i] = *headersForWriteTrx[i]; + } + headerPagesOnDisk = headersForReadTrx.size(); + } + + void rollbackCheckpoint() { + for (size_t i = 0; i < headersForWriteTrx.size(); i++) { + *headersForWriteTrx[i] = *headersForReadTrx[i]; + } + } + + void reclaimStorage(PageAllocator& pageAllocator, common::page_idx_t firstHeaderPage) const; + + template + std::unique_ptr> getDiskArray(uint32_t idx) { + KU_ASSERT(idx < numHeaders); + auto& readHeader = headersForReadTrx[idx / HeaderPage::NUM_HEADERS_PER_PAGE] + ->headers[idx % HeaderPage::NUM_HEADERS_PER_PAGE]; + auto& writeHeader = headersForWriteTrx[idx / HeaderPage::NUM_HEADERS_PER_PAGE] + ->headers[idx % HeaderPage::NUM_HEADERS_PER_PAGE]; + return std::make_unique>(fileHandle, readHeader, writeHeader, &shadowFile, + bypassShadowing); + } + + size_t addDiskArray(); + + void populateNextHeaderPage(PageAllocator& pageAllocator, common::page_idx_t indexInMemory); + +private: + FileHandle& fileHandle; + ShadowFile& shadowFile; + bool bypassShadowing; + common::page_idx_t headerPagesOnDisk; + std::vector> headersForReadTrx; + std::vector> headersForWriteTrx; + uint64_t numHeaders; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/csr_node_group_scan_source.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/csr_node_group_scan_source.h new file mode 100644 index 0000000000..0a2061281f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/csr_node_group_scan_source.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace lbug::storage { +enum class CSRNodeGroupScanSource : uint8_t { + COMMITTED_PERSISTENT = 0, + COMMITTED_IN_MEMORY = 1, + UNCOMMITTED = 2, + NONE = 10 +}; +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/page_read_policy.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/page_read_policy.h new file mode 100644 index 0000000000..222382ca18 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/page_read_policy.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace lbug { +namespace storage { + +enum class PageReadPolicy : uint8_t { READ_PAGE = 0, DONT_READ_PAGE = 1 }; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/residency_state.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/residency_state.h new file mode 100644 index 0000000000..5eaf4dddd4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/enums/residency_state.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "common/assert.h" + +namespace lbug { +namespace storage { + +enum class ResidencyState : uint8_t { IN_MEMORY = 0, ON_DISK = 1 }; + +struct ResidencyStateUtils { + static std::string toString(ResidencyState residencyState) { + switch (residencyState) { + case ResidencyState::IN_MEMORY: { + return "IN_MEMORY"; + } + case ResidencyState::ON_DISK: { + return "ON_DISK"; + } + default: { + KU_UNREACHABLE; + } + } + } +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/file_db_id_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/file_db_id_utils.h new file mode 100644 index 0000000000..4c3d50dfeb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/file_db_id_utils.h @@ -0,0 +1,18 @@ +#pragma once + +#include "common/file_system/file_info.h" +#include "common/serializer/serializer.h" +#include "common/types/uuid.h" +namespace lbug { +namespace storage { +struct FileDBIDUtils { + // For some temporary DB files such as the WAL and shadow file + // We want to verify that they actually match the current database before replaying + // We do this by adding a unique UUID to the header of the data.kz file + // And making sure they match the IDs of the temporary files + static void verifyDatabaseID(const common::FileInfo& fileInfo, + common::ku_uuid_t expectedDatabaseID, common::ku_uuid_t databaseID); + static void writeDatabaseID(common::Serializer& ser, common::ku_uuid_t databaseID); +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/file_handle.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/file_handle.h new file mode 100644 index 0000000000..d1629284e2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/file_handle.h @@ -0,0 +1,172 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/concurrent_vector.h" +#include "common/constants.h" +#include "common/copy_constructors.h" +#include "common/file_system/file_info.h" +#include "common/types/types.h" +#include "storage/buffer_manager/page_state.h" +#include "storage/buffer_manager/vm_region.h" +#include "storage/enums/page_read_policy.h" +#include "storage/page_manager.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace common { +class VirtualFileSystem; +} + +namespace storage { +// FileHandle serves several purposes: +// 1) holds basic state information of a file, including FileInfo, flags, pageSize, +// numPages, and pageCapacity. +// 2) provides utility methods to read/write pages from/to the file. +// 3) holds the state of each page in the file in buffer manager. File Handle is the bridge between +// a data structure and the Buffer Manager that abstracts the file in which that data structure is +// stored. + +class ShadowFile; +class BufferManager; +class FileHandle { +public: + friend class BufferManager; + friend class ShadowFile; + + constexpr static uint8_t isLargePagedMask{0b0000'0001}; // represents 1st least sig. bit (LSB) + constexpr static uint8_t isNewInMemoryTmpFileMask{0b0000'0010}; // represents 2nd LSB + // createIfNotExistsMask only applies to existing db files; tmp i-memory files are not created + constexpr static uint8_t createIfNotExistsMask{0b0000'0100}; // represents 3rd LSB + constexpr static uint8_t isReadOnlyMask{0b0000'1000}; // represents 4th LSB + constexpr static uint8_t isLockRequiredMask{0b1000'0000}; // represents 8th LSB + + // READ_ONLY subsumes DEFAULT_PAGED, PERSISTENT, and NO_CREATE. + constexpr static uint8_t O_PERSISTENT_FILE_READ_ONLY{0b0000'1000}; + constexpr static uint8_t O_PERSISTENT_FILE_CREATE_NOT_EXISTS{0b0000'0100}; + constexpr static uint8_t O_IN_MEM_TEMP_FILE{0b0000'0011}; + constexpr static uint8_t O_PERSISTENT_FILE_IN_MEM{0b0000'0010}; + constexpr static uint8_t O_LOCKED_PERSISTENT_FILE{0b1000'0000}; + + FileHandle(const std::string& path, uint8_t fhFlags, BufferManager* bm, uint32_t fileIndex, + common::VirtualFileSystem* vfs, main::ClientContext* context); + // File handles are registered with the buffer manager and must not be moved or copied + DELETE_COPY_AND_MOVE(FileHandle); + + uint8_t* pinPage(common::page_idx_t pageIdx, PageReadPolicy readPolicy); + void optimisticReadPage(common::page_idx_t pageIdx, + const std::function& readOp); + // The function assumes that the requested page is already pinned. + void unpinPage(common::page_idx_t pageIdx); + + // This function assumes the page is already LOCKED. + void setLockedPageDirty(common::page_idx_t pageIdx) { + KU_ASSERT(pageIdx < numPages); + pageStates[pageIdx].setDirty(); + } + + common::file_idx_t getFileIndex() const { return fileIndex; } + uint8_t* getFrame(common::page_idx_t pageIdx); + PageState* getPageState(common::page_idx_t pageIdx) { return &pageStates[pageIdx]; } + + // Pages added through these APIs are not tracked by the FSM + // If allocating pages from the data.kz file it's recommended to do so using the PageManager + common::page_idx_t addNewPage(); + common::page_idx_t addNewPages(common::page_idx_t numNewPages); + + void removePageIdxAndTruncateIfNecessary(common::page_idx_t pageIdx); + void removePageFromFrameIfNecessary(common::page_idx_t pageIdx); + void flushAllDirtyPagesInFrames(); + + void readPageFromDisk(uint8_t* frame, common::page_idx_t pageIdx) const { + KU_ASSERT(!isInMemoryMode()); + KU_ASSERT(pageIdx < numPages); + fileInfo->readFromFile(frame, getPageSize(), pageIdx * getPageSize()); + } + void writePageToFile(const uint8_t* buffer, common::page_idx_t pageIdx) { + KU_ASSERT(pageIdx < numPages); + writePagesToFile(buffer, getPageSize(), pageIdx); + } + void writePagesToFile(const uint8_t* buffer, uint64_t size, common::page_idx_t startPageIdx); + + bool isInMemoryMode() const { return !isLargePaged() && isNewTmpFile(); } + + common::page_idx_t getNumPages() const { return numPages; } + common::FileInfo* getFileInfo() const { return fileInfo.get(); } + void resetFileInfo() { fileInfo.reset(); } + + uint64_t getPageSize() const { + return isLargePaged() ? common::TEMP_PAGE_SIZE : common::LBUG_PAGE_SIZE; + } + + PageManager* getPageManager() { return pageManager.get(); } + +private: + bool isLargePaged() const { return fhFlags & isLargePagedMask; } + bool isNewTmpFile() const { return fhFlags & isNewInMemoryTmpFileMask; } + bool isReadOnlyFile() const { return fhFlags & isReadOnlyMask; } + bool createFileIfNotExists() const { return fhFlags & createIfNotExistsMask; } + bool isLockRequired() const { return fhFlags & isLockRequiredMask; } + + common::page_idx_t addNewPageWithoutLock(); + void constructPersistentFileHandle(const std::string& path, common::VirtualFileSystem* vfs, + main::ClientContext* context); + void constructTmpFileHandle(const std::string& path); + common::frame_idx_t getFrameIdx(common::page_idx_t pageIdx) { + KU_ASSERT(pageIdx < pageCapacity); + return (frameGroupIdxes[pageIdx >> common::StorageConstants::PAGE_GROUP_SIZE_LOG2] + << common::StorageConstants::PAGE_GROUP_SIZE_LOG2) | + (pageIdx & common::StorageConstants::PAGE_IDX_IN_GROUP_MASK); + } + common::PageSizeClass getPageSizeClass() const { return pageSizeClass; } + + void addNewPageGroupWithoutLock(); + common::page_group_idx_t getNumPageGroups() const { + return ceil(static_cast(numPages) / common::StorageConstants::PAGE_GROUP_SIZE); + } + // This function is intended to be used after a fileInfo is created and we want the file + // to have no pages and page locks. Should be called after ensuring that the buffer manager + // does not hold any of the pages of the file. + void resetToZeroPagesAndPageCapacity(); + void flushPageIfDirtyWithoutLock(common::page_idx_t pageIdx); + +private: + // Intended to be used to coordinate calls to functions that change in the internal data + // structures of the file handle. + std::shared_mutex fhSharedMutex; + + uint8_t fhFlags; + std::unique_ptr fileInfo; + common::file_idx_t fileIndex; + // Actually allocated/used number of pages in the file. + std::atomic numPages; + // This is the maximum number of pages the filehandle can currently support. + uint32_t pageCapacity; + + BufferManager* bm; + common::PageSizeClass pageSizeClass; + // With a page group size of 2^10 and an 256KB index size, the access cost increases + // only with each 128GB added to the file + common::ConcurrentVector + pageStates; + // Each file page group corresponds to a frame group in the VMRegion. + // Just one frame group for each page group, so performance is less sensitive than pageStates + // and left at the default which won't increase access cost for the frame groups until 16TB of + // data has been written + common::ConcurrentVector frameGroupIdxes; + + std::unique_ptr pageManager; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/free_space_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/free_space_manager.h new file mode 100644 index 0000000000..3fb61b822d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/free_space_manager.h @@ -0,0 +1,94 @@ +/** + * We would like to thank Mingkun Ni and Mayank Jasoria for doing the initial research and + * prototyping for the FreeSpaceManager in their CS 848 course project: + * https://github.com/ericpolo/lbug_cs848 + */ + +#pragma once + +#include +#include + +#include "common/types/types.h" +namespace lbug::storage { + +class BufferManager; +struct PageRange; +struct FreeEntryIterator; +class FileHandle; + +class FreeSpaceManager { +public: + static bool entryCmp(const PageRange& a, const PageRange& b); + using sorted_free_list_t = std::set; + using free_list_t = std::vector; + + FreeSpaceManager(); + + void addFreePages(PageRange entry); + void evictAndAddFreePages(FileHandle* fileHandle, PageRange entry); + std::optional popFreePages(common::page_idx_t numPages); + + // These pages are not reusable until the end of the next checkpoint + void addUncheckpointedFreePages(PageRange entry); + void rollbackCheckpoint(); + + common::page_idx_t getMaxNumPagesForSerialization() const; + void serialize(common::Serializer& serializer) const; + void deserialize(common::Deserializer& deSer); + void finalizeCheckpoint(FileHandle* fileHandle); + + common::row_idx_t getNumEntries() const; + std::vector getEntries(common::row_idx_t startOffset, + common::row_idx_t endOffset) const; + + // When a page is freed by the FSM, it evicts it from the BM. However, if the page is freed, + // then reused over and over, it can be appended to the eviction queue multiple times. To + // prevent multiple entries of the same page from existing in the eviction queue, at the end of + // each checkpoint we remove any already-evicted pages. + void clearEvictedBufferManagerEntriesIfNeeded(BufferManager* bufferManager); + +private: + PageRange splitPageRange(PageRange chunk, common::page_idx_t numRequiredPages); + void mergePageRanges(free_list_t newInitialEntries, FileHandle* fileHandle); + void handleLastPageRange(PageRange pageRange, FileHandle* fileHandle); + void resetFreeLists(); + static common::idx_t getLevel(common::page_idx_t numPages); + void evictPages(FileHandle* fileHandle, const PageRange& entry); + + template + void serializeInternal(ValueProcessor& serializer) const; + + std::vector freeLists; + free_list_t uncheckpointedFreePageRanges; + common::row_idx_t numEntries; + bool needClearEvictedEntries; +}; + +/** + * Used for iterating over all entries in the FreeSpaceManager + * Note that the iterator may become invalidated in the FSM is modified + */ +struct FreeEntryIterator { + explicit FreeEntryIterator(const std::vector& freeLists) + : FreeEntryIterator(freeLists, 0) {} + + FreeEntryIterator(const std::vector& freeLists, + common::idx_t freeListIdx_) + : freeLists(freeLists), freeListIdx(freeListIdx_) { + advanceFreeListIdx(); + } + + void advance(common::row_idx_t numEntries); + void operator++(); + PageRange operator*() const; + bool done() const; + + void advanceFreeListIdx(); + + const std::vector& freeLists; + common::idx_t freeListIdx; + FreeSpaceManager::sorted_free_list_t::const_iterator freeListIt; +}; + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index.h new file mode 100644 index 0000000000..276932ff93 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index.h @@ -0,0 +1,508 @@ +#pragma once + +#include +#include +#include + +#include "common/cast.h" +#include "common/serializer/buffer_reader.h" +#include "common/serializer/serializer.h" +#include "common/type_utils.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "hash_index_header.h" +#include "hash_index_slot.h" +#include "index.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/disk_array_collection.h" +#include "storage/index/hash_index_utils.h" +#include "storage/index/in_mem_hash_index.h" +#include "storage/local_storage/local_hash_index.h" + +namespace lbug { +namespace common { +class VirtualFileSystem; +} +namespace transaction { +class Transaction; +enum class TransactionType : uint8_t; +} // namespace transaction +namespace storage { + +class FileHandle; +class BufferManager; +class OverflowFileHandle; +template +class DiskArray; +class PageManager; + +class OnDiskHashIndex { +public: + virtual ~OnDiskHashIndex() = default; + virtual bool checkpoint(PageAllocator& pageAllocator) = 0; + virtual bool checkpointInMemory() = 0; + virtual bool rollbackInMemory() = 0; + virtual void rollbackCheckpoint() = 0; + virtual void bulkReserve(uint64_t numValuesToAppend) = 0; + virtual void reclaimStorage(PageAllocator& pageAllocator) = 0; + virtual bool tryLock() = 0; + virtual std::unique_lock adoptLock() = 0; +}; + +// HashIndex is the entrance to handle all updates and lookups into the index after building from +// scratch through InMemHashIndex. +// The index consists of two parts, one is the persistent storage (from the persistent index file), +// and the other is the local storage. All lookups/deletions/insertions go through local storage, +// and then the persistent storage if necessary. +// +// Key interfaces: +// - lookup(): Given a key, find its result. Return true if the key is found, else, return false. +// Lookups go through the local storage first, check if the key is marked as deleted or not, then +// check whether it can be found inside local insertions or not. If the key is neither marked as +// deleted nor found in local insertions, we proceed to lookups in the persistent store. +// - delete(): Delete the given key. +// Deletions are directly marked in the local storage. +// - insert(): Insert the given key and value. Return true if the given key doesn't exist in the +// index before insertion, otherwise, return false. +// First check if the key to be inserted already exists in local insertions or the persistent +// store. If the key doesn't exist yet, append it to local insertions, and also remove it from +// local deletions if it was marked as deleted. +// +// T is the key type used to access values +// S is the stored type, which is usually the same as T, with the exception of strings +template +class HashIndex final : public OnDiskHashIndex { +public: + HashIndex(MemoryManager& memoryManager, OverflowFileHandle* overflowFileHandle, + DiskArrayCollection& diskArrays, uint64_t indexPos, ShadowFile* shadowFile, + const HashIndexHeader& indexHeaderForReadTrx, HashIndexHeader& indexHeaderForWriteTrx); + + ~HashIndex() override; + +public: + using OnDiskSlotType = Slot; + static constexpr auto PERSISTENT_SLOT_CAPACITY = getSlotCapacity(); + + static_assert(DiskArray::getAlignedElementSize() <= + common::HashIndexConstants::SLOT_CAPACITY_BYTES); + static_assert(DiskArray::getAlignedElementSize() > + common::HashIndexConstants::SLOT_CAPACITY_BYTES / 2); + + using Key = + typename std::conditional, std::string_view, T>::type; + // For read transactions, local storage is skipped, lookups are performed on the persistent + // storage. For write transactions, lookups are performed in the local storage first, then in + // the persistent storage if necessary. In details, there are three cases for the local storage + // lookup: + // - the key is found in the local storage, directly return true; + // - the key has been marked as deleted in the local storage, return false; + // - the key is neither deleted nor found in the local storage, lookup in the persistent + // storage. + bool lookupInternal(const transaction::Transaction* transaction, Key key, + common::offset_t& result, visible_func isVisible) { + auto localLookupState = localStorage->lookup(key, result, isVisible); + if (localLookupState == HashIndexLocalLookupState::KEY_DELETED) { + return false; + } + if (localLookupState == HashIndexLocalLookupState::KEY_FOUND) { + return true; + } + KU_ASSERT(localLookupState == HashIndexLocalLookupState::KEY_NOT_EXIST); + return lookupInPersistentIndex(transaction, key, result, isVisible); + } + + // For deletions, we don't check if the deleted keys exist or not. Thus, we don't need to check + // in the persistent storage and directly delete keys in the local storage. + void deleteInternal(Key key) const { localStorage->deleteKey(key); } + // Discards from local storage, but will not insert a deletion (used for rollbacks) + bool discardLocal(Key key) const { return localStorage->discard(key); } + + // For insertions, we first check in the local storage. There are three cases: + // - the key is found in the local storage, return false; + // - the key is marked as deleted in the local storage, insert the key to the local storage; + // - the key doesn't exist in the local storage, check if the key exists in the persistent + // index, if + // so, return false, else insert the key to the local storage. + using InsertType = InMemHashIndex::OwnedType; + bool insertInternal(const transaction::Transaction* transaction, InsertType&& key, + common::offset_t value, visible_func isVisible) { + common::offset_t tmpResult = 0; + auto localLookupState = localStorage->lookup(key, tmpResult, isVisible); + if (localLookupState == HashIndexLocalLookupState::KEY_FOUND) { + return false; + } + if (localLookupState != HashIndexLocalLookupState::KEY_DELETED) { + if (lookupInPersistentIndex(transaction, key, tmpResult, isVisible)) { + return false; + } + } + return localStorage->insert(std::move(key), value, isVisible); + } + + using BufferKeyType = + typename std::conditional, std::string, T>::type; + // Appends the buffer to the index. Returns the number of values successfully inserted + // Note that this function does not acquire locks internally, as the caller is expected to hold + // the lock already. + size_t appendNoLock(const transaction::Transaction* transaction, + IndexBuffer& buffer, uint64_t bufferOffset, visible_func isVisible) { + // Check if values already exist in persistent storage + if (indexHeaderForWriteTrx.numEntries > 0) { + localStorage->reserveSpaceForAppendNoLock(buffer.size() - bufferOffset); + size_t numValuesInserted = 0; + common::offset_t result = 0; + for (size_t i = bufferOffset; i < buffer.size(); i++) { + auto& [key, value] = buffer[i]; + if (lookupInPersistentIndex(transaction, key, result, isVisible)) { + return i - bufferOffset; + } else { + numValuesInserted += + localStorage->appendNoLock(std::move(key), value, isVisible); + } + } + return numValuesInserted; + } else { + return localStorage->appendNoLock(buffer, bufferOffset, isVisible); + } + } + + bool tryLock() override { return localStorage->tryLock(); } + std::unique_lock adoptLock() override { return localStorage->adoptLock(); } + + bool checkpoint(PageAllocator& pageAllocator) override; + bool checkpointInMemory() override; + bool rollbackInMemory() override; + void rollbackCheckpoint() override; + void reclaimStorage(PageAllocator& pageAllocator) override; + +private: + bool lookupInPersistentIndex(const transaction::Transaction* transaction, Key key, + common::offset_t& result, visible_func isVisible) { + auto& header = transaction->getType() == transaction::TransactionType::CHECKPOINT ? + this->indexHeaderForWriteTrx : + this->indexHeaderForReadTrx; + // There may not be any primary key slots if we try to lookup on an empty index + if (header.numEntries == 0) { + return false; + } + auto hashValue = HashIndexUtils::hash(key); + auto fingerprint = HashIndexUtils::getFingerprintForHash(hashValue); + auto iter = getSlotIterator(HashIndexUtils::getPrimarySlotIdForHash(header, hashValue), + transaction); + do { + auto entryPos = + findMatchedEntryInSlot(transaction, iter.slot, key, fingerprint, isVisible); + if (entryPos != SlotHeader::INVALID_ENTRY_POS) { + result = iter.slot.entries[entryPos].value; + return true; + } + } while (nextChainedSlot(transaction, iter)); + return false; + } + void deleteFromPersistentIndex(const transaction::Transaction* transaction, Key key, + visible_func isVisible); + + entry_pos_t findMatchedEntryInSlot(const transaction::Transaction* transaction, + const OnDiskSlotType& slot, Key key, uint8_t fingerprint, + const visible_func& isVisible) const { + for (auto entryPos = 0u; entryPos < PERSISTENT_SLOT_CAPACITY; entryPos++) { + if (slot.header.isEntryValid(entryPos) && + slot.header.fingerprints[entryPos] == fingerprint && + equals(transaction, key, slot.entries[entryPos].key) && + isVisible(slot.entries[entryPos].value)) { + return entryPos; + } + } + return SlotHeader::INVALID_ENTRY_POS; + } + + inline void updateSlot(const transaction::Transaction* transaction, const SlotInfo& slotInfo, + const OnDiskSlotType& slot) { + slotInfo.slotType == SlotType::PRIMARY ? + pSlots->update(transaction, slotInfo.slotId, slot) : + oSlots->update(transaction, slotInfo.slotId, slot); + } + + inline OnDiskSlotType getSlot(const transaction::Transaction* transaction, + const SlotInfo& slotInfo) const { + return slotInfo.slotType == SlotType::PRIMARY ? pSlots->get(slotInfo.slotId, transaction) : + oSlots->get(slotInfo.slotId, transaction); + } + + void splitSlots(PageAllocator& pageAllocator, const transaction::Transaction* transaction, + HashIndexHeader& header, slot_id_t numSlotsToSplit); + + // Resizes the local storage to support the given number of new entries + void bulkReserve(uint64_t newEntries) override; + // Resizes the on-disk index to support the given number of new entries + void reserve(PageAllocator& pageAllocator, const transaction::Transaction* transaction, + uint64_t newEntries); + + struct HashIndexEntryView { + slot_id_t diskSlotId; + uint8_t fingerprint; + const SlotEntry::OwnedType>* entry; + }; + + void sortEntries(const transaction::Transaction* transaction, + const InMemHashIndex& insertLocalStorage, + typename InMemHashIndex::SlotIterator& slotToMerge, + std::vector& entries); + void mergeBulkInserts(PageAllocator& pageAllocator, const transaction::Transaction* transaction, + const InMemHashIndex& insertLocalStorage); + // Returns the number of elements merged which matched the given slot id + size_t mergeSlot(PageAllocator& pageAllocator, const transaction::Transaction* transaction, + const std::vector& slotToMerge, + typename DiskArray::WriteIterator& diskSlotIterator, + typename DiskArray::WriteIterator& diskOverflowSlotIterator, + slot_id_t diskSlotId); + + inline bool equals(const transaction::Transaction* /*transaction*/, Key keyToLookup, + const T& keyInEntry) const { + return keyToLookup == keyInEntry; + } + + inline common::hash_t hashStored(const transaction::Transaction* /*transaction*/, + const T& key) const { + return HashIndexUtils::hash(key); + } + + inline common::hash_t hashStored(const transaction::Transaction* /*transaction*/, + std::string_view key) const { + return HashIndexUtils::hash(key); + } + + struct SlotIterator { + SlotInfo slotInfo; + OnDiskSlotType slot; + }; + + SlotIterator getSlotIterator(slot_id_t slotId, const transaction::Transaction* transaction) { + return SlotIterator{SlotInfo{slotId, SlotType::PRIMARY}, + getSlot(transaction, SlotInfo{slotId, SlotType::PRIMARY})}; + } + + bool nextChainedSlot(const transaction::Transaction* transaction, SlotIterator& iter) const { + KU_ASSERT(iter.slotInfo.slotType == SlotType::PRIMARY || + iter.slotInfo.slotId != iter.slot.header.nextOvfSlotId); + if (iter.slot.header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + iter.slotInfo.slotId = iter.slot.header.nextOvfSlotId; + iter.slotInfo.slotType = SlotType::OVF; + iter.slot = getSlot(transaction, iter.slotInfo); + return true; + } + return false; + } + + std::vector> getChainedSlots( + const transaction::Transaction* transaction, slot_id_t pSlotId); + +private: + ShadowFile* shadowFile; + uint64_t headerPageIdx; + std::unique_ptr> pSlots; + std::unique_ptr> oSlots; + OverflowFileHandle* overflowFileHandle; + std::unique_ptr> localStorage; + const HashIndexHeader& indexHeaderForReadTrx; + HashIndexHeader& indexHeaderForWriteTrx; + MemoryManager& memoryManager; +}; + +template<> +common::hash_t HashIndex::hashStored( + const transaction::Transaction* transaction, const common::ku_string_t& key) const; + +template<> +bool HashIndex::equals(const transaction::Transaction* transaction, + std::string_view keyToLookup, const common::ku_string_t& keyInEntry) const; + +struct PrimaryKeyIndexStorageInfo final : IndexStorageInfo { + common::page_idx_t firstHeaderPage; + common::page_idx_t overflowHeaderPage; + + PrimaryKeyIndexStorageInfo() + : firstHeaderPage{common::INVALID_PAGE_IDX}, overflowHeaderPage{common::INVALID_PAGE_IDX} {} + PrimaryKeyIndexStorageInfo(common::page_idx_t firstHeaderPage, + common::page_idx_t overflowHeaderPage) + : firstHeaderPage{firstHeaderPage}, overflowHeaderPage{overflowHeaderPage} {} + + DELETE_COPY_DEFAULT_MOVE(PrimaryKeyIndexStorageInfo); + + std::shared_ptr serialize() const override { + auto bufferWriter = std::make_shared(); + auto serializer = common::Serializer(bufferWriter); + serializer.write(firstHeaderPage); + serializer.write(overflowHeaderPage); + return bufferWriter; + } + + static std::unique_ptr deserialize( + std::unique_ptr reader); +}; + +class PrimaryKeyIndex final : public Index { +public: + static constexpr const char* DEFAULT_NAME = "_PK"; + + struct InsertState final : Index::InsertState { + visible_func isVisible; // Function to check visibility of the inserted key + + explicit InsertState(visible_func isVisible_) : isVisible{std::move(isVisible_)} {} + }; + + // Construct an existing index + PrimaryKeyIndex(IndexInfo indexInfo, std::unique_ptr storageInfo, + bool inMemMode, MemoryManager& memoryManager, PageAllocator& pageAllocator, + ShadowFile* shadowFile); + ~PrimaryKeyIndex() override; + + static std::unique_ptr createNewIndex(IndexInfo indexInfo, bool inMemMode, + MemoryManager& memoryManager, PageAllocator& pageAllocator, ShadowFile* shadowFile); + + template + inline HashIndex>* getTypedHashIndex(T key) { + return common::ku_dynamic_cast>*>( + hashIndices[HashIndexUtils::getHashIndexPosition(key)].get()); + } + template + inline HashIndex* getTypedHashIndexByPos(uint64_t indexPos) { + return common::ku_dynamic_cast>*>(hashIndices[indexPos].get()); + } + + bool tryLockTypedIndex(uint64_t indexPos) { return hashIndices[indexPos]->tryLock(); } + std::unique_lock adoptLockOfTypedIndex(uint64_t indexPos) { + return hashIndices[indexPos]->adoptLock(); + } + + bool lookup(const transaction::Transaction* trx, common::ku_string_t key, + common::offset_t& result, visible_func isVisible) { + return lookup(trx, key.getAsStringView(), result, isVisible); + } + template + inline bool lookup(const transaction::Transaction* trx, T key, common::offset_t& result, + visible_func isVisible) { + KU_ASSERT(indexInfo.keyDataTypes[0] == common::TypeUtils::getPhysicalTypeIDForType()); + return getTypedHashIndex(key)->lookupInternal(trx, key, result, isVisible); + } + + bool lookup(const transaction::Transaction* trx, common::ValueVector* keyVector, + uint64_t vectorPos, common::offset_t& result, visible_func isVisible); + + std::unique_ptr initInsertState(main::ClientContext*, + visible_func isVisible) override { + return std::make_unique(isVisible); + } + void insert(transaction::Transaction*, const common::ValueVector&, + const std::vector&, Index::InsertState&) override { + // DO NOTHING. + // For hash index, we don't need to do anything here because the insertions are handled when + // the transaction commits. + } + bool insert(const transaction::Transaction* transaction, common::ku_string_t key, + common::offset_t value, visible_func isVisible) { + return insert(transaction, key.getAsString(), value, isVisible); + } + template + inline bool insert(const transaction::Transaction* transaction, T key, common::offset_t value, + visible_func isVisible) { + KU_ASSERT(indexInfo.keyDataTypes[0] == common::TypeUtils::getPhysicalTypeIDForType()); + return getTypedHashIndex(key)->insertInternal(transaction, std::move(key), value, + isVisible); + } + bool insert(const transaction::Transaction* transaction, const common::ValueVector* keyVector, + uint64_t vectorPos, common::offset_t value, visible_func isVisible); + bool needCommitInsert() const override { return true; } + void commitInsert(transaction::Transaction* transaction, + const common::ValueVector& nodeIDVector, + const std::vector& indexVectors, + Index::InsertState& insertState) override; + + // Appends the buffer to the index. Returns the number of values successfully inserted. + // If a key fails to insert, it immediately returns without inserting any more values, + // and the returned value is also the index of the key which failed to insert. + template + size_t appendWithIndexPosNoLock(const transaction::Transaction* transaction, + IndexBuffer& buffer, uint64_t bufferOffset, uint64_t indexPos, visible_func isVisible) { + KU_ASSERT(indexInfo.keyDataTypes[0] == common::TypeUtils::getPhysicalTypeIDForType()); + KU_ASSERT(std::all_of(buffer.begin(), buffer.end(), [&](auto& elem) { + return HashIndexUtils::getHashIndexPosition(elem.first) == indexPos; + })); + return getTypedHashIndexByPos>(indexPos)->appendNoLock(transaction, buffer, + bufferOffset, isVisible); + } + + void bulkReserve(uint64_t numValuesToAppend) { + uint32_t eachSize = numValuesToAppend / NUM_HASH_INDEXES + 1; + for (auto i = 0u; i < NUM_HASH_INDEXES; i++) { + hashIndices[i]->bulkReserve(eachSize); + } + } + + void delete_(common::ku_string_t key) { return delete_(key.getAsStringView()); } + std::unique_ptr initDeleteState(const transaction::Transaction* /*transaction*/, + MemoryManager* /*mm*/, visible_func /*isVisible*/) override { + return std::make_unique(); + } + void delete_(transaction::Transaction* /*transaction*/, + const common::ValueVector& /*nodeIDVector*/, DeleteState& /*deleteState*/) override { + // DO NOTHING. + } + template + inline void delete_(T key) { + KU_ASSERT(indexInfo.keyDataTypes[0] == common::TypeUtils::getPhysicalTypeIDForType()); + return getTypedHashIndex(key)->deleteInternal(key); + } + + bool discardLocal(common::ku_string_t key) { return discardLocal(key.getAsStringView()); } + template + inline bool discardLocal(T key) { + KU_ASSERT(indexInfo.keyDataTypes[0] == common::TypeUtils::getPhysicalTypeIDForType()); + return getTypedHashIndex(key)->discardLocal(key); + } + + void delete_(common::ValueVector* keyVector); + + void checkpointInMemory() override; + void checkpoint(main::ClientContext*, storage::PageAllocator& pageAllocator) override; + OverflowFile* getOverflowFile() const { return overflowFile.get(); } + + void rollbackCheckpoint() override; + + common::PhysicalTypeID keyTypeID() const { + KU_ASSERT(indexInfo.keyDataTypes.size() == 1); + return indexInfo.keyDataTypes[0]; + } + void reclaimStorage(PageAllocator& pageAllocator) const; + + static LBUG_API std::unique_ptr load(main::ClientContext* context, + StorageManager* storageManager, IndexInfo indexInfo, std::span storageInfoBuffer); + + static IndexType getIndexType() { + static const IndexType HASH_INDEX_TYPE{"HASH", IndexConstraintType::PRIMARY, + IndexDefinitionType::BUILTIN, load}; + return HASH_INDEX_TYPE; + } + +private: + void writeHeaders(PageAllocator& pageAllocator) const; + + void initOverflowAndSubIndices(bool inMemMode, MemoryManager& mm, PageAllocator& pageAllocator, + PrimaryKeyIndexStorageInfo& storageInfo); + + common::page_idx_t getFirstHeaderPage() const; + + common::page_idx_t getDiskArrayFirstHeaderPage() const; + +private: + std::unique_ptr overflowFile; + std::vector> hashIndices; + std::vector hashIndexHeadersForReadTrx; + std::vector hashIndexHeadersForWriteTrx; + ShadowFile& shadowFile; + // Stores both primary and overflow slots + std::unique_ptr hashIndexDiskArrays; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_header.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_header.h new file mode 100644 index 0000000000..3e5e776e8b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_header.h @@ -0,0 +1,68 @@ +#pragma once + +#include "hash_index_slot.h" + +namespace lbug { +namespace storage { + +struct HashIndexHeaderOnDisk { + explicit HashIndexHeaderOnDisk() + : nextSplitSlotId{0}, numEntries{0}, + firstFreeOverflowSlotId{SlotHeader::INVALID_OVERFLOW_SLOT_ID}, currentLevel{0} {} + slot_id_t nextSplitSlotId; + uint64_t numEntries; + slot_id_t firstFreeOverflowSlotId; + uint8_t currentLevel; + uint8_t _padding[7]{}; +}; +static_assert(std::has_unique_object_representations_v); + +class HashIndexHeader { +public: + explicit HashIndexHeader() + : currentLevel{1}, levelHashMask{1}, higherLevelHashMask{3}, nextSplitSlotId{0}, + numEntries{0}, firstFreeOverflowSlotId{SlotHeader::INVALID_OVERFLOW_SLOT_ID} {} + + explicit HashIndexHeader(const HashIndexHeaderOnDisk& onDiskHeader) + : currentLevel{onDiskHeader.currentLevel}, levelHashMask{(1ull << this->currentLevel) - 1}, + higherLevelHashMask{(1ull << (this->currentLevel + 1)) - 1}, + nextSplitSlotId{onDiskHeader.nextSplitSlotId}, numEntries{onDiskHeader.numEntries}, + firstFreeOverflowSlotId{onDiskHeader.firstFreeOverflowSlotId} {} + + inline void incrementLevel() { + currentLevel++; + nextSplitSlotId = 0; + levelHashMask = (1 << currentLevel) - 1; + higherLevelHashMask = (1 << (currentLevel + 1)) - 1; + } + inline void incrementNextSplitSlotId() { + if (nextSplitSlotId < (1ull << currentLevel) - 1) { + nextSplitSlotId++; + } else { + incrementLevel(); + } + } + + inline void write(HashIndexHeaderOnDisk& onDiskHeader) const { + onDiskHeader.currentLevel = currentLevel; + onDiskHeader.nextSplitSlotId = nextSplitSlotId; + onDiskHeader.numEntries = numEntries; + onDiskHeader.firstFreeOverflowSlotId = firstFreeOverflowSlotId; + } + +public: + uint64_t currentLevel; + uint64_t levelHashMask; + uint64_t higherLevelHashMask; + // Id of the next slot to split when resizing the hash index + slot_id_t nextSplitSlotId; + uint64_t numEntries; + // Id of the first in a chain of empty overflow slots which have been reclaimed during slot + // splitting. The nextOvfSlotId field in the slot's header indicates the next slot in the chain. + // These slots should be used first when allocating new overflow slots + // TODO(bmwinger): Make use of this in the on-disk hash index + slot_id_t firstFreeOverflowSlotId; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_slot.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_slot.h new file mode 100644 index 0000000000..89313fc1c4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_slot.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/constants.h" +#include "common/types/types.h" +#include + +namespace lbug { +namespace storage { + +using entry_pos_t = uint8_t; +using slot_id_t = uint64_t; + +class SlotHeader { +public: + static const entry_pos_t INVALID_ENTRY_POS = UINT8_MAX; + static const slot_id_t INVALID_OVERFLOW_SLOT_ID = UINT64_MAX; + // For a header of 32 bytes. + // This is smaller than the possible number of entries with an 1-byte key like uint8_t, + // but the additional size would limit the number of entries for 8-byte keys, so we + // instead restrict the capacity to 20 + static constexpr uint8_t FINGERPRINT_CAPACITY = 20; + + SlotHeader() : fingerprints{}, validityMask{0}, nextOvfSlotId{INVALID_OVERFLOW_SLOT_ID} {} + + void reset() { + validityMask = 0; + nextOvfSlotId = INVALID_OVERFLOW_SLOT_ID; + } + + inline bool isEntryValid(uint32_t entryPos) const { + return validityMask & ((uint32_t)1 << entryPos); + } + inline void setEntryValid(entry_pos_t entryPos, uint8_t fingerprint) { + validityMask |= ((uint32_t)1 << entryPos); + fingerprints[entryPos] = fingerprint; + } + inline void setEntryInvalid(entry_pos_t entryPos) { + validityMask &= ~((uint32_t)1 << entryPos); + } + + inline entry_pos_t numEntries() const { return std::popcount(validityMask); } + +public: + std::array fingerprints; + uint32_t validityMask; + slot_id_t nextOvfSlotId; +}; +static_assert(std::has_unique_object_representations_v); + +template +struct SlotEntry { + SlotEntry(T _key, common::offset_t _value) : key{std::move(_key)}, value{_value} { + // Zero padding, if any + if constexpr (sizeof(T) + sizeof(common::offset_t) < sizeof(SlotEntry)) { + auto padding = sizeof(SlotEntry) - sizeof(T) - sizeof(common::offset_t); + memset(reinterpret_cast(&key) + sizeof(T), 0, padding); + // Assumes that all the padding follows the key + KU_ASSERT((std::byte*)&key + sizeof(key) + padding == (std::byte*)&value); + } + } + SlotEntry() : SlotEntry(T{}, 0) {} + + T key; + common::offset_t value; +}; + +template +static constexpr uint8_t getSlotCapacity() { + return std::min((common::HashIndexConstants::SLOT_CAPACITY_BYTES - sizeof(SlotHeader)) / + sizeof(SlotEntry), + static_cast(SlotHeader::FINGERPRINT_CAPACITY)); +} + +template +struct Slot { + Slot() : header{}, entries{} {} + SlotHeader header; + std::array, getSlotCapacity()> entries; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_utils.h new file mode 100644 index 0000000000..4900e21e85 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/hash_index_utils.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include + +#include "common/constants.h" +#include "common/system_config.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "function/hash/hash_functions.h" +#include "storage/index/hash_index_header.h" + +namespace lbug { +namespace storage { + +static constexpr uint64_t NUM_HASH_INDEXES = common::HashIndexConstants::NUM_HASH_INDEXES; +static constexpr uint64_t NUM_HASH_INDEXES_LOG2 = common::HashIndexConstants::NUM_HASH_INDEXES_LOG2; + +static constexpr common::page_idx_t INDEX_HEADER_PAGES = 2; +static constexpr uint64_t INDEX_HEADERS_PER_PAGE = + common::LBUG_PAGE_SIZE / sizeof(HashIndexHeaderOnDisk); + +static constexpr common::page_idx_t P_SLOTS_HEADER_PAGE_IDX = 0; +static constexpr common::page_idx_t O_SLOTS_HEADER_PAGE_IDX = 1; +static constexpr common::page_idx_t NUM_HEADER_PAGES = 2; +static constexpr uint64_t INDEX_HEADER_IDX_IN_ARRAY = 0; + +// so that all 256 hash indexes can be stored in two pages, the HashIndexHeaderOnDisk must be +// smaller than 32 bytes +static_assert(NUM_HASH_INDEXES * sizeof(HashIndexHeaderOnDisk) <= + common::LBUG_PAGE_SIZE * INDEX_HEADER_PAGES); + +enum class SlotType : uint8_t { PRIMARY = 0, OVF = 1 }; + +struct SlotInfo { + slot_id_t slotId{UINT64_MAX}; + SlotType slotType{SlotType::PRIMARY}; + + bool operator==(const SlotInfo&) const = default; +}; + +class HashIndexUtils { + +public: + static constexpr auto INVALID_OVF_INFO = + SlotInfo{SlotHeader::INVALID_OVERFLOW_SLOT_ID, SlotType::OVF}; + + static bool areStringPrefixAndLenEqual(std::string_view keyToLookup, + const common::ku_string_t& keyInEntry) { + auto prefixLen = + std::min(static_cast(keyInEntry.len), common::ku_string_t::PREFIX_LENGTH); + return keyToLookup.length() == keyInEntry.len && + memcmp(keyToLookup.data(), keyInEntry.prefix, prefixLen) == 0; + } + + template + static common::hash_t hash(const T& key) { + common::hash_t hash = 0; + function::Hash::operation(key, hash); + return hash; + } + + static uint8_t getFingerprintForHash(common::hash_t hash) { + // Last 8 bits before the bits used to calculate the hash index position is the fingerprint + return (hash >> (64 - NUM_HASH_INDEXES_LOG2 - 8)) & 255; + } + + static slot_id_t getPrimarySlotIdForHash(const HashIndexHeader& indexHeader, + common::hash_t hash) { + auto slotId = hash & indexHeader.levelHashMask; + if (slotId < indexHeader.nextSplitSlotId) { + slotId = hash & indexHeader.higherLevelHashMask; + } + return slotId; + } + + static uint64_t getHashIndexPosition(common::IndexHashable auto key) { + return (HashIndexUtils::hash(key) >> (64 - NUM_HASH_INDEXES_LOG2)) & (NUM_HASH_INDEXES - 1); + } + + static uint64_t getNumRequiredEntries(uint64_t numEntries) { + return ceil(static_cast(numEntries) * common::DEFAULT_HT_LOAD_FACTOR); + } +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/in_mem_hash_index.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/in_mem_hash_index.h new file mode 100644 index 0000000000..9047d1ac52 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/in_mem_hash_index.h @@ -0,0 +1,316 @@ +#pragma once + +#include + +#include "common/static_vector.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/disk_array.h" +#include "storage/index/hash_index_header.h" +#include "storage/index/hash_index_slot.h" +#include "storage/index/hash_index_utils.h" +#include "storage/overflow_file.h" + +namespace lbug { +namespace storage { + +using visible_func = std::function; + +constexpr size_t INDEX_BUFFER_SIZE = 1024; +template +using IndexBuffer = common::StaticVector, INDEX_BUFFER_SIZE>; + +template +using HashIndexType = + std::conditional_t || std::same_as, + common::ku_string_t, T>; + +/** + * Basic index file consists of three disk arrays: indexHeader, primary slots (pSlots), and overflow + * slots (oSlots). + * + * 1. HashIndexHeader contains the current state of the hash tables (level and split information: + * currentLevel, levelHashMask, higherLevelHashMask, nextSplitSlotId; key data type). + * + * 2. Given a key, it is mapped to one of the pSlots based on its hash value and the level and + * splitting info. The actual key and value are either stored in the pSlot, or in a chained overflow + * slots (oSlots) of the pSlot. + * + * The slot data structure: + * Each slot (p/oSlot) consists of a slot header and several entries. The max number of entries in + * slot is given by HashIndexConstants::SLOT_CAPACITY. The size of the slot is given by + * (sizeof(SlotHeader) + (SLOT_CAPACITY * sizeof(Entry)). + * + * SlotHeader: [numEntries, validityMask, nextOvfSlotId] + * Entry: [key (fixed sized part), node_offset] + * + * 3. oSlots are used to store entries that comes to the designated primary slot that has already + * been filled to the capacity. Several overflow slots can be chained after the single primary slot + * as a singly linked link-list. Each slot's SlotHeader has information about the next overflow slot + * in the chain and also the number of filled entries in that slot. + * + * */ + +// T is the key type stored in the slots. +// For strings this is different than the type used when inserting/searching +// (see BufferKeyType and Key) +template +class InMemHashIndex final { +public: + using OwnedType = std::conditional_t, std::string, T>; + using KeyType = std::conditional_t, std::string_view, T>; + static_assert(std::is_constructible_v); + static_assert(std::is_constructible_v); + + static constexpr auto SLOT_CAPACITY = getSlotCapacity(); + using InMemSlotType = Slot; + + // Size of the validity mask + static_assert(SLOT_CAPACITY <= sizeof(SlotHeader().validityMask) * 8); + static_assert(SLOT_CAPACITY <= std::numeric_limits::max() + 1); + + // sanity check to make sure we aren't accidentally making slots for some types larger than 256 + // bytes. + static_assert(DiskArray::getAlignedElementSize() <= + common::HashIndexConstants::SLOT_CAPACITY_BYTES); + + // the size of Slot depends on the size of T and should always be close to the + // SLOT_CAPACITY_BYTES + static_assert(DiskArray::getAlignedElementSize() > + common::HashIndexConstants::SLOT_CAPACITY_BYTES / 2); + +public: + explicit InMemHashIndex(MemoryManager& memoryManager, OverflowFileHandle* overflowFileHandle); + + // Reserves space for at least the specified number of elements. + // This reserves space for numEntries in total, regardless of existing entries in the builder + void reserve(uint32_t numEntries); + // Allocates the given number of new slots, ignoo + void allocateSlots(uint32_t numSlots); + + void reserveSpaceForAppend(uint32_t numNewEntries) { + reserve(indexHeader.numEntries + numNewEntries); + } + + // Appends the buffer to the index. Returns the number of values successfully inserted. + // I.e. if a key fails to insert, its index will be the return value + size_t append(IndexBuffer& buffer, uint64_t bufferOffset, visible_func isVisible) { + reserve(indexHeader.numEntries + buffer.size() - bufferOffset); + common::hash_t hashes[INDEX_BUFFER_SIZE]; + for (size_t i = bufferOffset; i < buffer.size(); i++) { + hashes[i] = HashIndexUtils::hash(buffer[i].first); + auto& [key, value] = buffer[i]; + if (!appendInternal(std::move(key), value, hashes[i], isVisible)) { + return i - bufferOffset; + } + } + return buffer.size() - bufferOffset; + } + + bool append(OwnedType&& key, common::offset_t value, visible_func isVisible) { + reserve(indexHeader.numEntries + 1); + return appendInternal(std::move(key), value, HashIndexUtils::hash(key), isVisible); + } + bool lookup(KeyType key, common::offset_t& result, visible_func isVisible) { + // This needs to be fast if the builder is empty since this function is always tried + // when looking up in the persistent hash index + if (this->indexHeader.numEntries == 0) { + return false; + } + auto hashValue = HashIndexUtils::hash(key); + auto fingerprint = HashIndexUtils::getFingerprintForHash(hashValue); + auto slotId = HashIndexUtils::getPrimarySlotIdForHash(this->indexHeader, hashValue); + SlotIterator iter(slotId, this); + auto entryPos = findEntry(iter, key, fingerprint, isVisible); + if (entryPos != SlotHeader::INVALID_ENTRY_POS) { + result = iter.slot->entries[entryPos].value; + return true; + } + return false; + } + + uint64_t size() const { return this->indexHeader.numEntries; } + bool empty() const { return size() == 0; } + + void clear(); + + struct SlotIterator { + explicit SlotIterator(slot_id_t newSlotId, const InMemHashIndex* builder) + : slotInfo{newSlotId, SlotType::PRIMARY}, slot(builder->getSlot(slotInfo)) {} + explicit SlotIterator(SlotInfo slotInfo, const InMemHashIndex* builder) + : slotInfo{slotInfo}, slot(builder->getSlot(slotInfo)) {} + SlotInfo slotInfo; + InMemSlotType* slot; + }; + + // Leaves the slot pointer pointing at the last slot to make it easier to add a new one + bool nextChainedSlot(SlotIterator& iter) const { + KU_ASSERT(iter.slotInfo.slotType == SlotType::PRIMARY || + iter.slotInfo.slotId != iter.slot->header.nextOvfSlotId); + if (iter.slot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + iter.slotInfo.slotId = iter.slot->header.nextOvfSlotId; + iter.slotInfo.slotType = SlotType::OVF; + iter.slot = getSlot(iter.slotInfo); + return true; + } + return false; + } + + uint64_t numPrimarySlots() const { return pSlots->size(); } + uint64_t numOverflowSlots() const { return oSlots->size(); } + + const HashIndexHeader& getIndexHeader() const { return indexHeader; } + + // Deletes key, maintaining gapless structure by replacing it with the last entry in the + // slot + bool deleteKey(KeyType key) { + if (this->indexHeader.numEntries == 0) { + return false; + } + auto hashValue = HashIndexUtils::hash(key); + auto fingerprint = HashIndexUtils::getFingerprintForHash(hashValue); + auto slotId = HashIndexUtils::getPrimarySlotIdForHash(this->indexHeader, hashValue); + SlotIterator iter(slotId, this); + std::optional deletedPos; + do { + for (auto entryPos = 0u; entryPos < SLOT_CAPACITY; entryPos++) { + if (iter.slot->header.isEntryValid(entryPos) && + iter.slot->header.fingerprints[entryPos] == fingerprint && + equals(key, iter.slot->entries[entryPos].key)) { + deletedPos = entryPos; + break; + } + } + if (deletedPos.has_value()) { + break; + } + } while (nextChainedSlot(iter)); + + if (deletedPos.has_value()) { + // Find the last valid entry and move it into the deleted position + auto newIter = iter; + while (nextChainedSlot(newIter)) {} + if (newIter.slotInfo != iter.slotInfo || + *deletedPos != newIter.slot->header.numEntries() - 1) { + KU_ASSERT(newIter.slot->header.numEntries() > 0); + auto lastEntryPos = newIter.slot->header.numEntries() - 1; + iter.slot->entries[*deletedPos] = newIter.slot->entries[lastEntryPos]; + iter.slot->header.setEntryValid(*deletedPos, + newIter.slot->header.fingerprints[lastEntryPos]); + newIter.slot->header.setEntryInvalid(lastEntryPos); + } else { + iter.slot->header.setEntryInvalid(*deletedPos); + } + + if (newIter.slot->header.numEntries() == 0) { + reclaimOverflowSlots(SlotIterator(slotId, this)); + } + + return true; + } + return false; + } + +private: + // Assumes that space has already been allocated for the entry + bool appendInternal(OwnedType&& key, common::offset_t value, common::hash_t hash, + visible_func isVisible) { + auto fingerprint = HashIndexUtils::getFingerprintForHash(hash); + auto slotID = HashIndexUtils::getPrimarySlotIdForHash(this->indexHeader, hash); + SlotIterator iter(slotID, this); + // The builder never keeps holes and doesn't support deletions + // Check the valid entries, then insert at the end if we don't find one which matches + auto entryPos = findEntry(iter, key, fingerprint, isVisible); + auto numEntries = iter.slot->header.numEntries(); + if (entryPos != SlotHeader::INVALID_ENTRY_POS) { + // The key already exists + return false; + } else if (numEntries < SLOT_CAPACITY) [[likely]] { + // The key does not exist and the last slot has free space + insert(std::move(key), iter.slot, numEntries, value, fingerprint); + this->indexHeader.numEntries++; + return true; + } + // The last slot is full. Insert a new one + insertToNewOvfSlot(std::move(key), iter.slot, value, fingerprint); + this->indexHeader.numEntries++; + return true; + } + InMemSlotType* getSlot(const SlotInfo& slotInfo) const; + + uint32_t allocatePSlots(uint32_t numSlotsToAllocate); + uint32_t allocateAOSlot(); + /* + * When a slot is split, we add a new slot, which ends up with an + * id equal to the slot to split's ID + (1 << header.currentLevel). + * Values are then rehashed using a hash index which is one bit wider than before, + * meaning they either stay in the existing slot, or move into the new one. + */ + void splitSlot(); + // Reclaims empty overflow slots to be re-used, starting from the given slot iterator + void reclaimOverflowSlots(SlotIterator iter); + void addFreeOverflowSlot(InMemSlotType& overflowSlot, SlotInfo slotInfo); + uint64_t countSlots(SlotIterator iter) const; + // Make sure that the free overflow slot chain is at least as long as the totalSlotsRequired + void reserveOverflowSlots(uint64_t totalSlotsRequired); + + bool equals(KeyType keyToLookup, const OwnedType& keyInEntry) const { + return keyToLookup == keyInEntry; + } + + void insert(OwnedType&& key, InMemSlotType* slot, uint8_t entryPos, common::offset_t value, + uint8_t fingerprint) { + KU_ASSERT(HashIndexUtils::getFingerprintForHash(HashIndexUtils::hash(key)) == fingerprint); + auto& entry = slot->entries[entryPos]; + entry = SlotEntry(std::move(key), value); + slot->header.setEntryValid(entryPos, fingerprint); + } + + void insertToNewOvfSlot(OwnedType&& key, InMemSlotType* previousSlot, common::offset_t offset, + uint8_t fingerprint) { + auto newSlotId = allocateAOSlot(); + previousSlot->header.nextOvfSlotId = newSlotId; + auto newSlot = getSlot(SlotInfo{newSlotId, SlotType::OVF}); + auto entryPos = 0u; // Always insert to the first entry when there is a new slot. + insert(std::move(key), newSlot, entryPos, offset, fingerprint); + } + + common::hash_t hashStored(const OwnedType& key) const; + InMemSlotType* clearNextOverflowAndAdvanceIter(SlotIterator& iter); + + // Finds the entry matching the given key. The iterator will be advanced and will either point + // to the slot containing the matching entry, or the last slot available + entry_pos_t findEntry(SlotIterator& iter, KeyType key, uint8_t fingerprint, + visible_func isVisible) { + do { + auto numEntries = iter.slot->header.numEntries(); + KU_ASSERT(numEntries == std::countr_one(iter.slot->header.validityMask)); + for (auto entryPos = 0u; entryPos < numEntries; entryPos++) { + if (iter.slot->header.fingerprints[entryPos] == fingerprint && + equals(key, iter.slot->entries[entryPos].key) && + isVisible(iter.slot->entries[entryPos].value)) [[unlikely]] { + // Value already exists + return entryPos; + } + } + if (numEntries < SLOT_CAPACITY) { + return SlotHeader::INVALID_ENTRY_POS; + } + } while (nextChainedSlot(iter)); + return SlotHeader::INVALID_ENTRY_POS; + } + +private: + // TODO: might be more efficient to use a vector for each slot since this is now only needed + // in-memory and it would remove the need to handle overflow slots. + OverflowFileHandle* overflowFileHandle; + std::unique_ptr> pSlots; + std::unique_ptr> oSlots; + HashIndexHeader indexHeader; + MemoryManager& memoryManager; + uint64_t numFreeSlots; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/index.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/index.h new file mode 100644 index 0000000000..f1fb5c46f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/index/index.h @@ -0,0 +1,239 @@ +#pragma once + +#include + +#include "common/serializer/buffer_writer.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "in_mem_hash_index.h" +#include + +namespace lbug::storage { +class StorageManager; +} +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { + +enum class IndexConstraintType : uint8_t { + PRIMARY = 0, // Primary key index + SECONDARY_NON_UNIQUE = 1, // Secondary index that is not unique +}; + +enum class IndexDefinitionType : uint8_t { + BUILTIN = 0, + EXTENSION = 1, +}; + +class Index; +struct IndexInfo; +using index_load_func_t = std::function(main::ClientContext* context, + StorageManager* storageManager, IndexInfo, std::span)>; + +struct LBUG_API IndexType { + std::string typeName; + IndexConstraintType constraintType; + IndexDefinitionType definitionType; + index_load_func_t loadFunc; + + IndexType(std::string typeName, IndexConstraintType constraintType, + IndexDefinitionType definitionType, index_load_func_t loadFunc) + : typeName{std::move(typeName)}, constraintType{constraintType}, + definitionType{definitionType}, loadFunc{std::move(loadFunc)} {} +}; + +struct LBUG_API IndexInfo { + std::string name; + std::string indexType; + common::table_id_t tableID; + std::vector columnIDs; + std::vector keyDataTypes; + bool isPrimary; + bool isBuiltin; + + IndexInfo(std::string name, std::string indexType, common::table_id_t tableID, + std::vector columnIDs, + std::vector keyDataTypes, bool isPrimary, bool isBuiltin) + : name{std::move(name)}, indexType{std::move(indexType)}, tableID{tableID}, + columnIDs{std::move(columnIDs)}, keyDataTypes{std::move(keyDataTypes)}, + isPrimary{isPrimary}, isBuiltin{isBuiltin} {} + + void serialize(common::Serializer& ser) const; + static IndexInfo deserialize(common::Deserializer& deSer); +}; + +struct LBUG_API IndexStorageInfo { + IndexStorageInfo() {} + virtual ~IndexStorageInfo(); + DELETE_COPY_DEFAULT_MOVE(IndexStorageInfo); + + virtual std::shared_ptr serialize() const; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } +}; + +class LBUG_API Index { +public: + struct InsertState { + virtual ~InsertState(); + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + }; + + struct DeleteState { + virtual ~DeleteState(); + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + }; + + struct UpdateState { + virtual ~UpdateState(); + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + }; + + Index(IndexInfo indexInfo, std::unique_ptr storageInfo) + : indexInfo{std::move(indexInfo)}, storageInfo{std::move(storageInfo)}, + storageInfoBuffer{nullptr}, storageInfoBufferSize{0}, loaded{true} {} + Index(IndexInfo indexInfo, std::unique_ptr storageBuffer, uint32_t storageBufferSize) + : indexInfo{std::move(indexInfo)}, storageInfo{nullptr}, + storageInfoBuffer{std::move(storageBuffer)}, storageInfoBufferSize{storageBufferSize}, + loaded{false} {} + virtual ~Index(); + + DELETE_COPY_AND_MOVE(Index); + + bool isPrimary() const { return indexInfo.isPrimary; } + bool isExtension() const { return indexInfo.isBuiltin; } + bool isLoaded() const { return loaded; } + std::string getName() const { return indexInfo.name; } + IndexInfo getIndexInfo() const { return indexInfo; } + + virtual std::unique_ptr initInsertState(main::ClientContext* context, + visible_func isVisible) = 0; + virtual void insert(transaction::Transaction*, const common::ValueVector&, + const std::vector&, InsertState&) { + // DO NOTHING. + } + virtual std::unique_ptr initUpdateState(main::ClientContext* /*context*/, + common::column_id_t /*columnID*/, visible_func /*isVisible*/) { + KU_UNREACHABLE; + } + virtual void update(transaction::Transaction* /*transaction*/, + const common::ValueVector& /*nodeIDVector*/, common::ValueVector& /*propertyVector*/, + UpdateState& /*updateState*/) { + KU_UNREACHABLE; + } + virtual std::unique_ptr initDeleteState( + const transaction::Transaction* transaction, MemoryManager* mm, visible_func isVisible) = 0; + virtual void delete_(transaction::Transaction* transaction, + const common::ValueVector& nodeIDVector, DeleteState& deleteState) = 0; + virtual bool needCommitInsert() const { return false; } + virtual void commitInsert(transaction::Transaction*, const common::ValueVector&, + const std::vector&, InsertState&) { + // DO NOTHING. + } + + virtual void checkpointInMemory() { + // DO NOTHING. + }; + virtual void checkpoint(main::ClientContext*, PageAllocator&) { + // DO NOTHING. + } + virtual void rollbackCheckpoint() { + // DO NOTHING. + } + virtual void finalize(main::ClientContext*) { + // DO NOTHING. + } + + std::span getStorageBuffer() const { + KU_ASSERT(!loaded); + return std::span(storageInfoBuffer.get(), storageInfoBufferSize); + } + const IndexStorageInfo& getStorageInfo() const { return *storageInfo; } + bool isBuiltOnColumn(common::column_id_t columnID) const; + + virtual void serialize(common::Serializer& ser) const; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + +protected: + IndexInfo indexInfo; + std::unique_ptr storageInfo; + std::unique_ptr storageInfoBuffer; + uint64_t storageInfoBufferSize; + bool loaded; +}; + +class IndexHolder { +public: + explicit IndexHolder(std::unique_ptr loadedIndex); + IndexHolder(IndexInfo indexInfo, std::unique_ptr storageInfoBuffer, + uint32_t storageInfoBufferSize); + + std::string getName() const { return indexInfo.name; } + bool isLoaded() const { return loaded; } + + void serialize(common::Serializer& ser) const; + LBUG_API void load(main::ClientContext* context, StorageManager* storageManager); + bool needCommitInsert() const { return index->needCommitInsert(); } + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + void checkpoint(main::ClientContext* context, PageAllocator& pageAllocator) { + if (loaded) { + KU_ASSERT(index); + index->checkpoint(context, pageAllocator); + } + } + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + void rollbackCheckpoint() { + if (loaded) { + KU_ASSERT(index); + index->rollbackCheckpoint(); + } + } + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + void finalize(main::ClientContext* context) { + if (loaded) { + KU_ASSERT(index); + index->finalize(context); + } + } + + Index* getIndex() const { + KU_ASSERT(index); + return index.get(); + } + +private: + IndexInfo indexInfo; + std::unique_ptr storageInfoBuffer; + uint64_t storageInfoBufferSize; + bool loaded; + + // Loaded index structure. + std::unique_ptr index; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_cached_column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_cached_column.h new file mode 100644 index 0000000000..fa273717dc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_cached_column.h @@ -0,0 +1,22 @@ +#pragma once + +#include "storage/table/column_chunk_data.h" +#include "transaction/transaction.h" + +namespace lbug { +namespace storage { + +class LBUG_API CachedColumn : public transaction::LocalCacheObject { +public: + static std::string getKey(common::table_id_t tableID, common::property_id_t propertyID) { + return common::stringFormat("{}-{}", tableID, propertyID); + } + explicit CachedColumn(common::table_id_t tableID, common::property_id_t propertyID) + : LocalCacheObject{getKey(tableID, propertyID)}, columnChunks{} {} + DELETE_BOTH_COPY(CachedColumn); + + std::vector> columnChunks; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_hash_index.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_hash_index.h new file mode 100644 index 0000000000..cc216accf1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_hash_index.h @@ -0,0 +1,234 @@ +#pragma once + +#include "common/string_utils.h" +#include "common/type_utils.h" +#include "storage/index/in_mem_hash_index.h" + +namespace lbug { +namespace storage { + +// NOTE: This is a dummy base class to hide the templation from LocalHashIndex, while allows casting +// to templated HashIndexLocalStorage. Same for OnDiskHashIndex. +class BaseHashIndexLocalStorage { +public: + virtual ~BaseHashIndexLocalStorage() = default; +}; + +enum class HashIndexLocalLookupState : uint8_t { KEY_FOUND, KEY_DELETED, KEY_NOT_EXIST }; + +// Local storage consists of two in memory indexes. One (localInsertionIndex) is to keep track of +// all newly inserted entries, and the other (localDeletionIndex) is to keep track of newly deleted +// entries (not available in localInsertionIndex). We assume that in a transaction, the insertions +// and deletions are very small, thus they can be kept in memory. +template +class HashIndexLocalStorage final : public BaseHashIndexLocalStorage { +public: + using OwnedType = InMemHashIndex::OwnedType; + using KeyType = InMemHashIndex::KeyType; + + explicit HashIndexLocalStorage(MemoryManager& memoryManager, OverflowFileHandle* handle) + : localDeletions{}, localInsertions{memoryManager, handle} {} + HashIndexLocalLookupState lookup(KeyType key, common::offset_t& result, + visible_func isVisible) { + std::shared_lock sLock{mtx}; + if (localDeletions.contains(key)) { + return HashIndexLocalLookupState::KEY_DELETED; + } + if (localInsertions.lookup(key, result, isVisible)) { + return HashIndexLocalLookupState::KEY_FOUND; + } + return HashIndexLocalLookupState::KEY_NOT_EXIST; + } + + void deleteKey(KeyType key) { + std::unique_lock xLock{mtx}; + if (!localInsertions.deleteKey(key)) { + localDeletions.insert(static_cast(key)); + } + } + + bool discard(KeyType key) { + std::unique_lock xLock{mtx}; + return localInsertions.deleteKey(key); + } + + bool insert(OwnedType&& key, common::offset_t value, visible_func isVisible) { + std::unique_lock xLock{mtx}; + auto iter = localDeletions.find(key); + if (iter != localDeletions.end()) { + localDeletions.erase(iter); + } + return localInsertions.append(std::move(key), value, isVisible); + } + + void reserveSpaceForAppend(uint32_t numNewEntries) { + std::unique_lock xLock{mtx}; + reserveSpaceForAppendNoLock(numNewEntries); + } + + void reserveSpaceForAppendNoLock(uint32_t numNewEntries) { + localInsertions.reserveSpaceForAppend(numNewEntries); + } + + bool append(OwnedType&& key, common::offset_t value, visible_func isVisible) { + std::unique_lock xLock{mtx}; + return appendNoLock(std::move(key), value, isVisible); + } + + bool appendNoLock(OwnedType&& key, common::offset_t value, visible_func isVisible) { + return localInsertions.append(std::move(key), value, isVisible); + } + + size_t append(IndexBuffer& buffer, uint64_t bufferOffset, visible_func isVisible) { + std::unique_lock xLock{mtx}; + return appendNoLock(buffer, bufferOffset, isVisible); + } + + size_t appendNoLock(IndexBuffer& buffer, uint64_t bufferOffset, + visible_func isVisible) { + return localInsertions.append(buffer, bufferOffset, isVisible); + } + + bool hasUpdates() { + std::shared_lock sLock{mtx}; + return !(localInsertions.empty() && localDeletions.empty()); + } + + int64_t getNetInserts() { + std::shared_lock sLock{mtx}; + return static_cast(localInsertions.size()) - localDeletions.size(); + } + + void clear() { + std::unique_lock xLock{mtx}; + localInsertions.clear(); + localDeletions.clear(); + } + + void applyLocalChanges(const std::function& deleteOp, + const std::function&)>& insertOp) { + std::shared_lock sLock{mtx}; + for (auto& key : localDeletions) { + deleteOp(key); + } + insertOp(localInsertions); + } + + void reserveInserts(uint64_t newEntries) { + std::unique_lock xLock{mtx}; + localInsertions.reserve(newEntries); + } + + bool tryLock() { return mtx.try_lock(); } + std::unique_lock adoptLock() { + return std::unique_lock{mtx, std::adopt_lock}; + } + +private: + // When the storage type is string, allow the key type to be string_view with a custom hash + // function + using hash_function = std::conditional_t, + common::StringUtils::string_hash, std::hash>; + std::shared_mutex mtx; + std::unordered_set> localDeletions; + InMemHashIndex localInsertions; +}; + +class LocalHashIndex { +public: + explicit LocalHashIndex(MemoryManager& memoryManager, common::PhysicalTypeID keyDataTypeID, + OverflowFileHandle* overflowFileHandle) + : keyDataTypeID{keyDataTypeID} { + common::TypeUtils::visit( + keyDataTypeID, + [&](common::ku_string_t) { + localIndex = std::make_unique>( + memoryManager, overflowFileHandle); + }, + [&](T) { + localIndex = std::make_unique>(memoryManager, nullptr); + }, + [&](auto) { KU_UNREACHABLE; }); + } + + common::offset_t lookup(const common::ValueVector& keyVector, common::sel_t pos, + visible_func isVisible) { + common::offset_t result = common::INVALID_OFFSET; + common::TypeUtils::visit( + keyDataTypeID, + [&]( + T) { result = lookup(keyVector.getValue(pos), isVisible); }, + [](auto) { KU_UNREACHABLE; }); + return result; + } + + common::offset_t lookup(const common::ValueVector& keyVector, visible_func isVisible) { + KU_ASSERT(keyVector.state->getSelVector().getSelSize() == 1); + auto pos = keyVector.state->getSelVector().getSelectedPositions()[0]; + return lookup(keyVector, pos, isVisible); + } + + common::offset_t lookup(const common::ku_string_t key, visible_func isVisible) { + return lookup(key.getAsStringView(), isVisible); + } + template + common::offset_t lookup(T key, visible_func isVisible) { + common::offset_t result = common::INVALID_OFFSET; + common::ku_dynamic_cast>*>(localIndex.get()) + ->lookup(key, result, isVisible); + return result; + } + + bool insert(const common::ValueVector& keyVector, common::offset_t startNodeOffset, + visible_func isVisible) { + common::length_t numInserted = 0; + common::TypeUtils::visit( + keyDataTypeID, + [&](T) { + for (auto i = 0u; i < keyVector.state->getSelVector().getSelSize(); i++) { + const auto pos = keyVector.state->getSelVector().getSelectedPositions()[i]; + numInserted += + insert(keyVector.getValue(pos), startNodeOffset + i, isVisible); + } + }, + [](auto) { KU_UNREACHABLE; }); + return numInserted == keyVector.state->getSelVector().getSelSize(); + } + + bool insert(const common::ku_string_t key, common::offset_t value, visible_func isVisible) { + return insert(key.getAsString(), value, isVisible); + } + template + bool insert(T key, common::offset_t value, visible_func isVisible) { + KU_ASSERT(keyDataTypeID == common::TypeUtils::getPhysicalTypeIDForType()); + return common::ku_dynamic_cast>*>(localIndex.get()) + ->insert(std::move(key), value, isVisible); + } + + void delete_(const common::ValueVector& keyVector) { + common::TypeUtils::visit( + keyDataTypeID, + [&](T) { + for (auto i = 0u; i < keyVector.state->getSelVector().getSelSize(); i++) { + const auto pos = keyVector.state->getSelVector().getSelectedPositions()[i]; + delete_(keyVector.getValue(pos)); + } + }, + [](auto) { KU_UNREACHABLE; }); + } + + void delete_(const common::ku_string_t key) { delete_(key.getAsStringView()); } + template + void delete_(T key) { + KU_ASSERT(keyDataTypeID == common::TypeUtils::getPhysicalTypeIDForType()); + common::ku_dynamic_cast>*>(localIndex.get()) + ->deleteKey(key); + } + +private: + common::PhysicalTypeID keyDataTypeID; + std::unique_ptr localIndex; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_node_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_node_table.h new file mode 100644 index 0000000000..5767cd6fb0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_node_table.h @@ -0,0 +1,63 @@ +#pragma once + +#include "common/copy_constructors.h" +#include "storage/local_storage/local_hash_index.h" +#include "storage/local_storage/local_table.h" +#include "storage/table/node_group_collection.h" + +namespace lbug { +namespace storage { + +struct TableScanState; +class MemoryManager; + +class LocalNodeTable final : public LocalTable { +public: + LocalNodeTable(const catalog::TableCatalogEntry* tableEntry, Table& table, MemoryManager& mm); + DELETE_COPY_AND_MOVE(LocalNodeTable); + + bool insert(transaction::Transaction* transaction, TableInsertState& insertState) override; + bool update(transaction::Transaction* transaction, TableUpdateState& updateState) override; + bool delete_(transaction::Transaction* transaction, TableDeleteState& deleteState) override; + bool addColumn(TableAddColumnState& addColumnState) override; + + common::offset_t validateUniquenessConstraint(const transaction::Transaction* transaction, + const common::ValueVector& pkVector) const; + + common::TableType getTableType() const override { return common::TableType::NODE; } + + void clear(MemoryManager& mm) override; + + common::row_idx_t getNumTotalRows() override { return nodeGroups.getNumTotalRows(); } + common::node_group_idx_t getNumNodeGroups() const { return nodeGroups.getNumNodeGroups(); } + + NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups.getNodeGroup(nodeGroupIdx); + } + NodeGroupCollection& getNodeGroups() { return nodeGroups; } + + bool lookupPK(const transaction::Transaction* transaction, const common::ValueVector* keyVector, + common::sel_t pos, common::offset_t& result) const; + + TableStats getStats() const { return nodeGroups.getStats(); } + common::offset_t getStartOffset() const { return startOffset; } + + static std::vector getNodeTableColumnTypes( + const catalog::TableCatalogEntry& table); + +private: + void initLocalHashIndex(MemoryManager& mm); + bool isVisible(const transaction::Transaction* transaction, common::offset_t offset) const; + +private: + // This is equivalent to the num of committed nodes in the table. + common::offset_t startOffset; + PageCursor overflowCursor; + std::unique_ptr overflowFile; + OverflowFileHandle* overflowFileHandle; + std::unique_ptr hashIndex; + NodeGroupCollection nodeGroups; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_rel_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_rel_table.h new file mode 100644 index 0000000000..e4bd066c3e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_rel_table.h @@ -0,0 +1,97 @@ +#pragma once + +#include + +#include "common/enums/rel_direction.h" +#include "storage/local_storage/local_table.h" +#include "storage/table/csr_node_group.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +static constexpr common::column_id_t LOCAL_BOUND_NODE_ID_COLUMN_ID = 0; +static constexpr common::column_id_t LOCAL_NBR_NODE_ID_COLUMN_ID = 1; +static constexpr common::column_id_t LOCAL_REL_ID_COLUMN_ID = 2; + +class RelTable; +struct TableScanState; +struct RelTableUpdateState; + +struct DirectedCSRIndex { + using index_t = std::map; + + explicit DirectedCSRIndex(common::RelDataDirection direction) : direction(direction) {} + + bool isEmpty() const { return index.empty(); } + void clear() { index.clear(); } + + common::RelDataDirection direction; + index_t index; +}; + +class LocalRelTable final : public LocalTable { +public: + LocalRelTable(const catalog::TableCatalogEntry* tableEntry, const Table& table, + MemoryManager& mm); + DELETE_COPY_AND_MOVE(LocalRelTable); + + bool insert(transaction::Transaction* transaction, TableInsertState& state) override; + bool update(transaction::Transaction* transaction, TableUpdateState& state) override; + bool delete_(transaction::Transaction* transaction, TableDeleteState& state) override; + bool addColumn(TableAddColumnState& addColumnState) override; + + bool checkIfNodeHasRels(common::ValueVector* srcNodeIDVector, + common::RelDataDirection direction) const; + + common::TableType getTableType() const override { return common::TableType::REL; } + + static void initializeScan(TableScanState& state); + bool scan(const transaction::Transaction* transaction, TableScanState& state) const; + + void clear(MemoryManager&) override { + localNodeGroup.reset(); + for (auto& index : directedIndices) { + index.clear(); + } + } + bool isEmpty() const { + KU_ASSERT(directedIndices.size() >= 1); + RUNTIME_CHECK(for (const auto& index + : directedIndices) { + KU_ASSERT(index.index.empty() == directedIndices[0].index.empty()); + }); + return directedIndices[0].isEmpty(); + } + + common::column_id_t getNumColumns() const { return localNodeGroup->getDataTypes().size(); } + common::row_idx_t getNumTotalRows() override { return localNodeGroup->getNumRows(); } + + DirectedCSRIndex::index_t& getCSRIndex(common::RelDataDirection direction) { + const auto directionIdx = common::RelDirectionUtils::relDirectionToKeyIdx(direction); + KU_ASSERT(directionIdx < directedIndices.size()); + return directedIndices[directionIdx].index; + } + NodeGroup& getLocalNodeGroup() const { return *localNodeGroup; } + + static std::vector rewriteLocalColumnIDs( + common::RelDataDirection direction, const std::vector& columnIDs); + static common::column_id_t rewriteLocalColumnID(common::RelDataDirection direction, + common::column_id_t columnID); + +private: + common::row_idx_t findMatchingRow(const transaction::Transaction* transaction, + const std::vector& rowIndicesToCheck, common::offset_t relOffset) const; + +private: + // We don't duplicate local rel tuples. Tuples are stored same as node tuples. + // Chunks stored in local rel table are organized as follows: + // [srcNodeID, dstNodeID, relID, property1, property2, ...] + // All local rel tuples are stored in a single node group, and they are indexed by src/dst + // NodeID. + std::vector directedIndices; + std::unique_ptr localNodeGroup; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_storage.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_storage.h new file mode 100644 index 0000000000..fcd6f2b429 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_storage.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +#include "common/copy_constructors.h" +#include "storage/local_storage/local_table.h" +#include "storage/optimistic_allocator.h" + +namespace lbug { +namespace main { +class ClientContext; +} // namespace main +namespace storage { +// Data structures in LocalStorage are not thread-safe. +// For now, we only support single thread insertions and updates. Once we optimize them with +// multiple threads, LocalStorage and its related data structures should be reworked to be +// thread-safe. +class LocalStorage { +public: + explicit LocalStorage(main::ClientContext& clientContext) : clientContext{clientContext} {} + DELETE_COPY_AND_MOVE(LocalStorage); + + // Do nothing if the table already exists, otherwise create a new local table. + LocalTable* getOrCreateLocalTable(Table& table); + // Return nullptr if no local table exists. + LocalTable* getLocalTable(common::table_id_t tableID) const; + + PageAllocator* addOptimisticAllocator(); + + void commit(); + void rollback(); + +private: + main::ClientContext& clientContext; + std::unordered_map> tables; + + // The mutex is only needed when working with the optimistic allocators + std::mutex mtx; + std::vector> optimisticAllocators; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_table.h new file mode 100644 index 0000000000..28b3252293 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/local_storage/local_table.h @@ -0,0 +1,57 @@ +#pragma once + +#include "common/enums/table_type.h" +#include "storage/table/table.h" + +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { +class MemoryManager; + +struct TableAddColumnState; +struct TableInsertState; +struct TableUpdateState; +struct TableDeleteState; +class LocalTable { +public: + virtual ~LocalTable() = default; + + virtual bool insert(transaction::Transaction* transaction, TableInsertState& insertState) = 0; + virtual bool update(transaction::Transaction* transaction, TableUpdateState& updateState) = 0; + virtual bool delete_(transaction::Transaction* transaction, TableDeleteState& deleteState) = 0; + virtual bool addColumn(TableAddColumnState& addColumnState) = 0; + virtual void clear(MemoryManager& mm) = 0; + virtual common::TableType getTableType() const = 0; + virtual common::row_idx_t getNumTotalRows() = 0; + + template + const TARGET& constCast() { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + template + const TARGET* ptrCast() const { + return common::ku_dynamic_cast(this); + } + +protected: + // TODO(Guodong): Revisit this interface. We don't need to pass in Table here, instead should + // pass in a struct that describes Table, e.g., TableInfo. + explicit LocalTable(const Table& table) : table{table} {} + +protected: + const Table& table; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/optimistic_allocator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/optimistic_allocator.h new file mode 100644 index 0000000000..6b56097a83 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/optimistic_allocator.h @@ -0,0 +1,31 @@ +#pragma once + +#include "storage/page_allocator.h" + +namespace lbug { +namespace storage { + +class PageManager; + +/** + * Manages any optimistically allocated pages (e.g. during COPY) so that they can be freed if a + * rollback occurs. + * This class is designed to be thread-local so accesses are not guaranteed to be thread-safe. + */ +class OptimisticAllocator : public PageAllocator { +public: + explicit OptimisticAllocator(PageManager& pageManager); + + PageRange allocatePageRange(common::page_idx_t numPages) override; + + void freePageRange(PageRange block) override; + + void rollback(); + void commit(); + +private: + PageManager& pageManager; + std::vector optimisticallyAllocatedPages; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/overflow_file.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/overflow_file.h new file mode 100644 index 0000000000..5b1841eaa3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/overflow_file.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/types/types.h" +#include "storage//file_handle.h" +#include "storage/index/hash_index_utils.h" +#include "storage/storage_utils.h" + +namespace lbug { +namespace storage { + +class OverflowFile; + +// Stores the current state of the overflow file +// The cursors in use are stored here so that we can write new pages directly +// to the overflow file, and in the case of an interruption and rollback the header will +// still record the correct place in the file to allocate new pages +// +// The first page managed by each handle is also stored for cases where we wish to iterate through +// all the managed pages (e.g. when reclaiming pages) +struct StringOverflowFileHeader { + struct Entry { + common::page_idx_t startPageIdx{common::INVALID_PAGE_IDX}; + PageCursor cursor; + } entries[NUM_HASH_INDEXES]; + + // pages starts at one to reserve space for this header + StringOverflowFileHeader() : entries{} {} +}; +static_assert(std::has_unique_object_representations_v); + +class OverflowFileHandle { + +public: + OverflowFileHandle(OverflowFile& overflowFile, StringOverflowFileHeader::Entry& entry) + : startPageIdx(entry.startPageIdx), nextPosToWriteTo(entry.cursor), + overflowFile{overflowFile} {} + // The OverflowFile stores the handles and returns pointers to them. + // Moving the handle would invalidate those pointers + OverflowFileHandle(OverflowFileHandle&& other) = delete; + + std::string readString(transaction::TransactionType trxType, + const common::ku_string_t& str) const; + + bool equals(transaction::TransactionType trxType, std::string_view keyToLookup, + const common::ku_string_t& keyInEntry) const; + + common::ku_string_t writeString(PageAllocator* pageAllocator, std::string_view rawString); + common::ku_string_t writeString(PageAllocator* pageAllocator, const char* rawString) { + return writeString(pageAllocator, std::string_view(rawString)); + } + + void checkpoint(); + void checkpointInMemory() { pageWriteCache.clear(); } + void rollbackInMemory(PageCursor nextPosToWriteTo_) { + pageWriteCache.clear(); + this->nextPosToWriteTo = nextPosToWriteTo_; + } + void reclaimStorage(PageAllocator& pageAllocator); + +private: + uint8_t* addANewPage(PageAllocator* pageAllocator); + void setStringOverflow(PageAllocator* pageAllocator, const char* inMemSrcStr, uint64_t len, + common::ku_string_t& diskDstString); + + void read(transaction::TransactionType trxType, common::page_idx_t pageIdx, + const std::function& func) const; + +private: + static constexpr common::page_idx_t END_OF_PAGE = + common::LBUG_PAGE_SIZE - sizeof(common::page_idx_t); + // Index of the first page managed by this handle + common::page_idx_t& startPageIdx; + // This is the index of the last free byte to which we can write. + PageCursor& nextPosToWriteTo; + OverflowFile& overflowFile; + + struct CachedPage { + std::unique_ptr buffer; + bool newPage = false; + }; + + // Cached pages which have been written in the current transaction + std::unordered_map pageWriteCache; +}; + +class ShadowFile; +class OverflowFile { + friend class OverflowFileHandle; + +public: + // For reading an existing overflow file + OverflowFile(FileHandle* fileHandle, MemoryManager& memoryManager, ShadowFile* shadowFile, + common::page_idx_t headerPageIdx); + + virtual ~OverflowFile() = default; + + // Handles contain a reference to the overflow file + OverflowFile(OverflowFile&& other) = delete; + + void rollbackInMemory(); + void checkpoint(PageAllocator& pageAllocator); + void checkpointInMemory(); + + void reclaimStorage(PageAllocator& pageAllocator) const; + + common::page_idx_t getHeaderPageIdx() const { return headerPageIdx; } + + OverflowFileHandle* addHandle() { + KU_ASSERT(handles.size() < NUM_HASH_INDEXES); + handles.emplace_back( + std::make_unique(*this, header.entries[handles.size()])); + return handles.back().get(); + } + + FileHandle* getFileHandle() const { + KU_ASSERT(fileHandle); + return fileHandle; + } + +protected: + explicit OverflowFile(MemoryManager& memoryManager); + + common::page_idx_t getNewPageIdx(PageAllocator* pageAllocator); + +private: + void readFromDisk(transaction::TransactionType trxType, common::page_idx_t pageIdx, + const std::function& func) const; + + // Writes new pages directly to the file and existing pages to the WAL + void writePageToDisk(common::page_idx_t pageIdx, uint8_t* data, bool newPage) const; + +protected: + static constexpr uint64_t HEADER_PAGE_IDX = 0; + + std::vector> handles; + StringOverflowFileHeader header; + FileHandle* fileHandle; + ShadowFile* shadowFile; + MemoryManager& memoryManager; + std::atomic pageCounter; + std::atomic headerChanged; + common::page_idx_t headerPageIdx; +}; + +class InMemOverflowFile final : public OverflowFile { +public: + explicit InMemOverflowFile(MemoryManager& memoryManager) : OverflowFile{memoryManager} {} +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_allocator.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_allocator.h new file mode 100644 index 0000000000..f4f748f6cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_allocator.h @@ -0,0 +1,28 @@ +#pragma once + +#include "storage/page_range.h" + +namespace lbug { +namespace storage { + +class FileHandle; + +class PageAllocator { +public: + explicit PageAllocator(FileHandle* fileHandle) : dataFH(fileHandle) {} + virtual ~PageAllocator() = default; + + virtual PageRange allocatePageRange(common::page_idx_t numPages) = 0; + common::page_idx_t allocatePage() { return allocatePageRange(1).startPageIdx; } + + // Only used during checkpoint + virtual void freePageRange(PageRange block) = 0; + void freePage(common::page_idx_t pageIdx) { freePageRange(PageRange(pageIdx, 1)); } + + FileHandle* getDataFH() const { return dataFH; } + +private: + FileHandle* dataFH; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_manager.h new file mode 100644 index 0000000000..76edd3c109 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_manager.h @@ -0,0 +1,60 @@ +#pragma once + +#include + +#include "common/types/types.h" +#include "storage/free_space_manager.h" +#include "storage/page_allocator.h" +#include "storage/page_range.h" + +namespace lbug { +namespace transaction { +enum class TransactionType : uint8_t; +} +namespace storage { +struct PageCursor; +struct DBFileID; +class PageManager; +class FileHandle; + +class PageManager : public PageAllocator { +public: + explicit PageManager(FileHandle* fileHandle) + : PageAllocator(fileHandle), freeSpaceManager(std::make_unique()), + fileHandle(fileHandle), version(0) {} + + uint64_t getVersion() const { return version; } + bool changedSinceLastCheckpoint() const { return version != 0; } + void resetVersion() { version = 0; } + + PageRange allocatePageRange(common::page_idx_t numPages) override; + void freePageRange(PageRange block) override; + void freeImmediatelyRewritablePageRange(FileHandle* fileHandle, PageRange block); + + // The page manager must first allocate space for itself so that its serialized version also + // tracks the pages allocated itself + // Thus this function also allocates and returns the space for the serialized storage maanger + common::page_idx_t estimatePagesNeededForSerialize(); + void serialize(common::Serializer& serializer); + void deserialize(common::Deserializer& deSer); + void finalizeCheckpoint(); + void rollbackCheckpoint() { freeSpaceManager->rollbackCheckpoint(); } + + common::row_idx_t getNumFreeEntries() const { return freeSpaceManager->getNumEntries(); } + std::vector getFreeEntries(common::row_idx_t startOffset, + common::row_idx_t endOffset) const { + return freeSpaceManager->getEntries(startOffset, endOffset); + } + + void clearEvictedBMEntriesIfNeeded(BufferManager* bufferManager); + + static PageManager* Get(const main::ClientContext& context); + +private: + std::unique_ptr freeSpaceManager; + std::mutex mtx; + FileHandle* fileHandle; + uint64_t version; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_range.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_range.h new file mode 100644 index 0000000000..45e54961c9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/page_range.h @@ -0,0 +1,20 @@ +#pragma once + +#include "common/types/types.h" + +namespace lbug::storage { + +struct PageRange { + PageRange() : startPageIdx(common::INVALID_PAGE_IDX), numPages(0){}; + PageRange(common::page_idx_t startPageIdx, common::page_idx_t numPages) + : startPageIdx(startPageIdx), numPages(numPages) {} + + PageRange subrange(common::page_idx_t newStartPage) const { + KU_ASSERT(newStartPage <= numPages); + return {startPageIdx + newStartPage, numPages - newStartPage}; + } + + common::page_idx_t startPageIdx; + common::page_idx_t numPages; +}; +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/column_predicate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/column_predicate.h new file mode 100644 index 0000000000..5692e76428 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/column_predicate.h @@ -0,0 +1,65 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "common/cast.h" +#include "common/enums/zone_map_check_result.h" + +namespace lbug { +namespace storage { + +struct MergedColumnChunkStats; + +class ColumnPredicate; +class LBUG_API ColumnPredicateSet { +public: + ColumnPredicateSet() = default; + EXPLICIT_COPY_DEFAULT_MOVE(ColumnPredicateSet); + + void addPredicate(std::unique_ptr predicate) { + predicates.push_back(std::move(predicate)); + } + void tryAddPredicate(const binder::Expression& column, const binder::Expression& predicate); + bool isEmpty() const { return predicates.empty(); } + + common::ZoneMapCheckResult checkZoneMap(const MergedColumnChunkStats& stats) const; + + std::string toString() const; + +private: + ColumnPredicateSet(const ColumnPredicateSet& other) + : predicates{copyVector(other.predicates)} {} + +private: + std::vector> predicates; +}; + +class LBUG_API ColumnPredicate { +public: + ColumnPredicate(std::string columnName, common::ExpressionType expressionType) + : columnName{std::move(columnName)}, expressionType(expressionType) {} + + virtual ~ColumnPredicate() = default; + + virtual common::ZoneMapCheckResult checkZoneMap(const MergedColumnChunkStats& stats) const = 0; + + virtual std::string toString(); + + virtual std::unique_ptr copy() const = 0; + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + +protected: + std::string columnName; + common::ExpressionType expressionType; +}; + +struct LBUG_API ColumnPredicateUtil { + static std::unique_ptr tryConvert(const binder::Expression& column, + const binder::Expression& predicate); +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/constant_predicate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/constant_predicate.h new file mode 100644 index 0000000000..1fca555239 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/constant_predicate.h @@ -0,0 +1,29 @@ +#pragma once + +#include "column_predicate.h" +#include "common/enums/expression_type.h" +#include "common/types/value/value.h" + +namespace lbug { +namespace storage { + +class ColumnConstantPredicate : public ColumnPredicate { +public: + ColumnConstantPredicate(std::string columnName, common::ExpressionType expressionType, + common::Value value) + : ColumnPredicate{std::move(columnName), expressionType}, value{std::move(value)} {} + + common::ZoneMapCheckResult checkZoneMap(const MergedColumnChunkStats& stats) const override; + + std::string toString() override; + + std::unique_ptr copy() const override { + return std::make_unique(columnName, expressionType, value); + } + +private: + common::Value value; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/null_predicate.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/null_predicate.h new file mode 100644 index 0000000000..bb1683011f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/predicate/null_predicate.h @@ -0,0 +1,25 @@ +#pragma once + +#include "column_predicate.h" +#include "common/enums/expression_type.h" + +namespace lbug { +namespace storage { + +class ColumnNullPredicate : public ColumnPredicate { +public: + explicit ColumnNullPredicate(std::string columnName, common::ExpressionType type) + : ColumnPredicate{std::move(columnName), type} { + KU_ASSERT( + type == common::ExpressionType::IS_NULL || type == common::ExpressionType::IS_NOT_NULL); + } + + common::ZoneMapCheckResult checkZoneMap(const MergedColumnChunkStats& stats) const override; + + std::unique_ptr copy() const override { + return std::make_unique(columnName, expressionType); + } +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/shadow_file.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/shadow_file.h new file mode 100644 index 0000000000..ec6b692890 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/shadow_file.h @@ -0,0 +1,72 @@ +#pragma once + +#include "common/types/uuid.h" +#include "storage/file_handle.h" + +namespace lbug { +namespace storage { + +struct ShadowPageRecord { + common::file_idx_t originalFileIdx = common::INVALID_PAGE_IDX; + common::page_idx_t originalPageIdx = common::INVALID_PAGE_IDX; + + void serialize(common::Serializer& serializer) const; + static ShadowPageRecord deserialize(common::Deserializer& deserializer); +}; + +struct ShadowFileHeader { + common::ku_uuid_t databaseID{0}; + common::page_idx_t numShadowPages = 0; +}; +static_assert(std::is_trivially_copyable_v); + +class BufferManager; +// NOTE: This class is NOT thread-safe for now, as we are not checkpointing in parallel yet. +class ShadowFile { +public: + ShadowFile(BufferManager& bm, common::VirtualFileSystem* vfs, const std::string& databasePath); + + // TODO(Guodong): Remove originalFile param. + bool hasShadowPage(common::file_idx_t originalFile, common::page_idx_t originalPage) const { + return shadowPagesMap.contains(originalFile) && + shadowPagesMap.at(originalFile).contains(originalPage); + } + void clearShadowPage(common::file_idx_t originalFile, common::page_idx_t originalPage); + common::page_idx_t getShadowPage(common::file_idx_t originalFile, + common::page_idx_t originalPage) const; + common::page_idx_t getOrCreateShadowPage(common::file_idx_t originalFile, + common::page_idx_t originalPage); + + FileHandle& getShadowingFH() const { return *shadowingFH; } + + void applyShadowPages(main::ClientContext& context) const; + + void flushAll(main::ClientContext& context) const; + // Clear any buffer in the WAL writer. Also truncate the WAL file to 0 bytes. + void clear(BufferManager& bm); + // Reset the WAL writer to nullptr, and remove the WAL file if it exists. + void reset(); + + // Replay shadow page records from the shadow file to the original data file. This is used + // during recovery. + static void replayShadowPageRecords(main::ClientContext& context); + +private: + FileHandle* getOrCreateShadowingFH(); + +private: + BufferManager& bm; + std::string shadowFilePath; + common::VirtualFileSystem* vfs; + // This is the file handle for the shadow file. It is created lazily when the first shadow page + // is created. + FileHandle* shadowingFH; + // The map caches shadow page idxes for pages in original files. + std::unordered_map> + shadowPagesMap; + std::vector shadowPageRecords; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/shadow_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/shadow_utils.h new file mode 100644 index 0000000000..749b732480 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/shadow_utils.h @@ -0,0 +1,57 @@ +#pragma once + +#include + +#include "common/copy_constructors.h" +#include "common/types/types.h" + +namespace lbug { +namespace transaction { +enum class TransactionType : uint8_t; +} // namespace transaction + +namespace storage { + +struct DBFileID; +class FileHandle; +class BufferManager; +class ShadowFile; + +struct ShadowPageAndFrame { + ShadowPageAndFrame(common::page_idx_t originalPageIdx, common::page_idx_t pageIdxInShadow, + uint8_t* frame) + : originalPage{originalPageIdx}, shadowPage{pageIdxInShadow}, frame{frame} {} + + DELETE_COPY_DEFAULT_MOVE(ShadowPageAndFrame); + + common::page_idx_t originalPage; + common::page_idx_t shadowPage; + uint8_t* frame; +}; + +class ShadowUtils { +public: + constexpr static common::page_idx_t NULL_PAGE_IDX = common::INVALID_PAGE_IDX; + + // Where possible, updatePage/insertNewPage should be used instead + static ShadowPageAndFrame createShadowVersionIfNecessaryAndPinPage( + common::page_idx_t originalPage, bool skipReadingOriginalPage, FileHandle& fileHandle, + ShadowFile& shadowFile); + + static std::pair getFileHandleAndPhysicalPageIdxToPin( + FileHandle& fileHandle, common::page_idx_t pageIdx, const ShadowFile& shadowFile, + transaction::TransactionType trxType); + + static void readShadowVersionOfPage(const FileHandle& fileHandle, + common::page_idx_t originalPageIdx, const ShadowFile& shadowFile, + const std::function& readOp); + + // Note: This function updates a page "transactionally", i.e., creates the WAL version of the + // page if it doesn't exist. For the original page to be updated, the current WRITE trx needs to + // commit and checkpoint. + static void updatePage(FileHandle& fileHandle, common::page_idx_t originalPageIdx, + bool skipReadingOriginalPage, ShadowFile& shadowFile, + const std::function& updateOp); +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/column_stats.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/column_stats.h new file mode 100644 index 0000000000..8df6da8c14 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/column_stats.h @@ -0,0 +1,62 @@ +#pragma once + +#include + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/vector/value_vector.h" +#include "storage/stats/hyperloglog.h" + +namespace lbug { +namespace storage { + +class ColumnStats { +public: + ColumnStats() = default; + explicit ColumnStats(const common::LogicalType& dataType); + EXPLICIT_COPY_DEFAULT_MOVE(ColumnStats); + + common::cardinality_t getNumDistinctValues() const { return hll ? hll->count() : 0; } + + void update(const common::ValueVector* vector); + + void merge(const ColumnStats& other) { + if (hll) { + KU_ASSERT(other.hll); + hll->merge(*other.hll); + }; + } + + void serialize(common::Serializer& serializer) const { + serializer.writeDebuggingInfo("has_hll"); + serializer.serializeValue(hll.has_value()); + if (hll) { + serializer.writeDebuggingInfo("hll"); + hll->serialize(serializer); + } + } + + static ColumnStats deserialize(common::Deserializer& deserializer) { + ColumnStats columnStats; + std::string info; + deserializer.validateDebuggingInfo(info, "has_hll"); + bool hasHll = false; + deserializer.deserializeValue(hasHll); + if (hasHll) { + deserializer.validateDebuggingInfo(info, "hll"); + columnStats.hll = HyperLogLog::deserialize(deserializer); + } + return columnStats; + } + +private: + ColumnStats(const ColumnStats& other) : hll{other.hll}, hashes{nullptr} {} + +private: + std::optional hll; + // Preallocated vector for hash values. + std::unique_ptr hashes; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/hyperloglog.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/hyperloglog.h new file mode 100644 index 0000000000..624c91e33b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/hyperloglog.h @@ -0,0 +1,55 @@ +// This HyperLogLog implementation is taken from duckdb. +// Source code: +// https://github.com/duckdb/duckdb/blob/main/src/include/duckdb/common/types/hyperloglog.hpp + +#pragma once + +#include + +#include "common/utils.h" + +namespace lbug { +namespace storage { + +class HyperLogLog { +public: + static constexpr common::cardinality_t P = 6; + static constexpr common::cardinality_t Q = 64 - P; + static constexpr common::cardinality_t M = 1 << P; + static constexpr double ALPHA = 0.721347520444481703680; // 1 / (2 log(2)) + +public: + HyperLogLog() : k{} {} // NOLINT(*-pro-type-member-init) + + //! Algorithm 1 + void insertElement(common::hash_t h) { + const auto i = h & ((1 << P) - 1); + h >>= P; + h |= static_cast(1) << Q; + const uint8_t z = static_cast(common::CountZeros::Trailing(h) + 1); + update(i, z); + } + + void update(const common::idx_t& i, const uint8_t& z) { k[i] = std::max(k[i], z); } + + uint8_t getRegister(const common::idx_t& i) const { return k[i]; } + + common::cardinality_t count() const; + + //! Algorithm 2 + void merge(const HyperLogLog& other); + + void serialize(common::Serializer& serializer) const; + static HyperLogLog deserialize(common::Deserializer& deserializer); + + //! Algorithm 4 + void extractCounts(uint32_t* c) const; + //! Algorithm 6 + static int64_t estimateCardinality(const uint32_t* c); + +private: + std::array k; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/table_stats.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/table_stats.h new file mode 100644 index 0000000000..50caad9fab --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/stats/table_stats.h @@ -0,0 +1,69 @@ +#pragma once + +#include "common/types/types.h" +#include "storage/stats/column_stats.h" + +namespace lbug::common { +class LogicalType; +} +namespace lbug { +namespace storage { + +class TableStats { +public: + explicit TableStats(std::span dataTypes); + + EXPLICIT_COPY_DEFAULT_MOVE(TableStats); + + void incrementCardinality(common::cardinality_t increment) { cardinality += increment; } + + void merge(const TableStats& other) { + std::vector columnIDs; + for (auto i = 0u; i < columnStats.size(); i++) { + columnIDs.push_back(i); + } + merge(columnIDs, other); + } + + void merge(const std::vector& columnIDs, const TableStats& other) { + cardinality += other.cardinality; + KU_ASSERT(columnIDs.size() == other.columnStats.size()); + for (auto i = 0u; i < columnIDs.size(); ++i) { + auto columnID = columnIDs[i]; + KU_ASSERT(columnID < columnStats.size()); + columnStats[columnID].merge(other.columnStats[i]); + } + } + + common::cardinality_t getTableCard() const { return cardinality; } + + common::cardinality_t getNumDistinctValues(common::column_id_t columnID) const { + KU_ASSERT(columnID < columnStats.size()); + return columnStats[columnID].getNumDistinctValues(); + } + + void update(const std::vector& vectors, + size_t numColumns = std::numeric_limits::max()); + void update(const std::vector& columnIDs, + const std::vector& vectors, + size_t numColumns = std::numeric_limits::max()); + + ColumnStats& addNewColumn(const common::LogicalType& dataType) { + columnStats.emplace_back(dataType); + return columnStats.back(); + } + + void serialize(common::Serializer& serializer) const; + TableStats deserialize(common::Deserializer& deserializer); + +private: + TableStats(const TableStats& other); + +private: + // Note: cardinality is the estimated number of rows in the table. It is not always up-to-date. + common::cardinality_t cardinality; + std::vector columnStats; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_extension.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_extension.h new file mode 100644 index 0000000000..63afc86679 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_extension.h @@ -0,0 +1,33 @@ +#pragma once + +#include "main/attached_database.h" + +namespace lbug { +namespace binder { +struct AttachOption; +} + +namespace storage { + +using attach_function_t = std::unique_ptr (*)(std::string dbPath, + std::string dbName, main::ClientContext* clientContext, + const binder::AttachOption& attachOption); + +class StorageExtension { +public: + explicit StorageExtension(attach_function_t attachFunction) : attachFunction{attachFunction} {} + virtual bool canHandleDB(std::string /*dbType*/) const { return false; } + + std::unique_ptr attach(std::string dbName, std::string dbPath, + main::ClientContext* clientContext, const binder::AttachOption& attachOption) const { + return attachFunction(std::move(dbName), std::move(dbPath), clientContext, attachOption); + } + + virtual ~StorageExtension() = default; + +private: + attach_function_t attachFunction; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_manager.h new file mode 100644 index 0000000000..fb68c657e0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_manager.h @@ -0,0 +1,103 @@ +#pragma once + +#include + +#include "catalog/catalog.h" +#include "shadow_file.h" +#include "storage/index/index.h" +#include "storage/wal/wal.h" + +namespace lbug { +namespace main { +class Database; +} // namespace main + +namespace catalog { +class CatalogEntry; +class NodeTableCatalogEntry; +class RelGroupCatalogEntry; +struct RelTableCatalogInfo; +} // namespace catalog + +namespace storage { +class Table; +class NodeTable; +class RelTable; +class DiskArrayCollection; +struct DatabaseHeader; + +class LBUG_API StorageManager { +public: + StorageManager(const std::string& databasePath, bool readOnly, bool enableChecksums, + MemoryManager& memoryManager, bool enableCompression, common::VirtualFileSystem* vfs); + ~StorageManager(); + + Table* getTable(common::table_id_t tableID); + + static void recover(main::ClientContext& clientContext, bool throwOnWalReplayFailure, + bool enableChecksums); + + void createTable(catalog::TableCatalogEntry* entry); + void addRelTable(catalog::RelGroupCatalogEntry* entry, + const catalog::RelTableCatalogInfo& info); + + bool checkpoint(main::ClientContext* context, PageAllocator& pageAllocator); + void finalizeCheckpoint(); + void rollbackCheckpoint(const catalog::Catalog& catalog); + + WAL& getWAL() const; + ShadowFile& getShadowFile() const; + FileHandle* getDataFH() const { return dataFH; } + std::string getDatabasePath() const { return databasePath; } + bool isReadOnly() const { return readOnly; } + bool compressionEnabled() const { return enableCompression; } + bool isInMemory() const { return inMemory; } + + void registerIndexType(IndexType indexType) { + registeredIndexTypes.push_back(std::move(indexType)); + } + std::optional> getIndexType( + const std::string& typeName) const; + + void serialize(const catalog::Catalog& catalog, common::Serializer& ser); + // We need to pass in the catalog and storageManager explicitly as they can be from + // attachedDatabase. + void deserialize(main::ClientContext* context, const catalog::Catalog* catalog, + common::Deserializer& deSer); + + void initDataFileHandle(common::VirtualFileSystem* vfs, main::ClientContext* context); + + // If the database header hasn't been created yet, calling these methods will create + return + // the header + common::ku_uuid_t getOrInitDatabaseID(const main::ClientContext& clientContext); + const storage::DatabaseHeader* getOrInitDatabaseHeader( + const main::ClientContext& clientContext); + + void setDatabaseHeader(std::unique_ptr header); + + static StorageManager* Get(const main::ClientContext& context); + +private: + void createNodeTable(catalog::NodeTableCatalogEntry* entry); + + void createRelTableGroup(catalog::RelGroupCatalogEntry* entry); + + void reclaimDroppedTables(const catalog::Catalog& catalog); + +private: + std::mutex mtx; + std::string databasePath; + std::unique_ptr databaseHeader; + bool readOnly; + FileHandle* dataFH; + std::unordered_map> tables; + MemoryManager& memoryManager; + std::unique_ptr wal; + std::unique_ptr shadowFile; + bool enableCompression; + bool inMemory; + std::vector registeredIndexTypes; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_utils.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_utils.h new file mode 100644 index 0000000000..dee8b66e8f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_utils.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +#include "common/constants.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include + +namespace lbug { +namespace storage { + +struct PageCursor { + PageCursor() : PageCursor{UINT32_MAX, UINT16_MAX} {}; + PageCursor(common::page_idx_t pageIdx, uint32_t posInPage) + : pageIdx{pageIdx}, elemPosInPage{posInPage} {}; + + void nextPage() { + pageIdx++; + elemPosInPage = 0; + } + + common::page_idx_t pageIdx; + // Larger than necessary, but PageCursor is directly written to disk + // and adding an explicit padding field messes with structured bindings + uint32_t elemPosInPage; +}; +static_assert(std::has_unique_object_representations_v); + +template +concept NumericType = std::is_integral_v || std::floating_point; + +class StorageUtils { +public: + enum class ColumnType { + DEFAULT = 0, + INDEX = 1, // This is used for index columns in STRING columns. + OFFSET = 2, // This is used for offset columns in LIST and STRING columns. + DATA = 3, // This is used for data columns in LIST and STRING columns. + CSR_OFFSET = 4, + CSR_LENGTH = 5, + STRUCT_CHILD = 6, + NULL_MASK = 7, + }; + + template + static uint64_t divideAndRoundUpTo(T1 v1, T2 v2) { + return std::ceil(static_cast(v1) / static_cast(v2)); + } + + static std::string getColumnName(const std::string& propertyName, ColumnType type, + const std::string& prefix); + + static common::offset_t getStartOffsetOfNodeGroup(common::node_group_idx_t nodeGroupIdx) { + return nodeGroupIdx << common::StorageConfig::NODE_GROUP_SIZE_LOG2; + } + static common::node_group_idx_t getNodeGroupIdx(common::offset_t nodeOffset) { + return nodeOffset >> common::StorageConfig::NODE_GROUP_SIZE_LOG2; + } + static std::pair getNodeGroupIdxAndOffsetInChunk( + common::offset_t nodeOffset) { + auto nodeGroupIdx = getNodeGroupIdx(nodeOffset); + auto offsetInChunk = nodeOffset - getStartOffsetOfNodeGroup(nodeGroupIdx); + return std::make_pair(nodeGroupIdx, offsetInChunk); + } + + static std::string getWALFilePath(const std::string& path) { + return common::stringFormat("{}.{}", path, common::StorageConstants::WAL_FILE_SUFFIX); + } + static std::string getShadowFilePath(const std::string& path) { + return common::stringFormat("{}.{}", path, common::StorageConstants::SHADOWING_SUFFIX); + } + static std::string getTmpFilePath(const std::string& path) { + return common::stringFormat("{}.{}", path, common::StorageConstants::TEMP_FILE_SUFFIX); + } + + static std::string expandPath(const main::ClientContext* context, const std::string& path); + + // Note: This is a relatively slow function because of division and mod and making std::pair. + // It is not meant to be used in performance critical code path. + static std::pair getQuotientRemainder(uint64_t i, uint64_t divisor) { + return std::make_pair(i / divisor, i % divisor); + } + + static uint32_t getDataTypeSize(const common::LogicalType& type); +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_version_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_version_info.h new file mode 100644 index 0000000000..5ba0601bb7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/storage_version_info.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +#include "common/api.h" + +namespace lbug { +namespace storage { + +using storage_version_t = uint64_t; + +struct StorageVersionInfo { + static std::unordered_map getStorageVersionInfo() { + return {{"0.11.1", 39}, {"0.11.0", 39}, {"0.10.0", 38}, {"0.9.0", 37}, {"0.8.0", 36}, + {"0.7.1.1", 35}, {"0.7.0", 34}, {"0.6.0.6", 33}, {"0.6.0.5", 32}, {"0.6.0.2", 31}, + {"0.6.0.1", 31}, {"0.6.0", 28}, {"0.5.0", 28}, {"0.4.2", 27}, {"0.4.1", 27}, + {"0.4.0", 27}, {"0.3.2", 26}, {"0.3.1", 26}, {"0.3.0", 26}, {"0.2.1", 25}, + {"0.2.0", 25}, {"0.1.0", 24}, {"0.0.12.3", 24}, {"0.0.12.2", 24}, {"0.0.12.1", 24}, + {"0.0.12", 23}, {"0.0.11", 23}, {"0.0.10", 23}, {"0.0.9", 23}, {"0.0.8", 17}, + {"0.0.7", 15}, {"0.0.6", 9}, {"0.0.5", 8}, {"0.0.4", 7}, {"0.0.3", 1}}; + } + + static LBUG_API storage_version_t getStorageVersion(); + + static constexpr const char* MAGIC_BYTES = "LBUG"; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/chunked_node_group.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/chunked_node_group.h new file mode 100644 index 0000000000..1e461c6f09 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/chunked_node_group.h @@ -0,0 +1,260 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/enums/rel_multiplicity.h" +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/buffer_manager/spill_result.h" +#include "storage/enums/residency_state.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/version_info.h" + +namespace lbug { +namespace common { +class SelectionVector; +} // namespace common + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { +class MemoryManager; + +class Column; +struct TableScanState; +struct TableAddColumnState; +struct NodeGroupScanState; +class ColumnStats; +class FileHandle; +class PageAllocator; + +enum class NodeGroupDataFormat : uint8_t { REGULAR = 0, CSR = 1 }; + +class LBUG_API InMemChunkedNodeGroup { + friend class ChunkedNodeGroup; + +public: + virtual ~InMemChunkedNodeGroup() = default; + InMemChunkedNodeGroup(MemoryManager& mm, const std::vector& columnTypes, + bool enableCompression, uint64_t capacity, common::row_idx_t startRowIdx); + InMemChunkedNodeGroup(std::vector>&& chunks, + common::row_idx_t startRowIdx); + // Moves the specified columns out of base + InMemChunkedNodeGroup(InMemChunkedNodeGroup& base, + const std::vector& selectedColumns); + + // Also marks the chunks as in-use + // I.e. if you want to be able to spill to disk again you must call setUnused first + void loadFromDisk(const MemoryManager& mm); + // returns the amount of space reclaimed in bytes + SpillResult spillToDisk(); + void setUnused(const MemoryManager& mm); + + bool isFull() const { return numRows == capacity; } + common::idx_t getNumColumns() const { return chunks.size(); } + common::row_idx_t getStartRowIdx() const { return startRowIdx; } + common::row_idx_t getNumRows() const { return numRows; } + common::row_idx_t getCapacity() const { return capacity; } + void setNumRows(common::offset_t numRows_); + + ColumnChunkData& getColumnChunk(const common::column_id_t columnID) { + KU_ASSERT(columnID < chunks.size()); + return *chunks[columnID]; + } + + const ColumnChunkData& getColumnChunk(const common::column_id_t columnID) const { + KU_ASSERT(columnID < chunks.size()); + return *chunks[columnID]; + } + + uint64_t append(const std::vector& columnVectors, + common::row_idx_t startRowInVectors, uint64_t numValuesToAppend); + + // Appends up to numValuesToAppend from the other chunked node group, returning the actual + // number of values appended. + common::offset_t append(const InMemChunkedNodeGroup& other, + common::offset_t offsetInOtherNodeGroup, common::offset_t numRowsToAppend); + + void resizeChunks(uint64_t newSize); + void resetToEmpty(); + void resetToAllNull() const; + + // Moves the specified columns out of base + void merge(InMemChunkedNodeGroup& base, + const std::vector& columnsToMergeInto); + + void write(const InMemChunkedNodeGroup& data, common::column_id_t offsetColumnID); + virtual void writeToColumnChunk(common::idx_t chunkIdx, common::idx_t vectorIdx, + const std::vector>& data, ColumnChunkData& offsetChunk) { + chunks[chunkIdx]->write(data[vectorIdx].get(), &offsetChunk, common::RelMultiplicity::ONE); + } + + std::unique_ptr moveColumnChunk(const common::column_id_t columnID) { + KU_ASSERT(columnID < chunks.size()); + return std::move(chunks[columnID]); + } + + virtual std::unique_ptr flush(transaction::Transaction* transaction, + PageAllocator& pageAllocator); + +protected: + std::unique_ptr flushInternal(ColumnChunkData& chunk, + PageAllocator& pageAllocator); + +protected: + common::row_idx_t startRowIdx; + std::atomic numRows; + uint64_t capacity; + std::vector> chunks; + std::mutex spillToDiskMutex; + // Used to track if the group may be in use and to verify that spillToDisk is only called when + // it is safe to do so. If false, it is safe to spill the data to disk. + bool dataInUse; +}; + +// Collection of ColumnChunks for each column in a particular Node Group +class ChunkedNodeGroup { + friend class InMemChunkedNodeGroup; + +public: + ChunkedNodeGroup(std::vector> chunks, + common::row_idx_t startRowIdx, NodeGroupDataFormat format = NodeGroupDataFormat::REGULAR); + // Moves the specified columns out of base + ChunkedNodeGroup(InMemChunkedNodeGroup& base, + const std::vector& selectedColumns, + NodeGroupDataFormat format = NodeGroupDataFormat::REGULAR); + ChunkedNodeGroup(ChunkedNodeGroup& base, + const std::vector& selectedColumns); + ChunkedNodeGroup(MemoryManager& mm, ChunkedNodeGroup& base, + std::span columnTypes, + std::span baseColumnIDs); + ChunkedNodeGroup(MemoryManager& mm, const std::vector& columnTypes, + bool enableCompression, uint64_t capacity, common::row_idx_t startRowIdx, + ResidencyState residencyState, NodeGroupDataFormat format = NodeGroupDataFormat::REGULAR); + virtual ~ChunkedNodeGroup() = default; + + common::idx_t getNumColumns() const { return chunks.size(); } + common::row_idx_t getStartRowIdx() const { return startRowIdx; } + common::row_idx_t getNumRows() const { return numRows; } + const ColumnChunk& getColumnChunk(const common::column_id_t columnID) const { + KU_ASSERT(columnID < chunks.size()); + return *chunks[columnID]; + } + ColumnChunk& getColumnChunk(const common::column_id_t columnID) { + KU_ASSERT(columnID < chunks.size()); + return *chunks[columnID]; + } + std::unique_ptr moveColumnChunk(const common::column_id_t columnID) { + KU_ASSERT(columnID < chunks.size()); + return std::move(chunks[columnID]); + } + bool isFullOrOnDisk() const { + return numRows == capacity || residencyState == ResidencyState::ON_DISK; + } + ResidencyState getResidencyState() const { return residencyState; } + NodeGroupDataFormat getFormat() const { return format; } + + void resetNumRowsFromChunks(); + void truncate(common::offset_t numRows); + void setVersionInfo(std::unique_ptr versionInfo) { + this->versionInfo = std::move(versionInfo); + } + void resetVersionAndUpdateInfo(); + + uint64_t append(const transaction::Transaction* transaction, + const std::vector& columnVectors, common::row_idx_t startRowInVectors, + uint64_t numValuesToAppend); + common::offset_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, const ChunkedNodeGroup& other, + common::offset_t offsetInOtherNodeGroup, common::offset_t numRowsToAppend); + common::offset_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, const InMemChunkedNodeGroup& other, + common::offset_t offsetInOtherNodeGroup, common::offset_t numRowsToAppend); + common::offset_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, std::span other, + common::offset_t offsetInOtherNodeGroup, common::offset_t numRowsToAppend); + common::offset_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, std::span other, + common::offset_t offsetInOtherNodeGroup, common::offset_t numRowsToAppend); + + void scan(const transaction::Transaction* transaction, const TableScanState& scanState, + const NodeGroupScanState& nodeGroupScanState, common::offset_t rowIdxInGroup, + common::length_t numRowsToScan) const; + + template + void scanCommitted(transaction::Transaction* transaction, TableScanState& scanState, + InMemChunkedNodeGroup& output) const; + + bool hasUpdates() const; + bool hasDeletions(const transaction::Transaction* transaction) const; + common::row_idx_t getNumUpdatedRows(const transaction::Transaction* transaction, + common::column_id_t columnID); + + bool lookup(const transaction::Transaction* transaction, const TableScanState& state, + const NodeGroupScanState& nodeGroupScanState, common::offset_t rowIdxInChunk, + common::sel_t posInOutput) const; + + void update(const transaction::Transaction* transaction, common::row_idx_t rowIdxInChunk, + common::column_id_t columnID, const common::ValueVector& propertyVector); + + bool delete_(const transaction::Transaction* transaction, common::row_idx_t rowIdxInChunk); + + void addColumn(MemoryManager& mm, const TableAddColumnState& addColumnState, + bool enableCompression, PageAllocator* pageAllocator, ColumnStats* newColumnStats); + + bool isDeleted(const transaction::Transaction* transaction, common::row_idx_t rowInChunk) const; + bool isInserted(const transaction::Transaction* transaction, + common::row_idx_t rowInChunk) const; + bool hasAnyUpdates(const transaction::Transaction* transaction, common::column_id_t columnID, + common::row_idx_t startRow, common::length_t numRowsToCheck) const; + common::row_idx_t getNumDeletions(const transaction::Transaction* transaction, + common::row_idx_t startRow, common::length_t numRowsToCheck) const; + bool hasVersionInfo() const { return versionInfo != nullptr; } + + static std::unique_ptr flushEmpty(MemoryManager& mm, + const std::vector& columnTypes, bool enableCompression, + uint64_t capacity, common::row_idx_t startRowIdx, PageAllocator& pageAllocator); + + void commitInsert(common::row_idx_t startRow, common::row_idx_t numRowsToCommit, + common::transaction_t commitTS); + void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS); + void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS); + void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows_, + common::transaction_t commitTS); + virtual void reclaimStorage(PageAllocator& pageAllocator) const; + + uint64_t getEstimatedMemoryUsage() const; + + virtual void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(MemoryManager& memoryManager, + common::Deserializer& deSer); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } + +protected: + NodeGroupDataFormat format; + ResidencyState residencyState; + common::row_idx_t startRowIdx; + uint64_t capacity; + std::atomic numRows; + std::vector> chunks; + std::unique_ptr versionInfo; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column.h new file mode 100644 index 0000000000..b1f3873cb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column.h @@ -0,0 +1,203 @@ +#pragma once + +#include "common/null_mask.h" +#include "common/types/types.h" +#include "storage/table/column_reader_writer.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +class NullColumn; +class StructColumn; +class RelTableData; +struct ColumnCheckpointState; +class PageAllocator; +struct ChunkState; + +class ColumnChunk; +class Column { + friend class StringColumn; + friend class StructColumn; + friend class ListColumn; + friend class RelTableData; + +public: + Column(std::string name, common::LogicalType dataType, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression, bool requireNullColumn = true); + Column(std::string name, common::PhysicalTypeID physicalType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression, + bool requireNullColumn = true); + + virtual ~Column(); + + void populateExtraChunkState(SegmentState& state) const; + + static std::unique_ptr flushChunkData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator); + static std::unique_ptr flushNonNestedChunkData( + const ColumnChunkData& chunkData, PageAllocator& pageAllocator); + static ColumnChunkMetadata flushData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator); + + // Use lookupInternal to specialize + void lookupValue(const ChunkState& state, common::offset_t nodeOffset, + common::ValueVector* resultVector, uint32_t posInVector) const; + + // Scan from [offsetInChunk, offsetInChunk + length) (use scanInternal to specialize). + // + // The selectionVector in the resultVector's dataState should only select positions up to the + // length parameter (E.g. if you want to scan just position 2047 you have to pass a length of + // 2048; that is, like an unfiltered scan of 0-2047 but filtering everything but the value at + // index 2047). + // Primitive columns may scan more than the filtered values + virtual void scan(const ChunkState& state, common::offset_t startOffsetInGroup, + common::offset_t length, common::ValueVector* resultVector, uint64_t offsetInVector) const; + // Scan from [offsetInChunk, offsetInChunk + length) (use scanInternal to specialize). + // Appends to the end of the columnChunk + void scan(const ChunkState& state, ColumnChunkData* columnChunk, + common::offset_t offsetInChunk = 0, common::offset_t numValues = UINT64_MAX) const; + // Scan from [offsetInChunk, offsetInChunk + length) (use scanInternal to specialize). + // Appends to the end of the columnChunk + virtual void scanSegment(const SegmentState& state, ColumnChunkData* columnChunk, + common::offset_t offsetInSegment, common::offset_t numValue) const; + // Scan to raw data (does not scan any nested data and should only be used on primitive columns) + void scanSegment(const SegmentState& state, common::offset_t startOffsetInSegment, + common::offset_t length, uint8_t* result) const; + + common::LogicalType& getDataType() { return dataType; } + const common::LogicalType& getDataType() const { return dataType; } + + Column* getNullColumn() const; + + std::string_view getName() const { return name; } + + // Batch write to a set of sequential pages. + void write(ColumnChunkData& persistentChunk, ChunkState& state, common::offset_t dstOffset, + const ColumnChunkData& data, common::offset_t srcOffset, common::length_t numValues) const; + + virtual void writeSegment(ColumnChunkData& persistentChunk, SegmentState& state, + common::offset_t dstOffsetInSegment, const ColumnChunkData& data, + common::offset_t srcOffset, common::length_t numValues) const; + + // Append values to the end of the node group, resizing it if necessary + // Expects bools to be one bool per bit (like ColumnChunkData) + common::offset_t appendValues(ColumnChunkData& persistentChunk, SegmentState& state, + const uint8_t* data, const common::NullMask* nullChunkData, + common::offset_t numValues) const; + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } + + // Return value is the new segments if segment splitting occurs during an out of place + // checkpoint + virtual std::vector> checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment = true) const; + +protected: + // For a scan that includes a selectionVector, the startOffsetInVector should be considered to + // be an offset for the selected positions within the selectionVector The offset of a given pos + // from the selectionVector within the segment is equal to: + // startOffsetInSegment + pos - startOffsetInVector + // Note that the positions in the selectionVector may not be in the range covered by the segment + // Out of range positions should be ignored + virtual void scanSegment(const SegmentState& state, common::offset_t startOffsetInSegment, + common::row_idx_t numValuesToScan, common::ValueVector* resultVector, + common::offset_t startOffsetInVector) const; + + virtual void lookupInternal(const SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* resultVector, uint32_t posInVector) const; + + void writeValues(ChunkState& state, common::offset_t dstOffset, const uint8_t* data, + const common::NullMask* nullChunkData, common::offset_t srcOffset = 0, + common::offset_t numValues = 1) const; + + void writeValuesInternal(SegmentState& state, common::offset_t dstOffsetInSegment, + const uint8_t* data, const common::NullMask* nullChunkData, common::offset_t srcOffset = 0, + common::offset_t numValues = 1) const; + + void updateStatistics(ColumnChunkMetadata& metadata, common::offset_t maxIndex, + const std::optional& min, const std::optional& max) const; + +protected: + bool isEndOffsetOutOfPagesCapacity(const ColumnChunkMetadata& metadata, + common::offset_t endOffset) const; + + virtual bool canCheckpointInPlace(const SegmentState& state, + const ColumnCheckpointState& checkpointState) const; + + void checkpointColumnChunkInPlace(SegmentState& state, + const ColumnCheckpointState& checkpointState, PageAllocator& pageAllocator) const; + + void checkpointNullData(const ColumnCheckpointState& checkpointState, + PageAllocator& pageAllocator) const; + + std::vector> checkpointColumnChunkOutOfPlace( + const SegmentState& state, const ColumnCheckpointState& checkpointState, + PageAllocator& pageAllocator, bool canSplitSegment) const; + + // check if val is in range [start, end) + static bool isInRange(uint64_t val, uint64_t start, uint64_t end) { + return val >= start && val < end; + } + +protected: + std::string name; + common::LogicalType dataType; + MemoryManager* mm; + FileHandle* dataFH; + ShadowFile* shadowFile; + std::unique_ptr nullColumn; + read_values_to_vector_func_t readToVectorFunc; + write_values_func_t writeFunc; + read_values_to_page_func_t readToPageFunc; + bool enableCompression; + + std::unique_ptr columnReadWriter; +}; + +class InternalIDColumn final : public Column { +public: + InternalIDColumn(std::string name, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression); + + void scan(const ChunkState& state, common::offset_t startOffsetInGroup, common::offset_t length, + common::ValueVector* resultVector, uint64_t offsetInVector) const override { + Column::scan(state, startOffsetInGroup, length, resultVector, offsetInVector); + populateCommonTableID(resultVector); + } + + void lookupInternal(const SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* resultVector, uint32_t posInVector) const override { + Column::lookupInternal(state, offsetInSegment, resultVector, posInVector); + populateCommonTableID(resultVector); + } + + common::table_id_t getCommonTableID() const { return commonTableID; } + // TODO(Guodong): This function should be removed through rewriting INTERNAL_ID as STRUCT. + void setCommonTableID(common::table_id_t tableID) { commonTableID = tableID; } + +private: + void populateCommonTableID(const common::ValueVector* resultVector) const; + +private: + common::table_id_t commonTableID; +}; + +struct ColumnFactory { + static std::unique_ptr createColumn(std::string name, common::LogicalType dataType, + FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression); + static std::unique_ptr createColumn(std::string name, + common::PhysicalTypeID physicalType, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression); +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk.h new file mode 100644 index 0000000000..61dc765a52 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk.h @@ -0,0 +1,336 @@ +#pragma once + +#include +#include + +#include "common/assert.h" +#include "common/cast.h" +#include "common/types/types.h" +#include "storage/enums/residency_state.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/update_info.h" + +namespace lbug { +namespace storage { +class PageAllocator; +class MemoryManager; +class Column; +struct ColumnChunkScanner; + +struct ChunkCheckpointState { + std::unique_ptr chunkData; + // Start offset in the column chunk of the beginning of the chunk data + common::row_idx_t startRow; + common::length_t numRows; + + ChunkCheckpointState(std::unique_ptr chunkData, common::row_idx_t startRow, + common::length_t numRows) + : chunkData{std::move(chunkData)}, startRow{startRow}, numRows{numRows} {} +}; + +struct SegmentCheckpointState { + const ColumnChunkData& chunkData; + common::row_idx_t startRowInData; + common::row_idx_t offsetInSegment; + common::row_idx_t numRows; +}; + +template +std::pair::iterator, common::offset_t> genericFindSegment( + std::span segments, common::offset_t offsetInChunk) { + auto offsetInSegment = offsetInChunk; + auto segment = segments.begin(); + while (segment != segments.end()) { + if (offsetInSegment < (**segment).getNumValues()) { + return std::make_pair(segment, offsetInSegment); + } + offsetInSegment -= (**segment).getNumValues(); + segment++; + } + return std::make_pair(segments.end(), 0); +} + +template + Func> +common::offset_t genericRangeSegments(std::span segments, + common::offset_t offsetInChunk, common::length_t length, Func func) { + // TODO(bmwinger): try binary search (might only make a difference for a very large number + // of segments) + auto [segment, offsetInSegment] = genericFindSegment(segments, offsetInChunk); + return genericRangeSegmentsFromIt(segments, segment, offsetInSegment, length, std::move(func)); +} + +// dstOffset starts from 0 and is the offset in the output data for a given segment +// (it increases by lengthInSegment for each segment) +// Returns the total number of values scanned (input length can be longer than the available +// values) +template + Func> +common::offset_t genericRangeSegmentsFromIt(std::span segments, + typename std::span::iterator segment, common::offset_t offsetInSegment, + common::length_t length, Func func) { + common::offset_t lengthScanned = 0; + while (lengthScanned < length && segment != segments.end()) { + KU_ASSERT((**segment).getNumValues() > offsetInSegment); + auto lengthInSegment = + std::min(length - lengthScanned, (**segment).getNumValues() - offsetInSegment); + func(*segment, offsetInSegment, lengthInSegment, lengthScanned); + lengthScanned += lengthInSegment; + segment++; + offsetInSegment = 0; + } + return lengthScanned; +} + +class ColumnChunk; +struct ColumnCheckpointState { + ColumnChunkData& persistentData; + std::vector segmentCheckpointStates; + common::row_idx_t endRowIdxToWrite; + + ColumnCheckpointState(ColumnChunkData& persistentData, + std::vector segmentCheckpointStates) + : persistentData{persistentData}, + segmentCheckpointStates{std::move(segmentCheckpointStates)}, endRowIdxToWrite{0} { + for (const auto& chunkCheckpointState : this->segmentCheckpointStates) { + const auto endRowIdx = + chunkCheckpointState.offsetInSegment + chunkCheckpointState.numRows; + if (endRowIdx > endRowIdxToWrite) { + endRowIdxToWrite = endRowIdx; + } + } + } +}; + +struct ChunkState { + const Column* column = nullptr; + std::vector segmentStates; + + void reclaimAllocatedPages(PageAllocator& pageAllocator) const; + + std::pair findSegment( + common::offset_t offsetInChunk) const; + + // dstOffset starts from 0 and is the offset in the output data for a given segment + // (it increases by lengthInSegment for each segment) + // Returns the total number of values scanned (input length can be longer than the available + // values) + template + Func> + common::offset_t rangeSegments(common::offset_t offsetInChunk, common::length_t length, + Func func) { + return genericRangeSegments(std::span(segmentStates), offsetInChunk, length, func); + } + + // TODO(bmwinger): the above function should be const and only isn't because of ALP exception + // chunk modifications. The SegmentState& should also be const for the same reason + template + Func> + common::offset_t rangeSegments(common::offset_t offsetInChunk, common::length_t length, + Func func) const { + return const_cast(this)->rangeSegments(offsetInChunk, length, func); + } +}; + +class ColumnChunk { +public: + ColumnChunk(MemoryManager& mm, common::LogicalType&& dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState, bool initializeToZero = true); + ColumnChunk(MemoryManager& mm, common::LogicalType&& dataType, bool enableCompression, + ColumnChunkMetadata metadata); + ColumnChunk(bool enableCompression, std::unique_ptr data); + ColumnChunk(bool enableCompression, std::vector> segments); + + void initializeScanState(ChunkState& state, const Column* column) const; + void scan(const transaction::Transaction* transaction, const ChunkState& state, + common::ValueVector& output, common::offset_t offsetInChunk, common::length_t length) const; + template + void scanCommitted(const transaction::Transaction* transaction, ChunkState& chunkState, + ColumnChunkScanner& output, common::row_idx_t startRow = 0, + common::row_idx_t numRows = common::INVALID_ROW_IDX) const; + template + void scanCommitted(const transaction::Transaction* transaction, ChunkState& chunkState, + ColumnChunkData& output, common::row_idx_t startRow = 0, + common::row_idx_t numRows = common::INVALID_ROW_IDX) const; + void lookup(const transaction::Transaction* transaction, const ChunkState& state, + common::offset_t rowInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const; + void update(const transaction::Transaction* transaction, common::offset_t offsetInChunk, + const common::ValueVector& values); + + uint64_t getEstimatedMemoryUsage() const { + if (getResidencyState() == ResidencyState::ON_DISK) { + return 0; + } + uint64_t memUsage = 0; + for (auto& segment : data) { + memUsage += segment->getEstimatedMemoryUsage(); + } + return memUsage; + } + void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(MemoryManager& mm, common::Deserializer& deSer); + + uint64_t getNumValues() const { + uint64_t numValues = 0; + for (const auto& chunk : data) { + numValues += chunk->getNumValues(); + } + return numValues; + } + uint64_t getCapacity() const { + uint64_t capacity = 0; + for (const auto& chunk : data) { + capacity += chunk->getCapacity(); + } + return capacity; + } + void truncate(uint64_t numValues) { + uint64_t seenValues = 0; + uint64_t seenSegments = 0; + for (auto& segment : data) { + seenSegments++; + if (seenValues + segment->getNumValues() < numValues) { + seenValues += segment->getNumValues(); + } else { + segment->setNumValues(numValues - seenValues); + break; + } + } + data.resize(seenSegments); + } + + common::row_idx_t getNumUpdatedRows(const transaction::Transaction* transaction) const; + + // TODO(bmwinger): Segments could probably share a single datatype + const common::LogicalType& getDataType() const { return data.front()->getDataType(); } + bool isCompressionEnabled() const { return enableCompression; } + + ResidencyState getResidencyState() const { + auto state = data.front()->getResidencyState(); + RUNTIME_CHECK(for (auto& chunk : data) { KU_ASSERT(chunk->getResidencyState() == state); }); + return state; + } + bool hasUpdates() const { return updateInfo.isSet(); } + bool hasUpdates(const transaction::Transaction* transaction, common::row_idx_t startRow, + common::length_t numRows) const; + void resetUpdateInfo() { updateInfo.reset(); } + + MergedColumnChunkStats getMergedColumnChunkStats() const; + + void reclaimStorage(PageAllocator& pageAllocator) const; + + void append(common::ValueVector* vector, const common::SelectionView& selView); + void append(const ColumnChunk* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend); + + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend); + + template Func> + void mapValues(Func func, uint64_t startOffset = 0, uint64_t endOffset = UINT64_MAX) { + rangeSegments(startOffset, endOffset == UINT64_MAX ? UINT64_MAX : endOffset - startOffset, + [&](auto& segment, auto offsetInSegment, auto lengthInSegment, auto dstOffset) { + KU_ASSERT(segment->getResidencyState() == ResidencyState::IN_MEMORY); + auto* segmentData = segment->template getData(); + for (size_t i = offsetInSegment; i < lengthInSegment; i++) { + func(segmentData[i], dstOffset + i - offsetInSegment); + } + }); + } + + template + T getValue(common::offset_t pos) const { + KU_ASSERT(pos < getCapacity()); + auto [segment, offsetInSegment] = genericFindSegment(std::span(data), pos); + KU_ASSERT(segment->get() != nullptr); + KU_ASSERT((*segment)->getResidencyState() == ResidencyState::IN_MEMORY); + return (*segment)->template getValue(offsetInSegment); + } + + template + void setValue(T val, common::offset_t pos) const { + KU_ASSERT(pos < getCapacity()); + auto [segment, offsetInSegment] = genericFindSegment(std::span(data), pos); + KU_ASSERT(segment->get() != nullptr); + KU_ASSERT((*segment)->getResidencyState() == ResidencyState::IN_MEMORY); + (*segment)->template setValue(val, pos); + } + + void flush(PageAllocator& pageAllocator) { + for (auto& segment : data) { + KU_ASSERT(segment->getResidencyState() == ResidencyState::IN_MEMORY); + segment->flush(pageAllocator); + } + } + + void populateWithDefaultVal(evaluator::ExpressionEvaluator& defaultEvaluator, + uint64_t& numValues_, ColumnStats* newColumnStats) { + KU_ASSERT(data.size() == 1 && data.back()->getNumValues() == 0); + data.back()->populateWithDefaultVal(defaultEvaluator, numValues_, newColumnStats); + } + + void finalize() { + for (auto& segment : data) { + KU_ASSERT(segment->getResidencyState() == ResidencyState::IN_MEMORY); + segment->finalize(); + } + } + + // TODO(bmwinger): This is not ideal; it's just a workaround for storage_info + // We should either provide a way for ColumnChunk to provide its own details about the + // storage structure, or maybe change the type of data to allow us to directly return a + // std::span to get read-only info about the segments efficiently + std::vector getSegments() const { + std::vector segments; + for (const auto& segment : data) { + segments.push_back(segment.get()); + } + return segments; + } + + void checkpoint(Column& column, std::vector&& chunkCheckpointStates, + PageAllocator& pageAllocator); + + void write(Column& column, ChunkState& state, common::offset_t dstOffset, + const ColumnChunkData& dataToWrite, common::offset_t srcOffset, common::length_t numValues); + + void syncNumValues() { + for (auto& segment : data) { + segment->syncNumValues(); + } + } + + void setTableID(common::table_id_t tableID) { + for (const auto& segment : data) { + auto internalIDSegment = common::ku_dynamic_cast(segment.get()); + internalIDSegment->setTableID(tableID); + } + } + +private: + template + void rangeSegments(common::offset_t offsetInChunk, common::length_t length, Func func) const { + genericRangeSegments(std::span(data), offsetInChunk, length, func); + } + + void scanInMemSegments(ColumnChunkScanner& output, common::offset_t startRow, + common::offset_t numRows) const; + +private: + // TODO(Guodong): This field should be removed. Ideally it shouldn't be cached anywhere in + // storage structures, instead should be fed into functions needed from ClientContext + // dbConfig. + bool enableCompression; + std::vector> data; + UpdateInfo updateInfo; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_data.h new file mode 100644 index 0000000000..20f2b00f6f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_data.h @@ -0,0 +1,509 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/data_chunk/sel_vector.h" +#include "common/enums/rel_multiplicity.h" +#include "common/null_mask.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/compression/compression.h" +#include "storage/enums/residency_state.h" +#include "storage/table/column_chunk_metadata.h" +#include "storage/table/column_chunk_stats.h" +#include "storage/table/in_memory_exception_chunk.h" + +namespace lbug::storage { +class PageManager; +} +namespace lbug { +namespace evaluator { +class ExpressionEvaluator; +} // namespace evaluator + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { + +class Column; +class NullChunkData; +class ColumnStats; +class PageAllocator; +class FileHandle; + +// TODO(bmwinger): Hide access to variables. +struct SegmentState { + const Column* column; + ColumnChunkMetadata metadata; + uint64_t numValuesPerPage = UINT64_MAX; + std::unique_ptr nullState; + + // Used for struct/list/string columns. + std::vector childrenStates; + + // Used for floating point columns + std::variant>, + std::unique_ptr>> + alpExceptionChunk; + + explicit SegmentState(bool hasNull = true) : column{nullptr} { + if (hasNull) { + nullState = std::make_unique(false /*hasNull*/); + } + } + SegmentState(ColumnChunkMetadata metadata, uint64_t numValuesPerPage) + : column{nullptr}, metadata{std::move(metadata)}, numValuesPerPage{numValuesPerPage} { + nullState = std::make_unique(false /*hasNull*/); + } + + SegmentState& getChildState(common::idx_t childIdx) { + KU_ASSERT(childIdx < childrenStates.size()); + return childrenStates[childIdx]; + } + const SegmentState& getChildState(common::idx_t childIdx) const { + KU_ASSERT(childIdx < childrenStates.size()); + return childrenStates[childIdx]; + } + + template + InMemoryExceptionChunk* getExceptionChunk() { + using GetType = std::unique_ptr>; + KU_ASSERT(std::holds_alternative(alpExceptionChunk)); + return std::get(alpExceptionChunk).get(); + } + + template + const InMemoryExceptionChunk* getExceptionChunkConst() const { + using GetType = std::unique_ptr>; + KU_ASSERT(std::holds_alternative(alpExceptionChunk)); + return std::get(alpExceptionChunk).get(); + } + + void reclaimAllocatedPages(PageAllocator& pageAllocator) const; + + // Used by rangeSegments in column_chunk.h to provide the same interface as the segments stored + // in ColumnChunk inside unique_ptr + SegmentState& operator*() { return *this; } + const SegmentState& operator*() const { return *this; } + uint64_t getNumValues() const { return metadata.numValues; } +}; + +class Spiller; +// Base data segment covers all fixed-sized data types. +class LBUG_API ColumnChunkData { +public: + friend struct ColumnChunkFactory; + // For spilling to disk, we need access to the underlying buffer + friend class Spiller; + + ColumnChunkData(MemoryManager& mm, common::LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState, bool hasNullData, + bool initializeToZero = true); + ColumnChunkData(MemoryManager& mm, common::LogicalType dataType, bool enableCompression, + const ColumnChunkMetadata& metadata, bool hasNullData, bool initializeToZero = true); + ColumnChunkData(MemoryManager& mm, common::PhysicalTypeID physicalType, bool enableCompression, + const ColumnChunkMetadata& metadata, bool hasNullData, bool initializeToZero = true); + virtual ~ColumnChunkData(); + + template + T getValue(common::offset_t pos) const { + KU_ASSERT(pos < numValues); + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + return getData()[pos]; + } + template + void setValue(T val, common::offset_t pos) { + KU_ASSERT(pos < capacity); + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + getData()[pos] = val; + if (pos >= numValues) { + numValues = pos + 1; + } + if constexpr (StorageValueType) { + inMemoryStats.update(StorageValue{val}, dataType.getPhysicalType()); + } + } + + virtual bool isNull(common::offset_t pos) const; + void setNullData(std::unique_ptr nullData_) { nullData = std::move(nullData_); } + bool hasNullData() const { return nullData != nullptr; } + NullChunkData* getNullData() { return nullData.get(); } + const NullChunkData* getNullData() const { return nullData.get(); } + std::optional getNullMask() const; + std::unique_ptr moveNullData() { return std::move(nullData); } + + common::LogicalType& getDataType() { return dataType; } + const common::LogicalType& getDataType() const { return dataType; } + ResidencyState getResidencyState() const { return residencyState; } + bool isCompressionEnabled() const { return enableCompression; } + ColumnChunkMetadata& getMetadata() { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + return metadata; + } + const ColumnChunkMetadata& getMetadata() const { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + return metadata; + } + void setMetadata(const ColumnChunkMetadata& metadata_) { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + metadata = metadata_; + } + + // Only have side effects on in-memory or temporary chunks. + virtual void resetToAllNull(); + virtual void resetToEmpty(); + + // Note that the startPageIdx is not known, so it will always be common::INVALID_PAGE_IDX + virtual ColumnChunkMetadata getMetadataToFlush() const; + + virtual void append(common::ValueVector* vector, const common::SelectionView& selView); + virtual void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend); + + virtual void flush(PageAllocator& pageAllocator); + + ColumnChunkMetadata flushBuffer(PageAllocator& pageAllocator, const PageRange& entry, + const ColumnChunkMetadata& metadata) const; + + static common::page_idx_t getNumPagesForBytes(uint64_t numBytes) { + return (numBytes + common::LBUG_PAGE_SIZE - 1) / common::LBUG_PAGE_SIZE; + } + + uint64_t getNumBytesPerValue() const { return numBytesPerValue; } + uint8_t* getData() const; + template + T* getData() const { + return reinterpret_cast(getData()); + } + uint64_t getBufferSize() const; + + virtual void initializeScanState(SegmentState& state, const Column* column) const; + virtual void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector = 0) const; + virtual void lookup(common::offset_t offsetInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const; + + // TODO(Guodong): In general, this is not a good interface. Instead of passing in + // `offsetInVector`, we should flatten the vector to pos at `offsetInVector`. + virtual void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk); + virtual void write(ColumnChunkData* chunk, ColumnChunkData* offsetsInChunk, + common::RelMultiplicity multiplicity); + virtual void write(const ColumnChunkData* srcChunk, common::offset_t srcOffsetInChunk, + common::offset_t dstOffsetInChunk, common::offset_t numValuesToCopy); + + virtual void setToInMemory(); + // numValues must be at least the number of values the ColumnChunk was first initialized + // with + // reverse data and zero the part exceeding the original size + virtual void resize(uint64_t newCapacity); + // the opposite of the resize method, just simple resize + virtual void resizeWithoutPreserve(uint64_t newCapacity); + + void populateWithDefaultVal(evaluator::ExpressionEvaluator& defaultEvaluator, + uint64_t& numValues_, ColumnStats* newColumnStats); + virtual void finalize() { + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + // DO NOTHING. + } + + uint64_t getCapacity() const { return capacity; } + uint64_t getNumValues() const { return numValues; } + // TODO(Guodong): Alternatively, we can let `getNumValues` read from metadata when ON_DISK. + virtual void resetNumValuesFromMetadata(); + virtual void setNumValues(uint64_t numValues_); + // Just to provide the same interface for handleAppendException + inline void truncate(uint64_t numValues_) { setNumValues(numValues_); } + virtual void syncNumValues() {} + virtual bool numValuesSanityCheck() const; + + virtual bool sanityCheck() const; + + virtual uint64_t getEstimatedMemoryUsage() const; + bool shouldSplit() const { + // TODO(bmwinger): this should use the inMemoryStats to avoid scanning the data, however not + // all functions update them + return numValues > 1 && getSizeOnDisk() > std::max(getMinimumSizeOnDisk(), + common::StorageConfig::MAX_SEGMENT_SIZE); + } + const ColumnChunkStats& getInMemoryStats() const; + + // The minimum size is a function of the type's complexity and the page size + // If the page size is large, or the type is very complex, this could be larger than the max + // segment size (in which case we will treat the minimum size as the max segment size) E.g. if + // LBUG_PAGE_SIZE == MAX_SEGMENT_SIZE, even a normal column with non-constant-compressed nulls + // would have two pages and be detected as needing to split, even if the pages are nowhere near + // full. + // + // TODO(bmwinger): This was added to work around the issue of complex nested types having a + // larger initial size than the max segment size + // It should ideally be removed + virtual uint64_t getMinimumSizeOnDisk() const; + virtual uint64_t getSizeOnDisk() const; + // Not guaranteed to be accurate; not all functions keep the in memory statistics up to date! + virtual uint64_t getSizeOnDiskInMemoryStats() const; + + virtual void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(MemoryManager& mm, + common::Deserializer& deSer); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } + MemoryManager& getMemoryManager() const; + + void loadFromDisk(); + SpillResult spillToDisk(); + + MergedColumnChunkStats getMergedColumnChunkStats() const; + + void updateStats(const common::ValueVector* vector, const common::SelectionView& selVector); + + virtual void reclaimStorage(PageAllocator& pageAllocator); + + std::vector> split(bool targetMaxSize = false) const; + +protected: + // Initializes the data buffer and functions. They are (and should be) only called in + // constructor. + void initializeBuffer(common::PhysicalTypeID physicalType, MemoryManager& mm, + bool initializeToZero); + void initializeFunction(); + + // Note: This function is not setting child/null chunk data recursively. + void setToOnDisk(const ColumnChunkMetadata& metadata); + + virtual void copyVectorToBuffer(common::ValueVector* vector, common::offset_t startPosInChunk, + const common::SelectionView& selView); + + void resetInMemoryStats(); + +private: + using flush_buffer_func_t = std::function, + FileHandle*, const PageRange&, const ColumnChunkMetadata&)>; + flush_buffer_func_t initializeFlushBufferFunction( + std::shared_ptr compression) const; + uint64_t getBufferSize(uint64_t capacity_) const; + +protected: + using get_metadata_func_t = std::function, + uint64_t, StorageValue, StorageValue)>; + using get_min_max_func_t = + std::function(const uint8_t*, uint64_t)>; + + ResidencyState residencyState; + common::LogicalType dataType; + bool enableCompression; + uint32_t numBytesPerValue; + uint64_t capacity; + std::unique_ptr buffer; + std::unique_ptr nullData; + uint64_t numValues; + flush_buffer_func_t flushBufferFunction; + get_metadata_func_t getMetadataFunction; + + // On-disk metadata for column chunk. + ColumnChunkMetadata metadata; + + // Stats for any in-memory updates applied to the column chunk + // This will be merged with the on-disk metadata to get the overall stats + ColumnChunkStats inMemoryStats; +}; + +template<> +inline void ColumnChunkData::setValue(bool val, common::offset_t pos) { + KU_ASSERT(pos < capacity); + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + // Buffer is rounded up to the nearest 8 bytes so that this cast is safe + common::NullMask::setNull(getData(), pos, val); + if (pos >= numValues) { + numValues = pos + 1; + } + inMemoryStats.update(StorageValue{val}, dataType.getPhysicalType()); +} + +template<> +inline bool ColumnChunkData::getValue(common::offset_t pos) const { + // Buffer is rounded up to the nearest 8 bytes so that this cast is safe + return common::NullMask::isNull(getData(), pos); +} + +// Stored as bitpacked booleans in-memory and on-disk +class BoolChunkData : public ColumnChunkData { +public: + BoolChunkData(MemoryManager& mm, uint64_t capacity, bool enableCompression, ResidencyState type, + bool hasNullChunk) + : ColumnChunkData(mm, common::LogicalType::BOOL(), capacity, + // Booleans are always bitpacked, but this can also enable constant compression + enableCompression, type, hasNullChunk, true) {} + BoolChunkData(MemoryManager& mm, bool enableCompression, const ColumnChunkMetadata& metadata, + bool hasNullData) + : ColumnChunkData{mm, common::LogicalType::BOOL(), enableCompression, metadata, hasNullData, + true} {} + + void append(common::ValueVector* vector, const common::SelectionView& sel) final; + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) override; + + void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector = 0) const override; + void lookup(common::offset_t offsetInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const override; + + void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk) override; + void write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, + common::RelMultiplicity multiplicity) final; + void write(const ColumnChunkData* srcChunk, common::offset_t srcOffsetInChunk, + common::offset_t dstOffsetInChunk, common::offset_t numValuesToCopy) override; +}; + +class NullChunkData final : public BoolChunkData { +public: + NullChunkData(MemoryManager& mm, uint64_t capacity, bool enableCompression, ResidencyState type) + : BoolChunkData(mm, capacity, enableCompression, type, false /*hasNullData*/) {} + NullChunkData(MemoryManager& mm, bool enableCompression, const ColumnChunkMetadata& metadata) + : BoolChunkData{mm, enableCompression, metadata, false /*hasNullData*/} {} + + // Maybe this should be combined with BoolChunkData if the only difference is these + // functions? + bool isNull(common::offset_t pos) const override { return getValue(pos); } + void setNull(common::offset_t pos, bool isNull); + + bool noNullsGuaranteedInMem() const { + return !inMemoryStats.max || !inMemoryStats.max->get(); + } + bool allNullsGuaranteedInMem() const { + return !inMemoryStats.min || inMemoryStats.min->get(); + } + bool haveNoNullsGuaranteed() const; + bool haveAllNullsGuaranteed() const; + + void resetToEmpty() override { + memset(getData(), 0 /* non null */, getBufferSize()); + numValues = 0; + inMemoryStats.min = inMemoryStats.max = std::nullopt; + } + void resetToNoNull() { + memset(getData(), 0 /* non null */, getBufferSize()); + inMemoryStats.min = inMemoryStats.max = false; + } + void resetToAllNull() override { + memset(getData(), 0xFF /* null */, getBufferSize()); + inMemoryStats.min = inMemoryStats.max = true; + } + + void copyFromBuffer(const uint64_t* srcBuffer, uint64_t srcOffset, uint64_t dstOffset, + uint64_t numBits) { + KU_ASSERT(numBits > 0); + common::NullMask::copyNullMask(srcBuffer, srcOffset, getData(), dstOffset, + numBits); + auto [min, max] = common::NullMask::getMinMax(srcBuffer, srcOffset, numBits); + if (!inMemoryStats.min.has_value() || min < inMemoryStats.min->get()) { + inMemoryStats.min = min; + } + if (!inMemoryStats.max.has_value() || max > inMemoryStats.max->get()) { + inMemoryStats.max = max; + } + if ((dstOffset + numBits) >= numValues) { + numValues = dstOffset + numBits; + } + } + + // Appends the null data from the vector's null mask + void appendNulls(const common::ValueVector* vector, const common::SelectionView& selView, + common::offset_t startPosInChunk); + + // NullChunkData::scan updates the null mask of output vector + void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector = 0) const override; + + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) override; + + void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk) override; + void write(const ColumnChunkData* srcChunk, common::offset_t srcOffsetInChunk, + common::offset_t dstOffsetInChunk, common::offset_t numValuesToCopy) override; + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(MemoryManager& mm, + common::Deserializer& deSer); + + common::NullMask getNullMask() const; +}; + +class LBUG_API InternalIDChunkData final : public ColumnChunkData { +public: + // TODO(Guodong): Should make InternalIDChunkData has no NULL. + // Physically, we only materialize offset of INTERNAL_ID, which is same as UINT64, + InternalIDChunkData(MemoryManager& mm, uint64_t capacity, bool enableCompression, + ResidencyState residencyState) + : ColumnChunkData(mm, common::LogicalType::INTERNAL_ID(), capacity, enableCompression, + residencyState, false /*hasNullData*/), + commonTableID{common::INVALID_TABLE_ID} {} + InternalIDChunkData(MemoryManager& mm, bool enableCompression, + const ColumnChunkMetadata& metadata) + : ColumnChunkData{mm, common::LogicalType::INTERNAL_ID(), enableCompression, metadata, + false /*hasNullData*/}, + commonTableID{common::INVALID_TABLE_ID} {} + + void append(common::ValueVector* vector, const common::SelectionView& selView) override; + + void copyVectorToBuffer(common::ValueVector* vector, common::offset_t startPosInChunk, + const common::SelectionView& selView) override; + + void copyInt64VectorToBuffer(common::ValueVector* vector, common::offset_t startPosInChunk, + const common::SelectionView& selView) const; + + void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector = 0) const override; + void lookup(common::offset_t offsetInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const override; + + void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk) override; + + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) override; + + void setTableID(common::table_id_t tableID) { commonTableID = tableID; } + common::table_id_t getTableID() const { return commonTableID; } + + common::offset_t operator[](common::offset_t pos) const { + return getValue(pos); + } + common::offset_t& operator[](common::offset_t pos) { return getData()[pos]; } + +private: + common::table_id_t commonTableID; +}; + +struct ColumnChunkFactory { + static std::unique_ptr createColumnChunkData(MemoryManager& mm, + common::LogicalType dataType, bool enableCompression, uint64_t capacity, + ResidencyState residencyState, bool hasNullData = true, bool initializeToZero = true); + static std::unique_ptr createColumnChunkData(MemoryManager& mm, + common::LogicalType dataType, bool enableCompression, ColumnChunkMetadata& metadata, + bool hasNullData, bool initializeToZero); + + static std::unique_ptr createNullChunkData(MemoryManager& mm, + bool enableCompression, uint64_t capacity, ResidencyState type) { + return std::make_unique(mm, capacity, enableCompression, type); + } +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_metadata.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_metadata.h new file mode 100644 index 0000000000..e2dc38441c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_metadata.h @@ -0,0 +1,81 @@ +#pragma once + +#include "common/types/types.h" +#include "storage/compression/compression.h" +#include "storage/page_range.h" + +namespace lbug::storage { +struct ColumnChunkMetadata { + PageRange pageRange; + uint64_t numValues; + CompressionMetadata compMeta; + + common::page_idx_t getStartPageIdx() const { return pageRange.startPageIdx; } + common::page_idx_t getNumPages() const { return pageRange.numPages; } + + // Returns the number of pages used to store data + // In the case of ALP compression, this does not include the number of pages used to store + // exceptions + common::page_idx_t getNumDataPages(common::PhysicalTypeID dataType) const; + + void serialize(common::Serializer& serializer) const; + static ColumnChunkMetadata deserialize(common::Deserializer& deserializer); + + // TODO(Guodong): Delete copy constructor. + ColumnChunkMetadata() + : pageRange(common::INVALID_PAGE_IDX, 0), numValues{0}, + compMeta(StorageValue(), StorageValue(), CompressionType::CONSTANT) {} + ColumnChunkMetadata(common::page_idx_t pageIdx, common::page_idx_t numPages, uint64_t numValues, + const CompressionMetadata& compMeta) + : pageRange(pageIdx, numPages), numValues(numValues), compMeta(compMeta) {} +}; + +class GetCompressionMetadata { + std::shared_ptr alg; + const common::LogicalType& dataType; + +public: + GetCompressionMetadata(std::shared_ptr alg, const common::LogicalType& dataType) + : alg{std::move(alg)}, dataType{dataType} {} + + GetCompressionMetadata(const GetCompressionMetadata& other) = default; + + ColumnChunkMetadata operator()(std::span buffer, uint64_t numValues, + StorageValue min, StorageValue max) const; +}; + +class GetBitpackingMetadata { + std::shared_ptr alg; + const common::LogicalType& dataType; + +public: + GetBitpackingMetadata(std::shared_ptr alg, const common::LogicalType& dataType) + : alg{std::move(alg)}, dataType{dataType} {} + + GetBitpackingMetadata(const GetBitpackingMetadata& other) = default; + + ColumnChunkMetadata operator()(std::span buffer, uint64_t numValues, + StorageValue min, StorageValue max); +}; + +template +class GetFloatCompressionMetadata { + std::shared_ptr alg; + const common::LogicalType& dataType; + +public: + GetFloatCompressionMetadata(std::shared_ptr alg, + const common::LogicalType& dataType) + : alg{std::move(alg)}, dataType{dataType} {} + + GetFloatCompressionMetadata(const GetFloatCompressionMetadata& other) = default; + + ColumnChunkMetadata operator()(std::span buffer, uint64_t numValues, + StorageValue min, StorageValue max); +}; + +ColumnChunkMetadata uncompressedGetMetadata(common::PhysicalTypeID dataType, uint64_t numValues, + StorageValue min, StorageValue max); + +ColumnChunkMetadata booleanGetMetadata(uint64_t numValues, StorageValue min, StorageValue max); +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_scanner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_scanner.h new file mode 100644 index 0000000000..d4a04a077f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_scanner.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "common/types/types.h" +namespace lbug { +namespace transaction { +class Transaction; +} +namespace storage { +class ColumnChunkData; +class UpdateInfo; + +struct ColumnChunkScanner { + using scan_func_t = std::function; + + virtual ~ColumnChunkScanner(){}; + virtual void scanSegment(common::offset_t offsetInSegment, common::offset_t segmentLength, + scan_func_t scanFunc) = 0; + virtual void applyCommittedUpdates(const UpdateInfo& updateInfo, + const transaction::Transaction* transaction, common::offset_t startRow, + common::offset_t numRows) = 0; + virtual uint64_t getNumValues() = 0; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_stats.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_stats.h new file mode 100644 index 0000000000..0fb5cfc65c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_chunk_stats.h @@ -0,0 +1,36 @@ +#pragma once + +#include "storage/compression/compression.h" +namespace common { +class ValueVector; +} +namespace lbug::storage { +class ColumnChunkData; + +struct LBUG_API ColumnChunkStats { + std::optional max; + std::optional min; + + void update(std::optional min, std::optional max, + common::PhysicalTypeID dataType); + void update(StorageValue val, common::PhysicalTypeID dataType); + void update(const common::ValueVector& valueVector, uint64_t offset, uint64_t numValues, + common::PhysicalTypeID physicalType); + void update(const ColumnChunkData& data, uint64_t offset, uint64_t numValues, + common::PhysicalTypeID physicalType); + void reset(); +}; + +struct MergedColumnChunkStats { + MergedColumnChunkStats(ColumnChunkStats stats, bool guaranteedNoNulls, bool guaranteedAllNulls) + : stats(stats), guaranteedNoNulls(guaranteedNoNulls), + guaranteedAllNulls(guaranteedAllNulls) {} + + ColumnChunkStats stats; + bool guaranteedNoNulls; + bool guaranteedAllNulls; + + void merge(const MergedColumnChunkStats& o, common::PhysicalTypeID dataType); +}; + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_reader_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_reader_writer.h new file mode 100644 index 0000000000..74f3fceb56 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/column_reader_writer.h @@ -0,0 +1,91 @@ +#pragma once + +#include "storage/compression/float_compression.h" + +namespace lbug { +namespace transaction { +class Transaction; +} +namespace storage { + +class FileHandle; +class ColumnReadWriter; +class ShadowFile; +struct ColumnChunkMetadata; +struct SegmentState; + +template +using read_value_from_page_func_t = std::function; + +template +using read_values_from_page_func_t = std::function; + +using read_values_to_vector_func_t = read_values_from_page_func_t; +using read_values_to_page_func_t = read_values_from_page_func_t; + +template +using write_values_to_page_func_t = std::function; + +using write_values_from_vector_func_t = write_values_to_page_func_t; + +using write_values_func_t = write_values_to_page_func_t; + +using filter_func_t = std::function; + +struct ColumnReadWriterFactory { + static std::unique_ptr createColumnReadWriter(common::PhysicalTypeID dataType, + FileHandle* dataFH, ShadowFile* shadowFile); +}; + +class ColumnReadWriter { +public: + ColumnReadWriter(FileHandle* dataFH, ShadowFile* shadowFile); + + virtual ~ColumnReadWriter() = default; + + virtual void readCompressedValueToPage(const SegmentState& state, common::offset_t nodeOffset, + uint8_t* result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) = 0; + + virtual void readCompressedValueToVector(const SegmentState& state, common::offset_t nodeOffset, + common::ValueVector* result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) = 0; + + virtual uint64_t readCompressedValuesToPage(const SegmentState& state, uint8_t* result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc = {}) = 0; + + virtual uint64_t readCompressedValuesToVector(const SegmentState& state, + common::ValueVector* result, uint32_t startOffsetInResult, uint64_t startOffsetInSegment, + uint64_t length, const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc = {}) = 0; + + virtual void writeValueToPageFromVector(SegmentState& state, common::offset_t offsetInChunk, + common::ValueVector* vectorToWriteFrom, uint32_t posInVectorToWriteFrom, + const write_values_from_vector_func_t& writeFromVectorFunc) = 0; + + virtual void writeValuesToPageFromBuffer(SegmentState& state, common::offset_t dstOffset, + const uint8_t* data, const common::NullMask* nullChunkData, common::offset_t srcOffset, + common::offset_t numValues, const write_values_func_t& writeFunc) = 0; + + void readFromPage(common::page_idx_t pageIdx, + const std::function& readFunc) const; + + void updatePageWithCursor(PageCursor cursor, + const std::function& writeOp) const; + +protected: + static PageCursor getPageCursorForOffsetInGroup(common::offset_t offsetInChunk, + common::page_idx_t groupPageIdx, uint64_t numValuesPerPage); + +private: + FileHandle* dataFH; + ShadowFile* shadowFile; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/combined_chunk_scanner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/combined_chunk_scanner.h new file mode 100644 index 0000000000..1fe48f08cf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/combined_chunk_scanner.h @@ -0,0 +1,30 @@ +#pragma once + +#include "storage/table/column_chunk_data.h" +#include "storage/table/column_chunk_scanner.h" +#include "storage/table/update_info.h" +namespace lbug { +namespace storage { +// Scans all the segments into a single output chunk +struct CombinedChunkScanner : public ColumnChunkScanner { + explicit CombinedChunkScanner(ColumnChunkData& output) + : output(output), numValuesBeforeScan(output.getNumValues()) {} + + void scanSegment(common::offset_t offsetInSegment, common::offset_t length, + scan_func_t scanFunc) override { + scanFunc(output, offsetInSegment, length); + } + + void applyCommittedUpdates(const UpdateInfo& updateInfo, + const transaction::Transaction* transaction, common::offset_t startRow, + common::offset_t numRows) override { + updateInfo.scanCommitted(transaction, output, numValuesBeforeScan, startRow, numRows); + } + + uint64_t getNumValues() override { return output.getNumValues(); } + + ColumnChunkData& output; + common::offset_t numValuesBeforeScan; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/compression_flush_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/compression_flush_buffer.h new file mode 100644 index 0000000000..126991c9b8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/compression_flush_buffer.h @@ -0,0 +1,45 @@ +#pragma once + +#include "storage/compression/compression.h" +#include "storage/table/column_chunk_metadata.h" + +namespace lbug::storage { +class FileHandle; + +class CompressedFlushBuffer { + std::shared_ptr alg; + common::PhysicalTypeID dataType; + +public: + CompressedFlushBuffer(std::shared_ptr alg, common::PhysicalTypeID dataType) + : alg{std::move(alg)}, dataType{dataType} {} + CompressedFlushBuffer(std::shared_ptr alg, const common::LogicalType& dataType) + : CompressedFlushBuffer(std::move(alg), dataType.getPhysicalType()) {} + + CompressedFlushBuffer(const CompressedFlushBuffer& other) = default; + + ColumnChunkMetadata operator()(std::span buffer, FileHandle* dataFH, + const PageRange& entry, const ColumnChunkMetadata& metadata) const; +}; + +template +class CompressedFloatFlushBuffer { + std::shared_ptr alg; + common::PhysicalTypeID dataType; + +public: + CompressedFloatFlushBuffer(std::shared_ptr alg, + common::PhysicalTypeID dataType); + CompressedFloatFlushBuffer(std::shared_ptr alg, + const common::LogicalType& dataType); + + CompressedFloatFlushBuffer(const CompressedFloatFlushBuffer& other) = default; + + ColumnChunkMetadata operator()(std::span buffer, FileHandle* dataFH, + const PageRange& entry, const ColumnChunkMetadata& metadata) const; +}; + +ColumnChunkMetadata uncompressedFlushBuffer(std::span buffer, FileHandle* dataFH, + const PageRange& entry, const ColumnChunkMetadata& metadata); + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/csr_chunked_node_group.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/csr_chunked_node_group.h new file mode 100644 index 0000000000..c750ad4af5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/csr_chunked_node_group.h @@ -0,0 +1,215 @@ +#pragma once + +#include + +#include "storage/enums/residency_state.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/column_chunk.h" + +namespace lbug { +namespace storage { +class PageAllocator; +class MemoryManager; + +struct CSRRegion { + common::idx_t regionIdx = common::INVALID_IDX; + common::idx_t level = common::INVALID_IDX; + common::offset_t leftNodeOffset = common::INVALID_OFFSET; + common::offset_t rightNodeOffset = common::INVALID_OFFSET; + int64_t sizeChange = 0; + // Track if there is any updates to persistent data in this region per column in table. + // Note: should be accessed with columnID. + std::vector hasUpdates; + // Note: `sizeChange` equal to 0 is not enough to indicate the region has no insert or + // delete. It might just be num of insertions are equal to num of deletions. + // hasInsertions is true if there are insertions that are not deleted yet in this region. + bool hasInsertions = false; + bool hasPersistentDeletions = false; + + CSRRegion(common::idx_t regionIdx, common::idx_t level); + + bool needCheckpoint() const { + return hasInsertions || hasPersistentDeletions || + std::any_of(hasUpdates.begin(), hasUpdates.end(), + [](bool hasUpdate) { return hasUpdate; }); + } + bool needCheckpointColumn(common::column_id_t columnID) const { + KU_ASSERT(columnID < hasUpdates.size()); + return hasInsertions || hasPersistentDeletions || hasUpdates[columnID]; + } + bool hasDeletionsOrInsertions() const { return hasInsertions || hasPersistentDeletions; } + common::idx_t getLeftLeafRegionIdx() const { return regionIdx << level; } + common::idx_t getRightLeafRegionIdx() const { + const auto rightRegionIdx = + getLeftLeafRegionIdx() + (static_cast(1) << level) - 1; + constexpr auto maxNumRegions = + common::StorageConfig::NODE_GROUP_SIZE / common::StorageConfig::CSR_LEAF_REGION_SIZE; + if (rightRegionIdx >= maxNumRegions) { + return maxNumRegions - 1; + } + return rightRegionIdx; + } + // Return true if other is within the realm of this region. + bool isWithin(const CSRRegion& other) const; + + static CSRRegion upgradeLevel(const std::vector& leafRegions, + const CSRRegion& region); +}; + +struct LBUG_API InMemChunkedCSRHeader { + std::unique_ptr offset; + std::unique_ptr length; + bool randomLookup = false; + + InMemChunkedCSRHeader(MemoryManager& memoryManager, bool enableCompression, uint64_t capacity); + InMemChunkedCSRHeader(std::unique_ptr offset, + std::unique_ptr length) + : offset{std::move(offset)}, length{std::move(length)} { + KU_ASSERT(this->offset && this->length); + } + + common::offset_t getStartCSROffset(common::offset_t nodeOffset) const; + common::offset_t getEndCSROffset(common::offset_t nodeOffset) const; + common::length_t getCSRLength(common::offset_t nodeOffset) const; + common::length_t getGapSize(common::length_t length) const; + + bool sanityCheck() const; + void copyFrom(const InMemChunkedCSRHeader& other) const; + void fillDefaultValues(common::offset_t newNumValues) const; + void setNumValues(const common::offset_t numValues) const { + offset->setNumValues(numValues); + length->setNumValues(numValues); + } + + // Return a vector of CSR offsets for the end of each CSR region. + common::offset_vec_t populateStartCSROffsetsFromLength(bool leaveGaps) const; + void populateEndCSROffsetFromStartAndLength() const; + void finalizeCSRRegionEndOffsets(const common::offset_vec_t& rightCSROffsetOfRegions) const; + void populateRegionCSROffsets(const CSRRegion& region, + const InMemChunkedCSRHeader& oldHeader) const; + void populateEndCSROffsets(const common::offset_vec_t& gaps) const; + common::idx_t getNumRegions() const; + +private: + static common::length_t computeGapFromLength(common::length_t length); +}; + +struct ChunkedCSRHeader { + std::unique_ptr offset; + std::unique_ptr length; + bool randomLookup = false; + + ChunkedCSRHeader(MemoryManager& memoryManager, bool enableCompression, uint64_t capacity, + ResidencyState residencyState); + ChunkedCSRHeader(bool enableCompression, InMemChunkedCSRHeader&& other) + : offset{std::make_unique(enableCompression, std::move(other.offset))}, + length{std::make_unique(enableCompression, std::move(other.length))}, + randomLookup{other.randomLookup} {} + ChunkedCSRHeader(std::unique_ptr offset, std::unique_ptr length) + : offset{std::move(offset)}, length{std::move(length)} { + KU_ASSERT(this->offset && this->length); + } + + common::offset_t getStartCSROffset(common::offset_t nodeOffset) const; + common::offset_t getEndCSROffset(common::offset_t nodeOffset) const; + common::length_t getCSRLength(common::offset_t nodeOffset) const; + common::length_t getGapSize(common::length_t length) const; + + bool sanityCheck() const; + + // Return a vector of CSR offsets for the end of each CSR region. + common::offset_vec_t populateStartCSROffsetsFromLength(bool leaveGaps) const; + void populateEndCSROffsetFromStartAndLength() const; + void finalizeCSRRegionEndOffsets(const common::offset_vec_t& rightCSROffsetOfRegions) const; + void populateRegionCSROffsets(const CSRRegion& region, const ChunkedCSRHeader& oldHeader) const; + void populateEndCSROffsets(const common::offset_vec_t& gaps) const; + common::idx_t getNumRegions() const; + +private: + static common::length_t computeGapFromLength(common::length_t length); +}; + +class InMemChunkedCSRNodeGroup; + +struct CSRNodeGroupCheckpointState; +class ChunkedCSRNodeGroup final : public ChunkedNodeGroup { + friend class InMemChunkedCSRNodeGroup; + +public: + ChunkedCSRNodeGroup(MemoryManager& mm, const std::vector& columnTypes, + bool enableCompression, uint64_t capacity, common::offset_t startOffset, + ResidencyState residencyState) + : ChunkedNodeGroup{mm, columnTypes, enableCompression, capacity, startOffset, + residencyState, NodeGroupDataFormat::CSR}, + csrHeader{mm, enableCompression, common::StorageConfig::NODE_GROUP_SIZE, residencyState} { + } + ChunkedCSRNodeGroup(InMemChunkedCSRNodeGroup& base, + const std::vector& selectedColumns); + ChunkedCSRNodeGroup(ChunkedCSRNodeGroup& base, + const std::vector& selectedColumns) + : ChunkedNodeGroup{base, selectedColumns}, csrHeader{std::move(base.csrHeader)} {} + ChunkedCSRNodeGroup(MemoryManager& mm, ChunkedCSRNodeGroup& base, + std::span columnTypes, + std::span baseColumnIDs) + : ChunkedNodeGroup(mm, base, columnTypes, baseColumnIDs), + csrHeader(std::move(base.csrHeader)) {} + ChunkedCSRNodeGroup(ChunkedCSRHeader csrHeader, + std::vector> chunks, common::row_idx_t startRowIdx) + : ChunkedNodeGroup{std::move(chunks), startRowIdx, NodeGroupDataFormat::CSR}, + csrHeader{std::move(csrHeader)} {} + + ChunkedCSRHeader& getCSRHeader() { return csrHeader; } + const ChunkedCSRHeader& getCSRHeader() const { return csrHeader; } + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(MemoryManager& memoryManager, + common::Deserializer& deSer); + + void scanCSRHeader(MemoryManager& memoryManager, CSRNodeGroupCheckpointState& csrState) const; + + void reclaimStorage(PageAllocator& pageAllocator) const override; + +private: + ChunkedCSRHeader csrHeader; +}; + +class InMemChunkedCSRNodeGroup final : public InMemChunkedNodeGroup { + friend class ChunkedCSRNodeGroup; + +public: + InMemChunkedCSRNodeGroup(MemoryManager& mm, const std::vector& columnTypes, + bool enableCompression, uint64_t capacity, common::offset_t startOffset) + : InMemChunkedNodeGroup{mm, columnTypes, enableCompression, capacity, startOffset}, + csrHeader{mm, enableCompression, common::StorageConfig::NODE_GROUP_SIZE} {} + + InMemChunkedCSRNodeGroup(InMemChunkedCSRNodeGroup& base, + const std::vector& selectedColumns) + : InMemChunkedNodeGroup{base, selectedColumns}, csrHeader{std::move(base.csrHeader)} {} + + InMemChunkedCSRHeader& getCSRHeader() { return csrHeader; } + const InMemChunkedCSRHeader& getCSRHeader() const { return csrHeader; } + + // this does not override ChunkedNodeGroup::merge() since clang-tidy analyzer + // seems to struggle with detecting the std::move of the header unless this is inlined + void mergeChunkedCSRGroup(InMemChunkedCSRNodeGroup& base, + const std::vector& columnsToMergeInto) { + InMemChunkedNodeGroup::merge(base, columnsToMergeInto); + csrHeader = InMemChunkedCSRHeader(std::move(base.csrHeader.offset), + std::move(base.csrHeader.length)); + } + + void writeToColumnChunk(common::idx_t chunkIdx, common::idx_t vectorIdx, + const std::vector>& data, + ColumnChunkData& offsetChunk) override { + chunks[chunkIdx]->write(data[vectorIdx].get(), &offsetChunk, common::RelMultiplicity::MANY); + } + + std::unique_ptr flush(transaction::Transaction* transaction, + PageAllocator& pageAllocator) override; + +private: + InMemChunkedCSRHeader csrHeader; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/csr_node_group.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/csr_node_group.h new file mode 100644 index 0000000000..c797779b8f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/csr_node_group.h @@ -0,0 +1,291 @@ +#pragma once + +#include +#include + +#include "common/constants.h" +#include "common/system_config.h" +#include "storage/enums/csr_node_group_scan_source.h" +#include "storage/table/csr_chunked_node_group.h" +#include "storage/table/node_group.h" + +namespace lbug { +namespace transaction { +class Transaction; +} +namespace storage { +class MemoryManager; + +using row_idx_vec_t = std::vector; + +struct csr_list_t { + common::row_idx_t startRow = common::INVALID_ROW_IDX; + common::length_t length = 0; +}; + +// Store rows of a CSR list. +// If rows of the CSR list are stored in a sequential order, then `isSequential` is set to true. +// rowIndices consists of startRowIdx and length. +// Otherwise, rowIndices records the row indices of each row in the CSR list. +struct NodeCSRIndex { + // TODO(Guodong): Should seperate `isSequential` and `rowIndices` to two different data + // structures. so the struct can be more space efficient. + bool isSequential = false; + row_idx_vec_t rowIndices; // TODO(Guodong): Should optimze the vector to a more space-efficient + // data structure. + + bool isEmpty() const { return getNumRows() == 0; } + common::row_idx_t getNumRows() const { + if (isSequential) { + return rowIndices[1]; + } + return rowIndices.size(); + } + row_idx_vec_t getRows() const { + if (isSequential) { + row_idx_vec_t result; + result.reserve(rowIndices[1]); + for (common::row_idx_t i = 0u; i < rowIndices[1]; ++i) { + result.push_back(i + rowIndices[0]); + } + return result; + } + return rowIndices; + } + + void clear() { + isSequential = false; + rowIndices.clear(); + } + + void turnToNonSequential() { + if (isSequential) { + row_idx_vec_t newIndices; + newIndices.reserve(rowIndices[1]); + for (common::row_idx_t i = 0u; i < rowIndices[1]; ++i) { + newIndices.push_back(i + rowIndices[0]); + } + rowIndices = std::move(newIndices); + isSequential = false; + } + } + void setInvalid(common::idx_t idx) { + KU_ASSERT(!isSequential); + KU_ASSERT(idx < rowIndices.size()); + rowIndices[idx] = common::INVALID_ROW_IDX; + } +}; + +// TODO(Guodong): Split CSRIndex into two levels: one level per csr leaf region, another per node +// under the leaf region. This should be more space efficient. +struct CSRIndex { + std::array indices; + + common::row_idx_t getNumRows(common::offset_t offset) const { + return indices[offset].getNumRows(); + } + + common::offset_t getMaxOffsetWithRels() const { + common::offset_t maxOffset = 0; + for (auto i = 0u; i < indices.size(); i++) { + if (!indices[i].isEmpty()) { + maxOffset = i; + } + } + return maxOffset; + } +}; + +// TODO(Guodong): Serialize the info to disk. This should be a config per node group. +struct PackedCSRInfo { + static_assert(common::StorageConfig::NODE_GROUP_SIZE_LOG2 > + common::StorageConfig::CSR_LEAF_REGION_SIZE_LOG2); + uint64_t calibratorTreeHeight = common::StorageConfig::NODE_GROUP_SIZE_LOG2 - + common::StorageConfig::CSR_LEAF_REGION_SIZE_LOG2; + double highDensityStep = (common::StorageConstants::LEAF_HIGH_CSR_DENSITY - + common::StorageConstants::PACKED_CSR_DENSITY) / + static_cast(calibratorTreeHeight); + + constexpr PackedCSRInfo() noexcept = default; +}; + +class CSRNodeGroup; +struct RelTableScanState; +struct CSRNodeGroupScanState final : NodeGroupScanState { + // Cached offsets and lengths for a sequence of CSR lists within the current vector of + // boundNodes. + std::unique_ptr header; + + std::optional> cachedScannedVectorsSelBitset; + // The total number of rows (i.e., rels) in the current node group. + common::row_idx_t numTotalRows; + // The number of rows (i.e., rels) that have been scanned so far in current node group. + common::row_idx_t numCachedRows; + common::row_idx_t nextCachedRowToScan; + + // States at the csr list level. Cached during scan over a single csr list. + NodeCSRIndex inMemCSRList; + + CSRNodeGroupScanSource source; + + // This is for local scan state where we don't need `header`. + explicit CSRNodeGroupScanState() + : header{nullptr}, numTotalRows{0}, numCachedRows{0}, nextCachedRowToScan{0}, + source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} {} + explicit CSRNodeGroupScanState(common::idx_t numChunks) + : NodeGroupScanState{numChunks}, header{nullptr}, numTotalRows{0}, numCachedRows{0}, + nextCachedRowToScan{0}, source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} {} + explicit CSRNodeGroupScanState(MemoryManager& mm, bool randomLookup = false) + : numTotalRows{0}, numCachedRows{0}, nextCachedRowToScan{0}, + source{CSRNodeGroupScanSource::COMMITTED_PERSISTENT} { + header = std::make_unique(mm, false, + randomLookup ? 1 : common::StorageConfig::NODE_GROUP_SIZE); + } + + bool tryScanCachedTuples(RelTableScanState& tableScanState); +}; + +struct CSRNodeGroupCheckpointState final : NodeGroupCheckpointState { + Column* csrOffsetColumn; + Column* csrLengthColumn; + + std::unique_ptr oldHeader; + std::unique_ptr newHeader; + + CSRNodeGroupCheckpointState(std::vector columnIDs, + std::vector columns, PageAllocator& pageAllocator, MemoryManager* mm, + Column* csrOffsetCol, Column* csrLengthCol) + : NodeGroupCheckpointState{std::move(columnIDs), std::move(columns), pageAllocator, mm}, + csrOffsetColumn{csrOffsetCol}, csrLengthColumn{csrLengthCol} {} +}; + +static constexpr common::column_id_t NBR_ID_COLUMN_ID = 0; +static constexpr common::column_id_t REL_ID_COLUMN_ID = 1; + +// Data in a CSRNodeGroup is organized as follows: +// - persistent data: checkpointed data or flushed data from batch insert. `persistentChunkGroup`. +// - transient data: data that is being committed but kept in memory. `chunkedGroups`. +// Persistent data are organized in CSR format. +// Transient data are organized similar to normal node groups. Tuples are always appended to the end +// of `chunkedGroups`. We keep an extra csrIndex to track the vector of row indices for each bound +// node. +class CSRNodeGroup final : public NodeGroup { +public: + static constexpr PackedCSRInfo DEFAULT_PACKED_CSR_INFO{}; + + CSRNodeGroup(MemoryManager& mm, const common::node_group_idx_t nodeGroupIdx, + const bool enableCompression, std::vector dataTypes) + : NodeGroup{mm, nodeGroupIdx, enableCompression, std::move(dataTypes), + common::INVALID_OFFSET, NodeGroupDataFormat::CSR} {} + CSRNodeGroup(MemoryManager& mm, const common::node_group_idx_t nodeGroupIdx, + const bool enableCompression, std::unique_ptr chunkedNodeGroup) + : NodeGroup{mm, nodeGroupIdx, enableCompression, common::INVALID_OFFSET, + NodeGroupDataFormat::CSR}, + persistentChunkGroup{std::move(chunkedNodeGroup)} { + for (auto i = 0u; i < persistentChunkGroup->getNumColumns(); i++) { + dataTypes.push_back(persistentChunkGroup->getColumnChunk(i).getDataType().copy()); + } + } + + void initializeScanState(const transaction::Transaction* transaction, + TableScanState& state) const override; + NodeGroupScanResult scan(const transaction::Transaction* transaction, + TableScanState& state) const override; + + void appendChunkedCSRGroup(const transaction::Transaction* transaction, + const std::vector& columnIDs, InMemChunkedCSRNodeGroup& chunkedGroup); + void append(const transaction::Transaction* transaction, + const std::vector& columnIDs, common::offset_t boundOffsetInGroup, + std::span chunks, common::row_idx_t startRowInChunks, + common::row_idx_t numRows); + + void update(const transaction::Transaction* transaction, CSRNodeGroupScanSource source, + common::row_idx_t rowIdxInGroup, common::column_id_t columnID, + const common::ValueVector& propertyVector); + bool delete_(const transaction::Transaction* transaction, CSRNodeGroupScanSource source, + common::row_idx_t rowIdxInGroup); + + void addColumn(TableAddColumnState& addColumnState, PageAllocator* pageAllocator, + ColumnStats* newColumnStats) override; + + void checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state) override; + void reclaimStorage(PageAllocator& pageAllocator, const common::UniqLock& lock) const override; + + bool isEmpty() const override { return !persistentChunkGroup && NodeGroup::isEmpty(); } + + ChunkedNodeGroup* getPersistentChunkedGroup() const { return persistentChunkGroup.get(); } + void setPersistentChunkedGroup(std::unique_ptr chunkedNodeGroup) { + KU_ASSERT(chunkedNodeGroup->getFormat() == NodeGroupDataFormat::CSR); + persistentChunkGroup = std::move(chunkedNodeGroup); + } + + void serialize(common::Serializer& serializer) override; + +private: + void initScanForCommittedPersistent(const transaction::Transaction* transaction, + RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const; + static void initScanForCommittedInMem(RelTableScanState& relScanState, + CSRNodeGroupScanState& nodeGroupScanState); + + void updateCSRIndex(common::offset_t boundNodeOffsetInGroup, common::row_idx_t startRow, + common::length_t length) const; + + NodeGroupScanResult scanCommittedPersistent(const transaction::Transaction* transaction, + RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; + NodeGroupScanResult scanCommittedPersistentWithCache( + const transaction::Transaction* transaction, RelTableScanState& tableState, + CSRNodeGroupScanState& nodeGroupScanState) const; + NodeGroupScanResult scanCommittedPersistentWithoutCache( + const transaction::Transaction* transaction, RelTableScanState& tableState, + CSRNodeGroupScanState& nodeGroupScanState) const; + + NodeGroupScanResult scanCommittedInMem(const transaction::Transaction* transaction, + RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; + NodeGroupScanResult scanCommittedInMemSequential(const transaction::Transaction* transaction, + const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; + NodeGroupScanResult scanCommittedInMemRandom(const transaction::Transaction* transaction, + const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const; + + void checkpointInMemOnly(const common::UniqLock& lock, NodeGroupCheckpointState& state); + void checkpointInMemAndOnDisk(const common::UniqLock& lock, NodeGroupCheckpointState& state); + + void populateCSRLengthInMemOnly(const common::UniqLock& lock, common::offset_t numNodes, + const CSRNodeGroupCheckpointState& csrState); + + void collectRegionChangesAndUpdateHeaderLength(const common::UniqLock& lock, CSRRegion& region, + const CSRNodeGroupCheckpointState& csrState) const; + void collectInMemRegionChangesAndUpdateHeaderLength(const common::UniqLock& lock, + CSRRegion& region, const CSRNodeGroupCheckpointState& csrState) const; + void collectOnDiskRegionChangesAndUpdateHeaderLength(const common::UniqLock& lock, + CSRRegion& region, const CSRNodeGroupCheckpointState& csrState) const; + + std::vector collectLeafRegionsAndCSRLength(const common::UniqLock& lock, + const CSRNodeGroupCheckpointState& csrState) const; + void collectPersistentUpdatesInRegion(CSRRegion& region, + const CSRNodeGroupCheckpointState& csrState) const; + + common::row_idx_t getNumDeletionsForNodeInPersistentData(common::offset_t nodeOffset, + const CSRNodeGroupCheckpointState& csrState) const; + + static void redistributeCSRRegions(const CSRNodeGroupCheckpointState& csrState, + const std::vector& leafRegions); + static std::vector mergeRegionsToCheckpoint( + const CSRNodeGroupCheckpointState& csrState, const std::vector& leafRegions); + static bool isWithinDensityBound(const InMemChunkedCSRHeader& header, + const std::vector& leafRegions, const CSRRegion& region); + + void checkpointColumn(const common::UniqLock& lock, common::column_id_t columnID, + const CSRNodeGroupCheckpointState& csrState, const std::vector& regions) const; + std::vector checkpointColumnInRegion(const common::UniqLock& lock, + common::column_id_t columnID, const CSRNodeGroupCheckpointState& csrState, + const CSRRegion& region) const; + void checkpointCSRHeaderColumns(const CSRNodeGroupCheckpointState& csrState) const; + void finalizeCheckpoint(const common::UniqLock& lock); + +private: + std::unique_ptr persistentChunkGroup; + std::unique_ptr csrIndex; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/dictionary_chunk.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/dictionary_chunk.h new file mode 100644 index 0000000000..271e069165 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/dictionary_chunk.h @@ -0,0 +1,93 @@ +#pragma once + +#include "storage/enums/residency_state.h" +#include "storage/table/column_chunk_data.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +class DictionaryChunk { +public: + using string_offset_t = uint64_t; + using string_index_t = uint32_t; + + DictionaryChunk(MemoryManager& mm, uint64_t capacity, bool enableCompression, + ResidencyState residencyState); + // A pointer to the dictionary chunk is stored in the StringOps for the indexTable + // and can't be modified easily. Moving would invalidate that pointer + DictionaryChunk(DictionaryChunk&& other) = delete; + + void setToInMemory() { + stringDataChunk->setToInMemory(); + offsetChunk->setToInMemory(); + indexTable.clear(); + } + void resetToEmpty(); + + uint64_t getStringLength(string_index_t index) const; + + string_index_t appendString(std::string_view val); + + std::string_view getString(string_index_t index) const; + + ColumnChunkData* getStringDataChunk() const { return stringDataChunk.get(); } + ColumnChunkData* getOffsetChunk() const { return offsetChunk.get(); } + void setOffsetChunk(std::unique_ptr chunk) { offsetChunk = std::move(chunk); } + void setStringDataChunk(std::unique_ptr chunk) { + stringDataChunk = std::move(chunk); + } + + void resetNumValuesFromMetadata(); + + bool sanityCheck() const; + + uint64_t getEstimatedMemoryUsage() const; + + void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(MemoryManager& memoryManager, + common::Deserializer& deSer); + + void flush(PageAllocator& pageAllocator); + +private: + bool enableCompression; + // String data is stored as a UINT8 chunk, using the numValues in the chunk to track the number + // of characters stored. + std::unique_ptr stringDataChunk; + std::unique_ptr offsetChunk; + + struct DictionaryEntry { + string_index_t index; + + std::string_view get(const DictionaryChunk& dict) const { return dict.getString(index); } + }; + + struct StringOps { + explicit StringOps(const DictionaryChunk* dict) : dict(dict) {} + const DictionaryChunk* dict; + using hash_type = std::hash; + using is_transparent = void; + + std::size_t operator()(const DictionaryEntry& entry) const { + return std::hash()(entry.get(*dict)); + } + std::size_t operator()(const char* str) const { return hash_type{}(str); } + std::size_t operator()(std::string_view str) const { return hash_type{}(str); } + std::size_t operator()(std::string const& str) const { return hash_type{}(str); } + + bool operator()(const DictionaryEntry& lhs, const DictionaryEntry& rhs) const { + return lhs.get(*dict) == rhs.get(*dict); + } + bool operator()(const DictionaryEntry& lhs, std::string_view rhs) const { + return lhs.get(*dict) == rhs; + } + bool operator()(std::string_view lhs, const DictionaryEntry& rhs) const { + return lhs == rhs.get(*dict); + } + }; + + std::unordered_set indexTable; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/dictionary_column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/dictionary_column.h new file mode 100644 index 0000000000..ad2f8110eb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/dictionary_column.h @@ -0,0 +1,56 @@ +#pragma once + +#include "dictionary_chunk.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/string_chunk_data.h" + +namespace lbug { +namespace storage { + +class DictionaryColumn { +public: + DictionaryColumn(const std::string& name, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression); + + void scan(const SegmentState& state, DictionaryChunk& dictChunk) const; + // Offsets to scan should be a sorted list of pairs mapping the index of the entry in the string + // dictionary (as read from the index column) to the output index in the result vector to store + // the string. + template + void scan(const SegmentState& offsetState, const SegmentState& dataState, + std::vector>& offsetsToScan, + Result* result, const ColumnChunkMetadata& indexMeta) const; + + DictionaryChunk::string_index_t append(const DictionaryChunk& dictChunk, SegmentState& state, + std::string_view val) const; + + bool canCommitInPlace(const SegmentState& state, uint64_t numNewStrings, + uint64_t totalStringLengthToAdd) const; + + Column* getDataColumn() const { return dataColumn.get(); } + Column* getOffsetColumn() const { return offsetColumn.get(); } + +private: + void scanOffsets(const SegmentState& state, DictionaryChunk::string_offset_t* offsets, + uint64_t index, uint64_t numValues, uint64_t dataSize) const; + void scanValue(const SegmentState& dataState, uint64_t startOffset, uint64_t endOffset, + StringChunkData* result, uint64_t offsetInVector) const; + void scanValue(const SegmentState& dataState, uint64_t startOffset, uint64_t endOffset, + common::ValueVector* resultVector, uint64_t offsetInVector) const; + + static bool canDataCommitInPlace(const SegmentState& dataState, + uint64_t totalStringLengthToAdd); + bool canOffsetCommitInPlace(const SegmentState& offsetState, const SegmentState& dataState, + uint64_t numNewStrings, uint64_t totalStringLengthToAdd) const; + +private: + // The offset column stores the offsets for each index, and the data column stores the data in + // order. Values are never removed from the dictionary during in-place updates, only appended to + // the end. + std::unique_ptr dataColumn; + std::unique_ptr offsetColumn; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/group_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/group_collection.h new file mode 100644 index 0000000000..d9d5db35ed --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/group_collection.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/types/types.h" +#include "common/uniq_lock.h" +#include "common/utils.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +template +class GroupCollection { +public: + GroupCollection() {} + + common::UniqLock lock() const { return common::UniqLock{mtx}; } + + void deserializeGroups(MemoryManager& memoryManager, common::Deserializer& deSer, + const std::vector& columnTypes) { + auto lockGuard = lock(); + deSer.deserializeVectorOfPtrs(groups, [&](common::Deserializer& deser) { + return T::deserialize(memoryManager, deser, columnTypes); + }); + } + + void removeTrailingGroups([[maybe_unused]] const common::UniqLock& lock, + common::idx_t numGroupsToRemove) { + KU_ASSERT(lock.isLocked()); + KU_ASSERT(numGroupsToRemove <= groups.size()); + groups.erase(groups.end() - numGroupsToRemove, groups.end()); + } + + void serializeGroups(common::Serializer& ser) { + auto lockGuard = lock(); + ser.serializeVectorOfPtrs(groups); + } + + void appendGroup(const common::UniqLock& lock, std::unique_ptr group) { + KU_ASSERT(group); + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + groups.push_back(std::move(group)); + } + T* getGroup(const common::UniqLock& lock, common::idx_t groupIdx) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + KU_ASSERT(groupIdx < groups.size()); + return groups[groupIdx].get(); + } + T* getGroupNoLock(common::idx_t groupIdx) const { + KU_ASSERT(groupIdx < groups.size()); + return groups[groupIdx].get(); + } + void replaceGroup(const common::UniqLock& lock, common::idx_t groupIdx, + std::unique_ptr group) { + KU_ASSERT(group); + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + if (groupIdx >= groups.size()) { + groups.resize(groupIdx + 1); + } + groups[groupIdx] = std::move(group); + } + + void resize(const common::UniqLock& lock, common::idx_t newSize) { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + if (newSize <= groups.size()) { + return; + } + groups.resize(newSize); + } + + bool isEmpty(const common::UniqLock& lock) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + return groups.empty(); + } + common::idx_t getNumGroups(const common::UniqLock& lock) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + return groups.size(); + } + common::idx_t getNumGroupsNoLock() const { return groups.size(); } + + const std::vector>& getAllGroups(const common::UniqLock& lock) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + return groups; + } + T* getFirstGroup(const common::UniqLock& lock) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + if (groups.empty()) { + return nullptr; + } + return groups.front().get(); + } + T* getFirstGroupNoLock() const { + if (groups.empty()) { + return nullptr; + } + return groups.front().get(); + } + T* getLastGroup(const common::UniqLock& lock) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + if (groups.empty()) { + return nullptr; + } + return groups.back().get(); + } + + void clear(const common::UniqLock& lock) { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + groups.clear(); + } + + common::idx_t getNumEmptyTrailingGroups(const common::UniqLock& lock) { + const auto& groupsVector = getAllGroups(lock); + return common::safeIntegerConversion( + std::find_if(groupsVector.rbegin(), groupsVector.rend(), + [](const auto& group) { return (group->getNumRows() != 0); }) - + groupsVector.rbegin()); + } + +private: + mutable std::mutex mtx; + std::vector> groups; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/in_mem_chunked_node_group_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/in_mem_chunked_node_group_collection.h new file mode 100644 index 0000000000..bbb5772940 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/in_mem_chunked_node_group_collection.h @@ -0,0 +1,55 @@ +#pragma once + +#include "storage/table/chunked_node_group.h" + +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { + +class LBUG_API InMemChunkedNodeGroupCollection { +public: + explicit InMemChunkedNodeGroupCollection(std::vector types) + : types{std::move(types)} {} + DELETE_BOTH_COPY(InMemChunkedNodeGroupCollection); + + static std::pair getChunkIdxAndOffsetInChunk( + common::row_idx_t rowIdx) { + return std::make_pair(rowIdx / common::StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, + rowIdx % common::StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + } + + const std::vector>& getChunkedGroups() { + return chunkedGroups; + } + InMemChunkedNodeGroup& getChunkedGroup(common::node_group_idx_t groupIdx) const { + KU_ASSERT(groupIdx < chunkedGroups.size()); + return *chunkedGroups[groupIdx]; + } + + // Return num of rows before append. + void append(MemoryManager& memoryManager, const std::vector& vectors, + common::row_idx_t startRowInVectors, common::row_idx_t numRowsToAppend); + + // `merge` are directly moving the chunkedGroup to the collection. + void merge(std::unique_ptr chunkedGroup); + void merge(InMemChunkedNodeGroupCollection& other); + + uint64_t getNumChunkedGroups() const { return chunkedGroups.size(); } + void clear() { chunkedGroups.clear(); } + + void loadFromDisk(MemoryManager& memoryManager) { + for (auto& group : chunkedGroups) { + group->loadFromDisk(memoryManager); + } + } + +private: + std::vector types; + std::vector> chunkedGroups; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/in_memory_exception_chunk.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/in_memory_exception_chunk.h new file mode 100644 index 0000000000..ba37534741 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/in_memory_exception_chunk.h @@ -0,0 +1,65 @@ +#pragma once + +#include "storage/compression/float_compression.h" +#include "storage/table/column_reader_writer.h" + +namespace lbug { +namespace transaction { +class Transaction; +} +namespace storage { + +class Column; +class MemoryManager; +class ColumnReadWriter; +struct ColumnChunkMetadata; +struct ChunkState; + +// In memory representation of ALP exception chunk +// NOTE: read and write operations on this chunk cannot both be performed on this +// Additionally, each exceptionIdx can be updated at most once before finalizing +// After such operations are performed, you must call finalizeAndFlushToDisk() before reading again +template +class LBUG_API InMemoryExceptionChunk { +public: + InMemoryExceptionChunk(const SegmentState& state, FileHandle* dataFH, + MemoryManager* memoryManager, ShadowFile* shadowFile); + ~InMemoryExceptionChunk(); + + void finalizeAndFlushToDisk(SegmentState& state); + + void addException(EncodeException exception); + + void removeExceptionAt(size_t exceptionIdx); + + EncodeException getExceptionAt(size_t exceptionIdx) const; + + common::offset_t findFirstExceptionAtOrPastOffset(common::offset_t offsetInChunk) const; + + size_t getExceptionCount() const; + + void writeException(EncodeException exception, size_t exceptionIdx); + +private: + static PageCursor getExceptionPageCursor(const ColumnChunkMetadata& metadata, + PageCursor pageBaseCursor, size_t exceptionCapacity); + + void finalize(SegmentState& state); + + static constexpr common::PhysicalTypeID physicalType = + std::is_same_v ? common::PhysicalTypeID::ALP_EXCEPTION_FLOAT : + common::PhysicalTypeID::ALP_EXCEPTION_DOUBLE; + + size_t exceptionCount; + size_t finalizedExceptionCount; + size_t exceptionCapacity; + + common::NullMask emptyMask; + + std::unique_ptr column; + std::unique_ptr chunkData; + std::unique_ptr chunkState; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/lazy_segment_scanner.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/lazy_segment_scanner.h new file mode 100644 index 0000000000..1e39880466 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/lazy_segment_scanner.h @@ -0,0 +1,86 @@ +#pragma once + +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_scanner.h" +namespace lbug { +namespace storage { +struct LazySegmentData { + std::unique_ptr segmentData; + common::offset_t startOffsetInSegment; + common::offset_t length; + ColumnChunkScanner::scan_func_t scanFunc; + + // Used for genericRangeSegments() + const LazySegmentData& operator*() const { return *this; } + common::offset_t getNumValues() const { return length; } +}; + +// Separately scans each segment in a column chunk +// Avoids scanning a segment unless a call to updateScannedValue() is made for the current segment +class LazySegmentScanner : public ColumnChunkScanner { +public: + LazySegmentScanner(MemoryManager& mm, common::LogicalType columnType, bool enableCompression) + : numValues(0), mm(mm), columnType(std::move(columnType)), + enableCompression(enableCompression) {} + + struct Iterator { + common::offset_t segmentIdx; + common::offset_t offsetInSegment; + LazySegmentScanner& segmentScanner; + + void advance(common::offset_t n); + void operator++() { advance(1); } + LazySegmentData& operator*() const; + LazySegmentData* operator->() const { return &*(*this); } + }; + + Iterator begin() { return Iterator{0, 0, *this}; } + + // Since we lazily scan segments + // This actually only adds the information needed to scan the segment + // Either updateScannedValue() or scanSegmentIfNeeded must be called to actually scan + void scanSegment(common::offset_t offsetInSegment, common::offset_t segmentLength, + scan_func_t newScanFunc) override; + + void applyCommittedUpdates(const UpdateInfo& updateInfo, + const transaction::Transaction* transaction, common::offset_t startRow, + common::offset_t numRows) override; + + uint64_t getNumValues() override { return numValues; } + + void scanSegmentIfNeeded(LazySegmentData& segment); + void scanSegmentIfNeeded(common::idx_t segmentIdx) { + scanSegmentIfNeeded(segments[segmentIdx]); + } + + template + Func> + void rangeSegments(Iterator startIt, common::length_t length, Func func); + +private: + std::vector segments; + + uint64_t numValues; + + MemoryManager& mm; + common::LogicalType columnType; + bool enableCompression; +}; + +inline LazySegmentData& LazySegmentScanner::Iterator::operator*() const { + KU_ASSERT(segmentIdx < segmentScanner.segments.size() && + offsetInSegment < segmentScanner.segments[segmentIdx].length); + return segmentScanner.segments[segmentIdx]; +} + +template< + std::invocable Func> +void LazySegmentScanner::rangeSegments(Iterator startIt, common::length_t length, Func func) { + auto segmentSpan = std::span(segments); + genericRangeSegmentsFromIt(segmentSpan, segmentSpan.begin() + startIt.segmentIdx, + startIt.offsetInSegment, length, func); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/list_chunk_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/list_chunk_data.h new file mode 100644 index 0000000000..748d643922 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/list_chunk_data.h @@ -0,0 +1,135 @@ +#pragma once + +#include "common/data_chunk/sel_vector.h" +#include "common/types/types.h" +#include "storage/table/column_chunk_data.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +class LBUG_API ListChunkData final : public ColumnChunkData { +public: + static constexpr common::idx_t SIZE_COLUMN_CHILD_READ_STATE_IDX = 0; + static constexpr common::idx_t DATA_COLUMN_CHILD_READ_STATE_IDX = 1; + static constexpr common::idx_t OFFSET_COLUMN_CHILD_READ_STATE_IDX = 2; + static constexpr size_t CHILD_COLUMN_COUNT = 3; + + ListChunkData(MemoryManager& mm, common::LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState); + ListChunkData(MemoryManager& mm, common::LogicalType dataType, bool enableCompression, + const ColumnChunkMetadata& metadata); + + ColumnChunkData* getOffsetColumnChunk() const { return offsetColumnChunk.get(); } + + ColumnChunkData* getDataColumnChunk() const { return dataColumnChunk.get(); } + std::unique_ptr moveDataColumnChunk() { return std::move(dataColumnChunk); } + + ColumnChunkData* getSizeColumnChunk() const { return sizeColumnChunk.get(); } + std::unique_ptr moveSizeColumnChunk() { return std::move(sizeColumnChunk); } + + void setOffsetColumnChunk(std::unique_ptr offsetColumnChunk_) { + offsetColumnChunk = std::move(offsetColumnChunk_); + } + void setDataColumnChunk(std::unique_ptr dataColumnChunk_) { + dataColumnChunk = std::move(dataColumnChunk_); + } + void setSizeColumnChunk(std::unique_ptr sizeColumnChunk_) { + sizeColumnChunk = std::move(sizeColumnChunk_); + } + + void resetToEmpty() override; + + void setNumValues(uint64_t numValues_) override { + ColumnChunkData::setNumValues(numValues_); + sizeColumnChunk->setNumValues(numValues_); + offsetColumnChunk->setNumValues(numValues_); + } + + void resetNumValuesFromMetadata() override; + void syncNumValues() override { + numValues = offsetColumnChunk->getNumValues(); + metadata.numValues = numValues; + } + + void append(common::ValueVector* vector, const common::SelectionView& selVector) override; + + void initializeScanState(SegmentState& state, const Column* column) const override; + void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector) const override; + void lookup(common::offset_t offsetInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const override; + + // Note: `write` assumes that no `append` will be called afterward. + void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk) override; + void write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, + common::RelMultiplicity multiplicity) override; + void write(const ColumnChunkData* srcChunk, common::offset_t srcOffsetInChunk, + common::offset_t dstOffsetInChunk, common::offset_t numValuesToCopy) override; + + void resizeDataColumnChunk(uint64_t numValues) const { dataColumnChunk->resize(numValues); } + + void setToInMemory() override { + ColumnChunkData::setToInMemory(); + sizeColumnChunk->setToInMemory(); + offsetColumnChunk->setToInMemory(); + dataColumnChunk->setToInMemory(); + KU_ASSERT(offsetColumnChunk->getNumValues() == numValues); + } + void resize(uint64_t newCapacity) override { + ColumnChunkData::resize(newCapacity); + sizeColumnChunk->resize(newCapacity); + offsetColumnChunk->resize(newCapacity); + } + + void resizeWithoutPreserve(uint64_t newCapacity) override { + ColumnChunkData::resizeWithoutPreserve(newCapacity); + sizeColumnChunk->resizeWithoutPreserve(newCapacity); + offsetColumnChunk->resizeWithoutPreserve(newCapacity); + } + + common::offset_t getListStartOffset(common::offset_t offset) const; + + common::offset_t getListEndOffset(common::offset_t offset) const; + + common::list_size_t getListSize(common::offset_t offset) const; + + void resetOffset(); + void resetFromOtherChunk(ListChunkData* other); + void finalize() override; + bool isOffsetsConsecutiveAndSortedAscending(uint64_t startPos, uint64_t endPos) const; + bool sanityCheck() const override; + + uint64_t getEstimatedMemoryUsage() const override; + + void serialize(common::Serializer& serializer) const override; + static void deserialize(common::Deserializer& deSer, ColumnChunkData& chunkData); + + void flush(PageAllocator& pageAllocator) override; + uint64_t getMinimumSizeOnDisk() const override; + uint64_t getSizeOnDisk() const override; + uint64_t getSizeOnDiskInMemoryStats() const override; + void reclaimStorage(PageAllocator& pageAllocator) override; + +protected: + void copyListValues(const common::list_entry_t& entry, common::ValueVector* dataVector); + +private: + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) override; + + void appendNullList(); + + void setOffsetChunkValue(common::offset_t val, common::offset_t pos); + +protected: + std::unique_ptr offsetColumnChunk; + std::unique_ptr sizeColumnChunk; + std::unique_ptr dataColumnChunk; + // we use checkOffsetSortedAsc flag to indicate that we do not trigger random write + bool checkOffsetSortedAsc; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/list_column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/list_column.h new file mode 100644 index 0000000000..91d06b84c3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/list_column.h @@ -0,0 +1,105 @@ +#pragma once + +#include + +#include "column.h" +#include "common/types/types.h" + +// List is a nested data type which is stored as three chunks: +// 1. Offset column (type: INT64). Using offset to partition the data column into multiple lists. +// 2. Size column. Stores the size of each list. +// 3. Data column. Stores the actual data of the list. +// Similar to other data types, nulls are stored in the null column. +// Example layout for list of INT64: +// Four lists: [4,7,8,12], null, [2, 3], [] +// Offset column: [4, 4, 6, 6] +// Size column: [4, 0, 2, 0] +// data column: [4, 7, 8, 12, 2, 3] +// When updating the data, we first append the data to the data column, and then update the offset +// and size accordingly. Besides offset column, we introduce an extra size column here to enable +// in-place updates of a list column. In a list column chunk, offsets of lists are not always sorted +// after updates. This is good for writes, but it introduces extra overheads for scans, as lists can +// be scattered, and scans have to be broken into multiple small reads. To achieve a balance between +// reads and writes, during updates, we rewrite the whole list column chunk in ascending order +// when the offsets are not sorted in ascending order and the size of data column chunk is larger +// than half of its capacity. + +namespace lbug { +namespace storage { + +struct ListOffsetSizeInfo { + common::offset_t numTotal; + std::unique_ptr offsetColumnChunk; + std::unique_ptr sizeColumnChunk; + + ListOffsetSizeInfo(common::offset_t numTotal, + std::unique_ptr offsetColumnChunk, + std::unique_ptr sizeColumnChunk) + : numTotal{numTotal}, offsetColumnChunk{std::move(offsetColumnChunk)}, + sizeColumnChunk{std::move(sizeColumnChunk)} {} + + common::list_size_t getListSize(uint64_t pos) const; + common::offset_t getListEndOffset(uint64_t pos) const; + common::offset_t getListStartOffset(uint64_t pos) const; + + bool isOffsetSortedAscending(uint64_t startPos, uint64_t endPos) const; +}; + +class ListColumn final : public Column { + static constexpr common::idx_t SIZE_COLUMN_CHILD_READ_STATE_IDX = 0; + static constexpr common::idx_t DATA_COLUMN_CHILD_READ_STATE_IDX = 1; + static constexpr common::idx_t OFFSET_COLUMN_CHILD_READ_STATE_IDX = 2; + static constexpr size_t CHILD_COLUMN_COUNT = 3; + +public: + ListColumn(std::string name, common::LogicalType dataType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression); + + static bool disableCompressionOnData(const common::LogicalType& dataType); + + static std::unique_ptr flushChunkData(const ColumnChunkData& chunk, + PageAllocator& pageAllocator); + + Column* getOffsetColumn() const { return offsetColumn.get(); } + Column* getSizeColumn() const { return sizeColumn.get(); } + Column* getDataColumn() const { return dataColumn.get(); } + + std::vector> checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment = true) const override; + +protected: + void scanSegment(const SegmentState& state, common::offset_t startOffsetInChunk, + common::row_idx_t numValuesToScan, common::ValueVector* resultVector, + common::offset_t offsetInResult) const override; + + void scanSegment(const SegmentState& state, ColumnChunkData* resultChunk, + common::offset_t startOffsetInSegment, common::row_idx_t numValuesToScan) const override; + + void lookupInternal(const SegmentState& state, common::offset_t nodeOffset, + common::ValueVector* resultVector, uint32_t posInVector) const override; + +private: + void scanUnfiltered(const SegmentState& state, common::ValueVector* resultVector, + uint64_t numValuesToScan, const ListOffsetSizeInfo& listOffsetInfoInStorage, + common::offset_t offsetInResult) const; + void scanFiltered(const SegmentState& state, common::offset_t startOffsetInChunk, + common::ValueVector* offsetVector, const ListOffsetSizeInfo& listOffsetInfoInStorage, + common::offset_t offsetInResult) const; + + common::offset_t readOffset(const SegmentState& state, + common::offset_t offsetInNodeGroup) const; + common::list_size_t readSize(const SegmentState& state, + common::offset_t offsetInNodeGroup) const; + + ListOffsetSizeInfo getListOffsetSizeInfo(const SegmentState& state, + common::offset_t startOffsetInSegment, common::offset_t numOffsetsToRead) const; + +private: + std::unique_ptr offsetColumn; + std::unique_ptr sizeColumn; + std::unique_ptr dataColumn; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_group.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_group.h new file mode 100644 index 0000000000..f52518fd84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_group.h @@ -0,0 +1,269 @@ +#pragma once + +#include + +#include "common/uniq_lock.h" +#include "storage/enums/residency_state.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/group_collection.h" +#include "storage/table/version_record_handler.h" + +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { +class MemoryManager; + +class ColumnStats; +struct TableAddColumnState; +class NodeGroup; + +struct NodeGroupScanState { + // Index of committed but not yet checkpointed chunked group to scan. + common::idx_t chunkedGroupIdx = 0; + common::row_idx_t nextRowToScan = 0; + // State of each chunk in the checkpointed chunked group. + std::vector chunkStates; + + explicit NodeGroupScanState() {} + explicit NodeGroupScanState(common::idx_t numChunks) { chunkStates.resize(numChunks); } + + virtual ~NodeGroupScanState() = default; + DELETE_COPY_DEFAULT_MOVE(NodeGroupScanState); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& constCast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct NodeGroupCheckpointState { + std::vector columnIDs; + std::vector columns; + PageAllocator& pageAllocator; + MemoryManager* mm; + + NodeGroupCheckpointState(std::vector columnIDs, + std::vector columns, PageAllocator& pageAllocator, MemoryManager* mm) + : columnIDs{std::move(columnIDs)}, columns{std::move(columns)}, + pageAllocator{pageAllocator}, mm{mm} {} + virtual ~NodeGroupCheckpointState() = default; + + template + const T& cast() const { + return common::ku_dynamic_cast(*this); + } + template + T& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct NodeGroupScanResult { + + common::row_idx_t startRow = common::INVALID_ROW_IDX; + common::row_idx_t numRows = 0; + + constexpr NodeGroupScanResult() noexcept = default; + constexpr NodeGroupScanResult(common::row_idx_t startRow, common::row_idx_t numRows) noexcept + : startRow{startRow}, numRows{numRows} {} + + bool operator==(const NodeGroupScanResult& other) const { + return startRow == other.startRow && numRows == other.numRows; + } +}; + +static auto NODE_GROUP_SCAN_EMPTY_RESULT = NodeGroupScanResult{}; + +struct TableScanState; +class NodeGroup { +public: + NodeGroup(MemoryManager& mm, const common::node_group_idx_t nodeGroupIdx, + const bool enableCompression, std::vector dataTypes, + common::row_idx_t capacity = common::StorageConfig::NODE_GROUP_SIZE, + NodeGroupDataFormat format = NodeGroupDataFormat::REGULAR) + : mm{mm}, nodeGroupIdx{nodeGroupIdx}, format{format}, enableCompression{enableCompression}, + numRows{0}, nextRowToAppend{0}, capacity{capacity}, dataTypes{std::move(dataTypes)} {} + NodeGroup(MemoryManager& mm, const common::node_group_idx_t nodeGroupIdx, + const bool enableCompression, std::unique_ptr chunkedNodeGroup, + common::row_idx_t capacity = common::StorageConfig::NODE_GROUP_SIZE, + NodeGroupDataFormat format = NodeGroupDataFormat::REGULAR) + : mm{mm}, nodeGroupIdx{nodeGroupIdx}, format{format}, enableCompression{enableCompression}, + numRows{chunkedNodeGroup->getStartRowIdx() + chunkedNodeGroup->getNumRows()}, + nextRowToAppend{numRows}, capacity{capacity} { + for (auto i = 0u; i < chunkedNodeGroup->getNumColumns(); i++) { + dataTypes.push_back(chunkedNodeGroup->getColumnChunk(i).getDataType().copy()); + } + const auto lock = chunkedGroups.lock(); + chunkedGroups.appendGroup(lock, std::move(chunkedNodeGroup)); + } + NodeGroup(MemoryManager& mm, const common::node_group_idx_t nodeGroupIdx, + const bool enableCompression, common::row_idx_t capacity, NodeGroupDataFormat format) + : mm{mm}, nodeGroupIdx{nodeGroupIdx}, format{format}, enableCompression{enableCompression}, + numRows{0}, nextRowToAppend{0}, capacity{capacity} {} + virtual ~NodeGroup() = default; + + virtual bool isEmpty() const { return numRows.load() == 0; } + virtual common::row_idx_t getNumRows() const { return numRows.load(); } + void moveNextRowToAppend(common::row_idx_t numRowsToAppend) { + nextRowToAppend += numRowsToAppend; + } + common::row_idx_t getNumRowsLeftToAppend() const { return capacity - nextRowToAppend; } + bool isFull() const { return numRows.load() == capacity; } + const std::vector& getDataTypes() const { return dataTypes; } + NodeGroupDataFormat getFormat() const { return format; } + common::row_idx_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, ChunkedNodeGroup& chunkedGroup, + common::row_idx_t startRowIdx, common::row_idx_t numRowsToAppend); + common::row_idx_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, InMemChunkedNodeGroup& chunkedGroup, + common::row_idx_t startRowIdx, common::row_idx_t numRowsToAppend); + common::row_idx_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, + std::span chunkedGroup, common::row_idx_t startRowIdx, + common::row_idx_t numRowsToAppend); + common::row_idx_t append(const transaction::Transaction* transaction, + const std::vector& columnIDs, + std::span chunkedGroup, common::row_idx_t startRowIdx, + common::row_idx_t numRowsToAppend); + void append(const transaction::Transaction* transaction, + const std::vector& vectors, common::row_idx_t startRowIdx, + common::row_idx_t numRowsToAppend); + + void merge(transaction::Transaction* transaction, + std::unique_ptr chunkedGroup); + + virtual void initializeScanState(const transaction::Transaction* transaction, + TableScanState& state) const; + void initializeScanState(const transaction::Transaction* transaction, + const common::UniqLock& lock, TableScanState& state) const; + virtual NodeGroupScanResult scan(const transaction::Transaction* transaction, + TableScanState& state) const; + + virtual NodeGroupScanResult scan(transaction::Transaction* transaction, TableScanState& state, + common::offset_t startOffsetInGroup, common::offset_t numRowsToScan) const; + + bool lookup(const transaction::Transaction* transaction, const TableScanState& state, + common::sel_t posInSel = 0) const; + bool lookupNoLock(const transaction::Transaction* transaction, const TableScanState& state, + common::sel_t posInSel = 0) const; + // TODO(Guodong): These should be merged together with `lookup`. + bool lookupMultiple(const common::UniqLock& lock, const transaction::Transaction* transaction, + const TableScanState& state) const; + bool lookupMultiple(const transaction::Transaction* transaction, + const TableScanState& state) const; + + void update(const transaction::Transaction* transaction, common::row_idx_t rowIdxInGroup, + common::column_id_t columnID, const common::ValueVector& propertyVector); + bool delete_(const transaction::Transaction* transaction, common::row_idx_t rowIdxInGroup); + + bool hasDeletions(const transaction::Transaction* transaction) const; + virtual void addColumn(TableAddColumnState& addColumnState, PageAllocator* pageAllocator, + ColumnStats* newColumnStats); + + void applyFuncToChunkedGroups(version_record_handler_op_t func, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const; + void rollbackInsert(common::row_idx_t startRow); + void reclaimStorage(PageAllocator& pageAllocator) const; + virtual void reclaimStorage(PageAllocator& pageAllocator, const common::UniqLock& lock) const; + + virtual void checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state); + + uint64_t getEstimatedMemoryUsage() const; + + virtual void serialize(common::Serializer& serializer); + static std::unique_ptr deserialize(MemoryManager& mm, common::Deserializer& deSer, + const std::vector& columnTypes); + + common::node_group_idx_t getNumChunkedGroups() const { + const auto lock = chunkedGroups.lock(); + return chunkedGroups.getNumGroups(lock); + } + ChunkedNodeGroup* getChunkedNodeGroup(common::node_group_idx_t groupIdx) const { + const auto lock = chunkedGroups.lock(); + return chunkedGroups.getGroup(lock, groupIdx); + } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } + + bool isVisible(const transaction::Transaction* transaction, + common::row_idx_t rowIdxInGroup) const; + bool isVisibleNoLock(const transaction::Transaction* transaction, + common::row_idx_t rowIdxInGroup) const; + bool isDeleted(const transaction::Transaction* transaction, + common::offset_t offsetInGroup) const; + bool isInserted(const transaction::Transaction* transaction, + common::offset_t offsetInGroup) const; + + common::node_group_idx_t getNodeGroupIdx() const { return nodeGroupIdx; } + +protected: + static constexpr auto INVALID_CHUNKED_GROUP_IDX = UINT32_MAX; + static constexpr auto INVALID_START_ROW_IDX = UINT64_MAX; + +protected: + void checkpointDataTypesNoLock(const NodeGroupCheckpointState& state); + +private: + std::pair findChunkedGroupIdxFromRowIdxNoLock( + common::row_idx_t rowIdx) const; + ChunkedNodeGroup* findChunkedGroupFromRowIdx(const common::UniqLock& lock, + common::row_idx_t rowIdx) const; + ChunkedNodeGroup* findChunkedGroupFromRowIdxNoLock(common::row_idx_t rowIdx) const; + + std::unique_ptr checkpointInMemOnly(MemoryManager& memoryManager, + const common::UniqLock& lock, const NodeGroupCheckpointState& state) const; + std::unique_ptr checkpointInMemAndOnDisk(MemoryManager& memoryManager, + const common::UniqLock& lock, NodeGroupCheckpointState& state) const; + std::unique_ptr checkpointVersionInfo(const common::UniqLock& lock, + const transaction::Transaction* transaction) const; + + template + common::row_idx_t getNumResidentRows(const common::UniqLock& lock) const; + template + std::unique_ptr scanAllInsertedAndVersions(MemoryManager& memoryManager, + const common::UniqLock& lock, const std::vector& columnIDs, + const std::vector& columns) const; + + virtual NodeGroupScanResult scanInternal(const common::UniqLock& lock, + transaction::Transaction* transaction, TableScanState& state, + common::offset_t startOffsetInGroup, common::offset_t numRowsToScan) const; + + common::row_idx_t getStartRowIdxInGroupNoLock() const; + common::row_idx_t getStartRowIdxInGroup(const common::UniqLock& lock) const; + + void scanCommittedUpdatesForColumn(std::vector& chunkCheckpointStates, + MemoryManager& memoryManager, const common::UniqLock& lock, common::column_id_t columnID, + const Column* column) const; + +protected: + MemoryManager& mm; + common::node_group_idx_t nodeGroupIdx; + NodeGroupDataFormat format; + bool enableCompression; + std::atomic numRows; + // `nextRowToAppend` is a cursor to allow us to pre-reserve a set of rows to append before + // acutally appending data. This is an optimization to reduce lock-contention when appending in + // parallel. + // TODO(Guodong): Remove this field. + common::row_idx_t nextRowToAppend; + common::row_idx_t capacity; + std::vector dataTypes; + GroupCollection chunkedGroups; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_group_collection.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_group_collection.h new file mode 100644 index 0000000000..7696e2473e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_group_collection.h @@ -0,0 +1,125 @@ +#pragma once + +#include "storage/stats/table_stats.h" +#include "storage/table/group_collection.h" +#include "storage/table/node_group.h" + +namespace lbug { +namespace transaction { +class Transaction; +} +namespace storage { +class MemoryManager; + +class NodeGroupCollection { +public: + NodeGroupCollection(MemoryManager& mm, const std::vector& types, + bool enableCompression, ResidencyState residency = ResidencyState::IN_MEMORY, + const VersionRecordHandler* versionRecordHandler = nullptr); + + void append(const transaction::Transaction* transaction, + const std::vector& vectors); + void append(const transaction::Transaction* transaction, + const std::vector& columnIDs, const NodeGroupCollection& other); + void append(const transaction::Transaction* transaction, + const std::vector& columnIDs, const NodeGroup& nodeGroup); + + // This function only tries to append data into the last node group, and if the last node group + // is not enough to hold all the data, it will append partially and return the number of rows + // appended. + // The returned values are the startOffset and numValuesAppended. + // NOTE: This is specially coded to only be used by NodeBatchInsert for now. + std::pair appendToLastNodeGroupAndFlushWhenFull( + transaction::Transaction* transaction, const std::vector& columnIDs, + InMemChunkedNodeGroup& chunkedGroup, PageAllocator& pageAllocator); + + common::row_idx_t getNumTotalRows() const; + common::node_group_idx_t getNumNodeGroups() const { + const auto lock = nodeGroups.lock(); + return nodeGroups.getNumGroups(lock); + } + common::node_group_idx_t getNumNodeGroupsNoLock() const { + return nodeGroups.getNumGroupsNoLock(); + } + NodeGroup* getNodeGroupNoLock(const common::node_group_idx_t groupIdx) const { + KU_ASSERT(nodeGroups.getGroupNoLock(groupIdx)->getNodeGroupIdx() == groupIdx); + return nodeGroups.getGroupNoLock(groupIdx); + } + NodeGroup* getNodeGroup(const common::node_group_idx_t groupIdx, + bool mayOutOfBound = false) const { + const auto lock = nodeGroups.lock(); + if (mayOutOfBound && groupIdx >= nodeGroups.getNumGroups(lock)) { + return nullptr; + } + KU_ASSERT(nodeGroups.getGroupNoLock(groupIdx)->getNodeGroupIdx() == groupIdx); + return nodeGroups.getGroup(lock, groupIdx); + } + NodeGroup* getOrCreateNodeGroup(const transaction::Transaction* transaction, + common::node_group_idx_t groupIdx, NodeGroupDataFormat format); + + void setNodeGroup(const common::node_group_idx_t nodeGroupIdx, + std::unique_ptr group) { + const auto lock = nodeGroups.lock(); + nodeGroups.replaceGroup(lock, nodeGroupIdx, std::move(group)); + } + + void rollbackInsert(common::row_idx_t numRows_, bool updateNumRows = true); + + void clear() { + const auto lock = nodeGroups.lock(); + nodeGroups.clear(lock); + } + + common::column_id_t getNumColumns() const { return types.size(); } + + void addColumn(TableAddColumnState& addColumnState, PageAllocator* pageAllocator = nullptr); + + uint64_t getEstimatedMemoryUsage() const; + + void checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state); + void reclaimStorage(PageAllocator& pageAllocator) const; + + TableStats getStats() const { + auto lock = nodeGroups.lock(); + return stats.copy(); + } + TableStats getStats(const common::UniqLock& lock) const { + KU_ASSERT(lock.isLocked()); + KU_UNUSED(lock); + return stats.copy(); + } + void mergeStats(const TableStats& stats) { + auto lock = nodeGroups.lock(); + this->stats.merge(stats); + } + void mergeStats(const std::vector& columnIDs, const TableStats& stats) { + auto lock = nodeGroups.lock(); + this->stats.merge(columnIDs, stats); + } + + void serialize(common::Serializer& ser); + void deserialize(common::Deserializer& deSer, MemoryManager& memoryManager); + + void pushInsertInfo(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler, + bool incrementNumRows); + +private: + void pushInsertInfo(const transaction::Transaction* transaction, const NodeGroup* nodeGroup, + common::row_idx_t numRows); + +private: + MemoryManager& mm; + bool enableCompression; + // Num rows in the collection regardless of deletions. + std::atomic numTotalRows; + std::vector types; + GroupCollection nodeGroups; + ResidencyState residency; + TableStats stats; + const VersionRecordHandler* versionRecordHandler; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_table.h new file mode 100644 index 0000000000..18ba1f59d4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/node_table.h @@ -0,0 +1,230 @@ +#pragma once + +#include "common/types/types.h" +#include "storage/index/hash_index.h" +#include "storage/table/node_group_collection.h" +#include "storage/table/table.h" + +namespace lbug { +namespace evaluator { +class ExpressionEvaluator; +} // namespace evaluator + +namespace catalog { +class NodeTableCatalogEntry; +} // namespace catalog + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { + +struct LBUG_API NodeTableScanState : TableScanState { + NodeTableScanState(common::ValueVector* nodeIDVector, + std::vector outputVectors, + std::shared_ptr outChunkState) + : TableScanState{nodeIDVector, std::move(outputVectors), std::move(outChunkState)} { + nodeGroupScanState = std::make_unique(this->columnIDs.size()); + } + + void setToTable(const transaction::Transaction* transaction, Table* table_, + std::vector columnIDs_, + std::vector columnPredicateSets_ = {}, + common::RelDataDirection direction = common::RelDataDirection::INVALID) override; + + bool scanNext(transaction::Transaction* transaction) override; + + NodeGroupScanResult scanNext(transaction::Transaction* transaction, + common::offset_t startOffset, common::offset_t numNodes); +}; + +// There is a vtable bug related to the Apple clang v15.0.0+. Adding the `FINAL` specifier to +// derived class causes casting failures in Apple platform. +struct LBUG_API NodeTableInsertState : TableInsertState { + common::ValueVector& nodeIDVector; + const common::ValueVector& pkVector; + std::vector> indexInsertStates; + + NodeTableInsertState(common::ValueVector& nodeIDVector, const common::ValueVector& pkVector, + std::vector propertyVectors) + : TableInsertState{std::move(propertyVectors)}, nodeIDVector{nodeIDVector}, + pkVector{pkVector} {} + + NodeTableInsertState(const NodeTableInsertState&) = delete; +}; + +struct LBUG_API NodeTableUpdateState : TableUpdateState { + common::ValueVector& nodeIDVector; + std::vector> indexUpdateState; + + NodeTableUpdateState(common::column_id_t columnID, common::ValueVector& nodeIDVector, + common::ValueVector& propertyVector) + : TableUpdateState{columnID, propertyVector}, nodeIDVector{nodeIDVector} {} + + NodeTableUpdateState(const NodeTableUpdateState&) = delete; + + bool needToUpdateIndex(common::idx_t idx) const { + return idx < indexUpdateState.size() && indexUpdateState[idx] != nullptr; + } +}; + +struct LBUG_API NodeTableDeleteState : TableDeleteState { + common::ValueVector& nodeIDVector; + common::ValueVector& pkVector; + + explicit NodeTableDeleteState(common::ValueVector& nodeIDVector, common::ValueVector& pkVector) + : nodeIDVector{nodeIDVector}, pkVector{pkVector} {} +}; + +class NodeTable; +struct IndexScanHelper { + explicit IndexScanHelper(NodeTable* table, Index* index) : table{table}, index(index) {} + virtual ~IndexScanHelper() = default; + + virtual std::unique_ptr initScanState( + const transaction::Transaction* transaction, common::DataChunk& dataChunk); + virtual bool processScanOutput(main::ClientContext* context, NodeGroupScanResult scanResult, + const std::vector& scannedVectors) = 0; + + NodeTable* table; + Index* index; +}; + +class NodeTableVersionRecordHandler final : public VersionRecordHandler { +public: + explicit NodeTableVersionRecordHandler(NodeTable* table); + + void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const override; + void rollbackInsert(main::ClientContext* context, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) const override; + +private: + NodeTable* table; +}; + +class StorageManager; + +class LBUG_API NodeTable final : public Table { +public: + NodeTable(const StorageManager* storageManager, + const catalog::NodeTableCatalogEntry* nodeTableEntry, MemoryManager* mm); + + common::row_idx_t getNumTotalRows(const transaction::Transaction* transaction) override; + + void initScanState(transaction::Transaction* transaction, TableScanState& scanState, + bool resetCachedBoundNodeIDs = true) const override; + void initScanState(transaction::Transaction* transaction, TableScanState& scanState, + common::table_id_t tableID, common::offset_t startOffset) const; + + bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) override; + template + bool lookup(const transaction::Transaction* transaction, const TableScanState& scanState) const; + // TODO(Guodong): This should be merged together with `lookup`. + template + bool lookupMultiple(transaction::Transaction* transaction, TableScanState& scanState) const; + + // Return the max node offset during insertions. + common::offset_t validateUniquenessConstraint(const transaction::Transaction* transaction, + const std::vector& propertyVectors) const; + + void initInsertState(main::ClientContext* context, TableInsertState& insertState) override; + void insert(transaction::Transaction* transaction, TableInsertState& insertState) override; + void initUpdateState(main::ClientContext* context, TableUpdateState& updateState) const; + void update(transaction::Transaction* transaction, TableUpdateState& updateState) override; + bool delete_(transaction::Transaction* transaction, TableDeleteState& deleteState) override; + + void addColumn(transaction::Transaction* transaction, TableAddColumnState& addColumnState, + PageAllocator& pageAllocator) override; + bool isVisible(const transaction::Transaction* transaction, common::offset_t offset) const; + bool isVisibleNoLock(const transaction::Transaction* transaction, + common::offset_t offset) const; + + bool lookupPK(const transaction::Transaction* transaction, common::ValueVector* keyVector, + uint64_t vectorPos, common::offset_t& result) const; + + void addIndex(std::unique_ptr index); + void dropIndex(const std::string& name); + + common::column_id_t getPKColumnID() const { return pkColumnID; } + PrimaryKeyIndex* getPKIndex() const { + const auto index = getIndex(PrimaryKeyIndex::DEFAULT_NAME); + KU_ASSERT(index.has_value()); + return &index.value()->cast(); + } + std::optional> getIndexHolder(const std::string& name); + std::optional getIndex(const std::string& name) const; + std::vector& getIndexes() { return indexes; } + + common::column_id_t getNumColumns() const { return columns.size(); } + Column& getColumn(common::column_id_t columnID) { + KU_ASSERT(columnID < columns.size()); + return *columns[columnID]; + } + const Column& getColumn(common::column_id_t columnID) const { + KU_ASSERT(columnID < columns.size()); + return *columns[columnID]; + } + + std::pair appendToLastNodeGroup( + transaction::Transaction* transaction, const std::vector& columnIDs, + InMemChunkedNodeGroup& chunkedGroup, PageAllocator& pageAllocator); + + void commit(main::ClientContext* context, catalog::TableCatalogEntry* tableEntry, + LocalTable* localTable) override; + bool checkpoint(main::ClientContext* context, catalog::TableCatalogEntry* tableEntry, + PageAllocator& pageAllocator) override; + void rollbackCheckpoint() override; + void reclaimStorage(PageAllocator& pageAllocator) const override; + + void rollbackPKIndexInsert(main::ClientContext* context, common::row_idx_t startRow, + common::row_idx_t numRows_, common::node_group_idx_t nodeGroupIdx_); + void rollbackGroupCollectionInsert(common::row_idx_t numRows_); + + common::node_group_idx_t getNumCommittedNodeGroups() const { + return nodeGroups->getNumNodeGroups(); + } + + common::node_group_idx_t getNumNodeGroups() const { return nodeGroups->getNumNodeGroups(); } + common::offset_t getNumTuplesInNodeGroup(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getNodeGroup(nodeGroupIdx)->getNumRows(); + } + NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getNodeGroup(nodeGroupIdx); + } + NodeGroup* getNodeGroupNoLock(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getNodeGroupNoLock(nodeGroupIdx); + } + + TableStats getStats(const transaction::Transaction* transaction) const; + // NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. + void mergeStats(const std::vector& columnIDs, const TableStats& stats) { + nodeGroups->mergeStats(columnIDs, stats); + } + + void serialize(common::Serializer& serializer) const override; + void deserialize(main::ClientContext* context, StorageManager* storageManager, + common::Deserializer& deSer) override; + +private: + void validatePkNotExists(const transaction::Transaction* transaction, + common::ValueVector* pkVector) const; + + visible_func getVisibleFunc(const transaction::Transaction* transaction) const; + common::DataChunk constructDataChunkForColumns( + const std::vector& columnIDs) const; + void scanIndexColumns(main::ClientContext* context, IndexScanHelper& scanHelper, + const NodeGroupCollection& nodeGroups_) const; + +private: + std::vector> columns; + std::unique_ptr nodeGroups; + common::column_id_t pkColumnID; + std::vector indexes; + NodeTableVersionRecordHandler versionRecordHandler; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/null_column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/null_column.h new file mode 100644 index 0000000000..5284a51cff --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/null_column.h @@ -0,0 +1,20 @@ +#pragma once + +#include "common/system_config.h" +#include "storage/table/column.h" + +namespace lbug { +namespace storage { + +// Page size must be aligned to 8 byte chunks for the 64-bit NullMask algorithms to work +// without the possibility of memory errors from reading/writing off the end of a page. +static_assert(common::LBUG_PAGE_SIZE % 8 == 0); + +class NullColumn final : public Column { +public: + NullColumn(const std::string& name, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression); +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/rel_table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/rel_table.h new file mode 100644 index 0000000000..17a6d95f87 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/rel_table.h @@ -0,0 +1,255 @@ +#pragma once + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "storage/table/rel_table_data.h" +#include "storage/table/table.h" + +namespace lbug { +namespace evaluator { +class ExpressionEvaluator; +} // namespace evaluator +namespace transaction { +class Transaction; +} +namespace storage { +class MemoryManager; + +struct LocalRelTableScanState; +struct RelTableScanState : TableScanState { + common::RelDataDirection direction; + common::sel_t currBoundNodeIdx; + Column* csrOffsetColumn; + Column* csrLengthColumn; + bool randomLookup; + + // This is a reference of the original selVector of the input boundNodeIDVector. + common::SelectionVector cachedBoundNodeSelVector; + + std::unique_ptr localTableScanState; + + RelTableScanState(MemoryManager& mm, common::ValueVector* nodeIDVector, + std::vector outputVectors, + std::shared_ptr outChunkState, bool randomLookup = false) + : TableScanState{nodeIDVector, std::move(outputVectors), std::move(outChunkState)}, + direction{common::RelDataDirection::INVALID}, currBoundNodeIdx{0}, + csrOffsetColumn{nullptr}, csrLengthColumn{nullptr}, randomLookup{randomLookup}, + localTableScanState{nullptr} { + nodeGroupScanState = std::make_unique(mm, randomLookup); + } + + // This is for local table scan state. + RelTableScanState(common::ValueVector* nodeIDVector, + std::vector outputVectors, + std::shared_ptr outChunkState) + : TableScanState{nodeIDVector, std::move(outputVectors), std::move(outChunkState)}, + direction{common::RelDataDirection::INVALID}, currBoundNodeIdx{0}, + csrOffsetColumn{nullptr}, csrLengthColumn{nullptr}, randomLookup{false}, + localTableScanState{nullptr} { + nodeGroupScanState = std::make_unique(); + } + + void setToTable(const transaction::Transaction* transaction, Table* table_, + std::vector columnIDs_, + std::vector columnPredicateSets_, + common::RelDataDirection direction_) override; + + void initState(transaction::Transaction* transaction, NodeGroup* nodeGroup, + bool resetCachedBoundNodeIDs = true) override; + + bool scanNext(transaction::Transaction* transaction) override; + + void setNodeIDVectorToFlat(common::sel_t selPos) const; + +private: + bool hasUnCommittedData() const; + + void initCachedBoundNodeIDSelVector(); + void initStateForCommitted(const transaction::Transaction* transaction); + void initStateForUncommitted(); +}; + +class LocalRelTable; +struct LocalRelTableScanState final : RelTableScanState { + LocalRelTable* localRelTable; + // TODO(Guodong): Copy of rowIndices here is only to simplify the implementation. We can always + // go to the fwdIndex/bwdIndex inside LocalRelTable to find the row indices. We can revisit this + // if the copy of rowIndices from fwdIndex/bwdIndex becomes a problem. + row_idx_vec_t rowIndices; + common::row_idx_t nextRowToScan = 0; + + LocalRelTableScanState(const RelTableScanState& baseScanState, LocalRelTable* localRelTable, + std::vector columnIDs) + : RelTableScanState{baseScanState.nodeIDVector, baseScanState.outputVectors, + baseScanState.outState}, + localRelTable{localRelTable} { + this->columnIDs = std::move(columnIDs); + this->direction = baseScanState.direction; + // Setting source to UNCOMMITTED is not necessary but just to keep it semantically + // consistent. + this->source = TableScanSource::UNCOMMITTED; + this->nodeGroupScanState->chunkStates.resize(this->columnIDs.size()); + } +}; + +struct LBUG_API RelTableInsertState : TableInsertState { + common::ValueVector& srcNodeIDVector; + common::ValueVector& dstNodeIDVector; + + common::ValueVector& getBoundNodeIDVector(common::RelDataDirection direction) const { + return direction == common::RelDataDirection::FWD ? srcNodeIDVector : dstNodeIDVector; + } + + RelTableInsertState(common::ValueVector& srcNodeIDVector, common::ValueVector& dstNodeIDVector, + std::vector propertyVectors) + : TableInsertState{std::move(propertyVectors)}, srcNodeIDVector{srcNodeIDVector}, + dstNodeIDVector{dstNodeIDVector} {} +}; + +struct RelTableUpdateState final : TableUpdateState { + common::ValueVector& srcNodeIDVector; + common::ValueVector& dstNodeIDVector; + common::ValueVector& relIDVector; + + common::ValueVector& getBoundNodeIDVector(common::RelDataDirection direction) const { + return direction == common::RelDataDirection::FWD ? srcNodeIDVector : dstNodeIDVector; + } + + RelTableUpdateState(common::column_id_t columnID, common::ValueVector& srcNodeIDVector, + common::ValueVector& dstNodeIDVector, common::ValueVector& relIDVector, + common::ValueVector& propertyVector) + : TableUpdateState{columnID, propertyVector}, srcNodeIDVector{srcNodeIDVector}, + dstNodeIDVector{dstNodeIDVector}, relIDVector{relIDVector} {} +}; + +struct LBUG_API RelTableDeleteState final : TableDeleteState { + common::ValueVector& srcNodeIDVector; + common::ValueVector& dstNodeIDVector; + common::ValueVector& relIDVector; + common::RelDataDirection detachDeleteDirection; + + common::ValueVector& getBoundNodeIDVector(common::RelDataDirection direction) const { + return direction == common::RelDataDirection::FWD ? srcNodeIDVector : dstNodeIDVector; + } + + RelTableDeleteState(common::ValueVector& srcNodeIDVector, common::ValueVector& dstNodeIDVector, + common::ValueVector& relIDVector, + common::RelDataDirection detachDeleteDirection = common::RelDataDirection::FWD) + : srcNodeIDVector{srcNodeIDVector}, dstNodeIDVector{dstNodeIDVector}, + relIDVector{relIDVector}, detachDeleteDirection{detachDeleteDirection} {} +}; + +class LBUG_API RelTable final : public Table { +public: + using rel_multiplicity_constraint_throw_func_t = + std::function; + + RelTable(catalog::RelGroupCatalogEntry* relGroupEntry, common::table_id_t fromTableID, + common::table_id_t toTableID, const StorageManager* storageManager, + MemoryManager* memoryManager); + + common::table_id_t getFromNodeTableID() const { return fromNodeTableID; } + common::table_id_t getToNodeTableID() const { return toNodeTableID; } + + void initScanState(transaction::Transaction* transaction, TableScanState& scanState, + bool resetCachedBoundNodeSelVec = true) const override; + + bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) override; + + void initInsertState(main::ClientContext*, TableInsertState&) override { + // DO NOTHING. + } + void insert(transaction::Transaction* transaction, TableInsertState& insertState) override; + void update(transaction::Transaction* transaction, TableUpdateState& updateState) override; + bool delete_(transaction::Transaction* transaction, TableDeleteState& deleteState) override; + + // Deletes all edges attached to the node(s) specified in the deleteState + // Currently only supports deleting from a single src node + // Note that since the rel table doesn't store nodes this doesn't delete the node itself + void detachDelete(transaction::Transaction* transaction, RelTableDeleteState* deleteState); + bool checkIfNodeHasRels(transaction::Transaction* transaction, + common::RelDataDirection direction, common::ValueVector* srcNodeIDVector) const; + void throwIfNodeHasRels(transaction::Transaction* transaction, + common::RelDataDirection direction, common::ValueVector* srcNodeIDVector, + const rel_multiplicity_constraint_throw_func_t& throwFunc) const; + + void addColumn(transaction::Transaction* transaction, TableAddColumnState& addColumnState, + PageAllocator& pageAllocator) override; + Column* getCSROffsetColumn(common::RelDataDirection direction) const { + return getDirectedTableData(direction)->getCSROffsetColumn(); + } + Column* getCSRLengthColumn(common::RelDataDirection direction) const { + return getDirectedTableData(direction)->getCSRLengthColumn(); + } + common::column_id_t getNumColumns() const { + KU_ASSERT(directedRelData.size() >= 1); + RUNTIME_CHECK(for (const auto& relData + : directedRelData) { + KU_ASSERT(relData->getNumColumns() == directedRelData[0]->getNumColumns()); + }); + return directedRelData[0]->getNumColumns(); + } + Column* getColumn(common::column_id_t columnID, common::RelDataDirection direction) const { + return getDirectedTableData(direction)->getColumn(columnID); + } + + NodeGroup* getOrCreateNodeGroup(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx, common::RelDataDirection direction) const; + + void commit(main::ClientContext* context, catalog::TableCatalogEntry* tableEntry, + LocalTable* localTable) override; + bool checkpoint(main::ClientContext*, catalog::TableCatalogEntry* tableEntry, + PageAllocator& pageAllocator) override; + void rollbackCheckpoint() override {}; + void reclaimStorage(PageAllocator& pageAllocator) const override; + + common::row_idx_t getNumTotalRows(const transaction::Transaction* transaction) override; + + RelTableData* getDirectedTableData(common::RelDataDirection direction) const; + + common::offset_t reserveRelOffsets(common::offset_t numRels) { + std::unique_lock xLck{relOffsetMtx}; + const auto currentRelOffset = nextRelOffset; + nextRelOffset += numRels; + return currentRelOffset; + } + + void pushInsertInfo(const transaction::Transaction* transaction, + common::RelDataDirection direction, const CSRNodeGroup& nodeGroup, + common::row_idx_t numRows_, CSRNodeGroupScanSource source) const; + + std::vector getStorageDirections() const; + common::table_id_t getRelGroupID() const { return relGroupID; } + + void serialize(common::Serializer& ser) const override; + void deserialize(main::ClientContext* context, StorageManager* storageManager, + common::Deserializer& deSer) override; + +private: + static void prepareCommitForNodeGroup(const transaction::Transaction* transaction, + const std::vector& columnIDs, const NodeGroup& localNodeGroup, + CSRNodeGroup& csrNodeGroup, common::offset_t boundOffsetInGroup, + const row_idx_vec_t& rowIndices, common::column_id_t skippedColumn); + + void updateRelOffsets(const LocalRelTable& localRelTable); + + static common::offset_t getCommittedOffset(common::offset_t uncommittedOffset, + common::offset_t maxCommittedOffset); + + void detachDeleteForCSRRels(transaction::Transaction* transaction, RelTableData* tableData, + RelTableData* reverseTableData, RelTableScanState* relDataReadState, + RelTableDeleteState* deleteState); + + void checkRelMultiplicityConstraint(transaction::Transaction* transaction, + const TableInsertState& state) const; + +private: + common::table_id_t relGroupID; + common::table_id_t fromNodeTableID; + common::table_id_t toNodeTableID; + std::mutex relOffsetMtx; + common::offset_t nextRelOffset; + std::vector> directedRelData; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/rel_table_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/rel_table_data.h new file mode 100644 index 0000000000..ca0a3675b4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/rel_table_data.h @@ -0,0 +1,167 @@ +#pragma once + +#include + +#include "common/enums/rel_direction.h" +#include "common/enums/rel_multiplicity.h" +#include "storage/table/column.h" +#include "storage/table/csr_node_group.h" +#include "storage/table/node_group_collection.h" + +namespace lbug { +namespace catalog { +class RelGroupCatalogEntry; +} +namespace transaction { +class Transaction; +} +namespace storage { +class Table; +class MemoryManager; +class RelTableData; + +struct CSRHeaderColumns { + std::unique_ptr offset; + std::unique_ptr length; +}; + +class PersistentVersionRecordHandler final : public VersionRecordHandler { +public: + explicit PersistentVersionRecordHandler(RelTableData* relTableData); + + void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const override; + void rollbackInsert(main::ClientContext* context, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) const override; + +private: + RelTableData* relTableData; +}; + +class InMemoryVersionRecordHandler final : public VersionRecordHandler { +public: + explicit InMemoryVersionRecordHandler(RelTableData* relTableData); + + void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const override; + void rollbackInsert(main::ClientContext* context, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) const override; + +private: + RelTableData* relTableData; +}; + +class RelTableData { +public: + RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, + const catalog::RelGroupCatalogEntry& relGroupEntry, Table& table, + common::RelDataDirection direction, common::table_id_t nbrTableID, bool enableCompression); + + bool update(transaction::Transaction* transaction, common::ValueVector& boundNodeIDVector, + const common::ValueVector& relIDVector, common::column_id_t columnID, + const common::ValueVector& dataVector) const; + bool delete_(transaction::Transaction* transaction, common::ValueVector& boundNodeIDVector, + const common::ValueVector& relIDVector); + void addColumn(TableAddColumnState& addColumnState, PageAllocator& pageAllocator); + + bool checkIfNodeHasRels(transaction::Transaction* transaction, + common::ValueVector* srcNodeIDVector) const; + + Column* getNbrIDColumn() const { return columns[NBR_ID_COLUMN_ID].get(); } + Column* getCSROffsetColumn() const { return csrHeaderColumns.offset.get(); } + Column* getCSRLengthColumn() const { return csrHeaderColumns.length.get(); } + common::column_id_t getNumColumns() const { return columns.size(); } + Column* getColumn(common::column_id_t columnID) const { return columns[columnID].get(); } + std::vector getColumns() const { + std::vector result; + result.reserve(columns.size()); + for (const auto& column : columns) { + result.push_back(column.get()); + } + return result; + } + common::node_group_idx_t getNumNodeGroups() const { return nodeGroups->getNumNodeGroups(); } + NodeGroup* getNodeGroup(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getNodeGroup(nodeGroupIdx, true /*mayOutOfBound*/); + } + NodeGroup* getOrCreateNodeGroup(const transaction::Transaction* transaction, + common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getOrCreateNodeGroup(transaction, nodeGroupIdx, + NodeGroupDataFormat::CSR); + } + + common::RelMultiplicity getMultiplicity() const { return multiplicity; } + + TableStats getStats() const { return nodeGroups->getStats(); } + + void reclaimStorage(PageAllocator& pageAllocator) const; + void checkpoint(const std::vector& columnIDs, + PageAllocator& pageAllocator); + + void pushInsertInfo(const transaction::Transaction* transaction, const CSRNodeGroup& nodeGroup, + common::row_idx_t numRows_, CSRNodeGroupScanSource source); + + void serialize(common::Serializer& serializer) const; + void deserialize(common::Deserializer& deSerializer, MemoryManager& memoryManager); + + NodeGroup* getNodeGroupNoLock(common::node_group_idx_t nodeGroupIdx) const { + return nodeGroups->getNodeGroupNoLock(nodeGroupIdx); + } + + void rollbackGroupCollectionInsert(common::row_idx_t numRows_, bool isPersistent); + + common::RelDataDirection getDirection() const { return direction; } + +private: + void initCSRHeaderColumns(FileHandle* dataFH); + void initPropertyColumns(const catalog::RelGroupCatalogEntry& relGroupEntry, + common::table_id_t nbrTableID, FileHandle* dataFH); + + std::pair findMatchingRow( + transaction::Transaction* transaction, common::ValueVector& boundNodeIDVector, + const common::ValueVector& relIDVector) const; + + template + static double divideNoRoundUp(T1 v1, T2 v2) { + static_assert(std::is_arithmetic_v && std::is_arithmetic_v); + return static_cast(v1) / static_cast(v2); + } + template + static uint64_t multiplyAndRoundUpTo(T1 v1, T2 v2) { + static_assert(std::is_arithmetic_v && std::is_arithmetic_v); + return std::ceil(static_cast(v1) * static_cast(v2)); + } + + std::vector getColumnTypes() const { + std::vector types; + types.reserve(columns.size()); + for (const auto& column : columns) { + types.push_back(column->getDataType().copy()); + } + return types; + } + + const VersionRecordHandler* getVersionRecordHandler(CSRNodeGroupScanSource source) const; + +private: + Table& table; + MemoryManager* mm; + ShadowFile* shadowFile; + bool enableCompression; + PackedCSRInfo packedCSRInfo; + common::RelDataDirection direction; + common::RelMultiplicity multiplicity; + + std::unique_ptr nodeGroups; + + CSRHeaderColumns csrHeaderColumns; + std::vector> columns; + + PersistentVersionRecordHandler persistentVersionRecordHandler; + InMemoryVersionRecordHandler inMemoryVersionRecordHandler; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/string_chunk_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/string_chunk_data.h new file mode 100644 index 0000000000..8d0d45f94a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/string_chunk_data.h @@ -0,0 +1,111 @@ +#pragma once + +#include "common/assert.h" +#include "common/data_chunk/sel_vector.h" +#include "common/types/types.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/dictionary_chunk.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +class StringChunkData final : public ColumnChunkData { +public: + static constexpr common::idx_t DATA_COLUMN_CHILD_READ_STATE_IDX = 0; + static constexpr common::idx_t OFFSET_COLUMN_CHILD_READ_STATE_IDX = 1; + static constexpr common::idx_t INDEX_COLUMN_CHILD_READ_STATE_IDX = 2; + static constexpr common::idx_t CHILD_COLUMN_COUNT = 3; + + StringChunkData(MemoryManager& mm, common::LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState); + StringChunkData(MemoryManager& mm, bool enableCompression, const ColumnChunkMetadata& metadata); + void resetToEmpty() override; + + void append(common::ValueVector* vector, const common::SelectionView& selView) override; + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) override; + ColumnChunkData* getIndexColumnChunk(); + const ColumnChunkData* getIndexColumnChunk() const; + + void initializeScanState(SegmentState& state, const Column* column) const override; + void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector = 0) const override; + void lookup(common::offset_t offsetInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const override; + + void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk) override; + void write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, + common::RelMultiplicity multiplicity) override; + void write(const ColumnChunkData* srcChunk, common::offset_t srcOffsetInChunk, + common::offset_t dstOffsetInChunk, common::offset_t numValuesToCopy) override; + + template + T getValue(common::offset_t /*pos*/) const { + KU_UNREACHABLE; + } + + uint64_t getStringLength(common::offset_t pos) const { + const auto index = indexColumnChunk->getValue(pos); + return dictionaryChunk->getStringLength(index); + } + + void setIndexChunk(std::unique_ptr indexChunk) { + indexColumnChunk = std::move(indexChunk); + } + DictionaryChunk& getDictionaryChunk() { return *dictionaryChunk; } + const DictionaryChunk& getDictionaryChunk() const { return *dictionaryChunk; } + + void finalize() override; + + void flush(PageAllocator& pageAllocator) override; + uint64_t getSizeOnDisk() const override; + uint64_t getMinimumSizeOnDisk() const override; + uint64_t getSizeOnDiskInMemoryStats() const override; + void reclaimStorage(PageAllocator& pageAllocator) override; + + void resetNumValuesFromMetadata() override; + void syncNumValues() override { + numValues = indexColumnChunk->getNumValues(); + metadata.numValues = numValues; + } + + void setToInMemory() override; + void resize(uint64_t newCapacity) override; + void resizeWithoutPreserve(uint64_t newCapacity) override; + uint64_t getEstimatedMemoryUsage() const override; + + void serialize(common::Serializer& serializer) const override; + static void deserialize(common::Deserializer& deSer, ColumnChunkData& chunkData); + +private: + void appendStringColumnChunk(const StringChunkData* other, + common::offset_t startPosInOtherChunk, uint32_t numValuesToAppend); + + void setValueFromString(std::string_view value, uint64_t pos); + + void updateNumValues(size_t newValue); + + void setNumValues(uint64_t numValues) override { + ColumnChunkData::setNumValues(numValues); + indexColumnChunk->setNumValues(numValues); + needFinalize = true; + } + +private: + std::unique_ptr indexColumnChunk; + + std::unique_ptr dictionaryChunk; + // If we never update a value, we don't need to prune unused strings in finalize + bool needFinalize; +}; + +// STRING +template<> +std::string StringChunkData::getValue(common::offset_t pos) const; +template<> +std::string_view StringChunkData::getValue(common::offset_t pos) const; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/string_column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/string_column.h new file mode 100644 index 0000000000..fc2a548294 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/string_column.h @@ -0,0 +1,67 @@ +#pragma once + +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/table/dictionary_column.h" + +namespace lbug { +namespace storage { + +class StringColumn final : public Column { +public: + enum class ChildStateIndex : common::idx_t { DATA = 0, OFFSET = 1, INDEX = 2 }; + static constexpr size_t CHILD_STATE_COUNT = 3; + + StringColumn(std::string name, common::LogicalType dataType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression); + + static std::unique_ptr flushChunkData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator); + + void writeSegment(ColumnChunkData& persistentChunk, SegmentState& state, + common::offset_t dstOffsetInSegment, const ColumnChunkData& data, + common::offset_t srcOffset, common::length_t numValues) const override; + + std::vector> checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment = true) const override; + + const DictionaryColumn& getDictionary() const { return dictionary; } + const Column* getIndexColumn() const { return indexColumn.get(); } + + static SegmentState& getChildState(SegmentState& state, ChildStateIndex child); + static const SegmentState& getChildState(const SegmentState& state, ChildStateIndex child); + +protected: + void scanSegment(const SegmentState& state, common::offset_t startOffsetInChunk, + common::row_idx_t numValuesToScan, common::ValueVector* resultVector, + common::offset_t offsetInResult) const override; + + void scanSegment(const SegmentState& state, ColumnChunkData* resultChunk, + common::offset_t startOffsetInSegment, common::row_idx_t numValuesToScan) const override; + + void scanUnfiltered(const SegmentState& state, common::offset_t startOffsetInChunk, + common::offset_t numValuesToRead, common::ValueVector* resultVector, + common::sel_t startPosInVector = 0) const; + void scanFiltered(const SegmentState& state, common::offset_t startOffsetInChunk, + common::ValueVector* resultVector, common::sel_t startPosInVector) const; + + void lookupInternal(const SegmentState& state, common::offset_t nodeOffset, + common::ValueVector* resultVector, uint32_t posInVector) const override; + +private: + bool canCheckpointInPlace(const SegmentState& state, + const ColumnCheckpointState& checkpointState) const override; + + bool canIndexCommitInPlace(const SegmentState& state, uint64_t numStrings, + common::offset_t maxOffset) const; + +private: + // Main column stores indices of values in the dictionary + DictionaryColumn dictionary; + + std::unique_ptr indexColumn; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/struct_chunk_data.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/struct_chunk_data.h new file mode 100644 index 0000000000..27ad6061ac --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/struct_chunk_data.h @@ -0,0 +1,96 @@ +#pragma once + +#include "common/data_chunk/sel_vector.h" +#include "common/types/types.h" +#include "storage/table/column_chunk_data.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +class StructChunkData final : public ColumnChunkData { +public: + StructChunkData(MemoryManager& mm, common::LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState); + StructChunkData(MemoryManager& mm, common::LogicalType dataType, bool enableCompression, + const ColumnChunkMetadata& metadata); + + ColumnChunkData* getChild(common::idx_t childIdx) { + KU_ASSERT(childIdx < childChunks.size()); + return childChunks[childIdx].get(); + } + std::unique_ptr moveChild(common::idx_t childIdx) { + KU_ASSERT(childIdx < childChunks.size()); + return std::move(childChunks[childIdx]); + } + + void finalize() override; + + uint64_t getEstimatedMemoryUsage() const override; + + void resetNumValuesFromMetadata() override; + void syncNumValues() override { + KU_ASSERT(!childChunks.empty()); + numValues = childChunks[0]->getNumValues(); + metadata.numValues = numValues; + } + + void serialize(common::Serializer& serializer) const override; + static void deserialize(common::Deserializer& deSer, ColumnChunkData& chunkData); + + common::idx_t getNumChildren() const { return childChunks.size(); } + const ColumnChunkData& getChild(common::idx_t childIdx) const { + KU_ASSERT(childIdx < childChunks.size()); + return *childChunks[childIdx]; + } + void setChild(common::idx_t childIdx, std::unique_ptr childChunk) { + KU_ASSERT(childIdx < childChunks.size()); + childChunks[childIdx] = std::move(childChunk); + } + + void flush(PageAllocator& pageAllocator) override; + uint64_t getSizeOnDisk() const override; + uint64_t getMinimumSizeOnDisk() const override; + uint64_t getSizeOnDiskInMemoryStats() const override; + void reclaimStorage(PageAllocator& pageAllocator) override; + +protected: + void append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) override; + void append(common::ValueVector* vector, const common::SelectionView& selView) override; + + void scan(common::ValueVector& output, common::offset_t offset, common::length_t length, + common::sel_t posInOutputVector = 0) const override; + void lookup(common::offset_t offsetInChunk, common::ValueVector& output, + common::sel_t posInOutputVector) const override; + void initializeScanState(SegmentState& state, const Column* column) const override; + + void write(const common::ValueVector* vector, common::offset_t offsetInVector, + common::offset_t offsetInChunk) override; + void write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, + common::RelMultiplicity multiplicity) override; + void write(const ColumnChunkData* srcChunk, common::offset_t srcOffsetInChunk, + common::offset_t dstOffsetInChunk, common::offset_t numValuesToCopy) override; + + void setToInMemory() override; + void resize(uint64_t newCapacity) override; + void resizeWithoutPreserve(uint64_t newCapacity) override; + + void resetToEmpty() override; + void resetToAllNull() override; + + bool numValuesSanityCheck() const override; + + void setNumValues(uint64_t numValues) override { + ColumnChunkData::setNumValues(numValues); + for (auto& childChunk : childChunks) { + childChunk->setNumValues(numValues); + } + } + +private: + std::vector> childChunks; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/struct_column.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/struct_column.h new file mode 100644 index 0000000000..2584ef2112 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/struct_column.h @@ -0,0 +1,48 @@ +#pragma once + +#include "common/types/types.h" +#include "storage/table/column.h" + +namespace lbug { +namespace storage { +class MemoryManager; + +class StructColumn final : public Column { +public: + StructColumn(std::string name, common::LogicalType dataType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression); + + static std::unique_ptr flushChunkData(const ColumnChunkData& chunk, + PageAllocator& pageAllocator); + + Column* getChild(common::idx_t childIdx) const { + KU_ASSERT(childIdx < childColumns.size()); + return childColumns[childIdx].get(); + } + void writeSegment(ColumnChunkData& persistentChunk, SegmentState& state, + common::offset_t offsetInSegment, const ColumnChunkData& data, common::offset_t dataOffset, + common::length_t numValues) const override; + + std::vector> checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment = true) const override; + +protected: + void scanSegment(const SegmentState& state, ColumnChunkData* resultChunk, + common::offset_t startOffsetInSegment, common::row_idx_t numValuesToScan) const override; + void scanSegment(const SegmentState& state, common::offset_t startOffsetInSegment, + common::row_idx_t numValuesToScan, common::ValueVector* resultVector, + common::offset_t offsetInResult) const override; + + void lookupInternal(const SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* resultVector, uint32_t posInVector) const override; + + bool canCheckpointInPlace(const SegmentState& state, + const ColumnCheckpointState& checkpointState) const override; + +private: + std::vector> childColumns; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/table.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/table.h new file mode 100644 index 0000000000..60f32a1a25 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/table.h @@ -0,0 +1,221 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/enums/rel_direction.h" +#include "common/mask.h" +#include "storage/predicate/column_predicate.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/node_group.h" + +namespace lbug { +namespace evaluator { +class ExpressionEvaluator; +} // namespace evaluator +namespace storage { +class MemoryManager; +class Table; + +enum class TableScanSource : uint8_t { COMMITTED = 0, UNCOMMITTED = 1, NONE = UINT8_MAX }; + +struct LBUG_API TableScanState { + Table* table; + std::unique_ptr rowIdxVector; + // Node/Rel ID vector. We assume all output vectors are within the same DataChunk as this one. + common::ValueVector* nodeIDVector; + std::vector outputVectors; + std::shared_ptr outState; + std::vector columnIDs; + common::SemiMask* semiMask; + + // Only used when scan from persistent data. + std::vector columns; + + TableScanSource source; + common::node_group_idx_t nodeGroupIdx; + NodeGroup* nodeGroup = nullptr; + std::unique_ptr nodeGroupScanState; + + std::vector columnPredicateSets; + + TableScanState(common::ValueVector* nodeIDVector, + std::vector outputVectors, + std::shared_ptr outChunkState) + : table{nullptr}, nodeIDVector(nodeIDVector), outputVectors{std::move(outputVectors)}, + outState{std::move(outChunkState)}, semiMask{nullptr}, source{TableScanSource::NONE}, + nodeGroupIdx{common::INVALID_NODE_GROUP_IDX} { + rowIdxVector = std::make_unique(common::LogicalType::INT64()); + rowIdxVector->state = outState; + } + + TableScanState(std::vector columnIDs, std::vector columns) + : table{nullptr}, nodeIDVector(nullptr), outState{nullptr}, columnIDs{std::move(columnIDs)}, + semiMask{nullptr}, columns{std::move(columns)}, source{TableScanSource::NONE}, + nodeGroupIdx{common::INVALID_NODE_GROUP_IDX} {} + + virtual ~TableScanState(); + DELETE_COPY_DEFAULT_MOVE(TableScanState); + + virtual void setToTable(const transaction::Transaction* transaction, Table* table_, + std::vector columnIDs_, + std::vector columnPredicateSets_, + common::RelDataDirection direction = common::RelDataDirection::INVALID); + + // Note that `resetCachedBoundNodeSelVec` is only applicable to RelTable for now. + virtual void initState(transaction::Transaction* transaction, NodeGroup* nodeGroup, + bool /*resetCachedBoundNodeSelVev*/ = true) { + KU_ASSERT(nodeGroup); + this->nodeGroup = nodeGroup; + this->nodeGroup->initializeScanState(transaction, *this); + } + + virtual bool scanNext(transaction::Transaction*) { KU_UNREACHABLE; } + + void resetOutVectors(); + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } +}; + +struct LBUG_API TableInsertState { + std::vector propertyVectors; + // TODO(Guodong): Remove this when we have a better way to skip WAL logging for FTS. + bool logToWAL; + + explicit TableInsertState(std::vector propertyVectors); + virtual ~TableInsertState(); + + template + const T& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + T& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct LBUG_API TableUpdateState { + common::column_id_t columnID; + common::ValueVector& propertyVector; + // TODO(Guodong): Remove this when we have a better way to skip WAL logging for FTS. + bool logToWAL; + + TableUpdateState(common::column_id_t columnID, common::ValueVector& propertyVector); + virtual ~TableUpdateState(); + + template + const T& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + T& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct LBUG_API TableDeleteState { + bool logToWAL; + + TableDeleteState(); + + virtual ~TableDeleteState(); + + template + const T& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + T& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct TableAddColumnState final { + const binder::PropertyDefinition& propertyDefinition; + evaluator::ExpressionEvaluator& defaultEvaluator; + + TableAddColumnState(const binder::PropertyDefinition& propertyDefinition, + evaluator::ExpressionEvaluator& defaultEvaluator) + : propertyDefinition{propertyDefinition}, defaultEvaluator{defaultEvaluator} {} + ~TableAddColumnState() = default; +}; + +class LocalTable; +class StorageManager; +class LBUG_API Table { +public: + Table(const catalog::TableCatalogEntry* tableEntry, const StorageManager* storageManager, + MemoryManager* memoryManager); + virtual ~Table(); + + common::TableType getTableType() const { return tableType; } + common::table_id_t getTableID() const { return tableID; } + std::string getTableName() const { return tableName; } + + // Note that `resetCachedBoundNodeIDs` is only applicable to RelTable for now. + virtual void initScanState(transaction::Transaction* transaction, TableScanState& readState, + bool resetCachedBoundNodeSelVec = true) const = 0; + bool scan(transaction::Transaction* transaction, TableScanState& scanState); + + virtual void initInsertState(main::ClientContext* context, TableInsertState& insertState) = 0; + virtual void insert(transaction::Transaction* transaction, TableInsertState& insertState) = 0; + virtual void update(transaction::Transaction* transaction, TableUpdateState& updateState) = 0; + virtual bool delete_(transaction::Transaction* transaction, TableDeleteState& deleteState) = 0; + + virtual void addColumn(transaction::Transaction* transaction, + TableAddColumnState& addColumnState, PageAllocator& pageAllocator) = 0; + void dropColumn() { setHasChanges(); } + + virtual void commit(main::ClientContext* context, catalog::TableCatalogEntry* tableEntry, + LocalTable* localTable) = 0; + virtual bool checkpoint(main::ClientContext* context, catalog::TableCatalogEntry* tableEntry, + PageAllocator& pageAllocator) = 0; + virtual void rollbackCheckpoint() = 0; + virtual void reclaimStorage(PageAllocator& pageAllocator) const = 0; + + virtual common::row_idx_t getNumTotalRows(const transaction::Transaction* transaction) = 0; + + void setHasChanges() { hasChanges = true; } + + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } + template + const TARGET& cast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + + static common::DataChunk constructDataChunk(MemoryManager* mm, + std::vector types); + + virtual void serialize(common::Serializer& serializer) const = 0; + virtual void deserialize(main::ClientContext* context, StorageManager* storageManager, + common::Deserializer& deSer) = 0; + +protected: + virtual bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) = 0; + +protected: + common::TableType tableType; + common::table_id_t tableID; + std::string tableName; + bool enableCompression; + MemoryManager* memoryManager; + ShadowFile* shadowFile; + std::atomic hasChanges; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/update_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/update_info.h new file mode 100644 index 0000000000..b6caf7b3fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/update_info.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include + +#include "column_chunk_data.h" +#include "common/types/types.h" + +namespace lbug { +namespace common { +class ValueVector; +} // namespace common + +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { +class MemoryManager; + +class ColumnChunkData; +struct VectorUpdateInfo { + common::transaction_t version; + std::array rowsInVector; + common::sel_t numRowsUpdated; + // Older versions. + std::unique_ptr prev; + // Newer versions. + VectorUpdateInfo* next; + + std::unique_ptr data; + + VectorUpdateInfo() + : version{common::INVALID_TRANSACTION}, rowsInVector{}, numRowsUpdated(0), prev(nullptr), + next{nullptr}, data{nullptr} {} + VectorUpdateInfo(MemoryManager& memoryManager, const common::transaction_t transactionID, + common::LogicalType dataType) + : version{transactionID}, rowsInVector{}, numRowsUpdated{0}, prev{nullptr}, next{nullptr} { + data = ColumnChunkFactory::createColumnChunkData(memoryManager, std::move(dataType), false, + common::DEFAULT_VECTOR_CAPACITY, ResidencyState::IN_MEMORY); + } + + std::unique_ptr movePrev() { return std::move(prev); } + void setPrev(std::unique_ptr prev_) { this->prev = std::move(prev_); } + VectorUpdateInfo* getPrev() const { return prev.get(); } + void setNext(VectorUpdateInfo* next_) { this->next = next_; } + VectorUpdateInfo* getNext() const { return next; } +}; + +struct UpdateNode { + mutable std::shared_mutex mtx; + std::unique_ptr info; + + UpdateNode() : info{nullptr} {} + UpdateNode(UpdateNode&& other) noexcept : info{std::move(other.info)} {} + UpdateNode(const UpdateNode& other) = delete; + + bool isEmpty() const { + std::shared_lock lock{mtx}; + return info != nullptr; + } + void clear() { + std::unique_lock lock{mtx}; + info = nullptr; + } +}; + +class UpdateInfo { +public: + using iterate_read_from_row_func_t = + std::function; + + UpdateInfo() {} + + VectorUpdateInfo& update(MemoryManager& memoryManager, + const transaction::Transaction* transaction, common::idx_t vectorIdx, + common::sel_t rowIdxInVector, const common::ValueVector& values); + + void clearVectorInfo(common::idx_t vectorIdx) { + std::unique_lock lock{mtx}; + updates[vectorIdx]->clear(); + } + + common::idx_t getNumVectors() const { + std::shared_lock lock{mtx}; + return updates.size(); + } + + void scan(const transaction::Transaction* transaction, common::ValueVector& output, + common::offset_t offsetInChunk, common::length_t length) const; + void lookup(const transaction::Transaction* transaction, common::offset_t rowInChunk, + common::ValueVector& output, common::sel_t posInOutputVector) const; + + void scanCommitted(const transaction::Transaction* transaction, ColumnChunkData& output, + common::offset_t startOffsetInOutput, common::row_idx_t startRowScanned, + common::row_idx_t numRows) const; + + void iterateVectorInfo(const transaction::Transaction* transaction, common::idx_t idx, + const std::function& func) const; + + void commit(common::idx_t vectorIdx, VectorUpdateInfo* info, common::transaction_t commitTS); + void rollback(common::idx_t vectorIdx, common::transaction_t version); + + common::row_idx_t getNumUpdatedRows(const transaction::Transaction* transaction) const; + + bool hasUpdates(const transaction::Transaction* transaction, common::row_idx_t startRow, + common::length_t numRows) const; + + bool isSet() const { + std::shared_lock lock{mtx}; + return !updates.empty(); + } + void reset() { + std::unique_lock lock{mtx}; + updates.clear(); + } + + void iterateScan(const transaction::Transaction* transaction, uint64_t startOffsetToScan, + uint64_t numRowsToScan, uint64_t startPosInOutput, + const iterate_read_from_row_func_t& readFromRowFunc) const; + +private: + UpdateNode& getUpdateNode(common::idx_t vectorIdx); + UpdateNode& getOrCreateUpdateNode(common::idx_t vectorIdx); + +private: + mutable std::shared_mutex mtx; + std::vector> updates; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/version_info.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/version_info.h new file mode 100644 index 0000000000..b13738f0c2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/version_info.h @@ -0,0 +1,65 @@ +#pragma once + +#include "common/data_chunk/sel_vector.h" +#include "common/types/types.h" + +namespace lbug { +namespace transaction { +class Transaction; +} // namespace transaction + +namespace storage { + +class ChunkedNodeGroup; +struct VectorVersionInfo; + +class LBUG_API VersionInfo { +public: + VersionInfo(); + ~VersionInfo(); + DELETE_BOTH_COPY(VersionInfo); + + void append(common::transaction_t transactionID, common::row_idx_t startRow, + common::row_idx_t numRows); + bool delete_(common::transaction_t transactionID, common::row_idx_t rowIdx); + + bool isSelected(common::transaction_t startTS, common::transaction_t transactionID, + common::row_idx_t rowIdx) const; + void getSelVectorToScan(common::transaction_t startTS, common::transaction_t transactionID, + common::SelectionVector& selVector, common::row_idx_t startRow, + common::row_idx_t numRows) const; + + void clearVectorInfo(common::idx_t vectorIdx); + + bool hasDeletions() const; + common::row_idx_t getNumDeletions(const transaction::Transaction* transaction, + common::row_idx_t startRow, common::length_t numRows) const; + bool hasInsertions() const; + bool isDeleted(const transaction::Transaction* transaction, common::row_idx_t rowInChunk) const; + bool isInserted(const transaction::Transaction* transaction, + common::row_idx_t rowInChunk) const; + + bool hasDeletions(const transaction::Transaction* transaction) const; + + common::idx_t getNumVectors() const { return vectorsInfo.size(); } + + void commitInsert(common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS); + void rollbackInsert(common::row_idx_t startRow, common::row_idx_t numRows); + void commitDelete(common::row_idx_t startRow, common::row_idx_t numRows, + common::transaction_t commitTS); + void rollbackDelete(common::row_idx_t startRow, common::row_idx_t numRows); + + void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(common::Deserializer& deSer); + +private: + // Return nullptr when vectorIdx is out of range or when the vector is not created. + VectorVersionInfo* getVectorVersionInfo(common::idx_t vectorIdx) const; + VectorVersionInfo& getOrCreateVersionInfo(common::idx_t vectorIdx); + + std::vector> vectorsInfo; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/version_record_handler.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/version_record_handler.h new file mode 100644 index 0000000000..aca176e222 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/table/version_record_handler.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/types/types.h" +#include "transaction/transaction.h" + +namespace lbug { + +namespace storage { + +class ChunkedNodeGroup; + +using version_record_handler_op_t = void ( + ChunkedNodeGroup::*)(common::row_idx_t, common::row_idx_t, common::transaction_t); + +// Note: these handlers are not safe to use in multi-threaded contexts without external locking +class VersionRecordHandler { +public: + virtual ~VersionRecordHandler() = default; + + virtual void applyFuncToChunkedGroups(version_record_handler_op_t func, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, common::transaction_t commitTS) const = 0; + + virtual void rollbackInsert(main::ClientContext* context, common::node_group_idx_t nodeGroupIdx, + common::row_idx_t startRow, common::row_idx_t numRows) const; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/undo_buffer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/undo_buffer.h new file mode 100644 index 0000000000..6c806a3634 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/undo_buffer.h @@ -0,0 +1,129 @@ +#pragma once + +#include + +#include "buffer_manager/memory_manager.h" +#include "common/types/types.h" + +namespace lbug { +namespace catalog { +class CatalogEntry; +class CatalogSet; +class SequenceCatalogEntry; +struct SequenceRollbackData; +} // namespace catalog +namespace transaction { +class Transaction; +} + +namespace main { +class ClientContext; +} +namespace storage { +class VersionRecordHandler; + +class UndoMemoryBuffer { +public: + static constexpr uint64_t UNDO_MEMORY_BUFFER_INIT_CAPACITY = common::LBUG_PAGE_SIZE; + + explicit UndoMemoryBuffer(std::unique_ptr buffer, uint64_t capacity) + : buffer{std::move(buffer)}, capacity{capacity} { + currentPosition = 0; + } + + uint8_t* getDataUnsafe() const { return buffer->getData(); } + uint8_t const* getData() const { return buffer->getData(); } + uint64_t getSize() const { return capacity; } + uint64_t getCurrentPosition() const { return currentPosition; } + void moveCurrentPosition(uint64_t offset) { + KU_ASSERT(currentPosition + offset <= capacity); + currentPosition += offset; + } + bool canFit(uint64_t size_) const { return currentPosition + size_ <= this->capacity; } + +private: + std::unique_ptr buffer; + uint64_t capacity; + uint64_t currentPosition; +}; + +class UndoBuffer; +class UndoBufferIterator { +public: + explicit UndoBufferIterator(const UndoBuffer& undoBuffer) : undoBuffer{undoBuffer} {} + + template + void iterate(F&& callback); + template + void reverseIterate(F&& callback); + +private: + const UndoBuffer& undoBuffer; +}; + +class UpdateInfo; +class VersionInfo; +struct VectorUpdateInfo; +class WAL; +// This class is not thread safe, as it is supposed to be accessed by a single thread. +class UndoBuffer { + friend class UndoBufferIterator; + +public: + enum class UndoRecordType : uint16_t { + CATALOG_ENTRY = 0, + SEQUENCE_ENTRY = 1, + UPDATE_INFO = 6, + INSERT_INFO = 7, + DELETE_INFO = 8, + }; + + explicit UndoBuffer(MemoryManager* mm) : mm{mm} {} + + void createCatalogEntry(catalog::CatalogSet& catalogSet, catalog::CatalogEntry& catalogEntry); + void createSequenceChange(catalog::SequenceCatalogEntry& sequenceEntry, + const catalog::SequenceRollbackData& data); + void createInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler); + void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler); + void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, + VectorUpdateInfo* vectorUpdateInfo, common::transaction_t version); + + void commit(common::transaction_t commitTS) const; + void rollback(main::ClientContext* context) const; + +private: + uint8_t* createUndoRecord(uint64_t size); + + void createVersionInfo(UndoRecordType recordType, common::row_idx_t startRow, + common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler, + common::node_group_idx_t nodeGroupIdx = 0); + + static void commitRecord(UndoRecordType recordType, const uint8_t* record, + common::transaction_t commitTS); + static void rollbackRecord(main::ClientContext* context, UndoRecordType recordType, + const uint8_t* record); + + static void commitCatalogEntryRecord(const uint8_t* record, common::transaction_t commitTS); + static void rollbackCatalogEntryRecord(const uint8_t* record); + + static void commitSequenceEntry(uint8_t const* entry, common::transaction_t commitTS); + static void rollbackSequenceEntry(uint8_t const* entry); + + static void commitVersionInfo(UndoRecordType recordType, const uint8_t* record, + common::transaction_t commitTS); + static void rollbackVersionInfo(main::ClientContext* context, UndoRecordType recordType, + const uint8_t* record); + + static void commitVectorUpdateInfo(const uint8_t* record, common::transaction_t commitTS); + static void rollbackVectorUpdateInfo(const uint8_t* record); + +private: + std::mutex mtx; + MemoryManager* mm; + std::vector memoryBuffers; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/checksum_reader.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/checksum_reader.h new file mode 100644 index 0000000000..dcf6a165a9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/checksum_reader.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include "common/serializer/deserializer.h" +#include "common/serializer/reader.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace storage { +class ChecksumReader : public common::Reader { +public: + explicit ChecksumReader(common::FileInfo& fileInfo, MemoryManager& memoryManager, + std::string_view checksumMismatchMessage); + + void read(uint8_t* data, uint64_t size) override; + bool finished() override; + + void onObjectBegin() override; + // Reads the stored checksum + // Also computes + verifies the checksum for the entry that has just been read against the + // stored value + void onObjectEnd() override; + + uint64_t getReadOffset() const; + +private: + common::Deserializer deserializer; + + std::optional currentEntrySize; + std::unique_ptr entryBuffer; + + std::string_view checksumMismatchMessage; +}; +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/checksum_writer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/checksum_writer.h new file mode 100644 index 0000000000..6e13bbcfbe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/checksum_writer.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "common/serializer/serializer.h" +#include "common/serializer/writer.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace storage { +class ChecksumWriter; + +// A wrapper on top of another Writer that accumulates serialized data +// Then flushes that data (along with a computed checksum) when the data has completed serializing +class ChecksumWriter : public common::Writer { +public: + explicit ChecksumWriter(std::shared_ptr outputWriter, + MemoryManager& memoryManager); + + void write(const uint8_t* data, uint64_t size) override; + uint64_t getSize() const override; + + void clear() override; + void sync() override; + + void flush() override; + + void onObjectBegin() override; + // Calculate checksum + write the checksum + serialized contents to underlying writer + void onObjectEnd() override; + +private: + common::Serializer outputSerializer; + std::optional currentEntrySize; + std::unique_ptr entryBuffer; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/local_wal.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/local_wal.h new file mode 100644 index 0000000000..8129a584f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/local_wal.h @@ -0,0 +1,62 @@ +#pragma once + +#include "storage/wal/wal_record.h" + +namespace lbug { +namespace binder { +struct BoundAlterInfo; +} // namespace binder +namespace common { +class InMemFileWriter; +class ValueVector; +} // namespace common +namespace catalog { +class CatalogEntry; +} // namespace catalog + +namespace storage { +class WAL; +class LocalWAL { + friend class WAL; + +public: + explicit LocalWAL(MemoryManager& mm, bool enableChecksums); + + void logCreateCatalogEntryRecord(catalog::CatalogEntry* catalogEntry, bool isInternal); + void logDropCatalogEntryRecord(uint64_t tableID, catalog::CatalogEntryType type); + void logAlterCatalogEntryRecord(const binder::BoundAlterInfo* alterInfo); + void logUpdateSequenceRecord(common::sequence_id_t sequenceID, uint64_t kCount); + + void logTableInsertion(common::table_id_t tableID, common::TableType tableType, + common::row_idx_t numRows, const std::vector& vectors); + void logNodeDeletion(common::table_id_t tableID, common::offset_t nodeOffset, + common::ValueVector* pkVector); + void logNodeUpdate(common::table_id_t tableID, common::column_id_t columnID, + common::offset_t nodeOffset, common::ValueVector* propertyVector); + void logRelDelete(common::table_id_t tableID, common::ValueVector* srcNodeVector, + common::ValueVector* dstNodeVector, common::ValueVector* relIDVector); + void logRelDetachDelete(common::table_id_t tableID, common::RelDataDirection direction, + common::ValueVector* srcNodeVector); + void logRelUpdate(common::table_id_t tableID, common::column_id_t columnID, + common::ValueVector* srcNodeVector, common::ValueVector* dstNodeVector, + common::ValueVector* relIDVector, common::ValueVector* propertyVector); + + void logLoadExtension(std::string path); + + void logBeginTransaction(); + void logCommit(); + + void clear(); + uint64_t getSize(); + +private: + void addNewWALRecord(const WALRecord& walRecord); + +private: + std::mutex mtx; + std::shared_ptr inMemWriter; + common::Serializer serializer; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal.h new file mode 100644 index 0000000000..0389d5a907 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal.h @@ -0,0 +1,53 @@ +#pragma once + +#include "storage/wal/wal_record.h" + +namespace lbug { +namespace common { +class BufferedFileWriter; +class VirtualFileSystem; +} // namespace common + +namespace storage { +class LocalWAL; +class WAL { +public: + WAL(const std::string& dbPath, bool readOnly, bool enableChecksums, + common::VirtualFileSystem* vfs); + ~WAL(); + + void logCommittedWAL(LocalWAL& localWAL, main::ClientContext* context); + void logAndFlushCheckpoint(main::ClientContext* context); + + // Clear any buffer in the WAL writer. Also truncate the WAL file to 0 bytes. + void clear(); + // Reset the WAL writer to nullptr, and remove the WAL file if it exists. + void reset(); + + uint64_t getFileSize(); + + static WAL* Get(const main::ClientContext& context); + +private: + void initWriter(main::ClientContext* context); + void addNewWALRecordNoLock(const WALRecord& walRecord); + void flushAndSyncNoLock(); + void writeHeader(main::ClientContext& context); + +private: + std::mutex mtx; + std::string walPath; + bool inMemory; + [[maybe_unused]] bool readOnly; + common::VirtualFileSystem* vfs; + std::unique_ptr fileInfo; + + // Since most writes to the shared WAL will be flushing local WAL (which has its own checksums), + // these writes can go through the normal writer. We do still need a checksum writer though for + // writing COMMIT/CHECKPOINT records + std::unique_ptr serializer; + bool enableChecksums; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal_record.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal_record.h new file mode 100644 index 0000000000..c8c6c40e6e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal_record.h @@ -0,0 +1,341 @@ +#pragma once + +#include + +#include "binder/ddl/bound_alter_info.h" +#include "catalog/catalog_entry/catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "common/enums/rel_direction.h" +#include "common/enums/table_type.h" +#include "common/types/uuid.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace common { +class Serializer; +class Deserializer; +} // namespace common + +namespace storage { + +enum class WALRecordType : uint8_t { + INVALID_RECORD = 0, // This is not used for any record. 0 is reserved to detect cases where we + // accidentally read from an empty buffer. + BEGIN_TRANSACTION_RECORD = 1, + COMMIT_RECORD = 2, + + COPY_TABLE_RECORD = 13, + CREATE_CATALOG_ENTRY_RECORD = 14, + DROP_CATALOG_ENTRY_RECORD = 16, + ALTER_TABLE_ENTRY_RECORD = 17, + UPDATE_SEQUENCE_RECORD = 18, + TABLE_INSERTION_RECORD = 30, + NODE_DELETION_RECORD = 31, + NODE_UPDATE_RECORD = 32, + REL_DELETION_RECORD = 33, + REL_DETACH_DELETE_RECORD = 34, + REL_UPDATE_RECORD = 35, + + LOAD_EXTENSION_RECORD = 100, + + CHECKPOINT_RECORD = 254, +}; + +struct WALHeader { + common::ku_uuid_t databaseID; + bool enableChecksums; +}; + +struct WALRecord { + WALRecordType type = WALRecordType::INVALID_RECORD; + + WALRecord() = default; + explicit WALRecord(WALRecordType type) : type{type} {} + virtual ~WALRecord() = default; + DELETE_COPY_DEFAULT_MOVE(WALRecord); + + virtual void serialize(common::Serializer& serializer) const; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); + + template + const TARGET& constCast() const { + return common::ku_dynamic_cast(*this); + } + template + TARGET& cast() { + return common::ku_dynamic_cast(*this); + } +}; + +struct BeginTransactionRecord final : WALRecord { + BeginTransactionRecord() : WALRecord{WALRecordType::BEGIN_TRANSACTION_RECORD} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct CommitRecord final : WALRecord { + CommitRecord() : WALRecord{WALRecordType::COMMIT_RECORD} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct CheckpointRecord final : WALRecord { + CheckpointRecord() : WALRecord{WALRecordType::CHECKPOINT_RECORD} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct CreateCatalogEntryRecord final : WALRecord { + catalog::CatalogEntry* catalogEntry; + std::unique_ptr ownedCatalogEntry; + bool isInternal = false; + + CreateCatalogEntryRecord() + : WALRecord{WALRecordType::CREATE_CATALOG_ENTRY_RECORD}, catalogEntry{nullptr} {} + CreateCatalogEntryRecord(catalog::CatalogEntry* catalogEntry, bool isInternal) + : WALRecord{WALRecordType::CREATE_CATALOG_ENTRY_RECORD}, catalogEntry{catalogEntry}, + isInternal{isInternal} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize( + common::Deserializer& deserializer); +}; + +struct CopyTableRecord final : WALRecord { + common::table_id_t tableID; + + CopyTableRecord() + : WALRecord{WALRecordType::COPY_TABLE_RECORD}, tableID{common::INVALID_TABLE_ID} {} + explicit CopyTableRecord(common::table_id_t tableID) + : WALRecord{WALRecordType::COPY_TABLE_RECORD}, tableID{tableID} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct DropCatalogEntryRecord final : WALRecord { + common::oid_t entryID; + catalog::CatalogEntryType entryType; + + DropCatalogEntryRecord() + : WALRecord{WALRecordType::DROP_CATALOG_ENTRY_RECORD}, entryID{common::INVALID_OID}, + entryType{} {} + DropCatalogEntryRecord(common::table_id_t entryID, catalog::CatalogEntryType entryType) + : WALRecord{WALRecordType::DROP_CATALOG_ENTRY_RECORD}, entryID{entryID}, + entryType{entryType} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct AlterTableEntryRecord final : WALRecord { + const binder::BoundAlterInfo* alterInfo; + std::unique_ptr ownedAlterInfo; + + AlterTableEntryRecord() + : WALRecord{WALRecordType::ALTER_TABLE_ENTRY_RECORD}, alterInfo{nullptr} {} + explicit AlterTableEntryRecord(const binder::BoundAlterInfo* alterInfo) + : WALRecord{WALRecordType::ALTER_TABLE_ENTRY_RECORD}, alterInfo{alterInfo} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct UpdateSequenceRecord final : WALRecord { + common::sequence_id_t sequenceID; + uint64_t kCount; + + UpdateSequenceRecord() + : WALRecord{WALRecordType::UPDATE_SEQUENCE_RECORD}, sequenceID{0}, kCount{0} {} + UpdateSequenceRecord(common::sequence_id_t sequenceID, uint64_t kCount) + : WALRecord{WALRecordType::UPDATE_SEQUENCE_RECORD}, sequenceID{sequenceID}, kCount{kCount} { + } + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +struct TableInsertionRecord final : WALRecord { + common::table_id_t tableID; + common::TableType tableType; + common::row_idx_t numRows; + std::vector vectors; + std::vector> ownedVectors; + + TableInsertionRecord() + : WALRecord{WALRecordType::TABLE_INSERTION_RECORD}, tableID{common::INVALID_TABLE_ID}, + tableType{common::TableType::UNKNOWN}, numRows{0} {} + TableInsertionRecord(common::table_id_t tableID, common::TableType tableType, + common::row_idx_t numRows, const std::vector& vectors) + : WALRecord{WALRecordType::TABLE_INSERTION_RECORD}, tableID{tableID}, tableType{tableType}, + numRows{numRows}, vectors{vectors} {} + TableInsertionRecord(common::table_id_t tableID, common::TableType tableType, + common::row_idx_t numRows, std::vector> vectors) + : WALRecord{WALRecordType::TABLE_INSERTION_RECORD}, tableID{tableID}, tableType{tableType}, + numRows{numRows}, ownedVectors{std::move(vectors)} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); +}; + +struct NodeDeletionRecord final : WALRecord { + common::table_id_t tableID; + common::offset_t nodeOffset; + common::ValueVector* pkVector; + std::unique_ptr ownedPKVector; + + NodeDeletionRecord() + : WALRecord{WALRecordType::NODE_DELETION_RECORD}, tableID{common::INVALID_TABLE_ID}, + nodeOffset{common::INVALID_OFFSET}, pkVector{nullptr} {} + NodeDeletionRecord(common::table_id_t tableID, common::offset_t nodeOffset, + common::ValueVector* pkVector) + : WALRecord{WALRecordType::NODE_DELETION_RECORD}, tableID{tableID}, nodeOffset{nodeOffset}, + pkVector{pkVector} {} + NodeDeletionRecord(common::table_id_t tableID, common::offset_t nodeOffset, + std::unique_ptr pkVector) + : WALRecord{WALRecordType::NODE_DELETION_RECORD}, tableID{tableID}, nodeOffset{nodeOffset}, + pkVector{nullptr}, ownedPKVector{std::move(pkVector)} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); +}; + +struct NodeUpdateRecord final : WALRecord { + common::table_id_t tableID; + common::column_id_t columnID; + common::offset_t nodeOffset; + common::ValueVector* propertyVector; + std::unique_ptr ownedPropertyVector; + + NodeUpdateRecord() + : WALRecord{WALRecordType::NODE_UPDATE_RECORD}, tableID{common::INVALID_TABLE_ID}, + columnID{common::INVALID_COLUMN_ID}, nodeOffset{common::INVALID_OFFSET}, + propertyVector{nullptr} {} + NodeUpdateRecord(common::table_id_t tableID, common::column_id_t columnID, + common::offset_t nodeOffset, common::ValueVector* propertyVector) + : WALRecord{WALRecordType::NODE_UPDATE_RECORD}, tableID{tableID}, columnID{columnID}, + nodeOffset{nodeOffset}, propertyVector{propertyVector} {} + NodeUpdateRecord(common::table_id_t tableID, common::column_id_t columnID, + common::offset_t nodeOffset, std::unique_ptr propertyVector) + : WALRecord{WALRecordType::NODE_UPDATE_RECORD}, tableID{tableID}, columnID{columnID}, + nodeOffset{nodeOffset}, propertyVector{nullptr}, + ownedPropertyVector{std::move(propertyVector)} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); +}; + +struct RelDeletionRecord final : WALRecord { + common::table_id_t tableID; + common::ValueVector* srcNodeIDVector; + common::ValueVector* dstNodeIDVector; + common::ValueVector* relIDVector; + std::unique_ptr ownedSrcNodeIDVector; + std::unique_ptr ownedDstNodeIDVector; + std::unique_ptr ownedRelIDVector; + + RelDeletionRecord() + : WALRecord{WALRecordType::REL_DELETION_RECORD}, tableID{common::INVALID_TABLE_ID}, + srcNodeIDVector{nullptr}, dstNodeIDVector{nullptr}, relIDVector{nullptr} {} + RelDeletionRecord(common::table_id_t tableID, common::ValueVector* srcNodeIDVector, + common::ValueVector* dstNodeIDVector, common::ValueVector* relIDVector) + : WALRecord{WALRecordType::REL_DELETION_RECORD}, tableID{tableID}, + srcNodeIDVector{srcNodeIDVector}, dstNodeIDVector{dstNodeIDVector}, + relIDVector{relIDVector} {} + RelDeletionRecord(common::table_id_t tableID, + std::unique_ptr srcNodeIDVector, + std::unique_ptr dstNodeIDVector, + std::unique_ptr relIDVector) + : WALRecord{WALRecordType::REL_DELETION_RECORD}, tableID{tableID}, srcNodeIDVector{nullptr}, + dstNodeIDVector{nullptr}, relIDVector{nullptr}, + ownedSrcNodeIDVector{std::move(srcNodeIDVector)}, + ownedDstNodeIDVector{std::move(dstNodeIDVector)}, + ownedRelIDVector{std::move(relIDVector)} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); +}; + +struct RelDetachDeleteRecord final : WALRecord { + common::table_id_t tableID; + common::RelDataDirection direction; + common::ValueVector* srcNodeIDVector; + std::unique_ptr ownedSrcNodeIDVector; + + RelDetachDeleteRecord() + : WALRecord{WALRecordType::REL_DETACH_DELETE_RECORD}, tableID{common::INVALID_TABLE_ID}, + direction{common::RelDataDirection::FWD}, srcNodeIDVector{nullptr} {} + RelDetachDeleteRecord(common::table_id_t tableID, common::RelDataDirection direction, + common::ValueVector* srcNodeIDVector) + : WALRecord{WALRecordType::REL_DETACH_DELETE_RECORD}, tableID{tableID}, + direction{direction}, srcNodeIDVector{srcNodeIDVector} {} + RelDetachDeleteRecord(common::table_id_t tableID, common::RelDataDirection direction, + std::unique_ptr srcNodeIDVector) + : WALRecord{WALRecordType::REL_DETACH_DELETE_RECORD}, tableID{tableID}, + direction{direction}, srcNodeIDVector{nullptr}, + ownedSrcNodeIDVector{std::move(srcNodeIDVector)} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); +}; + +struct RelUpdateRecord final : WALRecord { + common::table_id_t tableID; + common::column_id_t columnID; + common::ValueVector* srcNodeIDVector; + common::ValueVector* dstNodeIDVector; + common::ValueVector* relIDVector; + common::ValueVector* propertyVector; + std::unique_ptr ownedSrcNodeIDVector; + std::unique_ptr ownedDstNodeIDVector; + std::unique_ptr ownedRelIDVector; + std::unique_ptr ownedPropertyVector; + + RelUpdateRecord() + : WALRecord{WALRecordType::REL_UPDATE_RECORD}, tableID{common::INVALID_TABLE_ID}, + columnID{common::INVALID_COLUMN_ID}, srcNodeIDVector{nullptr}, dstNodeIDVector{nullptr}, + relIDVector{nullptr}, propertyVector{nullptr} {} + RelUpdateRecord(common::table_id_t tableID, common::column_id_t columnID, + common::ValueVector* srcNodeIDVector, common::ValueVector* dstNodeIDVector, + common::ValueVector* relIDVector, common::ValueVector* propertyVector) + : WALRecord{WALRecordType::REL_UPDATE_RECORD}, tableID{tableID}, columnID{columnID}, + srcNodeIDVector{srcNodeIDVector}, dstNodeIDVector{dstNodeIDVector}, + relIDVector{relIDVector}, propertyVector{propertyVector} {} + RelUpdateRecord(common::table_id_t tableID, common::column_id_t columnID, + std::unique_ptr srcNodeIDVector, + std::unique_ptr dstNodeIDVector, + std::unique_ptr relIDVector, + std::unique_ptr propertyVector) + : WALRecord{WALRecordType::REL_UPDATE_RECORD}, tableID{tableID}, columnID{columnID}, + srcNodeIDVector{nullptr}, dstNodeIDVector{nullptr}, relIDVector{nullptr}, + propertyVector{nullptr}, ownedSrcNodeIDVector{std::move(srcNodeIDVector)}, + ownedDstNodeIDVector{std::move(dstNodeIDVector)}, + ownedRelIDVector{std::move(relIDVector)}, ownedPropertyVector{std::move(propertyVector)} { + } + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer, + const main::ClientContext& clientContext); +}; + +struct LoadExtensionRecord final : WALRecord { + std::string path; + + explicit LoadExtensionRecord(std::string path) + : WALRecord{WALRecordType::LOAD_EXTENSION_RECORD}, path{std::move(path)} {} + + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal_replayer.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal_replayer.h new file mode 100644 index 0000000000..58f79130f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/storage/wal/wal_replayer.h @@ -0,0 +1,60 @@ +#pragma once + +#include "storage/wal/wal_record.h" + +namespace lbug { +namespace main { +class ClientContext; +} // namespace main + +namespace storage { +class WALReplayer { +public: + explicit WALReplayer(main::ClientContext& clientContext); + + void replay(bool throwOnWalReplayFailure, bool enableChecksums) const; + +private: + struct WALReplayInfo { + uint64_t offsetDeserialized = 0; + bool isLastRecordCheckpoint = false; + }; + + void replayWALRecord(WALRecord& walRecord) const; + void replayCreateCatalogEntryRecord(WALRecord& walRecord) const; + void replayDropCatalogEntryRecord(const WALRecord& walRecord) const; + void replayAlterTableEntryRecord(const WALRecord& walRecord) const; + void replayTableInsertionRecord(const WALRecord& walRecord) const; + void replayNodeDeletionRecord(const WALRecord& walRecord) const; + void replayNodeUpdateRecord(const WALRecord& walRecord) const; + void replayRelDeletionRecord(const WALRecord& walRecord) const; + void replayRelDetachDeletionRecord(const WALRecord& walRecord) const; + void replayRelUpdateRecord(const WALRecord& walRecord) const; + void replayCopyTableRecord(const WALRecord& walRecord) const; + void replayUpdateSequenceRecord(const WALRecord& walRecord) const; + + void replayNodeTableInsertRecord(const WALRecord& walRecord) const; + void replayRelTableInsertRecord(const WALRecord& walRecord) const; + + void replayLoadExtensionRecord(const WALRecord& walRecord) const; + + // This function is used to deserialize the WAL records without actually applying them to the + // storage. + WALReplayInfo dryReplay(common::FileInfo& fileInfo, bool throwOnWalReplayFailure, + bool enableChecksums) const; + + void removeWALAndShadowFiles() const; + void removeFileIfExists(const std::string& path) const; + + std::unique_ptr openWALFile() const; + void syncWALFile(const common::FileInfo& fileInfo) const; + void truncateWALFile(common::FileInfo& fileInfo, uint64_t size) const; + +private: + main::ClientContext& clientContext; + std::string walPath; + std::string shadowFilePath; +}; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction.h new file mode 100644 index 0000000000..7bbaf4207e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction.h @@ -0,0 +1,171 @@ +#pragma once + +#include +#include + +#include "common/types/types.h" + +namespace lbug { +namespace binder { +struct BoundAlterInfo; +} +namespace catalog { +class CatalogEntry; +class CatalogSet; +class SequenceCatalogEntry; +struct SequenceRollbackData; +} // namespace catalog +namespace main { +class ClientContext; +} // namespace main +namespace storage { +class LocalWAL; +class LocalStorage; +class UndoBuffer; +class WAL; +class VersionInfo; +class UpdateInfo; +struct VectorUpdateInfo; +class ChunkedNodeGroup; +class VersionRecordHandler; +} // namespace storage +namespace transaction { +class TransactionManager; + +enum class TransactionType : uint8_t { READ_ONLY, WRITE, CHECKPOINT, DUMMY, RECOVERY }; + +class LocalCacheManager; +class LBUG_API LocalCacheObject { +public: + explicit LocalCacheObject(std::string key) : key{std::move(key)} {} + + virtual ~LocalCacheObject() = default; + + std::string getKey() const { return key; } + + template + T* cast() { + return common::ku_dynamic_cast(this); + } + +private: + std::string key; +}; + +class LocalCacheManager { +public: + bool contains(const std::string& key) { + std::unique_lock lck{mtx}; + return cachedObjects.contains(key); + } + LocalCacheObject& at(const std::string& key) { + std::unique_lock lck{mtx}; + return *cachedObjects.at(key); + } + bool put(std::unique_ptr object); + + void remove(const std::string& key) { + std::unique_lock lck{mtx}; + cachedObjects.erase(key); + } + +private: + std::unordered_map> cachedObjects; + std::mutex mtx; +}; + +class LBUG_API Transaction { + friend class TransactionManager; + +public: + static constexpr common::transaction_t DUMMY_TRANSACTION_ID = 0; + static constexpr common::transaction_t DUMMY_START_TIMESTAMP = 0; + static constexpr common::transaction_t START_TRANSACTION_ID = + static_cast(1) << 63; + + Transaction(main::ClientContext& clientContext, TransactionType transactionType, + common::transaction_t transactionID, common::transaction_t startTS); + + explicit Transaction(TransactionType transactionType) noexcept; + Transaction(TransactionType transactionType, common::transaction_t ID, + common::transaction_t startTS) noexcept; + + ~Transaction(); + + TransactionType getType() const { return type; } + bool isReadOnly() const { return TransactionType::READ_ONLY == type; } + bool isWriteTransaction() const { return TransactionType::WRITE == type; } + bool isDummy() const { return TransactionType::DUMMY == type; } + bool isRecovery() const { return TransactionType::RECOVERY == type; } + common::transaction_t getID() const { return ID; } + common::transaction_t getStartTS() const { return startTS; } + common::transaction_t getCommitTS() const { return commitTS; } + int64_t getCurrentTS() const { return currentTS; } + + void setForceCheckpoint() { forceCheckpoint = true; } + bool shouldAppendToUndoBuffer() const { + // Only write transactions and recovery transactions should append to the undo buffer. + return isWriteTransaction() || isRecovery(); + } + bool shouldLogToWAL() const; + storage::LocalWAL& getLocalWAL() const { + KU_ASSERT(localWAL); + return *localWAL; + } + + bool shouldForceCheckpoint() const; + + void commit(storage::WAL* wal); + void rollback(storage::WAL* wal); + + storage::LocalStorage* getLocalStorage() const { return localStorage.get(); } + LocalCacheManager& getLocalCacheManager() { return localCacheManager; } + bool isUnCommitted(common::table_id_t tableID, common::offset_t nodeOffset) const; + common::row_idx_t getLocalRowIdx(common::table_id_t tableID, + common::offset_t nodeOffset) const { + return nodeOffset - getMinUncommittedNodeOffset(tableID); + } + common::offset_t getUncommittedOffset(common::table_id_t tableID, + common::row_idx_t localRowIdx) const { + return getMinUncommittedNodeOffset(tableID) + localRowIdx; + } + + void pushCreateDropCatalogEntry(catalog::CatalogSet& catalogSet, + catalog::CatalogEntry& catalogEntry, bool isInternal, bool skipLoggingToWAL = false); + void pushAlterCatalogEntry(catalog::CatalogSet& catalogSet, catalog::CatalogEntry& catalogEntry, + const binder::BoundAlterInfo& alterInfo); + void pushSequenceChange(catalog::SequenceCatalogEntry* sequenceEntry, int64_t kCount, + const catalog::SequenceRollbackData& data); + void pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const; + void pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const; + void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, + storage::VectorUpdateInfo& vectorUpdateInfo, common::transaction_t version) const; + + static Transaction* Get(const main::ClientContext& context); + +private: + common::offset_t getMinUncommittedNodeOffset(common::table_id_t tableID) const; + +private: + TransactionType type; + common::transaction_t ID; + common::transaction_t startTS; + common::transaction_t commitTS; + int64_t currentTS; + main::ClientContext* clientContext; + std::unique_ptr localStorage; + std::unique_ptr undoBuffer; + std::unique_ptr localWAL; + LocalCacheManager localCacheManager; + bool forceCheckpoint; + std::atomic hasCatalogChanges; +}; + +// TODO(bmwinger): These shouldn't need to be exported +extern LBUG_API Transaction DUMMY_TRANSACTION; +extern LBUG_API Transaction DUMMY_CHECKPOINT_TRANSACTION; + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_action.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_action.h new file mode 100644 index 0000000000..25fb64c9b2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_action.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace lbug { +namespace transaction { + +enum class TransactionAction : uint8_t { + BEGIN_READ = 0, + BEGIN_WRITE = 1, + COMMIT = 10, + ROLLBACK = 20, + CHECKPOINT = 30, +}; + +class TransactionActionUtils { +public: + static std::string toString(TransactionAction action); +}; + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_context.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_context.h new file mode 100644 index 0000000000..9f69fab5e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_context.h @@ -0,0 +1,68 @@ +#pragma once + +#include + +#include "transaction.h" + +namespace lbug { + +namespace main { +class ClientContext; +} + +namespace transaction { + +/** + * If the connection is in AUTO_COMMIT mode, any query over the connection will be wrapped around + * a transaction and committed (even if the query is READ_ONLY). + * If the connection is in MANUAL transaction mode, which happens only if an application + * manually begins a transaction (see below), then an application has to manually commit or + * rollback the transaction by calling commit() or rollback(). + * + * AUTO_COMMIT is the default mode when a Connection is created. If an application calls + * begin[ReadOnly/Write]Transaction at any point, the mode switches to MANUAL. This creates + * an "active transaction" in the connection. When a connection is in MANUAL mode and the + * active transaction is rolled back or committed, then the active transaction is removed (so + * the connection no longer has an active transaction), and the mode automatically switches + * back to AUTO_COMMIT. + * Note: When a Connection object is deconstructed, if the connection has an active (manual) + * transaction, then the active transaction is rolled back. + */ +enum class TransactionMode : uint8_t { AUTO = 0, MANUAL = 1 }; + +class LBUG_API TransactionContext { +public: + explicit TransactionContext(main::ClientContext& clientContext); + ~TransactionContext(); + + bool isAutoTransaction() const { return mode == TransactionMode::AUTO; } + + void beginReadTransaction(); + void beginWriteTransaction(); + void beginAutoTransaction(bool readOnlyStatement); + void beginRecoveryTransaction(); + void validateManualTransaction(bool readOnlyStatement) const; + + void commit(); + void rollback(); + + TransactionMode getTransactionMode() const { return mode; } + bool hasActiveTransaction() const { return activeTransaction != nullptr; } + Transaction* getActiveTransaction() const { return activeTransaction; } + + void clearTransaction(); + + static TransactionContext* Get(const main::ClientContext& context); + +private: + void beginTransactionInternal(TransactionType transactionType); + +private: + std::mutex mtx; + main::ClientContext& clientContext; + TransactionMode mode; + Transaction* activeTransaction; +}; + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_manager.h b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_manager.h new file mode 100644 index 0000000000..17c3733b89 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/include/transaction/transaction_manager.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include + +#include "common/constants.h" +#include "common/uniq_lock.h" +#include "storage/checkpointer.h" +#include "storage/wal/wal.h" +#include "transaction/transaction.h" + +namespace lbug { +namespace main { +class ClientContext; +} // namespace main + +namespace testing { +class DBTest; +class FlakyBufferManager; +class FlakyCheckpointer; +} // namespace testing + +namespace transaction { + +class TransactionManager { + friend class testing::DBTest; + friend class testing::FlakyBufferManager; + friend class testing::FlakyCheckpointer; + + using init_checkpointer_func_t = + std::function(main::ClientContext&)>; + static std::unique_ptr initCheckpointer( + main::ClientContext& clientContext); + +public: + // Timestamp starts from 1. 0 is reserved for the dummy system transaction. + explicit TransactionManager(storage::WAL& wal) + : wal{wal}, lastTransactionID{Transaction::START_TRANSACTION_ID}, lastTimestamp{1} { + initCheckpointerFunc = initCheckpointer; + } + + Transaction* beginTransaction(main::ClientContext& clientContext, TransactionType type); + + void commit(main::ClientContext& clientContext, Transaction* transaction); + void rollback(main::ClientContext& clientContext, Transaction* transaction); + + void checkpoint(main::ClientContext& clientContext); + + static TransactionManager* Get(const main::ClientContext& context); + +private: + bool hasNoActiveTransactions() const; + void checkpointNoLock(main::ClientContext& clientContext); + + // This functions locks the mutex to start new transactions. + common::UniqLock stopNewTransactionsAndWaitUntilAllTransactionsLeave(); + + bool hasActiveWriteTransactionNoLock() const; + + // Note: Used by DBTest::createDB only. + void setCheckPointWaitTimeoutForTransactionsToLeaveInMicros(uint64_t waitTimeInMicros) { + checkpointWaitTimeoutInMicros = waitTimeInMicros; + } + + void clearTransactionNoLock(common::transaction_t transactionID); + +private: + storage::WAL& wal; + std::vector> activeTransactions; + common::transaction_t lastTransactionID; + common::transaction_t lastTimestamp; + // This mutex is used to ensure thread safety and letting only one public function to be called + // at any time except the stopNewTransactionsAndWaitUntilAllReadTransactionsLeave + // function, which needs to let calls to coming and rollback. + std::mutex mtxForSerializingPublicFunctionCalls; + std::mutex mtxForStartingNewTransactions; + uint64_t checkpointWaitTimeoutInMicros = common::DEFAULT_CHECKPOINT_WAIT_TIMEOUT_IN_MICROS; + + init_checkpointer_func_t initCheckpointerFunc; +}; +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/main/CMakeLists.txt new file mode 100644 index 0000000000..eccee7adc6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/CMakeLists.txt @@ -0,0 +1,22 @@ +add_subdirectory(query_result) + +add_library(lbug_main + OBJECT + attached_database.cpp + client_context.cpp + connection.cpp + database.cpp + database_manager.cpp + plan_printer.cpp + prepared_statement.cpp + prepared_statement_manager.cpp + query_result.cpp + query_summary.cpp + storage_driver.cpp + version.cpp + db_config.cpp + settings.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/attached_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/attached_database.cpp new file mode 100644 index 0000000000..5d19be963a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/attached_database.cpp @@ -0,0 +1,68 @@ +#include "main/attached_database.h" + +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "storage/checkpointer.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "transaction/transaction_manager.h" + +namespace lbug { +namespace main { + +void AttachedDatabase::invalidateCache() { + if (dbType != common::ATTACHED_LBUG_DB_TYPE) { + auto catalogExtension = catalog->ptrCast(); + catalogExtension->invalidateCache(); + } +} + +static void validateEmptyWAL(const std::string& path, ClientContext* context) { + auto vfs = common::VirtualFileSystem::GetUnsafe(*context); + auto walFilePath = storage::StorageUtils::getWALFilePath(path); + if (vfs->fileOrPathExists(walFilePath, context)) { + auto walFile = vfs->openFile(walFilePath, + common::FileOpenFlags(common::FileFlags::READ_ONLY), context); + if (walFile->getFileSize() > 0) { + throw common::RuntimeException(common::stringFormat( + "Cannot attach an external Lbug database with non-empty wal file. Try manually " + "checkpointing the external database (i.e., run \"CHECKPOINT;\").")); + } + } +} + +AttachedLbugDatabase::AttachedLbugDatabase(std::string dbPath, std::string dbName, + std::string dbType, ClientContext* clientContext) + : AttachedDatabase{std::move(dbName), std::move(dbType), nullptr /* catalog */} { + auto vfs = common::VirtualFileSystem::GetUnsafe(*clientContext); + if (DBConfig::isDBPathInMemory(dbPath)) { + throw common::RuntimeException("Cannot attach an in-memory Lbug database. Please give a " + "path to an on-disk Lbug database directory."); + } + auto path = vfs->expandPath(clientContext, dbPath); + // Note: S3 directory path may end with a '/'. + if (path.ends_with('/')) { + path = path.substr(0, path.size() - 1); + } + if (!vfs->fileOrPathExists(path, clientContext)) { + throw common::RuntimeException(common::stringFormat( + "Cannot attach a remote Lbug database due to invalid path: {}.", path)); + } + catalog = std::make_unique(); + validateEmptyWAL(path, clientContext); + storageManager = std::make_unique(path, true /* isReadOnly */, + clientContext->getDBConfig()->enableChecksums, *storage::MemoryManager::Get(*clientContext), + clientContext->getDBConfig()->enableCompression, vfs); + transactionManager = + std::make_unique(storageManager->getWAL()); + + storageManager->initDataFileHandle(vfs, clientContext); + if (storageManager->getDataFH()->getNumPages() > 0) { + storage::Checkpointer::readCheckpoint(clientContext, catalog.get(), storageManager.get()); + } +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/client_context.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/client_context.cpp new file mode 100644 index 0000000000..3c8a2888c6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/client_context.cpp @@ -0,0 +1,613 @@ +#include "main/client_context.h" + +#include "binder/binder.h" +#include "common/exception/checkpoint.h" +#include "common/exception/connection.h" +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "common/random_engine.h" +#include "common/string_utils.h" +#include "common/task_system/progress_bar.h" +#include "extension/extension.h" +#include "extension/extension_manager.h" +#include "graph/graph_entry_set.h" +#include "main/attached_database.h" +#include "main/database.h" +#include "main/database_manager.h" +#include "main/db_config.h" +#include "optimizer/optimizer.h" +#include "parser/parser.h" +#include "parser/visitor/standalone_call_rewriter.h" +#include "parser/visitor/statement_read_write_analyzer.h" +#include "planner/planner.h" +#include "processor/plan_mapper.h" +#include "processor/processor.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/spiller.h" +#include "storage/storage_manager.h" +#include "transaction/transaction_context.h" +#include + +#if defined(_WIN32) +#include "common/windows_utils.h" +#endif + +using namespace lbug::parser; +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::planner; +using namespace lbug::processor; +using namespace lbug::transaction; + +namespace lbug { +namespace main { + +ActiveQuery::ActiveQuery() : interrupted{false} {} + +void ActiveQuery::reset() { + interrupted = false; + timer = Timer(); +} + +ClientContext::ClientContext(Database* database) : localDatabase{database} { + transactionContext = std::make_unique(*this); + randomEngine = std::make_unique(); + remoteDatabase = nullptr; + graphEntrySet = std::make_unique(); + clientConfig.homeDirectory = getUserHomeDir(); + clientConfig.fileSearchPath = ""; + clientConfig.enableSemiMask = ClientConfigDefault::ENABLE_SEMI_MASK; + clientConfig.enableZoneMap = ClientConfigDefault::ENABLE_ZONE_MAP; + clientConfig.numThreads = database->dbConfig.maxNumThreads; + clientConfig.timeoutInMS = ClientConfigDefault::TIMEOUT_IN_MS; + clientConfig.varLengthMaxDepth = ClientConfigDefault::VAR_LENGTH_MAX_DEPTH; + clientConfig.enableProgressBar = ClientConfigDefault::ENABLE_PROGRESS_BAR; + clientConfig.showProgressAfter = ClientConfigDefault::SHOW_PROGRESS_AFTER; + clientConfig.recursivePatternSemantic = ClientConfigDefault::RECURSIVE_PATTERN_SEMANTIC; + clientConfig.recursivePatternCardinalityScaleFactor = + ClientConfigDefault::RECURSIVE_PATTERN_FACTOR; + clientConfig.disableMapKeyCheck = ClientConfigDefault::DISABLE_MAP_KEY_CHECK; + clientConfig.warningLimit = ClientConfigDefault::WARNING_LIMIT; + progressBar = std::make_unique(clientConfig.enableProgressBar); + warningContext = std::make_unique(&clientConfig); +} + +ClientContext::~ClientContext() { + if (preventTransactionRollbackOnDestruction) { + return; + } + if (Transaction::Get(*this)) { + getDatabase()->transactionManager->rollback(*this, Transaction::Get(*this)); + } +} + +const DBConfig* ClientContext::getDBConfig() const { + return &getDatabase()->dbConfig; +} + +DBConfig* ClientContext::getDBConfigUnsafe() const { + return &getDatabase()->dbConfig; +} + +uint64_t ClientContext::getTimeoutRemainingInMS() const { + KU_ASSERT(hasTimeout()); + const auto elapsed = activeQuery.timer.getElapsedTimeInMS(); + return elapsed >= clientConfig.timeoutInMS ? 0 : clientConfig.timeoutInMS - elapsed; +} + +void ClientContext::startTimer() { + if (hasTimeout()) { + activeQuery.timer.start(); + } +} + +void ClientContext::setQueryTimeOut(uint64_t timeoutInMS) { + lock_t lck{mtx}; + clientConfig.timeoutInMS = timeoutInMS; +} + +uint64_t ClientContext::getQueryTimeOut() const { + return clientConfig.timeoutInMS; +} + +void ClientContext::setMaxNumThreadForExec(uint64_t numThreads) { + lock_t lck{mtx}; + if (numThreads == 0) { + numThreads = localDatabase->dbConfig.maxNumThreads; + } + clientConfig.numThreads = numThreads; +} + +uint64_t ClientContext::getMaxNumThreadForExec() const { + return clientConfig.numThreads; +} + +Value ClientContext::getCurrentSetting(const std::string& optionName) const { + auto lowerCaseOptionName = optionName; + StringUtils::toLower(lowerCaseOptionName); + // Firstly, try to find in built-in options. + const auto option = DBConfig::getOptionByName(lowerCaseOptionName); + if (option != nullptr) { + return option->getSetting(this); + } + // Secondly, try to find in current client session. + if (extensionOptionValues.contains(lowerCaseOptionName)) { + return extensionOptionValues.at(lowerCaseOptionName); + } + // Lastly, find the default value in db clientConfig. + const auto defaultOption = getExtensionOption(lowerCaseOptionName); + if (defaultOption != nullptr) { + return defaultOption->defaultValue; + } + throw RuntimeException{"Invalid option name: " + lowerCaseOptionName + "."}; +} + +void ClientContext::addScanReplace(function::ScanReplacement scanReplacement) { + scanReplacements.push_back(std::move(scanReplacement)); +} + +std::unique_ptr ClientContext::tryReplaceByName( + const std::string& objectName) const { + for (auto& scanReplacement : scanReplacements) { + auto replaceHandles = scanReplacement.lookupFunc(objectName); + if (replaceHandles.empty()) { + continue; // Fail to replace. + } + return scanReplacement.replaceFunc(std::span(replaceHandles.begin(), replaceHandles.end())); + } + return {}; +} + +std::unique_ptr ClientContext::tryReplaceByHandle( + function::scan_replace_handle_t handle) const { + auto handleSpan = std::span{&handle, 1}; + for (auto& scanReplacement : scanReplacements) { + auto replaceData = scanReplacement.replaceFunc(handleSpan); + if (replaceData == nullptr) { + continue; // Fail to replace. + } + return replaceData; + } + return nullptr; +} + +void ClientContext::setExtensionOption(std::string name, Value value) { + StringUtils::toLower(name); + extensionOptionValues.insert_or_assign(name, std::move(value)); +} + +const main::ExtensionOption* ClientContext::getExtensionOption(std::string optionName) const { + return localDatabase->extensionManager->getExtensionOption(optionName); +} + +std::string ClientContext::getExtensionDir() const { + return stringFormat("{}/.lbug/extension/{}/{}/", clientConfig.homeDirectory, + LBUG_EXTENSION_VERSION, extension::getPlatform()); +} + +std::string ClientContext::getDatabasePath() const { + return localDatabase->databasePath; +} + +Database* ClientContext::getDatabase() const { + return localDatabase; +} + +AttachedLbugDatabase* ClientContext::getAttachedDatabase() const { + return remoteDatabase; +} + +bool ClientContext::isInMemory() const { + if (remoteDatabase != nullptr) { + // If we are connected to a remote database, we assume it is not in memory. + return false; + } + return localDatabase->storageManager->isInMemory(); +} + +std::string ClientContext::getEnvVariable(const std::string& name) { +#if defined(_WIN32) + auto envValue = WindowsUtils::utf8ToUnicode(name.c_str()); + auto result = _wgetenv(envValue.c_str()); + if (!result) { + return std::string(); + } + return WindowsUtils::unicodeToUTF8(result); +#else + const char* env = getenv(name.c_str()); // NOLINT(*-mt-unsafe) + if (!env) { + return std::string(); + } + return env; +#endif +} + +std::string ClientContext::getUserHomeDir() { +#if defined(_WIN32) + return getEnvVariable("USERPROFILE"); +#else + return getEnvVariable("HOME"); +#endif +} + +void ClientContext::setDefaultDatabase(AttachedLbugDatabase* defaultDatabase_) { + remoteDatabase = defaultDatabase_; +} + +bool ClientContext::hasDefaultDatabase() const { + return remoteDatabase != nullptr; +} + +void ClientContext::addScalarFunction(std::string name, function::function_set definitions) { + TransactionHelper::runFuncInTransaction( + *transactionContext, + [&]() { + localDatabase->catalog->addFunction(Transaction::Get(*this), + CatalogEntryType::SCALAR_FUNCTION_ENTRY, std::move(name), std::move(definitions)); + }, + false /*readOnlyStatement*/, false /*isTransactionStatement*/, + TransactionHelper::TransactionCommitAction::COMMIT_IF_NEW); +} + +void ClientContext::removeScalarFunction(const std::string& name) { + TransactionHelper::runFuncInTransaction( + *transactionContext, + [&]() { localDatabase->catalog->dropFunction(Transaction::Get(*this), name); }, + false /*readOnlyStatement*/, false /*isTransactionStatement*/, + TransactionHelper::TransactionCommitAction::COMMIT_IF_NEW); +} + +void ClientContext::cleanUp() { + VirtualFileSystem::GetUnsafe(*this)->cleanUP(this); +} + +std::unique_ptr ClientContext::prepareWithParams(std::string_view query, + std::unordered_map> inputParams) { + std::unique_lock lck{mtx}; + auto parsedStatements = std::vector>(); + try { + parsedStatements = parseQuery(query); + } catch (std::exception& exception) { + return PreparedStatement::getPreparedStatementWithError(exception.what()); + } + if (parsedStatements.size() > 1) { + return PreparedStatement::getPreparedStatementWithError( + "Connection Exception: We do not support prepare multiple statements."); + } + + // The binder deals with the parameter values as shared ptrs + // Copy the params to a new map that matches the format that the binder expects + std::unordered_map> inputParamsTmp; + for (auto& [key, value] : inputParams) { + inputParamsTmp.insert(std::make_pair(key, std::make_shared(*value))); + } + auto [preparedStatement, cachedStatement] = prepareNoLock(parsedStatements[0], + true /*shouldCommitNewTransaction*/, std::move(inputParamsTmp)); + preparedStatement->cachedPreparedStatementName = + cachedPreparedStatementManager.addStatement(std::move(cachedStatement)); + useInternalCatalogEntry_ = false; + return std::move(preparedStatement); +} + +static void bindParametersNoLock(PreparedStatement& preparedStatement, + const std::unordered_map>& inputParams) { + for (auto& key : preparedStatement.getKnownParameters()) { + if (inputParams.contains(key)) { + // Found input. Update parameter map. + preparedStatement.updateParameter(key, inputParams.at(key).get()); + } + } + for (auto& key : preparedStatement.getUnknownParameters()) { + if (!inputParams.contains(key)) { + throw Exception("Parameter " + key + " not found."); + } + preparedStatement.addParameter(key, inputParams.at(key).get()); + } +} + +std::unique_ptr ClientContext::executeWithParams(PreparedStatement* preparedStatement, + std::unordered_map> inputParams, + std::optional queryID) { // NOLINT(performance-unnecessary-value-param): It doesn't + // make sense to pass the map as a const reference. + lock_t lck{mtx}; + if (!preparedStatement->isSuccess()) { + return QueryResult::getQueryResultWithError(preparedStatement->errMsg); + } + try { + bindParametersNoLock(*preparedStatement, inputParams); + } catch (std::exception& e) { + return QueryResult::getQueryResultWithError(e.what()); + } + auto name = preparedStatement->getName(); + // LCOV_EXCL_START + // The following should never happen. But we still throw just in case. + if (!cachedPreparedStatementManager.containsStatement(name)) { + return QueryResult::getQueryResultWithError( + stringFormat("Cannot find prepared statement with name {}.", name)); + } + // LCOV_EXCL_STOP + auto cachedStatement = cachedPreparedStatementManager.getCachedStatement(name); + // rebind + auto [newPreparedStatement, newCachedStatement] = + prepareNoLock(cachedStatement->parsedStatement, false /*shouldCommitNewTransaction*/, + preparedStatement->parameterMap); + useInternalCatalogEntry_ = false; + return executeNoLock(newPreparedStatement.get(), newCachedStatement.get(), queryID); +} + +std::unique_ptr ClientContext::query(std::string_view query, + std::optional queryID, QueryConfig config) { + lock_t lck{mtx}; + return queryNoLock(query, queryID, config); +} + +std::unique_ptr ClientContext::queryNoLock(std::string_view query, + std::optional queryID, QueryConfig config) { + auto parsedStatements = std::vector>(); + try { + parsedStatements = parseQuery(query); + } catch (std::exception& exception) { + return QueryResult::getQueryResultWithError(exception.what()); + } + std::unique_ptr queryResult; + QueryResult* lastResult = nullptr; + double internalCompilingTime = 0.0, internalExecutionTime = 0.0; + for (const auto& statement : parsedStatements) { + auto [preparedStatement, cachedStatement] = + prepareNoLock(statement, false /*shouldCommitNewTransaction*/); + auto currentQueryResult = + executeNoLock(preparedStatement.get(), cachedStatement.get(), queryID, config); + if (!currentQueryResult->isSuccess()) { + if (!lastResult) { + queryResult = std::move(currentQueryResult); + } else { + queryResult->addNextResult(std::move(currentQueryResult)); + } + break; + } + auto currentQuerySummary = currentQueryResult->getQuerySummary(); + if (statement->isInternal()) { + // The result of internal statements should be invisible to end users. Skip chaining the + // result of internal statements to the final result to end users. + internalCompilingTime += currentQuerySummary->getCompilingTime(); + internalExecutionTime += currentQuerySummary->getExecutionTime(); + continue; + } + currentQuerySummary->incrementCompilingTime(internalCompilingTime); + currentQuerySummary->incrementExecutionTime(internalExecutionTime); + if (!lastResult) { + // first result of the query + queryResult = std::move(currentQueryResult); + lastResult = queryResult.get(); + } else { + auto current = currentQueryResult.get(); + lastResult->addNextResult(std::move(currentQueryResult)); + lastResult = current; + } + } + useInternalCatalogEntry_ = false; + return queryResult; +} + +std::vector> ClientContext::parseQuery(std::string_view query) { + if (query.empty()) { + throw ConnectionException("Query is empty."); + } + std::vector> statements; + auto parserTimer = TimeMetric(true /*enable*/); + parserTimer.start(); + auto parsedStatements = Parser::parseQuery(query, localDatabase->getTransformerExtensions()); + parserTimer.stop(); + const auto avgParsingTime = parserTimer.getElapsedTimeMS() / parsedStatements.size() / 1.0; + StandaloneCallRewriter standaloneCallAnalyzer{this, parsedStatements.size() == 1}; + for (auto i = 0u; i < parsedStatements.size(); i++) { + auto rewriteQuery = standaloneCallAnalyzer.getRewriteQuery(*parsedStatements[i]); + if (rewriteQuery.empty()) { + parsedStatements[i]->setParsingTime(avgParsingTime); + statements.push_back(std::move(parsedStatements[i])); + } else { + parserTimer.start(); + auto rewrittenStatements = + Parser::parseQuery(rewriteQuery, localDatabase->getTransformerExtensions()); + parserTimer.stop(); + const auto avgRewriteParsingTime = + parserTimer.getElapsedTimeMS() / rewrittenStatements.size() / 1.0; + KU_ASSERT(rewrittenStatements.size() >= 1); + for (auto j = 0u; j < rewrittenStatements.size() - 1; j++) { + rewrittenStatements[j]->setParsingTime(avgParsingTime + avgRewriteParsingTime); + rewrittenStatements[j]->setToInternal(); + statements.push_back(std::move(rewrittenStatements[j])); + } + auto lastRewrittenStatement = rewrittenStatements.back(); + lastRewrittenStatement->setParsingTime(avgParsingTime + avgRewriteParsingTime); + statements.push_back(std::move(lastRewrittenStatement)); + } + } + return statements; +} + +void ClientContext::validateTransaction(bool readOnly, bool requireTransaction) const { + if (!canExecuteWriteQuery() && !readOnly) { + throw ConnectionException("Cannot execute write operations in a read-only database!"); + } + if (requireTransaction && transactionContext->hasActiveTransaction()) { + KU_ASSERT(!transactionContext->isAutoTransaction()); + transactionContext->validateManualTransaction(readOnly); + } +} + +ClientContext::PrepareResult ClientContext::prepareNoLock( + std::shared_ptr parsedStatement, bool shouldCommitNewTransaction, + std::unordered_map> inputParams) { + auto preparedStatement = std::make_unique(); + auto cachedStatement = std::make_unique(); + cachedStatement->parsedStatement = parsedStatement; + cachedStatement->useInternalCatalogEntry = useInternalCatalogEntry_; + auto prepareTimer = TimeMetric(true /* enable */); + prepareTimer.start(); + try { + preparedStatement->preparedSummary.statementType = parsedStatement->getStatementType(); + auto readWriteAnalyzer = StatementReadWriteAnalyzer(this); + TransactionHelper::runFuncInTransaction( + *transactionContext, [&]() -> void { readWriteAnalyzer.visit(*parsedStatement); }, + true /* readOnly */, false /* */, + TransactionHelper::TransactionCommitAction::COMMIT_IF_NEW); + preparedStatement->readOnly = readWriteAnalyzer.isReadOnly(); + validateTransaction(preparedStatement->readOnly, parsedStatement->requireTransaction()); + TransactionHelper::runFuncInTransaction( + *transactionContext, + [&]() -> void { + auto binder = Binder(this, localDatabase->getBinderExtensions()); + auto expressionBinder = binder.getExpressionBinder(); + for (auto& [name, value] : inputParams) { + expressionBinder->addParameter(name, value); + } + const auto boundStatement = binder.bind(*parsedStatement); + preparedStatement->unknownParameters = expressionBinder->getUnknownParameters(); + preparedStatement->parameterMap = expressionBinder->getKnownParameters(); + cachedStatement->columns = boundStatement->getStatementResult()->getColumns(); + auto planner = Planner(this); + auto bestPlan = planner.planStatement(*boundStatement); + optimizer::Optimizer::optimize(&bestPlan, this, planner.getCardinalityEstimator()); + cachedStatement->logicalPlan = std::make_unique(std::move(bestPlan)); + }, + preparedStatement->isReadOnly(), + preparedStatement->getStatementType() == StatementType::TRANSACTION, + TransactionHelper::getAction(shouldCommitNewTransaction, + false /*shouldCommitAutoTransaction*/)); + } catch (std::exception& exception) { + preparedStatement->success = false; + preparedStatement->errMsg = exception.what(); + } + prepareTimer.stop(); + preparedStatement->preparedSummary.compilingTime = + parsedStatement->getParsingTime() + prepareTimer.getElapsedTimeMS(); + return {std::move(preparedStatement), std::move(cachedStatement)}; +} + +std::unique_ptr ClientContext::executeNoLock(PreparedStatement* preparedStatement, + CachedPreparedStatement* cachedStatement, std::optional queryID, + QueryConfig queryConfig) { + if (!preparedStatement->isSuccess()) { + return QueryResult::getQueryResultWithError(preparedStatement->errMsg); + } + useInternalCatalogEntry_ = cachedStatement->useInternalCatalogEntry; + this->resetActiveQuery(); + this->startTimer(); + auto executingTimer = TimeMetric(true /* enable */); + executingTimer.start(); + std::unique_ptr result; + try { + bool isTransactionStatement = + preparedStatement->getStatementType() == StatementType::TRANSACTION; + TransactionHelper::runFuncInTransaction( + *transactionContext, + [&]() -> void { + const auto profiler = std::make_unique(); + profiler->enabled = cachedStatement->logicalPlan->isProfile(); + if (!queryID) { + queryID = localDatabase->getNextQueryID(); + } + const auto executionContext = + std::make_unique(profiler.get(), this, *queryID); + auto mapper = PlanMapper(executionContext.get()); + const auto physicalPlan = mapper.getPhysicalPlan(cachedStatement->logicalPlan.get(), + cachedStatement->columns, queryConfig.resultType, queryConfig.arrowConfig); + if (isTransactionStatement) { + result = localDatabase->queryProcessor->execute(physicalPlan.get(), + executionContext.get()); + } else { + if (preparedStatement->getStatementType() == StatementType::COPY_FROM) { + // Note: We always force checkpoint for COPY_FROM statement. + Transaction::Get(*this)->setForceCheckpoint(); + } + result = localDatabase->queryProcessor->execute(physicalPlan.get(), + executionContext.get()); + } + }, + preparedStatement->isReadOnly(), isTransactionStatement, + TransactionHelper::getAction(true /*shouldCommitNewTransaction*/, + !isTransactionStatement /*shouldCommitAutoTransaction*/)); + } catch (std::exception& e) { + useInternalCatalogEntry_ = false; + return handleFailedExecution(queryID, e); + } + const auto memoryManager = storage::MemoryManager::Get(*this); + memoryManager->getBufferManager()->getSpillerOrSkip([](auto& spiller) { spiller.clearFile(); }); + executingTimer.stop(); + result->setColumnNames(cachedStatement->getColumnNames()); + result->setColumnTypes(cachedStatement->getColumnTypes()); + auto summary = std::make_unique(preparedStatement->preparedSummary); + summary->setExecutionTime(executingTimer.getElapsedTimeMS()); + result->setQuerySummary(std::move(summary)); + return result; +} + +std::unique_ptr ClientContext::handleFailedExecution(std::optional queryID, + const std::exception& e) const { + const auto memoryManager = storage::MemoryManager::Get(*this); + memoryManager->getBufferManager()->getSpillerOrSkip([](auto& spiller) { spiller.clearFile(); }); + if (queryID.has_value()) { + progressBar->endProgress(queryID.value()); + } + return QueryResult::getQueryResultWithError(e.what()); +} + +ClientContext::TransactionHelper::TransactionCommitAction +ClientContext::TransactionHelper::getAction(bool commitIfNew, bool commitIfAuto) { + if (commitIfNew && commitIfAuto) { + return TransactionCommitAction::COMMIT_NEW_OR_AUTO; + } + if (commitIfNew) { + return TransactionCommitAction::COMMIT_IF_NEW; + } + if (commitIfAuto) { + return TransactionCommitAction::COMMIT_IF_AUTO; + } + return TransactionCommitAction::NOT_COMMIT; +} + +// If there is an active transaction in the context, we execute the function in the current active +// transaction. If there is no active transaction, we start an auto commit transaction. +void ClientContext::TransactionHelper::runFuncInTransaction(TransactionContext& context, + const std::function& fun, bool readOnlyStatement, bool isTransactionStatement, + TransactionCommitAction action) { + KU_ASSERT(context.isAutoTransaction() || context.hasActiveTransaction()); + const bool requireNewTransaction = + context.isAutoTransaction() && !context.hasActiveTransaction() && !isTransactionStatement; + if (requireNewTransaction) { + context.beginAutoTransaction(readOnlyStatement); + } + try { + fun(); + if ((requireNewTransaction && commitIfNew(action)) || + (context.isAutoTransaction() && commitIfAuto(action))) { + context.commit(); + } + } catch (CheckpointException&) { + context.clearTransaction(); + throw; + } catch (std::exception&) { + context.rollback(); + throw; + } +} + +bool ClientContext::canExecuteWriteQuery() const { + if (getDBConfig()->readOnly) { + return false; + } + // Note: we can only attach a remote lbug database in read-only mode and only one + // remote lbug database can be attached. + const auto dbManager = DatabaseManager::Get(*this); + for (const auto& attachedDB : dbManager->getAttachedDatabases()) { + if (attachedDB->getDBType() == ATTACHED_LBUG_DB_TYPE) { + return false; + } + } + return true; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/connection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/connection.cpp new file mode 100644 index 0000000000..2a663100cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/connection.cpp @@ -0,0 +1,111 @@ +#include "main/connection.h" + +#include + +#include "common/random_engine.h" + +using namespace lbug::parser; +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::processor; +using namespace lbug::transaction; + +namespace lbug { +namespace main { + +Connection::Connection(Database* database) { + KU_ASSERT(database != nullptr); + this->database = database; + this->dbLifeCycleManager = database->dbLifeCycleManager; + clientContext = std::make_unique(database); +} + +Connection::~Connection() { + clientContext->preventTransactionRollbackOnDestruction = dbLifeCycleManager->isDatabaseClosed; +} + +void Connection::setMaxNumThreadForExec(uint64_t numThreads) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + clientContext->setMaxNumThreadForExec(numThreads); +} + +uint64_t Connection::getMaxNumThreadForExec() { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + return clientContext->getMaxNumThreadForExec(); +} + +std::unique_ptr Connection::prepare(std::string_view query) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + return clientContext->prepareWithParams(query); +} + +std::unique_ptr Connection::prepareWithParams(std::string_view query, + std::unordered_map> inputParams) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + return clientContext->prepareWithParams(query, std::move(inputParams)); +} + +std::unique_ptr Connection::query(std::string_view queryStatement) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + auto queryResult = clientContext->query(queryStatement); + queryResult->setDBLifeCycleManager(dbLifeCycleManager); + return queryResult; +} + +std::unique_ptr Connection::queryAsArrow(std::string_view query, int64_t chunkSize) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + auto queryResult = clientContext->query(query, std::nullopt, + {QueryResultType::ARROW, ArrowResultConfig{chunkSize}}); + queryResult->setDBLifeCycleManager(dbLifeCycleManager); + return queryResult; +} + +std::unique_ptr Connection::queryWithID(std::string_view queryStatement, + uint64_t queryID) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + auto queryResult = clientContext->query(queryStatement, queryID); + queryResult->setDBLifeCycleManager(dbLifeCycleManager); + return queryResult; +} + +void Connection::interrupt() { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + clientContext->interrupt(); +} + +void Connection::setQueryTimeOut(uint64_t timeoutInMS) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + clientContext->setQueryTimeOut(timeoutInMS); +} + +std::unique_ptr Connection::executeWithParams(PreparedStatement* preparedStatement, + std::unordered_map> inputParams) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + auto queryResult = clientContext->executeWithParams(preparedStatement, std::move(inputParams)); + queryResult->setDBLifeCycleManager(dbLifeCycleManager); + return queryResult; +} + +std::unique_ptr Connection::executeWithParamsWithID( + PreparedStatement* preparedStatement, + std::unordered_map> inputParams, uint64_t queryID) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + auto queryResult = + clientContext->executeWithParams(preparedStatement, std::move(inputParams), queryID); + queryResult->setDBLifeCycleManager(dbLifeCycleManager); + return queryResult; +} + +void Connection::addScalarFunction(std::string name, function::function_set definitions) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + clientContext->addScalarFunction(name, std::move(definitions)); +} + +void Connection::removeScalarFunction(std::string name) { + dbLifeCycleManager->checkDatabaseClosedOrThrow(); + clientContext->removeScalarFunction(name); +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/database.cpp new file mode 100644 index 0000000000..24eba72f1a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/database.cpp @@ -0,0 +1,239 @@ +#include "main/database.h" + +#include "extension/binder_extension.h" +#include "extension/extension_manager.h" +#include "extension/mapper_extension.h" +#include "extension/planner_extension.h" +#include "extension/transformer_extension.h" +#include "main/client_context.h" +#include "main/database_manager.h" +#include "storage/buffer_manager/buffer_manager.h" + +#if defined(_WIN32) +#include +#else +#include +#endif + +#include "common/exception/exception.h" +#include "common/file_system/virtual_file_system.h" +#include "main/db_config.h" +#include "processor/processor.h" +#include "storage/storage_extension.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "transaction/transaction_manager.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace main { + +SystemConfig::SystemConfig(uint64_t bufferPoolSize_, uint64_t maxNumThreads, bool enableCompression, + bool readOnly, uint64_t maxDBSize, bool autoCheckpoint, uint64_t checkpointThreshold, + bool forceCheckpointOnClose, bool throwOnWalReplayFailure, bool enableChecksums +#if defined(__APPLE__) + , + uint32_t threadQos +#endif + ) + : maxNumThreads{maxNumThreads}, enableCompression{enableCompression}, readOnly{readOnly}, + autoCheckpoint{autoCheckpoint}, checkpointThreshold{checkpointThreshold}, + forceCheckpointOnClose{forceCheckpointOnClose}, + throwOnWalReplayFailure(throwOnWalReplayFailure), enableChecksums(enableChecksums) { +#if defined(__APPLE__) + this->threadQos = threadQos; +#endif + if (bufferPoolSize_ == -1u || bufferPoolSize_ == 0) { +#if defined(_WIN32) + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + auto systemMemSize = (std::uint64_t)status.ullTotalPhys; +#else + auto systemMemSize = static_cast(sysconf(_SC_PHYS_PAGES)) * + static_cast(sysconf(_SC_PAGESIZE)); +#endif + bufferPoolSize_ = static_cast( + BufferPoolConstants::DEFAULT_PHY_MEM_SIZE_RATIO_FOR_BM * + static_cast(std::min(systemMemSize, static_cast(UINTPTR_MAX)))); + // On 32-bit systems or systems with extremely large memory, the buffer pool size may + // exceed the maximum size of a VMRegion. In this case, we set the buffer pool size to + // 80% of the maximum size of a VMRegion. + bufferPoolSize_ = static_cast(std::min(static_cast(bufferPoolSize_), + BufferPoolConstants::DEFAULT_VM_REGION_MAX_SIZE * + BufferPoolConstants::DEFAULT_PHY_MEM_SIZE_RATIO_FOR_BM)); + } + bufferPoolSize = bufferPoolSize_; +#ifndef __SINGLE_THREADED__ + if (maxNumThreads == 0) { + this->maxNumThreads = std::thread::hardware_concurrency(); + } +#else + // In single-threaded mode, even if the user specifies a number of threads, + // it will be ignored and set to 0. + this->maxNumThreads = 1; +#endif + if (maxDBSize == -1u) { + maxDBSize = BufferPoolConstants::DEFAULT_VM_REGION_MAX_SIZE; + } + this->maxDBSize = maxDBSize; +} + +Database::Database(std::string_view databasePath, SystemConfig systemConfig) + : Database(databasePath, systemConfig, initBufferManager) {} + +Database::Database(std::string_view databasePath, SystemConfig systemConfig, + construct_bm_func_t constructBMFunc) + : dbConfig(systemConfig) { + initMembers(databasePath, constructBMFunc); +} + +std::unique_ptr Database::initBufferManager(const Database& db) { + return std::make_unique(db.databasePath, + StorageUtils::getTmpFilePath(db.databasePath), db.dbConfig.bufferPoolSize, + db.dbConfig.maxDBSize, db.vfs.get(), db.dbConfig.readOnly); +} + +void Database::initMembers(std::string_view dbPath, construct_bm_func_t initBmFunc) { + // To expand a path with home directory(~), we have to pass in a dummy clientContext which + // handles the home directory expansion. + const auto dbPathStr = std::string(dbPath); + auto clientContext = ClientContext(this); + databasePath = StorageUtils::expandPath(&clientContext, dbPathStr); + + if (std::filesystem::is_directory(databasePath)) { + throw RuntimeException("Database path cannot be a directory: " + databasePath); + } + vfs = std::make_unique(databasePath); + validatePathInReadOnly(); + + bufferManager = initBmFunc(*this); + memoryManager = std::make_unique(bufferManager.get(), vfs.get()); +#if defined(__APPLE__) + queryProcessor = + std::make_unique(dbConfig.maxNumThreads, dbConfig.threadQos); +#else + queryProcessor = std::make_unique(dbConfig.maxNumThreads); +#endif + + catalog = std::make_unique(); + storageManager = std::make_unique(databasePath, dbConfig.readOnly, + dbConfig.enableChecksums, *memoryManager, dbConfig.enableCompression, vfs.get()); + transactionManager = std::make_unique(storageManager->getWAL()); + databaseManager = std::make_unique(); + + extensionManager = std::make_unique(); + dbLifeCycleManager = std::make_shared(); + if (clientContext.isInMemory()) { + storageManager->initDataFileHandle(vfs.get(), &clientContext); + extensionManager->autoLoadLinkedExtensions(&clientContext); + return; + } + StorageManager::recover(clientContext, dbConfig.throwOnWalReplayFailure, + dbConfig.enableChecksums); +} + +Database::~Database() { + if (!dbConfig.readOnly && dbConfig.forceCheckpointOnClose) { + try { + ClientContext clientContext(this); + transactionManager->checkpoint(clientContext); + } catch (...) {} // NOLINT + } + dbLifeCycleManager->isDatabaseClosed = true; +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const function. +void Database::registerFileSystem(std::unique_ptr fs) { + vfs->registerFileSystem(std::move(fs)); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const function. +void Database::registerStorageExtension(std::string name, + std::unique_ptr storageExtension) { + extensionManager->registerStorageExtension(std::move(name), std::move(storageExtension)); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const function. +void Database::addExtensionOption(std::string name, LogicalTypeID type, Value defaultValue, + bool isConfidential) { + extensionManager->addExtensionOption(std::move(name), type, std::move(defaultValue), + isConfidential); +} + +void Database::addTransformerExtension( + std::unique_ptr transformerExtension) { + transformerExtensions.push_back(std::move(transformerExtension)); +} + +std::vector Database::getTransformerExtensions() { + std::vector transformers; + for (auto& transformerExtension : transformerExtensions) { + transformers.push_back(transformerExtension.get()); + } + return transformers; +} + +void Database::addBinderExtension( + std::unique_ptr transformerExtension) { + binderExtensions.push_back(std::move(transformerExtension)); +} + +std::vector Database::getBinderExtensions() { + std::vector binders; + for (auto& binderExtension : binderExtensions) { + binders.push_back(binderExtension.get()); + } + return binders; +} + +void Database::addPlannerExtension(std::unique_ptr plannerExtension) { + plannerExtensions.push_back(std::move(plannerExtension)); +} + +std::vector Database::getPlannerExtensions() { + std::vector planners; + for (auto& plannerExtension : plannerExtensions) { + planners.push_back(plannerExtension.get()); + } + return planners; +} + +void Database::addMapperExtension(std::unique_ptr mapperExtension) { + mapperExtensions.push_back(std::move(mapperExtension)); +} + +std::vector Database::getMapperExtensions() { + std::vector mappers; + for (auto& mapperExtension : mapperExtensions) { + mappers.push_back(mapperExtension.get()); + } + return mappers; +} + +std::vector Database::getStorageExtensions() { + return extensionManager->getStorageExtensions(); +} + +void Database::validatePathInReadOnly() const { + if (dbConfig.readOnly) { + if (DBConfig::isDBPathInMemory(databasePath)) { + throw Exception("Cannot open an in-memory database under READ ONLY mode."); + } + if (!vfs->fileOrPathExists(databasePath)) { + throw Exception("Cannot create an empty database under READ ONLY mode."); + } + } +} + +uint64_t Database::getNextQueryID() { + std::unique_lock lock(queryIDGenerator.queryIDLock); + return queryIDGenerator.queryID++; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/database_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/database_manager.cpp new file mode 100644 index 0000000000..0b7de4a3fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/database_manager.cpp @@ -0,0 +1,88 @@ +#include "main/database_manager.h" + +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include "main/client_context.h" +#include "main/database.h" + +using namespace lbug::common; + +namespace lbug { +namespace main { + +DatabaseManager::DatabaseManager() : defaultDatabase{""} {} + +void DatabaseManager::registerAttachedDatabase(std::unique_ptr attachedDatabase) { + if (defaultDatabase == "") { + defaultDatabase = attachedDatabase->getDBName(); + } + if (hasAttachedDatabase(attachedDatabase->getDBName())) { + throw RuntimeException{stringFormat( + "Duplicate attached database name: {}. Attached database name must be unique.", + attachedDatabase->getDBName())}; + } + attachedDatabases.push_back(std::move(attachedDatabase)); +} + +bool DatabaseManager::hasAttachedDatabase(const std::string& name) { + auto upperCaseName = StringUtils::getUpper(name); + for (auto& attachedDatabase : attachedDatabases) { + auto attachedDBName = StringUtils::getUpper(attachedDatabase->getDBName()); + if (attachedDBName == upperCaseName) { + return true; + } + } + return false; +} + +AttachedDatabase* DatabaseManager::getAttachedDatabase(const std::string& name) { + auto upperCaseName = StringUtils::getUpper(name); + for (auto& attachedDatabase : attachedDatabases) { + auto attachedDBName = StringUtils::getUpper(attachedDatabase->getDBName()); + if (attachedDBName == upperCaseName) { + return attachedDatabase.get(); + } + } + throw RuntimeException{stringFormat("No database named {}.", name)}; +} + +void DatabaseManager::detachDatabase(const std::string& databaseName) { + auto upperCaseName = StringUtils::getUpper(databaseName); + for (auto it = attachedDatabases.begin(); it != attachedDatabases.end(); ++it) { + auto attachedDBName = (*it)->getDBName(); + StringUtils::toUpper(attachedDBName); + if (attachedDBName == upperCaseName) { + attachedDatabases.erase(it); + return; + } + } + throw RuntimeException{stringFormat("Database: {} doesn't exist.", databaseName)}; +} + +void DatabaseManager::setDefaultDatabase(const std::string& databaseName) { + if (getAttachedDatabase(databaseName) == nullptr) { + throw RuntimeException{stringFormat("No database named {}.", databaseName)}; + } + defaultDatabase = databaseName; +} + +std::vector DatabaseManager::getAttachedDatabases() const { + std::vector attachedDatabasesPtr; + for (auto& attachedDatabase : attachedDatabases) { + attachedDatabasesPtr.push_back(attachedDatabase.get()); + } + return attachedDatabasesPtr; +} + +void DatabaseManager::invalidateCache() { + for (auto& attachedDatabase : attachedDatabases) { + attachedDatabase->invalidateCache(); + } +} + +DatabaseManager* DatabaseManager::Get(const ClientContext& context) { + return context.getDatabase()->getDatabaseManager(); +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/db_config.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/db_config.cpp new file mode 100644 index 0000000000..b835ba0d75 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/db_config.cpp @@ -0,0 +1,57 @@ +#include "main/db_config.h" + +#include "common/string_utils.h" +#include "main/database.h" +#include "main/settings.h" + +using namespace lbug::common; + +namespace lbug { +namespace main { + +#define GET_CONFIGURATION(_PARAM) \ + { _PARAM::name, _PARAM::inputType, _PARAM::setContext, _PARAM::getSetting } + +static ConfigurationOption options[] = { // NOLINT(cert-err58-cpp): + GET_CONFIGURATION(ThreadsSetting), GET_CONFIGURATION(TimeoutSetting), + GET_CONFIGURATION(WarningLimitSetting), GET_CONFIGURATION(VarLengthExtendMaxDepthSetting), + GET_CONFIGURATION(SparseFrontierThresholdSetting), GET_CONFIGURATION(EnableSemiMaskSetting), + GET_CONFIGURATION(DisableMapKeyCheck), GET_CONFIGURATION(EnableZoneMapSetting), + GET_CONFIGURATION(HomeDirectorySetting), GET_CONFIGURATION(FileSearchPathSetting), + GET_CONFIGURATION(ProgressBarSetting), GET_CONFIGURATION(RecursivePatternSemanticSetting), + GET_CONFIGURATION(RecursivePatternFactorSetting), GET_CONFIGURATION(EnableMVCCSetting), + GET_CONFIGURATION(CheckpointThresholdSetting), GET_CONFIGURATION(AutoCheckpointSetting), + GET_CONFIGURATION(ForceCheckpointClosingDBSetting), GET_CONFIGURATION(SpillToDiskSetting), + GET_CONFIGURATION(EnableOptimizerSetting), GET_CONFIGURATION(EnableInternalCatalogSetting)}; + +DBConfig::DBConfig(const SystemConfig& systemConfig) + : bufferPoolSize{systemConfig.bufferPoolSize}, maxNumThreads{systemConfig.maxNumThreads}, + enableCompression{systemConfig.enableCompression}, readOnly{systemConfig.readOnly}, + maxDBSize{systemConfig.maxDBSize}, enableMultiWrites{false}, + autoCheckpoint{systemConfig.autoCheckpoint}, + checkpointThreshold{systemConfig.checkpointThreshold}, + forceCheckpointOnClose{systemConfig.forceCheckpointOnClose}, + throwOnWalReplayFailure(systemConfig.throwOnWalReplayFailure), + enableChecksums(systemConfig.enableChecksums), enableSpillingToDisk{true} { +#if defined(__APPLE__) + this->threadQos = systemConfig.threadQos; +#endif +} + +ConfigurationOption* DBConfig::getOptionByName(const std::string& optionName) { + auto lOptionName = optionName; + StringUtils::toLower(lOptionName); + for (auto& internalOption : options) { + if (internalOption.name == lOptionName) { + return &internalOption; + } + } + return nullptr; +} + +bool DBConfig::isDBPathInMemory(const std::string& dbPath) { + return dbPath.empty() || dbPath == ":memory:"; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/plan_printer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/plan_printer.cpp new file mode 100644 index 0000000000..ed18126bab --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/plan_printer.cpp @@ -0,0 +1,445 @@ +#include "main/plan_printer.h" + +#include + +#include "json.hpp" +#include "planner/operator/logical_plan.h" +#include "processor/physical_plan.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::processor; + +namespace lbug { +namespace main { + +OpProfileBox::OpProfileBox(std::string opName, const std::string& paramsName, + std::vector attributes) + : opName{std::move(opName)}, attributes{std::move(attributes)} { + std::stringstream paramsStream{paramsName}; + std::string paramStr = ""; + std::string subStr; + bool subParam = false; + // This loop splits the parameters by commas, while not + // splitting up parameters that are operators. + while (paramsStream.good()) { + getline(paramsStream, subStr, ','); + if (subStr.find('(') != std::string::npos && subStr.find(')') == std::string::npos) { + paramStr = subStr; + subParam = true; + continue; + } + if (subParam && subStr.find(')') == std::string::npos) { + paramStr += "," + subStr; + continue; + } + if (subParam) { + subStr = paramStr + ")"; + paramStr = ""; + subParam = false; + } + // This if statement discards any strings that are completely whitespace. + if (subStr.find_first_not_of(" \t\n\v\f\r") != std::string::npos) { + paramsNames.push_back(subStr); + } + } +} + +uint32_t OpProfileBox::getAttributeMaxLen() const { + auto maxAttributeLen = opName.length(); + for (auto& param : paramsNames) { + maxAttributeLen = std::max(param.length(), maxAttributeLen); + } + for (auto& attribute : attributes) { + maxAttributeLen = std::max(attribute.length(), maxAttributeLen); + } + return maxAttributeLen; +} + +std::string OpProfileBox::getParamsName(uint32_t idx) const { + KU_ASSERT(idx < paramsNames.size()); + return paramsNames[idx]; +} + +std::string OpProfileBox::getAttribute(uint32_t idx) const { + KU_ASSERT(idx < attributes.size()); + return attributes[idx]; +} + +OpProfileTree::OpProfileTree(const PhysicalOperator* op, Profiler& profiler) { + auto numRows = 0u, numCols = 0u; + calculateNumRowsAndColsForOp(op, numRows, numCols); + opProfileBoxes.resize(numRows); + for_each(opProfileBoxes.begin(), opProfileBoxes.end(), + [numCols](std::vector>& profileBoxes) { + profileBoxes.resize(numCols); + }); + auto maxFieldWidth = 0u; + fillOpProfileBoxes(op, 0 /* rowIdx */, 0 /* colIdx */, maxFieldWidth, profiler); + // The width of each profileBox = fieldWidth + leftIndentWidth + boxLeftFrameWidth + + // rightIndentWidth + boxRightFrameWidth; + this->opProfileBoxWidth = maxFieldWidth + 2 * (INDENT_WIDTH + BOX_FRAME_WIDTH); +} + +OpProfileTree::OpProfileTree(const LogicalOperator* op) { + auto numRows = 0u, numCols = 0u; + calculateNumRowsAndColsForOp(op, numRows, numCols); + opProfileBoxes.resize(numRows); + for_each(opProfileBoxes.begin(), opProfileBoxes.end(), + [numCols](std::vector>& profileBoxes) { + profileBoxes.resize(numCols); + }); + auto maxFieldWidth = 0u; + fillOpProfileBoxes(op, 0 /* rowIdx */, 0 /* colIdx */, maxFieldWidth); + // The width of each profileBox = fieldWidth + leftIndentWidth + boxLeftFrameWidth + + // rightIndentWidth + boxRightFrameWidth; + this->opProfileBoxWidth = std::max( + maxFieldWidth + 2 * (INDENT_WIDTH + BOX_FRAME_WIDTH), MIN_LOGICAL_BOX_WIDTH); +} + +void printSpaceIfNecessary(uint32_t idx, std::ostringstream& oss) { + if (idx > 0) { + oss << " "; + } +} + +std::ostringstream OpProfileTree::printPlanToOstream() const { + std::ostringstream oss; + prettyPrintPlanTitle(oss, "Physical Plan"); + for (auto i = 0u; i < opProfileBoxes.size(); i++) { + printOpProfileBoxUpperFrame(i, oss); + printOpProfileBoxes(i, oss); + printOpProfileBoxLowerFrame(i, oss); + } + return oss; +} + +std::ostringstream OpProfileTree::printLogicalPlanToOstream() const { + std::ostringstream oss; + prettyPrintPlanTitle(oss, "Logical Plan"); + for (auto i = 0u; i < opProfileBoxes.size(); i++) { + printOpProfileBoxUpperFrame(i, oss); + printOpProfileBoxes(i, oss); + printOpProfileBoxLowerFrame(i, oss); + } + return oss; +} + +void OpProfileTree::calculateNumRowsAndColsForOp(const PhysicalOperator* op, uint32_t& numRows, + uint32_t& numCols) { + if (!op->getNumChildren()) { + numRows = 1; + numCols = 1; + return; + } + + for (auto i = 0u; i < op->getNumChildren(); i++) { + auto numRowsInChild = 0u, numColsInChild = 0u; + calculateNumRowsAndColsForOp(op->getChild(i), numRowsInChild, numColsInChild); + numCols += numColsInChild; + numRows = std::max(numRowsInChild, numRows); + } + numRows++; +} + +void OpProfileTree::calculateNumRowsAndColsForOp(const LogicalOperator* op, uint32_t& numRows, + uint32_t& numCols) { + if (!op->getNumChildren()) { + numRows = 1; + numCols = 1; + return; + } + + for (auto i = 0u; i < op->getNumChildren(); i++) { + auto numRowsInChild = 0u, numColsInChild = 0u; + calculateNumRowsAndColsForOp(op->getChild(i).get(), numRowsInChild, numColsInChild); + numCols += numColsInChild; + numRows = std::max(numRowsInChild, numRows); + } + numRows++; +} + +uint32_t OpProfileTree::fillOpProfileBoxes(const PhysicalOperator* op, uint32_t rowIdx, + uint32_t colIdx, uint32_t& maxFieldWidth, Profiler& profiler) { + auto opProfileBox = std::make_unique(PlanPrinter::getOperatorName(op), + PlanPrinter::getOperatorParams(op), op->getProfilerAttributes(profiler)); + maxFieldWidth = std::max(opProfileBox->getAttributeMaxLen(), maxFieldWidth); + insertOpProfileBox(rowIdx, colIdx, std::move(opProfileBox)); + if (!op->getNumChildren()) { + return 1; + } + + uint32_t colOffset = 0; + for (auto i = 0u; i < op->getNumChildren(); i++) { + colOffset += fillOpProfileBoxes(op->getChild(i), rowIdx + 1, colIdx + colOffset, + maxFieldWidth, profiler); + } + return colOffset; +} + +uint32_t OpProfileTree::fillOpProfileBoxes(const LogicalOperator* op, uint32_t rowIdx, + uint32_t colIdx, uint32_t& maxFieldWidth) { + auto opProfileBox = std::make_unique(PlanPrinter::getOperatorName(op), + PlanPrinter::getOperatorParams(op), + std::vector{"Cardinality: " + std::to_string(op->getCardinality())}); + maxFieldWidth = std::max(opProfileBox->getAttributeMaxLen(), maxFieldWidth); + insertOpProfileBox(rowIdx, colIdx, std::move(opProfileBox)); + if (!op->getNumChildren()) { + return 1; + } + + uint32_t colOffset = 0; + for (auto i = 0u; i < op->getNumChildren(); i++) { + colOffset += fillOpProfileBoxes(op->getChild(i).get(), rowIdx + 1, colIdx + colOffset, + maxFieldWidth); + } + return colOffset; +} + +void OpProfileTree::printOpProfileBoxUpperFrame(uint32_t rowIdx, std::ostringstream& oss) const { + for (auto i = 0u; i < opProfileBoxes[rowIdx].size(); i++) { + printSpaceIfNecessary(i, oss); + if (getOpProfileBox(rowIdx, i)) { + // If the current box has a parent, we need to put a "┴" in the box upper frame to + // connect to its parent. + if (hasOpProfileBoxOnUpperLeft(rowIdx, i)) { + auto leftFrameLength = (opProfileBoxWidth - 2 * BOX_FRAME_WIDTH - 1) / 2; + oss << "┌" << genHorizLine(leftFrameLength) << "┴" + << genHorizLine(opProfileBoxWidth - 2 * BOX_FRAME_WIDTH - 1 - leftFrameLength) + << "┐"; + } else { + oss << "┌" << genHorizLine(opProfileBoxWidth - 2 * BOX_FRAME_WIDTH) << "┐"; + } + } else { + oss << std::string(opProfileBoxWidth, ' '); + } + } + oss << '\n'; +} + +static std::string dashedLineAccountingForIndex(uint32_t width, uint32_t indent) { + return std::string(width - (1 + indent) * 2, '-'); +} + +void OpProfileTree::printOpProfileBoxes(uint32_t rowIdx, std::ostringstream& oss) const { + auto height = calculateRowHeight(rowIdx); + auto halfWayPoint = height / 2; + uint32_t offset = 0; + for (auto i = 0u; i < height; i++) { + for (auto j = 0u; j < opProfileBoxes[rowIdx].size(); j++) { + auto opProfileBox = getOpProfileBox(rowIdx, j); + if (opProfileBox && + i < 2 * (opProfileBox->getNumAttributes() + 1) + opProfileBox->getNumParams()) { + printSpaceIfNecessary(j, oss); + std::string textToPrint; + unsigned int numParams = opProfileBox->getNumParams(); + if (i == 0) { + textToPrint = opProfileBox->getOpName(); + } else if (i == 1) { // NOLINT(bugprone-branch-clone): Merging these branches is a + // logical error, and this conditional chain is pleasant. + textToPrint = dashedLineAccountingForIndex(opProfileBoxWidth, INDENT_WIDTH); + } else if (i <= numParams + 1) { + textToPrint = opProfileBox->getParamsName(i - 2); + } else if ((i - numParams - 1) % 2) { + textToPrint = dashedLineAccountingForIndex(opProfileBoxWidth, INDENT_WIDTH); + } else { + textToPrint = opProfileBox->getAttribute((i - numParams - 1) / 2 - 1); + } + auto numLeftSpaces = + (opProfileBoxWidth - (1 + INDENT_WIDTH) * 2 - textToPrint.length()) / 2; + auto numRightSpace = opProfileBoxWidth - (1 + INDENT_WIDTH) * 2 - + textToPrint.length() - numLeftSpaces; + oss << "│" << std::string(INDENT_WIDTH + numLeftSpaces, ' ') << textToPrint + << std::string(INDENT_WIDTH + numRightSpace, ' ') << "│"; + } else if (opProfileBox) { + // If we have printed out all the attributes in the current opProfileBox, print + // empty spaces as placeholders. + printSpaceIfNecessary(j, oss); + oss << "│" << std::string(opProfileBoxWidth - 2, ' ') << "│"; + } else { + if (hasOpProfileBox(rowIdx + 1, j) && i >= halfWayPoint) { + auto leftHorizLineSize = (opProfileBoxWidth - 1) / 2; + if (i == halfWayPoint) { + oss << genHorizLine(leftHorizLineSize + 1); + if (hasOpProfileBox(rowIdx + 1, j + 4) && !hasOpProfileBox(rowIdx, j + 1)) { + oss << "┬" << genHorizLine(opProfileBoxWidth - 1 - leftHorizLineSize); + } else { + if ((hasOpProfileBox(rowIdx + 1, j + 1) && + !hasOpProfileBox(rowIdx, j) && + !hasOpProfileBox(rowIdx, j + 1)) || + (hasOpProfileBox(rowIdx + 1, j + 2) && + !hasOpProfileBox(rowIdx, j + 1))) { + oss << "┬" << genHorizLine(opProfileBoxWidth / 2); + } else { + oss << "┐" + << std::string(opProfileBoxWidth - 1 - leftHorizLineSize, ' '); + } + } + } else if (i > halfWayPoint) { + printSpaceIfNecessary(j, oss); + oss << std::string(leftHorizLineSize, ' ') << "│" + << std::string(opProfileBoxWidth - 1 - leftHorizLineSize, ' '); + } + } else if (((hasOpProfileBox(rowIdx + 1, j + 1) && + !hasOpProfileBox(rowIdx, j + 1)) || + (hasOpProfileBox(rowIdx + 1, j + 3) && + !hasOpProfileBox(rowIdx, j + 3) && + !hasOpProfileBox(rowIdx, j + 1) && + !hasOpProfileBox(rowIdx, j + 2)) || + (hasOpProfileBox(rowIdx + 1, j - 2) && + !hasOpProfileBox(rowIdx, j - 2) && + hasOpProfileBox(rowIdx + 1, j + 3))) && + i == halfWayPoint && !hasOpProfileBox(rowIdx, j + 2)) { + oss << genHorizLine(opProfileBoxWidth + 1); + offset = offset == 0 ? 1 : 0; + } else { + printSpaceIfNecessary(j, oss); + oss << std::string(opProfileBoxWidth, ' '); + } + } + } + oss << '\n'; + } +} + +void OpProfileTree::printOpProfileBoxLowerFrame(uint32_t rowIdx, std::ostringstream& oss) const { + for (auto i = 0u; i < opProfileBoxes[rowIdx].size(); i++) { + if (getOpProfileBox(rowIdx, i)) { + printSpaceIfNecessary(i, oss); + // If the current opProfileBox has a child, we need to print out a connector to it. + if (hasOpProfileBox(rowIdx + 1, i)) { + auto leftFrameLength = (opProfileBoxWidth - 2 * BOX_FRAME_WIDTH - 1) / 2; + oss << "└" << genHorizLine(leftFrameLength) << "┬" + << genHorizLine(opProfileBoxWidth - 2 * BOX_FRAME_WIDTH - 1 - leftFrameLength) + << "┘"; + } else { + oss << "└" << genHorizLine(opProfileBoxWidth - 2) << "┘"; + } + } else if (hasOpProfileBox(rowIdx + 1, i)) { + // If there is a opProfileBox at the bottom, we need to print out a vertical line to + // connect it. + auto leftFrameLength = (opProfileBoxWidth - 1) / 2; + printSpaceIfNecessary(i, oss); + oss << std::string(leftFrameLength, ' ') << "│" + << std::string(opProfileBoxWidth - leftFrameLength - 1, ' '); + } else { + printSpaceIfNecessary(i, oss); + oss << std::string(opProfileBoxWidth, ' '); + } + } + oss << '\n'; +} + +void OpProfileTree::prettyPrintPlanTitle(std::ostringstream& oss, std::string title) const { + const std::string plan = title; + oss << "┌" << genHorizLine(opProfileBoxWidth - 2) << "┐" << '\n'; + oss << "│┌" << genHorizLine(opProfileBoxWidth - 4) << "┐│" << '\n'; + auto numLeftSpaces = (opProfileBoxWidth - plan.length() - 2 * (2 + INDENT_WIDTH)) / 2; + auto numRightSpaces = + opProfileBoxWidth - plan.length() - 2 * (2 + INDENT_WIDTH) - numLeftSpaces; + oss << "││" << std::string(INDENT_WIDTH + numLeftSpaces, ' ') << plan + << std::string(INDENT_WIDTH + numRightSpaces, ' ') << "││" << '\n'; + oss << "│└" << genHorizLine(opProfileBoxWidth - 4) << "┘│" << '\n'; + oss << "└" << genHorizLine(opProfileBoxWidth - 2) << "┘" << '\n'; +} + +std::string OpProfileTree::genHorizLine(uint32_t len) { + std::ostringstream tableFrame; + for (auto i = 0u; i < len; i++) { + tableFrame << "─"; + } + return tableFrame.str(); +} + +void OpProfileTree::insertOpProfileBox(uint32_t rowIdx, uint32_t colIdx, + std::unique_ptr opProfileBox) { + validateRowIdxAndColIdx(rowIdx, colIdx); + opProfileBoxes[rowIdx][colIdx] = std::move(opProfileBox); +} + +OpProfileBox* OpProfileTree::getOpProfileBox(uint32_t rowIdx, uint32_t colIdx) const { + validateRowIdxAndColIdx(rowIdx, colIdx); + return opProfileBoxes[rowIdx][colIdx].get(); +} + +bool OpProfileTree::hasOpProfileBoxOnUpperLeft(uint32_t rowIdx, uint32_t colIdx) const { + validateRowIdxAndColIdx(rowIdx, colIdx); + for (auto i = 0u; i <= colIdx; i++) { + if (hasOpProfileBox(rowIdx - 1, i)) { + return true; + } + } + return false; +} + +uint32_t OpProfileTree::calculateRowHeight(uint32_t rowIdx) const { + validateRowIdxAndColIdx(rowIdx, 0 /* colIdx */); + auto height = 0u; + for (auto i = 0u; i < opProfileBoxes[rowIdx].size(); i++) { + auto opProfileBox = getOpProfileBox(rowIdx, i); + if (opProfileBox) { + height = std::max(height, + 2 * opProfileBox->getNumAttributes() + opProfileBox->getNumParams()); + } + } + return height + 2; +} + +nlohmann::json PlanPrinter::printPlanToJson(const PhysicalPlan* physicalPlan, Profiler* profiler) { + return toJson(physicalPlan->lastOperator.get(), *profiler); +} + +std::ostringstream PlanPrinter::printPlanToOstream(const PhysicalPlan* physicalPlan, + Profiler* profiler) { + return OpProfileTree(physicalPlan->lastOperator.get(), *profiler).printPlanToOstream(); +} + +nlohmann::json PlanPrinter::printPlanToJson(const LogicalPlan* logicalPlan) { + return toJson(logicalPlan->getLastOperator().get()); +} + +std::ostringstream PlanPrinter::printPlanToOstream(const LogicalPlan* logicalPlan) { + return OpProfileTree(logicalPlan->getLastOperator().get()).printLogicalPlanToOstream(); +} + +std::string PlanPrinter::getOperatorName(const PhysicalOperator* physicalOperator) { + return PhysicalOperatorUtils::operatorToString(physicalOperator); +} + +std::string PlanPrinter::getOperatorParams(const PhysicalOperator* physicalOperator) { + return physicalOperator->getPrintInfo()->toString(); +} + +std::string PlanPrinter::getOperatorName(const LogicalOperator* logicalOperator) { + return LogicalOperatorUtils::logicalOperatorTypeToString(logicalOperator->getOperatorType()); +} + +std::string PlanPrinter::getOperatorParams(const LogicalOperator* logicalOperator) { + return logicalOperator->getPrintInfo()->toString(); +} + +nlohmann::json PlanPrinter::toJson(const PhysicalOperator* physicalOperator, Profiler& profiler_) { + auto json = nlohmann::json(); + json["Name"] = getOperatorName(physicalOperator); + if (profiler_.enabled) { + for (auto& [key, val] : physicalOperator->getProfilerKeyValAttributes(profiler_)) { + json[key] = val; + } + } + for (auto i = 0u; i < physicalOperator->getNumChildren(); ++i) { + json["Child" + std::to_string(i)] = toJson(physicalOperator->getChild(i), profiler_); + } + return json; +} + +nlohmann::json PlanPrinter::toJson(const LogicalOperator* logicalOperator) { + auto json = nlohmann::json(); + json["Name"] = getOperatorName(logicalOperator); + for (auto i = 0u; i < logicalOperator->getNumChildren(); ++i) { + json["Child" + std::to_string(i)] = toJson(logicalOperator->getChild(i).get()); + } + return json; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/prepared_statement.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/prepared_statement.cpp new file mode 100644 index 0000000000..23ed9f646f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/prepared_statement.cpp @@ -0,0 +1,89 @@ +#include "main/prepared_statement.h" + +#include "binder/expression/expression.h" // IWYU pragma: keep +#include "common/exception/binder.h" +#include "common/types/value/value.h" +#include "planner/operator/logical_plan.h" // IWYU pragma: keep + +using namespace lbug::common; + +namespace lbug { +namespace main { + +CachedPreparedStatement::CachedPreparedStatement() = default; +CachedPreparedStatement::~CachedPreparedStatement() = default; + +std::vector CachedPreparedStatement::getColumnNames() const { + std::vector names; + for (auto& column : columns) { + names.push_back(column->toString()); + } + return names; +} + +std::vector CachedPreparedStatement::getColumnTypes() const { + std::vector types; + for (auto& column : columns) { + types.push_back(column->getDataType().copy()); + } + return types; +} + +bool PreparedStatement::isSuccess() const { + return success; +} + +std::string PreparedStatement::getErrorMessage() const { + return errMsg; +} + +bool PreparedStatement::isReadOnly() const { + return readOnly; +} + +StatementType PreparedStatement::getStatementType() const { + return preparedSummary.statementType; +} + +static void validateParam(const std::string& paramName, Value* newVal, Value* oldVal) { + if (newVal->getDataType().getLogicalTypeID() == LogicalTypeID::POINTER && + newVal->getValue() != oldVal->getValue()) { + throw BinderException(stringFormat( + "When preparing the current statement the dataframe passed into parameter " + "'{}' was different from the one provided during prepare. Dataframes parameters " + "are only used during prepare; please make sure that they are either not passed into " + "execute or they match the one passed during prepare.", + paramName)); + } +} + +std::unordered_set PreparedStatement::getKnownParameters() { + std::unordered_set result; + for (auto& [k, _] : parameterMap) { + result.insert(k); + } + return result; +} + +void PreparedStatement::updateParameter(const std::string& name, Value* value) { + KU_ASSERT(parameterMap.contains(name)); + validateParam(name, value, parameterMap.at(name).get()); + *parameterMap.at(name) = std::move(*value); +} + +void PreparedStatement::addParameter(const std::string& name, Value* value) { + parameterMap.insert({name, std::make_shared(*value)}); +} + +PreparedStatement::~PreparedStatement() = default; + +std::unique_ptr PreparedStatement::getPreparedStatementWithError( + const std::string& errorMessage) { + auto preparedStatement = std::make_unique(); + preparedStatement->success = false; + preparedStatement->errMsg = errorMessage; + return preparedStatement; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/prepared_statement_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/prepared_statement_manager.cpp new file mode 100644 index 0000000000..62c21b3a84 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/prepared_statement_manager.cpp @@ -0,0 +1,29 @@ +#include "main/prepared_statement_manager.h" + +#include "common/assert.h" +#include "main/prepared_statement.h" + +namespace lbug { +namespace main { + +CachedPreparedStatementManager::CachedPreparedStatementManager() = default; + +CachedPreparedStatementManager::~CachedPreparedStatementManager() = default; + +std::string CachedPreparedStatementManager::addStatement( + std::unique_ptr statement) { + std::unique_lock lck{mtx}; + auto idx = std::to_string(currentIdx); + currentIdx++; + statementMap.insert({idx, std::move(statement)}); + return idx; +} + +CachedPreparedStatement* CachedPreparedStatementManager::getCachedStatement( + const std::string& name) const { + KU_ASSERT(containsStatement(name)); + return statementMap.at(name).get(); +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result.cpp new file mode 100644 index 0000000000..1432e9b4b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result.cpp @@ -0,0 +1,132 @@ +#include "main/query_result.h" + +#include "common/arrow/arrow_converter.h" +#include "main/query_result/materialized_query_result.h" +#include "processor/result/flat_tuple.h" + +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace main { + +QueryResult::QueryResult() + : type{QueryResultType::FTABLE}, nextQueryResult{nullptr}, queryResultIterator{this}, + dbLifeCycleManager{nullptr} {} + +QueryResult::QueryResult(QueryResultType type) + : type{type}, nextQueryResult{nullptr}, queryResultIterator{this}, dbLifeCycleManager{nullptr} { + +} + +QueryResult::QueryResult(QueryResultType type, std::vector columnNames, + std::vector columnTypes) + : type{type}, columnNames{std::move(columnNames)}, columnTypes{std::move(columnTypes)}, + nextQueryResult{nullptr}, queryResultIterator{this}, dbLifeCycleManager{nullptr} { + tuple = std::make_shared(this->columnTypes); +} + +QueryResult::~QueryResult() = default; + +bool QueryResult::isSuccess() const { + return success; +} + +std::string QueryResult::getErrorMessage() const { + return errMsg; +} + +size_t QueryResult::getNumColumns() const { + return columnTypes.size(); +} + +std::vector QueryResult::getColumnNames() const { + return columnNames; +} + +std::vector QueryResult::getColumnDataTypes() const { + return LogicalType::copy(columnTypes); +} + +QuerySummary* QueryResult::getQuerySummary() const { + return querySummary.get(); +} + +QuerySummary* QueryResult::getQuerySummaryUnsafe() { + return querySummary.get(); +} + +void QueryResult::checkDatabaseClosedOrThrow() const { + if (!dbLifeCycleManager) { + return; + } + dbLifeCycleManager->checkDatabaseClosedOrThrow(); +} + +bool QueryResult::hasNextQueryResult() const { + checkDatabaseClosedOrThrow(); + return queryResultIterator.hasNextQueryResult(); +} + +QueryResult* QueryResult::getNextQueryResult() { + checkDatabaseClosedOrThrow(); + if (hasNextQueryResult()) { + ++queryResultIterator; + return queryResultIterator.getCurrentResult(); + } + return nullptr; +} + +std::unique_ptr QueryResult::getArrowSchema() const { + checkDatabaseClosedOrThrow(); + return ArrowConverter::toArrowSchema(getColumnDataTypes(), getColumnNames(), + false /* fallbackExtensionTypes */); +} + +void QueryResult::validateQuerySucceed() const { + if (!success) { + throw Exception(errMsg); + } +} + +void QueryResult::setColumnNames(std::vector columnNames) { + this->columnNames = std::move(columnNames); +} + +void QueryResult::setColumnTypes(std::vector columnTypes) { + this->columnTypes = std::move(columnTypes); + tuple = std::make_shared(this->columnTypes); +} + +void QueryResult::addNextResult(std::unique_ptr next_) { + nextQueryResult = std::move(next_); +} + +std::unique_ptr QueryResult::moveNextResult() { + return std::move(nextQueryResult); +} + +void QueryResult::setQuerySummary(std::unique_ptr summary) { + querySummary = std::move(summary); +} + +void QueryResult::setDBLifeCycleManager( + std::shared_ptr dbLifeCycleManager) { + this->dbLifeCycleManager = dbLifeCycleManager; + if (nextQueryResult) { + nextQueryResult->setDBLifeCycleManager(dbLifeCycleManager); + } +} + +std::unique_ptr QueryResult::getQueryResultWithError(const std::string& errorMessage) { + // TODO(Xiyang): consider introduce error query result class. + auto queryResult = std::make_unique(); + queryResult->success = false; + queryResult->errMsg = errorMessage; + queryResult->nextQueryResult = nullptr; + queryResult->queryResultIterator = QueryResultIterator{queryResult.get()}; + return queryResult; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/CMakeLists.txt new file mode 100644 index 0000000000..71ac782bb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_main_query_result + OBJECT + arrow_query_result.cpp + materialized_query_result.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/arrow_query_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/arrow_query_result.cpp new file mode 100644 index 0000000000..9b7e224e38 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/arrow_query_result.cpp @@ -0,0 +1,81 @@ +#include "main/query_result/arrow_query_result.h" + +#include "common/arrow/arrow_row_batch.h" +#include "common/exception/not_implemented.h" +#include "common/exception/runtime.h" +#include "processor/result/factorized_table.h" + +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace main { + +ArrowQueryResult::ArrowQueryResult(std::vector arrays, int64_t chunkSize) + : QueryResult{type_}, arrays{std::move(arrays)}, chunkSize_{chunkSize} { + for (auto& array : this->arrays) { + numTuples += array.length; + } +} + +ArrowQueryResult::ArrowQueryResult(std::vector columnNames, + std::vector columnTypes, FactorizedTable& table, int64_t chunkSize) + : QueryResult{type_, std::move(columnNames), std::move(columnTypes)}, chunkSize_{chunkSize} { + auto iterator = FactorizedTableIterator(table); + while (iterator.hasNext()) { + arrays.push_back(getArray(iterator, chunkSize)); + } +} + +uint64_t ArrowQueryResult::getNumTuples() const { + return numTuples; +} + +ArrowArray ArrowQueryResult::getArray(FactorizedTableIterator& iterator, int64_t chunkSize) { + auto rowBatch = ArrowRowBatch(columnTypes, chunkSize, false /* fallbackExtensionTypes */); + auto rowBatchSize = 0u; + while (rowBatchSize < chunkSize) { + if (!iterator.hasNext()) { + break; + } + (void)iterator.getNext(*tuple); + rowBatch.append(*tuple); + rowBatchSize++; + numTuples++; + } + return rowBatch.toArray(columnTypes); +} + +bool ArrowQueryResult::hasNext() const { + throw NotImplementedException( + "ArrowQueryResult does not implement hasNext. Use MaterializedQueryResult instead."); +} + +std::shared_ptr ArrowQueryResult::getNext() { + throw NotImplementedException( + "ArrowQueryResult does not implement getNext. Use MaterializedQueryResult instead."); +} + +void ArrowQueryResult::resetIterator() { + cursor = 0u; +} + +std::string ArrowQueryResult::toString() const { + throw NotImplementedException( + "ArrowQueryResult does not implement toString. Use MaterializedQueryResult instead."); +} + +bool ArrowQueryResult::hasNextArrowChunk() { + return cursor < arrays.size(); +} + +std::unique_ptr ArrowQueryResult::getNextArrowChunk(int64_t chunkSize) { + if (chunkSize != chunkSize_) { + throw RuntimeException( + stringFormat("Chunk size does not match expected value {}.", chunkSize_)); + } + return std::make_unique(arrays[cursor++]); +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/materialized_query_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/materialized_query_result.cpp new file mode 100644 index 0000000000..9dc67235f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_result/materialized_query_result.cpp @@ -0,0 +1,110 @@ +#include "main/query_result/materialized_query_result.h" + +#include "common/arrow/arrow_row_batch.h" +#include "common/exception/runtime.h" +#include "processor/result/factorized_table.h" +#include "processor/result/flat_tuple.h" + +using namespace lbug::common; +using namespace lbug::processor; + +namespace lbug { +namespace main { + +MaterializedQueryResult::MaterializedQueryResult() = default; + +MaterializedQueryResult::MaterializedQueryResult(std::shared_ptr table) + : QueryResult{type_}, table{std::move(table)} { + iterator = std::make_unique(*this->table); +} + +MaterializedQueryResult::MaterializedQueryResult(std::vector columnNames, + std::vector columnTypes, std::shared_ptr table) + : QueryResult{type_, std::move(columnNames), std::move(columnTypes)}, table{std::move(table)} { + iterator = std::make_unique(*this->table); +} + +MaterializedQueryResult::~MaterializedQueryResult() { + if (!dbLifeCycleManager) { + return; + } + if (!table) { + return; + } + table->setPreventDestruction(dbLifeCycleManager->isDatabaseClosed); +} + +uint64_t MaterializedQueryResult::getNumTuples() const { + checkDatabaseClosedOrThrow(); + validateQuerySucceed(); + return table->getTotalNumFlatTuples(); +} + +bool MaterializedQueryResult::hasNext() const { + checkDatabaseClosedOrThrow(); + validateQuerySucceed(); + return iterator->hasNext(); +} + +std::shared_ptr MaterializedQueryResult::getNext() { + checkDatabaseClosedOrThrow(); + validateQuerySucceed(); + if (!hasNext()) { + throw RuntimeException( + "No more tuples in QueryResult, Please check hasNext() before calling getNext()."); + } + iterator->getNext(*tuple); + return tuple; +} + +void MaterializedQueryResult::resetIterator() { + checkDatabaseClosedOrThrow(); + validateQuerySucceed(); + iterator->resetState(); +} + +std::string MaterializedQueryResult::toString() const { + checkDatabaseClosedOrThrow(); + if (!isSuccess()) { + return errMsg; + } + std::string result; + // print header + for (auto i = 0u; i < columnNames.size(); ++i) { + if (i != 0) { + result += "|"; + } + result += columnNames[i]; + } + result += "\n"; + auto tuple_ = FlatTuple(this->columnTypes); + auto iterator_ = FactorizedTableIterator(*table); + while (iterator->hasNext()) { + iterator->getNext(tuple_); + result += tuple_.toString(); + } + return result; +} + +bool MaterializedQueryResult::hasNextArrowChunk() { + return hasNext(); +} + +std::unique_ptr MaterializedQueryResult::getNextArrowChunk(int64_t chunkSize) { + checkDatabaseClosedOrThrow(); + auto rowBatch = + std::make_unique(columnTypes, chunkSize, false /* fallbackExtensionTypes */); + auto rowBatchSize = 0u; + while (rowBatchSize < chunkSize) { + if (!iterator->hasNext()) { + break; + } + (void)iterator->getNext(*tuple); + rowBatch->append(*tuple); + rowBatchSize++; + } + return std::make_unique(rowBatch->toArray(columnTypes)); +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_summary.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_summary.cpp new file mode 100644 index 0000000000..2f079e7747 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/query_summary.cpp @@ -0,0 +1,39 @@ +#include "main/query_summary.h" + +#include "common/enums/statement_type.h" + +using namespace lbug::common; + +namespace lbug { +namespace main { + +double QuerySummary::getCompilingTime() const { + return preparedSummary.compilingTime; +} + +double QuerySummary::getExecutionTime() const { + return executionTime; +} + +void QuerySummary::setExecutionTime(double time) { + executionTime = time; +} + +void QuerySummary::incrementCompilingTime(double increment) { + preparedSummary.compilingTime += increment; +} + +void QuerySummary::incrementExecutionTime(double increment) { + executionTime += increment; +} + +bool QuerySummary::isExplain() const { + return preparedSummary.statementType == StatementType::EXPLAIN; +} + +StatementType QuerySummary::getStatementType() const { + return preparedSummary.statementType; +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/settings.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/settings.cpp new file mode 100644 index 0000000000..c7e6950d91 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/settings.cpp @@ -0,0 +1,227 @@ +#include "main/settings.h" + +#include "common/exception/runtime.h" +#include "common/task_system/progress_bar.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/storage_utils.h" + +namespace lbug { +namespace main { + +void ThreadsSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->numThreads = parameter.getValue(); +} + +common::Value ThreadsSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->numThreads); +} + +void WarningLimitSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->warningLimit = parameter.getValue(); +} + +common::Value WarningLimitSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->warningLimit); +} + +void TimeoutSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->timeoutInMS = parameter.getValue(); +} + +common::Value TimeoutSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->timeoutInMS); +} + +void ProgressBarSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->enableProgressBar = parameter.getValue(); + common::ProgressBar::Get(*context)->toggleProgressBarPrinting(parameter.getValue()); +} + +common::Value ProgressBarSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->enableProgressBar); +} + +void VarLengthExtendMaxDepthSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->varLengthMaxDepth = parameter.getValue(); +} + +common::Value VarLengthExtendMaxDepthSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->varLengthMaxDepth); +} + +void SparseFrontierThresholdSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->sparseFrontierThreshold = parameter.getValue(); +} + +common::Value SparseFrontierThresholdSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->sparseFrontierThreshold); +} + +void EnableSemiMaskSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->enableSemiMask = parameter.getValue(); +} + +common::Value EnableSemiMaskSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->enableSemiMask); +} + +void DisableMapKeyCheck::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->disableMapKeyCheck = parameter.getValue(); +} + +common::Value DisableMapKeyCheck::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->disableMapKeyCheck); +} + +void EnableZoneMapSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->enableZoneMap = parameter.getValue(); +} + +common::Value EnableZoneMapSetting::getSetting(const ClientContext* context) { + return common::Value(context->getClientConfig()->enableZoneMap); +} + +void HomeDirectorySetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->homeDirectory = parameter.getValue(); +} + +common::Value HomeDirectorySetting::getSetting(const ClientContext* context) { + return common::Value::createValue(context->getClientConfig()->homeDirectory); +} + +void FileSearchPathSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->fileSearchPath = parameter.getValue(); +} + +common::Value FileSearchPathSetting::getSetting(const ClientContext* context) { + return common::Value::createValue(context->getClientConfig()->fileSearchPath); +} + +void RecursivePatternSemanticSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + const auto input = parameter.getValue(); + context->getClientConfigUnsafe()->recursivePatternSemantic = + common::PathSemanticUtils::fromString(input); +} + +common::Value RecursivePatternSemanticSetting::getSetting(const ClientContext* context) { + const auto result = + common::PathSemanticUtils::toString(context->getClientConfig()->recursivePatternSemantic); + return common::Value::createValue(result); +} + +void RecursivePatternFactorSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->recursivePatternCardinalityScaleFactor = + parameter.getValue(); +} + +common::Value RecursivePatternFactorSetting::getSetting(const ClientContext* context) { + return common::Value::createValue( + context->getClientConfig()->recursivePatternCardinalityScaleFactor); +} + +void EnableMVCCSetting::setContext(ClientContext* context, const common::Value& parameter) { + KU_ASSERT(parameter.getDataType().getLogicalTypeID() == common::LogicalTypeID::BOOL); + // TODO: This is a temporary solution to make tests of multiple write transactions easier. + context->getDBConfigUnsafe()->enableMultiWrites = parameter.getValue(); +} + +common::Value EnableMVCCSetting::getSetting(const ClientContext* context) { + return common::Value(context->getDBConfig()->enableMultiWrites); +} + +void CheckpointThresholdSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + context->getDBConfigUnsafe()->checkpointThreshold = parameter.getValue(); +} + +common::Value CheckpointThresholdSetting::getSetting(const ClientContext* context) { + return common::Value(context->getDBConfig()->checkpointThreshold); +} + +void AutoCheckpointSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getDBConfigUnsafe()->autoCheckpoint = parameter.getValue(); +} + +common::Value AutoCheckpointSetting::getSetting(const ClientContext* context) { + return common::Value(context->getDBConfig()->autoCheckpoint); +} + +void ForceCheckpointClosingDBSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + context->getDBConfigUnsafe()->forceCheckpointOnClose = parameter.getValue(); +} + +common::Value ForceCheckpointClosingDBSetting::getSetting(const ClientContext* context) { + return common::Value(context->getDBConfig()->forceCheckpointOnClose); +} + +void EnableOptimizerSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->enablePlanOptimizer = parameter.getValue(); +} + +common::Value EnableOptimizerSetting::getSetting(const ClientContext* context) { + return common::Value::createValue(context->getClientConfig()->enablePlanOptimizer); +} + +void EnableInternalCatalogSetting::setContext(ClientContext* context, + const common::Value& parameter) { + parameter.validateType(inputType); + context->getClientConfigUnsafe()->enableInternalCatalog = parameter.getValue(); +} + +common::Value EnableInternalCatalogSetting::getSetting(const ClientContext* context) { + return common::Value::createValue(context->getClientConfig()->enableInternalCatalog); +} + +void SpillToDiskSetting::setContext(ClientContext* context, const common::Value& parameter) { + parameter.validateType(inputType); + context->getDBConfigUnsafe()->enableSpillingToDisk = parameter.getValue(); + const auto& dbConfig = *context->getDBConfig(); + std::string spillPath; + if (dbConfig.enableSpillingToDisk) { + if (context->isInMemory()) { + throw common::RuntimeException( + "Cannot set spill_to_disk to true for an in-memory database!"); + } + if (!context->canExecuteWriteQuery()) { + throw common::RuntimeException( + "Cannot set spill_to_disk to true for a read only database!"); + } + spillPath = storage::StorageUtils::getTmpFilePath(context->getDatabasePath()); + } else { + // Set path to empty will disable spiller. + spillPath = ""; + } + storage::MemoryManager::Get(*context)->getBufferManager()->resetSpiller(spillPath); +} + +common::Value SpillToDiskSetting::getSetting(const ClientContext* context) { + return common::Value::createValue(context->getDBConfig()->enableSpillingToDisk); +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/storage_driver.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/storage_driver.cpp new file mode 100644 index 0000000000..167fe46163 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/storage_driver.cpp @@ -0,0 +1,180 @@ +#include "main/storage_driver.h" + +#include + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "main/client_context.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" + +using namespace lbug::common; +using namespace lbug::transaction; +using namespace lbug::storage; +using namespace lbug::catalog; + +namespace lbug { +namespace main { + +StorageDriver::StorageDriver(Database* database) { + clientContext = std::make_unique(database); +} + +StorageDriver::~StorageDriver() = default; + +static TableCatalogEntry* getEntry(const ClientContext& context, const std::string& tableName) { + return Catalog::Get(context)->getTableCatalogEntry(Transaction::Get(context), tableName); +} + +static Table* getTable(const ClientContext& context, const std::string& tableName) { + return StorageManager::Get(context)->getTable(getEntry(context, tableName)->getTableID()); +} + +static bool validateNumericalType(const LogicalType& type) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + case LogicalTypeID::INT128: + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::INT8: + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + return true; + default: + return false; + } +} + +static std::string getUnsupportedTypeErrMsg(const LogicalType& type) { + return stringFormat("Unsupported data type {}.", type.toString()); +} + +static uint32_t getElementSize(const LogicalType& type) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + case LogicalTypeID::INT128: + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::INT8: + case LogicalTypeID::UINT64: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT8: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + return PhysicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); + case LogicalTypeID::ARRAY: { + auto& childType = ArrayType::getChildType(type); + if (!validateNumericalType(childType)) { + throw RuntimeException(getUnsupportedTypeErrMsg(type)); + } + auto numElements = ArrayType::getNumElements(type); + return numElements * PhysicalTypeUtils::getFixedTypeSize(childType.getPhysicalType()); + } + default: + throw RuntimeException(getUnsupportedTypeErrMsg(type)); + } +} + +void StorageDriver::scan(const std::string& nodeName, const std::string& propertyName, + common::offset_t* offsets, size_t numOffsets, uint8_t* result, size_t numThreads) { + clientContext->query("BEGIN TRANSACTION READ ONLY;"); + auto entry = getEntry(*clientContext, nodeName); + auto columnID = entry->getColumnID(propertyName); + auto table = getTable(*clientContext, nodeName); + auto& dataType = table->ptrCast()->getColumn(columnID).getDataType(); + auto elementSize = getElementSize(dataType); + auto numOffsetsPerThread = numOffsets / numThreads + 1; + auto remainingNumOffsets = numOffsets; + auto current_buffer = result; + std::vector threads; + while (remainingNumOffsets > 0) { + auto numOffsetsToScan = std::min(numOffsetsPerThread, remainingNumOffsets); + threads.emplace_back(&StorageDriver::scanColumn, this, table, columnID, offsets, + numOffsetsToScan, current_buffer); + offsets += numOffsetsToScan; + current_buffer += numOffsetsToScan * elementSize; + remainingNumOffsets -= numOffsetsToScan; + } + for (auto& thread : threads) { + thread.join(); + } + clientContext->query("COMMIT"); +} + +uint64_t StorageDriver::getNumNodes(const std::string& nodeName) const { + clientContext->query("BEGIN TRANSACTION READ ONLY;"); + auto transaction = Transaction::Get(*clientContext); + auto result = getTable(*clientContext, nodeName)->getNumTotalRows(transaction); + clientContext->query("COMMIT"); + return result; +} + +uint64_t StorageDriver::getNumRels(const std::string& relName) const { + clientContext->query("BEGIN TRANSACTION READ ONLY;"); + auto transaction = Transaction::Get(*clientContext); + auto result = getTable(*clientContext, relName)->getNumTotalRows(transaction); + clientContext->query("COMMIT"); + return result; +} + +void StorageDriver::scanColumn(Table* table, column_id_t columnID, const offset_t* offsets, + size_t size, uint8_t* result) const { + // Create scan state. + auto nodeTable = table->ptrCast(); + auto column = &nodeTable->getColumn(columnID); + // Create value vectors + auto idVector = std::make_unique(LogicalType::INTERNAL_ID()); + auto columnVector = std::make_unique(column->getDataType().copy(), + MemoryManager::Get(*clientContext)); + auto vectorState = DataChunkState::getSingleValueDataChunkState(); + idVector->state = vectorState; + columnVector->state = vectorState; + auto scanState = std::make_unique(idVector.get(), + std::vector{columnVector.get()}, vectorState); + auto transaction = Transaction::Get(*clientContext); + switch (auto physicalType = column->getDataType().getPhysicalType()) { + case PhysicalTypeID::BOOL: + case PhysicalTypeID::INT128: + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT8: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::DOUBLE: + case PhysicalTypeID::FLOAT: { + for (auto i = 0u; i < size; ++i) { + idVector->setValue(0, nodeID_t{offsets[i], table->getTableID()}); + [[maybe_unused]] auto res = nodeTable->lookup(transaction, *scanState); + memcpy(result, columnVector->getData(), + PhysicalTypeUtils::getFixedTypeSize(physicalType)); + } + } break; + case PhysicalTypeID::ARRAY: { + auto& childType = ArrayType::getChildType(column->getDataType()); + auto elementSize = PhysicalTypeUtils::getFixedTypeSize(childType.getPhysicalType()); + auto numElements = ArrayType::getNumElements(column->getDataType()); + auto arraySize = elementSize * numElements; + for (auto i = 0u; i < size; ++i) { + idVector->setValue(0, nodeID_t{offsets[i], table->getTableID()}); + [[maybe_unused]] auto res = nodeTable->lookup(transaction, *scanState); + auto dataVector = ListVector::getDataVector(columnVector.get()); + memcpy(result, dataVector->getData() + i * arraySize, arraySize); + } + } break; + default: + throw RuntimeException(stringFormat("Not supported data type in StorageDriver::scanColumn", + PhysicalTypeUtils::toString(physicalType))); + } +} + +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/main/version.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/main/version.cpp new file mode 100644 index 0000000000..253bd7452f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/main/version.cpp @@ -0,0 +1,15 @@ +#include "main/version.h" + +#include "storage/storage_version_info.h" + +namespace lbug { +namespace main { +const char* Version::getVersion() { + return LBUG_CMAKE_VERSION; +} + +uint64_t Version::getStorageVersion() { + return storage::StorageVersionInfo::getStorageVersion(); +} +} // namespace main +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/CMakeLists.txt new file mode 100644 index 0000000000..7ba4d4ed2b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/CMakeLists.txt @@ -0,0 +1,21 @@ +add_library(lbug_optimizer + OBJECT + acc_hash_join_optimizer.cpp + agg_key_dependency_optimizer.cpp + cardinality_updater.cpp + correlated_subquery_unnest_solver.cpp + factorization_rewriter.cpp + filter_push_down_optimizer.cpp + logical_operator_collector.cpp + logical_operator_visitor.cpp + optimizer.cpp + projection_push_down_optimizer.cpp + schema_populator.cpp + remove_factorization_rewriter.cpp + remove_unnecessary_join_optimizer.cpp + top_k_optimizer.cpp + limit_push_down_optimizer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/acc_hash_join_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/acc_hash_join_optimizer.cpp new file mode 100644 index 0000000000..d15f1de604 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/acc_hash_join_optimizer.cpp @@ -0,0 +1,412 @@ +#include "optimizer/acc_hash_join_optimizer.h" + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "optimizer/logical_operator_collector.h" +#include "planner/operator/extend/logical_recursive_extend.h" +#include "planner/operator/logical_accumulate.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_intersect.h" +#include "planner/operator/logical_path_property_probe.h" +#include "planner/operator/scan/logical_scan_node_table.h" +#include "planner/operator/sip/logical_semi_masker.h" + +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::planner; +using namespace lbug::function; + +namespace lbug { +namespace optimizer { + +static std::shared_ptr appendAccumulate(std::shared_ptr child) { + auto accumulate = std::make_shared(AccumulateType::REGULAR, + expression_vector{}, nullptr /* mark */, std::move(child)); + accumulate->computeFlatSchema(); + return accumulate; +} + +static table_id_vector_t getTableIDs(const std::vector& entries) { + table_id_vector_t result; + for (auto& entry : entries) { + result.push_back(entry->getTableID()); + } + return result; +} + +static std::vector getTableIDs(const LogicalOperator* op, + SemiMaskTargetType targetType) { + switch (op->getOperatorType()) { + case LogicalOperatorType::SCAN_NODE_TABLE: { + return op->constCast().getTableIDs(); + } + case LogicalOperatorType::RECURSIVE_EXTEND: { + auto& bindData = op->constCast().getBindData(); + switch (targetType) { + case SemiMaskTargetType::RECURSIVE_EXTEND_INPUT_NODE: { + auto& node = bindData.nodeInput->constCast(); + return getTableIDs(node.getEntries()); + } + case SemiMaskTargetType::RECURSIVE_EXTEND_OUTPUT_NODE: { + auto& node = bindData.nodeOutput->constCast(); + return getTableIDs(node.getEntries()); + } + default: + KU_UNREACHABLE; + } + } + default: + KU_UNREACHABLE; + } +} + +static bool sameTableIDs(const std::unordered_set& set, + const std::vector& ids) { + if (set.size() != ids.size()) { + return false; + } + for (auto id : ids) { + if (!set.contains(id)) { + return false; + } + } + return true; +} + +static bool haveSameTableIDs(const std::vector& ops, + SemiMaskTargetType targetType) { + std::unordered_set tableIDSet; + for (auto id : getTableIDs(ops[0], targetType)) { + tableIDSet.insert(id); + } + for (auto i = 0u; i < ops.size(); ++i) { + if (!sameTableIDs(tableIDSet, getTableIDs(ops[i], targetType))) { + return false; + } + } + return true; +} + +static bool haveSameType(const std::vector& ops) { + for (auto i = 0u; i < ops.size(); ++i) { + if (ops[i]->getOperatorType() != ops[0]->getOperatorType()) { + return false; + } + } + return true; +} + +bool sanityCheckCandidates(const std::vector& ops, + SemiMaskTargetType targetType) { + KU_ASSERT(!ops.empty()); + if (!haveSameType(ops)) { + return false; + } + if (!haveSameTableIDs(ops, targetType)) { + return false; + } + return true; +} + +static std::shared_ptr appendSemiMasker(SemiMaskKeyType keyType, + SemiMaskTargetType targetType, std::shared_ptr key, + std::vector candidates, std::shared_ptr child) { + auto tableIDs = getTableIDs(candidates[0], targetType); + auto semiMasker = + std::make_shared(keyType, targetType, key, tableIDs, child); + for (auto candidate : candidates) { + semiMasker->addTarget(candidate); + } + semiMasker->computeFlatSchema(); + return semiMasker; +} + +void HashJoinSIPOptimizer::rewrite(const LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void HashJoinSIPOptimizer::visitOperator(LogicalOperator* op) { + // bottom up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + visitOperatorSwitch(op); +} + +static bool subPlanContainsFilter(LogicalOperator* root) { + auto filterCollector = LogicalFilterCollector(); + filterCollector.collect(root); + auto indexScanNodeCollector = LogicalIndexScanNodeCollector(); + indexScanNodeCollector.collect(root); + if (!filterCollector.hasOperators() && !indexScanNodeCollector.hasOperators()) { + return false; + } + return true; +} + +// Probe side is qualified if it is selective. +static bool isProbeSideQualified(LogicalOperator* probeRoot) { + if (probeRoot->getOperatorType() == LogicalOperatorType::ACCUMULATE) { + // No Acc hash join if probe side has already been accumulated. This can be solved. + return false; + } + // Probe side is not selective. So we don't apply acc hash join. + return subPlanContainsFilter(probeRoot); +} + +// Find all ScanNodeIDs under root which scans parameter nodeID. Note that there might be +// multiple ScanNodeIDs matches because both node and rel table scans will trigger scanNodeIDs. +static std::vector getScanNodeCandidates(const Expression& nodeID, + LogicalOperator* root) { + std::vector result; + auto collector = LogicalScanNodeTableCollector(); + collector.collect(root); + for (auto& op : collector.getOperators()) { + auto& scan = op->constCast(); + if (scan.getScanType() != LogicalScanNodeTableType::SCAN) { + // Do not apply semi mask to index scan. + continue; + } + if (nodeID.getUniqueName() == scan.getNodeID()->getUniqueName()) { + result.push_back(op); + } + } + return result; +} + +static std::vector getRecursiveExtendInputNodeCandidates(const Expression& nodeID, + LogicalOperator* root) { + std::vector result; + auto collector = LogicalRecursiveExtendCollector(); + collector.collect(root); + for (auto& op : collector.getOperators()) { + auto& recursiveExtend = op->constCast(); + auto& bindData = recursiveExtend.getBindData(); + if (nodeID == *bindData.nodeInput->constCast().getInternalID()) { + result.push_back(op); + } + } + return result; +} + +static std::vector getRecursiveExtendOutputNodeCandidates( + const Expression& nodeID, LogicalOperator* root) { + std::vector result; + auto collector = LogicalRecursiveExtendCollector(); + collector.collect(root); + for (auto op : collector.getOperators()) { + auto& recursiveExtend = op->constCast(); + auto& bindData = recursiveExtend.getBindData(); + if (nodeID == *bindData.nodeOutput->constCast().getInternalID()) { + result.push_back(op); + } + } + return result; +} + +static std::shared_ptr tryApplySemiMask(std::shared_ptr nodeID, + std::shared_ptr fromRoot, LogicalOperator* toRoot) { + // TODO(Xiyang): Check if a semi mask can/need to be applied to ScanNodeTable, RecursiveJoin & + // GDS at the same time + auto recursiveExtendInputNodeCandidates = + getRecursiveExtendInputNodeCandidates(*nodeID, toRoot); + if (!recursiveExtendInputNodeCandidates.empty()) { + for (auto& op : recursiveExtendInputNodeCandidates) { + op->cast().setInputNodeMask(); + } + auto targetType = SemiMaskTargetType::RECURSIVE_EXTEND_INPUT_NODE; + KU_ASSERT(sanityCheckCandidates(recursiveExtendInputNodeCandidates, targetType)); + return appendSemiMasker(SemiMaskKeyType::NODE, targetType, std::move(nodeID), + recursiveExtendInputNodeCandidates, std::move(fromRoot)); + } + auto recursiveExtendNodeCandidates = getRecursiveExtendOutputNodeCandidates(*nodeID, toRoot); + if (!recursiveExtendNodeCandidates.empty()) { + for (auto& op : recursiveExtendNodeCandidates) { + op->cast().setOutputNodeMask(); + } + auto targetType = SemiMaskTargetType::RECURSIVE_EXTEND_OUTPUT_NODE; + KU_ASSERT(sanityCheckCandidates(recursiveExtendNodeCandidates, targetType)); + return appendSemiMasker(SemiMaskKeyType::NODE, targetType, std::move(nodeID), + recursiveExtendNodeCandidates, std::move(fromRoot)); + } + auto scanNodeCandidates = getScanNodeCandidates(*nodeID, toRoot); + if (!scanNodeCandidates.empty()) { + return appendSemiMasker(SemiMaskKeyType::NODE, SemiMaskTargetType::SCAN_NODE, + std::move(nodeID), scanNodeCandidates, std::move(fromRoot)); + } + return nullptr; +} + +static bool tryProbeToBuildHJSIP(LogicalOperator* op) { + auto& hashJoin = op->cast(); + if (!isProbeSideQualified(op->getChild(0).get())) { + return false; + } + auto probeRoot = hashJoin.getChild(0); + auto buildRoot = hashJoin.getChild(1); + auto hasSemiMaskApplied = false; + for (auto& nodeID : hashJoin.getJoinNodeIDs()) { + auto newProbeRoot = tryApplySemiMask(nodeID, probeRoot, buildRoot.get()); + if (newProbeRoot != nullptr) { + probeRoot = newProbeRoot; + hasSemiMaskApplied = true; + } + } + if (!hasSemiMaskApplied) { + return false; + } + auto& sipInfo = hashJoin.getSIPInfoUnsafe(); + sipInfo.position = SemiMaskPosition::ON_PROBE; + sipInfo.dependency = SIPDependency::PROBE_DEPENDS_ON_BUILD; + sipInfo.direction = SIPDirection::PROBE_TO_BUILD; + hashJoin.setChild(0, appendAccumulate(probeRoot)); + return true; +} + +static bool isBuildSideQualified(LogicalOperator* buildRoot) { + if (subPlanContainsFilter(buildRoot)) { + return true; + } + // TODO(Xiyang): this may not be the best solution. Most of the time we will pass a semi mask + // to GDS (recursive join) operator and make it generate small result. Though there are also + // exceptions. In such case we will pay a bit overhead. + auto op = buildRoot; + while (op->getNumChildren() == 1) { + op = op->getChild(0).get(); + } + return op->getOperatorType() == LogicalOperatorType::RECURSIVE_EXTEND; +} + +static bool tryBuildToProbeHJSIP(LogicalOperator* op) { + auto& hashJoin = op->cast(); + if (hashJoin.getJoinType() != JoinType::INNER) { + return false; + } + if (hashJoin.getSIPInfo().direction != SIPDirection::FORCE_BUILD_TO_PROBE && + !isBuildSideQualified(op->getChild(1).get())) { + return false; + } + auto probeRoot = hashJoin.getChild(0); + auto buildRoot = hashJoin.getChild(1); + auto hasSemiMaskApplied = false; + for (auto& nodeID : hashJoin.getJoinNodeIDs()) { + auto newBuildRoot = tryApplySemiMask(nodeID, buildRoot, probeRoot.get()); + if (newBuildRoot != nullptr) { + buildRoot = newBuildRoot; + hasSemiMaskApplied = true; + } + } + if (!hasSemiMaskApplied) { + return false; + } + auto& sipInfo = hashJoin.getSIPInfoUnsafe(); + sipInfo.position = SemiMaskPosition::ON_BUILD; + sipInfo.dependency = SIPDependency::BUILD_DEPENDS_ON_PROBE; + sipInfo.direction = SIPDirection::BUILD_TO_PROBE; + hashJoin.setChild(1, buildRoot); + return true; +} + +void HashJoinSIPOptimizer::visitHashJoin(LogicalOperator* op) { + auto& hashJoin = op->cast(); + if (LogicalOperatorUtils::isAccHashJoin(hashJoin)) { + return; + } + if (hashJoin.getSIPInfo().position == SemiMaskPosition::PROHIBIT) { + return; + } + if (tryBuildToProbeHJSIP(op)) { // Try build to probe SIP first. + return; + } + if (hashJoin.getSIPInfo().position == SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD) { + return; + } + tryProbeToBuildHJSIP(op); +} + +// TODO(Xiyang): we don't apply SIP from build to probe. +void HashJoinSIPOptimizer::visitIntersect(LogicalOperator* op) { + auto& intersect = op->cast(); + switch (intersect.getSIPInfo().position) { + case SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD: + case SemiMaskPosition::PROHIBIT: + return; + default: + break; + } + if (!isProbeSideQualified(op->getChild(0).get())) { + return; + } + auto probeRoot = intersect.getChild(0); + auto hasSemiMaskApplied = false; + for (auto& nodeID : intersect.getKeyNodeIDs()) { + std::vector ops; + for (auto i = 1u; i < intersect.getNumChildren(); ++i) { + auto buildRoot = intersect.getChild(i); + for (auto& op_ : getScanNodeCandidates(*nodeID, buildRoot.get())) { + ops.push_back(op_); + } + } + if (!ops.empty()) { + probeRoot = appendSemiMasker(SemiMaskKeyType::NODE, SemiMaskTargetType::SCAN_NODE, + nodeID, ops, probeRoot); + hasSemiMaskApplied = true; + } + } + if (!hasSemiMaskApplied) { + return; + } + auto& sipInfo = intersect.getSIPInfoUnsafe(); + sipInfo.position = SemiMaskPosition::ON_PROBE; + sipInfo.dependency = SIPDependency::PROBE_DEPENDS_ON_BUILD; + sipInfo.direction = SIPDirection::PROBE_TO_BUILD; + intersect.setChild(0, appendAccumulate(probeRoot)); +} + +void HashJoinSIPOptimizer::visitPathPropertyProbe(LogicalOperator* op) { + auto& pathPropertyProbe = op->cast(); + switch (pathPropertyProbe.getSIPInfo().position) { + case SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD: + case SemiMaskPosition::PROHIBIT: + return; + default: + break; + } + if (pathPropertyProbe.getJoinType() == RecursiveJoinType::TRACK_NONE) { + return; + } + auto recursiveRel = pathPropertyProbe.getRel(); + auto nodeID = recursiveRel->getRecursiveInfo()->node->getInternalID(); + std::vector opsToApplySemiMask; + if (pathPropertyProbe.getNodeChild() != nullptr) { + auto child = pathPropertyProbe.getNodeChild().get(); + for (auto op_ : getScanNodeCandidates(*nodeID, child)) { + opsToApplySemiMask.push_back(op_); + } + } + if (pathPropertyProbe.getRelChild() != nullptr) { + auto child = pathPropertyProbe.getRelChild().get(); + for (auto op_ : getScanNodeCandidates(*nodeID, child)) { + opsToApplySemiMask.push_back(op_); + } + } + if (opsToApplySemiMask.empty()) { + return; + } + KU_ASSERT( + pathPropertyProbe.getChild(0)->getOperatorType() == LogicalOperatorType::RECURSIVE_EXTEND); + auto semiMasker = appendSemiMasker(SemiMaskKeyType::NODE_ID_LIST, SemiMaskTargetType::SCAN_NODE, + recursiveRel->getRecursiveInfo()->bindData->pathNodeIDsExpr, opsToApplySemiMask, + pathPropertyProbe.getChild(0)); + auto srcNodeID = recursiveRel->getSrcNode()->getInternalID(); + auto dstNodeID = recursiveRel->getDstNode()->getInternalID(); + semiMasker->setExtraKeyInfo(std::make_unique(srcNodeID, dstNodeID)); + pathPropertyProbe.setChild(0, semiMasker); + + auto& sipInfo = pathPropertyProbe.getSIPInfoUnsafe(); + sipInfo.position = SemiMaskPosition::ON_PROBE; + sipInfo.dependency = SIPDependency::PROBE_DEPENDS_ON_BUILD; + sipInfo.direction = SIPDirection::PROBE_TO_BUILD; +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/agg_key_dependency_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/agg_key_dependency_optimizer.cpp new file mode 100644 index 0000000000..91334eaba3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/agg_key_dependency_optimizer.cpp @@ -0,0 +1,91 @@ +#include "optimizer/agg_key_dependency_optimizer.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/property_expression.h" +#include "planner/operator/logical_aggregate.h" +#include "planner/operator/logical_distinct.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void AggKeyDependencyOptimizer::rewrite(planner::LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void AggKeyDependencyOptimizer::visitOperator(planner::LogicalOperator* op) { + // bottom up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + visitOperatorSwitch(op); +} + +void AggKeyDependencyOptimizer::visitAggregate(planner::LogicalOperator* op) { + auto agg = (LogicalAggregate*)op; + auto [keys, dependentKeys] = resolveKeysAndDependentKeys(agg->getKeys()); + agg->setKeys(keys); + agg->setDependentKeys(dependentKeys); +} + +void AggKeyDependencyOptimizer::visitDistinct(planner::LogicalOperator* op) { + auto distinct = (LogicalDistinct*)op; + auto [keys, dependentKeys] = resolveKeysAndDependentKeys(distinct->getKeys()); + distinct->setKeys(keys); + distinct->setPayloads(dependentKeys); +} + +std::pair +AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const expression_vector& inputKeys) { + // Consider example RETURN a.ID, a.age, COUNT(*). + // We first collect a.ID into primaryKeys. Then collect "a" into primaryVarNames. + // Finally, we loop through all group by keys to put non-primary key properties under name "a" + // into dependentKeyExpressions. + + // Collect primary variables from keys. + std::unordered_set primaryVarNames; + for (auto& key : inputKeys) { + if (key->expressionType == ExpressionType::PROPERTY) { + auto property = (PropertyExpression*)key.get(); + if (property->isPrimaryKey() || property->isInternalID()) { + primaryVarNames.insert(property->getVariableName()); + } + } + } + // Resolve key dependency. + binder::expression_vector keys; + binder::expression_vector dependentKeys; + for (auto& key : inputKeys) { + if (key->expressionType == ExpressionType::PROPERTY) { + auto property = (PropertyExpression*)key.get(); + if (property->isPrimaryKey() || + property->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing + // is a logical error. + // Primary properties are always keys. + keys.push_back(key); + } else if (primaryVarNames.contains(property->getVariableName())) { + // Properties depend on any primary property are dependent keys. + // e.g. a.age depends on a._id + dependentKeys.push_back(key); + } else { + keys.push_back(key); + } + } else if (ExpressionUtil::isNodePattern(*key) || ExpressionUtil::isRelPattern(*key)) { + if (primaryVarNames.contains(key->getUniqueName())) { + // e.g. a depends on a._id + dependentKeys.push_back(key); + } else { + keys.push_back(key); + } + } else { + keys.push_back(key); + } + } + return std::make_pair(std::move(keys), std::move(dependentKeys)); +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/cardinality_updater.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/cardinality_updater.cpp new file mode 100644 index 0000000000..a0b0f9f068 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/cardinality_updater.cpp @@ -0,0 +1,137 @@ +#include "optimizer/cardinality_updater.h" + +#include "binder/expression/expression_util.h" +#include "planner/join_order/cardinality_estimator.h" +#include "planner/operator/extend/logical_extend.h" +#include "planner/operator/logical_aggregate.h" +#include "planner/operator/logical_filter.h" +#include "planner/operator/logical_flatten.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_intersect.h" +#include "planner/operator/logical_limit.h" +#include "planner/operator/logical_plan.h" + +namespace lbug::optimizer { +void CardinalityUpdater::rewrite(planner::LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void CardinalityUpdater::visitOperator(planner::LogicalOperator* op) { + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + visitOperatorSwitchWithDefault(op); +} + +void CardinalityUpdater::visitOperatorSwitchWithDefault(planner::LogicalOperator* op) { + switch (op->getOperatorType()) { + case planner::LogicalOperatorType::SCAN_NODE_TABLE: { + visitScanNodeTable(op); + break; + } + case planner::LogicalOperatorType::EXTEND: { + visitExtend(op); + break; + } + case planner::LogicalOperatorType::HASH_JOIN: { + visitHashJoin(op); + break; + } + case planner::LogicalOperatorType::CROSS_PRODUCT: { + visitCrossProduct(op); + break; + } + case planner::LogicalOperatorType::INTERSECT: { + visitIntersect(op); + break; + } + case planner::LogicalOperatorType::FLATTEN: { + visitFlatten(op); + break; + } + case planner::LogicalOperatorType::FILTER: { + visitFilter(op); + break; + } + case planner::LogicalOperatorType::LIMIT: { + visitLimit(op); + break; + } + case planner::LogicalOperatorType::AGGREGATE: { + visitAggregate(op); + break; + } + default: { + visitOperatorDefault(op); + break; + } + } +} + +void CardinalityUpdater::visitOperatorDefault(planner::LogicalOperator* op) { + if (op->getNumChildren() == 1) { + op->setCardinality(op->getChild(0)->getCardinality()); + } +} + +void CardinalityUpdater::visitScanNodeTable(planner::LogicalOperator* op) { + op->setCardinality(cardinalityEstimator.estimateScanNode(*op)); +} + +void CardinalityUpdater::visitExtend(planner::LogicalOperator* op) { + KU_ASSERT(transaction); + auto& extend = op->cast(); + const auto extensionRate = cardinalityEstimator.getExtensionRate(*extend.getRel(), + *extend.getBoundNode(), transaction); + extend.setCardinality( + cardinalityEstimator.multiply(extensionRate, op->getChild(0)->getCardinality())); +} + +void CardinalityUpdater::visitHashJoin(planner::LogicalOperator* op) { + auto& hashJoin = op->cast(); + KU_ASSERT(hashJoin.getNumChildren() >= 2); + hashJoin.setCardinality(cardinalityEstimator.estimateHashJoin(hashJoin.getJoinConditions(), + *hashJoin.getChild(0), *hashJoin.getChild(1))); +} + +void CardinalityUpdater::visitCrossProduct(planner::LogicalOperator* op) { + op->setCardinality( + cardinalityEstimator.estimateCrossProduct(*op->getChild(0), *op->getChild(1))); +} + +void CardinalityUpdater::visitIntersect(planner::LogicalOperator* op) { + auto& intersect = op->cast(); + KU_ASSERT(intersect.getNumChildren() >= 2); + std::vector buildOps; + for (uint32_t i = 1; i < intersect.getNumChildren(); ++i) { + buildOps.push_back(intersect.getChild(i).get()); + } + intersect.setCardinality(cardinalityEstimator.estimateIntersect(intersect.getKeyNodeIDs(), + *intersect.getChild(0), buildOps)); +} + +void CardinalityUpdater::visitFlatten(planner::LogicalOperator* op) { + auto& flatten = op->cast(); + flatten.setCardinality( + cardinalityEstimator.estimateFlatten(*flatten.getChild(0), flatten.getGroupPos())); +} + +void CardinalityUpdater::visitFilter(planner::LogicalOperator* op) { + auto& filter = op->cast(); + filter.setCardinality( + cardinalityEstimator.estimateFilter(*filter.getChild(0), *filter.getPredicate())); +} + +void CardinalityUpdater::visitLimit(planner::LogicalOperator* op) { + auto& limit = op->cast(); + if (limit.hasLimitNum() && binder::ExpressionUtil::canEvaluateAsLiteral(*limit.getLimitNum())) { + limit.setCardinality(binder::ExpressionUtil::evaluateAsSkipLimit(*limit.getLimitNum())); + } +} + +void CardinalityUpdater::visitAggregate(planner::LogicalOperator* op) { + auto& aggregate = op->cast(); + aggregate.setCardinality(cardinalityEstimator.estimateAggregate(aggregate)); +} + +} // namespace lbug::optimizer diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/correlated_subquery_unnest_solver.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/correlated_subquery_unnest_solver.cpp new file mode 100644 index 0000000000..823102899a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/correlated_subquery_unnest_solver.cpp @@ -0,0 +1,51 @@ +#include "optimizer/correlated_subquery_unnest_solver.h" + +#include "common/exception/internal.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/scan/logical_expressions_scan.h" + +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void CorrelatedSubqueryUnnestSolver::solve(planner::LogicalOperator* root_) { + visitOperator(root_); +} + +void CorrelatedSubqueryUnnestSolver::visitOperator(LogicalOperator* op) { + visitOperatorSwitch(op); + if (LogicalOperatorUtils::isAccHashJoin(*op)) { + solveAccHashJoin(op); + return; + } + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } +} + +void CorrelatedSubqueryUnnestSolver::solveAccHashJoin(LogicalOperator* op) const { + auto& hashJoin = op->cast(); + auto& sipInfo = hashJoin.getSIPInfoUnsafe(); + sipInfo.dependency = SIPDependency::BUILD_DEPENDS_ON_PROBE; + sipInfo.direction = SIPDirection::PROBE_TO_BUILD; + auto acc = op->getChild(0).get(); + auto rightSolver = std::make_unique(acc); + rightSolver->solve(hashJoin.getChild(1).get()); + auto leftSolver = std::make_unique(accumulateOp); + leftSolver->solve(acc->getChild(0).get()); +} + +void CorrelatedSubqueryUnnestSolver::visitExpressionsScan(LogicalOperator* op) { + auto expressionsScan = op->ptrCast(); + // LCOV_EXCL_START + if (accumulateOp == nullptr) { + throw common::InternalException( + "Failed to execute CorrelatedSubqueryUnnestSolver. This should not happen."); + } + // LCOV_EXCL_STOP + expressionsScan->setOuterAccumulate(accumulateOp); +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/factorization_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/factorization_rewriter.cpp new file mode 100644 index 0000000000..332fb285b8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/factorization_rewriter.cpp @@ -0,0 +1,191 @@ +#include "optimizer/factorization_rewriter.h" + +#include "binder/expression_visitor.h" +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/logical_accumulate.h" +#include "planner/operator/logical_aggregate.h" +#include "planner/operator/logical_distinct.h" +#include "planner/operator/logical_filter.h" +#include "planner/operator/logical_flatten.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_intersect.h" +#include "planner/operator/logical_limit.h" +#include "planner/operator/logical_order_by.h" +#include "planner/operator/logical_projection.h" +#include "planner/operator/logical_union.h" +#include "planner/operator/logical_unwind.h" +#include "planner/operator/persistent/logical_copy_to.h" +#include "planner/operator/persistent/logical_delete.h" +#include "planner/operator/persistent/logical_insert.h" +#include "planner/operator/persistent/logical_merge.h" +#include "planner/operator/persistent/logical_set.h" + +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void FactorizationRewriter::rewrite(planner::LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void FactorizationRewriter::visitOperator(planner::LogicalOperator* op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + visitOperatorSwitch(op); + op->computeFactorizedSchema(); +} + +void FactorizationRewriter::visitHashJoin(planner::LogicalOperator* op) { + // TODO(Royi) correctly set the cardinality here + auto& hashJoin = op->cast(); + auto groupsPosToFlattenOnProbeSide = hashJoin.getGroupsPosToFlattenOnProbeSide(); + hashJoin.setChild(0, appendFlattens(hashJoin.getChild(0), groupsPosToFlattenOnProbeSide)); + auto groupsPosToFlattenOnBuildSide = hashJoin.getGroupsPosToFlattenOnBuildSide(); + hashJoin.setChild(1, appendFlattens(hashJoin.getChild(1), groupsPosToFlattenOnBuildSide)); +} + +void FactorizationRewriter::visitIntersect(planner::LogicalOperator* op) { + auto& intersect = op->cast(); + auto groupsPosToFlattenOnProbeSide = intersect.getGroupsPosToFlattenOnProbeSide(); + intersect.setChild(0, appendFlattens(intersect.getChild(0), groupsPosToFlattenOnProbeSide)); + for (auto i = 0u; i < intersect.getNumBuilds(); ++i) { + auto groupPosToFlatten = intersect.getGroupsPosToFlattenOnBuildSide(i); + auto childIdx = i + 1; // skip probe + intersect.setChild(childIdx, + appendFlattens(intersect.getChild(childIdx), groupPosToFlatten)); + } +} + +void FactorizationRewriter::visitProjection(planner::LogicalOperator* op) { + auto& projection = op->cast(); + bool hasRandomFunction = false; + for (auto& expr : projection.getExpressionsToProject()) { + if (ExpressionVisitor::isRandom(*expr)) { + hasRandomFunction = true; + } + } + if (hasRandomFunction) { + // Fall back to tuple-at-a-time evaluation. + auto groupsPos = op->getChild(0)->getSchema()->getGroupsPosInScope(); + auto groupsPosToFlatten = + FlattenAll::getGroupsPosToFlatten(groupsPos, *op->getChild(0)->getSchema()); + projection.setChild(0, appendFlattens(projection.getChild(0), groupsPosToFlatten)); + } else { + for (auto& expression : projection.getExpressionsToProject()) { + auto groupsPosToFlatten = + FlattenAllButOne::getGroupsPosToFlatten(expression, *op->getChild(0)->getSchema()); + projection.setChild(0, appendFlattens(projection.getChild(0), groupsPosToFlatten)); + } + } +} + +void FactorizationRewriter::visitAccumulate(planner::LogicalOperator* op) { + auto& accumulate = op->cast(); + auto groupsPosToFlatten = accumulate.getGroupPositionsToFlatten(); + accumulate.setChild(0, appendFlattens(accumulate.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitAggregate(planner::LogicalOperator* op) { + auto& aggregate = op->cast(); + auto groupsPosToFlatten = aggregate.getGroupsPosToFlatten(); + aggregate.setChild(0, appendFlattens(aggregate.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitOrderBy(planner::LogicalOperator* op) { + auto& orderBy = op->cast(); + auto groupsPosToFlatten = orderBy.getGroupsPosToFlatten(); + orderBy.setChild(0, appendFlattens(orderBy.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitLimit(planner::LogicalOperator* op) { + auto& limit = op->cast(); + auto groupsPosToFlatten = limit.getGroupsPosToFlatten(); + limit.setChild(0, appendFlattens(limit.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitDistinct(planner::LogicalOperator* op) { + auto& distinct = op->cast(); + auto groupsPosToFlatten = distinct.getGroupsPosToFlatten(); + distinct.setChild(0, appendFlattens(distinct.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitUnwind(planner::LogicalOperator* op) { + auto& unwind = op->cast(); + auto groupsPosToFlatten = unwind.getGroupsPosToFlatten(); + unwind.setChild(0, appendFlattens(unwind.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitUnion(planner::LogicalOperator* op) { + auto& union_ = op->cast(); + for (auto i = 0u; i < union_.getNumChildren(); ++i) { + auto groupsPosToFlatten = union_.getGroupsPosToFlatten(i); + union_.setChild(i, appendFlattens(union_.getChild(i), groupsPosToFlatten)); + } +} + +void FactorizationRewriter::visitFilter(planner::LogicalOperator* op) { + auto& filter = op->cast(); + auto groupsPosToFlatten = filter.getGroupsPosToFlatten(); + filter.setChild(0, appendFlattens(filter.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitSetProperty(planner::LogicalOperator* op) { + auto& set = op->cast(); + for (auto i = 0u; i < set.getInfos().size(); ++i) { + auto groupsPos = set.getGroupsPosToFlatten(i); + set.setChild(0, appendFlattens(set.getChild(0), groupsPos)); + } +} + +void FactorizationRewriter::visitDelete(planner::LogicalOperator* op) { + auto& delete_ = op->cast(); + auto groupsPosToFlatten = delete_.getGroupsPosToFlatten(); + delete_.setChild(0, appendFlattens(delete_.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitInsert(planner::LogicalOperator* op) { + auto& insert = op->cast(); + auto groupsPosToFlatten = insert.getGroupsPosToFlatten(); + insert.setChild(0, appendFlattens(insert.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitMerge(planner::LogicalOperator* op) { + auto& merge = op->cast(); + auto groupsPosToFlatten = merge.getGroupsPosToFlatten(); + merge.setChild(0, appendFlattens(merge.getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitCopyTo(planner::LogicalOperator* op) { + auto& copyTo = op->cast(); + auto groupsPosToFlatten = copyTo.getGroupsPosToFlatten(); + copyTo.setChild(0, appendFlattens(copyTo.getChild(0), groupsPosToFlatten)); +} + +std::shared_ptr FactorizationRewriter::appendFlattens( + std::shared_ptr op, + const std::unordered_set& groupsPos) { + auto currentChild = std::move(op); + for (auto groupPos : groupsPos) { + currentChild = appendFlattenIfNecessary(std::move(currentChild), groupPos); + } + return currentChild; +} + +std::shared_ptr FactorizationRewriter::appendFlattenIfNecessary( + std::shared_ptr op, planner::f_group_pos groupPos) { + if (op->getSchema()->getGroup(groupPos)->isFlat()) { + return op; + } + // we set the cardinalities in a separate pass + auto flatten = std::make_shared(groupPos, std::move(op), 0 /* cardinality */); + flatten->computeFactorizedSchema(); + return flatten; +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/filter_push_down_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/filter_push_down_optimizer.cpp new file mode 100644 index 0000000000..05c6036ba9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/filter_push_down_optimizer.cpp @@ -0,0 +1,332 @@ +#include "optimizer/filter_push_down_optimizer.h" + +#include "binder/expression/literal_expression.h" +#include "binder/expression/property_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "main/client_context.h" +#include "planner/operator/extend/logical_extend.h" +#include "planner/operator/logical_empty_result.h" +#include "planner/operator/logical_filter.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_table_function_call.h" +#include "planner/operator/scan/logical_scan_node_table.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; + +namespace lbug { +namespace optimizer { + +void FilterPushDownOptimizer::rewrite(LogicalPlan* plan) { + visitOperator(plan->getLastOperator()); +} + +std::shared_ptr FilterPushDownOptimizer::visitOperator( + const std::shared_ptr& op) { + switch (op->getOperatorType()) { + case LogicalOperatorType::FILTER: { + return visitFilterReplace(op); + } + case LogicalOperatorType::CROSS_PRODUCT: { + return visitCrossProductReplace(op); + } + case LogicalOperatorType::EXTEND: { + return visitExtendReplace(op); + } + case LogicalOperatorType::SCAN_NODE_TABLE: { + return visitScanNodeTableReplace(op); + } + case LogicalOperatorType::TABLE_FUNCTION_CALL: { + return visitTableFunctionCallReplace(op); + } + default: { // Stop current push down for unhandled operator. + return visitChildren(op); + } + } +} + +std::shared_ptr FilterPushDownOptimizer::visitChildren( + const std::shared_ptr& op) { + for (auto i = 0u; i < op->getNumChildren(); ++i) { + // Start new push down for child. + auto optimizer = FilterPushDownOptimizer(context); + op->setChild(i, optimizer.visitOperator(op->getChild(i))); + } + op->computeFlatSchema(); + return finishPushDown(op); +} + +std::shared_ptr FilterPushDownOptimizer::visitFilterReplace( + const std::shared_ptr& op) { + auto& filter = op->constCast(); + auto predicate = filter.getPredicate(); + if (predicate->expressionType == ExpressionType::LITERAL) { + // Avoid executing child plan if literal is Null or False. + auto& literalExpr = predicate->constCast(); + if (literalExpr.isNull() || !literalExpr.getValue().getValue()) { + return std::make_shared(*op->getSchema()); + } + // Ignore if literal is True. + } else { + predicateSet.addPredicate(predicate); + } + return visitOperator(filter.getChild(0)); +} + +std::shared_ptr FilterPushDownOptimizer::visitCrossProductReplace( + const std::shared_ptr& op) { + auto remainingPSet = PredicateSet(); + auto probePSet = PredicateSet(); + auto buildPSet = PredicateSet(); + for (auto& p : predicateSet.getAllPredicates()) { + auto inProbe = op->getChild(0)->getSchema()->evaluable(*p); + auto inBuild = op->getChild(1)->getSchema()->evaluable(*p); + if (inProbe && !inBuild) { + probePSet.addPredicate(p); + } else if (!inProbe && inBuild) { + buildPSet.addPredicate(p); + } else { + remainingPSet.addPredicate(p); + } + } + KU_ASSERT(op->getNumChildren() == 2); + // Push probe side + auto probeOptimizer = FilterPushDownOptimizer(context, std::move(probePSet)); + op->setChild(0, probeOptimizer.visitOperator(op->getChild(0))); + // Push build side + auto buildOptimizer = FilterPushDownOptimizer(context, std::move(buildPSet)); + op->setChild(1, buildOptimizer.visitOperator(op->getChild(1))); + + auto probeSchema = op->getChild(0)->getSchema(); + auto buildSchema = op->getChild(1)->getSchema(); + expression_vector predicates; + std::vector joinConditions; + for (auto& predicate : remainingPSet.equalityPredicates) { + auto left = predicate->getChild(0); + auto right = predicate->getChild(1); + // TODO(Xiyang): this can only rewrite left = right, we should also be able to do + // expr(left), expr(right) + if (probeSchema->isExpressionInScope(*left) && buildSchema->isExpressionInScope(*right)) { + joinConditions.emplace_back(left, right); + } else if (probeSchema->isExpressionInScope(*right) && + buildSchema->isExpressionInScope(*left)) { + joinConditions.emplace_back(right, left); + } else { + // Collect predicates that cannot be rewritten as join conditions. + predicates.push_back(predicate); + } + } + if (joinConditions.empty()) { // Nothing to push down. Terminate. + return finishPushDown(op); + } + auto hashJoin = std::make_shared(joinConditions, JoinType::INNER, + nullptr /* mark */, op->getChild(0), op->getChild(1), 0 /* cardinality */); + // For non-id based joins, we disable side way information passing. + hashJoin->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT; + hashJoin->computeFlatSchema(); + // Apply remaining predicates. + predicates.insert(predicates.end(), remainingPSet.nonEqualityPredicates.begin(), + remainingPSet.nonEqualityPredicates.end()); + if (predicates.empty()) { + return hashJoin; + } + return appendFilters(predicates, hashJoin); +} + +static ColumnPredicateSet getPredicateSet(const Expression& column, + const binder::expression_vector& predicates) { + auto predicateSet = ColumnPredicateSet(); + for (auto& predicate : predicates) { + auto columnPredicate = ColumnPredicateUtil::tryConvert(column, *predicate); + if (columnPredicate == nullptr) { + continue; + } + predicateSet.addPredicate(std::move(columnPredicate)); + } + return predicateSet; +} + +static std::vector getColumnPredicateSets(const expression_vector& columns, + const expression_vector& predicates) { + std::vector predicateSets; + for (auto& column : columns) { + predicateSets.push_back(getPredicateSet(*column, predicates)); + } + return predicateSets; +} + +static bool isConstantExpression(const std::shared_ptr expression) { + switch (expression->expressionType) { + case ExpressionType::LITERAL: + case ExpressionType::PARAMETER: { + return true; + } + // TODO(Xiyang): fold parameter expression in binder. + case ExpressionType::FUNCTION: { + auto& func = expression->constCast(); + if (func.getFunction().name == "CAST") { + return isConstantExpression(func.getChild(0)); + } else { + return false; + } + } + default: + return false; + } +} + +std::shared_ptr FilterPushDownOptimizer::visitScanNodeTableReplace( + const std::shared_ptr& op) { + auto& scan = op->cast(); + auto nodeID = scan.getNodeID(); + // Apply column predicates. + if (context->getClientConfig()->enableZoneMap) { + scan.setPropertyPredicates( + getColumnPredicateSets(scan.getProperties(), predicateSet.getAllPredicates())); + } + // Apply index scan + auto tableIDs = scan.getTableIDs(); + std::shared_ptr primaryKeyEqualityComparison = nullptr; + if (tableIDs.size() == 1) { + primaryKeyEqualityComparison = predicateSet.popNodePKEqualityComparison(*nodeID); + } + if (primaryKeyEqualityComparison != nullptr) { // Try rewrite index scan + auto rhs = primaryKeyEqualityComparison->getChild(1); + if (isConstantExpression(rhs)) { + auto extraInfo = std::make_unique(rhs); + scan.setScanType(LogicalScanNodeTableType::PRIMARY_KEY_SCAN); + scan.setExtraInfo(std::move(extraInfo)); + scan.computeFlatSchema(); + } else { + // Cannot rewrite and add predicate back. + predicateSet.addPredicate(primaryKeyEqualityComparison); + } + } + return finishPushDown(op); +} + +std::shared_ptr FilterPushDownOptimizer::visitTableFunctionCallReplace( + const std::shared_ptr& op) { + auto& tableFunctionCall = op->cast(); + auto columnPredicates = getColumnPredicateSets(tableFunctionCall.getBindData()->columns, + predicateSet.getAllPredicates()); + tableFunctionCall.setColumnPredicates(std::move(columnPredicates)); + return finishPushDown(op); +} + +std::shared_ptr FilterPushDownOptimizer::visitExtendReplace( + const std::shared_ptr& op) { + if (op->ptrCast()->isRecursive() || + !context->getClientConfig()->enableZoneMap) { + return visitChildren(op); + } + auto& extend = op->cast(); + // Apply column predicates. + auto columnPredicates = + getColumnPredicateSets(extend.getProperties(), predicateSet.getAllPredicates()); + extend.setPropertyPredicates(std::move(columnPredicates)); + return visitChildren(op); +} + +std::shared_ptr FilterPushDownOptimizer::finishPushDown( + std::shared_ptr op) { + if (predicateSet.isEmpty()) { + return op; + } + auto predicates = predicateSet.getAllPredicates(); + auto root = appendFilters(predicates, op); + predicateSet.clear(); + return root; +} + +std::shared_ptr FilterPushDownOptimizer::appendScanNodeTable( + std::shared_ptr nodeID, std::vector nodeTableIDs, + binder::expression_vector properties, std::shared_ptr child) { + if (properties.empty()) { + return child; + } + auto printInfo = std::make_unique(); + auto scanNodeTable = std::make_shared(std::move(nodeID), + std::move(nodeTableIDs), std::move(properties)); + scanNodeTable->computeFlatSchema(); + return scanNodeTable; +} + +std::shared_ptr FilterPushDownOptimizer::appendFilters( + const expression_vector& predicates, std::shared_ptr child) { + if (predicates.empty()) { + return child; + } + auto root = child; + for (auto& p : predicates) { + root = appendFilter(p, root); + } + return root; +} + +std::shared_ptr FilterPushDownOptimizer::appendFilter( + std::shared_ptr predicate, std::shared_ptr child) { + auto printInfo = std::make_unique(); + auto filter = std::make_shared(std::move(predicate), std::move(child)); + filter->computeFlatSchema(); + return filter; +} + +void PredicateSet::addPredicate(std::shared_ptr predicate) { + if (predicate->expressionType == ExpressionType::EQUALS) { + equalityPredicates.push_back(std::move(predicate)); + } else { + nonEqualityPredicates.push_back(std::move(predicate)); + } +} + +static bool isNodePrimaryKey(const Expression& expression, const Expression& nodeID) { + if (expression.expressionType != ExpressionType::PROPERTY) { + // not property + return false; + } + auto& property = expression.constCast(); + if (property.getVariableName() != nodeID.constCast().getVariableName()) { + // not property for node + return false; + } + return property.isPrimaryKey(); +} + +std::shared_ptr PredicateSet::popNodePKEqualityComparison(const Expression& nodeID) { + // We pop when the first primary key equality comparison is found. + auto resultPredicateIdx = INVALID_IDX; + for (auto i = 0u; i < equalityPredicates.size(); ++i) { + auto predicate = equalityPredicates[i]; + if (isNodePrimaryKey(*predicate->getChild(0), nodeID)) { + resultPredicateIdx = i; + break; + } else if (isNodePrimaryKey(*predicate->getChild(1), nodeID)) { + // Normalize primary key to LHS. + auto leftChild = predicate->getChild(0); + auto rightChild = predicate->getChild(1); + predicate->setChild(1, leftChild); + predicate->setChild(0, rightChild); + resultPredicateIdx = i; + break; + } + } + if (resultPredicateIdx != INVALID_IDX) { + auto result = equalityPredicates[resultPredicateIdx]; + equalityPredicates.erase(equalityPredicates.begin() + resultPredicateIdx); + return result; + } + return nullptr; +} + +expression_vector PredicateSet::getAllPredicates() { + expression_vector result; + result.insert(result.end(), equalityPredicates.begin(), equalityPredicates.end()); + result.insert(result.end(), nonEqualityPredicates.begin(), nonEqualityPredicates.end()); + return result; +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/limit_push_down_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/limit_push_down_optimizer.cpp new file mode 100644 index 0000000000..5f929a905d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/limit_push_down_optimizer.cpp @@ -0,0 +1,85 @@ +#include "optimizer/limit_push_down_optimizer.h" + +#include "binder/expression/expression_util.h" +#include "common/exception/runtime.h" +#include "planner/operator/extend/logical_recursive_extend.h" +#include "planner/operator/logical_distinct.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_limit.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void LimitPushDownOptimizer::rewrite(LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void LimitPushDownOptimizer::visitOperator(planner::LogicalOperator* op) { + switch (op->getOperatorType()) { + case LogicalOperatorType::LIMIT: { + auto& limit = op->constCast(); + if (limit.hasSkipNum() && ExpressionUtil::canEvaluateAsLiteral(*limit.getSkipNum())) { + skipNumber = ExpressionUtil::evaluateAsSkipLimit(*limit.getSkipNum()); + } + if (limit.hasLimitNum() && ExpressionUtil::canEvaluateAsLiteral(*limit.getLimitNum())) { + limitNumber = ExpressionUtil::evaluateAsSkipLimit(*limit.getLimitNum()); + } + visitOperator(limit.getChild(0).get()); + return; + } + case LogicalOperatorType::MULTIPLICITY_REDUCER: + case LogicalOperatorType::EXPLAIN: + case LogicalOperatorType::ACCUMULATE: + case LogicalOperatorType::PROJECTION: { + visitOperator(op->getChild(0).get()); + return; + } + case LogicalOperatorType::DISTINCT: { + if (limitNumber == INVALID_LIMIT && skipNumber == 0) { + return; + } + auto& distinctOp = op->cast(); + distinctOp.setLimitNum(limitNumber); + distinctOp.setSkipNum(skipNumber); + return; + } + case LogicalOperatorType::HASH_JOIN: { + if (limitNumber == INVALID_LIMIT && skipNumber == 0) { + return; + } + if (op->getChild(0)->getOperatorType() == LogicalOperatorType::HASH_JOIN) { + op->ptrCast()->getSIPInfoUnsafe().position = SemiMaskPosition::NONE; + // OP is the hash join reading destination node property. Continue push limit down. + op = op->getChild(0).get(); + } + if (op->getChild(0)->getOperatorType() == LogicalOperatorType::PATH_PROPERTY_PROBE) { + // LCOV_EXCL_START + if (op->getChild(0)->getChild(0)->getOperatorType() != + LogicalOperatorType::RECURSIVE_EXTEND) { + throw RuntimeException("Trying to push limit to a non RECURSIVE_EXTEND operator. " + "This should never happen."); + } + // LCOV_EXCL_STOP + auto& extend = op->getChild(0)->getChild(0)->cast(); + extend.setLimitNum(skipNumber + limitNumber); + } + return; + } + case LogicalOperatorType::UNION_ALL: { + for (auto i = 0u; i < op->getNumChildren(); ++i) { + auto optimizer = LimitPushDownOptimizer(); + optimizer.visitOperator(op->getChild(i).get()); + } + return; + } + default: + return; + } +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/logical_operator_collector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/logical_operator_collector.cpp new file mode 100644 index 0000000000..a17f6f0e32 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/logical_operator_collector.cpp @@ -0,0 +1,23 @@ +#include "optimizer/logical_operator_collector.h" + +#include "planner/operator/scan/logical_scan_node_table.h" + +namespace lbug { +namespace optimizer { + +void LogicalOperatorCollector::collect(planner::LogicalOperator* op) { + for (auto i = 0u; i < op->getNumChildren(); ++i) { + collect(op->getChild(i).get()); + } + visitOperatorSwitch(op); +} + +void LogicalIndexScanNodeCollector::visitScanNodeTable(planner::LogicalOperator* op) { + auto scan = op->constCast(); + if (scan.getScanType() == planner::LogicalScanNodeTableType::PRIMARY_KEY_SCAN) { + ops.push_back(op); + } +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/logical_operator_visitor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/logical_operator_visitor.cpp new file mode 100644 index 0000000000..89d454ce5f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/logical_operator_visitor.cpp @@ -0,0 +1,186 @@ +#include "optimizer/logical_operator_visitor.h" + +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void LogicalOperatorVisitor::visitOperatorSwitch(LogicalOperator* op) { + switch (op->getOperatorType()) { + case LogicalOperatorType::ACCUMULATE: { + visitAccumulate(op); + } break; + case LogicalOperatorType::AGGREGATE: { + visitAggregate(op); + } break; + case LogicalOperatorType::COPY_FROM: { + visitCopyFrom(op); + } break; + case LogicalOperatorType::COPY_TO: { + visitCopyTo(op); + } break; + case LogicalOperatorType::DELETE: { + visitDelete(op); + } break; + case LogicalOperatorType::DISTINCT: { + visitDistinct(op); + } break; + case LogicalOperatorType::EMPTY_RESULT: { + visitEmptyResult(op); + } break; + case LogicalOperatorType::EXPRESSIONS_SCAN: { + visitExpressionsScan(op); + } break; + case LogicalOperatorType::EXTEND: { + visitExtend(op); + } break; + case LogicalOperatorType::FILTER: { + visitFilter(op); + } break; + case LogicalOperatorType::FLATTEN: { + visitFlatten(op); + } break; + case LogicalOperatorType::HASH_JOIN: { + visitHashJoin(op); + } break; + case LogicalOperatorType::INTERSECT: { + visitIntersect(op); + } break; + case LogicalOperatorType::INSERT: { + visitInsert(op); + } break; + case LogicalOperatorType::LIMIT: { + visitLimit(op); + } break; + case LogicalOperatorType::MERGE: { + visitMerge(op); + } break; + case LogicalOperatorType::NODE_LABEL_FILTER: { + visitNodeLabelFilter(op); + } break; + case LogicalOperatorType::ORDER_BY: { + visitOrderBy(op); + } break; + case LogicalOperatorType::PATH_PROPERTY_PROBE: { + visitPathPropertyProbe(op); + } break; + case LogicalOperatorType::PROJECTION: { + visitProjection(op); + } break; + case LogicalOperatorType::RECURSIVE_EXTEND: { + visitRecursiveExtend(op); + } break; + case LogicalOperatorType::SCAN_NODE_TABLE: { + visitScanNodeTable(op); + } break; + case LogicalOperatorType::SET_PROPERTY: { + visitSetProperty(op); + } break; + case LogicalOperatorType::TABLE_FUNCTION_CALL: { + visitTableFunctionCall(op); + } break; + case LogicalOperatorType::UNION_ALL: { + visitUnion(op); + } break; + case LogicalOperatorType::UNWIND: { + visitUnwind(op); + } break; + case LogicalOperatorType::CROSS_PRODUCT: { + visitCrossProduct(op); + } + default: + return; + } +} + +std::shared_ptr LogicalOperatorVisitor::visitOperatorReplaceSwitch( + std::shared_ptr op) { + switch (op->getOperatorType()) { + case LogicalOperatorType::ACCUMULATE: { + return visitAccumulateReplace(op); + } + case LogicalOperatorType::AGGREGATE: { + return visitAggregateReplace(op); + } + case LogicalOperatorType::COPY_FROM: { + return visitCopyFromReplace(op); + } + case LogicalOperatorType::COPY_TO: { + return visitCopyToReplace(op); + } + case LogicalOperatorType::DELETE: { + return visitDeleteReplace(op); + } + case LogicalOperatorType::DISTINCT: { + return visitDistinctReplace(op); + } + case LogicalOperatorType::EMPTY_RESULT: { + return visitEmptyResultReplace(op); + } + case LogicalOperatorType::EXPRESSIONS_SCAN: { + return visitExpressionsScanReplace(op); + } + case LogicalOperatorType::EXTEND: { + return visitExtendReplace(op); + } + case LogicalOperatorType::FILTER: { + return visitFilterReplace(op); + } + case LogicalOperatorType::FLATTEN: { + return visitFlattenReplace(op); + } + case LogicalOperatorType::HASH_JOIN: { + return visitHashJoinReplace(op); + } + case LogicalOperatorType::INTERSECT: { + return visitIntersectReplace(op); + } + case LogicalOperatorType::INSERT: { + return visitInsertReplace(op); + } + case LogicalOperatorType::LIMIT: { + return visitLimitReplace(op); + } + case LogicalOperatorType::MERGE: { + return visitMergeReplace(op); + } + case LogicalOperatorType::NODE_LABEL_FILTER: { + return visitNodeLabelFilterReplace(op); + } + case LogicalOperatorType::ORDER_BY: { + return visitOrderByReplace(op); + } + case LogicalOperatorType::PATH_PROPERTY_PROBE: { + return visitPathPropertyProbeReplace(op); + } + case LogicalOperatorType::PROJECTION: { + return visitProjectionReplace(op); + } + case LogicalOperatorType::RECURSIVE_EXTEND: { + return visitRecursiveExtendReplace(op); + } + case LogicalOperatorType::SCAN_NODE_TABLE: { + return visitScanNodeTableReplace(op); + } + case LogicalOperatorType::SET_PROPERTY: { + return visitSetPropertyReplace(op); + } + case LogicalOperatorType::TABLE_FUNCTION_CALL: { + return visitTableFunctionCallReplace(op); + } + case LogicalOperatorType::UNION_ALL: { + return visitUnionReplace(op); + } + case LogicalOperatorType::UNWIND: { + return visitUnwindReplace(op); + } + case LogicalOperatorType::CROSS_PRODUCT: { + return visitCrossProductReplace(op); + } + default: + return op; + } +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/optimizer.cpp new file mode 100644 index 0000000000..e7f6f04283 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/optimizer.cpp @@ -0,0 +1,81 @@ +#include "optimizer/optimizer.h" + +#include "main/client_context.h" +#include "optimizer/acc_hash_join_optimizer.h" +#include "optimizer/agg_key_dependency_optimizer.h" +#include "optimizer/cardinality_updater.h" +#include "optimizer/correlated_subquery_unnest_solver.h" +#include "optimizer/factorization_rewriter.h" +#include "optimizer/filter_push_down_optimizer.h" +#include "optimizer/limit_push_down_optimizer.h" +#include "optimizer/projection_push_down_optimizer.h" +#include "optimizer/remove_factorization_rewriter.h" +#include "optimizer/remove_unnecessary_join_optimizer.h" +#include "optimizer/schema_populator.h" +#include "optimizer/top_k_optimizer.h" +#include "planner/operator/logical_explain.h" +#include "transaction/transaction.h" + +namespace lbug { +namespace optimizer { + +void Optimizer::optimize(planner::LogicalPlan* plan, main::ClientContext* context, + const planner::CardinalityEstimator& cardinalityEstimator) { + if (context->getClientConfig()->enablePlanOptimizer) { + // Factorization structure should be removed before further optimization can be applied. + auto removeFactorizationRewriter = RemoveFactorizationRewriter(); + removeFactorizationRewriter.rewrite(plan); + + auto correlatedSubqueryUnnestSolver = CorrelatedSubqueryUnnestSolver(nullptr); + correlatedSubqueryUnnestSolver.solve(plan->getLastOperator().get()); + + auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer(); + removeUnnecessaryJoinOptimizer.rewrite(plan); + + auto filterPushDownOptimizer = FilterPushDownOptimizer(context); + filterPushDownOptimizer.rewrite(plan); + + auto projectionPushDownOptimizer = + ProjectionPushDownOptimizer(context->getClientConfig()->recursivePatternSemantic); + projectionPushDownOptimizer.rewrite(plan); + + auto limitPushDownOptimizer = LimitPushDownOptimizer(); + limitPushDownOptimizer.rewrite(plan); + + if (context->getClientConfig()->enableSemiMask) { + // HashJoinSIPOptimizer should be applied after optimizers that manipulate hash join. + auto hashJoinSIPOptimizer = HashJoinSIPOptimizer(); + hashJoinSIPOptimizer.rewrite(plan); + } + + auto topKOptimizer = TopKOptimizer(); + topKOptimizer.rewrite(plan); + + auto factorizationRewriter = FactorizationRewriter(); + factorizationRewriter.rewrite(plan); + + // AggKeyDependencyOptimizer doesn't change factorization structure and thus can be put + // after FactorizationRewriter. + auto aggKeyDependencyOptimizer = AggKeyDependencyOptimizer(); + aggKeyDependencyOptimizer.rewrite(plan); + + // for EXPLAIN LOGICAL we need to update the cardinalities for the optimized plan + // we don't need to do this otherwise as we don't use the cardinalities after planning + if (plan->getLastOperatorRef().getOperatorType() == planner::LogicalOperatorType::EXPLAIN) { + const auto& explain = plan->getLastOperatorRef().cast(); + if (explain.getExplainType() == common::ExplainType::LOGICAL_PLAN) { + auto cardinalityUpdater = CardinalityUpdater(cardinalityEstimator, + transaction::Transaction::Get(*context)); + cardinalityUpdater.rewrite(plan); + } + } + } else { + // we still need to compute the schema for each operator even if we have optimizations + // disabled + auto schemaPopulator = SchemaPopulator{}; + schemaPopulator.rewrite(plan); + } +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/projection_push_down_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/projection_push_down_optimizer.cpp new file mode 100644 index 0000000000..59c642f936 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/projection_push_down_optimizer.cpp @@ -0,0 +1,368 @@ +#include "optimizer/projection_push_down_optimizer.h" + +#include "binder/expression_visitor.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/rec_joins.h" +#include "planner/operator/extend/logical_extend.h" +#include "planner/operator/extend/logical_recursive_extend.h" +#include "planner/operator/logical_accumulate.h" +#include "planner/operator/logical_filter.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_intersect.h" +#include "planner/operator/logical_node_label_filter.h" +#include "planner/operator/logical_order_by.h" +#include "planner/operator/logical_path_property_probe.h" +#include "planner/operator/logical_projection.h" +#include "planner/operator/logical_table_function_call.h" +#include "planner/operator/logical_unwind.h" +#include "planner/operator/persistent/logical_copy_from.h" +#include "planner/operator/persistent/logical_delete.h" +#include "planner/operator/persistent/logical_insert.h" +#include "planner/operator/persistent/logical_merge.h" +#include "planner/operator/persistent/logical_set.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::binder; +using namespace lbug::function; + +namespace lbug { +namespace optimizer { + +void ProjectionPushDownOptimizer::rewrite(LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void ProjectionPushDownOptimizer::visitOperator(LogicalOperator* op) { + visitOperatorSwitch(op); + if (op->getOperatorType() == LogicalOperatorType::PROJECTION) { + // We will start a new optimizer once a projection is encountered. + return; + } + // top-down traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + op->computeFlatSchema(); +} + +void ProjectionPushDownOptimizer::visitPathPropertyProbe(LogicalOperator* op) { + auto& pathPropertyProbe = op->cast(); + auto child = pathPropertyProbe.getChild(0); + KU_ASSERT(child->getOperatorType() == LogicalOperatorType::RECURSIVE_EXTEND); + if (nodeOrRelInUse.contains(pathPropertyProbe.getRel())) { + return; // Path is needed + } + // Path is not needed + pathPropertyProbe.setJoinType(planner::RecursiveJoinType::TRACK_NONE); + auto extend = child->ptrCast(); + auto functionName = extend->getFunction().getFunctionName(); + if (functionName == VarLenJoinsFunction::name) { + extend->getBindDataUnsafe().writePath = false; + } else if (functionName == SingleSPPathsFunction::name) { + extend->setFunction(SingleSPDestinationsFunction::getAlgorithm()); + } else if (functionName == AllSPPathsFunction::name) { + extend->setFunction(AllSPDestinationsFunction::getAlgorithm()); + } else if (functionName == WeightedSPPathsFunction::name) { + extend->setFunction(WeightedSPDestinationsFunction::getAlgorithm()); + } + extend->setResultColumns(extend->getFunction().getResultColumns(extend->getBindData())); +} + +void ProjectionPushDownOptimizer::visitExtend(LogicalOperator* op) { + auto& extend = op->cast(); + const auto boundNodeID = extend.getBoundNode()->getInternalID(); + collectExpressionsInUse(boundNodeID); + const auto nbrNodeID = extend.getNbrNode()->getInternalID(); + extend.setScanNbrID(propertiesInUse.contains(nbrNodeID)); +} + +void ProjectionPushDownOptimizer::visitAccumulate(LogicalOperator* op) { + auto& accumulate = op->constCast(); + if (accumulate.getAccumulateType() != AccumulateType::REGULAR) { + return; + } + auto expressionsBeforePruning = accumulate.getPayloads(); + auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning); + if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) { + return; + } + preAppendProjection(op, 0, expressionsAfterPruning); +} + +void ProjectionPushDownOptimizer::visitFilter(LogicalOperator* op) { + auto& filter = op->constCast(); + collectExpressionsInUse(filter.getPredicate()); +} + +void ProjectionPushDownOptimizer::visitNodeLabelFilter(LogicalOperator* op) { + auto& filter = op->constCast(); + collectExpressionsInUse(filter.getNodeID()); +} + +void ProjectionPushDownOptimizer::visitHashJoin(LogicalOperator* op) { + auto& hashJoin = op->constCast(); + for (auto& [probeJoinKey, buildJoinKey] : hashJoin.getJoinConditions()) { + collectExpressionsInUse(probeJoinKey); + collectExpressionsInUse(buildJoinKey); + } + if (hashJoin.getJoinType() == JoinType::MARK) { // no need to perform push down for mark join. + return; + } + auto expressionsBeforePruning = hashJoin.getExpressionsToMaterialize(); + auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning); + if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) { + // TODO(Xiyang): replace this with a separate optimizer. + return; + } + preAppendProjection(op, 1, expressionsAfterPruning); +} + +void ProjectionPushDownOptimizer::visitIntersect(LogicalOperator* op) { + auto& intersect = op->constCast(); + collectExpressionsInUse(intersect.getIntersectNodeID()); + for (auto i = 0u; i < intersect.getNumBuilds(); ++i) { + auto childIdx = i + 1; // skip probe + auto keyNodeID = intersect.getKeyNodeID(i); + collectExpressionsInUse(keyNodeID); + // Note: we have a potential bug under intersect.cpp. The following code ensures build key + // and intersect key always appear as the first and second column. Should be removed once + // the bug is fixed. + expression_vector expressionsBeforePruning; + expression_vector expressionsAfterPruning; + for (auto& expression : + intersect.getChild(childIdx)->getSchema()->getExpressionsInScope()) { + if (expression->getUniqueName() == intersect.getIntersectNodeID()->getUniqueName() || + expression->getUniqueName() == keyNodeID->getUniqueName()) { + continue; + } + expressionsBeforePruning.push_back(expression); + } + expressionsAfterPruning.push_back(keyNodeID); + expressionsAfterPruning.push_back(intersect.getIntersectNodeID()); + for (auto& expression : pruneExpressions(expressionsBeforePruning)) { + expressionsAfterPruning.push_back(expression); + } + if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) { + return; + } + + preAppendProjection(op, childIdx, expressionsAfterPruning); + } +} + +void ProjectionPushDownOptimizer::visitProjection(LogicalOperator* op) { + // Projection operator defines the start of a projection push down until the next projection + // operator is seen. + ProjectionPushDownOptimizer optimizer(this->semantic); + auto& projection = op->constCast(); + for (auto& expression : projection.getExpressionsToProject()) { + optimizer.collectExpressionsInUse(expression); + } + optimizer.visitOperator(op->getChild(0).get()); +} + +void ProjectionPushDownOptimizer::visitOrderBy(LogicalOperator* op) { + auto& orderBy = op->constCast(); + for (auto& expression : orderBy.getExpressionsToOrderBy()) { + collectExpressionsInUse(expression); + } + auto expressionsBeforePruning = orderBy.getChild(0)->getSchema()->getExpressionsInScope(); + auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning); + if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) { + return; + } + preAppendProjection(op, 0, expressionsAfterPruning); +} + +void ProjectionPushDownOptimizer::visitUnwind(LogicalOperator* op) { + auto& unwind = op->constCast(); + collectExpressionsInUse(unwind.getInExpr()); +} + +void ProjectionPushDownOptimizer::visitInsert(LogicalOperator* op) { + auto& insert = op->constCast(); + for (auto& info : insert.getInfos()) { + visitInsertInfo(info); + } +} + +void ProjectionPushDownOptimizer::visitDelete(LogicalOperator* op) { + auto& delete_ = op->constCast(); + auto& infos = delete_.getInfos(); + KU_ASSERT(!infos.empty()); + switch (infos[0].tableType) { + case TableType::NODE: { + for (auto& info : infos) { + auto& node = info.pattern->constCast(); + collectExpressionsInUse(node.getInternalID()); + for (auto entry : node.getEntries()) { + collectExpressionsInUse(node.getPrimaryKey(entry->getTableID())); + } + } + } break; + case TableType::REL: { + for (auto& info : infos) { + auto& rel = info.pattern->constCast(); + collectExpressionsInUse(rel.getSrcNode()->getInternalID()); + collectExpressionsInUse(rel.getDstNode()->getInternalID()); + KU_ASSERT(rel.getRelType() == QueryRelType::NON_RECURSIVE); + if (!rel.isEmpty()) { + collectExpressionsInUse(rel.getInternalID()); + } + } + } break; + default: + KU_UNREACHABLE; + } +} + +void ProjectionPushDownOptimizer::visitMerge(LogicalOperator* op) { + auto& merge = op->constCast(); + collectExpressionsInUse(merge.getExistenceMark()); + for (auto& info : merge.getInsertNodeInfos()) { + visitInsertInfo(info); + } + for (auto& info : merge.getInsertRelInfos()) { + visitInsertInfo(info); + } + for (auto& info : merge.getOnCreateSetNodeInfos()) { + visitSetInfo(info); + } + for (auto& info : merge.getOnMatchSetNodeInfos()) { + visitSetInfo(info); + } + for (auto& info : merge.getOnCreateSetRelInfos()) { + visitSetInfo(info); + } + for (auto& info : merge.getOnMatchSetRelInfos()) { + visitSetInfo(info); + } +} + +void ProjectionPushDownOptimizer::visitSetProperty(LogicalOperator* op) { + auto& set = op->constCast(); + for (auto& info : set.getInfos()) { + visitSetInfo(info); + } +} + +void ProjectionPushDownOptimizer::visitCopyFrom(LogicalOperator* op) { + auto& copyFrom = op->constCast(); + for (auto& expr : copyFrom.getInfo()->getSourceColumns()) { + collectExpressionsInUse(expr); + } + if (copyFrom.getInfo()->offset) { + collectExpressionsInUse(copyFrom.getInfo()->offset); + } +} + +void ProjectionPushDownOptimizer::visitTableFunctionCall(LogicalOperator* op) { + auto& tableFunctionCall = op->cast(); + std::vector columnSkips; + for (auto& column : tableFunctionCall.getBindData()->columns) { + columnSkips.push_back(!variablesInUse.contains(column)); + } + tableFunctionCall.setColumnSkips(std::move(columnSkips)); +} + +void ProjectionPushDownOptimizer::visitSetInfo(const binder::BoundSetPropertyInfo& info) { + switch (info.tableType) { + case TableType::NODE: { + auto& node = info.pattern->constCast(); + collectExpressionsInUse(node.getInternalID()); + } break; + case TableType::REL: { + auto& rel = info.pattern->constCast(); + collectExpressionsInUse(rel.getSrcNode()->getInternalID()); + collectExpressionsInUse(rel.getDstNode()->getInternalID()); + collectExpressionsInUse(rel.getInternalID()); + } break; + default: + KU_UNREACHABLE; + } + collectExpressionsInUse(info.columnData); +} + +void ProjectionPushDownOptimizer::visitInsertInfo(const LogicalInsertInfo& info) { + if (info.tableType == TableType::REL) { + auto& rel = info.pattern->constCast(); + collectExpressionsInUse(rel.getSrcNode()->getInternalID()); + collectExpressionsInUse(rel.getDstNode()->getInternalID()); + collectExpressionsInUse(rel.getInternalID()); + } + for (auto i = 0u; i < info.columnExprs.size(); ++i) { + if (info.isReturnColumnExprs[i]) { + collectExpressionsInUse(info.columnExprs[i]); + } + collectExpressionsInUse(info.columnDataExprs[i]); + } +} + +// See comments above this class for how to collect expressions in use. +void ProjectionPushDownOptimizer::collectExpressionsInUse( + std::shared_ptr expression) { + switch (expression->expressionType) { + case ExpressionType::PROPERTY: { + propertiesInUse.insert(expression); + return; + } + case ExpressionType::VARIABLE: { + variablesInUse.insert(expression); + return; + } + case ExpressionType::PATTERN: { + nodeOrRelInUse.insert(expression); + for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) { + collectExpressionsInUse(child); + } + return; + } + default: + for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) { + collectExpressionsInUse(child); + } + } +} + +binder::expression_vector ProjectionPushDownOptimizer::pruneExpressions( + const binder::expression_vector& expressions) { + expression_set expressionsAfterPruning; + for (auto& expression : expressions) { + switch (expression->expressionType) { + case ExpressionType::PROPERTY: { + if (propertiesInUse.contains(expression)) { + expressionsAfterPruning.insert(expression); + } + } break; + case ExpressionType::VARIABLE: { + if (variablesInUse.contains(expression)) { + expressionsAfterPruning.insert(expression); + } + } break; + case ExpressionType::PATTERN: { + if (nodeOrRelInUse.contains(expression)) { + expressionsAfterPruning.insert(expression); + } + } break; + default: // We don't track other expression types so always assume they will be in use. + expressionsAfterPruning.insert(expression); + } + } + return expression_vector{expressionsAfterPruning.begin(), expressionsAfterPruning.end()}; +} + +void ProjectionPushDownOptimizer::preAppendProjection(LogicalOperator* op, idx_t childIdx, + binder::expression_vector expressions) { + if (expressions.empty()) { + // We don't have a way to handle + return; + } + auto projection = + std::make_shared(std::move(expressions), op->getChild(childIdx)); + projection->computeFlatSchema(); + op->setChild(childIdx, std::move(projection)); +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/remove_factorization_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/remove_factorization_rewriter.cpp new file mode 100644 index 0000000000..4b54b07556 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/remove_factorization_rewriter.cpp @@ -0,0 +1,39 @@ +#include "optimizer/remove_factorization_rewriter.h" + +#include "common/exception/internal.h" +#include "optimizer/logical_operator_collector.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void RemoveFactorizationRewriter::rewrite(planner::LogicalPlan* plan) { + auto root = plan->getLastOperator(); + visitOperator(root); + auto collector = LogicalFlattenCollector(); + collector.collect(root.get()); + if (collector.hasOperators()) { + throw InternalException("Remove factorization rewriter failed."); + } +} + +std::shared_ptr RemoveFactorizationRewriter::visitOperator( + const std::shared_ptr& op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + op->setChild(i, visitOperator(op->getChild(i))); + } + auto result = visitOperatorReplaceSwitch(op); + result->computeFlatSchema(); + return result; +} + +std::shared_ptr RemoveFactorizationRewriter::visitFlattenReplace( + std::shared_ptr op) { + return op->getChild(0); +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/remove_unnecessary_join_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/remove_unnecessary_join_optimizer.cpp new file mode 100644 index 0000000000..c307ff7ef2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/remove_unnecessary_join_optimizer.cpp @@ -0,0 +1,58 @@ +#include "optimizer/remove_unnecessary_join_optimizer.h" + +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/scan/logical_scan_node_table.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace optimizer { + +void RemoveUnnecessaryJoinOptimizer::rewrite(LogicalPlan* plan) { + visitOperator(plan->getLastOperator()); +} + +std::shared_ptr RemoveUnnecessaryJoinOptimizer::visitOperator( + const std::shared_ptr& op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + op->setChild(i, visitOperator(op->getChild(i))); + } + auto result = visitOperatorReplaceSwitch(op); + result->computeFlatSchema(); + return result; +} + +std::shared_ptr RemoveUnnecessaryJoinOptimizer::visitHashJoinReplace( + std::shared_ptr op) { + auto hashJoin = (LogicalHashJoin*)op.get(); + switch (hashJoin->getJoinType()) { + case JoinType::MARK: + case JoinType::LEFT: { + // Do not prune no-trivial join type + return op; + } + default: + break; + } + // TODO(Xiyang): Double check on these changes here. + if (op->getChild(1)->getOperatorType() == LogicalOperatorType::SCAN_NODE_TABLE) { + const auto scanNode = ku_dynamic_cast(op->getChild(1).get()); + if (scanNode->getProperties().empty()) { + // Build side is trivial. Prune build side. + return op->getChild(0); + } + } + if (op->getChild(0)->getOperatorType() == LogicalOperatorType::SCAN_NODE_TABLE) { + const auto scanNode = ku_dynamic_cast(op->getChild(0).get()); + if (scanNode->getProperties().empty()) { + // Probe side is trivial. Prune probe side. + return op->getChild(1); + } + } + return op; +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/schema_populator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/schema_populator.cpp new file mode 100644 index 0000000000..fd7447329a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/schema_populator.cpp @@ -0,0 +1,15 @@ +#include "optimizer/schema_populator.h" + +namespace lbug::optimizer { + +static void populateSchemaRecursive(planner::LogicalOperator* op) { + for (auto i = 0u; i < op->getNumChildren(); ++i) { + populateSchemaRecursive(op->getChild(i).get()); + } + op->computeFactorizedSchema(); +} + +void SchemaPopulator::rewrite(planner::LogicalPlan* plan) { + populateSchemaRecursive(plan->getLastOperator().get()); +} +} // namespace lbug::optimizer diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/top_k_optimizer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/top_k_optimizer.cpp new file mode 100644 index 0000000000..4cbae4be32 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/optimizer/top_k_optimizer.cpp @@ -0,0 +1,63 @@ +#include "optimizer/top_k_optimizer.h" + +#include "planner/operator/logical_limit.h" +#include "planner/operator/logical_order_by.h" + +using namespace lbug::planner; +using namespace lbug::common; + +namespace lbug { +namespace optimizer { + +void TopKOptimizer::rewrite(planner::LogicalPlan* plan) { + plan->setLastOperator(visitOperator(plan->getLastOperator())); +} + +std::shared_ptr TopKOptimizer::visitOperator( + const std::shared_ptr& op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + op->setChild(i, visitOperator(op->getChild(i))); + } + auto result = visitOperatorReplaceSwitch(op); + result->computeFlatSchema(); + return result; +} + +// TODO(Xiyang): we should probably remove the projection between ORDER BY and MULTIPLICITY REDUCER +// We search for pattern +// ORDER BY -> PROJECTION -> MULTIPLICITY REDUCER -> LIMIT +// ORDER BY -> MULTIPLICITY REDUCER -> LIMIT +// and rewrite as TOP_K +std::shared_ptr TopKOptimizer::visitLimitReplace( + std::shared_ptr op) { + auto limit = op->ptrCast(); + if (!limit->hasLimitNum()) { + return op; // only skip no limit. No need to rewrite + } + auto multiplicityReducer = limit->getChild(0); + KU_ASSERT(multiplicityReducer->getOperatorType() == LogicalOperatorType::MULTIPLICITY_REDUCER); + auto projectionOrOrderBy = multiplicityReducer->getChild(0); + std::shared_ptr orderBy; + if (projectionOrOrderBy->getOperatorType() == LogicalOperatorType::PROJECTION) { + if (projectionOrOrderBy->getChild(0)->getOperatorType() != LogicalOperatorType::ORDER_BY) { + return op; + } + orderBy = std::static_pointer_cast(projectionOrOrderBy->getChild(0)); + } else if (projectionOrOrderBy->getOperatorType() == LogicalOperatorType::ORDER_BY) { + orderBy = std::static_pointer_cast(projectionOrOrderBy); + } else { + return op; + } + KU_ASSERT(orderBy != nullptr); + if (limit->hasLimitNum()) { + orderBy->setLimitNum(limit->getLimitNum()); + } + if (limit->hasSkipNum()) { + orderBy->setSkipNum(limit->getSkipNum()); + } + return projectionOrOrderBy; +} + +} // namespace optimizer +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/CMakeLists.txt new file mode 100644 index 0000000000..46bd1217f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/CMakeLists.txt @@ -0,0 +1,15 @@ +add_subdirectory(antlr_parser) +add_subdirectory(expression) +add_subdirectory(transform) +add_subdirectory(visitor) + +add_library(lbug_parser + OBJECT + create_macro.cpp + parser.cpp + parsed_statement_visitor.cpp + transformer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/CMakeLists.txt new file mode 100644 index 0000000000..4c1121ff05 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_parser_antlr_parser + OBJECT + lbug_cypher_parser.cpp + parser_error_listener.cpp + parser_error_strategy.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/lbug_cypher_parser.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/lbug_cypher_parser.cpp new file mode 100644 index 0000000000..08b3da6427 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/lbug_cypher_parser.cpp @@ -0,0 +1,43 @@ +#include "parser/antlr_parser/lbug_cypher_parser.h" + +#include + +namespace lbug { +namespace parser { + +void LbugCypherParser::notifyQueryNotConcludeWithReturn(antlr4::Token* startToken) { + auto errorMsg = "Query must conclude with RETURN clause"; + notifyErrorListeners(startToken, errorMsg, nullptr); +} + +void LbugCypherParser::notifyNodePatternWithoutParentheses(std::string nodeName, + antlr4::Token* startToken) { + auto errorMsg = + "Parentheses are required to identify nodes in patterns, i.e. (" + nodeName + ")"; + notifyErrorListeners(startToken, errorMsg, nullptr); +} + +void LbugCypherParser::notifyInvalidNotEqualOperator(antlr4::Token* startToken) { + auto errorMsg = "Unknown operation '!=' (you probably meant to use '<>', which is the operator " + "for inequality testing.)"; + notifyErrorListeners(startToken, errorMsg, nullptr); +} + +void LbugCypherParser::notifyEmptyToken(antlr4::Token* startToken) { + auto errorMsg = + "'' is not a valid token name. Token names cannot be empty or contain any null-bytes"; + notifyErrorListeners(startToken, errorMsg, nullptr); +} + +void LbugCypherParser::notifyReturnNotAtEnd(antlr4::Token* startToken) { + auto errorMsg = "RETURN can only be used at the end of the query"; + notifyErrorListeners(startToken, errorMsg, nullptr); +} + +void LbugCypherParser::notifyNonBinaryComparison(antlr4::Token* startToken) { + auto errorMsg = "Non-binary comparison (e.g. a=b=c) is not supported"; + notifyErrorListeners(startToken, errorMsg, nullptr); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/parser_error_listener.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/parser_error_listener.cpp new file mode 100644 index 0000000000..ef265ed665 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/parser_error_listener.cpp @@ -0,0 +1,36 @@ +#include "parser/antlr_parser/parser_error_listener.h" + +#include "common/exception/parser.h" +#include "common/string_utils.h" + +using namespace antlr4; +using namespace lbug::common; + +namespace lbug { +namespace parser { + +void ParserErrorListener::syntaxError(Recognizer* recognizer, Token* offendingSymbol, size_t line, + size_t charPositionInLine, const std::string& msg, std::exception_ptr /*e*/) { + auto finalError = msg + " (line: " + std::to_string(line) + + ", offset: " + std::to_string(charPositionInLine) + ")\n" + + formatUnderLineError(*recognizer, *offendingSymbol, line, charPositionInLine); + throw ParserException(finalError); +} + +std::string ParserErrorListener::formatUnderLineError(Recognizer& recognizer, + const Token& offendingToken, size_t line, size_t charPositionInLine) { + auto tokens = (CommonTokenStream*)recognizer.getInputStream(); + auto input = tokens->getTokenSource()->getInputStream()->toString(); + auto errorLine = StringUtils::split(input, "\n", false)[line - 1]; + auto underLine = std::string(" "); + for (auto i = 0u; i < charPositionInLine; ++i) { + underLine += " "; + } + for (auto i = offendingToken.getStartIndex(); i <= offendingToken.getStopIndex(); ++i) { + underLine += "^"; + } + return "\"" + errorLine + "\"\n" + underLine; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/parser_error_strategy.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/parser_error_strategy.cpp new file mode 100644 index 0000000000..e88c900777 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/antlr_parser/parser_error_strategy.cpp @@ -0,0 +1,21 @@ +#include "parser/antlr_parser/parser_error_strategy.h" + +namespace lbug { +namespace parser { + +void ParserErrorStrategy::reportNoViableAlternative(antlr4::Parser* recognizer, + const antlr4::NoViableAltException& e) { + auto tokens = recognizer->getTokenStream(); + auto errorMsg = + tokens ? + antlr4::Token::EOF == e.getStartToken()->getType() ? + "Unexpected end of input" : + "Invalid input <" + tokens->getText(e.getStartToken(), e.getOffendingToken()) + ">" : + "Unknown input"; + auto expectedRuleName = recognizer->getRuleNames()[recognizer->getContext()->getRuleIndex()]; + errorMsg += ": expected rule " + expectedRuleName; + recognizer->notifyErrorListeners(e.getOffendingToken(), errorMsg, make_exception_ptr(e)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/create_macro.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/create_macro.cpp new file mode 100644 index 0000000000..2035738a3e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/create_macro.cpp @@ -0,0 +1,15 @@ +#include "parser/create_macro.h" + +namespace lbug { +namespace parser { + +std::vector> CreateMacro::getDefaultArgs() const { + std::vector> defaultArgsToReturn; + for (auto& defaultArg : defaultArgs) { + defaultArgsToReturn.emplace_back(defaultArg.first, defaultArg.second.get()); + } + return defaultArgsToReturn; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/CMakeLists.txt new file mode 100644 index 0000000000..c17a8ea8aa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(lbug_parser_expression + OBJECT + parsed_case_expression.cpp + parsed_expression.cpp + parsed_expression_visitor.cpp + parsed_function_expression.cpp + parsed_property_expression.cpp + parsed_variable_expression.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_case_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_case_expression.cpp new file mode 100644 index 0000000000..e3ea092ca7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_case_expression.cpp @@ -0,0 +1,52 @@ +#include "parser/expression/parsed_case_expression.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +void ParsedCaseAlternative::serialize(Serializer& serializer) const { + whenExpression->serialize(serializer); + thenExpression->serialize(serializer); +} + +ParsedCaseAlternative ParsedCaseAlternative::deserialize(Deserializer& deserializer) { + auto whenExpression = ParsedExpression::deserialize(deserializer); + auto thenExpression = ParsedExpression::deserialize(deserializer); + return ParsedCaseAlternative(std::move(whenExpression), std::move(thenExpression)); +} + +std::unique_ptr ParsedCaseExpression::deserialize( + Deserializer& deserializer) { + std::unique_ptr caseExpression; + deserializer.deserializeOptionalValue(caseExpression); + std::vector caseAlternatives; + deserializer.deserializeVector(caseAlternatives); + std::unique_ptr elseExpression; + deserializer.deserializeOptionalValue(elseExpression); + return std::make_unique(std::move(caseExpression), + std::move(caseAlternatives), std::move(elseExpression)); +} + +std::unique_ptr ParsedCaseExpression::copy() const { + std::vector caseAlternativesCopy; + caseAlternativesCopy.reserve(caseAlternatives.size()); + for (auto& caseAlternative : caseAlternatives) { + caseAlternativesCopy.push_back(caseAlternative); + } + return std::make_unique(alias, rawName, copyVector(children), + caseExpression ? caseExpression->copy() : nullptr, std::move(caseAlternativesCopy), + elseExpression ? elseExpression->copy() : nullptr); +} + +void ParsedCaseExpression::serializeInternal(Serializer& serializer) const { + serializer.serializeOptionalValue(caseExpression); + serializer.serializeVector(caseAlternatives); + serializer.serializeOptionalValue(elseExpression); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_expression.cpp new file mode 100644 index 0000000000..5fbdd7fe77 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_expression.cpp @@ -0,0 +1,82 @@ +#include "parser/expression/parsed_expression.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "parser/expression/parsed_case_expression.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/expression/parsed_literal_expression.h" +#include "parser/expression/parsed_parameter_expression.h" +#include "parser/expression/parsed_property_expression.h" +#include "parser/expression/parsed_subquery_expression.h" +#include "parser/expression/parsed_variable_expression.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +ParsedExpression::ParsedExpression(ExpressionType type, std::unique_ptr child, + std::string rawName) + : type{type}, rawName{std::move(rawName)} { + children.push_back(std::move(child)); +} + +ParsedExpression::ParsedExpression(ExpressionType type, std::unique_ptr left, + std::unique_ptr right, std::string rawName) + : type{type}, rawName{std::move(rawName)} { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +void ParsedExpression::serialize(Serializer& serializer) const { + serializer.serializeValue(type); + serializer.serializeValue(alias); + serializer.serializeValue(rawName); + serializer.serializeVectorOfPtrs(children); + serializeInternal(serializer); +} + +std::unique_ptr ParsedExpression::deserialize(Deserializer& deserializer) { + auto type = ExpressionType::INVALID; + std::string alias; + std::string rawName; + parsed_expr_vector children; + deserializer.deserializeValue(type); + deserializer.deserializeValue(alias); + deserializer.deserializeValue(rawName); + deserializer.deserializeVectorOfPtrs(children); + std::unique_ptr parsedExpression; + switch (type) { + case ExpressionType::CASE_ELSE: { + parsedExpression = ParsedCaseExpression::deserialize(deserializer); + } break; + case ExpressionType::FUNCTION: { + parsedExpression = ParsedFunctionExpression::deserialize(deserializer); + } break; + case ExpressionType::LITERAL: { + parsedExpression = ParsedLiteralExpression::deserialize(deserializer); + } break; + case ExpressionType::PARAMETER: { + parsedExpression = ParsedParameterExpression::deserialize(deserializer); + } break; + case ExpressionType::PROPERTY: { + parsedExpression = ParsedPropertyExpression::deserialize(deserializer); + } break; + case ExpressionType::SUBQUERY: { + parsedExpression = ParsedSubqueryExpression::deserialize(deserializer); + } break; + case ExpressionType::VARIABLE: { + parsedExpression = ParsedVariableExpression::deserialize(deserializer); + } break; + default: { + KU_UNREACHABLE; + } + } + parsedExpression->alias = std::move(alias); + parsedExpression->rawName = std::move(rawName); + parsedExpression->children = std::move(children); + return parsedExpression; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_expression_visitor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_expression_visitor.cpp new file mode 100644 index 0000000000..81fa4cefc2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_expression_visitor.cpp @@ -0,0 +1,229 @@ +#include "parser/expression/parsed_expression_visitor.h" + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/function_catalog_entry.h" +#include "common/exception/not_implemented.h" +#include "parser/expression/parsed_case_expression.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/expression/parsed_lambda_expression.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace parser { + +void ParsedExpressionVisitor::visit(const ParsedExpression* expr) { + visitChildren(*expr); + visitSwitch(expr); +} + +void ParsedExpressionVisitor::visitUnsafe(ParsedExpression* expr) { + visitChildrenUnsafe(*expr); + visitSwitchUnsafe(expr); +} + +void ParsedExpressionVisitor::visitSwitch(const ParsedExpression* expr) { + switch (expr->getExpressionType()) { + case ExpressionType::OR: + case ExpressionType::XOR: + case ExpressionType::AND: + case ExpressionType::NOT: + case ExpressionType::EQUALS: + case ExpressionType::NOT_EQUALS: + case ExpressionType::GREATER_THAN: + case ExpressionType::GREATER_THAN_EQUALS: + case ExpressionType::LESS_THAN: + case ExpressionType::LESS_THAN_EQUALS: + case ExpressionType::IS_NULL: + case ExpressionType::IS_NOT_NULL: + case ExpressionType::FUNCTION: { + visitFunctionExpr(expr); + } break; + case ExpressionType::AGGREGATE_FUNCTION: { + visitAggFunctionExpr(expr); + } break; + case ExpressionType::PROPERTY: { + visitPropertyExpr(expr); + } break; + case ExpressionType::LITERAL: { + visitLiteralExpr(expr); + } break; + case ExpressionType::VARIABLE: { + visitVariableExpr(expr); + } break; + case ExpressionType::PATH: { + visitPathExpr(expr); + } break; + case ExpressionType::PATTERN: { + visitNodeRelExpr(expr); + } break; + case ExpressionType::PARAMETER: { + visitParamExpr(expr); + } break; + case ExpressionType::SUBQUERY: { + visitSubqueryExpr(expr); + } break; + case ExpressionType::CASE_ELSE: { + visitCaseExpr(expr); + } break; + case ExpressionType::GRAPH: { + visitGraphExpr(expr); + } break; + case ExpressionType::LAMBDA: { + visitLambdaExpr(expr); + } break; + case ExpressionType::STAR: { + visitStar(expr); + } break; + // LCOV_EXCL_START + default: + throw NotImplementedException("ExpressionVisitor::visitSwitch"); + // LCOV_EXCL_STOP + } +} + +void ParsedExpressionVisitor::visitChildren(const ParsedExpression& expr) { + switch (expr.getExpressionType()) { + case ExpressionType::CASE_ELSE: { + visitCaseChildren(expr); + } break; + case ExpressionType::LAMBDA: { + auto& lambda = expr.constCast(); + visit(lambda.getFunctionExpr()); + } break; + default: { + for (auto i = 0u; i < expr.getNumChildren(); ++i) { + visit(expr.getChild(i)); + } + } + } +} + +void ParsedExpressionVisitor::visitChildrenUnsafe(ParsedExpression& expr) { + switch (expr.getExpressionType()) { + case ExpressionType::CASE_ELSE: { + visitCaseChildrenUnsafe(expr); + } break; + default: { + for (auto i = 0u; i < expr.getNumChildren(); ++i) { + visitUnsafe(expr.getChild(i)); + } + } + } +} + +void ParsedExpressionVisitor::visitCaseChildren(const ParsedExpression& expr) { + auto& caseExpr = expr.constCast(); + if (caseExpr.hasCaseExpression()) { + visit(caseExpr.getCaseExpression()); + } + for (auto i = 0u; i < caseExpr.getNumCaseAlternative(); ++i) { + auto alternative = caseExpr.getCaseAlternative(i); + visit(alternative->whenExpression.get()); + visit(alternative->thenExpression.get()); + } + if (caseExpr.hasElseExpression()) { + visit(caseExpr.getElseExpression()); + } +} + +void ParsedExpressionVisitor::visitCaseChildrenUnsafe(ParsedExpression& expr) { + auto& caseExpr = expr.cast(); + if (caseExpr.hasCaseExpression()) { + visitUnsafe(caseExpr.getCaseExpression()); + } + for (auto i = 0u; i < caseExpr.getNumCaseAlternative(); ++i) { + auto alternative = caseExpr.getCaseAlternative(i); + visitUnsafe(alternative->whenExpression.get()); + visitUnsafe(alternative->thenExpression.get()); + } + if (caseExpr.hasElseExpression()) { + visitUnsafe(caseExpr.getElseExpression()); + } +} + +void ReadWriteExprAnalyzer::visitFunctionExpr(const ParsedExpression* expr) { + if (expr->getExpressionType() != ExpressionType::FUNCTION) { + // Can be AND/OR/... which guarantees to be readonly. + return; + } + auto funcName = expr->constCast().getFunctionName(); + auto catalog = Catalog::Get(*context); + // Assume user cannot add function with sideeffect, i.e. all non-readonly function is + // registered when database starts. + auto transaction = &transaction::DUMMY_TRANSACTION; + if (!catalog->containsFunction(transaction, funcName)) { + return; + } + auto entry = catalog->getFunctionEntry(transaction, funcName); + if (entry->getType() != CatalogEntryType::SCALAR_FUNCTION_ENTRY) { + // Can be macro function which guarantees to be readonly. + return; + } + auto& funcSet = entry->constPtrCast()->getFunctionSet(); + KU_ASSERT(!funcSet.empty()); + if (!funcSet[0]->isReadOnly) { + readOnly = false; + } +} + +std::unique_ptr MacroParameterReplacer::replace( + std::unique_ptr input) { + if (nameToExpr.contains(input->getRawName())) { + return nameToExpr.at(input->getRawName())->copy(); + } + visitUnsafe(input.get()); + return input; +} + +void MacroParameterReplacer::visitSwitchUnsafe(ParsedExpression* expr) { + switch (expr->getExpressionType()) { + case ExpressionType::CASE_ELSE: { + auto& caseExpr = expr->cast(); + if (caseExpr.hasCaseExpression()) { + auto replace = getReplace(caseExpr.getCaseExpression()->getRawName()); + if (replace) { + caseExpr.setCaseExpression(std::move(replace)); + } + } + for (auto i = 0u; i < caseExpr.getNumCaseAlternative(); i++) { + auto caseAlternative = caseExpr.getCaseAlternativeUnsafe(i); + auto whenReplace = getReplace(caseAlternative->whenExpression->getRawName()); + auto thenReplace = getReplace(caseAlternative->thenExpression->getRawName()); + if (whenReplace) { + caseAlternative->whenExpression = std::move(whenReplace); + } + if (thenReplace) { + caseAlternative->thenExpression = std::move(thenReplace); + } + } + if (caseExpr.hasElseExpression()) { + auto replace = getReplace(caseExpr.getElseExpression()->getRawName()); + if (replace) { + caseExpr.setElseExpression(std::move(replace)); + } + } + } break; + default: { + for (auto i = 0u; i < expr->getNumChildren(); ++i) { + auto child = expr->getChild(i); + auto replace = getReplace(child->getRawName()); + if (replace) { + expr->setChild(i, std::move(replace)); + } + } + } + } +} + +std::unique_ptr MacroParameterReplacer::getReplace(const std::string& name) { + if (nameToExpr.contains(name)) { + return nameToExpr.at(name)->copy(); + } + return nullptr; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_function_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_function_expression.cpp new file mode 100644 index 0000000000..e61063e38c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_function_expression.cpp @@ -0,0 +1,31 @@ +#include "parser/expression/parsed_function_expression.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr ParsedFunctionExpression::deserialize( + Deserializer& deserializer) { + bool isDistinct = false; + deserializer.deserializeValue(isDistinct); + std::string functionName; + deserializer.deserializeValue(functionName); + std::vector optionalArguments; + deserializer.deserializeVector(optionalArguments); + auto result = std::make_unique(std::move(functionName), isDistinct); + result->setOptionalArguments(std::move(optionalArguments)); + return result; +} + +void ParsedFunctionExpression::serializeInternal(Serializer& serializer) const { + serializer.serializeValue(isDistinct); + serializer.serializeValue(functionName); + serializer.serializeVector(optionalArguments); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_property_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_property_expression.cpp new file mode 100644 index 0000000000..fa4d613110 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_property_expression.cpp @@ -0,0 +1,18 @@ +#include "parser/expression/parsed_property_expression.h" + +#include "common/serializer/deserializer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr ParsedPropertyExpression::deserialize( + Deserializer& deserializer) { + std::string propertyName; + deserializer.deserializeValue(propertyName); + return std::make_unique(std::move(propertyName)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_variable_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_variable_expression.cpp new file mode 100644 index 0000000000..07bdcb9eeb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/expression/parsed_variable_expression.cpp @@ -0,0 +1,18 @@ +#include "parser/expression/parsed_variable_expression.h" + +#include "common/serializer/deserializer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr ParsedVariableExpression::deserialize( + Deserializer& deserializer) { + std::string variableName; + deserializer.deserializeValue(variableName); + return std::make_unique(std::move(variableName)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/parsed_statement_visitor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/parsed_statement_visitor.cpp new file mode 100644 index 0000000000..4bd54f65a3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/parsed_statement_visitor.cpp @@ -0,0 +1,155 @@ +#include "parser/parsed_statement_visitor.h" + +#include "common/cast.h" +#include "parser/explain_statement.h" +#include "parser/query/regular_query.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +void StatementVisitor::visit(const Statement& statement) { + switch (statement.getStatementType()) { + case StatementType::QUERY: { + visitQuery(statement); + } break; + case StatementType::CREATE_SEQUENCE: { + visitCreateSequence(statement); + } break; + case StatementType::DROP: { + visitDrop(statement); + } break; + case StatementType::CREATE_TABLE: { + visitCreateTable(statement); + } break; + case StatementType::CREATE_TYPE: { + visitCreateType(statement); + } break; + case StatementType::ALTER: { + visitAlter(statement); + } break; + case StatementType::COPY_FROM: { + visitCopyFrom(statement); + } break; + case StatementType::COPY_TO: { + visitCopyTo(statement); + } break; + case StatementType::STANDALONE_CALL: { + visitStandaloneCall(statement); + } break; + case StatementType::EXPLAIN: { + visitExplain(statement); + } break; + case StatementType::CREATE_MACRO: { + visitCreateMacro(statement); + } break; + case StatementType::TRANSACTION: { + visitTransaction(statement); + } break; + case StatementType::EXTENSION: { + visitExtension(statement); + } break; + case StatementType::EXPORT_DATABASE: { + visitExportDatabase(statement); + } break; + case StatementType::IMPORT_DATABASE: { + visitImportDatabase(statement); + } break; + case StatementType::ATTACH_DATABASE: { + visitAttachDatabase(statement); + } break; + case StatementType::DETACH_DATABASE: { + visitDetachDatabase(statement); + } break; + case StatementType::USE_DATABASE: { + visitUseDatabase(statement); + } break; + case StatementType::STANDALONE_CALL_FUNCTION: { + visitStandaloneCallFunction(statement); + } break; + case StatementType::EXTENSION_CLAUSE: { + visitExtensionClause(statement); + } break; + default: + KU_UNREACHABLE; + } +} + +void StatementVisitor::visitExplain(const Statement& statement) { + auto& explainStatement = ku_dynamic_cast(statement); + visit(*explainStatement.getStatementToExplain()); +} + +void StatementVisitor::visitQuery(const Statement& statement) { + auto& regularQuery = ku_dynamic_cast(statement); + for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) { + visitSingleQuery(regularQuery.getSingleQuery(i)); + } +} + +void StatementVisitor::visitSingleQuery(const SingleQuery* singleQuery) { + for (auto i = 0u; i < singleQuery->getNumQueryParts(); ++i) { + visitQueryPart(singleQuery->getQueryPart(i)); + } + for (auto i = 0u; i < singleQuery->getNumReadingClauses(); ++i) { + visitReadingClause(singleQuery->getReadingClause(i)); + } + for (auto i = 0u; i < singleQuery->getNumUpdatingClauses(); ++i) { + visitUpdatingClause(singleQuery->getUpdatingClause(i)); + } + if (singleQuery->hasReturnClause()) { + visitReturnClause(singleQuery->getReturnClause()); + } +} + +void StatementVisitor::visitQueryPart(const QueryPart* queryPart) { + for (auto i = 0u; i < queryPart->getNumReadingClauses(); ++i) { + visitReadingClause(queryPart->getReadingClause(i)); + } + for (auto i = 0u; i < queryPart->getNumUpdatingClauses(); ++i) { + visitUpdatingClause(queryPart->getUpdatingClause(i)); + } + visitWithClause(queryPart->getWithClause()); +} + +void StatementVisitor::visitReadingClause(const ReadingClause* readingClause) { + switch (readingClause->getClauseType()) { + case ClauseType::MATCH: { + visitMatch(readingClause); + } break; + case ClauseType::UNWIND: { + visitUnwind(readingClause); + } break; + case ClauseType::IN_QUERY_CALL: { + visitInQueryCall(readingClause); + } break; + case ClauseType::LOAD_FROM: { + visitLoadFrom(readingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +void StatementVisitor::visitUpdatingClause(const UpdatingClause* updatingClause) { + switch (updatingClause->getClauseType()) { + case ClauseType::SET: { + visitSet(updatingClause); + } break; + case ClauseType::DELETE_: { + visitDelete(updatingClause); + } break; + case ClauseType::INSERT: { + visitInsert(updatingClause); + } break; + case ClauseType::MERGE: { + visitMerge(updatingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/parser.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/parser.cpp new file mode 100644 index 0000000000..2b15d41380 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/parser.cpp @@ -0,0 +1,54 @@ +#include "parser/parser.h" + +// ANTLR4 generates code with unused parameters. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "cypher_lexer.h" +#pragma GCC diagnostic pop + +#include "common/exception/parser.h" +#include "common/string_utils.h" +#include "parser/antlr_parser/lbug_cypher_parser.h" +#include "parser/antlr_parser/parser_error_listener.h" +#include "parser/antlr_parser/parser_error_strategy.h" +#include "parser/transformer.h" + +using namespace antlr4; + +namespace lbug { +namespace parser { + +std::vector> Parser::parseQuery(std::string_view query, + std::vector transformerExtensions) { + auto queryStr = std::string(query); + queryStr = common::StringUtils::ltrim(queryStr); + queryStr = common::StringUtils::ltrimNewlines(queryStr); + // LCOV_EXCL_START + // We should have enforced this in connection, but I also realize empty query will cause + // antlr to hang. So enforce a duplicate check here. + if (queryStr.empty()) { + throw common::ParserException( + "Cannot parse empty query. This should be handled in connection."); + } + // LCOV_EXCL_STOP + + auto inputStream = ANTLRInputStream(queryStr); + auto parserErrorListener = ParserErrorListener(); + + auto cypherLexer = CypherLexer(&inputStream); + cypherLexer.removeErrorListeners(); + cypherLexer.addErrorListener(&parserErrorListener); + auto tokens = CommonTokenStream(&cypherLexer); + tokens.fill(); + + auto lbugCypherParser = LbugCypherParser(&tokens); + lbugCypherParser.removeErrorListeners(); + lbugCypherParser.addErrorListener(&parserErrorListener); + lbugCypherParser.setErrorHandler(std::make_shared()); + + Transformer transformer(*lbugCypherParser.ku_Statements(), std::move(transformerExtensions)); + return transformer.transform(); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/CMakeLists.txt new file mode 100644 index 0000000000..541cd6845b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/CMakeLists.txt @@ -0,0 +1,22 @@ +add_library(lbug_parser_transform + OBJECT + transform_copy.cpp + transform_ddl.cpp + transform_expression.cpp + transform_graph_pattern.cpp + transform_macro.cpp + transform_projection.cpp + transform_query.cpp + transform_reading_clause.cpp + transform_standalone_call.cpp + transform_transaction.cpp + transform_updating_clause.cpp + transform_extension.cpp + transform_port_db.cpp + transform_attach_database.cpp + transform_detach_database.cpp + transform_use_database.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_attach_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_attach_database.cpp new file mode 100644 index 0000000000..e7b141ff1a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_attach_database.cpp @@ -0,0 +1,19 @@ +#include "parser/attach_database.h" +#include "parser/transformer.h" + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformAttachDatabase( + CypherParser::KU_AttachDatabaseContext& ctx) { + auto dbPath = transformStringLiteral(*ctx.StringLiteral()); + auto dbAlias = ctx.oC_SchemaName() ? transformSchemaName(*ctx.oC_SchemaName()) : ""; + auto dbType = transformSymbolicName(*ctx.oC_SymbolicName()); + auto attachOption = ctx.kU_Options() ? transformOptions(*ctx.kU_Options()) : options_t{}; + AttachInfo attachInfo{std::move(dbPath), std::move(dbAlias), std::move(dbType), + std::move(attachOption)}; + return std::make_unique(std::move(attachInfo)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_copy.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_copy.cpp new file mode 100644 index 0000000000..a169938656 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_copy.cpp @@ -0,0 +1,109 @@ +#include "common/assert.h" +#include "parser/copy.h" +#include "parser/expression/parsed_literal_expression.h" +#include "parser/scan_source.h" +#include "parser/transformer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformCopyTo(CypherParser::KU_CopyTOContext& ctx) { + std::string filePath = transformStringLiteral(*ctx.StringLiteral()); + auto regularQuery = transformQuery(*ctx.oC_Query()); + auto copyTo = std::make_unique(std::move(filePath), std::move(regularQuery)); + if (ctx.kU_Options()) { + copyTo->setParsingOption(transformOptions(*ctx.kU_Options())); + } + return copyTo; +} + +std::unique_ptr Transformer::transformCopyFrom(CypherParser::KU_CopyFromContext& ctx) { + auto source = transformScanSource(*ctx.kU_ScanSource()); + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto copyFrom = std::make_unique(std::move(source), std::move(tableName)); + CopyFromColumnInfo info; + info.inputColumnOrder = ctx.kU_ColumnNames(); + if (ctx.kU_ColumnNames()) { + info.columnNames = transformColumnNames(*ctx.kU_ColumnNames()); + } + if (ctx.kU_Options()) { + copyFrom->setParsingOption(transformOptions(*ctx.kU_Options())); + } + copyFrom->setColumnInfo(std::move(info)); + return copyFrom; +} + +std::unique_ptr Transformer::transformCopyFromByColumn( + CypherParser::KU_CopyFromByColumnContext& ctx) { + auto source = std::make_unique(transformFilePaths(ctx.StringLiteral())); + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto copyFrom = std::make_unique(std::move(source), std::move(tableName)); + copyFrom->setByColumn(); + return copyFrom; +} + +std::vector Transformer::transformColumnNames( + CypherParser::KU_ColumnNamesContext& ctx) { + std::vector columnNames; + for (auto& schemaName : ctx.oC_SchemaName()) { + columnNames.push_back(transformSchemaName(*schemaName)); + } + return columnNames; +} + +std::vector Transformer::transformFilePaths( + const std::vector& stringLiteral) { + std::vector csvFiles; + csvFiles.reserve(stringLiteral.size()); + for (auto& csvFile : stringLiteral) { + csvFiles.push_back(transformStringLiteral(*csvFile)); + } + return csvFiles; +} + +std::unique_ptr Transformer::transformScanSource( + CypherParser::KU_ScanSourceContext& ctx) { + if (ctx.kU_FilePaths()) { + auto filePaths = transformFilePaths(ctx.kU_FilePaths()->StringLiteral()); + return std::make_unique(std::move(filePaths)); + } else if (ctx.oC_Query()) { + auto query = transformQuery(*ctx.oC_Query()); + return std::make_unique(std::move(query)); + } else if (ctx.oC_Variable()) { + std::vector objectNames; + objectNames.push_back(transformVariable(*ctx.oC_Variable())); + if (ctx.oC_SchemaName()) { + objectNames.push_back(transformSchemaName(*ctx.oC_SchemaName())); + } + return std::make_unique(std::move(objectNames)); + } else if (ctx.oC_FunctionInvocation()) { + auto functionExpression = transformFunctionInvocation(*ctx.oC_FunctionInvocation()); + return std::make_unique(std::move(functionExpression)); + } else if (ctx.oC_Parameter()) { + auto paramExpression = transformParameterExpression(*ctx.oC_Parameter()); + return std::make_unique(std::move(paramExpression)); + } + KU_UNREACHABLE; +} + +options_t Transformer::transformOptions(CypherParser::KU_OptionsContext& ctx) { + options_t options; + for (auto loadOption : ctx.kU_Option()) { + auto optionName = transformSymbolicName(*loadOption->oC_SymbolicName()); + // Check if the literal exists, otherwise set the value to true by default + if (loadOption->oC_Literal()) { + // If there is a literal, transform it and use it as the value + options.emplace(optionName, transformLiteral(*loadOption->oC_Literal())); + } else { + // If no literal is provided, set the default value to true + options.emplace(optionName, + std::make_unique(Value(true), "true")); + } + } + return options; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_ddl.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_ddl.cpp new file mode 100644 index 0000000000..0780959741 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_ddl.cpp @@ -0,0 +1,395 @@ +#include "common/exception/parser.h" +#include "common/string_format.h" +#include "parser/ddl/alter.h" +#include "parser/ddl/create_sequence.h" +#include "parser/ddl/create_table.h" +#include "parser/ddl/create_type.h" +#include "parser/ddl/drop.h" +#include "parser/ddl/drop_info.h" +#include "parser/transformer.h" + +using namespace lbug::common; +using namespace lbug::catalog; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformAlterTable( + CypherParser::KU_AlterTableContext& ctx) { + if (ctx.kU_AlterOptions()->kU_AddProperty()) { + return transformAddProperty(ctx); + } + if (ctx.kU_AlterOptions()->kU_DropProperty()) { + return transformDropProperty(ctx); + } + if (ctx.kU_AlterOptions()->kU_RenameTable()) { + return transformRenameTable(ctx); + } + if (ctx.kU_AlterOptions()->kU_AddFromToConnection()) { + return transformAddFromToConnection(ctx); + } + if (ctx.kU_AlterOptions()->kU_DropFromToConnection()) { + return transformDropFromToConnection(ctx); + } + return transformRenameProperty(ctx); +} + +std::string Transformer::getPKName(CypherParser::KU_CreateNodeTableContext& ctx) { + auto pkCount = 0; + std::string pkName; + auto& propertyDefinitions = *ctx.kU_PropertyDefinitions(); + for (auto& definition : propertyDefinitions.kU_PropertyDefinition()) { + if (definition->PRIMARY() && definition->KEY()) { + pkCount++; + pkName = transformPrimaryKey(*definition->kU_ColumnDefinition()); + } + } + if (ctx.kU_CreateNodeConstraint()) { + // In the case where no pkName has been found, or the Node Constraint's name is different + // than the pkName found, add the counter. + if (pkCount == 0 || transformPrimaryKey(*ctx.kU_CreateNodeConstraint()) != pkName) { + pkCount++; + } + pkName = transformPrimaryKey(*ctx.kU_CreateNodeConstraint()); + } + if (pkCount == 0) { + // Raise exception when no PRIMARY KEY is specified. + throw ParserException("Can not find primary key."); + } else if (pkCount > 1) { + // Raise exception when multiple PRIMARY KEY are specified. + throw ParserException("Found multiple primary keys."); + } + return pkName; +} + +ConflictAction Transformer::transformConflictAction(CypherParser::KU_IfNotExistsContext* ctx) { + if (ctx != nullptr) { + return ConflictAction::ON_CONFLICT_DO_NOTHING; + } + return ConflictAction::ON_CONFLICT_THROW; +} + +std::unique_ptr Transformer::transformCreateNodeTable( + CypherParser::KU_CreateNodeTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto createTableInfo = + CreateTableInfo(TableType::NODE, tableName, transformConflictAction(ctx.kU_IfNotExists())); + // If CREATE NODE TABLE AS syntax + if (ctx.oC_Query()) { + return std::make_unique(std::move(createTableInfo), + std::make_unique(transformQuery(*ctx.oC_Query()))); + } else { + createTableInfo.propertyDefinitions = + transformPropertyDefinitions(*ctx.kU_PropertyDefinitions()); + createTableInfo.extraInfo = std::make_unique(getPKName(ctx)); + return std::make_unique(std::move(createTableInfo)); + } +} + +std::unique_ptr Transformer::transformCreateRelGroup( + CypherParser::KU_CreateRelTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + std::string relMultiplicity = "MANY_MANY"; + if (ctx.oC_SymbolicName()) { + relMultiplicity = transformSymbolicName(*ctx.oC_SymbolicName()); + } + options_t options; + if (ctx.kU_Options()) { + options = transformOptions(*ctx.kU_Options()); + } + std::vector> fromToPairs; + for (auto& fromTo : ctx.kU_FromToConnections()->kU_FromToConnection()) { + auto src = transformSchemaName(*fromTo->oC_SchemaName(0)); + auto dst = transformSchemaName(*fromTo->oC_SchemaName(1)); + fromToPairs.emplace_back(src, dst); + } + std::unique_ptr extraInfo = + std::make_unique(relMultiplicity, std::move(fromToPairs), + std::move(options)); + auto conflictAction = transformConflictAction(ctx.kU_IfNotExists()); + auto createTableInfo = CreateTableInfo(common::TableType::REL, tableName, conflictAction); + if (ctx.kU_PropertyDefinitions()) { + createTableInfo.propertyDefinitions = + transformPropertyDefinitions(*ctx.kU_PropertyDefinitions()); + } + createTableInfo.extraInfo = std::move(extraInfo); + if (ctx.oC_Query()) { + auto scanSource = std::make_unique(transformQuery(*ctx.oC_Query())); + return std::make_unique(std::move(createTableInfo), std::move(scanSource)); + } else { + return std::make_unique(std::move(createTableInfo)); + } +} + +std::unique_ptr Transformer::transformCreateSequence( + CypherParser::KU_CreateSequenceContext& ctx) { + auto sequenceName = transformSchemaName(*ctx.oC_SchemaName()); + auto createSequenceInfo = CreateSequenceInfo(sequenceName, + ctx.kU_IfNotExists() ? common::ConflictAction::ON_CONFLICT_DO_NOTHING : + common::ConflictAction::ON_CONFLICT_THROW); + std::unordered_set applied; + for (auto seqOption : ctx.kU_SequenceOptions()) { + SequenceInfoType type; // NOLINT(*-init-variables) + std::string typeString; + CypherParser::OC_IntegerLiteralContext* valCtx = nullptr; + std::string* valOption = nullptr; + if (seqOption->kU_StartWith()) { + type = SequenceInfoType::START; + typeString = "START"; + valCtx = seqOption->kU_StartWith()->oC_IntegerLiteral(); + valOption = &createSequenceInfo.startWith; + *valOption = seqOption->kU_StartWith()->MINUS() ? "-" : ""; + } else if (seqOption->kU_IncrementBy()) { + type = SequenceInfoType::INCREMENT; + typeString = "INCREMENT"; + valCtx = seqOption->kU_IncrementBy()->oC_IntegerLiteral(); + valOption = &createSequenceInfo.increment; + *valOption = seqOption->kU_IncrementBy()->MINUS() ? "-" : ""; + } else if (seqOption->kU_MinValue()) { + type = SequenceInfoType::MINVALUE; + typeString = "MINVALUE"; + if (!seqOption->kU_MinValue()->NO()) { + valCtx = seqOption->kU_MinValue()->oC_IntegerLiteral(); + valOption = &createSequenceInfo.minValue; + *valOption = seqOption->kU_MinValue()->MINUS() ? "-" : ""; + } + } else if (seqOption->kU_MaxValue()) { + type = SequenceInfoType::MAXVALUE; + typeString = "MAXVALUE"; + if (!seqOption->kU_MaxValue()->NO()) { + valCtx = seqOption->kU_MaxValue()->oC_IntegerLiteral(); + valOption = &createSequenceInfo.maxValue; + *valOption = seqOption->kU_MaxValue()->MINUS() ? "-" : ""; + } + } else { // seqOption->kU_Cycle() + type = SequenceInfoType::CYCLE; + typeString = "CYCLE"; + if (!seqOption->kU_Cycle()->NO()) { + createSequenceInfo.cycle = true; + } + } + if (applied.find(type) != applied.end()) { + throw ParserException(typeString + " should be passed at most once."); + } + applied.insert(type); + + if (valCtx && valOption) { + *valOption += valCtx->DecimalInteger()->getText(); + } + } + return std::make_unique(std::move(createSequenceInfo)); +} + +std::unique_ptr Transformer::transformCreateType( + CypherParser::KU_CreateTypeContext& ctx) { + auto name = transformSchemaName(*ctx.oC_SchemaName()); + auto type = transformDataType(*ctx.kU_DataType()); + return std::make_unique(name, type); +} + +DropType transformDropType(CypherParser::KU_DropContext& ctx) { + if (ctx.TABLE()) { + return DropType::TABLE; + } else if (ctx.SEQUENCE()) { + return DropType::SEQUENCE; + } else if (ctx.MACRO()) { + return DropType::MACRO; + } else { + KU_UNREACHABLE; + } +} + +std::unique_ptr Transformer::transformDrop(CypherParser::KU_DropContext& ctx) { + auto name = transformSchemaName(*ctx.oC_SchemaName()); + auto dropType = transformDropType(ctx); + auto conflictAction = ctx.kU_IfExists() ? common::ConflictAction::ON_CONFLICT_DO_NOTHING : + common::ConflictAction::ON_CONFLICT_THROW; + return std::make_unique(DropInfo{std::move(name), dropType, conflictAction}); +} + +std::unique_ptr Transformer::transformRenameTable( + CypherParser::KU_AlterTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto newName = transformSchemaName(*ctx.kU_AlterOptions()->kU_RenameTable()->oC_SchemaName()); + auto extraInfo = std::make_unique(std::move(newName)); + auto info = AlterInfo(AlterType::RENAME, tableName, std::move(extraInfo)); + return std::make_unique(std::move(info)); +} + +std::unique_ptr Transformer::transformAddFromToConnection( + CypherParser::KU_AlterTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto schemaNameCtx = + ctx.kU_AlterOptions()->kU_AddFromToConnection()->kU_FromToConnection()->oC_SchemaName(); + KU_ASSERT(schemaNameCtx.size() == 2); + auto srcTableName = transformSchemaName(*schemaNameCtx[0]); + auto dstTableName = transformSchemaName(*schemaNameCtx[1]); + auto extraInfo = std::make_unique(std::move(srcTableName), + std::move(dstTableName)); + ConflictAction action = ConflictAction::ON_CONFLICT_THROW; + if (ctx.kU_AlterOptions()->kU_AddFromToConnection()->kU_IfNotExists()) { + action = ConflictAction::ON_CONFLICT_DO_NOTHING; + } + auto info = AlterInfo(AlterType::ADD_FROM_TO_CONNECTION, std::move(tableName), + std::move(extraInfo), action); + return std::make_unique(std::move(info)); +} + +std::unique_ptr Transformer::transformDropFromToConnection( + CypherParser::KU_AlterTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto schemaNameCtx = + ctx.kU_AlterOptions()->kU_DropFromToConnection()->kU_FromToConnection()->oC_SchemaName(); + KU_ASSERT(schemaNameCtx.size() == 2); + auto srcTableName = transformSchemaName(*schemaNameCtx[0]); + auto dstTableName = transformSchemaName(*schemaNameCtx[1]); + auto extraInfo = std::make_unique(std::move(srcTableName), + std::move(dstTableName)); + ConflictAction action = ConflictAction::ON_CONFLICT_THROW; + if (ctx.kU_AlterOptions()->kU_DropFromToConnection()->kU_IfExists()) { + action = ConflictAction::ON_CONFLICT_DO_NOTHING; + } + auto info = AlterInfo(AlterType::DROP_FROM_TO_CONNECTION, std::move(tableName), + std::move(extraInfo), action); + return std::make_unique(std::move(info)); +} + +std::unique_ptr Transformer::transformAddProperty( + CypherParser::KU_AlterTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto addPropertyCtx = ctx.kU_AlterOptions()->kU_AddProperty(); + auto propertyName = transformPropertyKeyName(*addPropertyCtx->oC_PropertyKeyName()); + auto dataType = transformDataType(*addPropertyCtx->kU_DataType()); + std::unique_ptr defaultValue = nullptr; + if (addPropertyCtx->kU_Default()) { + defaultValue = transformExpression(*addPropertyCtx->kU_Default()->oC_Expression()); + } + auto extraInfo = std::make_unique(std::move(propertyName), + std::move(dataType), std::move(defaultValue)); + ConflictAction action = ConflictAction::ON_CONFLICT_THROW; + if (addPropertyCtx->kU_IfNotExists()) { + action = ConflictAction::ON_CONFLICT_DO_NOTHING; + } + auto info = AlterInfo(AlterType::ADD_PROPERTY, tableName, std::move(extraInfo), action); + return std::make_unique(std::move(info)); +} + +std::unique_ptr Transformer::transformDropProperty( + CypherParser::KU_AlterTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto dropProperty = ctx.kU_AlterOptions()->kU_DropProperty(); + auto propertyName = transformPropertyKeyName(*dropProperty->oC_PropertyKeyName()); + auto extraInfo = std::make_unique(std::move(propertyName)); + ConflictAction action = ConflictAction::ON_CONFLICT_THROW; + if (dropProperty->kU_IfExists()) { + action = ConflictAction::ON_CONFLICT_DO_NOTHING; + } + auto info = AlterInfo(AlterType::DROP_PROPERTY, tableName, std::move(extraInfo), action); + return std::make_unique(std::move(info)); +} + +std::unique_ptr Transformer::transformRenameProperty( + CypherParser::KU_AlterTableContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto propertyName = transformPropertyKeyName( + *ctx.kU_AlterOptions()->kU_RenameProperty()->oC_PropertyKeyName()[0]); + auto newName = transformPropertyKeyName( + *ctx.kU_AlterOptions()->kU_RenameProperty()->oC_PropertyKeyName()[1]); + auto extraInfo = std::make_unique(propertyName, newName); + auto info = AlterInfo(AlterType::RENAME_PROPERTY, tableName, std::move(extraInfo)); + return std::make_unique(std::move(info)); +} + +std::unique_ptr Transformer::transformCommentOn(CypherParser::KU_CommentOnContext& ctx) { + auto tableName = transformSchemaName(*ctx.oC_SchemaName()); + auto comment = transformStringLiteral(*ctx.StringLiteral()); + auto extraInfo = std::make_unique(comment); + auto info = AlterInfo(AlterType::COMMENT, tableName, std::move(extraInfo)); + return std::make_unique(std::move(info)); +} + +std::vector Transformer::transformColumnDefinitions( + CypherParser::KU_ColumnDefinitionsContext& ctx) { + std::vector definitions; + for (auto& definition : ctx.kU_ColumnDefinition()) { + definitions.emplace_back(transformColumnDefinition(*definition)); + } + return definitions; +} + +ParsedColumnDefinition Transformer::transformColumnDefinition( + CypherParser::KU_ColumnDefinitionContext& ctx) { + auto propertyName = transformPropertyKeyName(*ctx.oC_PropertyKeyName()); + auto dataType = transformDataType(*ctx.kU_DataType()); + return ParsedColumnDefinition(propertyName, dataType); +} + +std::vector Transformer::transformPropertyDefinitions( + CypherParser::KU_PropertyDefinitionsContext& ctx) { + std::vector definitions; + for (auto& definition : ctx.kU_PropertyDefinition()) { + auto columnDefinition = transformColumnDefinition(*definition->kU_ColumnDefinition()); + std::unique_ptr defaultExpr = nullptr; + if (definition->kU_Default()) { + defaultExpr = transformExpression(*definition->kU_Default()->oC_Expression()); + } + definitions.push_back( + ParsedPropertyDefinition(std::move(columnDefinition), std::move(defaultExpr))); + } + return definitions; +} + +static std::string convertColumnDefinitionsToString( + const std::vector& columnDefinitions) { + std::string result; + for (auto& columnDefinition : columnDefinitions) { + result += common::stringFormat("{} {},", columnDefinition.name, columnDefinition.type); + } + return result.substr(0, result.length() - 1); +} + +std::string Transformer::transformUnionType(CypherParser::KU_UnionTypeContext& ctx) { + return common::stringFormat("{}({})", ctx.UNION()->getText(), + convertColumnDefinitionsToString(transformColumnDefinitions(*ctx.kU_ColumnDefinitions()))); +} + +std::string Transformer::transformStructType(CypherParser::KU_StructTypeContext& ctx) { + return common::stringFormat("{}({})", ctx.STRUCT()->getText(), + convertColumnDefinitionsToString(transformColumnDefinitions(*ctx.kU_ColumnDefinitions()))); +} + +std::string Transformer::transformMapType(CypherParser::KU_MapTypeContext& ctx) { + return common::stringFormat("{}({},{})", ctx.MAP()->getText(), + transformDataType(*ctx.kU_DataType()[0]), transformDataType(*ctx.kU_DataType()[1])); +} + +std::string Transformer::transformDecimalType(CypherParser::KU_DecimalTypeContext& ctx) { + return ctx.getText(); +} + +std::string Transformer::transformDataType(CypherParser::KU_DataTypeContext& ctx) { + if (ctx.oC_SymbolicName()) { + return transformSymbolicName(*ctx.oC_SymbolicName()); + } else if (ctx.kU_UnionType()) { + return transformUnionType(*ctx.kU_UnionType()); + } else if (ctx.kU_StructType()) { + return transformStructType(*ctx.kU_StructType()); + } else if (ctx.kU_MapType()) { + return transformMapType(*ctx.kU_MapType()); + } else if (ctx.kU_DecimalType()) { + return transformDecimalType(*ctx.kU_DecimalType()); + } else { + return transformDataType(*ctx.kU_DataType()) + ctx.kU_ListIdentifiers()->getText(); + } +} + +std::string Transformer::transformPrimaryKey(CypherParser::KU_CreateNodeConstraintContext& ctx) { + return transformPropertyKeyName(*ctx.oC_PropertyKeyName()); +} + +std::string Transformer::transformPrimaryKey(CypherParser::KU_ColumnDefinitionContext& ctx) { + return transformPropertyKeyName(*ctx.oC_PropertyKeyName()); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_detach_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_detach_database.cpp new file mode 100644 index 0000000000..c8656d5e8a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_detach_database.cpp @@ -0,0 +1,14 @@ +#include "parser/detach_database.h" +#include "parser/transformer.h" + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformDetachDatabase( + CypherParser::KU_DetachDatabaseContext& ctx) { + auto dbName = transformSchemaName(*ctx.oC_SchemaName()); + return std::make_unique(std::move(dbName)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_expression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_expression.cpp new file mode 100644 index 0000000000..45c1e20b8a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_expression.cpp @@ -0,0 +1,704 @@ +#include "function/aggregate/count_star.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "function/list/vector_list_functions.h" +#include "function/string/vector_string_functions.h" +#include "function/struct/vector_struct_functions.h" +#include "parser/expression/parsed_case_expression.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/expression/parsed_lambda_expression.h" +#include "parser/expression/parsed_literal_expression.h" +#include "parser/expression/parsed_parameter_expression.h" +#include "parser/expression/parsed_property_expression.h" +#include "parser/expression/parsed_subquery_expression.h" +#include "parser/expression/parsed_variable_expression.h" +#include "parser/transformer.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformExpression( + CypherParser::OC_ExpressionContext& ctx) { + return transformOrExpression(*ctx.oC_OrExpression()); +} + +std::unique_ptr Transformer::transformOrExpression( + CypherParser::OC_OrExpressionContext& ctx) { + std::unique_ptr expression; + for (auto& xorExpression : ctx.oC_XorExpression()) { + auto next = transformXorExpression(*xorExpression); + if (!expression) { + expression = std::move(next); + } else { + auto rawName = expression->getRawName() + " OR " + next->getRawName(); + expression = std::make_unique(ExpressionType::OR, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformXorExpression( + CypherParser::OC_XorExpressionContext& ctx) { + std::unique_ptr expression; + for (auto& andExpression : ctx.oC_AndExpression()) { + auto next = transformAndExpression(*andExpression); + if (!expression) { + expression = std::move(next); + } else { + auto rawName = expression->getRawName() + " XOR " + next->getRawName(); + expression = std::make_unique(ExpressionType::XOR, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformAndExpression( + CypherParser::OC_AndExpressionContext& ctx) { + std::unique_ptr expression; + for (auto& notExpression : ctx.oC_NotExpression()) { + auto next = transformNotExpression(*notExpression); + if (!expression) { + expression = std::move(next); + } else { + auto rawName = expression->getRawName() + " AND " + next->getRawName(); + expression = std::make_unique(ExpressionType::AND, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformNotExpression( + CypherParser::OC_NotExpressionContext& ctx) { + auto result = transformComparisonExpression(*ctx.oC_ComparisonExpression()); + if (!ctx.NOT().empty()) { + for ([[maybe_unused]] auto& _ : ctx.NOT()) { + auto rawName = "NOT " + result->toString(); + result = std::make_unique(ExpressionType::NOT, std::move(result), + std::move(rawName)); + } + } + return result; +} + +std::unique_ptr Transformer::transformComparisonExpression( + CypherParser::OC_ComparisonExpressionContext& ctx) { + if (1 == ctx.kU_BitwiseOrOperatorExpression().size()) { + return transformBitwiseOrOperatorExpression(*ctx.kU_BitwiseOrOperatorExpression(0)); + } + // Antlr parser throws error for conjunctive comparison. + // Transformer should only handle the case of single comparison operator. + KU_ASSERT(ctx.kU_ComparisonOperator().size() == 1); + auto left = transformBitwiseOrOperatorExpression(*ctx.kU_BitwiseOrOperatorExpression(0)); + auto right = transformBitwiseOrOperatorExpression(*ctx.kU_BitwiseOrOperatorExpression(1)); + auto comparisonOperator = ctx.kU_ComparisonOperator()[0]->getText(); + if (comparisonOperator == "=") { + return std::make_unique(ExpressionType::EQUALS, std::move(left), + std::move(right), ctx.getText()); + } else if (comparisonOperator == "<>") { + return std::make_unique(ExpressionType::NOT_EQUALS, std::move(left), + std::move(right), ctx.getText()); + } else if (comparisonOperator == ">") { + return std::make_unique(ExpressionType::GREATER_THAN, std::move(left), + std::move(right), ctx.getText()); + } else if (comparisonOperator == ">=") { + return std::make_unique(ExpressionType::GREATER_THAN_EQUALS, + std::move(left), std::move(right), ctx.getText()); + } else if (comparisonOperator == "<") { + return std::make_unique(ExpressionType::LESS_THAN, std::move(left), + std::move(right), ctx.getText()); + } else { + KU_ASSERT(comparisonOperator == "<="); + return std::make_unique(ExpressionType::LESS_THAN_EQUALS, std::move(left), + std::move(right), ctx.getText()); + } +} + +std::unique_ptr Transformer::transformBitwiseOrOperatorExpression( + CypherParser::KU_BitwiseOrOperatorExpressionContext& ctx) { + std::unique_ptr expression; + for (auto i = 0ul; i < ctx.kU_BitwiseAndOperatorExpression().size(); ++i) { + auto next = transformBitwiseAndOperatorExpression(*ctx.kU_BitwiseAndOperatorExpression(i)); + if (!expression) { + expression = std::move(next); + } else { + auto rawName = expression->getRawName() + " | " + next->getRawName(); + expression = std::make_unique(BitwiseOrFunction::name, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformBitwiseAndOperatorExpression( + CypherParser::KU_BitwiseAndOperatorExpressionContext& ctx) { + std::unique_ptr expression; + for (auto i = 0ul; i < ctx.kU_BitShiftOperatorExpression().size(); ++i) { + auto next = transformBitShiftOperatorExpression(*ctx.kU_BitShiftOperatorExpression(i)); + if (!expression) { + expression = std::move(next); + } else { + auto rawName = expression->getRawName() + " & " + next->getRawName(); + expression = std::make_unique(BitwiseAndFunction::name, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformBitShiftOperatorExpression( + CypherParser::KU_BitShiftOperatorExpressionContext& ctx) { + std::unique_ptr expression; + for (auto i = 0ul; i < ctx.oC_AddOrSubtractExpression().size(); ++i) { + auto next = transformAddOrSubtractExpression(*ctx.oC_AddOrSubtractExpression(i)); + if (!expression) { + expression = std::move(next); + } else { + auto bitShiftOperator = ctx.kU_BitShiftOperator(i - 1)->getText(); + auto rawName = + expression->getRawName() + " " + bitShiftOperator + " " + next->getRawName(); + if (bitShiftOperator == "<<") { + expression = std::make_unique(BitShiftLeftFunction::name, + std::move(expression), std::move(next), rawName); + } else { + KU_ASSERT(bitShiftOperator == ">>"); + expression = std::make_unique(BitShiftRightFunction::name, + std::move(expression), std::move(next), rawName); + } + } + } + return expression; +} + +std::unique_ptr Transformer::transformAddOrSubtractExpression( + CypherParser::OC_AddOrSubtractExpressionContext& ctx) { + std::unique_ptr expression; + for (auto i = 0ul; i < ctx.oC_MultiplyDivideModuloExpression().size(); ++i) { + auto next = + transformMultiplyDivideModuloExpression(*ctx.oC_MultiplyDivideModuloExpression(i)); + if (!expression) { + expression = std::move(next); + } else { + auto arithmeticOperator = ctx.kU_AddOrSubtractOperator(i - 1)->getText(); + auto rawName = + expression->getRawName() + " " + arithmeticOperator + " " + next->getRawName(); + expression = std::make_unique(arithmeticOperator, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformMultiplyDivideModuloExpression( + CypherParser::OC_MultiplyDivideModuloExpressionContext& ctx) { + std::unique_ptr expression; + for (auto i = 0ul; i < ctx.oC_PowerOfExpression().size(); i++) { + auto next = transformPowerOfExpression(*ctx.oC_PowerOfExpression(i)); + if (!expression) { + expression = std::move(next); + } else { + auto arithmeticOperator = ctx.kU_MultiplyDivideModuloOperator(i - 1)->getText(); + auto rawName = + expression->getRawName() + " " + arithmeticOperator + " " + next->getRawName(); + expression = std::make_unique(arithmeticOperator, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformPowerOfExpression( + CypherParser::OC_PowerOfExpressionContext& ctx) { + std::unique_ptr expression; + for (auto& stringListNullOperatorExpression : ctx.oC_StringListNullOperatorExpression()) { + auto next = transformStringListNullOperatorExpression(*stringListNullOperatorExpression); + if (!expression) { + expression = std::move(next); + } else { + auto rawName = expression->getRawName() + " ^ " + next->getRawName(); + expression = std::make_unique(PowerFunction::name, + std::move(expression), std::move(next), rawName); + } + } + return expression; +} + +std::unique_ptr Transformer::transformUnaryAddSubtractOrFactorialExpression( + CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext& ctx) { + auto atomCtx = ctx.oC_PropertyOrLabelsExpression()->oC_Atom(); + bool isNumberLiteral = atomCtx->oC_Literal() && atomCtx->oC_Literal()->oC_NumberLiteral(); + std::unique_ptr result; + if (isNumberLiteral) { + // Try parse -number as a signed literal. This is to avoid + // -170141183460469231731687303715884105728 being parsed as + // 170141183460469231731687303715884105728 and cause overflow + result = transformNumberLiteral(*atomCtx->oC_Literal()->oC_NumberLiteral(), + ctx.MINUS().size() % 2 == 1); + } else { + result = transformPropertyOrLabelsExpression(*ctx.oC_PropertyOrLabelsExpression()); + } + if (ctx.FACTORIAL()) { // Factorial has a higher precedence + auto raw = result->toString() + "!"; + result = std::make_unique(FactorialFunction::name, + std::move(result), std::move(raw)); + } + if (!ctx.MINUS().empty() && !isNumberLiteral) { + for ([[maybe_unused]] auto& _ : ctx.MINUS()) { + auto raw = "-" + result->toString(); + result = std::make_unique(NegateFunction::name, + std::move(result), std::move(raw)); + } + } + return result; +} + +std::unique_ptr Transformer::transformStringListNullOperatorExpression( + CypherParser::OC_StringListNullOperatorExpressionContext& ctx) { + auto unaryAddSubtractOrFactorialExpression = transformUnaryAddSubtractOrFactorialExpression( + *ctx.oC_UnaryAddSubtractOrFactorialExpression()); + if (ctx.oC_NullOperatorExpression()) { + return transformNullOperatorExpression(*ctx.oC_NullOperatorExpression(), + std::move(unaryAddSubtractOrFactorialExpression)); + } + if (!ctx.oC_ListOperatorExpression().empty()) { + auto result = transformListOperatorExpression(*ctx.oC_ListOperatorExpression(0), + std::move(unaryAddSubtractOrFactorialExpression)); + for (auto i = 1u; i < ctx.oC_ListOperatorExpression().size(); ++i) { + result = transformListOperatorExpression(*ctx.oC_ListOperatorExpression(i), + std::move(result)); + } + return result; + } + if (ctx.oC_StringOperatorExpression()) { + return transformStringOperatorExpression(*ctx.oC_StringOperatorExpression(), + std::move(unaryAddSubtractOrFactorialExpression)); + } + return unaryAddSubtractOrFactorialExpression; +} + +std::unique_ptr Transformer::transformStringOperatorExpression( + CypherParser::OC_StringOperatorExpressionContext& ctx, + std::unique_ptr propertyExpression) { + auto rawExpression = propertyExpression->getRawName() + " " + ctx.getText(); + auto right = transformPropertyOrLabelsExpression(*ctx.oC_PropertyOrLabelsExpression()); + if (ctx.STARTS()) { + return std::make_unique(StartsWithFunction::name, + std::move(propertyExpression), std::move(right), rawExpression); + } else if (ctx.ENDS()) { + return std::make_unique(EndsWithFunction::name, + std::move(propertyExpression), std::move(right), rawExpression); + } else if (ctx.CONTAINS()) { + return std::make_unique(ContainsFunction::name, + std::move(propertyExpression), std::move(right), rawExpression); + } else { + KU_ASSERT(ctx.oC_RegularExpression()); + return std::make_unique(RegexpFullMatchFunction::name, + std::move(propertyExpression), std::move(right), rawExpression); + } +} + +std::unique_ptr Transformer::transformListOperatorExpression( + CypherParser::OC_ListOperatorExpressionContext& ctx, std::unique_ptr child) { + auto raw = child->getRawName() + ctx.getText(); + if (ctx.IN()) { // x IN y + auto listContains = + std::make_unique(ListContainsFunction::name, std::move(raw)); + auto right = transformPropertyOrLabelsExpression(*ctx.oC_PropertyOrLabelsExpression()); + listContains->addChild(std::move(right)); + listContains->addChild(std::move(child)); + return listContains; + } + if (ctx.COLON() || ctx.DOTDOT()) { // x[:]/x[..] + auto listSlice = + std::make_unique(ListSliceFunction::name, std::move(raw)); + listSlice->addChild(std::move(child)); + std::unique_ptr left; + std::unique_ptr right; + if (ctx.oC_Expression().size() == 2) { // [left:right]/[left..right] + left = transformExpression(*ctx.oC_Expression(0)); + right = transformExpression(*ctx.oC_Expression(1)); + } else if (ctx.oC_Expression().size() == 0) { // [:]/[..] + left = std::make_unique(Value(1), "1"); + right = std::make_unique(Value(-1), "-1"); + } else { + if (ctx.children[1]->getText() == ":" || + ctx.children[1]->getText() == "..") { // [:right]/[..right] + left = std::make_unique(Value(1), "1"); + right = transformExpression(*ctx.oC_Expression(0)); + } else { // [left:]/[left..] + left = transformExpression(*ctx.oC_Expression(0)); + right = std::make_unique(Value(-1), "-1"); + } + } + listSlice->addChild(std::move(left)); + listSlice->addChild(std::move(right)); + return listSlice; + } + // x[a] + auto listExtract = + std::make_unique(ListExtractFunction::name, std::move(raw)); + listExtract->addChild(std::move(child)); + KU_ASSERT(ctx.oC_Expression().size() == 1); + listExtract->addChild(transformExpression(*ctx.oC_Expression()[0])); + return listExtract; +} + +std::unique_ptr Transformer::transformNullOperatorExpression( + CypherParser::OC_NullOperatorExpressionContext& ctx, + std::unique_ptr propertyExpression) { + auto rawExpression = propertyExpression->getRawName() + " " + ctx.getText(); + KU_ASSERT(ctx.IS() && ctx.NULL_()); + return ctx.NOT() ? std::make_unique(ExpressionType::IS_NOT_NULL, + std::move(propertyExpression), rawExpression) : + std::make_unique(ExpressionType::IS_NULL, + std::move(propertyExpression), rawExpression); +} + +std::unique_ptr Transformer::transformPropertyOrLabelsExpression( + CypherParser::OC_PropertyOrLabelsExpressionContext& ctx) { + auto atom = transformAtom(*ctx.oC_Atom()); + if (!ctx.oC_PropertyLookup().empty()) { + auto lookUpCtx = ctx.oC_PropertyLookup(0); + auto result = createPropertyExpression(*lookUpCtx, std::move(atom)); + for (auto i = 1u; i < ctx.oC_PropertyLookup().size(); ++i) { + lookUpCtx = ctx.oC_PropertyLookup(i); + result = createPropertyExpression(*lookUpCtx, std::move(result)); + } + return result; + } + return atom; +} + +std::unique_ptr Transformer::transformAtom(CypherParser::OC_AtomContext& ctx) { + if (ctx.oC_Literal()) { + return transformLiteral(*ctx.oC_Literal()); + } else if (ctx.oC_Parameter()) { + return transformParameterExpression(*ctx.oC_Parameter()); + } else if (ctx.oC_CaseExpression()) { + return transformCaseExpression(*ctx.oC_CaseExpression()); + } else if (ctx.oC_ParenthesizedExpression()) { + return transformParenthesizedExpression(*ctx.oC_ParenthesizedExpression()); + } else if (ctx.oC_FunctionInvocation()) { + return transformFunctionInvocation(*ctx.oC_FunctionInvocation()); + } else if (ctx.oC_PathPatterns()) { + return transformPathPattern(*ctx.oC_PathPatterns()); + } else if (ctx.oC_ExistCountSubquery()) { + return transformExistCountSubquery(*ctx.oC_ExistCountSubquery()); + } else if (ctx.oC_Quantifier()) { + return transformOcQuantifier(*ctx.oC_Quantifier()); + } else { + KU_ASSERT(ctx.oC_Variable()); + return std::make_unique(transformVariable(*ctx.oC_Variable()), + ctx.getText()); + } +} + +std::unique_ptr Transformer::transformLiteral( + CypherParser::OC_LiteralContext& ctx) { + if (ctx.oC_NumberLiteral()) { + return transformNumberLiteral(*ctx.oC_NumberLiteral(), false /*negative*/); + } else if (ctx.oC_BooleanLiteral()) { + return transformBooleanLiteral(*ctx.oC_BooleanLiteral()); + } else if (ctx.StringLiteral()) { + return std::make_unique( + Value(LogicalType::STRING(), transformStringLiteral(*ctx.StringLiteral())), + ctx.getText()); + } else if (ctx.NULL_()) { + return std::make_unique(Value::createNullValue(), ctx.getText()); + } else if (ctx.kU_StructLiteral()) { + return transformStructLiteral(*ctx.kU_StructLiteral()); + } else { + KU_ASSERT(ctx.oC_ListLiteral()); + return transformListLiteral(*ctx.oC_ListLiteral()); + } +} + +std::unique_ptr Transformer::transformBooleanLiteral( + CypherParser::OC_BooleanLiteralContext& ctx) { + if (ctx.TRUE()) { + return std::make_unique(Value(true), ctx.getText()); + } else if (ctx.FALSE()) { + return std::make_unique(Value(false), ctx.getText()); + } + KU_UNREACHABLE; +} + +std::unique_ptr Transformer::transformListLiteral( + CypherParser::OC_ListLiteralContext& ctx) { + auto listCreation = + std::make_unique(ListCreationFunction::name, ctx.getText()); + if (ctx.oC_Expression() == nullptr) { // empty list + return listCreation; + } + listCreation->addChild(transformExpression(*ctx.oC_Expression())); + for (auto& listEntry : ctx.kU_ListEntry()) { + if (listEntry->oC_Expression() == nullptr) { + auto nullValue = Value::createNullValue(); + listCreation->addChild( + std::make_unique(nullValue, nullValue.toString())); + } else { + listCreation->addChild(transformExpression(*listEntry->oC_Expression())); + } + } + return listCreation; +} + +std::unique_ptr Transformer::transformStructLiteral( + CypherParser::KU_StructLiteralContext& ctx) { + auto structPack = + std::make_unique(StructPackFunctions::name, ctx.getText()); + for (auto& structField : ctx.kU_StructField()) { + auto structExpr = transformExpression(*structField->oC_Expression()); + std::string paramName; + if (structField->oC_SymbolicName()) { + paramName = transformSymbolicName(*structField->oC_SymbolicName()); + } else { + paramName = transformStringLiteral(*structField->StringLiteral()); + } + structPack->addOptionalParams(std::move(paramName), std::move(structExpr)); + } + return structPack; +} + +std::unique_ptr Transformer::transformParameterExpression( + CypherParser::OC_ParameterContext& ctx) { + auto parameterName = + ctx.oC_SymbolicName() ? ctx.oC_SymbolicName()->getText() : ctx.DecimalInteger()->getText(); + return std::make_unique(parameterName, ctx.getText()); +} + +std::unique_ptr Transformer::transformParenthesizedExpression( + CypherParser::OC_ParenthesizedExpressionContext& ctx) { + return transformExpression(*ctx.oC_Expression()); +} + +std::unique_ptr Transformer::transformFunctionInvocation( + CypherParser::OC_FunctionInvocationContext& ctx) { + if (ctx.STAR()) { + return std::make_unique(CountStarFunction::name, ctx.getText()); + } + std::string functionName; + if (ctx.COUNT()) { + functionName = "COUNT"; + } else if (ctx.CAST()) { + functionName = "CAST"; + } else { + functionName = transformFunctionName(*ctx.oC_FunctionName()); + } + auto expression = std::make_unique(functionName, ctx.getText(), + ctx.DISTINCT() != nullptr); + if (ctx.CAST()) { + for (auto& functionParameter : ctx.kU_FunctionParameter()) { + expression->addChild(transformFunctionParameterExpression(*functionParameter)); + } + if (ctx.kU_DataType()) { + expression->addChild(std::make_unique( + common::Value(transformDataType(*ctx.kU_DataType())))); + } + } else { + for (auto& functionParameter : ctx.kU_FunctionParameter()) { + auto parsedFunctionParameter = transformFunctionParameterExpression(*functionParameter); + if (functionParameter->oC_SymbolicName()) { + // Optional parameter + expression->addOptionalParams( + transformSymbolicName(*functionParameter->oC_SymbolicName()), + std::move(parsedFunctionParameter)); + } else { + expression->addChild(std::move(parsedFunctionParameter)); + } + } + } + return expression; +} + +std::string Transformer::transformFunctionName(CypherParser::OC_FunctionNameContext& ctx) { + return transformSymbolicName(*ctx.oC_SymbolicName()); +} + +std::vector Transformer::transformLambdaVariables( + CypherParser::KU_LambdaVarsContext& ctx) { + std::vector lambdaVariables; + lambdaVariables.reserve(ctx.oC_SymbolicName().size()); + for (auto& var : ctx.oC_SymbolicName()) { + lambdaVariables.push_back(transformSymbolicName(*var)); + } + return lambdaVariables; +} + +std::unique_ptr Transformer::transformLambdaParameter( + CypherParser::KU_LambdaParameterContext& ctx) { + auto vars = transformLambdaVariables(*ctx.kU_LambdaVars()); + auto lambdaOperation = transformExpression(*ctx.oC_Expression()); + return std::make_unique(std::move(vars), std::move(lambdaOperation), + ctx.getText()); +} + +std::unique_ptr Transformer::transformFunctionParameterExpression( + CypherParser::KU_FunctionParameterContext& ctx) { + if (ctx.kU_LambdaParameter()) { + return transformLambdaParameter(*ctx.kU_LambdaParameter()); + } else { + auto expression = transformExpression(*ctx.oC_Expression()); + if (ctx.oC_SymbolicName()) { + expression->setAlias(transformSymbolicName(*ctx.oC_SymbolicName())); + } + return expression; + } +} + +std::unique_ptr Transformer::transformPathPattern( + CypherParser::OC_PathPatternsContext& ctx) { + auto subquery = std::make_unique(SubqueryType::EXISTS, ctx.getText()); + auto patternElement = PatternElement(transformNodePattern(*ctx.oC_NodePattern())); + for (auto& chain : ctx.oC_PatternElementChain()) { + patternElement.addPatternElementChain(transformPatternElementChain(*chain)); + } + subquery->addPatternElement(std::move(patternElement)); + return subquery; +} + +std::unique_ptr Transformer::transformExistCountSubquery( + CypherParser::OC_ExistCountSubqueryContext& ctx) { + auto type = ctx.EXISTS() ? SubqueryType::EXISTS : SubqueryType::COUNT; + auto subquery = std::make_unique(type, ctx.getText()); + subquery->setPatternElements(transformPattern(*ctx.oC_Pattern())); + if (ctx.oC_Where()) { + subquery->setWhereClause(transformWhere(*ctx.oC_Where())); + } + if (ctx.kU_Hint()) { + subquery->setHint(transformJoinHint(*ctx.kU_Hint()->kU_JoinNode())); + } + return subquery; +} + +std::unique_ptr Transformer::transformOcQuantifier( + CypherParser::OC_QuantifierContext& ctx) { + auto variable = transformVariable(*ctx.oC_FilterExpression()->oC_IdInColl()->oC_Variable()); + auto whereExpr = transformWhere(*ctx.oC_FilterExpression()->oC_Where()); + auto lambdaRaw = variable + "->" + whereExpr->toString(); + auto lambdaExpr = std::make_unique(std::vector{variable}, + std::move(whereExpr), lambdaRaw); + std::string quantifierName; + if (ctx.ALL()) { + quantifierName = "ALL"; + } else if (ctx.ANY()) { + quantifierName = "ANY"; + } else if (ctx.NONE()) { + quantifierName = "NONE"; + } else if (ctx.SINGLE()) { + quantifierName = "SINGLE"; + } + auto listExpr = transformExpression(*ctx.oC_FilterExpression()->oC_IdInColl()->oC_Expression()); + return std::make_unique(quantifierName, std::move(listExpr), + std::move(lambdaExpr), ctx.getText()); +} + +std::unique_ptr Transformer::createPropertyExpression( + CypherParser::OC_PropertyKeyNameContext& ctx, std::unique_ptr child) { + auto key = transformPropertyKeyName(ctx); + return std::make_unique(key, std::move(child), + child->toString() + "." + key); +} + +std::unique_ptr Transformer::createPropertyExpression( + CypherParser::OC_PropertyLookupContext& ctx, std::unique_ptr child) { + auto key = + ctx.STAR() ? InternalKeyword::STAR : transformPropertyKeyName(*ctx.oC_PropertyKeyName()); + return std::make_unique(key, std::move(child), + child->toString() + ctx.getText()); +} + +std::unique_ptr Transformer::transformCaseExpression( + CypherParser::OC_CaseExpressionContext& ctx) { + std::unique_ptr caseExpression = nullptr; + std::unique_ptr elseExpression = nullptr; + if (ctx.ELSE()) { + if (ctx.oC_Expression().size() == 1) { + elseExpression = transformExpression(*ctx.oC_Expression(0)); + } else { + KU_ASSERT(ctx.oC_Expression().size() == 2); + caseExpression = transformExpression(*ctx.oC_Expression(0)); + elseExpression = transformExpression(*ctx.oC_Expression(1)); + } + } else { + if (ctx.oC_Expression().size() == 1) { + caseExpression = transformExpression(*ctx.oC_Expression(0)); + } + } + auto parsedCaseExpression = std::make_unique(ctx.getText()); + parsedCaseExpression->setCaseExpression(std::move(caseExpression)); + parsedCaseExpression->setElseExpression(std::move(elseExpression)); + for (auto& caseAlternative : ctx.oC_CaseAlternative()) { + parsedCaseExpression->addCaseAlternative(transformCaseAlternative(*caseAlternative)); + } + return parsedCaseExpression; +} + +ParsedCaseAlternative Transformer::transformCaseAlternative( + CypherParser::OC_CaseAlternativeContext& ctx) { + auto whenExpression = transformExpression(*ctx.oC_Expression(0)); + auto thenExpression = transformExpression(*ctx.oC_Expression(1)); + return ParsedCaseAlternative(std::move(whenExpression), std::move(thenExpression)); +} + +std::unique_ptr Transformer::transformNumberLiteral( + CypherParser::OC_NumberLiteralContext& ctx, bool negative) { + if (ctx.oC_IntegerLiteral()) { + return transformIntegerLiteral(*ctx.oC_IntegerLiteral(), negative); + } else { + KU_ASSERT(ctx.oC_DoubleLiteral()); + return transformDoubleLiteral(*ctx.oC_DoubleLiteral(), negative); + } +} + +std::unique_ptr Transformer::transformProperty( + CypherParser::OC_PropertyExpressionContext& ctx) { + auto child = transformAtom(*ctx.oC_Atom()); + return createPropertyExpression(*ctx.oC_PropertyLookup(), std::move(child)); +} + +std::string Transformer::transformPropertyKeyName(CypherParser::OC_PropertyKeyNameContext& ctx) { + return transformSchemaName(*ctx.oC_SchemaName()); +} + +std::unique_ptr Transformer::transformIntegerLiteral( + CypherParser::OC_IntegerLiteralContext& ctx, bool negative) { + auto text = ctx.DecimalInteger()->getText(); + if (negative) { + text = '-' + text; + } + ku_string_t literal{text.c_str(), text.length()}; + int64_t result = 0; + if (function::CastString::tryCast(literal, result)) { + return std::make_unique(Value(result), ctx.getText()); + } + int128_t result128 = 0; + if (function::trySimpleIntegerCast(reinterpret_cast(literal.getData()), + literal.len, result128)) { + return std::make_unique(Value(result128), ctx.getText()); + } + uint128_t resultu128 = 0; + function::CastString::operation(literal, resultu128); + return std::make_unique(Value(resultu128), ctx.getText()); +} + +std::unique_ptr Transformer::transformDoubleLiteral( + CypherParser::OC_DoubleLiteralContext& ctx, bool negative) { + auto text = ctx.ExponentDecimalReal() ? ctx.ExponentDecimalReal()->getText() : + ctx.RegularDecimalReal()->getText(); + if (negative) { + text = '-' + text; + } + ku_string_t literal{text.c_str(), text.length()}; + double result = 0; + function::CastString::operation(literal, result); + return std::make_unique(Value(result), ctx.getText()); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_extension.cpp new file mode 100644 index 0000000000..5df14cd166 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_extension.cpp @@ -0,0 +1,43 @@ +#include "extension/extension.h" +#include "parser/extension_statement.h" +#include "parser/transformer.h" + +using namespace lbug::common; +using namespace lbug::extension; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformExtension(CypherParser::KU_ExtensionContext& ctx) { + if (ctx.kU_InstallExtension()) { + auto extensionRepo = + ctx.kU_InstallExtension()->StringLiteral() ? + transformStringLiteral(*ctx.kU_InstallExtension()->StringLiteral()) : + ExtensionUtils::OFFICIAL_EXTENSION_REPO; + + auto installExtensionAuxInfo = std::make_unique( + std::move(extensionRepo), transformVariable(*ctx.kU_InstallExtension()->oC_Variable()), + ctx.kU_InstallExtension()->FORCE()); + return std::make_unique(std::move(installExtensionAuxInfo)); + } else if (ctx.kU_UpdateExtension()) { + // Update extension is a syntax sugar for force install extension. + auto installExtensionAuxInfo = std::make_unique( + ExtensionUtils::OFFICIAL_EXTENSION_REPO, + transformVariable(*ctx.kU_UpdateExtension()->oC_Variable()), true /* forceInstall */); + return std::make_unique(std::move(installExtensionAuxInfo)); + } else if (ctx.kU_UninstallExtension()) { + auto path = transformVariable(*ctx.kU_UninstallExtension()->oC_Variable()); + return std::make_unique( + std::make_unique(ExtensionAction::UNINSTALL, std::move(path))); + } else { + auto path = ctx.kU_LoadExtension()->StringLiteral() ? + transformStringLiteral(*ctx.kU_LoadExtension()->StringLiteral()) : + transformVariable(*ctx.kU_LoadExtension()->oC_Variable()); + auto installExtensionAuxInfo = + std::make_unique(ExtensionAction::LOAD, std::move(path)); + return std::make_unique(std::move(installExtensionAuxInfo)); + } +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_graph_pattern.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_graph_pattern.cpp new file mode 100644 index 0000000000..107f6f3947 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_graph_pattern.cpp @@ -0,0 +1,213 @@ +#include "common/assert.h" +#include "parser/query/graph_pattern/pattern_element.h" +#include "parser/transformer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::vector Transformer::transformPattern(CypherParser::OC_PatternContext& ctx) { + std::vector patterns; + for (auto& patternPart : ctx.oC_PatternPart()) { + patterns.push_back(transformPatternPart(*patternPart)); + } + return patterns; +} + +PatternElement Transformer::transformPatternPart(CypherParser::OC_PatternPartContext& ctx) { + auto patternElement = transformAnonymousPatternPart(*ctx.oC_AnonymousPatternPart()); + if (ctx.oC_Variable()) { + auto variable = transformVariable(*ctx.oC_Variable()); + patternElement.setPathName(variable); + } + return patternElement; +} + +PatternElement Transformer::transformAnonymousPatternPart( + CypherParser::OC_AnonymousPatternPartContext& ctx) { + return transformPatternElement(*ctx.oC_PatternElement()); +} + +PatternElement Transformer::transformPatternElement(CypherParser::OC_PatternElementContext& ctx) { + if (ctx.oC_PatternElement()) { // parenthesized pattern element + return transformPatternElement(*ctx.oC_PatternElement()); + } + auto patternElement = PatternElement(transformNodePattern(*ctx.oC_NodePattern())); + if (!ctx.oC_PatternElementChain().empty()) { + for (auto& patternElementChain : ctx.oC_PatternElementChain()) { + patternElement.addPatternElementChain( + transformPatternElementChain(*patternElementChain)); + } + } + return patternElement; +} + +NodePattern Transformer::transformNodePattern(CypherParser::OC_NodePatternContext& ctx) { + auto variable = std::string(); + if (ctx.oC_Variable()) { + variable = transformVariable(*ctx.oC_Variable()); + } + auto nodeLabels = std::vector{}; + if (ctx.oC_NodeLabels()) { + nodeLabels = transformNodeLabels(*ctx.oC_NodeLabels()); + } + auto properties = std::vector>>{}; + if (ctx.kU_Properties()) { + properties = transformProperties(*ctx.kU_Properties()); + } + return NodePattern(std::move(variable), std::move(nodeLabels), std::move(properties)); +} + +PatternElementChain Transformer::transformPatternElementChain( + CypherParser::OC_PatternElementChainContext& ctx) { + return PatternElementChain(transformRelationshipPattern(*ctx.oC_RelationshipPattern()), + transformNodePattern(*ctx.oC_NodePattern())); +} + +RelPattern Transformer::transformRelationshipPattern( + CypherParser::OC_RelationshipPatternContext& ctx) { + auto relDetail = ctx.oC_RelationshipDetail(); + auto variable = std::string(); + auto relTypes = std::vector{}; + auto properties = std::vector>>{}; + // Parse name, label & properties + if (relDetail) { + if (relDetail->oC_Variable()) { + variable = transformVariable(*relDetail->oC_Variable()); + } + if (relDetail->oC_RelationshipTypes()) { + relTypes = transformRelTypes(*relDetail->oC_RelationshipTypes()); + } + if (relDetail->kU_Properties()) { + properties = transformProperties(*relDetail->kU_Properties()); + } + } + // Parse direction + ArrowDirection arrowDirection; // NOLINT(*-init-variables) + if (ctx.oC_LeftArrowHead()) { + arrowDirection = ArrowDirection::LEFT; + } else if (ctx.oC_RightArrowHead()) { + arrowDirection = ArrowDirection::RIGHT; + } else { + arrowDirection = ArrowDirection::BOTH; + } + // Parse recursive info + auto relType = QueryRelType::NON_RECURSIVE; + auto recursiveInfo = RecursiveRelPatternInfo(); + + if (relDetail && relDetail->kU_RecursiveDetail()) { + auto recursiveDetail = relDetail->kU_RecursiveDetail(); + // Parse recursive type + auto recursiveType = recursiveDetail->kU_RecursiveType(); + if (recursiveType) { + if (recursiveType->ALL()) { + if (recursiveType->WSHORTEST()) { + relType = QueryRelType::ALL_WEIGHTED_SHORTEST; + recursiveInfo.weightPropertyName = + transformPropertyKeyName(*recursiveType->oC_PropertyKeyName()); + } else { + relType = QueryRelType::ALL_SHORTEST; + } + } else if (recursiveType->WSHORTEST()) { + relType = QueryRelType::WEIGHTED_SHORTEST; + recursiveInfo.weightPropertyName = + transformPropertyKeyName(*recursiveType->oC_PropertyKeyName()); + } else if (recursiveDetail->kU_RecursiveType()->SHORTEST()) { + relType = QueryRelType::SHORTEST; + } else if (recursiveDetail->kU_RecursiveType()->TRAIL()) { + relType = QueryRelType::VARIABLE_LENGTH_TRAIL; + } else if (recursiveDetail->kU_RecursiveType()->ACYCLIC()) { + relType = QueryRelType::VARIABLE_LENGTH_ACYCLIC; + } else { + relType = QueryRelType::VARIABLE_LENGTH_WALK; + } + } else { + relType = QueryRelType::VARIABLE_LENGTH_WALK; + } + // Parse lower, upper bound + auto lowerBound = std::string("1"); + auto upperBound = std::string(""); + auto range = recursiveDetail->oC_RangeLiteral(); + if (range) { + if (range->oC_IntegerLiteral()) { + lowerBound = range->oC_IntegerLiteral()->getText(); + upperBound = lowerBound; + } + if (range->oC_LowerBound()) { + lowerBound = range->oC_LowerBound()->getText(); + } + if (range->oC_UpperBound()) { + upperBound = range->oC_UpperBound()->getText(); + } + } + recursiveInfo.lowerBound = lowerBound; + recursiveInfo.upperBound = upperBound; + // Parse recursive comprehension + auto comprehension = recursiveDetail->kU_RecursiveComprehension(); + if (comprehension) { + recursiveInfo.relName = transformVariable(*comprehension->oC_Variable(0)); + recursiveInfo.nodeName = transformVariable(*comprehension->oC_Variable(1)); + if (comprehension->oC_Where()) { + recursiveInfo.whereExpression = transformWhere(*comprehension->oC_Where()); + } + if (!comprehension->kU_RecursiveProjectionItems().empty()) { + recursiveInfo.hasProjection = true; + KU_ASSERT(comprehension->kU_RecursiveProjectionItems().size() == 2); + auto relProjectionList = + comprehension->kU_RecursiveProjectionItems(0)->oC_ProjectionItems(); + if (relProjectionList) { + recursiveInfo.relProjectionList = transformProjectionItems(*relProjectionList); + } + auto nodeProjectionList = + comprehension->kU_RecursiveProjectionItems(1)->oC_ProjectionItems(); + if (nodeProjectionList) { + recursiveInfo.nodeProjectionList = + transformProjectionItems(*nodeProjectionList); + } + } + } + } + return RelPattern(variable, relTypes, relType, arrowDirection, std::move(properties), + std::move(recursiveInfo)); +} + +std::vector Transformer::transformProperties( + CypherParser::KU_PropertiesContext& ctx) { + std::vector>> result; + KU_ASSERT(ctx.oC_PropertyKeyName().size() == ctx.oC_Expression().size()); + for (auto i = 0u; i < ctx.oC_PropertyKeyName().size(); ++i) { + auto propertyKeyName = transformPropertyKeyName(*ctx.oC_PropertyKeyName(i)); + auto expression = transformExpression(*ctx.oC_Expression(i)); + result.emplace_back(propertyKeyName, std::move(expression)); + } + return result; +} + +std::vector Transformer::transformRelTypes( + CypherParser::OC_RelationshipTypesContext& ctx) { + std::vector relTypes; + for (auto& relType : ctx.oC_RelTypeName()) { + relTypes.push_back(transformRelTypeName(*relType)); + } + return relTypes; +} + +std::vector Transformer::transformNodeLabels(CypherParser::OC_NodeLabelsContext& ctx) { + std::vector nodeLabels; + for (auto& labelName : ctx.oC_LabelName()) { + nodeLabels.push_back(transformLabelName(*labelName)); + } + return nodeLabels; +} + +std::string Transformer::transformLabelName(CypherParser::OC_LabelNameContext& ctx) { + return transformSchemaName(*ctx.oC_SchemaName()); +} + +std::string Transformer::transformRelTypeName(CypherParser::OC_RelTypeNameContext& ctx) { + return transformSchemaName(*ctx.oC_SchemaName()); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_macro.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_macro.cpp new file mode 100644 index 0000000000..2c210848f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_macro.cpp @@ -0,0 +1,34 @@ +#include "parser/create_macro.h" +#include "parser/transformer.h" + +namespace lbug { +namespace parser { + +std::vector Transformer::transformPositionalArgs( + CypherParser::KU_PositionalArgsContext& ctx) { + std::vector positionalArgs; + for (auto& positionalArg : ctx.oC_SymbolicName()) { + positionalArgs.push_back(transformSymbolicName(*positionalArg)); + } + return positionalArgs; +} + +std::unique_ptr Transformer::transformCreateMacro( + CypherParser::KU_CreateMacroContext& ctx) { + auto macroName = transformFunctionName(*ctx.oC_FunctionName()); + auto macroExpression = transformExpression(*ctx.oC_Expression()); + std::vector positionalArgs; + if (ctx.kU_PositionalArgs()) { + positionalArgs = transformPositionalArgs(*ctx.kU_PositionalArgs()); + } + default_macro_args defaultArgs; + for (auto& defaultArg : ctx.kU_DefaultArg()) { + defaultArgs.emplace_back(transformSymbolicName(*defaultArg->oC_SymbolicName()), + transformLiteral(*defaultArg->oC_Literal())); + } + return std::make_unique(std::move(macroName), std::move(macroExpression), + std::move(positionalArgs), std::move(defaultArgs)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_port_db.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_port_db.cpp new file mode 100644 index 0000000000..37d97efff4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_port_db.cpp @@ -0,0 +1,26 @@ +#include "parser/port_db.h" +#include "parser/transformer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformExportDatabase( + CypherParser::KU_ExportDatabaseContext& ctx) { + std::string filePath = transformStringLiteral(*ctx.StringLiteral()); + auto exportDB = std::make_unique(std::move(filePath)); + if (ctx.kU_Options()) { + exportDB->setParsingOption(transformOptions(*ctx.kU_Options())); + } + return exportDB; +} + +std::unique_ptr Transformer::transformImportDatabase( + CypherParser::KU_ImportDatabaseContext& ctx) { + std::string filePath = transformStringLiteral(*ctx.StringLiteral()); + return std::make_unique(std::move(filePath)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_projection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_projection.cpp new file mode 100644 index 0000000000..e41363ebcd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_projection.cpp @@ -0,0 +1,66 @@ +#include "parser/query/return_with_clause/return_clause.h" +#include "parser/query/return_with_clause/with_clause.h" +#include "parser/transformer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +WithClause Transformer::transformWith(CypherParser::OC_WithContext& ctx) { + auto withClause = WithClause(transformProjectionBody(*ctx.oC_ProjectionBody())); + if (ctx.oC_Where()) { + withClause.setWhereExpression(transformWhere(*ctx.oC_Where())); + } + return withClause; +} + +ReturnClause Transformer::transformReturn(CypherParser::OC_ReturnContext& ctx) { + return ReturnClause(transformProjectionBody(*ctx.oC_ProjectionBody())); +} + +ProjectionBody Transformer::transformProjectionBody(CypherParser::OC_ProjectionBodyContext& ctx) { + auto projectionBody = ProjectionBody(nullptr != ctx.DISTINCT(), + transformProjectionItems(*ctx.oC_ProjectionItems())); + if (ctx.oC_Order()) { + std::vector> orderByExpressions; + std::vector isAscOrders; + for (auto& sortItem : ctx.oC_Order()->oC_SortItem()) { + orderByExpressions.push_back(transformExpression(*sortItem->oC_Expression())); + isAscOrders.push_back(!(sortItem->DESC() || sortItem->DESCENDING())); + } + projectionBody.setOrderByExpressions(std::move(orderByExpressions), std::move(isAscOrders)); + } + if (ctx.oC_Skip()) { + projectionBody.setSkipExpression(transformExpression(*ctx.oC_Skip()->oC_Expression())); + } + if (ctx.oC_Limit()) { + projectionBody.setLimitExpression(transformExpression(*ctx.oC_Limit()->oC_Expression())); + } + return projectionBody; +} + +std::vector> Transformer::transformProjectionItems( + CypherParser::OC_ProjectionItemsContext& ctx) { + std::vector> projectionExpressions; + if (ctx.STAR()) { + projectionExpressions.push_back( + std::make_unique(ExpressionType::STAR, ctx.STAR()->getText())); + } + for (auto& projectionItem : ctx.oC_ProjectionItem()) { + projectionExpressions.push_back(transformProjectionItem(*projectionItem)); + } + return projectionExpressions; +} + +std::unique_ptr Transformer::transformProjectionItem( + CypherParser::OC_ProjectionItemContext& ctx) { + auto expression = transformExpression(*ctx.oC_Expression()); + if (ctx.AS()) { + expression->setAlias(transformVariable(*ctx.oC_Variable())); + } + return expression; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_query.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_query.cpp new file mode 100644 index 0000000000..526c24343d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_query.cpp @@ -0,0 +1,60 @@ +#include "parser/query/regular_query.h" +#include "parser/transformer.h" + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformQuery(CypherParser::OC_QueryContext& ctx) { + return transformRegularQuery(*ctx.oC_RegularQuery()); +} + +std::unique_ptr Transformer::transformRegularQuery( + CypherParser::OC_RegularQueryContext& ctx) { + auto regularQuery = std::make_unique(transformSingleQuery(*ctx.oC_SingleQuery())); + for (auto unionClause : ctx.oC_Union()) { + regularQuery->addSingleQuery(transformSingleQuery(*unionClause->oC_SingleQuery()), + unionClause->ALL()); + } + return regularQuery; +} + +SingleQuery Transformer::transformSingleQuery(CypherParser::OC_SingleQueryContext& ctx) { + auto singleQuery = + ctx.oC_MultiPartQuery() ? + transformSinglePartQuery(*ctx.oC_MultiPartQuery()->oC_SinglePartQuery()) : + transformSinglePartQuery(*ctx.oC_SinglePartQuery()); + if (ctx.oC_MultiPartQuery()) { + for (auto queryPart : ctx.oC_MultiPartQuery()->kU_QueryPart()) { + singleQuery.addQueryPart(transformQueryPart(*queryPart)); + } + } + return singleQuery; +} + +SingleQuery Transformer::transformSinglePartQuery(CypherParser::OC_SinglePartQueryContext& ctx) { + auto singleQuery = SingleQuery(); + for (auto& readingClause : ctx.oC_ReadingClause()) { + singleQuery.addReadingClause(transformReadingClause(*readingClause)); + } + for (auto& updatingClause : ctx.oC_UpdatingClause()) { + singleQuery.addUpdatingClause(transformUpdatingClause(*updatingClause)); + } + if (ctx.oC_Return()) { + singleQuery.setReturnClause(transformReturn(*ctx.oC_Return())); + } + return singleQuery; +} + +QueryPart Transformer::transformQueryPart(CypherParser::KU_QueryPartContext& ctx) { + auto queryPart = QueryPart(transformWith(*ctx.oC_With())); + for (auto& readingClause : ctx.oC_ReadingClause()) { + queryPart.addReadingClause(transformReadingClause(*readingClause)); + } + for (auto& updatingClause : ctx.oC_UpdatingClause()) { + queryPart.addUpdatingClause(transformUpdatingClause(*updatingClause)); + } + return queryPart; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_reading_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_reading_clause.cpp new file mode 100644 index 0000000000..cd927d39ca --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_reading_clause.cpp @@ -0,0 +1,118 @@ +#include "common/assert.h" +#include "parser/query/reading_clause/in_query_call_clause.h" +#include "parser/query/reading_clause/load_from.h" +#include "parser/query/reading_clause/match_clause.h" +#include "parser/query/reading_clause/unwind_clause.h" +#include "parser/transformer.h" + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformReadingClause( + CypherParser::OC_ReadingClauseContext& ctx) { + if (ctx.oC_Match()) { + return transformMatch(*ctx.oC_Match()); + } else if (ctx.oC_Unwind()) { + return transformUnwind(*ctx.oC_Unwind()); + } else if (ctx.kU_InQueryCall()) { + return transformInQueryCall(*ctx.kU_InQueryCall()); + } else if (ctx.kU_LoadFrom()) { + return transformLoadFrom(*ctx.kU_LoadFrom()); + } + KU_UNREACHABLE; +} + +std::unique_ptr Transformer::transformMatch(CypherParser::OC_MatchContext& ctx) { + auto matchClauseType = + ctx.OPTIONAL() ? MatchClauseType::OPTIONAL_MATCH : MatchClauseType::MATCH; + auto matchClause = + std::make_unique(transformPattern(*ctx.oC_Pattern()), matchClauseType); + if (ctx.oC_Where()) { + matchClause->setWherePredicate(transformWhere(*ctx.oC_Where())); + } + if (ctx.kU_Hint()) { + matchClause->setHint(transformJoinHint(*ctx.kU_Hint()->kU_JoinNode())); + } + return matchClause; +} + +std::shared_ptr Transformer::transformJoinHint( + CypherParser::KU_JoinNodeContext& ctx) { + if (!ctx.MULTI_JOIN().empty()) { + auto joinNode = std::make_shared(); + joinNode->addChild(transformJoinHint(*ctx.kU_JoinNode(0))); + for (auto& schemaNameCtx : ctx.oC_SchemaName()) { + joinNode->addChild(std::make_shared(transformSchemaName(*schemaNameCtx))); + } + return joinNode; + } + if (!ctx.oC_SchemaName().empty()) { + return std::make_shared(transformSchemaName(*ctx.oC_SchemaName(0))); + } + if (ctx.kU_JoinNode().size() == 1) { + return transformJoinHint(*ctx.kU_JoinNode(0)); + } + KU_ASSERT(ctx.kU_JoinNode().size() == 2); + auto joinNode = std::make_shared(); + joinNode->addChild(transformJoinHint(*ctx.kU_JoinNode(0))); + joinNode->addChild(transformJoinHint(*ctx.kU_JoinNode(1))); + return joinNode; +} + +std::unique_ptr Transformer::transformUnwind(CypherParser::OC_UnwindContext& ctx) { + auto expression = transformExpression(*ctx.oC_Expression()); + auto transformedVariable = transformVariable(*ctx.oC_Variable()); + return std::make_unique(std::move(expression), std::move(transformedVariable)); +} + +std::vector Transformer::transformYieldVariables( + CypherParser::OC_YieldItemsContext& ctx) { + std::vector yieldVariables; + std::string name; + for (auto& yieldItem : ctx.oC_YieldItem()) { + std::string alias; + if (yieldItem->AS()) { + alias = transformVariable(*yieldItem->oC_Variable(1)); + } + name = transformVariable(*yieldItem->oC_Variable(0)); + yieldVariables.emplace_back(name, alias); + } + return yieldVariables; +} + +std::unique_ptr Transformer::transformInQueryCall( + CypherParser::KU_InQueryCallContext& ctx) { + auto functionExpression = + Transformer::transformFunctionInvocation(*ctx.oC_FunctionInvocation()); + std::vector yieldVariables; + if (ctx.oC_YieldItems()) { + yieldVariables = transformYieldVariables(*ctx.oC_YieldItems()); + } + auto inQueryCall = std::make_unique(std::move(functionExpression), + std::move(yieldVariables)); + if (ctx.oC_Where()) { + inQueryCall->setWherePredicate(transformWhere(*ctx.oC_Where())); + } + return inQueryCall; +} + +std::unique_ptr Transformer::transformLoadFrom( + CypherParser::KU_LoadFromContext& ctx) { + auto source = transformScanSource(*ctx.kU_ScanSource()); + auto loadFrom = std::make_unique(std::move(source)); + if (ctx.kU_ColumnDefinitions()) { + loadFrom->setPropertyDefinitions(transformColumnDefinitions(*ctx.kU_ColumnDefinitions())); + } + if (ctx.kU_Options()) { + loadFrom->setParingOptions(transformOptions(*ctx.kU_Options())); + } + if (ctx.oC_Where()) { + loadFrom->setWherePredicate(transformWhere(*ctx.oC_Where())); + } + return loadFrom; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_standalone_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_standalone_call.cpp new file mode 100644 index 0000000000..9063c2c3c0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_standalone_call.cpp @@ -0,0 +1,21 @@ +#include "parser/standalone_call.h" +#include "parser/standalone_call_function.h" +#include "parser/transformer.h" + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformStandaloneCall( + CypherParser::KU_StandaloneCallContext& ctx) { + if (ctx.oC_FunctionInvocation()) { + return std::make_unique( + transformFunctionInvocation(*ctx.oC_FunctionInvocation())); + } else { + auto optionName = transformSymbolicName(*ctx.oC_SymbolicName()); + auto parameter = transformExpression(*ctx.oC_Expression()); + return std::make_unique(std::move(optionName), std::move(parameter)); + } +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_transaction.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_transaction.cpp new file mode 100644 index 0000000000..22adaeecf2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_transaction.cpp @@ -0,0 +1,32 @@ +#include "common/assert.h" +#include "parser/transaction_statement.h" +#include "parser/transformer.h" + +using namespace lbug::transaction; +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformTransaction( + CypherParser::KU_TransactionContext& ctx) { + if (ctx.TRANSACTION()) { + if (ctx.READ()) { + return std::make_unique(TransactionAction::BEGIN_READ); + } + return std::make_unique(TransactionAction::BEGIN_WRITE); + } + if (ctx.COMMIT()) { + return std::make_unique(TransactionAction::COMMIT); + } + if (ctx.ROLLBACK()) { + return std::make_unique(TransactionAction::ROLLBACK); + } + if (ctx.CHECKPOINT()) { + return std::make_unique(TransactionAction::CHECKPOINT); + } + KU_UNREACHABLE; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_updating_clause.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_updating_clause.cpp new file mode 100644 index 0000000000..dbd4db57dd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_updating_clause.cpp @@ -0,0 +1,79 @@ +#include "common/assert.h" +#include "parser/query/updating_clause/delete_clause.h" +#include "parser/query/updating_clause/insert_clause.h" +#include "parser/query/updating_clause/merge_clause.h" +#include "parser/query/updating_clause/set_clause.h" +#include "parser/transformer.h" + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformUpdatingClause( + CypherParser::OC_UpdatingClauseContext& ctx) { + if (ctx.oC_Create()) { + return transformCreate(*ctx.oC_Create()); + } else if (ctx.oC_Merge()) { + return transformMerge(*ctx.oC_Merge()); + } else if (ctx.oC_Set()) { + return transformSet(*ctx.oC_Set()); + } else { + KU_ASSERT(ctx.oC_Delete()); + return transformDelete(*ctx.oC_Delete()); + } +} + +std::unique_ptr Transformer::transformCreate(CypherParser::OC_CreateContext& ctx) { + return std::make_unique(transformPattern(*ctx.oC_Pattern())); +} + +std::unique_ptr Transformer::transformMerge(CypherParser::OC_MergeContext& ctx) { + auto mergeClause = std::make_unique(transformPattern(*ctx.oC_Pattern())); + for (auto& mergeActionCtx : ctx.oC_MergeAction()) { + if (mergeActionCtx->MATCH()) { + for (auto& setItemCtx : mergeActionCtx->oC_Set()->oC_SetItem()) { + mergeClause->addOnMatchSetItems(transformSetItem(*setItemCtx)); + } + } else { + for (auto& setItemCtx : mergeActionCtx->oC_Set()->oC_SetItem()) { + mergeClause->addOnCreateSetItems(transformSetItem(*setItemCtx)); + } + } + } + return mergeClause; +} + +std::unique_ptr Transformer::transformSet(CypherParser::OC_SetContext& ctx) { + auto setClause = std::make_unique(); + if (ctx.kU_Properties()) { + auto child = transformAtom(*ctx.oC_Atom()); + for (auto i = 0u; i < ctx.kU_Properties()->oC_PropertyKeyName().size(); ++i) { + auto propertyKeyName = createPropertyExpression( + *ctx.kU_Properties()->oC_PropertyKeyName(i), child->copy()); + auto expression = transformExpression(*ctx.kU_Properties()->oC_Expression(i)); + setClause->addSetItem(make_pair(std::move(propertyKeyName), std::move(expression))); + } + } else { + for (auto& setItem : ctx.oC_SetItem()) { + setClause->addSetItem(transformSetItem(*setItem)); + } + } + return setClause; +} + +parsed_expr_pair Transformer::transformSetItem(CypherParser::OC_SetItemContext& ctx) { + return make_pair(transformProperty(*ctx.oC_PropertyExpression()), + transformExpression(*ctx.oC_Expression())); +} + +std::unique_ptr Transformer::transformDelete(CypherParser::OC_DeleteContext& ctx) { + auto deleteClauseType = + ctx.DETACH() ? common::DeleteNodeType::DETACH_DELETE : common::DeleteNodeType::DELETE; + auto deleteClause = std::make_unique(deleteClauseType); + for (auto& expression : ctx.oC_Expression()) { + deleteClause->addExpression(transformExpression(*expression)); + } + return deleteClause; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_use_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_use_database.cpp new file mode 100644 index 0000000000..d72b83e8e8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transform/transform_use_database.cpp @@ -0,0 +1,14 @@ +#include "parser/transformer.h" +#include "parser/use_database.h" + +namespace lbug { +namespace parser { + +std::unique_ptr Transformer::transformUseDatabase( + CypherParser::KU_UseDatabaseContext& ctx) { + auto dbName = transformSchemaName(*ctx.oC_SchemaName()); + return std::make_unique(std::move(dbName)); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transformer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transformer.cpp new file mode 100644 index 0000000000..003a25f0bf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/transformer.cpp @@ -0,0 +1,213 @@ +#include "parser/transformer.h" + +#include + +#include "common/assert.h" +#include "common/exception/parser.h" +#include "extension/transformer_extension.h" +#include "parser/explain_statement.h" +#include "parser/query/regular_query.h" // IWYU pragma: keep (fixes a forward declaration error) + +using namespace lbug::common; + +namespace lbug { +namespace parser { + +std::vector> Transformer::transform() { + std::vector> statements; + for (auto& oc_Statement : root.oC_Cypher()) { + auto statement = transformStatement(*oc_Statement->oC_Statement()); + if (oc_Statement->oC_AnyCypherOption()) { + auto cypherOption = oc_Statement->oC_AnyCypherOption(); + auto explainType = ExplainType::PROFILE; + if (cypherOption->oC_Explain()) { + explainType = cypherOption->oC_Explain()->LOGICAL() ? ExplainType::LOGICAL_PLAN : + ExplainType::PHYSICAL_PLAN; + } + statements.push_back( + std::make_unique(std::move(statement), explainType)); + continue; + } + statements.push_back(std::move(statement)); + } + return statements; +} + +std::unique_ptr Transformer::transformStatement(CypherParser::OC_StatementContext& ctx) { + if (ctx.oC_Query()) { + return transformQuery(*ctx.oC_Query()); + } else if (ctx.kU_CreateNodeTable()) { + return transformCreateNodeTable(*ctx.kU_CreateNodeTable()); + } else if (ctx.kU_CreateRelTable()) { + return transformCreateRelGroup(*ctx.kU_CreateRelTable()); + } else if (ctx.kU_CreateSequence()) { + return transformCreateSequence(*ctx.kU_CreateSequence()); + } else if (ctx.kU_CreateType()) { + return transformCreateType(*ctx.kU_CreateType()); + } else if (ctx.kU_CreateUser()) { + return transformExtensionStatement(ctx.kU_CreateUser()); + } else if (ctx.kU_CreateRole()) { + return transformExtensionStatement(ctx.kU_CreateRole()); + } else if (ctx.kU_Drop()) { + return transformDrop(*ctx.kU_Drop()); + } else if (ctx.kU_AlterTable()) { + return transformAlterTable(*ctx.kU_AlterTable()); + } else if (ctx.kU_CopyFromByColumn()) { + return transformCopyFromByColumn(*ctx.kU_CopyFromByColumn()); + } else if (ctx.kU_CopyFrom()) { + return transformCopyFrom(*ctx.kU_CopyFrom()); + } else if (ctx.kU_CopyTO()) { + return transformCopyTo(*ctx.kU_CopyTO()); + } else if (ctx.kU_StandaloneCall()) { + return transformStandaloneCall(*ctx.kU_StandaloneCall()); + } else if (ctx.kU_CreateMacro()) { + return transformCreateMacro(*ctx.kU_CreateMacro()); + } else if (ctx.kU_CommentOn()) { + return transformCommentOn(*ctx.kU_CommentOn()); + } else if (ctx.kU_Transaction()) { + return transformTransaction(*ctx.kU_Transaction()); + } else if (ctx.kU_Extension()) { + return transformExtension(*ctx.kU_Extension()); + } else if (ctx.kU_ExportDatabase()) { + return transformExportDatabase(*ctx.kU_ExportDatabase()); + } else if (ctx.kU_ImportDatabase()) { + return transformImportDatabase(*ctx.kU_ImportDatabase()); + } else if (ctx.kU_AttachDatabase()) { + return transformAttachDatabase(*ctx.kU_AttachDatabase()); + } else if (ctx.kU_DetachDatabase()) { + return transformDetachDatabase(*ctx.kU_DetachDatabase()); + } else if (ctx.kU_UseDatabase()) { + return transformUseDatabase(*ctx.kU_UseDatabase()); + } else { + KU_UNREACHABLE; + } +} + +std::unique_ptr Transformer::transformWhere(CypherParser::OC_WhereContext& ctx) { + return transformExpression(*ctx.oC_Expression()); +} + +std::string Transformer::transformSchemaName(CypherParser::OC_SchemaNameContext& ctx) { + return transformSymbolicName(*ctx.oC_SymbolicName()); +} + +std::string Transformer::transformStringLiteral(antlr4::tree::TerminalNode& stringLiteral) { + auto str = stringLiteral.getText(); + std::string content = str.substr(1, str.length() - 2); + std::string result; + result.reserve(content.length()); + for (auto i = 0u; i < content.length(); i++) { + if (content[i] == '\\' && i + 1 < content.length()) { + char next = content[i + 1]; + switch (next) { + case '\\': + case '\'': + case '"': { + result += next; + i++; + } break; + case 'b': + case 'B': { + result += '\b'; + i++; + } break; + case 'f': + case 'F': { + result += '\f'; + i++; + } break; + case 'n': + case 'N': { + result += '\n'; + i++; + } break; + case 'r': + case 'R': { + result += '\r'; + i++; + } break; + case 't': + case 'T': { + result += '\t'; + i++; + } break; + case 'x': + case 'X': { + result += content.substr(i, 4); + i += 3; + } break; + case 'u': + case 'U': { + // Handle \uHHHH and \UHHHHHHHH unicode escape sequences + if (next == 'u' || next == 'U') { + int hexDigits = (next == 'u') ? 4 : 8; + if (i + 1 + hexDigits > content.length()) { + KU_UNREACHABLE; + } + std::string hexStr = content.substr(i + 2, hexDigits); + char* endPtr = nullptr; + long hexValue = std::strtol(hexStr.c_str(), &endPtr, 16); + if (endPtr != hexStr.c_str() + hexDigits) { + KU_UNREACHABLE; + } + // Convert Unicode code point to UTF-8 + if (hexValue <= 0x7F) { + result += static_cast(hexValue); + } else if (hexValue <= 0x7FF) { + result += static_cast(0xC0 | (hexValue >> 6)); + result += static_cast(0x80 | (hexValue & 0x3F)); + } else if (hexValue <= 0xFFFF) { + result += static_cast(0xE0 | (hexValue >> 12)); + result += static_cast(0x80 | ((hexValue >> 6) & 0x3F)); + result += static_cast(0x80 | (hexValue & 0x3F)); + } else if (hexValue <= 0x10FFFF) { + result += static_cast(0xF0 | (hexValue >> 18)); + result += static_cast(0x80 | ((hexValue >> 12) & 0x3F)); + result += static_cast(0x80 | ((hexValue >> 6) & 0x3F)); + result += static_cast(0x80 | (hexValue & 0x3F)); + } else { + KU_UNREACHABLE; + } + i += 1 + hexDigits; + } + } break; + default: + KU_UNREACHABLE; + } + } else { + result += content[i]; + } + } + + return result; +} + +std::string Transformer::transformVariable(CypherParser::OC_VariableContext& ctx) { + return transformSymbolicName(*ctx.oC_SymbolicName()); +} +std::string Transformer::transformSymbolicName(CypherParser::OC_SymbolicNameContext& ctx) { + if (ctx.EscapedSymbolicName()) { + std::string escapedSymbolName = ctx.EscapedSymbolicName()->getText(); + // escapedSymbolName symbol will be of form "`Some.Value`". Therefore, we need to sanitize + // it such that we don't store the symbol with escape character. + return escapedSymbolName.substr(1, escapedSymbolName.size() - 2); + } else { + KU_ASSERT(ctx.HexLetter() || ctx.UnescapedSymbolicName() || ctx.kU_NonReservedKeywords()); + return ctx.getText(); + } +} + +std::unique_ptr Transformer::transformExtensionStatement( + antlr4::ParserRuleContext* ctx) { + for (auto& transformerExtension : transformerExtensions) { + auto statement = transformerExtension->transform(ctx); + if (statement) { + return statement; + } + } + throw common::ParserException{ + "Failed parse the statement. Do you forget to load the extension?"}; +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/CMakeLists.txt new file mode 100644 index 0000000000..32fbc07f93 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library( + lbug_parser_visitor + OBJECT + standalone_call_rewriter.cpp + statement_read_write_analyzer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/standalone_call_rewriter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/standalone_call_rewriter.cpp new file mode 100644 index 0000000000..69c5910e51 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/standalone_call_rewriter.cpp @@ -0,0 +1,50 @@ +#include "parser/visitor/standalone_call_rewriter.h" + +#include "binder/binder.h" +#include "binder/bound_standalone_call_function.h" +#include "catalog/catalog.h" +#include "common/exception/parser.h" +#include "main/client_context.h" +#include "parser/expression/parsed_function_expression.h" +#include "parser/standalone_call_function.h" +#include "transaction/transaction.h" +#include "transaction/transaction_context.h" + +namespace lbug { +namespace parser { + +std::string StandaloneCallRewriter::getRewriteQuery(const Statement& statement) { + visit(statement); + return rewriteQuery; +} + +void StandaloneCallRewriter::visitStandaloneCallFunction(const Statement& statement) { + auto& standaloneCallFunc = statement.constCast(); + main::ClientContext::TransactionHelper::runFuncInTransaction( + *transaction::TransactionContext::Get(*context), + [&]() -> void { + auto funcName = standaloneCallFunc.getFunctionExpression() + ->constPtrCast() + ->getFunctionName(); + if (!catalog::Catalog::Get(*context)->containsFunction( + transaction::Transaction::Get(*context), funcName) && + !singleStatement) { + throw common::ParserException{ + funcName + " must be called in a query which doesn't have other statements."}; + } + binder::Binder binder{context}; + const auto boundStatement = binder.bind(standaloneCallFunc); + auto& boundStandaloneCall = + boundStatement->constCast(); + const auto func = + boundStandaloneCall.getTableFunction().constPtrCast(); + if (func->rewriteFunc) { + rewriteQuery = func->rewriteFunc(*context, *boundStandaloneCall.getBindData()); + } + }, + true /*readOnlyStatement*/, false /*isTransactionStatement*/, + main::ClientContext::TransactionHelper::TransactionCommitAction::COMMIT_IF_NEW); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/statement_read_write_analyzer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/statement_read_write_analyzer.cpp new file mode 100644 index 0000000000..d54a31c8c5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/parser/visitor/statement_read_write_analyzer.cpp @@ -0,0 +1,54 @@ +#include "parser/visitor/statement_read_write_analyzer.h" + +#include "main/client_context.h" +#include "main/db_config.h" +#include "parser/expression/parsed_expression_visitor.h" +#include "parser/query/reading_clause/reading_clause.h" +#include "parser/query/return_with_clause/with_clause.h" + +namespace lbug { +namespace parser { + +void StatementReadWriteAnalyzer::visitExtension(const Statement& /*statement*/) { + // We allow LOAD EXTENSION to run in read-only mode. + if (context->getDBConfig()->readOnly) { + readOnly = true; + } else { + readOnly = false; + } +} + +void StatementReadWriteAnalyzer::visitReadingClause(const ReadingClause* readingClause) { + if (readingClause->hasWherePredicate()) { + if (!isExprReadOnly(readingClause->getWherePredicate())) { + readOnly = false; + } + } +} + +void StatementReadWriteAnalyzer::visitWithClause(const WithClause* withClause) { + for (auto& expr : withClause->getProjectionBody()->getProjectionExpressions()) { + if (!isExprReadOnly(expr.get())) { + readOnly = false; + return; + } + } +} + +void StatementReadWriteAnalyzer::visitReturnClause(const ReturnClause* returnClause) { + for (auto& expr : returnClause->getProjectionBody()->getProjectionExpressions()) { + if (!isExprReadOnly(expr.get())) { + readOnly = false; + return; + } + } +} + +bool StatementReadWriteAnalyzer::isExprReadOnly(const ParsedExpression* expr) { + auto analyzer = ReadWriteExprAnalyzer(context); + analyzer.visit(expr); + return analyzer.isReadOnly(); +} + +} // namespace parser +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/CMakeLists.txt new file mode 100644 index 0000000000..5d600fea53 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/CMakeLists.txt @@ -0,0 +1,14 @@ +add_subdirectory(join_order) +add_subdirectory(operator) +add_subdirectory(plan) + +add_library(lbug_planner + OBJECT + join_order_enumerator_context.cpp + planner.cpp + query_planner.cpp + subplans_table.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/CMakeLists.txt new file mode 100644 index 0000000000..76ce7a3212 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(lbug_planner_join_order + OBJECT + cardinality_estimator.cpp + cost_model.cpp + join_order_util.cpp + join_plan_solver.cpp + join_tree.cpp + join_tree_constructor.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/cardinality_estimator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/cardinality_estimator.cpp new file mode 100644 index 0000000000..e99a651ace --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/cardinality_estimator.cpp @@ -0,0 +1,247 @@ +#include "planner/join_order/cardinality_estimator.h" + +#include "binder/expression/property_expression.h" +#include "main/client_context.h" +#include "planner/join_order/join_order_util.h" +#include "planner/operator/logical_aggregate.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/scan/logical_scan_node_table.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace planner { + +static cardinality_t atLeastOne(uint64_t x) { + return x == 0 ? 1 : x; +} + +void CardinalityEstimator::init(const QueryGraph& queryGraph) { + for (auto i = 0u; i < queryGraph.getNumQueryNodes(); ++i) { + init(*queryGraph.getQueryNode(i)); + } + for (uint64_t i = 0u; i < queryGraph.getNumQueryRels(); ++i) { + auto rel = queryGraph.getQueryRel(i); + if (QueryRelTypeUtils::isRecursive(rel->getRelType())) { + auto recursiveInfo = rel->getRecursiveInfo(); + init(*recursiveInfo->node); + } + } +} + +void CardinalityEstimator::init(const NodeExpression& node) { + auto key = node.getInternalID()->getUniqueName(); + cardinality_t numNodes = 0u; + auto storageManager = storage::StorageManager::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + for (auto tableID : node.getTableIDs()) { + auto stats = + storageManager->getTable(tableID)->cast().getStats(transaction); + numNodes += stats.getTableCard(); + if (!nodeTableStats.contains(tableID)) { + nodeTableStats.insert({tableID, std::move(stats)}); + } + } + if (!nodeIDName2dom.contains(key)) { + nodeIDName2dom.insert({key, numNodes}); + } +} + +void CardinalityEstimator::rectifyCardinality(const Expression& nodeID, cardinality_t card) { + KU_ASSERT(nodeIDName2dom.contains(nodeID.getUniqueName())); + auto newCard = std::min(nodeIDName2dom.at(nodeID.getUniqueName()), card); + nodeIDName2dom[nodeID.getUniqueName()] = newCard; +} + +cardinality_t CardinalityEstimator::getNodeIDDom(const std::string& nodeIDName) const { + KU_ASSERT(nodeIDName2dom.contains(nodeIDName)); + return nodeIDName2dom.at(nodeIDName); +} + +uint64_t CardinalityEstimator::estimateScanNode(const LogicalOperator& op) const { + const auto& scan = op.constCast(); + switch (scan.getScanType()) { + case LogicalScanNodeTableType::PRIMARY_KEY_SCAN: + return 1; + default: + return atLeastOne(getNodeIDDom(scan.getNodeID()->getUniqueName())); + } +} + +uint64_t CardinalityEstimator::estimateAggregate(const LogicalAggregate& op) const { + // TODO(Royi) we can use HLL to better estimate the number of distinct keys here + return op.getKeys().empty() ? 1 : op.getChild(0)->getCardinality(); +} + +cardinality_t CardinalityEstimator::multiply(double extensionRate, cardinality_t card) const { + return atLeastOne(extensionRate * card); +} + +uint64_t CardinalityEstimator::estimateHashJoin( + const std::vector& joinConditions, const LogicalOperator& probeOp, + const LogicalOperator& buildOp) const { + if (LogicalHashJoin::isNodeIDOnlyJoin(joinConditions)) { + cardinality_t denominator = 1u; + auto joinKeys = LogicalHashJoin::getJoinNodeIDs(joinConditions); + for (auto& joinKey : joinKeys) { + if (nodeIDName2dom.contains(joinKey->getUniqueName())) { + denominator *= getNodeIDDom(joinKey->getUniqueName()); + } + } + return atLeastOne(probeOp.getCardinality() * + JoinOrderUtil::getJoinKeysFlatCardinality(joinKeys, buildOp) / + atLeastOne(denominator)); + } else { + // Naively estimate the cardinality if the join is non-ID based + cardinality_t estCardinality = probeOp.getCardinality() * buildOp.getCardinality(); + for (size_t i = 0; i < joinConditions.size(); ++i) { + estCardinality *= PlannerKnobs::EQUALITY_PREDICATE_SELECTIVITY; + } + return atLeastOne(estCardinality); + } +} + +uint64_t CardinalityEstimator::estimateCrossProduct(const LogicalOperator& probeOp, + const LogicalOperator& buildOp) const { + return atLeastOne(probeOp.getCardinality() * buildOp.getCardinality()); +} + +uint64_t CardinalityEstimator::estimateIntersect(const expression_vector& joinNodeIDs, + const LogicalOperator& probeOp, const std::vector& buildOps) const { + // Formula 1: treat intersect as a Filter on probe side. + uint64_t estCardinality1 = + probeOp.getCardinality() * PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY; + // Formula 2: assume independence on join conditions. + cardinality_t denominator = 1u; + for (auto& joinNodeID : joinNodeIDs) { + denominator *= getNodeIDDom(joinNodeID->getUniqueName()); + } + auto numerator = probeOp.getCardinality(); + for (auto& buildOp : buildOps) { + numerator *= buildOp->getCardinality(); + } + auto estCardinality2 = numerator / atLeastOne(denominator); + // Pick minimum between the two formulas. + return atLeastOne(std::min(estCardinality1, estCardinality2)); +} + +uint64_t CardinalityEstimator::estimateFlatten(const LogicalOperator& childOp, + f_group_pos groupPosToFlatten) const { + auto group = childOp.getSchema()->getGroup(groupPosToFlatten); + return atLeastOne(childOp.getCardinality() * group->cardinalityMultiplier); +} + +static bool isPrimaryKey(const Expression& expression) { + if (expression.expressionType != ExpressionType::PROPERTY) { + return false; + } + return ((PropertyExpression&)expression).isPrimaryKey(); +} + +static bool isSingleLabelledProperty(const Expression& expression) { + if (expression.expressionType != ExpressionType::PROPERTY) { + return false; + } + return expression.constCast().isSingleLabel(); +} + +static std::optional getTableStatsIfPossible(main::ClientContext* context, + const Expression& predicate, + const std::unordered_map& nodeTableStats) { + KU_ASSERT(predicate.getNumChildren() >= 1); + if (isSingleLabelledProperty(*predicate.getChild(0))) { + auto& propertyExpr = predicate.getChild(0)->cast(); + auto tableID = propertyExpr.getSingleTableID(); + if (nodeTableStats.contains(tableID) && propertyExpr.hasProperty(tableID)) { + auto transaction = Transaction::Get(*context); + auto entry = + catalog::Catalog::Get(*context)->getTableCatalogEntry(transaction, tableID); + auto columnID = entry->getColumnID(propertyExpr.getPropertyName()); + if (columnID != INVALID_COLUMN_ID && columnID != ROW_IDX_COLUMN_ID) { + auto& stats = nodeTableStats.at(tableID); + return atLeastOne(stats.getNumDistinctValues(columnID)); + } + } + } + return {}; +} + +uint64_t CardinalityEstimator::estimateFilter(const LogicalOperator& childPlan, + const Expression& predicate) const { + if (predicate.expressionType == ExpressionType::EQUALS) { + if (isPrimaryKey(*predicate.getChild(0)) || isPrimaryKey(*predicate.getChild(1))) { + return 1; + } else { + const auto numDistinctValues = + getTableStatsIfPossible(context, predicate, nodeTableStats); + if (numDistinctValues.has_value()) { + return atLeastOne(childPlan.getCardinality() / numDistinctValues.value()); + } + return atLeastOne( + childPlan.getCardinality() * PlannerKnobs::EQUALITY_PREDICATE_SELECTIVITY); + } + } else { + return atLeastOne( + childPlan.getCardinality() * PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY); + } +} + +uint64_t CardinalityEstimator::getNumNodes(const Transaction*, + const std::vector& tableIDs) const { + cardinality_t numNodes = 0u; + for (auto& tableID : tableIDs) { + KU_ASSERT(nodeTableStats.contains(tableID)); + numNodes += nodeTableStats.at(tableID).getTableCard(); + } + return atLeastOne(numNodes); +} + +uint64_t CardinalityEstimator::getNumRels(const Transaction* transaction, + const std::vector& tableIDs) const { + cardinality_t numRels = 0u; + for (auto tableID : tableIDs) { + numRels += + storage::StorageManager::Get(*context)->getTable(tableID)->getNumTotalRows(transaction); + } + return atLeastOne(numRels); +} + +double CardinalityEstimator::getExtensionRate(const RelExpression& rel, + const NodeExpression& boundNode, const Transaction* transaction) const { + auto numBoundNodes = static_cast(getNumNodes(transaction, boundNode.getTableIDs())); + auto numRels = static_cast(getNumRels(transaction, rel.getInnerRelTableIDs())); + KU_ASSERT(numBoundNodes > 0); + auto oneHopExtensionRate = numRels / atLeastOne(numBoundNodes); + switch (rel.getRelType()) { + case QueryRelType::NON_RECURSIVE: { + return oneHopExtensionRate; + } + case QueryRelType::VARIABLE_LENGTH_WALK: + case QueryRelType::VARIABLE_LENGTH_TRAIL: + case QueryRelType::VARIABLE_LENGTH_ACYCLIC: { + auto rate = oneHopExtensionRate * + std::max(rel.getRecursiveInfo()->bindData->upperBound, 1); + return rate * context->getClientConfig()->recursivePatternCardinalityScaleFactor; + } + case QueryRelType::SHORTEST: + case QueryRelType::ALL_SHORTEST: + case QueryRelType::WEIGHTED_SHORTEST: + case QueryRelType::ALL_WEIGHTED_SHORTEST: { + auto rate = std::min( + oneHopExtensionRate * + std::max(rel.getRecursiveInfo()->bindData->upperBound, 1), + numRels); + return rate * context->getClientConfig()->recursivePatternCardinalityScaleFactor; + } + default: + KU_UNREACHABLE; + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/cost_model.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/cost_model.cpp new file mode 100644 index 0000000000..25f3e4c3cb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/cost_model.cpp @@ -0,0 +1,57 @@ +#include "planner/join_order/cost_model.h" + +#include "common/constants.h" +#include "planner/join_order/join_order_util.h" +#include "planner/operator/logical_hash_join.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +uint64_t CostModel::computeExtendCost(const LogicalPlan& childPlan) { + return childPlan.getCost() + childPlan.getCardinality(); +} + +uint64_t CostModel::computeHashJoinCost(const std::vector& joinConditions, + const LogicalPlan& probe, const LogicalPlan& build) { + return computeHashJoinCost(LogicalHashJoin::getJoinNodeIDs(joinConditions), probe, build); +} + +uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNodeIDs, + const LogicalPlan& probe, const LogicalPlan& build) { + uint64_t cost = 0ul; + cost += probe.getCost(); + cost += build.getCost(); + cost += probe.getCardinality(); + cost += PlannerKnobs::BUILD_PENALTY * + JoinOrderUtil::getJoinKeysFlatCardinality(joinNodeIDs, build.getLastOperatorRef()); + return cost; +} + +uint64_t CostModel::computeMarkJoinCost(const std::vector& joinConditions, + const LogicalPlan& probe, const LogicalPlan& build) { + return computeMarkJoinCost(LogicalHashJoin::getJoinNodeIDs(joinConditions), probe, build); +} + +uint64_t CostModel::computeMarkJoinCost(const binder::expression_vector& joinNodeIDs, + const LogicalPlan& probe, const LogicalPlan& build) { + return computeHashJoinCost(joinNodeIDs, probe, build); +} + +uint64_t CostModel::computeIntersectCost(const LogicalPlan& probePlan, + const std::vector& buildPlans) { + uint64_t cost = 0ul; + cost += probePlan.getCost(); + // TODO(Xiyang): think of how to calculate intersect cost such that it will be picked in worst + // case. + cost += probePlan.getCardinality(); + for (auto& buildPlan : buildPlans) { + KU_ASSERT(buildPlan.getCardinality() >= 1); + cost += buildPlan.getCost(); + } + return cost; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_order_util.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_order_util.cpp new file mode 100644 index 0000000000..c6526171f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_order_util.cpp @@ -0,0 +1,24 @@ +#include "planner/join_order/join_order_util.h" + +namespace lbug { +namespace planner { + +uint64_t JoinOrderUtil::getJoinKeysFlatCardinality(const binder::expression_vector& joinNodeIDs, + const LogicalOperator& buildOp) { + auto schema = buildOp.getSchema(); + f_group_pos_set unFlatGroupsPos; + for (auto& joinID : joinNodeIDs) { + auto groupPos = schema->getGroupPos(*joinID); + if (!schema->getGroup(groupPos)->isFlat()) { + unFlatGroupsPos.insert(groupPos); + } + } + auto cardinality = buildOp.getCardinality(); + for (auto groupPos : unFlatGroupsPos) { + cardinality *= schema->getGroup(groupPos)->getMultiplier(); + } + return cardinality; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_plan_solver.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_plan_solver.cpp new file mode 100644 index 0000000000..346c92f262 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_plan_solver.cpp @@ -0,0 +1,149 @@ +#include "planner/join_order/join_plan_solver.h" + +#include "common/enums/extend_direction.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +LogicalPlan JoinPlanSolver::solve(const JoinTree& joinTree) { + return solveTreeNode(*joinTree.root, nullptr); +} + +LogicalPlan JoinPlanSolver::solveTreeNode(const JoinTreeNode& current, const JoinTreeNode* parent) { + switch (current.type) { + case TreeNodeType::NODE_SCAN: { + return solveNodeScanTreeNode(current); + } + case TreeNodeType::REL_SCAN: { + KU_ASSERT(parent != nullptr); + return solveRelScanTreeNode(current, *parent); + } + case TreeNodeType::BINARY_JOIN: { + return solveBinaryJoinTreeNode(current); + } + case TreeNodeType::MULTIWAY_JOIN: { + return solveMultiwayJoinTreeNode(current); + } + default: + KU_UNREACHABLE; + } +} + +static ExtendDirection getExtendDirection(const RelExpression& rel, + const NodeExpression& boundNode) { + if (rel.getDirectionType() == binder::RelDirectionType::BOTH) { + return ExtendDirection::BOTH; + } + if (*rel.getSrcNode() == boundNode) { + return ExtendDirection::FWD; + } else { + return ExtendDirection::BWD; + } +} + +static std::shared_ptr getOtherNode(const RelExpression& rel, + const NodeExpression& boundNode) { + if (*rel.getSrcNode() == boundNode) { + return rel.getDstNode(); + } + return rel.getSrcNode(); +} + +LogicalPlan JoinPlanSolver::solveNodeScanTreeNode(const JoinTreeNode& treeNode) { + auto& extraInfo = treeNode.extraInfo->constCast(); + KU_ASSERT(extraInfo.nodeInfo != nullptr); + auto& nodeInfo = *extraInfo.nodeInfo; + auto boundNode = std::static_pointer_cast(nodeInfo.nodeOrRel); + auto plan = LogicalPlan(); + planner->appendScanNodeTable(boundNode->getInternalID(), boundNode->getTableIDs(), + nodeInfo.properties, plan); + planner->appendFilters(nodeInfo.predicates, plan); + for (auto& relInfo : extraInfo.relInfos) { + auto rel = std::static_pointer_cast(relInfo.nodeOrRel); + auto nbrNode = getOtherNode(*rel, *boundNode); + auto direction = getExtendDirection(*rel, *boundNode); + planner->appendExtend(boundNode, nbrNode, rel, direction, relInfo.properties, plan); + planner->appendFilters(relInfo.predicates, plan); + } + planner->appendFilters(extraInfo.predicates, plan); + return plan; +} + +LogicalPlan JoinPlanSolver::solveRelScanTreeNode(const JoinTreeNode& treeNode, + const JoinTreeNode& parent) { + auto& extraInfo = treeNode.extraInfo->constCast(); + auto& relInfo = extraInfo.relInfos[0]; + auto rel = std::static_pointer_cast(relInfo.nodeOrRel); + std::shared_ptr boundNode = nullptr; + std::shared_ptr nbrNode = nullptr; + switch (parent.type) { + case TreeNodeType::BINARY_JOIN: { + auto& joinExtraInfo = parent.extraInfo->constCast(); + if (joinExtraInfo.joinNodes.size() == 1) { + boundNode = joinExtraInfo.joinNodes[0]; + } else { + boundNode = rel->getSrcNode(); + } + nbrNode = getOtherNode(*rel, *boundNode); + } break; + case TreeNodeType::MULTIWAY_JOIN: { + auto& joinExtraInfo = parent.extraInfo->constCast(); + KU_ASSERT(joinExtraInfo.joinNodes.size() == 1); + nbrNode = joinExtraInfo.joinNodes[0]; + boundNode = getOtherNode(*rel, *nbrNode); + } break; + default: + KU_UNREACHABLE; + } + auto direction = getExtendDirection(*rel, *boundNode); + auto plan = LogicalPlan(); + planner->appendScanNodeTable(boundNode->getInternalID(), boundNode->getTableIDs(), + expression_vector{}, plan); + planner->appendExtend(boundNode, nbrNode, rel, direction, relInfo.properties, plan); + planner->appendFilters(relInfo.predicates, plan); + return plan; +} + +LogicalPlan JoinPlanSolver::solveBinaryJoinTreeNode(const JoinTreeNode& treeNode) { + auto probePlan = solveTreeNode(*treeNode.children[0], &treeNode); + auto buildPlan = solveTreeNode(*treeNode.children[1], &treeNode); + auto& extraInfo = treeNode.extraInfo->constCast(); + binder::expression_vector joinNodeIDs; + for (auto& expr : extraInfo.joinNodes) { + joinNodeIDs.push_back(expr->constCast().getInternalID()); + } + auto plan = LogicalPlan(); + planner->appendHashJoin(joinNodeIDs, JoinType::INNER, probePlan, buildPlan, plan); + planner->appendFilters(extraInfo.predicates, plan); + return plan; +} + +LogicalPlan JoinPlanSolver::solveMultiwayJoinTreeNode(const JoinTreeNode& treeNode) { + auto& extraInfo = treeNode.extraInfo->constCast(); + KU_ASSERT(extraInfo.joinNodes.size() == 1); + auto& joinNode = extraInfo.joinNodes[0]->constCast(); + auto probePlan = solveTreeNode(*treeNode.children[0], &treeNode); + std::vector buildPlans; + expression_vector boundNodeIDs; + for (auto i = 1u; i < treeNode.children.size(); ++i) { + auto child = treeNode.children[i]; + KU_ASSERT(child->type == TreeNodeType::REL_SCAN); + auto& childExtraInfo = child->extraInfo->constCast(); + auto rel = std::static_pointer_cast(childExtraInfo.relInfos[0].nodeOrRel); + auto boundNode = *rel->getSrcNode() == joinNode ? rel->getDstNode() : rel->getSrcNode(); + buildPlans.push_back(solveTreeNode(*child, &treeNode).copy()); + boundNodeIDs.push_back(boundNode->constCast().getInternalID()); + } + auto plan = LogicalPlan(); + // TODO(Xiyang): provide an interface to append operator to resultPlan. + planner->appendIntersect(joinNode.getInternalID(), boundNodeIDs, probePlan, buildPlans); + plan.setLastOperator(probePlan.getLastOperator()); + planner->appendFilters(extraInfo.predicates, plan); + return plan; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_tree.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_tree.cpp new file mode 100644 index 0000000000..8abe00dac2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_tree.cpp @@ -0,0 +1,57 @@ +#include "planner/join_order/join_tree.h" + +namespace lbug { +namespace planner { + +std::string TreeNodeTypeUtils::toString(TreeNodeType type) { + switch (type) { + case TreeNodeType::NODE_SCAN: + return "NODE_SCAN"; + case TreeNodeType::REL_SCAN: + return "REL_SCAN"; + case TreeNodeType::BINARY_JOIN: + return "BINARY_JOIN"; + case TreeNodeType::MULTIWAY_JOIN: + return "MULTIWAY_JOIN"; + default: + KU_UNREACHABLE; + } +} + +void ExtraScanTreeNodeInfo::merge(const ExtraScanTreeNodeInfo& other) { + KU_ASSERT(other.nodeInfo == nullptr && other.relInfos.size() == 1); + relInfos.push_back(other.relInfos[0]); +} + +std::string JoinTreeNode::toString() const { + switch (type) { + case TreeNodeType::NODE_SCAN: + case TreeNodeType::REL_SCAN: { + auto& scanInfo = extraInfo->constCast(); + auto result = "Scan(" + scanInfo.nodeInfo->nodeOrRel->toString(); + for (auto relInfo : scanInfo.relInfos) { + result += "," + relInfo.nodeOrRel->toString(); + } + result += ")"; + return result; + } + case TreeNodeType::BINARY_JOIN: { + KU_ASSERT(children.size() == 2); + return "JOIN(" + children[0]->toString() + "," + children[1]->toString() + ")"; + } + case TreeNodeType::MULTIWAY_JOIN: { + KU_ASSERT(!children.empty()); + auto result = "MULTI_JOIN(" + children[0]->toString(); + for (auto i = 1u; i < children.size(); ++i) { + result += "," + children[i]->toString(); + } + return result; + } + default: { + KU_UNREACHABLE; + } + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_tree_constructor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_tree_constructor.cpp new file mode 100644 index 0000000000..8f0fc8dd5a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order/join_tree_constructor.cpp @@ -0,0 +1,207 @@ +#include "planner/join_order/join_tree_constructor.h" + +#include "binder/expression/expression_util.h" +#include "binder/query/reading_clause/bound_join_hint.h" +#include "common/exception/binder.h" +#include "common/exception/not_implemented.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +JoinTree JoinTreeConstructor::construct(std::shared_ptr root) { + if (planningInfo.subqueryType == SubqueryPlanningType::CORRELATED) { + throw NotImplementedException( + stringFormat("Hint join pattern has correlation with previous " + "patterns. This is not supported yet.")); + } + return JoinTree(constructTreeNode(root).treeNode); +} + +static std::vector> getJoinNodes(const SubqueryGraph& subgraph, + const SubqueryGraph& otherSubgraph) { + KU_ASSERT(&subgraph.queryGraph == &otherSubgraph.queryGraph); + std::vector> joinNodes; + for (auto idx : subgraph.getNbrNodeIndices()) { + if (otherSubgraph.queryNodesSelector[idx]) { + joinNodes.push_back(subgraph.queryGraph.getQueryNode(idx)); + } + } + return joinNodes; +} + +static std::vector intersect(std::vector left, + std::vector right) { + std::vector result; + auto set = std::unordered_set{right.begin(), right.end()}; + for (auto idx : left) { + if (set.contains(idx)) { + result.push_back(idx); + } + } + return result; +} + +std::shared_ptr getIntersectNode(const QueryGraph& queryGraph, + const std::vector& buildSubgraphs) { + auto candidates = buildSubgraphs[0].getNbrNodeIndices(); + for (auto i = 1u; i < buildSubgraphs.size(); ++i) { + candidates = intersect(candidates, buildSubgraphs[i].getNbrNodeIndices()); + } + if (candidates.size() != 1) { + throw BinderException("Cannot resolve join condition for multi-way join."); + } + return queryGraph.getQueryNode(candidates[0]); +} + +JoinTreeConstructor::IntermediateResult JoinTreeConstructor::constructTreeNode( + std::shared_ptr hintNode) { + // Construct leaf scans. + if (hintNode->isLeaf()) { + if (ExpressionUtil::isNodePattern(*hintNode->nodeOrRel)) { + return constructNodeScan(hintNode->nodeOrRel); + } else { + KU_ASSERT(ExpressionUtil::isRelPattern(*hintNode->nodeOrRel) || + ExpressionUtil::isRecursiveRelPattern(*hintNode->nodeOrRel)); + return constructRelScan(hintNode->nodeOrRel); + } + } + // Construct binary join. + if (hintNode->isBinary()) { + auto left = constructTreeNode(hintNode->children[0]); + auto right = constructTreeNode(hintNode->children[1]); + auto joinNodes = getJoinNodes(left.subqueryGraph, right.subqueryGraph); + if (joinNodes.empty()) { + joinNodes = getJoinNodes(right.subqueryGraph, left.subqueryGraph); + } + if (joinNodes.empty()) { + throw BinderException(stringFormat("Cannot resolve join condition between {} and {}.", + left.treeNode->toString(), right.treeNode->toString())); + } + auto newSubgraph = left.subqueryGraph; + newSubgraph.addSubqueryGraph(right.subqueryGraph); + auto predicates = Planner::getNewlyMatchedExprs(left.subqueryGraph, right.subqueryGraph, + newSubgraph, queryGraphPredicates); + // First try to construct as index nested loop join. + auto nestedLoopTreeNode = + tryConstructNestedLoopJoin(joinNodes, *left.treeNode, *right.treeNode, predicates); + if (nestedLoopTreeNode != nullptr) { + return {nestedLoopTreeNode, newSubgraph}; + } + // Cannot construct index nested loop join. Fall back to hash join. + auto extraInfo = std::make_unique(joinNodes); + extraInfo->predicates = predicates; + auto treeNode = + std::make_shared(TreeNodeType::BINARY_JOIN, std::move(extraInfo)); + treeNode->addChild(left.treeNode); + treeNode->addChild(right.treeNode); + return {treeNode, newSubgraph}; + } + // Construct multi-way join + KU_ASSERT(hintNode->isMultiWay()); + auto probe = constructTreeNode(hintNode->children[0]); + auto newSubgraph = probe.subqueryGraph; + std::vector> childrenNodes; + childrenNodes.push_back(probe.treeNode); + std::vector buildSubgraphs; + for (auto i = 1u; i < hintNode->children.size(); ++i) { + auto build = constructTreeNode(hintNode->children[i]); + if (build.treeNode->type != TreeNodeType::REL_SCAN) { + throw BinderException(stringFormat( + "Cannot construct multi-way join because build side is not a relationship table.")); + } + newSubgraph.addSubqueryGraph(build.subqueryGraph); + childrenNodes.push_back(build.treeNode); + buildSubgraphs.push_back(build.subqueryGraph); + } + auto joinNode = getIntersectNode(queryGraph, buildSubgraphs); + auto subgraphs = buildSubgraphs; + subgraphs.push_back(probe.subqueryGraph); + auto predicates = Planner::getNewlyMatchedExprs(subgraphs, newSubgraph, queryGraphPredicates); + auto extraInfo = std::make_unique(joinNode); + extraInfo->predicates = predicates; + auto treeNode = + std::make_shared(TreeNodeType::MULTIWAY_JOIN, std::move(extraInfo)); + for (auto& child : childrenNodes) { + treeNode->addChild(child); + } + return {treeNode, newSubgraph}; +} + +JoinTreeConstructor::IntermediateResult JoinTreeConstructor::constructNodeScan( + std::shared_ptr expr) { + auto& node = expr->constCast(); + auto nodeIdx = queryGraph.getQueryNodeIdx(node.getUniqueName()); + auto emptySubgraph = SubqueryGraph(queryGraph); + auto newSubgraph = SubqueryGraph(queryGraph); + newSubgraph.addQueryNode(nodeIdx); + auto extraInfo = std::make_unique(); + // See Planner::planBaseTableScans for how we plan unnest correlated subqueries. + if (planningInfo.subqueryType == SubqueryPlanningType::UNNEST_CORRELATED && + planningInfo.containsCorrExpr(*node.getInternalID())) { + extraInfo->nodeInfo = std::make_unique(expr, expression_vector{}); + ; + auto treeNode = + std::make_shared(TreeNodeType::NODE_SCAN, std::move(extraInfo)); + return {treeNode, newSubgraph}; + } + auto properties = propertyCollection.getProperties(*expr); + auto predicates = + Planner::getNewlyMatchedExprs(emptySubgraph, newSubgraph, queryGraphPredicates); + auto nodeScanInfo = std::make_unique(expr, properties); + nodeScanInfo->predicates = predicates; + extraInfo->nodeInfo = std::move(nodeScanInfo); + auto treeNode = std::make_shared(TreeNodeType::NODE_SCAN, std::move(extraInfo)); + return {treeNode, newSubgraph}; +} + +JoinTreeConstructor::IntermediateResult JoinTreeConstructor::constructRelScan( + std::shared_ptr expr) { + auto& rel = expr->constCast(); + auto relIdx = queryGraph.getQueryRelIdx(rel.getUniqueName()); + auto emptySubgraph = SubqueryGraph(queryGraph); + auto newSubgraph = SubqueryGraph(queryGraph); + newSubgraph.addQueryRel(relIdx); + auto properties = propertyCollection.getProperties(*expr); + auto predicates = + Planner::getNewlyMatchedExprs(emptySubgraph, newSubgraph, queryGraphPredicates); + auto relScanInfo = NodeRelScanInfo(expr, properties); + relScanInfo.predicates = predicates; + auto extraInfo = std::make_unique(); + extraInfo->relInfos.push_back(std::move(relScanInfo)); + auto treeNode = std::make_shared(TreeNodeType::REL_SCAN, std::move(extraInfo)); + return {treeNode, newSubgraph}; +} + +std::shared_ptr JoinTreeConstructor::tryConstructNestedLoopJoin( + std::vector> joinNodes, const JoinTreeNode& leftRoot, + const JoinTreeNode& rightRoot, const binder::expression_vector& predicates) { + if (joinNodes.size() > 1) { + return nullptr; + } + if (leftRoot.type == TreeNodeType::REL_SCAN && rightRoot.type == TreeNodeType::NODE_SCAN) { + return tryConstructNestedLoopJoin(joinNodes, rightRoot, leftRoot, predicates); + } + if (leftRoot.type != TreeNodeType::NODE_SCAN) { + return nullptr; + } + if (rightRoot.type != TreeNodeType::REL_SCAN) { + return nullptr; + } + auto joinNode = joinNodes[0]; + auto& leftExtraInfo = leftRoot.extraInfo->constCast(); + auto& rightExtraInfo = rightRoot.extraInfo->constCast(); + if (*leftExtraInfo.nodeInfo->nodeOrRel != *joinNode) { + return nullptr; + } + auto newExtraInfo = std::make_unique(leftExtraInfo); + newExtraInfo->relInfos.push_back(rightExtraInfo.relInfos[0]); + newExtraInfo->predicates = predicates; + return std::make_shared(TreeNodeType::NODE_SCAN, std::move(newExtraInfo)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order_enumerator_context.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order_enumerator_context.cpp new file mode 100644 index 0000000000..0bc0aefdd4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/join_order_enumerator_context.cpp @@ -0,0 +1,37 @@ +#include "planner/join_order_enumerator_context.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void JoinOrderEnumeratorContext::init(const QueryGraph* queryGraph_, + const expression_vector& predicates) { + whereExpressionsSplitOnAND = predicates; + this->queryGraph = queryGraph_; + // clear and resize subPlansTable + subPlansTable->clear(); + maxLevel = queryGraph_->getNumQueryNodes() + queryGraph_->getNumQueryRels() + 1; + subPlansTable->resize(maxLevel); + // Restart from level 1 for new query part so that we get hashJoin based plans + // that uses subplans coming from previous query part.See example in planRelIndexJoin(). + currentLevel = 1; +} + +SubqueryGraph JoinOrderEnumeratorContext::getFullyMatchedSubqueryGraph() const { + auto subqueryGraph = SubqueryGraph(*queryGraph); + for (auto i = 0u; i < queryGraph->getNumQueryNodes(); ++i) { + subqueryGraph.addQueryNode(i); + } + for (auto i = 0u; i < queryGraph->getNumQueryRels(); ++i) { + subqueryGraph.addQueryRel(i); + } + return subqueryGraph; +} + +void JoinOrderEnumeratorContext::resetState() { + subPlansTable = std::make_unique(); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/CMakeLists.txt new file mode 100644 index 0000000000..2ad0838d77 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/CMakeLists.txt @@ -0,0 +1,38 @@ +add_subdirectory(extend) +add_subdirectory(factorization) +add_subdirectory(persistent) +add_subdirectory(scan) +add_subdirectory(simple) +add_subdirectory(sip) + +add_library(lbug_planner_operator + OBJECT + logical_accumulate.cpp + logical_aggregate.cpp + logical_create_macro.cpp + logical_cross_product.cpp + logical_distinct.cpp + logical_dummy_scan.cpp + logical_dummy_sink.cpp + logical_explain.cpp + logical_filter.cpp + logical_flatten.cpp + logical_hash_join.cpp + logical_table_function_call.cpp + logical_intersect.cpp + logical_limit.cpp + logical_operator.cpp + logical_order_by.cpp + logical_partitioner.cpp + logical_path_property_probe.cpp + logical_plan.cpp + logical_plan_util.cpp + logical_projection.cpp + logical_standalone_call.cpp + logical_union.cpp + logical_unwind.cpp + schema.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/CMakeLists.txt new file mode 100644 index 0000000000..65e8c3f7f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_planner_extend + OBJECT + base_logical_extend.cpp + logical_extend.cpp + logical_recursive_extend.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) \ No newline at end of file diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/base_logical_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/base_logical_extend.cpp new file mode 100644 index 0000000000..92b1c87982 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/base_logical_extend.cpp @@ -0,0 +1,55 @@ +#include "planner/operator/extend/base_logical_extend.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +static std::string relToString(const binder::RelExpression& rel) { + auto result = rel.toString(); + switch (rel.getRelType()) { + case QueryRelType::SHORTEST: { + result += "SHORTEST"; + } break; + case QueryRelType::ALL_SHORTEST: { + result += "ALL SHORTEST"; + } break; + default: + break; + } + if (QueryRelTypeUtils::isRecursive(rel.getRelType())) { + auto bindData = rel.getRecursiveInfo()->bindData.get(); + result += std::to_string(bindData->lowerBound); + result += ".."; + result += std::to_string(bindData->upperBound); + } + return result; +} + +std::string BaseLogicalExtend::getExpressionsForPrinting() const { + auto result = boundNode->toString(); + switch (direction) { + case ExtendDirection::FWD: { + result += "-"; + result += relToString(*rel); + result += "->"; + } break; + case ExtendDirection::BWD: { + result += "<-"; + result += relToString(*rel); + result += "-"; + } break; + case ExtendDirection::BOTH: { + result += "<-"; + result += relToString(*rel); + result += "->"; + } break; + default: + KU_UNREACHABLE; + } + result += nbrNode->toString(); + return result; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/logical_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/logical_extend.cpp new file mode 100644 index 0000000000..141058bf92 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/logical_extend.cpp @@ -0,0 +1,43 @@ +#include "planner/operator/extend/logical_extend.h" + +namespace lbug { +namespace planner { + +void LogicalExtend::computeFactorizedSchema() { + copyChildSchema(0); + const auto boundGroupPos = schema->getGroupPos(*boundNode->getInternalID()); + if (!schema->getGroup(boundGroupPos)->isFlat()) { + schema->flattenGroup(boundGroupPos); + } + uint32_t nbrGroupPos = 0u; + nbrGroupPos = schema->createGroup(); + schema->insertToGroupAndScope(nbrNode->getInternalID(), nbrGroupPos); + for (auto& property : properties) { + schema->insertToGroupAndScope(property, nbrGroupPos); + } + if (rel->hasDirectionExpr()) { + schema->insertToGroupAndScope(rel->getDirectionExpr(), nbrGroupPos); + } +} + +void LogicalExtend::computeFlatSchema() { + copyChildSchema(0); + schema->insertToGroupAndScope(nbrNode->getInternalID(), 0); + for (auto& property : properties) { + schema->insertToGroupAndScope(property, 0); + } + if (rel->hasDirectionExpr()) { + schema->insertToGroupAndScope(rel->getDirectionExpr(), 0); + } +} + +std::unique_ptr LogicalExtend::copy() { + auto extend = std::make_unique(boundNode, nbrNode, rel, direction, + extendFromSource_, properties, children[0]->copy(), cardinality); + extend->setPropertyPredicates(copyVector(propertyPredicates)); + extend->scanNbrID = scanNbrID; + return extend; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/logical_recursive_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/logical_recursive_extend.cpp new file mode 100644 index 0000000000..57a351b169 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/extend/logical_recursive_extend.cpp @@ -0,0 +1,23 @@ +#include "planner/operator/extend/logical_recursive_extend.h" + +namespace lbug { +namespace planner { + +void LogicalRecursiveExtend::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); + for (auto& expr : resultColumns) { + schema->insertToGroupAndScope(expr, 0); + } +} + +void LogicalRecursiveExtend::computeFactorizedSchema() { + createEmptySchema(); + auto pos = schema->createGroup(); + for (auto& e : resultColumns) { + schema->insertToGroupAndScope(e, pos); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/CMakeLists.txt new file mode 100644 index 0000000000..eaa94b16f6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_planner_factorization + OBJECT + flatten_resolver.cpp + sink_util.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/flatten_resolver.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/flatten_resolver.cpp new file mode 100644 index 0000000000..c903b8a48b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/flatten_resolver.cpp @@ -0,0 +1,221 @@ +#include "planner/operator/factorization/flatten_resolver.h" + +#include "binder/expression/case_expression.h" +#include "binder/expression/lambda_expression.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression/subquery_expression.h" +#include "common/exception/not_implemented.h" +#include "planner/operator/schema.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +std::pair FlattenAllButOne::getGroupsPosToFlatten( + const expression_vector& exprs, const Schema& schema) { + f_group_pos_set result; + f_group_pos_set dependentGroups; + for (auto expr : exprs) { + auto analyzer = GroupDependencyAnalyzer(false /* collectDependentExpr */, schema); + analyzer.visit(expr); + for (auto pos : analyzer.getRequiredFlatGroups()) { + result.insert(pos); + } + for (auto pos : analyzer.getDependentGroups()) { + dependentGroups.insert(pos); + } + } + std::vector candidates; + for (auto pos : dependentGroups) { + if (!schema.getGroup(pos)->isFlat() && !result.contains(pos)) { + candidates.push_back(pos); + } + } + for (auto i = 1u; i < candidates.size(); ++i) { + result.insert(candidates[i]); + } + if (candidates.empty()) { + return std::make_pair(INVALID_F_GROUP_POS, result); + } else { + return std::make_pair(candidates[0], result); + } +} + +f_group_pos_set FlattenAllButOne::getGroupsPosToFlatten(std::shared_ptr expr, + const Schema& schema) { + auto analyzer = GroupDependencyAnalyzer(false /* collectDependentExpr */, schema); + analyzer.visit(expr); + f_group_pos_set result = analyzer.getRequiredFlatGroups(); + std::vector candidates; + for (auto groupPos : analyzer.getDependentGroups()) { + if (!schema.getGroup(groupPos)->isFlat() && !result.contains(groupPos)) { + candidates.push_back(groupPos); + } + } + // Keep the first group as unFlat. + for (auto i = 1u; i < candidates.size(); ++i) { + result.insert(candidates[i]); + } + return result; +} + +f_group_pos_set FlattenAllButOne::getGroupsPosToFlatten( + const std::unordered_set& dependentGroups, const Schema& schema) { + f_group_pos_set result; + std::vector candidates; + for (auto groupPos : dependentGroups) { + if (!schema.getGroup(groupPos)->isFlat()) { + candidates.push_back(groupPos); + } + } + for (auto i = 1u; i < candidates.size(); ++i) { + result.insert(candidates[i]); + } + return result; +} + +f_group_pos_set FlattenAll::getGroupsPosToFlatten(const expression_vector& exprs, + const Schema& schema) { + f_group_pos_set result; + for (auto& expr : exprs) { + for (auto pos : getGroupsPosToFlatten(expr, schema)) { + result.insert(pos); + } + } + return result; +} + +f_group_pos_set FlattenAll::getGroupsPosToFlatten(std::shared_ptr expr, + const Schema& schema) { + auto analyzer = GroupDependencyAnalyzer(false /* collectDependentExpr */, schema); + analyzer.visit(expr); + return getGroupsPosToFlatten(analyzer.getDependentGroups(), schema); +} + +f_group_pos_set FlattenAll::getGroupsPosToFlatten( + const std::unordered_set& dependentGroups, const Schema& schema) { + f_group_pos_set result; + for (auto groupPos : dependentGroups) { + if (!schema.getGroup(groupPos)->isFlat()) { + result.insert(groupPos); + } + } + return result; +} + +void GroupDependencyAnalyzer::visit(std::shared_ptr expr) { + if (schema.isExpressionInScope(*expr)) { + dependentGroups.insert(schema.getGroupPos(*expr)); + if (collectDependentExpr) { + dependentExprs.insert(expr); + } + return; + } + switch (expr->expressionType) { + case ExpressionType::FUNCTION: + return visitFunction(expr); + case ExpressionType::CASE_ELSE: { + visitCase(expr); + } break; + case ExpressionType::PATTERN: { + visitNodeOrRel(expr); + } break; + case ExpressionType::SUBQUERY: { + visitSubquery(expr); + } break; + case ExpressionType::LAMBDA: { + visit(expr->constCast().getFunctionExpr()); + } break; + case ExpressionType::LITERAL: + case ExpressionType::AGGREGATE_FUNCTION: + case ExpressionType::PROPERTY: + case ExpressionType::VARIABLE: + case ExpressionType::PATH: + case ExpressionType::PARAMETER: + case ExpressionType::GRAPH: + case ExpressionType::OR: + case ExpressionType::XOR: + case ExpressionType::AND: + case ExpressionType::NOT: + case ExpressionType::EQUALS: + case ExpressionType::NOT_EQUALS: + case ExpressionType::GREATER_THAN: + case ExpressionType::GREATER_THAN_EQUALS: + case ExpressionType::LESS_THAN: + case ExpressionType::LESS_THAN_EQUALS: + case ExpressionType::IS_NULL: + case ExpressionType::IS_NOT_NULL: { + for (auto& child : expr->getChildren()) { + visit(child); + } + } break; + // LCOV_EXCL_START + default: + throw NotImplementedException("GroupDependencyAnalyzer::visit"); + // LCOV_EXCL_STOP + } +} + +void GroupDependencyAnalyzer::visitFunction(std::shared_ptr expr) { + auto& funcExpr = expr->constCast(); + for (auto& child : expr->getChildren()) { + visit(child); + } + // For list lambda we need to flatten all dependent expressions in lambda function + // E.g. MATCH (a)->(b) RETURN list_filter(a.list, x -> x>b.age) + if (funcExpr.getFunction().isListLambda) { + auto lambdaFunctionAnalyzer = GroupDependencyAnalyzer(collectDependentExpr, schema); + lambdaFunctionAnalyzer.visit(funcExpr.getChild(1)); + requiredFlatGroups = lambdaFunctionAnalyzer.getDependentGroups(); + } +} + +void GroupDependencyAnalyzer::visitCase(std::shared_ptr expr) { + auto& caseExpression = expr->constCast(); + for (auto i = 0u; i < caseExpression.getNumCaseAlternatives(); ++i) { + auto caseAlternative = caseExpression.getCaseAlternative(i); + visit(caseAlternative->whenExpression); + visit(caseAlternative->thenExpression); + } + visit(caseExpression.getElseExpression()); +} + +void GroupDependencyAnalyzer::visitNodeOrRel(std::shared_ptr expr) { + for (auto& p : expr->constCast().getPropertyExpressions()) { + visit(p); + } + switch (expr->getDataType().getLogicalTypeID()) { + case LogicalTypeID::NODE: { + auto& node = expr->constCast(); + visit(node.getInternalID()); + } break; + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: { + auto& rel = expr->constCast(); + visit(rel.getSrcNode()->getInternalID()); + visit(rel.getDstNode()->getInternalID()); + if (rel.hasDirectionExpr()) { + visit(rel.getDirectionExpr()); + } + } break; + default: + KU_UNREACHABLE; + } +} + +void GroupDependencyAnalyzer::visitSubquery(std::shared_ptr expr) { + auto& subqueryExpr = expr->constCast(); + for (auto& node : subqueryExpr.getQueryGraphCollection()->getQueryNodes()) { + visit(node->getInternalID()); + } + if (subqueryExpr.hasWhereExpression()) { + visit(subqueryExpr.getWhereExpression()); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/sink_util.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/sink_util.cpp new file mode 100644 index 0000000000..c930bb73c1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/factorization/sink_util.cpp @@ -0,0 +1,70 @@ +#include "planner/operator/factorization/sink_util.h" + +namespace lbug { +namespace planner { + +void SinkOperatorUtil::mergeSchema(const Schema& inputSchema, + const binder::expression_vector& expressionsToMerge, Schema& resultSchema) { + auto flatPayloads = getFlatPayloads(inputSchema, expressionsToMerge); + auto unFlatPayloadsPerGroup = getUnFlatPayloadsPerGroup(inputSchema, expressionsToMerge); + if (unFlatPayloadsPerGroup.empty()) { + appendPayloadsToNewGroup(resultSchema, flatPayloads); + } else { + if (!flatPayloads.empty()) { + auto groupPos = appendPayloadsToNewGroup(resultSchema, flatPayloads); + resultSchema.setGroupAsSingleState(groupPos); + } + for (auto& [inputGroupPos, payloads] : unFlatPayloadsPerGroup) { + auto resultGroupPos = appendPayloadsToNewGroup(resultSchema, payloads); + resultSchema.getGroup(resultGroupPos) + ->setMultiplier(inputSchema.getGroup(inputGroupPos)->getMultiplier()); + } + } +} + +void SinkOperatorUtil::recomputeSchema(const Schema& inputSchema, + const binder::expression_vector& expressionsToMerge, Schema& resultSchema) { + KU_ASSERT(!expressionsToMerge.empty()); + resultSchema.clear(); + mergeSchema(inputSchema, expressionsToMerge, resultSchema); +} + +std::unordered_map +SinkOperatorUtil::getUnFlatPayloadsPerGroup(const Schema& schema, + const binder::expression_vector& payloads) { + std::unordered_map result; + for (auto& payload : payloads) { + auto groupPos = schema.getGroupPos(*payload); + if (schema.getGroup(groupPos)->isFlat()) { + continue; + } + if (!result.contains(groupPos)) { + result.insert({groupPos, binder::expression_vector{}}); + } + result.at(groupPos).push_back(payload); + } + return result; +} + +binder::expression_vector SinkOperatorUtil::getFlatPayloads(const Schema& schema, + const binder::expression_vector& payloads) { + binder::expression_vector result; + for (auto& payload : payloads) { + if (schema.getGroup(payload)->isFlat()) { + result.push_back(payload); + } + } + return result; +} + +uint32_t SinkOperatorUtil::appendPayloadsToNewGroup(Schema& schema, + binder::expression_vector& payloads) { + auto outputGroupPos = schema.createGroup(); + for (auto& payload : payloads) { + schema.insertToGroupAndScope(payload, outputGroupPos); + } + return outputGroupPos; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_accumulate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_accumulate.cpp new file mode 100644 index 0000000000..06567054f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_accumulate.cpp @@ -0,0 +1,34 @@ +#include "planner/operator/logical_accumulate.h" + +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/factorization/sink_util.h" + +namespace lbug { +namespace planner { + +void LogicalAccumulate::computeFactorizedSchema() { + createEmptySchema(); + auto childSchema = children[0]->getSchema(); + SinkOperatorUtil::recomputeSchema(*childSchema, getPayloads(), *schema); + if (mark != nullptr) { + auto groupPos = schema->createGroup(); + schema->setGroupAsSingleState(groupPos); + schema->insertToGroupAndScope(mark, groupPos); + } +} + +void LogicalAccumulate::computeFlatSchema() { + copyChildSchema(0); + if (mark != nullptr) { + schema->insertToGroupAndScope(mark, 0); + } +} + +f_group_pos_set LogicalAccumulate::getGroupPositionsToFlatten() const { + f_group_pos_set result; + auto childSchema = children[0]->getSchema(); + return FlattenAll::getGroupsPosToFlatten(flatExprs, *childSchema); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_aggregate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_aggregate.cpp new file mode 100644 index 0000000000..4386f21802 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_aggregate.cpp @@ -0,0 +1,88 @@ +#include "planner/operator/logical_aggregate.h" + +#include "binder/expression/aggregate_function_expression.h" +#include "binder/expression/expression_util.h" +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/schema.h" + +namespace lbug { +namespace planner { + +std::string LogicalAggregatePrintInfo::toString() const { + std::string result = ""; + result += "Group By: "; + result += binder::ExpressionUtil::toString(keys); + result += ", Aggregates: "; + result += binder::ExpressionUtil::toString(aggregates); + return result; +} + +void LogicalAggregate::computeFactorizedSchema() { + createEmptySchema(); + auto groupPos = schema->createGroup(); + insertAllExpressionsToGroupAndScope(groupPos); +} + +void LogicalAggregate::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); + insertAllExpressionsToGroupAndScope(0 /* groupPos */); +} + +f_group_pos_set LogicalAggregate::getGroupsPosToFlatten() { + auto [unflatGroup, flattenedGroups] = + FlattenAllButOne::getGroupsPosToFlatten(getAllKeys(), *children[0]->getSchema()); + // Flatten distinct aggregates if they are from a different group than the unflat key group + // Regular aggregates can be processed when unflat, but distinct aggregates get added to their + // own AggregateHashTable and have the same input limitations as the aggregate groups + if (unflatGroup != INVALID_F_GROUP_POS) { + for (const auto& aggregate : aggregates) { + auto funcExpr = aggregate->constPtrCast(); + auto analyzer = GroupDependencyAnalyzer(false /* collectDependentExpr */, + *children[0]->getSchema()); + analyzer.visit(aggregate); + for (const auto& group : analyzer.getRequiredFlatGroups()) { + flattenedGroups.insert(group); + } + if (funcExpr->isDistinct()) { + for (const auto& group : analyzer.getDependentGroups()) { + if (group != unflatGroup) { + flattenedGroups.insert(group); + } + } + } + } + } + return flattenedGroups; +} + +std::string LogicalAggregate::getExpressionsForPrinting() const { + std::string result = "Group By ["; + for (auto& expression : keys) { + result += expression->toString() + ", "; + } + for (auto& expression : dependentKeys) { + result += expression->toString() + ", "; + } + result += "], Aggregate ["; + for (auto& expression : aggregates) { + result += expression->toString() + ", "; + } + result += "]"; + return result; +} + +void LogicalAggregate::insertAllExpressionsToGroupAndScope(f_group_pos groupPos) { + for (auto& expression : keys) { + schema->insertToGroupAndScopeMayRepeat(expression, groupPos); + } + for (auto& expression : dependentKeys) { + schema->insertToGroupAndScopeMayRepeat(expression, groupPos); + } + for (auto& expression : aggregates) { + schema->insertToGroupAndScopeMayRepeat(expression, groupPos); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_create_macro.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_create_macro.cpp new file mode 100644 index 0000000000..18a193c1ea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_create_macro.cpp @@ -0,0 +1,21 @@ +#include "planner/operator/logical_create_macro.h" + +namespace lbug { +namespace planner { + +std::string LogicalCreateMacroPrintInfo::toString() const { + std::string result = "Macro: "; + result += macroName; + return result; +} + +void LogicalCreateMacro::computeFlatSchema() { + createEmptySchema(); +} + +void LogicalCreateMacro::computeFactorizedSchema() { + createEmptySchema(); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_cross_product.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_cross_product.cpp new file mode 100644 index 0000000000..035e5cc903 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_cross_product.cpp @@ -0,0 +1,34 @@ +#include "planner/operator/logical_cross_product.h" + +#include "planner/operator/factorization/sink_util.h" + +namespace lbug { +namespace planner { + +void LogicalCrossProduct::computeFactorizedSchema() { + auto probeSchema = children[0]->getSchema(); + auto buildSchema = children[1]->getSchema(); + schema = probeSchema->copy(); + SinkOperatorUtil::mergeSchema(*buildSchema, buildSchema->getExpressionsInScope(), *schema); + if (mark != nullptr) { + auto groupPos = schema->createGroup(); + schema->setGroupAsSingleState(groupPos); + schema->insertToGroupAndScope(mark, groupPos); + } +} + +void LogicalCrossProduct::computeFlatSchema() { + auto probeSchema = children[0]->getSchema(); + auto buildSchema = children[1]->getSchema(); + schema = probeSchema->copy(); + KU_ASSERT(schema->getNumGroups() == 1); + for (auto& expression : buildSchema->getExpressionsInScope()) { + schema->insertToGroupAndScope(expression, 0); + } + if (mark != nullptr) { + schema->insertToGroupAndScope(mark, 0); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_distinct.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_distinct.cpp new file mode 100644 index 0000000000..3785fa49f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_distinct.cpp @@ -0,0 +1,42 @@ +#include "planner/operator/logical_distinct.h" + +#include "binder/expression/expression_util.h" +#include "planner/operator/factorization/flatten_resolver.h" + +namespace lbug { +namespace planner { + +void LogicalDistinct::computeFactorizedSchema() { + createEmptySchema(); + auto groupPos = schema->createGroup(); + for (auto& expression : getKeysAndPayloads()) { + schema->insertToGroupAndScope(expression, groupPos); + } +} + +void LogicalDistinct::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); + for (auto& expression : getKeysAndPayloads()) { + schema->insertToGroupAndScope(expression, 0); + } +} + +f_group_pos_set LogicalDistinct::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return FlattenAll::getGroupsPosToFlatten(getKeysAndPayloads(), *childSchema); +} + +std::string LogicalDistinct::getExpressionsForPrinting() const { + return binder::ExpressionUtil::toString(getKeysAndPayloads()); +} + +binder::expression_vector LogicalDistinct::getKeysAndPayloads() const { + binder::expression_vector result; + result.insert(result.end(), keys.begin(), keys.end()); + result.insert(result.end(), payloads.begin(), payloads.end()); + return result; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_dummy_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_dummy_scan.cpp new file mode 100644 index 0000000000..c72e49b1f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_dummy_scan.cpp @@ -0,0 +1,27 @@ +#include "planner/operator/scan/logical_dummy_scan.h" + +#include "binder/expression/literal_expression.h" +#include "common/constants.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void LogicalDummyScan::computeFactorizedSchema() { + createEmptySchema(); + schema->createGroup(); +} + +void LogicalDummyScan::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); +} + +std::shared_ptr LogicalDummyScan::getDummyExpression() { + return std::make_shared( + Value::createNullValue(LogicalType::STRING()), InternalKeyword::PLACE_HOLDER); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_dummy_sink.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_dummy_sink.cpp new file mode 100644 index 0000000000..203307f3cc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_dummy_sink.cpp @@ -0,0 +1,15 @@ +#include "planner/operator/logical_dummy_sink.h" + +namespace lbug { +namespace planner { + +void LogicalDummySink::computeFactorizedSchema() { + copyChildSchema(0); +} + +void LogicalDummySink::computeFlatSchema() { + copyChildSchema(0); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_explain.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_explain.cpp new file mode 100644 index 0000000000..4a2ffd5b15 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_explain.cpp @@ -0,0 +1,35 @@ +#include "planner/operator/logical_explain.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void LogicalExplain::computeSchema() { + switch (explainType) { + case ExplainType::PROFILE: + if (children[0]->getSchema()) { + copyChildSchema(0); + } else { + createEmptySchema(); + } + break; + case ExplainType::PHYSICAL_PLAN: + case ExplainType::LOGICAL_PLAN: + createEmptySchema(); + break; + default: + KU_UNREACHABLE; + } +} + +void LogicalExplain::computeFlatSchema() { + computeSchema(); +} + +void LogicalExplain::computeFactorizedSchema() { + computeSchema(); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_filter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_filter.cpp new file mode 100644 index 0000000000..99fa982fec --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_filter.cpp @@ -0,0 +1,22 @@ +#include "planner/operator/logical_filter.h" + +#include "planner/operator/factorization/flatten_resolver.h" + +namespace lbug { +namespace planner { + +f_group_pos_set LogicalFilter::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return FlattenAllButOne::getGroupsPosToFlatten(expression, *childSchema); +} + +f_group_pos LogicalFilter::getGroupPosToSelect() const { + auto childSchema = children[0]->getSchema(); + auto analyzer = GroupDependencyAnalyzer(false, *childSchema); + analyzer.visit(expression); + SchemaUtils::validateAtMostOneUnFlatGroup(analyzer.getDependentGroups(), *childSchema); + return SchemaUtils::getLeadingGroupPos(analyzer.getDependentGroups(), *childSchema); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_flatten.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_flatten.cpp new file mode 100644 index 0000000000..9a3a2f9429 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_flatten.cpp @@ -0,0 +1,18 @@ +#include "planner/operator/logical_flatten.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void LogicalFlatten::computeFactorizedSchema() { + copyChildSchema(0); + schema->flattenGroup(groupPos); +} + +void LogicalFlatten::computeFlatSchema() { + throw InternalException("LogicalFlatten::computeFlatSchema() should never be used."); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_hash_join.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_hash_join.cpp new file mode 100644 index 0000000000..b54383df37 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_hash_join.cpp @@ -0,0 +1,207 @@ +#include "planner/operator/logical_hash_join.h" + +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/factorization/sink_util.h" +#include "planner/operator/scan/logical_scan_node_table.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +f_group_pos_set LogicalHashJoin::getGroupsPosToFlattenOnProbeSide() { + f_group_pos_set result; + if (!requireFlatProbeKeys()) { + return result; + } + auto probeSchema = children[0]->getSchema(); + for (auto& [probeKey, buildKey] : joinConditions) { + result.insert(probeSchema->getGroupPos(*probeKey)); + } + return result; +} + +f_group_pos_set LogicalHashJoin::getGroupsPosToFlattenOnBuildSide() { + auto buildSchema = children[1]->getSchema(); + f_group_pos_set joinNodesGroupPos; + for (auto& [probeKey, buildKey] : joinConditions) { + joinNodesGroupPos.insert(buildSchema->getGroupPos(*buildKey)); + } + return FlattenAllButOne::getGroupsPosToFlatten(joinNodesGroupPos, *buildSchema); +} + +void LogicalHashJoin::computeFactorizedSchema() { + auto probeSchema = children[0]->getSchema(); + auto buildSchema = children[1]->getSchema(); + schema = probeSchema->copy(); + switch (joinType) { + case JoinType::INNER: + case JoinType::LEFT: + case JoinType::COUNT: { + // Populate group position mapping + std::unordered_map buildToProbeKeyGroupPositionMap; + for (auto& [probeKey, buildKey] : joinConditions) { + auto probeKeyGroupPos = probeSchema->getGroupPos(*probeKey); + auto buildKeyGroupPos = buildSchema->getGroupPos(*buildKey); + if (!buildToProbeKeyGroupPositionMap.contains(buildKeyGroupPos)) { + buildToProbeKeyGroupPositionMap.insert({buildKeyGroupPos, probeKeyGroupPos}); + } + } + // Resolve expressions to materialize in each group + binder::expression_vector expressionsToMaterializeInNonKeyGroups; + for (auto groupIdx = 0u; groupIdx < buildSchema->getNumGroups(); ++groupIdx) { + auto expressions = buildSchema->getExpressionsInScope(groupIdx); + if (buildToProbeKeyGroupPositionMap.contains(groupIdx)) { // merge key group + auto probeKeyGroupPos = buildToProbeKeyGroupPositionMap.at(groupIdx); + for (auto& expression : expressions) { + // Join key may repeat for internal ID based joins + schema->insertToGroupAndScopeMayRepeat(expression, probeKeyGroupPos); + } + } else { + for (auto& expression : expressions) { + expressionsToMaterializeInNonKeyGroups.push_back(expression); + } + } + } + SinkOperatorUtil::mergeSchema(*buildSchema, expressionsToMaterializeInNonKeyGroups, + *schema); + if (mark != nullptr) { + auto groupPos = schema->getGroupPos(*joinConditions[0].first); + schema->insertToGroupAndScope(mark, groupPos); + } + } break; + case JoinType::MARK: { + std::unordered_set probeSideKeyGroupPositions; + for (auto& [probeKey, buildKey] : joinConditions) { + probeSideKeyGroupPositions.insert(probeSchema->getGroupPos(*probeKey)); + } + if (probeSideKeyGroupPositions.size() > 1) { + SchemaUtils::validateNoUnFlatGroup(probeSideKeyGroupPositions, *probeSchema); + } + auto markPos = *probeSideKeyGroupPositions.begin(); + schema->insertToGroupAndScope(mark, markPos); + } break; + default: + KU_UNREACHABLE; + } +} + +void LogicalHashJoin::computeFlatSchema() { + auto probeSchema = children[0]->getSchema(); + auto buildSchema = children[1]->getSchema(); + schema = probeSchema->copy(); + switch (joinType) { + case JoinType::INNER: + case JoinType::LEFT: + case JoinType::COUNT: { + for (auto& expression : buildSchema->getExpressionsInScope()) { + // Join key may repeat for internal ID based joins. + schema->insertToGroupAndScopeMayRepeat(expression, 0); + } + if (mark != nullptr) { + schema->insertToGroupAndScope(mark, 0); + } + } break; + case JoinType::MARK: { + schema->insertToGroupAndScope(mark, 0); + } break; + default: + KU_UNREACHABLE; + } +} + +std::string LogicalHashJoin::getExpressionsForPrinting() const { + if (isNodeIDOnlyJoin(joinConditions)) { + return binder::ExpressionUtil::toStringOrdered(getJoinNodeIDs()); + } + return binder::ExpressionUtil::toString(joinConditions); +} + +binder::expression_vector LogicalHashJoin::getExpressionsToMaterialize() const { + switch (joinType) { + case JoinType::INNER: + case JoinType::LEFT: + case JoinType::COUNT: { + return children[1]->getSchema()->getExpressionsInScope(); + } + case JoinType::MARK: { + return binder::expression_vector{}; + } + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr LogicalHashJoin::copy() { + auto op = std::make_unique(joinConditions, joinType, mark, children[0]->copy(), + children[1]->copy(), cardinality); + op->sipInfo = sipInfo; + return op; +} + +bool LogicalHashJoin::isNodeIDOnlyJoin(const std::vector& joinConditions) { + for (auto& [probeKey, buildKey] : joinConditions) { + if (probeKey->getUniqueName() != buildKey->getUniqueName() || + probeKey->getDataType().getLogicalTypeID() != common::LogicalTypeID::INTERNAL_ID) { + return false; + } + } + return true; +} + +binder::expression_vector LogicalHashJoin::getJoinNodeIDs() const { + return getJoinNodeIDs(joinConditions); +} + +binder::expression_vector LogicalHashJoin::getJoinNodeIDs( + const std::vector& joinConditions) { + binder::expression_vector result; + for (auto& [probeKey, _] : joinConditions) { + if (probeKey->expressionType != ExpressionType::PROPERTY) { + continue; + } + if (probeKey->dataType.getLogicalTypeID() != LogicalTypeID::INTERNAL_ID) { + continue; + } + result.push_back(probeKey); + } + return result; +} + +class JoinNodeIDUniquenessAnalyzer { +public: + static bool isUnique(const LogicalOperator* op, const binder::Expression& joinNodeID) { + switch (op->getOperatorType()) { + case LogicalOperatorType::FILTER: + case LogicalOperatorType::FLATTEN: + case LogicalOperatorType::LIMIT: + case LogicalOperatorType::PROJECTION: + case LogicalOperatorType::SEMI_MASKER: + return isUnique(op->getChild(0).get(), joinNodeID); + case LogicalOperatorType::SCAN_NODE_TABLE: + return *op->constCast().getNodeID() == joinNodeID; + default: + return false; + } + } +}; + +bool LogicalHashJoin::requireFlatProbeKeys() const { + // Flatten for multiple join keys. + if (joinConditions.size() > 1) { + return true; + } + // Flatten for left join. + if (joinType == JoinType::LEFT || joinType == JoinType::COUNT) { + return true; // TODO(Guodong): fix this. We shouldn't require flatten. + } + auto& [probeKey, buildKey] = joinConditions[0]; + // Flatten for non-ID-based join. + if (probeKey->dataType.getLogicalTypeID() != LogicalTypeID::INTERNAL_ID) { + return true; + } + return !JoinNodeIDUniquenessAnalyzer::isUnique(children[1].get(), *buildKey); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_intersect.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_intersect.cpp new file mode 100644 index 0000000000..7125928fde --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_intersect.cpp @@ -0,0 +1,71 @@ +#include "planner/operator/logical_intersect.h" + +namespace lbug { +namespace planner { + +f_group_pos_set LogicalIntersect::getGroupsPosToFlattenOnProbeSide() { + f_group_pos_set result; + for (auto& keyNodeID : keyNodeIDs) { + result.insert(children[0]->getSchema()->getGroupPos(*keyNodeID)); + } + return result; +} + +f_group_pos_set LogicalIntersect::getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx) { + f_group_pos_set result; + auto childIdx = buildIdx + 1; // skip probe + result.insert(children[childIdx]->getSchema()->getGroupPos(*keyNodeIDs[buildIdx])); + return result; +} + +void LogicalIntersect::computeFactorizedSchema() { + auto probeSchema = children[0]->getSchema(); + schema = probeSchema->copy(); + // Write intersect node and rels into a new group regardless of whether rel is n-n. + auto outGroupPos = schema->createGroup(); + schema->insertToGroupAndScope(intersectNodeID, outGroupPos); + for (auto i = 1u; i < children.size(); ++i) { + auto buildSchema = children[i]->getSchema(); + auto keyNodeID = keyNodeIDs[i - 1]; + // Write rel properties into output group. + for (auto& expression : buildSchema->getExpressionsInScope()) { + if (expression->getUniqueName() == intersectNodeID->getUniqueName() || + expression->getUniqueName() == keyNodeID->getUniqueName()) { + continue; + } + schema->insertToGroupAndScope(expression, outGroupPos); + } + } +} + +void LogicalIntersect::computeFlatSchema() { + auto probeSchema = children[0]->getSchema(); + schema = probeSchema->copy(); + schema->insertToGroupAndScope(intersectNodeID, 0); + for (auto i = 1u; i < children.size(); ++i) { + auto buildSchema = children[i]->getSchema(); + auto keyNodeID = keyNodeIDs[i - 1]; + for (auto& expression : buildSchema->getExpressionsInScope()) { + if (expression->getUniqueName() == intersectNodeID->getUniqueName() || + expression->getUniqueName() == keyNodeID->getUniqueName()) { + continue; + } + schema->insertToGroupAndScope(expression, 0); + } + } +} + +std::unique_ptr LogicalIntersect::copy() { + std::vector> buildChildren; + for (auto i = 1u; i < children.size(); ++i) { + buildChildren.push_back(children[i]->copy()); + } + auto op = make_unique(intersectNodeID, keyNodeIDs, children[0]->copy(), + std::move(buildChildren), cardinality); + op->sipInfo = sipInfo; + op->cardinality = cardinality; + return op; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_limit.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_limit.cpp new file mode 100644 index 0000000000..e9d14f17a0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_limit.cpp @@ -0,0 +1,45 @@ +#include "planner/operator/logical_limit.h" + +#include "binder/expression/expression_util.h" +#include "planner/operator/factorization/flatten_resolver.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +std::string LogicalLimit::getExpressionsForPrinting() const { + std::string result; + if (hasSkipNum()) { + result += "SKIP "; + if (ExpressionUtil::canEvaluateAsLiteral(*skipNum)) { + result += std::to_string(ExpressionUtil::evaluateAsSkipLimit(*skipNum)); + } + } + if (hasLimitNum()) { + if (!result.empty()) { + result += ","; + } + result += "LIMIT "; + if (ExpressionUtil::canEvaluateAsLiteral(*limitNum)) { + result += std::to_string(ExpressionUtil::evaluateAsSkipLimit(*limitNum)); + } + } + return result; +} + +f_group_pos_set LogicalLimit::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return FlattenAllButOne::getGroupsPosToFlatten(childSchema->getGroupsPosInScope(), + *childSchema); +} + +f_group_pos LogicalLimit::getGroupPosToSelect() const { + auto childSchema = children[0]->getSchema(); + auto groupsPosInScope = childSchema->getGroupsPosInScope(); + SchemaUtils::validateAtMostOneUnFlatGroup(groupsPosInScope, *childSchema); + return SchemaUtils::getLeadingGroupPos(groupsPosInScope, *childSchema); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_operator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_operator.cpp new file mode 100644 index 0000000000..9f80089c91 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_operator.cpp @@ -0,0 +1,195 @@ +#include "planner/operator/logical_operator.h" + +#include "common/exception/runtime.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +// LCOV_EXCL_START +std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorType type) { + switch (type) { + case LogicalOperatorType::ACCUMULATE: + return "ACCUMULATE"; + case LogicalOperatorType::AGGREGATE: + return "AGGREGATE"; + case LogicalOperatorType::ALTER: + return "ALTER"; + case LogicalOperatorType::ATTACH_DATABASE: + return "ATTACH_DATABASE"; + case LogicalOperatorType::COPY_FROM: + return "COPY_FROM"; + case LogicalOperatorType::COPY_TO: + return "COPY_TO"; + case LogicalOperatorType::CREATE_MACRO: + return "CREATE_MACRO"; + case LogicalOperatorType::CREATE_SEQUENCE: + return "CREATE_SEQUENCE"; + case LogicalOperatorType::CREATE_TABLE: + return "CREATE_TABLE"; + case LogicalOperatorType::CROSS_PRODUCT: + return "CROSS_PRODUCT"; + case LogicalOperatorType::DELETE: + return "DELETE_NODE"; + case LogicalOperatorType::DETACH_DATABASE: + return "DETACH_DATABASE"; + case LogicalOperatorType::DISTINCT: + return "DISTINCT"; + case LogicalOperatorType::DROP: + return "DROP"; + case LogicalOperatorType::DUMMY_SCAN: + return "DUMMY_SCAN"; + case LogicalOperatorType::DUMMY_SINK: + return "DUMMY_SINK"; + case LogicalOperatorType::EMPTY_RESULT: + return "EMPTY_RESULT"; + case LogicalOperatorType::EXPLAIN: + return "EXPLAIN"; + case LogicalOperatorType::EXPRESSIONS_SCAN: + return "EXPRESSIONS_SCAN"; + case LogicalOperatorType::EXTENSION: + return "LOAD"; + case LogicalOperatorType::EXPORT_DATABASE: + return "EXPORT_DATABASE"; + case LogicalOperatorType::EXTEND: + return "EXTEND"; + case LogicalOperatorType::FILTER: + return "FILTER"; + case LogicalOperatorType::FLATTEN: + return "FLATTEN"; + case LogicalOperatorType::HASH_JOIN: + return "HASH_JOIN"; + case LogicalOperatorType::IMPORT_DATABASE: + return "IMPORT_DATABASE"; + case LogicalOperatorType::INDEX_LOOK_UP: + return "INDEX_LOOK_UP"; + case LogicalOperatorType::INTERSECT: + return "INTERSECT"; + case LogicalOperatorType::INSERT: + return "INSERT"; + case LogicalOperatorType::LIMIT: + return "LIMIT"; + case LogicalOperatorType::MERGE: + return "MERGE"; + case LogicalOperatorType::MULTIPLICITY_REDUCER: + return "MULTIPLICITY_REDUCER"; + case LogicalOperatorType::NODE_LABEL_FILTER: + return "NODE_LABEL_FILTER"; + case LogicalOperatorType::NOOP: + return "NOOP"; + case LogicalOperatorType::ORDER_BY: + return "ORDER_BY"; + case LogicalOperatorType::PARTITIONER: + return "PARTITIONER"; + case LogicalOperatorType::PATH_PROPERTY_PROBE: + return "PATH_PROPERTY_PROBE"; + case LogicalOperatorType::PROJECTION: + return "PROJECTION"; + case LogicalOperatorType::RECURSIVE_EXTEND: + return "RECURSIVE_EXTEND"; + case LogicalOperatorType::SCAN_NODE_TABLE: + return "SCAN_NODE_TABLE"; + case LogicalOperatorType::SEMI_MASKER: + return "SEMI_MASKER"; + case LogicalOperatorType::SET_PROPERTY: + return "SET_PROPERTY"; + case LogicalOperatorType::STANDALONE_CALL: + return "STANDALONE_CALL"; + case LogicalOperatorType::TABLE_FUNCTION_CALL: + return "TABLE_FUNCTION_CALL"; + case LogicalOperatorType::TRANSACTION: + return "TRANSACTION"; + case LogicalOperatorType::UNION_ALL: + return "UNION_ALL"; + case LogicalOperatorType::UNWIND: + return "UNWIND"; + case LogicalOperatorType::USE_DATABASE: + return "USE_DATABASE"; + case LogicalOperatorType::CREATE_TYPE: + return "CREATE_TYPE"; + case LogicalOperatorType::EXTENSION_CLAUSE: + return "EXTENSION_CLAUSE"; + default: + throw RuntimeException("Unknown logical operator type."); + } +} +// LCOV_EXCL_STOP + +bool LogicalOperatorUtils::isUpdate(LogicalOperatorType type) { + switch (type) { + case LogicalOperatorType::INSERT: + case LogicalOperatorType::DELETE: + case LogicalOperatorType::SET_PROPERTY: + case LogicalOperatorType::MERGE: + return true; + default: + return false; + } +} + +bool LogicalOperatorUtils::isAccHashJoin(const LogicalOperator& op) { + return op.getOperatorType() == LogicalOperatorType::HASH_JOIN && + op.getChild(0)->getOperatorType() == LogicalOperatorType::ACCUMULATE; +} + +LogicalOperator::LogicalOperator(LogicalOperatorType operatorType, + std::shared_ptr child, std::optional cardinality) + : operatorType{operatorType}, + cardinality{cardinality.has_value() ? cardinality.value() : child->getCardinality()} { + children.push_back(std::move(child)); +} + +LogicalOperator::LogicalOperator(LogicalOperatorType operatorType, + std::shared_ptr left, std::shared_ptr right) + : LogicalOperator{operatorType} { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +LogicalOperator::LogicalOperator(LogicalOperatorType operatorType, + const logical_op_vector_t& children) + : LogicalOperator{operatorType} { + for (auto& child : children) { + this->children.push_back(child); + } +} + +bool LogicalOperator::hasUpdateRecursive() { + if (LogicalOperatorUtils::isUpdate(operatorType)) { + return true; + } + for (auto& child : children) { + if (child->hasUpdateRecursive()) { + return true; + } + } + return false; +} + +std::string LogicalOperator::toString(uint64_t depth) const { + auto padding = std::string(depth * 4, ' '); + std::string result = padding; + result += LogicalOperatorUtils::logicalOperatorTypeToString(operatorType) + "[" + + getExpressionsForPrinting() + "]"; + if (children.size() == 1) { + result += "\n" + children[0]->toString(depth); + } else { + for (auto& child : children) { + result += "\n" + padding + "CHILD:\n" + child->toString(depth + 1); + } + } + return result; +} + +logical_op_vector_t LogicalOperator::copy(const logical_op_vector_t& ops) { + logical_op_vector_t result; + result.reserve(ops.size()); + for (auto& op : ops) { + result.push_back(op->copy()); + } + return result; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_order_by.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_order_by.cpp new file mode 100644 index 0000000000..e2ab8b5ad8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_order_by.cpp @@ -0,0 +1,59 @@ +#include "planner/operator/logical_order_by.h" + +#include "binder/expression/expression_util.h" +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/factorization/sink_util.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +f_group_pos_set LogicalOrderBy::getGroupsPosToFlatten() { + // We only allow orderby key(s) to be unflat, if they are all part of the same factorization + // group and there is no other factorized group in the schema, so any payload is also unflat + // and part of the same factorization group. The rationale for this limitation is this: (1) + // to keep both the frontend and orderby operators simpler, we want order by to not change + // the schema, so the input and output of order by should have the same factorization + // structure. (2) Because orderby needs to flatten the keys to sort, if a key column that is + // unflat is the input, we need to somehow flatten it in the factorized table. However + // whenever we can we want to avoid adding an explicit flatten operator as this makes us + // fall back to tuple-at-a-time processing. However in the specified limited case, we can + // give factorized table a set of unflat vectors (all in the same datachunk/factorization + // group), sort the table, and scan into unflat vectors, so the schema remains the same. In + // more complicated cases, e.g., when there are 2 factorization groups, FactorizedTable + // cannot read back a flat column into an unflat std::vector. + auto childSchema = children[0]->getSchema(); + if (childSchema->getNumGroups() > 1) { + return FlattenAll::getGroupsPosToFlatten(expressionsToOrderBy, *childSchema); + } + return f_group_pos_set{}; +} + +void LogicalOrderBy::computeFactorizedSchema() { + createEmptySchema(); + auto childSchema = children[0]->getSchema(); + SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema); +} + +void LogicalOrderBy::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); + for (auto& expression : children[0]->getSchema()->getExpressionsInScope()) { + schema->insertToGroupAndScope(expression, 0); + } +} + +std::string LogicalOrderBy::getExpressionsForPrinting() const { + auto result = ExpressionUtil::toString(expressionsToOrderBy) + " "; + if (hasSkipNum() && ExpressionUtil::canEvaluateAsLiteral(*skipNum)) { + result += "SKIP " + std::to_string(ExpressionUtil::evaluateAsSkipLimit(*skipNum)); + } + if (hasLimitNum() && ExpressionUtil::canEvaluateAsLiteral(*limitNum)) { + result += "LIMIT " + std::to_string(ExpressionUtil::evaluateAsSkipLimit(*limitNum)); + } + return result; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_partitioner.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_partitioner.cpp new file mode 100644 index 0000000000..7c79b3ae20 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_partitioner.cpp @@ -0,0 +1,41 @@ +#include "planner/operator/logical_partitioner.h" + +#include "binder/expression/expression_util.h" +#include "common/exception/runtime.h" + +namespace lbug { +namespace planner { + +static void validateSingleGroup(const Schema& schema) { + if (schema.getNumGroups() != 1) { + throw common::RuntimeException( + "Try to partition multiple factorization group. This should not happen."); + } +} + +void LogicalPartitioner::computeFactorizedSchema() { + copyChildSchema(0); + // LCOV_EXCL_START + validateSingleGroup(*schema); + // LCOV_EXCL_STOP + schema->insertToGroupAndScope(info.offset, 0); +} + +void LogicalPartitioner::computeFlatSchema() { + copyChildSchema(0); + // LCOV_EXCL_START + validateSingleGroup(*schema); + // LCOV_EXCL_STOP + schema->insertToGroupAndScope(info.offset, 0); +} + +std::string LogicalPartitioner::getExpressionsForPrinting() const { + binder::expression_vector expressions; + for (auto& partitioningInfo : info.partitioningInfos) { + expressions.push_back(copyFromInfo.columnExprs[partitioningInfo.keyIdx]); + } + return binder::ExpressionUtil::toString(expressions); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_path_property_probe.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_path_property_probe.cpp new file mode 100644 index 0000000000..3fc3ad48d9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_path_property_probe.cpp @@ -0,0 +1,57 @@ +#include "planner/operator/logical_path_property_probe.h" + +#include "optimizer/factorization_rewriter.h" +#include "optimizer/remove_factorization_rewriter.h" + +namespace lbug { +namespace planner { + +void LogicalPathPropertyProbe::computeFactorizedSchema() { + copyChildSchema(0); + if (pathNodeIDs != nullptr) { + KU_ASSERT(schema->getNumGroups() == 1); + schema->insertToGroupAndScope(recursiveRel, 0); + } + + if (nodeChild != nullptr) { + auto rewriter = optimizer::FactorizationRewriter(); + rewriter.visitOperator(nodeChild.get()); + } + if (relChild != nullptr) { + auto rewriter = optimizer::FactorizationRewriter(); + rewriter.visitOperator(relChild.get()); + } +} + +void LogicalPathPropertyProbe::computeFlatSchema() { + copyChildSchema(0); + if (pathNodeIDs != nullptr) { + KU_ASSERT(schema->getNumGroups() == 1); + schema->insertToGroupAndScope(recursiveRel, 0); + } + + if (nodeChild != nullptr) { + auto rewriter = optimizer::RemoveFactorizationRewriter(); + rewriter.visitOperator(nodeChild); + } + if (relChild != nullptr) { + auto rewriter = optimizer::RemoveFactorizationRewriter(); + rewriter.visitOperator(relChild); + } +} + +std::unique_ptr LogicalPathPropertyProbe::copy() { + auto nodeChildCopy = nodeChild == nullptr ? nullptr : nodeChild->copy(); + auto relChildCopy = relChild == nullptr ? nullptr : relChild->copy(); + auto op = std::make_unique(recursiveRel, children[0]->copy(), + std::move(nodeChildCopy), std::move(relChildCopy), joinType); + op->sipInfo = sipInfo; + op->direction = direction; + op->extendFromLeft = extendFromLeft; + op->pathNodeIDs = pathNodeIDs; + op->pathEdgeIDs = pathEdgeIDs; + return op; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_plan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_plan.cpp new file mode 100644 index 0000000000..188857a3e8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_plan.cpp @@ -0,0 +1,19 @@ +#include "planner/operator/logical_plan.h" + +#include "planner/operator/logical_explain.h" + +namespace lbug { +namespace planner { + +bool LogicalPlan::isProfile() const { + return lastOperator->getOperatorType() == LogicalOperatorType::EXPLAIN && + reinterpret_cast(lastOperator.get())->getExplainType() == + common::ExplainType::PROFILE; +} + +bool LogicalPlan::hasUpdate() const { + return lastOperator->hasUpdateRecursive(); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_plan_util.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_plan_util.cpp new file mode 100644 index 0000000000..b54c879886 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_plan_util.cpp @@ -0,0 +1,105 @@ +#include "planner/operator/logical_plan_util.h" + +#include "binder/expression/property_expression.h" +#include "planner/operator/extend/logical_extend.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_intersect.h" +#include "planner/operator/scan/logical_scan_node_table.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +std::string LogicalPlanUtil::encodeJoin(LogicalPlan& logicalPlan) { + return encode(logicalPlan.getLastOperator().get()); +} + +std::string LogicalPlanUtil::encode(LogicalOperator* logicalOperator) { + std::string result; + encodeRecursive(logicalOperator, result); + return result; +} + +void LogicalPlanUtil::encodeRecursive(LogicalOperator* logicalOperator, std::string& encodeString) { + switch (logicalOperator->getOperatorType()) { + case LogicalOperatorType::CROSS_PRODUCT: { + encodeCrossProduct(logicalOperator, encodeString); + for (auto i = 0u; i < logicalOperator->getNumChildren(); ++i) { + encodeString += "{"; + encodeRecursive(logicalOperator->getChild(i).get(), encodeString); + encodeString += "}"; + } + } break; + case LogicalOperatorType::INTERSECT: { + encodeIntersect(logicalOperator, encodeString); + for (auto i = 0u; i < logicalOperator->getNumChildren(); ++i) { + encodeString += "{"; + encodeRecursive(logicalOperator->getChild(i).get(), encodeString); + encodeString += "}"; + } + } break; + case LogicalOperatorType::HASH_JOIN: { + encodeHashJoin(logicalOperator, encodeString); + encodeString += "{"; + encodeRecursive(logicalOperator->getChild(0).get(), encodeString); + encodeString += "}{"; + encodeRecursive(logicalOperator->getChild(1).get(), encodeString); + encodeString += "}"; + } break; + case LogicalOperatorType::EXTEND: { + encodeExtend(logicalOperator, encodeString); + encodeRecursive(logicalOperator->getChild(0).get(), encodeString); + } break; + case LogicalOperatorType::SCAN_NODE_TABLE: { + encodeScanNodeTable(logicalOperator, encodeString); + } break; + case LogicalOperatorType::FILTER: { + encodeFilter(logicalOperator, encodeString); + encodeRecursive(logicalOperator->getChild(0).get(), encodeString); + } break; + default: + for (auto i = 0u; i < logicalOperator->getNumChildren(); ++i) { + encodeRecursive(logicalOperator->getChild(i).get(), encodeString); + } + } +} + +void LogicalPlanUtil::encodeCrossProduct(LogicalOperator* /*logicalOperator*/, + std::string& encodeString) { + encodeString += "CP()"; +} + +void LogicalPlanUtil::encodeIntersect(LogicalOperator* logicalOperator, std::string& encodeString) { + auto& logicalIntersect = logicalOperator->constCast(); + encodeString += "I(" + logicalIntersect.getIntersectNodeID()->toString() + ")"; +} + +void LogicalPlanUtil::encodeHashJoin(LogicalOperator* logicalOperator, std::string& encodeString) { + auto& logicalHashJoin = logicalOperator->constCast(); + encodeString += "HJ(" + logicalHashJoin.getExpressionsForPrinting() + ")"; +} + +void LogicalPlanUtil::encodeExtend(LogicalOperator* logicalOperator, std::string& encodeString) { + auto& logicalExtend = logicalOperator->constCast(); + encodeString += "E(" + logicalExtend.getNbrNode()->toString() + ")"; +} + +void LogicalPlanUtil::encodeScanNodeTable(LogicalOperator* logicalOperator, + std::string& encodeString) { + auto& scan = logicalOperator->constCast(); + if (scan.getScanType() == LogicalScanNodeTableType::PRIMARY_KEY_SCAN) { + encodeString += "IndexScan"; + } else { + encodeString += "S"; + } + encodeString += + "(" + scan.getNodeID()->constCast().getRawVariableName() + ")"; +} + +void LogicalPlanUtil::encodeFilter(LogicalOperator*, std::string& encodedString) { + encodedString += "Filter()"; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_projection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_projection.cpp new file mode 100644 index 0000000000..0f7f7379f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_projection.cpp @@ -0,0 +1,60 @@ +#include "planner/operator/logical_projection.h" + +#include "planner/operator/factorization/flatten_resolver.h" + +namespace lbug { +namespace planner { + +void LogicalProjection::computeFactorizedSchema() { + auto childSchema = children[0]->getSchema(); + schema = childSchema->copy(); + schema->clearExpressionsInScope(); + for (auto& expression : expressions) { + auto groupPos = INVALID_F_GROUP_POS; + if (childSchema->isExpressionInScope(*expression)) { // expression to reference + groupPos = childSchema->getGroupPos(*expression); + schema->insertToScopeMayRepeat(expression, groupPos); + } else { // expression to evaluate + auto analyzer = GroupDependencyAnalyzer(false, *childSchema); + analyzer.visit(expression); + auto dependentGroupPos = analyzer.getDependentGroups(); + SchemaUtils::validateAtMostOneUnFlatGroup(dependentGroupPos, *childSchema); + if (dependentGroupPos.empty()) { // constant + groupPos = schema->createGroup(); + schema->setGroupAsSingleState(groupPos); + } else { + groupPos = SchemaUtils::getLeadingGroupPos(dependentGroupPos, *childSchema); + } + schema->insertToGroupAndScopeMayRepeat(expression, groupPos); + } + } +} + +void LogicalProjection::computeFlatSchema() { + copyChildSchema(0); + auto childSchema = children[0]->getSchema(); + schema->clearExpressionsInScope(); + for (auto& expression : expressions) { + if (childSchema->isExpressionInScope(*expression)) { + schema->insertToScopeMayRepeat(expression, 0); + } else { + schema->insertToGroupAndScopeMayRepeat(expression, 0); + } + } +} + +std::unordered_set LogicalProjection::getDiscardedGroupsPos() const { + auto groupsPosInScopeBeforeProjection = children[0]->getSchema()->getGroupsPosInScope(); + auto groupsPosInScopeAfterProjection = schema->getGroupsPosInScope(); + std::unordered_set discardGroupsPos; + for (auto i = 0u; i < schema->getNumGroups(); ++i) { + if (groupsPosInScopeBeforeProjection.contains(i) && + !groupsPosInScopeAfterProjection.contains(i)) { + discardGroupsPos.insert(i); + } + } + return discardGroupsPos; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_standalone_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_standalone_call.cpp new file mode 100644 index 0000000000..9277b4ada8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_standalone_call.cpp @@ -0,0 +1,13 @@ +#include "planner/operator/logical_standalone_call.h" + +#include "main/db_config.h" + +namespace lbug { +namespace planner { + +std::string LogicalStandaloneCall::getExpressionsForPrinting() const { + return option->name; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_table_function_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_table_function_call.cpp new file mode 100644 index 0000000000..e78d0712a2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_table_function_call.cpp @@ -0,0 +1,23 @@ +#include "planner/operator/logical_table_function_call.h" + +namespace lbug { +namespace planner { + +void LogicalTableFunctionCall::computeFlatSchema() { + createEmptySchema(); + auto groupPos = schema->createGroup(); + for (auto& expr : bindData->columns) { + schema->insertToGroupAndScope(expr, groupPos); + } +} + +void LogicalTableFunctionCall::computeFactorizedSchema() { + createEmptySchema(); + auto groupPos = schema->createGroup(); + for (auto& expr : bindData->columns) { + schema->insertToGroupAndScope(expr, groupPos); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_union.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_union.cpp new file mode 100644 index 0000000000..b9fdc1ba67 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_union.cpp @@ -0,0 +1,57 @@ +#include "planner/operator/logical_union.h" + +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/factorization/sink_util.h" + +namespace lbug { +namespace planner { + +f_group_pos_set LogicalUnion::getGroupsPosToFlatten(uint32_t childIdx) { + f_group_pos_set groupsPos; + auto childSchema = children[childIdx]->getSchema(); + for (auto i = 0u; i < expressionsToUnion.size(); ++i) { + if (requireFlatExpression(i)) { + auto expression = childSchema->getExpressionsInScope()[i]; + groupsPos.insert(childSchema->getGroupPos(*expression)); + } + } + return FlattenAll::getGroupsPosToFlatten(groupsPos, *childSchema); +} + +void LogicalUnion::computeFactorizedSchema() { + auto firstChildSchema = children[0]->getSchema(); + createEmptySchema(); + SinkOperatorUtil::recomputeSchema(*firstChildSchema, firstChildSchema->getExpressionsInScope(), + *schema); +} + +void LogicalUnion::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); + for (auto& expression : children[0]->getSchema()->getExpressionsInScope()) { + schema->insertToGroupAndScope(expression, 0); + } +} + +std::unique_ptr LogicalUnion::copy() { + std::vector> copiedChildren; + copiedChildren.reserve(getNumChildren()); + for (auto i = 0u; i < getNumChildren(); ++i) { + copiedChildren.push_back(getChild(i)->copy()); + } + return make_unique(expressionsToUnion, std::move(copiedChildren)); +} + +bool LogicalUnion::requireFlatExpression(uint32_t expressionIdx) { + for (auto& child : children) { + auto childSchema = child->getSchema(); + auto expression = childSchema->getExpressionsInScope()[expressionIdx]; + if (childSchema->getGroup(expression)->isFlat()) { + return true; + } + } + return false; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_unwind.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_unwind.cpp new file mode 100644 index 0000000000..38735c2e8d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/logical_unwind.cpp @@ -0,0 +1,34 @@ +#include "planner/operator/logical_unwind.h" + +#include "planner/operator/factorization/flatten_resolver.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +f_group_pos_set LogicalUnwind::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return FlattenAll::getGroupsPosToFlatten(inExpr, *childSchema); +} + +void LogicalUnwind::computeFactorizedSchema() { + copyChildSchema(0); + auto groupPos = schema->createGroup(); + schema->insertToGroupAndScope(outExpr, groupPos); + if (hasIDExpr()) { + schema->insertToGroupAndScope(idExpr, groupPos); + } +} + +void LogicalUnwind::computeFlatSchema() { + copyChildSchema(0); + schema->insertToGroupAndScope(outExpr, 0); + if (hasIDExpr()) { + schema->insertToGroupAndScope(idExpr, 0); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/CMakeLists.txt new file mode 100644 index 0000000000..93a3eea661 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(lbug_planner_persistent + OBJECT + logical_copy_from.cpp + logical_copy_to.cpp + logical_insert.cpp + logical_delete.cpp + logical_merge.cpp + logical_set.cpp +) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_copy_from.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_copy_from.cpp new file mode 100644 index 0000000000..7c77b89ffe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_copy_from.cpp @@ -0,0 +1,17 @@ +#include "planner/operator/persistent/logical_copy_from.h" + +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void LogicalCopyFrom::computeFactorizedSchema() { + copyChildSchema(0); +} + +void LogicalCopyFrom::computeFlatSchema() { + copyChildSchema(0); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_copy_to.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_copy_to.cpp new file mode 100644 index 0000000000..260ad1abd0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_copy_to.cpp @@ -0,0 +1,33 @@ +#include "planner/operator/persistent/logical_copy_to.h" + +#include "planner/operator/factorization/flatten_resolver.h" + +namespace lbug { +namespace planner { + +std::string LogicalCopyToPrintInfo::toString() const { + std::string result = ""; + result += "Export: "; + for (auto& name : columnNames) { + result += name + ", "; + } + result += "To: " + fileName; + return result; +} + +void LogicalCopyTo::computeFactorizedSchema() { + copyChildSchema(0); +} + +void LogicalCopyTo::computeFlatSchema() { + copyChildSchema(0); +} + +f_group_pos_set LogicalCopyTo::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + auto dependentGroupsPos = childSchema->getGroupsPosInScope(); + return FlattenAllButOne::getGroupsPosToFlatten(dependentGroupsPos, *childSchema); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_delete.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_delete.cpp new file mode 100644 index 0000000000..40107e84fd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_delete.cpp @@ -0,0 +1,47 @@ +#include "planner/operator/persistent/logical_delete.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" +#include "planner/operator/factorization/flatten_resolver.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +std::string LogicalDelete::getExpressionsForPrinting() const { + expression_vector patterns; + for (auto& info : infos) { + patterns.push_back(info.pattern); + } + return ExpressionUtil::toString(patterns); +} + +f_group_pos_set LogicalDelete::getGroupsPosToFlatten() const { + KU_ASSERT(!infos.empty()); + const auto childSchema = children[0]->getSchema(); + f_group_pos_set dependentGroupPos; + switch (infos[0].tableType) { + case TableType::NODE: { + for (auto& info : infos) { + auto nodeID = info.pattern->constCast().getInternalID(); + dependentGroupPos.insert(childSchema->getGroupPos(*nodeID)); + } + } break; + case TableType::REL: { + for (auto& info : infos) { + auto& rel = info.pattern->constCast(); + dependentGroupPos.insert(childSchema->getGroupPos(*rel.getSrcNode()->getInternalID())); + dependentGroupPos.insert(childSchema->getGroupPos(*rel.getDstNode()->getInternalID())); + } + } break; + default: + KU_UNREACHABLE; + } + return FlattenAll::getGroupsPosToFlatten(dependentGroupPos, *childSchema); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_insert.cpp new file mode 100644 index 0000000000..20238fd4e1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_insert.cpp @@ -0,0 +1,60 @@ +#include "planner/operator/persistent/logical_insert.h" + +#include "binder/expression/node_expression.h" +#include "common/cast.h" +#include "planner/operator/factorization/flatten_resolver.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void LogicalInsert::computeFactorizedSchema() { + copyChildSchema(0); + for (auto& info : infos) { + auto groupPos = schema->createGroup(); + schema->setGroupAsSingleState(groupPos); + for (auto i = 0u; i < info.columnExprs.size(); ++i) { + if (info.isReturnColumnExprs[i]) { + schema->insertToGroupAndScope(info.columnExprs[i], groupPos); + } + } + if (info.tableType == TableType::NODE) { + auto node = ku_dynamic_cast(info.pattern.get()); + schema->insertToGroupAndScopeMayRepeat(node->getInternalID(), groupPos); + } + } +} + +void LogicalInsert::computeFlatSchema() { + copyChildSchema(0); + for (auto& info : infos) { + for (auto i = 0u; i < info.columnExprs.size(); ++i) { + if (info.isReturnColumnExprs[i]) { + schema->insertToGroupAndScope(info.columnExprs[i], 0); + } + } + if (info.tableType == TableType::NODE) { + auto node = ku_dynamic_cast(info.pattern.get()); + schema->insertToGroupAndScopeMayRepeat(node->getInternalID(), 0); + } + } +} + +std::string LogicalInsert::getExpressionsForPrinting() const { + std::string result; + for (auto i = 0u; i < infos.size() - 1; ++i) { + result += infos[i].pattern->toString() + ","; + } + result += infos[infos.size() - 1].pattern->toString(); + return result; +} + +f_group_pos_set LogicalInsert::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return FlattenAll::getGroupsPosToFlatten(childSchema->getGroupsPosInScope(), *childSchema); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_merge.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_merge.cpp new file mode 100644 index 0000000000..af0b6464f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_merge.cpp @@ -0,0 +1,51 @@ +#include "planner/operator/persistent/logical_merge.h" + +#include "binder/expression/node_expression.h" +#include "common/cast.h" +#include "planner/operator/factorization/flatten_resolver.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void LogicalMerge::computeFactorizedSchema() { + copyChildSchema(0); + for (auto& info : insertNodeInfos) { + // Predicate iri is not matched but needs to be inserted. + auto node = ku_dynamic_cast(info.pattern.get()); + if (!schema->isExpressionInScope(*node->getInternalID())) { + auto groupPos = schema->createGroup(); + schema->setGroupAsSingleState(groupPos); + schema->insertToGroupAndScope(node->getInternalID(), groupPos); + } + } +} + +void LogicalMerge::computeFlatSchema() { + copyChildSchema(0); + for (auto& info : insertNodeInfos) { + auto node = ku_dynamic_cast(info.pattern.get()); + schema->insertToGroupAndScopeMayRepeat(node->getInternalID(), 0); + } +} + +f_group_pos_set LogicalMerge::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return FlattenAll::getGroupsPosToFlatten(childSchema->getGroupsPosInScope(), *childSchema); +} + +std::unique_ptr LogicalMerge::copy() { + auto merge = std::make_unique(existenceMark, keys, children[0]->copy()); + merge->insertNodeInfos = copyVector(insertNodeInfos); + merge->insertRelInfos = copyVector(insertRelInfos); + merge->onCreateSetNodeInfos = copyVector(onCreateSetNodeInfos); + merge->onCreateSetRelInfos = copyVector(onCreateSetRelInfos); + merge->onMatchSetNodeInfos = copyVector(onMatchSetNodeInfos); + merge->onMatchSetRelInfos = copyVector(onMatchSetRelInfos); + return merge; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_set.cpp new file mode 100644 index 0000000000..b09c57e887 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/persistent/logical_set.cpp @@ -0,0 +1,61 @@ +#include "planner/operator/persistent/logical_set.h" + +#include "binder/expression/expression_util.h" +#include "binder/expression/rel_expression.h" +#include "planner/operator/factorization/flatten_resolver.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void LogicalSetProperty::computeFactorizedSchema() { + copyChildSchema(0); +} + +void LogicalSetProperty::computeFlatSchema() { + copyChildSchema(0); +} + +f_group_pos_set LogicalSetProperty::getGroupsPosToFlatten(uint32_t idx) const { + f_group_pos_set result; + auto childSchema = children[0]->getSchema(); + auto& info = infos[idx]; + switch (getTableType()) { + case TableType::NODE: { + auto node = info.pattern->constPtrCast(); + result.insert(childSchema->getGroupPos(*node->getInternalID())); + } break; + case TableType::REL: { + auto rel = info.pattern->constPtrCast(); + result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalID())); + result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalID())); + } break; + default: + KU_UNREACHABLE; + } + auto analyzer = GroupDependencyAnalyzer(false, *childSchema); + analyzer.visit(info.columnData); + for (auto& groupPos : analyzer.getDependentGroups()) { + result.insert(groupPos); + } + return FlattenAll::getGroupsPosToFlatten(result, *childSchema); +} + +std::string LogicalSetProperty::getExpressionsForPrinting() const { + std::string result = + ExpressionUtil::toString(std::make_pair(infos[0].column, infos[0].columnData)); + for (auto i = 1u; i < infos.size(); ++i) { + result += ExpressionUtil::toString(std::make_pair(infos[i].column, infos[i].columnData)); + } + return result; +} + +common::TableType LogicalSetProperty::getTableType() const { + KU_ASSERT(!infos.empty()); + return infos[0].tableType; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/CMakeLists.txt new file mode 100644 index 0000000000..f2b03bca0f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_planner_scan + OBJECT + logical_expressions_scan.cpp + logical_index_look_up.cpp + logical_scan_node_table.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_expressions_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_expressions_scan.cpp new file mode 100644 index 0000000000..c1f767461a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_expressions_scan.cpp @@ -0,0 +1,15 @@ +#include "planner/operator/scan/logical_expressions_scan.h" + +namespace lbug { +namespace planner { + +void LogicalExpressionsScan::computeSchema() { + createEmptySchema(); + schema->createGroup(); + for (auto& expression : expressions) { + schema->insertToGroupAndScope(expression, 0); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_index_look_up.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_index_look_up.cpp new file mode 100644 index 0000000000..6937bf337a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_index_look_up.cpp @@ -0,0 +1,35 @@ +#include "planner/operator/scan/logical_index_look_up.h" + +#include "binder/expression/expression_util.h" + +namespace lbug { +namespace planner { + +std::string LogicalPrimaryKeyLookup::getExpressionsForPrinting() const { + binder::expression_vector expressions; + for (auto& info : infos) { + expressions.push_back(info.offset); + } + return binder::ExpressionUtil::toString(expressions); +} + +void LogicalPrimaryKeyLookup::computeFactorizedSchema() { + copyChildSchema(0); + for (auto& info : infos) { + auto groupPos = 0u; + if (schema->isExpressionInScope(*info.key)) { + groupPos = schema->getGroupPos(*info.key); + } + schema->insertToGroupAndScope(info.offset, groupPos); + } +} + +void LogicalPrimaryKeyLookup::computeFlatSchema() { + copyChildSchema(0); + for (auto& info : infos) { + schema->insertToGroupAndScope(info.offset, 0); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_scan_node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_scan_node_table.cpp new file mode 100644 index 0000000000..177ecd647e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/scan/logical_scan_node_table.cpp @@ -0,0 +1,47 @@ +#include "planner/operator/scan/logical_scan_node_table.h" + +namespace lbug { +namespace planner { + +LogicalScanNodeTable::LogicalScanNodeTable(const LogicalScanNodeTable& other) + : LogicalOperator{type_}, scanType{other.scanType}, nodeID{other.nodeID}, + nodeTableIDs{other.nodeTableIDs}, properties{other.properties}, + propertyPredicates{copyVector(other.propertyPredicates)} { + if (other.extraInfo != nullptr) { + setExtraInfo(other.extraInfo->copy()); + } + this->cardinality = other.cardinality; +} + +void LogicalScanNodeTable::computeFactorizedSchema() { + createEmptySchema(); + const auto groupPos = schema->createGroup(); + KU_ASSERT(groupPos == 0); + schema->insertToGroupAndScope(nodeID, groupPos); + for (auto& property : properties) { + schema->insertToGroupAndScope(property, groupPos); + } + switch (scanType) { + case LogicalScanNodeTableType::PRIMARY_KEY_SCAN: { + schema->setGroupAsSingleState(groupPos); + } break; + default: + break; + } +} + +void LogicalScanNodeTable::computeFlatSchema() { + createEmptySchema(); + schema->createGroup(); + schema->insertToGroupAndScope(nodeID, 0); + for (auto& property : properties) { + schema->insertToGroupAndScope(property, 0); + } +} + +std::unique_ptr LogicalScanNodeTable::copy() { + return std::make_unique(*this); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/schema.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/schema.cpp new file mode 100644 index 0000000000..b1123e9526 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/schema.cpp @@ -0,0 +1,163 @@ +#include "planner/operator/schema.h" + +#include "binder/expression_visitor.h" +#include "common/exception/internal.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +f_group_pos Schema::createGroup() { + auto pos = groups.size(); + groups.push_back(std::make_unique()); + return pos; +} + +void Schema::insertToScope(const std::shared_ptr& expression, f_group_pos groupPos) { + KU_ASSERT(!expressionNameToGroupPos.contains(expression->getUniqueName())); + expressionNameToGroupPos.insert({expression->getUniqueName(), groupPos}); + KU_ASSERT(getGroup(groupPos)->expressionNameToPos.contains(expression->getUniqueName())); + expressionsInScope.push_back(expression); +} + +void Schema::insertToGroupAndScope(const std::shared_ptr& expression, + f_group_pos groupPos) { + KU_ASSERT(!expressionNameToGroupPos.contains(expression->getUniqueName())); + expressionNameToGroupPos.insert({expression->getUniqueName(), groupPos}); + groups[groupPos]->insertExpression(expression); + expressionsInScope.push_back(expression); +} + +void Schema::insertToScopeMayRepeat(const std::shared_ptr& expression, + uint32_t groupPos) { + if (expressionNameToGroupPos.contains(expression->getUniqueName())) { + return; + } + insertToScope(expression, groupPos); +} + +void Schema::insertToGroupAndScopeMayRepeat(const std::shared_ptr& expression, + uint32_t groupPos) { + if (expressionNameToGroupPos.contains(expression->getUniqueName())) { + return; + } + insertToGroupAndScope(expression, groupPos); +} + +void Schema::insertToGroupAndScope(const expression_vector& expressions, f_group_pos groupPos) { + for (auto& expression : expressions) { + insertToGroupAndScope(expression, groupPos); + } +} + +f_group_pos Schema::getGroupPos(const std::string& expressionName) const { + KU_ASSERT(expressionNameToGroupPos.contains(expressionName)); + return expressionNameToGroupPos.at(expressionName); +} + +bool Schema::isExpressionInScope(const Expression& expression) const { + for (auto& expressionInScope : expressionsInScope) { + if (expressionInScope->getUniqueName() == expression.getUniqueName()) { + return true; + } + } + return false; +} + +expression_vector Schema::getExpressionsInScope(f_group_pos pos) const { + expression_vector result; + for (auto& expression : expressionsInScope) { + if (getGroupPos(expression->getUniqueName()) == pos) { + result.push_back(expression); + } + } + return result; +} + +bool Schema::evaluable(const Expression& expression) const { + auto inScope = isExpressionInScope(expression); + if (expression.expressionType == ExpressionType::LITERAL || inScope) { + return true; + } + auto children = ExpressionChildrenCollector::collectChildren(expression); + if (children.empty()) { + return inScope; + } else { + for (auto& child : children) { + if (!evaluable(*child)) { + return false; + } + } + return true; + } +} + +std::unordered_set Schema::getGroupsPosInScope() const { + std::unordered_set result; + for (auto& expressionInScope : expressionsInScope) { + result.insert(getGroupPos(expressionInScope->getUniqueName())); + } + return result; +} + +std::unique_ptr Schema::copy() const { + auto newSchema = std::make_unique(); + newSchema->expressionNameToGroupPos = expressionNameToGroupPos; + for (auto& group : groups) { + newSchema->groups.push_back(std::make_unique(*group)); + } + newSchema->expressionsInScope = expressionsInScope; + return newSchema; +} + +void Schema::clear() { + groups.clear(); + clearExpressionsInScope(); +} + +size_t Schema::getNumGroups(bool isFlat) const { + auto result = 0u; + for (auto groupPos : getGroupsPosInScope()) { + result += groups[groupPos]->isFlat() == isFlat; + } + return result; +} + +f_group_pos SchemaUtils::getLeadingGroupPos(const std::unordered_set& groupPositions, + const Schema& schema) { + auto leadingGroupPos = INVALID_F_GROUP_POS; + for (auto groupPos : groupPositions) { + if (!schema.getGroup(groupPos)->isFlat()) { + return groupPos; + } + leadingGroupPos = groupPos; + } + return leadingGroupPos; +} + +void SchemaUtils::validateAtMostOneUnFlatGroup( + const std::unordered_set& groupPositions, const Schema& schema) { + auto hasUnFlatGroup = false; + for (auto groupPos : groupPositions) { + if (!schema.getGroup(groupPos)->isFlat()) { + if (hasUnFlatGroup) { + throw InternalException("Unexpected multiple unFlat factorization groups found."); + } + hasUnFlatGroup = true; + } + } +} + +void SchemaUtils::validateNoUnFlatGroup(const std::unordered_set& groupPositions, + const Schema& schema) { + for (auto groupPos : groupPositions) { + if (!schema.getGroup(groupPos)->isFlat()) { + throw InternalException("Unexpected unFlat factorization group found."); + } + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/simple/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/simple/CMakeLists.txt new file mode 100644 index 0000000000..f0a3a9edcf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/simple/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_planner_simple + OBJECT + logical_simple.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/simple/logical_simple.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/simple/logical_simple.cpp new file mode 100644 index 0000000000..ed6d16417c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/simple/logical_simple.cpp @@ -0,0 +1,15 @@ +#include "planner/operator/simple/logical_simple.h" + +namespace lbug { +namespace planner { + +void LogicalSimple::computeFlatSchema() { + createEmptySchema(); +} + +void LogicalSimple::computeFactorizedSchema() { + createEmptySchema(); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/sip/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/sip/CMakeLists.txt new file mode 100644 index 0000000000..450b9fc170 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/sip/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_planner_sip + OBJECT + logical_semi_masker.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/sip/logical_semi_masker.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/sip/logical_semi_masker.cpp new file mode 100644 index 0000000000..24327c7799 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/operator/sip/logical_semi_masker.cpp @@ -0,0 +1,7 @@ +#include "planner/operator/sip/logical_semi_masker.h" + +namespace lbug { +namespace planner { +LogicalSemiMasker::~LogicalSemiMasker() = default; +} +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/CMakeLists.txt new file mode 100644 index 0000000000..2282652c0d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/CMakeLists.txt @@ -0,0 +1,38 @@ +add_library(lbug_planner_plan_operator + OBJECT + append_accumulate.cpp + append_aggregate.cpp + append_insert.cpp + append_cross_product.cpp + append_delete.cpp + append_distinct.cpp + append_dummy_scan.cpp + append_empty_result.cpp + append_expressions_scan.cpp + append_extend.cpp + append_filter.cpp + append_flatten.cpp + append_table_function_call.cpp + append_join.cpp + append_limit.cpp + append_multiplicity_reducer.cpp + append_order_by.cpp + append_projection.cpp + append_scan_node_table.cpp + append_set.cpp + append_simple.cpp + append_unwind.cpp + plan_copy.cpp + plan_join_order.cpp + plan_node_scan.cpp + plan_node_semi_mask.cpp + plan_projection.cpp + plan_read.cpp + plan_single_query.cpp + plan_subquery.cpp + plan_update.cpp + plan_port_db.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_accumulate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_accumulate.cpp new file mode 100644 index 0000000000..b0cb4bb3d3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_accumulate.cpp @@ -0,0 +1,40 @@ +#include "planner/operator/logical_accumulate.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void Planner::tryAppendAccumulate(LogicalPlan& plan) { + if (plan.getLastOperator()->getOperatorType() == LogicalOperatorType::ACCUMULATE) { + return; + } + appendAccumulate(plan); +} + +void Planner::appendAccumulate(LogicalPlan& plan) { + appendAccumulate(AccumulateType::REGULAR, expression_vector{}, nullptr /* mark */, plan); +} + +void Planner::appendOptionalAccumulate(std::shared_ptr mark, LogicalPlan& plan) { + appendAccumulate(AccumulateType::OPTIONAL_, expression_vector{}, mark, plan); +} + +void Planner::appendAccumulate(const expression_vector& flatExprs, LogicalPlan& plan) { + appendAccumulate(AccumulateType::REGULAR, flatExprs, nullptr /* mark */, plan); +} + +void Planner::appendAccumulate(AccumulateType accumulateType, const expression_vector& flatExprs, + std::shared_ptr mark, LogicalPlan& plan) { + auto op = + make_shared(accumulateType, flatExprs, mark, plan.getLastOperator()); + appendFlattens(op->getGroupPositionsToFlatten(), plan); + op->setChild(0, plan.getLastOperator()); + op->computeFactorizedSchema(); + plan.setLastOperator(std::move(op)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_aggregate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_aggregate.cpp new file mode 100644 index 0000000000..373f9c11fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_aggregate.cpp @@ -0,0 +1,21 @@ +#include "planner/operator/logical_aggregate.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendAggregate(const expression_vector& expressionsToGroupBy, + const expression_vector& expressionsToAggregate, LogicalPlan& plan) { + auto aggregate = make_shared(expressionsToGroupBy, expressionsToAggregate, + plan.getLastOperator()); + appendFlattens(aggregate->getGroupsPosToFlatten(), plan); + aggregate->setChild(0, plan.getLastOperator()); + aggregate->computeFactorizedSchema(); + aggregate->setCardinality(cardinalityEstimator.estimateAggregate(*aggregate)); + plan.setLastOperator(std::move(aggregate)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_cross_product.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_cross_product.cpp new file mode 100644 index 0000000000..150176237f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_cross_product.cpp @@ -0,0 +1,43 @@ +#include "planner/operator/logical_cross_product.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void Planner::appendCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan, + LogicalPlan& resultPlan) { + appendCrossProduct(AccumulateType::REGULAR, nullptr /* mark */, probePlan, buildPlan, + resultPlan); +} + +void Planner::appendOptionalCrossProduct(std::shared_ptr mark, + const LogicalPlan& probePlan, const LogicalPlan& buildPlan, LogicalPlan& resultPlan) { + appendCrossProduct(AccumulateType::OPTIONAL_, mark, probePlan, buildPlan, resultPlan); +} + +void Planner::appendAccOptionalCrossProduct(std::shared_ptr mark, + LogicalPlan& probePlan, const LogicalPlan& buildPlan, LogicalPlan& resultPlan) { + KU_ASSERT(probePlan.hasUpdate()); + tryAppendAccumulate(probePlan); + appendCrossProduct(AccumulateType::OPTIONAL_, mark, probePlan, buildPlan, resultPlan); + auto& sipInfo = resultPlan.getLastOperator()->cast().getSIPInfoUnsafe(); + sipInfo.direction = SIPDirection::PROBE_TO_BUILD; +} + +void Planner::appendCrossProduct(AccumulateType accumulateType, std::shared_ptr mark, + const LogicalPlan& probePlan, const LogicalPlan& buildPlan, LogicalPlan& resultPlan) { + auto crossProduct = make_shared(accumulateType, mark, + probePlan.getLastOperator(), buildPlan.getLastOperator(), + cardinalityEstimator.estimateCrossProduct(probePlan.getLastOperatorRef(), + buildPlan.getLastOperatorRef())); + crossProduct->computeFactorizedSchema(); + // update cost + resultPlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality()); + resultPlan.setLastOperator(std::move(crossProduct)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_delete.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_delete.cpp new file mode 100644 index 0000000000..71fa26ba56 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_delete.cpp @@ -0,0 +1,19 @@ +#include "binder/query/updating_clause/bound_delete_info.h" +#include "planner/operator/persistent/logical_delete.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendDelete(const std::vector& infos, LogicalPlan& plan) { + auto delete_ = std::make_shared(copyVector(infos), plan.getLastOperator()); + appendFlattens(delete_->getGroupsPosToFlatten(), plan); + delete_->setChild(0, plan.getLastOperator()); + delete_->computeFactorizedSchema(); + plan.setLastOperator(std::move(delete_)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_distinct.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_distinct.cpp new file mode 100644 index 0000000000..e75af6490c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_distinct.cpp @@ -0,0 +1,19 @@ +#include "planner/operator/logical_distinct.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendDistinct(const expression_vector& keys, LogicalPlan& plan) { + KU_ASSERT(!keys.empty()); + auto distinct = make_shared(keys, plan.getLastOperator()); + appendFlattens(distinct->getGroupsPosToFlatten(), plan); + distinct->setChild(0, plan.getLastOperator()); + distinct->computeFactorizedSchema(); + plan.setLastOperator(std::move(distinct)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_dummy_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_dummy_scan.cpp new file mode 100644 index 0000000000..ada53fb57a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_dummy_scan.cpp @@ -0,0 +1,15 @@ +#include "planner/operator/scan/logical_dummy_scan.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +void Planner::appendDummyScan(LogicalPlan& plan) { + KU_ASSERT(plan.isEmpty()); + auto dummyScan = std::make_shared(); + dummyScan->computeFactorizedSchema(); + plan.setLastOperator(std::move(dummyScan)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_empty_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_empty_result.cpp new file mode 100644 index 0000000000..a8266ecdfa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_empty_result.cpp @@ -0,0 +1,14 @@ +#include "planner/operator/logical_empty_result.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +void Planner::appendEmptyResult(LogicalPlan& plan) { + auto op = std::make_shared(*plan.getSchema()); + op->computeFactorizedSchema(); + plan.setLastOperator(std::move(op)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_expressions_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_expressions_scan.cpp new file mode 100644 index 0000000000..7ae5536f22 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_expressions_scan.cpp @@ -0,0 +1,16 @@ +#include "planner/operator/scan/logical_expressions_scan.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan) { + auto expressionsScan = std::make_shared(expressions); + expressionsScan->computeFactorizedSchema(); + plan.setLastOperator(expressionsScan); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_extend.cpp new file mode 100644 index 0000000000..e915043deb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_extend.cpp @@ -0,0 +1,191 @@ +#include + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/enums/join_type.h" +#include "planner/join_order/cost_model.h" +#include "planner/operator/extend/logical_extend.h" +#include "planner/operator/extend/logical_recursive_extend.h" +#include "planner/operator/extend/recursive_join_type.h" +#include "planner/operator/logical_node_label_filter.h" +#include "planner/operator/logical_path_property_probe.h" +#include "planner/planner.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::catalog; +using namespace lbug::transaction; +using namespace lbug::function; + +namespace lbug { +namespace planner { + +static std::unordered_set getBoundNodeTableIDSet(const RelExpression& rel, + ExtendDirection extendDirection) { + std::unordered_set result; + for (auto entry : rel.getEntries()) { + auto& groupEntry = entry->constCast(); + switch (extendDirection) { + case ExtendDirection::FWD: { + result.merge(groupEntry.getBoundNodeTableIDSet(RelDataDirection::FWD)); + } break; + case ExtendDirection::BWD: { + result.merge(groupEntry.getBoundNodeTableIDSet(RelDataDirection::BWD)); + } break; + case ExtendDirection::BOTH: { + result.merge(groupEntry.getBoundNodeTableIDSet(RelDataDirection::FWD)); + result.merge(groupEntry.getBoundNodeTableIDSet(RelDataDirection::BWD)); + } break; + default: + KU_UNREACHABLE; + } + } + return result; +} + +static std::unordered_set getNbrNodeTableIDSet(const RelExpression& rel, + ExtendDirection extendDirection) { + std::unordered_set result; + for (auto entry : rel.getEntries()) { + auto& groupEntry = entry->constCast(); + switch (extendDirection) { + case ExtendDirection::FWD: { + result.merge(groupEntry.getNbrNodeTableIDSet(RelDataDirection::FWD)); + } break; + case ExtendDirection::BWD: { + result.merge(groupEntry.getNbrNodeTableIDSet(RelDataDirection::BWD)); + } break; + case ExtendDirection::BOTH: { + result.merge(groupEntry.getNbrNodeTableIDSet(RelDataDirection::FWD)); + result.merge(groupEntry.getNbrNodeTableIDSet(RelDataDirection::BWD)); + } break; + default: + KU_UNREACHABLE; + } + } + return result; +} + +void Planner::appendNonRecursiveExtend(const std::shared_ptr& boundNode, + const std::shared_ptr& nbrNode, const std::shared_ptr& rel, + ExtendDirection direction, bool extendFromSource, const expression_vector& properties, + LogicalPlan& plan) { + // Filter bound node label if we know some incoming nodes won't have any outgoing rel. This + // cannot be done at binding time because the pruning is affected by extend direction. + auto boundNodeTableIDSet = getBoundNodeTableIDSet(*rel, direction); + if (boundNode->getNumEntries() > boundNodeTableIDSet.size()) { + appendNodeLabelFilter(boundNode->getInternalID(), boundNodeTableIDSet, plan); + } + auto properties_ = properties; + // Append extend + auto extend = make_shared(boundNode, nbrNode, rel, direction, extendFromSource, + properties_, plan.getLastOperator()); + extend->computeFactorizedSchema(); + // Update cost & cardinality. Note that extend does not change factorized cardinality. + auto transaction = Transaction::Get(*clientContext); + const auto extensionRate = cardinalityEstimator.getExtensionRate(*rel, *boundNode, transaction); + extend->setCardinality(plan.getLastOperator()->getCardinality()); + plan.setCost(CostModel::computeExtendCost(plan)); + auto group = extend->getSchema()->getGroup(nbrNode->getInternalID()); + group->setMultiplier(extensionRate); + plan.setLastOperator(std::move(extend)); + auto nbrNodeTableIDSet = getNbrNodeTableIDSet(*rel, direction); + if (nbrNodeTableIDSet.size() > nbrNode->getNumEntries()) { + appendNodeLabelFilter(nbrNode->getInternalID(), nbrNode->getTableIDsSet(), plan); + } +} + +void Planner::appendRecursiveExtend(const std::shared_ptr& boundNode, + const std::shared_ptr& nbrNode, const std::shared_ptr& rel, + ExtendDirection direction, LogicalPlan& plan) { + // GDS pipeline + auto recursiveInfo = rel->getRecursiveInfo(); + // Fill bind data with direction information. This can only be decided at planning time. + auto bindData = recursiveInfo->bindData.get(); + bindData->nodeOutput = nbrNode; + bindData->nodeInput = boundNode; + bindData->extendDirection = direction; + // If we extend from right to left, we need to print path in reverse direction. + bindData->flipPath = *boundNode == *rel->getRightNode(); + auto resultColumns = recursiveInfo->function->getResultColumns(*bindData); + auto recursiveExtend = std::make_shared(recursiveInfo->function->copy(), + *recursiveInfo->bindData, resultColumns); + if (recursiveInfo->nodePredicate != nullptr) { + auto p = getNodeSemiMaskPlan(SemiMaskTargetType::RECURSIVE_EXTEND_PATH_NODE, + *recursiveInfo->node, recursiveInfo->nodePredicate); + recursiveExtend->addChild(p.getLastOperator()); + } + recursiveExtend->computeFactorizedSchema(); + auto probePlan = LogicalPlan(); + probePlan.setLastOperator(std::move(recursiveExtend)); + // Scan path node property pipeline + std::shared_ptr pathNodePropertyScanRoot = nullptr; + if (!recursiveInfo->nodeProjectionList.empty()) { + auto pathNodePropertyScanPlan = LogicalPlan(); + createPathNodePropertyScanPlan(recursiveInfo->node, recursiveInfo->nodeProjectionList, + pathNodePropertyScanPlan); + pathNodePropertyScanRoot = pathNodePropertyScanPlan.getLastOperator(); + } + // Scan path rel property pipeline + std::shared_ptr pathRelPropertyScanRoot = nullptr; + if (!recursiveInfo->relProjectionList.empty()) { + auto pathRelPropertyScanPlan = LogicalPlan(); + auto relProperties = recursiveInfo->relProjectionList; + relProperties.push_back(recursiveInfo->rel->getInternalID()); + bool extendFromSource = *boundNode == *rel->getSrcNode(); + createPathRelPropertyScanPlan(recursiveInfo->node, recursiveInfo->nodeCopy, + recursiveInfo->rel, direction, extendFromSource, relProperties, + pathRelPropertyScanPlan); + pathRelPropertyScanRoot = pathRelPropertyScanPlan.getLastOperator(); + } + // Construct path by probing scanned properties + auto pathPropertyProbe = + std::make_shared(rel, probePlan.getLastOperator(), + pathNodePropertyScanRoot, pathRelPropertyScanRoot, RecursiveJoinType::TRACK_PATH); + pathPropertyProbe->direction = direction; + pathPropertyProbe->extendFromLeft = *boundNode == *rel->getLeftNode(); + pathPropertyProbe->pathNodeIDs = recursiveInfo->bindData->pathNodeIDsExpr; + pathPropertyProbe->pathEdgeIDs = recursiveInfo->bindData->pathEdgeIDsExpr; + pathPropertyProbe->computeFactorizedSchema(); + auto transaction = Transaction::Get(*clientContext); + auto extensionRate = cardinalityEstimator.getExtensionRate(*rel, *boundNode, transaction); + auto resultCard = + cardinalityEstimator.multiply(extensionRate, plan.getLastOperator()->getCardinality()); + pathPropertyProbe->setCardinality(resultCard); + probePlan.setLastOperator(pathPropertyProbe); + probePlan.setCost(plan.getCardinality()); + + // Join with input node + auto joinConditions = expression_vector{boundNode->getInternalID()}; + appendHashJoin(joinConditions, JoinType::INNER, probePlan, plan, plan); + // Hash join above is joining input node with its properties. So 1-1 match is guaranteed and + // thus should not change cardinality. + plan.getLastOperator()->setCardinality(resultCard); +} + +void Planner::createPathNodePropertyScanPlan(const std::shared_ptr& node, + const expression_vector& properties, LogicalPlan& plan) { + appendScanNodeTable(node->getInternalID(), node->getTableIDs(), properties, plan); +} + +void Planner::createPathRelPropertyScanPlan(const std::shared_ptr& boundNode, + const std::shared_ptr& nbrNode, const std::shared_ptr& rel, + ExtendDirection direction, bool extendFromSource, const expression_vector& properties, + LogicalPlan& plan) { + appendScanNodeTable(boundNode->getInternalID(), boundNode->getTableIDs(), {}, plan); + appendNonRecursiveExtend(boundNode, nbrNode, rel, direction, extendFromSource, properties, + plan); + appendProjection(properties, plan); +} + +void Planner::appendNodeLabelFilter(std::shared_ptr nodeID, + std::unordered_set tableIDSet, LogicalPlan& plan) { + auto filter = std::make_shared(std::move(nodeID), std::move(tableIDSet), + plan.getLastOperator()); + filter->computeFactorizedSchema(); + plan.setLastOperator(std::move(filter)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_filter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_filter.cpp new file mode 100644 index 0000000000..0d9d7de4ed --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_filter.cpp @@ -0,0 +1,28 @@ +#include "planner/operator/logical_filter.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendFilters(const expression_vector& predicates, LogicalPlan& plan) { + for (auto& predicate : predicates) { + appendFilter(predicate, plan); + } +} + +void Planner::appendFilter(const std::shared_ptr& predicate, LogicalPlan& plan) { + planSubqueryIfNecessary(predicate, plan); + auto filter = make_shared(predicate, plan.getLastOperator()); + appendFlattens(filter->getGroupsPosToFlatten(), plan); + filter->setChild(0, plan.getLastOperator()); + filter->computeFactorizedSchema(); + // estimate cardinality + filter->setCardinality( + cardinalityEstimator.estimateFilter(plan.getLastOperatorRef(), *predicate)); + plan.setLastOperator(std::move(filter)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_flatten.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_flatten.cpp new file mode 100644 index 0000000000..7b1ea2ab3b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_flatten.cpp @@ -0,0 +1,25 @@ +#include "planner/operator/logical_flatten.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +void Planner::appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan) { + for (auto groupPos : groupsPos) { + appendFlattenIfNecessary(groupPos, plan); + } +} + +void Planner::appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan) { + auto group = plan.getSchema()->getGroup(groupPos); + if (group->isFlat()) { + return; + } + auto flatten = make_shared(groupPos, plan.getLastOperator(), + cardinalityEstimator.estimateFlatten(plan.getLastOperatorRef(), groupPos)); + flatten->computeFactorizedSchema(); + plan.setLastOperator(std::move(flatten)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_insert.cpp new file mode 100644 index 0000000000..9fe3a027d5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_insert.cpp @@ -0,0 +1,55 @@ +#include "binder/query/updating_clause/bound_insert_info.h" +#include "planner/operator/persistent/logical_insert.h" +#include "planner/planner.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +std::unique_ptr Planner::createLogicalInsertInfo( + const BoundInsertInfo* info) const { + auto insertInfo = std::make_unique(info->tableType, info->pattern, + info->columnExprs, info->columnDataExprs, info->conflictAction); + binder::expression_set propertyExprSet; + for (auto& expr : getProperties(*info->pattern)) { + propertyExprSet.insert(expr); + } + for (auto& expr : insertInfo->columnExprs) { + insertInfo->isReturnColumnExprs.push_back(propertyExprSet.contains(expr)); + } + return insertInfo; +} + +void Planner::appendInsertNode(const std::vector& boundInsertInfos, + LogicalPlan& plan) { + std::vector logicalInfos; + logicalInfos.reserve(boundInsertInfos.size()); + for (auto& boundInfo : boundInsertInfos) { + logicalInfos.push_back(createLogicalInsertInfo(boundInfo)->copy()); + } + auto insertNode = + std::make_shared(std::move(logicalInfos), plan.getLastOperator()); + appendFlattens(insertNode->getGroupsPosToFlatten(), plan); + insertNode->setChild(0, plan.getLastOperator()); + insertNode->computeFactorizedSchema(); + plan.setLastOperator(insertNode); +} + +void Planner::appendInsertRel(const std::vector& boundInsertInfos, + LogicalPlan& plan) { + std::vector logicalInfos; + logicalInfos.reserve(boundInsertInfos.size()); + for (auto& boundInfo : boundInsertInfos) { + logicalInfos.push_back(createLogicalInsertInfo(boundInfo)->copy()); + } + auto insertRel = + std::make_shared(std::move(logicalInfos), plan.getLastOperator()); + appendFlattens(insertRel->getGroupsPosToFlatten(), plan); + insertRel->setChild(0, plan.getLastOperator()); + insertRel->computeFactorizedSchema(); + plan.setLastOperator(insertRel); +} +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_join.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_join.cpp new file mode 100644 index 0000000000..9c514756e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_join.cpp @@ -0,0 +1,123 @@ +#include "planner/join_order/cost_model.h" +#include "planner/operator/logical_hash_join.h" +#include "planner/operator/logical_intersect.h" +#include "planner/planner.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType joinType, + LogicalPlan& probePlan, LogicalPlan& buildPlan, LogicalPlan& resultPlan) { + appendHashJoin(joinNodeIDs, joinType, nullptr /* mark */, probePlan, buildPlan, resultPlan); +} + +void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType joinType, + std::shared_ptr mark, LogicalPlan& probePlan, LogicalPlan& buildPlan, + LogicalPlan& resultPlan) { + std::vector joinConditions; + for (auto& joinNodeID : joinNodeIDs) { + joinConditions.emplace_back(joinNodeID, joinNodeID); + } + appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan); +} + +void Planner::appendHashJoin(const std::vector& joinConditions, JoinType joinType, + std::shared_ptr mark, LogicalPlan& probePlan, LogicalPlan& buildPlan, + LogicalPlan& resultPlan) { + auto hashJoin = make_shared(joinConditions, joinType, mark, + probePlan.getLastOperator(), buildPlan.getLastOperator()); + // Apply flattening to probe side + auto groupsPosToFlattenOnProbeSide = hashJoin->getGroupsPosToFlattenOnProbeSide(); + appendFlattens(groupsPosToFlattenOnProbeSide, probePlan); + hashJoin->setChild(0, probePlan.getLastOperator()); + // Apply flattening to build side + appendFlattens(hashJoin->getGroupsPosToFlattenOnBuildSide(), buildPlan); + hashJoin->setChild(1, buildPlan.getLastOperator()); + hashJoin->computeFactorizedSchema(); + // Check for sip + if (probePlan.getCardinality() > buildPlan.getCardinality() * PlannerKnobs::SIP_RATIO) { + hashJoin->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD; + } + // Update cost + hashJoin->setCardinality(cardinalityEstimator.estimateHashJoin(joinConditions, + probePlan.getLastOperatorRef(), buildPlan.getLastOperatorRef())); + resultPlan.setCost(CostModel::computeHashJoinCost(joinConditions, probePlan, buildPlan)); + resultPlan.setLastOperator(std::move(hashJoin)); +} + +void Planner::appendAccHashJoin(const std::vector& joinConditions, + JoinType joinType, std::shared_ptr mark, LogicalPlan& probePlan, + LogicalPlan& buildPlan, LogicalPlan& resultPlan) { + KU_ASSERT(probePlan.hasUpdate()); + tryAppendAccumulate(probePlan); + appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan); + auto& sipInfo = probePlan.getLastOperator()->cast().getSIPInfoUnsafe(); + sipInfo.direction = SIPDirection::PROBE_TO_BUILD; +} + +void Planner::appendMarkJoin(const expression_vector& joinNodeIDs, + const std::shared_ptr& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan, + LogicalPlan& resultPlan) { + std::vector joinConditions; + for (auto& joinNodeID : joinNodeIDs) { + joinConditions.emplace_back(joinNodeID, joinNodeID); + } + appendMarkJoin(joinConditions, mark, probePlan, buildPlan, resultPlan); +} + +void Planner::appendMarkJoin(const std::vector& joinConditions, + const std::shared_ptr& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan, + LogicalPlan& resultPlan) { + auto hashJoin = make_shared(joinConditions, JoinType::MARK, mark, + probePlan.getLastOperator(), buildPlan.getLastOperator()); + // Apply flattening to probe side + appendFlattens(hashJoin->getGroupsPosToFlattenOnProbeSide(), probePlan); + hashJoin->setChild(0, probePlan.getLastOperator()); + // Apply flattening to build side + appendFlattens(hashJoin->getGroupsPosToFlattenOnBuildSide(), buildPlan); + hashJoin->setChild(1, buildPlan.getLastOperator()); + hashJoin->computeFactorizedSchema(); + // update cost. Mark join does not change cardinality. + hashJoin->setCardinality(probePlan.getCardinality()); + resultPlan.setCost(CostModel::computeMarkJoinCost(joinConditions, probePlan, buildPlan)); + resultPlan.setLastOperator(std::move(hashJoin)); +} + +void Planner::appendIntersect(const std::shared_ptr& intersectNodeID, + expression_vector& boundNodeIDs, LogicalPlan& probePlan, std::vector& buildPlans) { + KU_ASSERT(boundNodeIDs.size() == buildPlans.size()); + std::vector> buildChildren; + expression_vector keyNodeIDs; + for (auto i = 0u; i < buildPlans.size(); ++i) { + keyNodeIDs.push_back(boundNodeIDs[i]); + buildChildren.push_back(buildPlans[i].getLastOperator()); + } + auto intersect = make_shared(intersectNodeID, std::move(keyNodeIDs), + probePlan.getLastOperator(), std::move(buildChildren)); + appendFlattens(intersect->getGroupsPosToFlattenOnProbeSide(), probePlan); + intersect->setChild(0, probePlan.getLastOperator()); + for (auto i = 0u; i < buildPlans.size(); ++i) { + appendFlattens(intersect->getGroupsPosToFlattenOnBuildSide(i), buildPlans[i]); + intersect->setChild(i + 1, buildPlans[i].getLastOperator()); + auto ratio = probePlan.getCardinality() / buildPlans[i].getCardinality(); + if (ratio > PlannerKnobs::SIP_RATIO) { + intersect->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT; + } + } + intersect->computeFactorizedSchema(); + // update cost + std::vector buildOps; + for (const auto& plan : buildPlans) { + buildOps.push_back(plan.getLastOperator().get()); + } + intersect->setCardinality(cardinalityEstimator.estimateIntersect(boundNodeIDs, + probePlan.getLastOperatorRef(), buildOps)); + probePlan.setCost(CostModel::computeIntersectCost(probePlan, buildPlans)); + probePlan.setLastOperator(std::move(intersect)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_limit.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_limit.cpp new file mode 100644 index 0000000000..79bd52b0b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_limit.cpp @@ -0,0 +1,17 @@ +#include "planner/operator/logical_limit.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +void Planner::appendLimit(std::shared_ptr skipNum, + std::shared_ptr limitNum, LogicalPlan& plan) { + auto limit = make_shared(skipNum, limitNum, plan.getLastOperator()); + appendFlattens(limit->getGroupsPosToFlatten(), plan); + limit->setChild(0, plan.getLastOperator()); + limit->computeFactorizedSchema(); + plan.setLastOperator(std::move(limit)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_multiplicity_reducer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_multiplicity_reducer.cpp new file mode 100644 index 0000000000..77d1aa3919 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_multiplicity_reducer.cpp @@ -0,0 +1,14 @@ +#include "planner/operator/logical_multiplcity_reducer.h" +#include "planner/planner.h" + +namespace lbug { +namespace planner { + +void Planner::appendMultiplicityReducer(LogicalPlan& plan) { + auto multiplicityReducer = make_shared(plan.getLastOperator()); + multiplicityReducer->computeFactorizedSchema(); + plan.setLastOperator(std::move(multiplicityReducer)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_order_by.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_order_by.cpp new file mode 100644 index 0000000000..0ca37ab657 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_order_by.cpp @@ -0,0 +1,19 @@ +#include "planner/operator/logical_order_by.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendOrderBy(const expression_vector& expressions, + const std::vector& isAscOrders, LogicalPlan& plan) { + auto orderBy = make_shared(expressions, isAscOrders, plan.getLastOperator()); + appendFlattens(orderBy->getGroupsPosToFlatten(), plan); + orderBy->setChild(0, plan.getLastOperator()); + orderBy->computeFactorizedSchema(); + plan.setLastOperator(std::move(orderBy)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_projection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_projection.cpp new file mode 100644 index 0000000000..96919e75d0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_projection.cpp @@ -0,0 +1,38 @@ +#include "binder/expression_visitor.h" +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/logical_projection.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendProjection(const expression_vector& expressionsToProject, LogicalPlan& plan) { + for (auto& expression : expressionsToProject) { + planSubqueryIfNecessary(expression, plan); + } + bool hasRandomFunction = false; + for (auto& expr : expressionsToProject) { + if (ExpressionVisitor::isRandom(*expr)) { + hasRandomFunction = true; + } + } + if (hasRandomFunction) { + // Fall back to tuple-at-a-time evaluation. + appendMultiplicityReducer(plan); + appendFlattens(plan.getSchema()->getGroupsPosInScope(), plan); + } else { + for (auto& expression : expressionsToProject) { + auto groupsPosToFlatten = + FlattenAllButOne::getGroupsPosToFlatten(expression, *plan.getSchema()); + appendFlattens(groupsPosToFlatten, plan); + } + } + auto projection = make_shared(expressionsToProject, plan.getLastOperator()); + projection->computeFactorizedSchema(); + plan.setLastOperator(std::move(projection)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_scan_node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_scan_node_table.cpp new file mode 100644 index 0000000000..72167dee03 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_scan_node_table.cpp @@ -0,0 +1,33 @@ +#include "binder/expression/property_expression.h" +#include "planner/operator/scan/logical_scan_node_table.h" +#include "planner/planner.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +static expression_vector removeInternalIDProperty(const expression_vector& expressions) { + expression_vector result; + for (auto expr : expressions) { + if (expr->constCast().isInternalID()) { + continue; + } + result.push_back(expr); + } + return result; +} + +void Planner::appendScanNodeTable(std::shared_ptr nodeID, + std::vector tableIDs, const expression_vector& properties, LogicalPlan& plan) { + auto propertiesToScan_ = removeInternalIDProperty(properties); + auto scan = make_shared(std::move(nodeID), std::move(tableIDs), + propertiesToScan_); + scan->computeFactorizedSchema(); + scan->setCardinality(cardinalityEstimator.estimateScanNode(*scan)); + plan.setLastOperator(std::move(scan)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_set.cpp new file mode 100644 index 0000000000..c1de9e5b1a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_set.cpp @@ -0,0 +1,22 @@ +#include "binder/query/updating_clause/bound_set_info.h" +#include "planner/operator/persistent/logical_set.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendSetProperty(const std::vector& infos, LogicalPlan& plan) { + auto set = std::make_shared(copyVector(infos), plan.getLastOperator()); + for (auto i = 0u; i < set->getInfos().size(); ++i) { + auto groupsPos = set->getGroupsPosToFlatten(i); + appendFlattens(groupsPos, plan); + set->setChild(0, plan.getLastOperator()); + } + set->computeFactorizedSchema(); + plan.setLastOperator(std::move(set)); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_simple.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_simple.cpp new file mode 100644 index 0000000000..60fca72419 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_simple.cpp @@ -0,0 +1,170 @@ +#include "binder/bound_attach_database.h" +#include "binder/bound_create_macro.h" +#include "binder/bound_detach_database.h" +#include "binder/bound_explain.h" +#include "binder/bound_extension_statement.h" +#include "binder/bound_standalone_call.h" +#include "binder/bound_standalone_call_function.h" +#include "binder/bound_transaction_statement.h" +#include "binder/bound_use_database.h" +#include "binder/ddl/bound_alter.h" +#include "binder/ddl/bound_create_sequence.h" +#include "binder/ddl/bound_create_table.h" +#include "binder/ddl/bound_create_type.h" +#include "binder/ddl/bound_drop.h" +#include "extension/planner_extension.h" +#include "planner/operator/ddl/logical_alter.h" +#include "planner/operator/ddl/logical_create_sequence.h" +#include "planner/operator/ddl/logical_create_table.h" +#include "planner/operator/ddl/logical_create_type.h" +#include "planner/operator/ddl/logical_drop.h" +#include "planner/operator/logical_create_macro.h" +#include "planner/operator/logical_explain.h" +#include "planner/operator/logical_noop.h" +#include "planner/operator/logical_standalone_call.h" +#include "planner/operator/logical_table_function_call.h" +#include "planner/operator/logical_transaction.h" +#include "planner/operator/simple/logical_attach_database.h" +#include "planner/operator/simple/logical_detach_database.h" +#include "planner/operator/simple/logical_extension.h" +#include "planner/operator/simple/logical_use_database.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +static LogicalPlan getSimplePlan(std::shared_ptr op) { + LogicalPlan plan; + op->computeFactorizedSchema(); + plan.setLastOperator(std::move(op)); + return plan; +} + +LogicalPlan Planner::planCreateTable(const BoundStatement& statement) { + auto& createTable = statement.constCast(); + auto& info = createTable.getInfo(); + // If it is a CREATE NODE TABLE AS, then copy as well + if (createTable.hasCopyInfo()) { + std::vector> children; + switch (info.type) { + case catalog::CatalogEntryType::NODE_TABLE_ENTRY: { + children.push_back(planCopyNodeFrom(&createTable.getCopyInfo()).getLastOperator()); + } break; + case catalog::CatalogEntryType::REL_GROUP_ENTRY: { + children.push_back(planCopyRelFrom(&createTable.getCopyInfo()).getLastOperator()); + } break; + default: { + KU_UNREACHABLE; + } + } + auto create = std::make_shared(info.copy()); + children.push_back(std::move(create)); + auto noop = std::make_shared(children.size() - 1, children); + return getSimplePlan(std::move(noop)); + } + auto op = std::make_shared(info.copy()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planCreateType(const BoundStatement& statement) { + auto& createType = statement.constCast(); + auto op = + std::make_shared(createType.getName(), createType.getType().copy()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planCreateSequence(const BoundStatement& statement) { + auto& createSequence = statement.constCast(); + auto& info = createSequence.getInfo(); + auto op = std::make_shared(info.copy()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planDrop(const BoundStatement& statement) { + auto& dropTable = statement.constCast(); + auto op = std::make_shared(dropTable.getDropInfo()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planAlter(const BoundStatement& statement) { + auto& alter = statement.constCast(); + auto op = std::make_shared(alter.getInfo().copy()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planStandaloneCall(const BoundStatement& statement) { + auto& standaloneCallClause = statement.constCast(); + auto op = std::make_shared(standaloneCallClause.getOption(), + standaloneCallClause.getOptionValue()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planStandaloneCallFunction(const BoundStatement& statement) { + auto& standaloneCallFunctionClause = statement.constCast(); + auto op = + std::make_shared(standaloneCallFunctionClause.getTableFunction(), + standaloneCallFunctionClause.getBindData()->copy()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planExplain(const BoundStatement& statement) { + auto& explain = statement.constCast(); + auto statementToExplain = explain.getStatementToExplain(); + auto planToExplain = planStatement(*statementToExplain); + auto op = std::make_shared(planToExplain.getLastOperator(), + explain.getExplainType(), statementToExplain->getStatementResult()->getColumns()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planCreateMacro(const BoundStatement& statement) { + auto& createMacro = statement.constCast(); + auto op = + std::make_shared(createMacro.getMacroName(), createMacro.getMacro()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planTransaction(const BoundStatement& statement) { + auto& transactionStatement = statement.constCast(); + auto op = std::make_shared(transactionStatement.getTransactionAction()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planExtension(const BoundStatement& statement) { + auto& extensionStatement = statement.constCast(); + auto op = std::make_shared(extensionStatement.getAuxInfo()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planAttachDatabase(const BoundStatement& statement) { + auto& boundAttachDatabase = statement.constCast(); + auto op = std::make_shared(boundAttachDatabase.getAttachInfo()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planDetachDatabase(const BoundStatement& statement) { + auto& boundDetachDatabase = statement.constCast(); + auto op = std::make_shared(boundDetachDatabase.getDBName()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planUseDatabase(const BoundStatement& statement) { + auto& boundUseDatabase = statement.constCast(); + auto op = std::make_shared(boundUseDatabase.getDBName()); + return getSimplePlan(std::move(op)); +} + +LogicalPlan Planner::planExtensionClause(const BoundStatement& statement) { + for (auto& plannerExtension : plannerExtensions) { + auto op = plannerExtension->plan(statement); + if (op) { + return getSimplePlan(op); + } + } + KU_UNREACHABLE; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_table_function_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_table_function_call.cpp new file mode 100644 index 0000000000..b0688801fa --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_table_function_call.cpp @@ -0,0 +1,33 @@ +#include "binder/bound_table_scan_info.h" +#include "binder/query/reading_clause/bound_table_function_call.h" +#include "planner/operator/logical_table_function_call.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::appendTableFunctionCall(const BoundTableScanInfo& info, LogicalPlan& plan) { + auto call = std::make_shared(info.func, info.bindData->copy()); + call->computeFactorizedSchema(); + plan.setLastOperator(std::move(call)); +} + +std::shared_ptr Planner::getTableFunctionCall(const BoundTableScanInfo& info) { + auto call = std::make_shared(info.func, info.bindData->copy()); + call->computeFactorizedSchema(); + return call; +} + +std::shared_ptr Planner::getTableFunctionCall( + const BoundReadingClause& readingClause) { + auto& call = readingClause.constCast(); + auto op = + std::make_shared(call.getTableFunc(), call.getBindData()->copy()); + op->computeFactorizedSchema(); + return op; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_unwind.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_unwind.cpp new file mode 100644 index 0000000000..53404d2858 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/append_unwind.cpp @@ -0,0 +1,22 @@ +#include "binder/query/reading_clause/bound_unwind_clause.h" +#include "planner/operator/logical_unwind.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void Planner::appendUnwind(const BoundReadingClause& readingClause, LogicalPlan& plan) { + auto& unwindClause = ku_dynamic_cast(readingClause); + auto unwind = make_shared(unwindClause.getInExpr(), unwindClause.getOutExpr(), + unwindClause.getIDExpr(), plan.getLastOperator()); + appendFlattens(unwind->getGroupsPosToFlatten(), plan); + unwind->setChild(0, plan.getLastOperator()); + unwind->computeFactorizedSchema(); + plan.setLastOperator(unwind); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_copy.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_copy.cpp new file mode 100644 index 0000000000..75405d674f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_copy.cpp @@ -0,0 +1,141 @@ +#include "binder/copy/bound_copy_from.h" +#include "binder/copy/bound_copy_to.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "planner/operator/logical_partitioner.h" +#include "planner/operator/persistent/logical_copy_from.h" +#include "planner/operator/persistent/logical_copy_to.h" +#include "planner/operator/scan/logical_index_look_up.h" +#include "planner/planner.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::storage; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace planner { + +static void appendIndexScan(const ExtraBoundCopyRelInfo& extraInfo, LogicalPlan& plan) { + auto indexScan = + std::make_shared(extraInfo.infos, plan.getLastOperator()); + indexScan->computeFactorizedSchema(); + plan.setLastOperator(std::move(indexScan)); +} + +static void appendPartitioner(const BoundCopyFromInfo& copyFromInfo, LogicalPlan& plan, + const std::vector& directions) { + LogicalPartitionerInfo info(copyFromInfo.offset); + for (auto& direction : directions) { + info.partitioningInfos.push_back( + LogicalPartitioningInfo(RelDirectionUtils::relDirectionToKeyIdx(direction))); + } + auto partitioner = std::make_shared(std::move(info), copyFromInfo.copy(), + plan.getLastOperator()); + partitioner->computeFactorizedSchema(); + plan.setLastOperator(std::move(partitioner)); +} + +static void appendCopyFrom(const BoundCopyFromInfo& info, LogicalPlan& plan) { + auto op = make_shared(info.copy(), plan.getLastOperator()); + op->computeFactorizedSchema(); + plan.setLastOperator(std::move(op)); +} + +LogicalPlan Planner::planCopyFrom(const BoundStatement& statement) { + auto& copyFrom = statement.constCast(); + auto copyFromInfo = copyFrom.getInfo(); + switch (copyFromInfo->tableType) { + case TableType::NODE: { + return planCopyNodeFrom(copyFromInfo); + } + case TableType::REL: { + return planCopyRelFrom(copyFromInfo); + } + default: + KU_UNREACHABLE; + } +} + +LogicalPlan Planner::planCopyNodeFrom(const BoundCopyFromInfo* info) { + auto plan = LogicalPlan(); + switch (info->source->type) { + case ScanSourceType::FILE: + case ScanSourceType::OBJECT: { + auto& scanSource = info->source->constCast(); + appendTableFunctionCall(scanSource.info, plan); + } break; + case ScanSourceType::QUERY: { + auto& querySource = info->source->constCast(); + plan = planQuery(*querySource.statement); + if (plan.getSchema()->getNumGroups() > 1) { + // Copy operator assumes all input are in the same data chunk. If this is not the case, + // we first materialize input in flat form into a factorized table. + appendAccumulate(AccumulateType::REGULAR, plan.getSchema()->getExpressionsInScope(), + nullptr /* mark */, plan); + } + } break; + default: + KU_UNREACHABLE; + } + appendCopyFrom(*info, plan); + return plan; +} + +LogicalPlan Planner::planCopyRelFrom(const BoundCopyFromInfo* info) { + auto plan = LogicalPlan(); + switch (info->source->type) { + case ScanSourceType::FILE: + case ScanSourceType::OBJECT: { + auto& fileSource = info->source->constCast(); + appendTableFunctionCall(fileSource.info, plan); + } break; + case ScanSourceType::QUERY: { + auto& querySource = info->source->constCast(); + plan = planQuery(*querySource.statement); + if (plan.getSchema()->getNumGroups() == 1 && !plan.getSchema()->getGroup(0)->isFlat()) { + break; + } + // Copy operator assumes all input are in the same data chunk. If this is not the case, + // we first materialize input in flat form into a factorized table. + appendAccumulate(AccumulateType::REGULAR, plan.getSchema()->getExpressionsInScope(), + nullptr /* mark */, plan); + } break; + default: + KU_UNREACHABLE; + } + auto& extraInfo = info->extraInfo->constCast(); + // If the table entry doesn't exist, assume both directions + std::vector directions = {RelDataDirection::FWD, RelDataDirection::BWD}; + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + if (catalog->containsTable(transaction, info->tableName)) { + const auto& relGroupEntry = catalog->getTableCatalogEntry(transaction, info->tableName) + ->constCast(); + directions = relGroupEntry.getRelDataDirections(); + } + appendIndexScan(extraInfo, plan); + appendPartitioner(*info, plan, directions); + appendCopyFrom(*info, plan); + return plan; +} + +LogicalPlan Planner::planCopyTo(const BoundStatement& statement) { + auto& boundCopyTo = statement.constCast(); + auto regularQuery = boundCopyTo.getRegularQuery(); + std::vector columnNames; + for (auto& column : regularQuery->getStatementResult()->getColumns()) { + columnNames.push_back(column->toString()); + } + KU_ASSERT(regularQuery->getStatementType() == StatementType::QUERY); + auto plan = planStatement(*regularQuery); + auto copyTo = make_shared(boundCopyTo.getBindData()->copy(), + boundCopyTo.getExportFunc(), plan.getLastOperator()); + plan.setLastOperator(std::move(copyTo)); + return plan; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_join_order.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_join_order.cpp new file mode 100644 index 0000000000..b582b8c668 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_join_order.cpp @@ -0,0 +1,627 @@ +#include + +#include "binder/expression_visitor.h" +#include "common/enums/join_type.h" +#include "common/enums/rel_direction.h" +#include "common/utils.h" +#include "planner/join_order/cost_model.h" +#include "planner/join_order/join_plan_solver.h" +#include "planner/join_order/join_tree_constructor.h" +#include "planner/operator/scan/logical_scan_node_table.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +LogicalPlan Planner::planQueryGraphCollectionInNewContext( + const QueryGraphCollection& queryGraphCollection, const QueryGraphPlanningInfo& info) { + auto prevContext = enterNewContext(); + auto plan = planQueryGraphCollection(queryGraphCollection, info); + exitContext(std::move(prevContext)); + return plan; +} + +static int32_t getConnectedQueryGraphIdx(const QueryGraphCollection& queryGraphCollection, + const QueryGraphPlanningInfo& info) { + for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) { + auto queryGraph = queryGraphCollection.getQueryGraph(i); + for (auto& queryNode : queryGraph->getQueryNodes()) { + if (info.containsCorrExpr(*queryNode->getInternalID())) { + return i; + } + } + } + return -1; +} + +LogicalPlan Planner::planQueryGraphCollection(const QueryGraphCollection& queryGraphCollection, + const QueryGraphPlanningInfo& info) { + KU_ASSERT(queryGraphCollection.getNumQueryGraphs() > 0); + auto& corrExprs = info.corrExprs; + int32_t queryGraphIdxToPlanExpressionsScan = -1; + if (info.subqueryType == SubqueryPlanningType::CORRELATED) { + // Pick a query graph to plan ExpressionsScan. If -1 is returned, we fall back to cross + // product. + queryGraphIdxToPlanExpressionsScan = getConnectedQueryGraphIdx(queryGraphCollection, info); + } + std::unordered_set evaluatedPredicatesIndices; + std::vector planPerQueryGraph; + for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) { + auto queryGraph = queryGraphCollection.getQueryGraph(i); + // Extract predicates for current query graph + std::unordered_set predicateToEvaluateIndices; + for (auto j = 0u; j < info.predicates.size(); ++j) { + if (info.predicates[j]->expressionType == ExpressionType::LITERAL) { + continue; + } + if (evaluatedPredicatesIndices.contains(j)) { + continue; + } + if (queryGraph->canProjectExpression(info.predicates[j])) { + predicateToEvaluateIndices.insert(j); + } + } + evaluatedPredicatesIndices.insert(predicateToEvaluateIndices.begin(), + predicateToEvaluateIndices.end()); + expression_vector predicatesToEvaluate; + for (auto idx : predicateToEvaluateIndices) { + predicatesToEvaluate.push_back(info.predicates[idx]); + } + LogicalPlan plan; + auto newInfo = info; + newInfo.predicates = predicatesToEvaluate; + switch (info.subqueryType) { + case SubqueryPlanningType::NONE: + case SubqueryPlanningType::UNNEST_CORRELATED: { + plan = planQueryGraph(*queryGraph, newInfo); + } break; + case SubqueryPlanningType::CORRELATED: { + if (i == (uint32_t)queryGraphIdxToPlanExpressionsScan) { + // Plan ExpressionsScan with current query graph. + plan = planQueryGraph(*queryGraph, newInfo); + } else { + // Plan current query graph as an isolated query graph. + newInfo.subqueryType = SubqueryPlanningType::NONE; + plan = planQueryGraph(*queryGraph, newInfo); + } + } break; + default: + KU_UNREACHABLE; + } + planPerQueryGraph.push_back(std::move(plan)); + } + // Fail to plan ExpressionsScan with any query graph. Plan it independently and fall back to + // cross product. + if (info.subqueryType == SubqueryPlanningType::CORRELATED && + queryGraphIdxToPlanExpressionsScan == -1) { + auto plan = LogicalPlan(); + appendExpressionsScan(corrExprs, plan); + appendDistinct(corrExprs, plan); + planPerQueryGraph.push_back(std::move(plan)); + } + // Take cross products + auto plan = planPerQueryGraph[0].copy(); + for (auto i = 1u; i < planPerQueryGraph.size(); ++i) { + appendCrossProduct(plan, planPerQueryGraph[i], plan); + } + // Apply remaining predicates + expression_vector remainingPredicates; + for (auto i = 0u; i < info.predicates.size(); ++i) { + if (!evaluatedPredicatesIndices.contains(i)) { + remainingPredicates.push_back(info.predicates[i]); + } + } + for (auto& predicate : remainingPredicates) { + appendFilter(predicate, plan); + } + return plan; +} + +LogicalPlan Planner::planQueryGraph(const QueryGraph& queryGraph, + const QueryGraphPlanningInfo& info) { + context.init(&queryGraph, info.predicates); + cardinalityEstimator.init(queryGraph); + if (info.hint != nullptr) { + auto constructor = + JoinTreeConstructor(queryGraph, propertyExprCollection, info.predicates, info); + auto joinTree = constructor.construct(info.hint); + auto plan = JoinPlanSolver(this).solve(joinTree); + return plan.copy(); + } + planBaseTableScans(info); + context.currentLevel++; + while (context.currentLevel < context.maxLevel) { + planLevel(context.currentLevel++); + } + + auto& plans = context.getPlans(context.getFullyMatchedSubqueryGraph()); + auto bestIdx = 0; + for (auto i = 1u; i < plans.size(); ++i) { + if (plans[i].getCost() < plans[bestIdx].getCost()) { + bestIdx = i; + } + } + auto bestPlan = plans[bestIdx].copy(); + if (queryGraph.isEmpty()) { + appendEmptyResult(bestPlan); + } + return bestPlan; +} + +void Planner::planLevel(uint32_t level) { + KU_ASSERT(level > 1); + if (level > MAX_LEVEL_TO_PLAN_EXACTLY) { + planLevelApproximately(level); + } else { + planLevelExactly(level); + } +} + +void Planner::planLevelExactly(uint32_t level) { + auto maxLeftLevel = floor(level / 2.0); + for (auto leftLevel = 1u; leftLevel <= maxLeftLevel; ++leftLevel) { + auto rightLevel = level - leftLevel; + if (leftLevel > 1) { // wcoj requires at least 2 rels + planWCOJoin(leftLevel, rightLevel); + } + planInnerJoin(leftLevel, rightLevel); + } +} + +void Planner::planLevelApproximately(uint32_t level) { + planInnerJoin(1, level - 1); +} + +void Planner::planBaseTableScans(const QueryGraphPlanningInfo& info) { + auto queryGraph = context.getQueryGraph(); + switch (info.subqueryType) { + case SubqueryPlanningType::NONE: { + for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) { + planNodeScan(nodePos); + } + } break; + case SubqueryPlanningType::UNNEST_CORRELATED: { + for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) { + auto queryNode = queryGraph->getQueryNode(nodePos); + if (info.containsCorrExpr(*queryNode->getInternalID())) { + // NodeID will be a join condition with outer plan so very likely we will apply a + // semi mask later in the optimization stage. So we can assume the cardinality will + // not exceed outer plan cardinality. + cardinalityEstimator.rectifyCardinality(*queryNode->getInternalID(), + info.corrExprsCard); + // In un-nested subquery, e.g. MATCH (a) OPTIONAL MATCH (a)-[e1]->(b), the inner + // query ("(a)-[e1]->(b)") needs to scan a, which is already scanned in the outer + // query (a). To avoid scanning storage twice, we keep track of node table "a" and + // make sure when planning inner query, we only scan internal ID of "a". + planNodeIDScan(nodePos); + } else { + planNodeScan(nodePos); + } + } + } break; + case SubqueryPlanningType::CORRELATED: { + for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) { + auto queryNode = queryGraph->getQueryNode(nodePos); + if (info.containsCorrExpr(*queryNode->getInternalID())) { + continue; + } + planNodeScan(nodePos); + } + planCorrelatedExpressionsScan(info); + } break; + default: + KU_UNREACHABLE; + } + for (auto relPos = 0u; relPos < queryGraph->getNumQueryRels(); ++relPos) { + planRelScan(relPos); + } +} + +void Planner::planCorrelatedExpressionsScan(const QueryGraphPlanningInfo& info) { + auto queryGraph = context.getQueryGraph(); + auto newSubgraph = context.getEmptySubqueryGraph(); + auto& corrExprs = info.corrExprs; + for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) { + auto queryNode = queryGraph->getQueryNode(nodePos); + if (info.containsCorrExpr(*queryNode->getInternalID())) { + newSubgraph.addQueryNode(nodePos); + } + } + auto plan = LogicalPlan(); + appendExpressionsScan(corrExprs, plan); + plan.getLastOperator()->setCardinality(info.corrExprsCard); + auto predicates = getNewlyMatchedExprs(context.getEmptySubqueryGraph(), newSubgraph, + context.getWhereExpressions()); + appendFilters(predicates, plan); + appendDistinct(corrExprs, plan); + context.addPlan(newSubgraph, std::move(plan)); +} + +void Planner::planNodeScan(uint32_t nodePos) { + auto node = context.queryGraph->getQueryNode(nodePos); + auto newSubgraph = context.getEmptySubqueryGraph(); + newSubgraph.addQueryNode(nodePos); + auto plan = LogicalPlan(); + auto properties = getProperties(*node); + appendScanNodeTable(node->getInternalID(), node->getTableIDs(), properties, plan); + auto predicates = getNewlyMatchedExprs(context.getEmptySubqueryGraph(), newSubgraph, + context.getWhereExpressions()); + appendFilters(predicates, plan); + context.addPlan(newSubgraph, std::move(plan)); +} + +void Planner::planNodeIDScan(uint32_t nodePos) { + auto node = context.queryGraph->getQueryNode(nodePos); + auto newSubgraph = context.getEmptySubqueryGraph(); + newSubgraph.addQueryNode(nodePos); + auto plan = LogicalPlan(); + appendScanNodeTable(node->getInternalID(), node->getTableIDs(), {}, plan); + context.addPlan(newSubgraph, std::move(plan)); +} + +static std::pair, std::shared_ptr> +getBoundAndNbrNodes(const RelExpression& rel, ExtendDirection direction) { + KU_ASSERT(direction != ExtendDirection::BOTH); + auto boundNode = direction == ExtendDirection::FWD ? rel.getSrcNode() : rel.getDstNode(); + auto dstNode = direction == ExtendDirection::FWD ? rel.getDstNode() : rel.getSrcNode(); + return make_pair(boundNode, dstNode); +} + +static ExtendDirection getExtendDirection(const binder::RelExpression& relExpression, + const binder::NodeExpression& boundNode) { + if (relExpression.getDirectionType() == binder::RelDirectionType::BOTH) { + KU_ASSERT(relExpression.getExtendDirections().size() == common::NUM_REL_DIRECTIONS); + return ExtendDirection::BOTH; + } + if (relExpression.getSrcNodeName() == boundNode.getUniqueName()) { + return ExtendDirection::FWD; + } else { + return ExtendDirection::BWD; + } +} + +void Planner::planRelScan(uint32_t relPos) { + const auto rel = context.queryGraph->getQueryRel(relPos); + auto newSubgraph = context.getEmptySubqueryGraph(); + newSubgraph.addQueryRel(relPos); + const auto predicates = getNewlyMatchedExprs(context.getEmptySubqueryGraph(), newSubgraph, + context.getWhereExpressions()); + for (const auto direction : rel->getExtendDirections()) { + auto plan = LogicalPlan(); + auto [boundNode, nbrNode] = getBoundAndNbrNodes(*rel, direction); + const auto extendDirection = getExtendDirection(*rel, *boundNode); + appendScanNodeTable(boundNode->getInternalID(), boundNode->getTableIDs(), {}, plan); + appendExtend(boundNode, nbrNode, rel, extendDirection, getProperties(*rel), plan); + appendFilters(predicates, plan); + context.addPlan(newSubgraph, std::move(plan)); + } +} + +void Planner::appendExtend(std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + ExtendDirection direction, const binder::expression_vector& properties, LogicalPlan& plan) { + switch (rel->getRelType()) { + case QueryRelType::NON_RECURSIVE: { + auto extendFromSource = *boundNode == *rel->getSrcNode(); + appendNonRecursiveExtend(boundNode, nbrNode, rel, direction, extendFromSource, properties, + plan); + } break; + case QueryRelType::VARIABLE_LENGTH_WALK: + case QueryRelType::VARIABLE_LENGTH_TRAIL: + case QueryRelType::VARIABLE_LENGTH_ACYCLIC: + case QueryRelType::SHORTEST: + case QueryRelType::ALL_SHORTEST: + case QueryRelType::WEIGHTED_SHORTEST: + case QueryRelType::ALL_WEIGHTED_SHORTEST: { + appendRecursiveExtend(boundNode, nbrNode, rel, direction, plan); + } break; + default: + KU_UNREACHABLE; + } +} + +static std::unordered_map>> +populateIntersectRelCandidates(const QueryGraph& queryGraph, const SubqueryGraph& subgraph) { + std::unordered_map>> + intersectNodePosToRelsMap; + for (auto relPos : subgraph.getRelNbrPositions()) { + auto rel = queryGraph.getQueryRel(relPos); + if (!queryGraph.containsQueryNode(rel->getSrcNodeName()) || + !queryGraph.containsQueryNode(rel->getDstNodeName())) { + continue; + } + auto srcNodePos = queryGraph.getQueryNodeIdx(rel->getSrcNodeName()); + auto dstNodePos = queryGraph.getQueryNodeIdx(rel->getDstNodeName()); + auto isSrcConnected = subgraph.queryNodesSelector[srcNodePos]; + auto isDstConnected = subgraph.queryNodesSelector[dstNodePos]; + // Closing rel should be handled with inner join. + if (isSrcConnected && isDstConnected) { + continue; + } + auto intersectNodePos = isSrcConnected ? dstNodePos : srcNodePos; + if (!intersectNodePosToRelsMap.contains(intersectNodePos)) { + intersectNodePosToRelsMap.insert( + {intersectNodePos, std::vector>{}}); + } + intersectNodePosToRelsMap.at(intersectNodePos).push_back(rel); + } + return intersectNodePosToRelsMap; +} + +void Planner::planWCOJoin(uint32_t leftLevel, uint32_t rightLevel) { + KU_ASSERT(leftLevel <= rightLevel); + auto queryGraph = context.getQueryGraph(); + for (auto& rightSubgraph : context.subPlansTable->getSubqueryGraphs(rightLevel)) { + auto candidates = populateIntersectRelCandidates(*queryGraph, rightSubgraph); + for (auto& [intersectNodePos, rels] : candidates) { + if (rels.size() == leftLevel) { + auto intersectNode = queryGraph->getQueryNode(intersectNodePos); + planWCOJoin(rightSubgraph, rels, intersectNode); + } + } + } +} + +static LogicalOperator* getSequentialScan(LogicalOperator* op) { + switch (op->getOperatorType()) { + case LogicalOperatorType::FLATTEN: + case LogicalOperatorType::FILTER: + case LogicalOperatorType::EXTEND: + case LogicalOperatorType::PROJECTION: { // operators we directly search through + return getSequentialScan(op->getChild(0).get()); + } + case LogicalOperatorType::SCAN_NODE_TABLE: { + return op; + } + default: + return nullptr; + } +} + +// Check whether given node ID has sequential guarantee on the plan. +static bool isNodeSequentialOnPlan(const LogicalPlan& plan, const NodeExpression& node) { + const auto seqScan = getSequentialScan(plan.getLastOperator().get()); + if (seqScan == nullptr) { + return false; + } + const auto sequentialScan = ku_dynamic_cast(seqScan); + return sequentialScan->getNodeID()->getUniqueName() == node.getInternalID()->getUniqueName(); +} + +// As a heuristic for wcoj, we always pick rel scan that starts from the bound node. +static LogicalPlan getWCOJBuildPlanForRel(const std::vector& candidatePlans, + const NodeExpression& boundNode) { + for (auto& candidatePlan : candidatePlans) { + if (isNodeSequentialOnPlan(candidatePlan, boundNode)) { + return candidatePlan.copy(); + } + } + return LogicalPlan(); +} + +void Planner::planWCOJoin(const SubqueryGraph& subgraph, + const std::vector>& rels, + const std::shared_ptr& intersectNode) { + auto newSubgraph = subgraph; + std::vector prevSubgraphs; + prevSubgraphs.push_back(subgraph); + expression_vector boundNodeIDs; + std::vector relPlans; + for (auto& rel : rels) { + auto boundNode = rel->getSrcNodeName() == intersectNode->getUniqueName() ? + rel->getDstNode() : + rel->getSrcNode(); + + // stop if the rel pattern's supported rel directions don't contain the current direction + const auto extendDirection = getExtendDirection(*rel, *boundNode); + if (extendDirection != ExtendDirection::BOTH && + !containsValue(rel->getExtendDirections(), extendDirection)) { + return; + } + + boundNodeIDs.push_back(boundNode->getInternalID()); + auto relPos = context.getQueryGraph()->getQueryRelIdx(rel->getUniqueName()); + auto prevSubgraph = context.getEmptySubqueryGraph(); + prevSubgraph.addQueryRel(relPos); + prevSubgraphs.push_back(subgraph); + newSubgraph.addQueryRel(relPos); + // fetch build plans for rel + auto relSubgraph = context.getEmptySubqueryGraph(); + relSubgraph.addQueryRel(relPos); + KU_ASSERT(context.subPlansTable->containSubgraphPlans(relSubgraph)); + auto& relPlanCandidates = context.subPlansTable->getSubgraphPlans(relSubgraph); + auto relPlan = getWCOJBuildPlanForRel(relPlanCandidates, *boundNode); + if (relPlan.isEmpty()) { // Cannot find a suitable rel plan. + return; + } + relPlans.push_back(std::move(relPlan)); + } + auto predicates = + getNewlyMatchedExprs(prevSubgraphs, newSubgraph, context.getWhereExpressions()); + for (auto& leftPlan : context.getPlans(subgraph)) { + // Disable WCOJ if intersect node is in the scope of probe plan. This happens in the case + // like, MATCH (a)-[e1]->(b), (b)-[e2]->(a), (a)-[e3]->(b). + // When we perform edge-at-a-time enumeration, at some point we will in the state of e1 as + // probe side and e2, e3 as build side and we attempt to apply WCOJ. However, the right + // approach is to build e1, e2, e3 and intersect on a common node (either a or b). + // I tend to disable WCOJ for this case for now. The proper fix should be move to + // node-at-a-time enumeration and re-enable WCOJ. + // TODO(Xiyang): Fixme according to the description above. + if (leftPlan.getSchema()->isExpressionInScope(*intersectNode->getInternalID())) { + continue; + } + auto leftPlanCopy = leftPlan.copy(); + std::vector rightPlansCopy; + rightPlansCopy.reserve(relPlans.size()); + for (auto& relPlan : relPlans) { + rightPlansCopy.push_back(relPlan.copy()); + } + appendIntersect(intersectNode->getInternalID(), boundNodeIDs, leftPlanCopy, rightPlansCopy); + for (auto& predicate : predicates) { + appendFilter(predicate, leftPlanCopy); + } + context.subPlansTable->addPlan(newSubgraph, std::move(leftPlanCopy)); + } +} + +// E.g. Query graph (a)-[e1]->(b), (b)-[e2]->(a) and join between (a)-[e1] and [e2] +// Since (b) is not in the scope of any join subgraph, join node is analyzed as (a) only, However, +// [e1] and [e2] are also connected at (b) implicitly. So actual join nodes should be (a) and (b). +// We prune such join. +// Note that this does not mean we may lose good plan. An equivalent join can be found between [e2] +// and (a)-[e1]->(b). +static bool needPruneImplicitJoins(const SubqueryGraph& leftSubgraph, + const SubqueryGraph& rightSubgraph, uint32_t numJoinNodes) { + auto leftNodePositions = leftSubgraph.getNodePositionsIgnoringNodeSelector(); + auto rightNodePositions = rightSubgraph.getNodePositionsIgnoringNodeSelector(); + auto intersectionSize = 0u; + for (auto& pos : leftNodePositions) { + if (rightNodePositions.contains(pos)) { + intersectionSize++; + } + } + return intersectionSize != numJoinNodes; +} + +void Planner::planInnerJoin(uint32_t leftLevel, uint32_t rightLevel) { + KU_ASSERT(leftLevel <= rightLevel); + for (auto& rightSubgraph : context.subPlansTable->getSubqueryGraphs(rightLevel)) { + for (auto& nbrSubgraph : rightSubgraph.getNbrSubgraphs(leftLevel)) { + // E.g. MATCH (a)->(b) MATCH (b)->(c) + // Since we merge query graph for multipart query, during enumeration for the second + // match, the query graph is (a)->(b)->(c). However, we omit plans corresponding to the + // first match (i.e. (a)->(b)). + if (!context.containPlans(nbrSubgraph)) { + continue; + } + auto joinNodePositions = rightSubgraph.getConnectedNodePos(nbrSubgraph); + auto joinNodes = context.queryGraph->getQueryNodes(joinNodePositions); + if (needPruneImplicitJoins(nbrSubgraph, rightSubgraph, joinNodes.size())) { + continue; + } + // If index nested loop (INL) join is possible, we prune hash join plans + if (tryPlanINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) { + continue; + } + planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel); + } + } +} + +bool Planner::tryPlanINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph, + const std::vector>& joinNodes) { + if (joinNodes.size() > 1) { + return false; + } + if (!subgraph.isSingleRel() && !otherSubgraph.isSingleRel()) { + return false; + } + if (subgraph.isSingleRel()) { // Always put single rel subgraph to right. + return tryPlanINLJoin(otherSubgraph, subgraph, joinNodes); + } + auto relPos = UINT32_MAX; + for (auto i = 0u; i < context.queryGraph->getNumQueryRels(); ++i) { + if (otherSubgraph.queryRelsSelector[i]) { + relPos = i; + } + } + KU_ASSERT(relPos != UINT32_MAX); + auto rel = context.queryGraph->getQueryRel(relPos); + const auto& boundNode = joinNodes[0]; + auto nbrNode = + boundNode->getUniqueName() == rel->getSrcNodeName() ? rel->getDstNode() : rel->getSrcNode(); + auto extendDirection = getExtendDirection(*rel, *boundNode); + if (extendDirection != common::ExtendDirection::BOTH && + !common::containsValue(rel->getExtendDirections(), extendDirection)) { + return false; + } + auto newSubgraph = subgraph; + newSubgraph.addQueryRel(relPos); + auto predicates = getNewlyMatchedExprs(subgraph, newSubgraph, context.getWhereExpressions()); + bool hasAppliedINLJoin = false; + for (auto& prevPlan : context.getPlans(subgraph)) { + if (isNodeSequentialOnPlan(prevPlan, *boundNode)) { + auto plan = prevPlan.copy(); + appendExtend(boundNode, nbrNode, rel, extendDirection, getProperties(*rel), plan); + appendFilters(predicates, plan); + context.addPlan(newSubgraph, std::move(plan)); + hasAppliedINLJoin = true; + } + } + return hasAppliedINLJoin; +} + +void Planner::planInnerHashJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph, + const std::vector>& joinNodes, bool flipPlan) { + auto newSubgraph = subgraph; + newSubgraph.addSubqueryGraph(otherSubgraph); + auto maxCost = context.subPlansTable->getMaxCost(newSubgraph); + expression_vector joinNodeIDs; + for (auto& joinNode : joinNodes) { + joinNodeIDs.push_back(joinNode->getInternalID()); + } + auto predicates = + getNewlyMatchedExprs(subgraph, otherSubgraph, newSubgraph, context.getWhereExpressions()); + for (auto& leftPlan : context.getPlans(subgraph)) { + for (auto& rightPlan : context.getPlans(otherSubgraph)) { + if (CostModel::computeHashJoinCost(joinNodeIDs, leftPlan, rightPlan) < maxCost) { + auto leftPlanProbeCopy = leftPlan.copy(); + auto rightPlanBuildCopy = rightPlan.copy(); + appendHashJoin(joinNodeIDs, JoinType::INNER, leftPlanProbeCopy, rightPlanBuildCopy, + leftPlanProbeCopy); + appendFilters(predicates, leftPlanProbeCopy); + context.addPlan(newSubgraph, std::move(leftPlanProbeCopy)); + } + // flip build and probe side to get another HashJoin plan + if (flipPlan && + CostModel::computeHashJoinCost(joinNodeIDs, rightPlan, leftPlan) < maxCost) { + auto leftPlanBuildCopy = leftPlan.copy(); + auto rightPlanProbeCopy = rightPlan.copy(); + appendHashJoin(joinNodeIDs, JoinType::INNER, rightPlanProbeCopy, leftPlanBuildCopy, + rightPlanProbeCopy); + appendFilters(predicates, rightPlanProbeCopy); + context.addPlan(newSubgraph, std::move(rightPlanProbeCopy)); + } + } + } +} + +static bool isExpressionNewlyMatched(const std::vector& prevs, + const SubqueryGraph& newSubgraph, const std::shared_ptr& expression) { + auto collector = DependentVarNameCollector(); + collector.visit(expression); + auto variables = collector.getVarNames(); + for (auto& prev : prevs) { + if (prev.containAllVariables(variables)) { + return false; // matched in prev subgraph + } + } + return newSubgraph.containAllVariables(variables); +} + +expression_vector Planner::getNewlyMatchedExprs(const std::vector& prevs, + const SubqueryGraph& new_, const expression_vector& exprs) { + expression_vector result; + for (auto& expr : exprs) { + if (isExpressionNewlyMatched(prevs, new_, expr)) { + result.push_back(expr); + } + } + return result; +} + +expression_vector Planner::getNewlyMatchedExprs(const SubqueryGraph& prev, + const SubqueryGraph& new_, const expression_vector& exprs) { + return getNewlyMatchedExprs(std::vector{prev}, new_, exprs); +} + +expression_vector Planner::getNewlyMatchedExprs(const SubqueryGraph& leftPrev, + const SubqueryGraph& rightPrev, const SubqueryGraph& new_, const expression_vector& exprs) { + return getNewlyMatchedExprs(std::vector{leftPrev, rightPrev}, new_, exprs); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_node_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_node_scan.cpp new file mode 100644 index 0000000000..74609714c3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_node_scan.cpp @@ -0,0 +1,19 @@ +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +LogicalPlan Planner::getNodePropertyScanPlan(const NodeExpression& node) { + auto properties = getProperties(node); + auto scanPlan = LogicalPlan(); + if (properties.empty()) { + return scanPlan; + } + appendScanNodeTable(node.getInternalID(), node.getTableIDs(), properties, scanPlan); + return scanPlan; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_node_semi_mask.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_node_semi_mask.cpp new file mode 100644 index 0000000000..f614d970f3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_node_semi_mask.cpp @@ -0,0 +1,48 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/property_expression.h" +#include "binder/expression_visitor.h" +#include "planner/operator/logical_dummy_sink.h" +#include "planner/operator/sip/logical_semi_masker.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void Planner::appendNodeSemiMask(SemiMaskTargetType targetType, const NodeExpression& node, + LogicalPlan& plan) { + auto semiMasker = std::make_shared(SemiMaskKeyType::NODE, targetType, + node.getInternalID(), node.getTableIDs(), plan.getLastOperator()); + semiMasker->computeFactorizedSchema(); + plan.setLastOperator(semiMasker); +} + +void Planner::appendDummySink(LogicalPlan& plan) { + auto dummySink = std::make_shared(plan.getLastOperator()); + dummySink->computeFactorizedSchema(); + plan.setLastOperator(std::move(dummySink)); +} + +// Create a plan with a root semi masker for given node and node predicate. +LogicalPlan Planner::getNodeSemiMaskPlan(SemiMaskTargetType targetType, const NodeExpression& node, + std::shared_ptr nodePredicate) { + auto plan = LogicalPlan(); + auto prevCollection = enterNewPropertyExprCollection(); + auto collector = PropertyExprCollector(); + collector.visit(nodePredicate); + for (auto& expr : ExpressionUtil::removeDuplication(collector.getPropertyExprs())) { + auto& propExpr = expr->constCast(); + propertyExprCollection.addProperties(propExpr.getVariableName(), expr); + } + appendScanNodeTable(node.getInternalID(), node.getTableIDs(), getProperties(node), plan); + appendFilter(nodePredicate, plan); + exitPropertyExprCollection(std::move(prevCollection)); + appendNodeSemiMask(targetType, node, plan); + appendDummySink(plan); + return plan; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_port_db.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_port_db.cpp new file mode 100644 index 0000000000..38fa79ee23 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_port_db.cpp @@ -0,0 +1,76 @@ +#include "binder/bound_export_database.h" +#include "binder/bound_import_database.h" +#include "catalog/catalog.h" +#include "common/file_system/virtual_file_system.h" +#include "common/string_utils.h" +#include "function/built_in_function_utils.h" +#include "planner/operator/persistent/logical_copy_to.h" +#include "planner/operator/simple/logical_export_db.h" +#include "planner/operator/simple/logical_import_db.h" +#include "planner/planner.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::storage; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace planner { + +std::vector> Planner::planExportTableData( + const BoundStatement& statement) { + std::vector> logicalOperators; + auto& boundExportDatabase = statement.constCast(); + auto fileTypeStr = FileTypeUtils::toString(boundExportDatabase.getFileType()); + StringUtils::toLower(fileTypeStr); + // TODO(Ziyi): Shouldn't these be done in Binder? + std::string name = + stringFormat("COPY_{}", FileTypeUtils::toString(boundExportDatabase.getFileType())); + auto entry = + Catalog::Get(*clientContext)->getFunctionEntry(Transaction::Get(*clientContext), name); + auto func = function::BuiltInFunctionsUtils::matchFunction(name, + entry->ptrCast()); + KU_ASSERT(func != nullptr); + auto exportFunc = *func->constPtrCast(); + for (auto& exportTableData : *boundExportDatabase.getExportData()) { + auto regularQuery = exportTableData.getRegularQuery(); + KU_ASSERT(regularQuery->getStatementType() == StatementType::QUERY); + auto tablePlan = planStatement(*regularQuery); + auto path = VirtualFileSystem::GetUnsafe(*clientContext) + ->joinPath(boundExportDatabase.getFilePath(), exportTableData.fileName); + function::ExportFuncBindInput bindInput{exportTableData.columnNames, std::move(path), + boundExportDatabase.getExportOptions()}; + auto copyTo = std::make_shared(exportFunc.bind(bindInput), exportFunc, + tablePlan.getLastOperator()); + logicalOperators.push_back(std::move(copyTo)); + } + return logicalOperators; +} + +LogicalPlan Planner::planExportDatabase(const BoundStatement& statement) { + auto& boundExportDatabase = statement.constCast(); + auto logicalOperators = std::vector>(); + auto plan = LogicalPlan(); + if (!boundExportDatabase.exportSchemaOnly()) { + logicalOperators = planExportTableData(statement); + } + auto exportDatabase = + std::make_shared(boundExportDatabase.getBoundFileInfo()->copy(), + std::move(logicalOperators), boundExportDatabase.exportSchemaOnly()); + plan.setLastOperator(std::move(exportDatabase)); + return plan; +} + +LogicalPlan Planner::planImportDatabase(const BoundStatement& statement) { + auto& boundImportDatabase = statement.constCast(); + auto plan = LogicalPlan(); + auto importDatabase = std::make_shared(boundImportDatabase.getQuery(), + boundImportDatabase.getIndexQuery()); + plan.setLastOperator(std::move(importDatabase)); + return plan; +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_projection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_projection.cpp new file mode 100644 index 0000000000..7ee2c8a6b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_projection.cpp @@ -0,0 +1,85 @@ +#include "binder/expression_visitor.h" +#include "binder/query/return_with_clause/bound_projection_body.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::planProjectionBody(const BoundProjectionBody* projectionBody, LogicalPlan& plan) { + auto expressionsToProject = projectionBody->getProjectionExpressions(); + if (expressionsToProject.empty()) { + return; + } + if (plan.isEmpty()) { // e.g. RETURN 1, COUNT(2) + appendDummyScan(plan); + } + auto expressionsToAggregate = projectionBody->getAggregateExpressions(); + auto expressionsToGroupBy = projectionBody->getGroupByExpressions(); + if (!expressionsToAggregate.empty()) { + planAggregate(expressionsToAggregate, expressionsToGroupBy, plan); + } + // We might order by an expression that is not in projection list, so after order by we + // always need to append a projection. + // If distinct is presented in projection list, we need to first append project to evaluate the + // list, then take the distinct. + // Order by should always be the last operator (except for skip/limit) because other operators + // will break the order. + if (projectionBody->isDistinct() && projectionBody->hasOrderByExpressions()) { + appendProjection(expressionsToProject, plan); + appendDistinct(expressionsToProject, plan); + planOrderBy(expressionsToProject, projectionBody->getOrderByExpressions(), + projectionBody->getSortingOrders(), plan); + appendProjection(expressionsToProject, plan); + } else if (projectionBody->isDistinct()) { + appendProjection(expressionsToProject, plan); + appendDistinct(expressionsToProject, plan); + } else if (projectionBody->hasOrderByExpressions()) { + planOrderBy(expressionsToProject, projectionBody->getOrderByExpressions(), + projectionBody->getSortingOrders(), plan); + appendProjection(expressionsToProject, plan); + } else { + appendProjection(expressionsToProject, plan); + } + if (projectionBody->hasSkipOrLimit()) { + appendMultiplicityReducer(plan); + appendLimit(projectionBody->getSkipNumber(), projectionBody->getLimitNumber(), plan); + } +} + +void Planner::planAggregate(const expression_vector& expressionsToAggregate, + const expression_vector& expressionsToGroupBy, LogicalPlan& plan) { + KU_ASSERT(!expressionsToAggregate.empty()); + expression_vector expressionsToProject; + for (auto& expressionToAggregate : expressionsToAggregate) { + if (ExpressionChildrenCollector::collectChildren(*expressionToAggregate) + .empty()) { // skip COUNT(*) + continue; + } + expressionsToProject.push_back(expressionToAggregate->getChild(0)); + } + for (auto& expressionToGroupBy : expressionsToGroupBy) { + expressionsToProject.push_back(expressionToGroupBy); + } + appendProjection(expressionsToProject, plan); + appendAggregate(expressionsToGroupBy, expressionsToAggregate, plan); +} + +void Planner::planOrderBy(const binder::expression_vector& expressionsToProject, + const binder::expression_vector& expressionsToOrderBy, const std::vector& isAscOrders, + LogicalPlan& plan) { + auto expressionsToProjectBeforeOrderBy = expressionsToProject; + auto expressionsToProjectSet = + expression_set{expressionsToProject.begin(), expressionsToProject.end()}; + for (auto& expression : expressionsToOrderBy) { + if (!expressionsToProjectSet.contains(expression)) { + expressionsToProjectBeforeOrderBy.push_back(expression); + } + } + appendProjection(expressionsToProjectBeforeOrderBy, plan); + appendOrderBy(expressionsToOrderBy, isAscOrders, plan); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_read.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_read.cpp new file mode 100644 index 0000000000..9f793ab5a6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_read.cpp @@ -0,0 +1,143 @@ +#include "binder/expression_visitor.h" +#include "binder/query/reading_clause/bound_load_from.h" +#include "binder/query/reading_clause/bound_match_clause.h" +#include "binder/query/reading_clause/bound_table_function_call.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +void Planner::planReadingClause(const BoundReadingClause& readingClause, LogicalPlan& plan) { + switch (readingClause.getClauseType()) { + case ClauseType::MATCH: { + planMatchClause(readingClause, plan); + } break; + case ClauseType::UNWIND: { + planUnwindClause(readingClause, plan); + } break; + case ClauseType::TABLE_FUNCTION_CALL: { + planTableFunctionCall(readingClause, plan); + } break; + case ClauseType::LOAD_FROM: { + planLoadFrom(readingClause, plan); + } break; + default: + KU_UNREACHABLE; + } +} + +void Planner::planMatchClause(const BoundReadingClause& readingClause, LogicalPlan& plan) { + auto& boundMatchClause = readingClause.constCast(); + auto queryGraphCollection = boundMatchClause.getQueryGraphCollection(); + auto predicates = boundMatchClause.getConjunctivePredicates(); + switch (boundMatchClause.getMatchClauseType()) { + case MatchClauseType::MATCH: { + if (plan.isEmpty()) { + auto info = QueryGraphPlanningInfo(); + info.predicates = predicates; + info.hint = boundMatchClause.getHint(); + plan = planQueryGraphCollection(*queryGraphCollection, info); + } else { + planRegularMatch(*queryGraphCollection, predicates, plan, boundMatchClause.getHint()); + } + } break; + case MatchClauseType::OPTIONAL_MATCH: { + planOptionalMatch(*queryGraphCollection, predicates, plan, boundMatchClause.getHint()); + } break; + default: + KU_UNREACHABLE; + } +} + +void Planner::planUnwindClause(const BoundReadingClause& boundReadingClause, LogicalPlan& plan) { + if (plan.isEmpty()) { // UNWIND [1, 2, 3, 4] AS x RETURN x + appendDummyScan(plan); + } + appendUnwind(boundReadingClause, plan); +} + +class PredicatesDependencyAnalyzer { +public: + explicit PredicatesDependencyAnalyzer(const expression_vector& outputColumns) { + for (auto& column : outputColumns) { + columnNameSet.insert(column->getUniqueName()); + } + } + + void analyze(const expression_vector& predicates) { + predicatesDependsOnlyOnOutputColumns.clear(); + predicatesWithOtherDependencies.clear(); + for (auto& predicate : predicates) { + if (hasExternalDependency(predicate)) { + predicatesWithOtherDependencies.push_back(predicate); + } else { + predicatesDependsOnlyOnOutputColumns.push_back(predicate); + } + } + } + +private: + bool hasExternalDependency(const std::shared_ptr& expression) { + auto collector = DependentVarNameCollector(); + collector.visit(expression); + for (auto& name : collector.getVarNames()) { + if (!columnNameSet.contains(name)) { + return true; + } + } + return false; + } + +public: + expression_vector predicatesDependsOnlyOnOutputColumns; + expression_vector predicatesWithOtherDependencies; + +private: + std::unordered_set columnNameSet; +}; + +void Planner::planTableFunctionCall(const BoundReadingClause& readingClause, LogicalPlan& plan) { + auto& boundCall = readingClause.constCast(); + auto analyzer = PredicatesDependencyAnalyzer(boundCall.getBindData()->columns); + analyzer.analyze(boundCall.getConjunctivePredicates()); + KU_ASSERT(boundCall.getTableFunc().getLogicalPlanFunc); + boundCall.getTableFunc().getLogicalPlanFunc(this, readingClause, + analyzer.predicatesDependsOnlyOnOutputColumns, plan); + if (!analyzer.predicatesWithOtherDependencies.empty()) { + appendFilters(analyzer.predicatesWithOtherDependencies, plan); + } +} + +void Planner::planLoadFrom(const BoundReadingClause& readingClause, LogicalPlan& plan) { + auto& loadFrom = readingClause.constCast(); + auto analyzer = PredicatesDependencyAnalyzer(loadFrom.getInfo()->bindData->columns); + analyzer.analyze(loadFrom.getConjunctivePredicates()); + auto op = getTableFunctionCall(*loadFrom.getInfo()); + planReadOp(std::move(op), analyzer.predicatesDependsOnlyOnOutputColumns, plan); + if (!analyzer.predicatesWithOtherDependencies.empty()) { + appendFilters(analyzer.predicatesWithOtherDependencies, plan); + } +} + +void Planner::planReadOp(std::shared_ptr op, const expression_vector& predicates, + LogicalPlan& plan) { + if (!plan.isEmpty()) { + auto tmpPlan = LogicalPlan(); + tmpPlan.setLastOperator(std::move(op)); + if (!predicates.empty()) { + appendFilters(predicates, tmpPlan); + } + appendCrossProduct(plan, tmpPlan, plan); + } else { + plan.setLastOperator(std::move(op)); + if (!predicates.empty()) { + appendFilters(predicates, plan); + } + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_single_query.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_single_query.cpp new file mode 100644 index 0000000000..a8112350cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_single_query.cpp @@ -0,0 +1,48 @@ +#include "binder/expression/property_expression.h" +#include "binder/visitor/property_collector.h" +#include "planner/planner.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +// Note: we cannot append ResultCollector for plans enumerated for single query before there could +// be a UNION on top which requires further flatten. So we delay ResultCollector appending to +// enumerate regular query level. +LogicalPlan Planner::planSingleQuery(const NormalizedSingleQuery& singleQuery) { + auto propertyCollector = PropertyCollector(); + propertyCollector.visitSingleQuery(singleQuery); + auto properties = propertyCollector.getProperties(); + for (auto& expr : propertyCollector.getProperties()) { + auto& property = expr->constCast(); + propertyExprCollection.addProperties(property.getVariableName(), expr); + } + context.resetState(); + auto plan = LogicalPlan(); + for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) { + planQueryPart(*singleQuery.getQueryPart(i), plan); + } + return plan; +} + +void Planner::planQueryPart(const NormalizedQueryPart& queryPart, LogicalPlan& plan) { + // plan read + for (auto i = 0u; i < queryPart.getNumReadingClause(); i++) { + planReadingClause(*queryPart.getReadingClause(i), plan); + } + // plan update + for (auto i = 0u; i < queryPart.getNumUpdatingClause(); ++i) { + planUpdatingClause(*queryPart.getUpdatingClause(i), plan); + } + // plan projection + if (queryPart.hasProjectionBody()) { + planProjectionBody(queryPart.getProjectionBody(), plan); + if (queryPart.hasProjectionBodyPredicate()) { + appendFilter(queryPart.getProjectionBodyPredicate(), plan); + } + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_subquery.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_subquery.cpp new file mode 100644 index 0000000000..0c5c29689b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_subquery.cpp @@ -0,0 +1,308 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/subquery_expression.h" +#include "binder/expression_visitor.h" +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +static expression_vector getDependentExprs(std::shared_ptr expr, const Schema& schema) { + auto analyzer = GroupDependencyAnalyzer(true /* collectDependentExpr */, schema); + analyzer.visit(expr); + return analyzer.getDependentExprs(); +} + +expression_vector Planner::getCorrelatedExprs(const QueryGraphCollection& collection, + const expression_vector& predicates, Schema* outerSchema) { + expression_vector result; + for (auto& predicate : predicates) { + for (auto& expression : getDependentExprs(predicate, *outerSchema)) { + result.push_back(expression); + } + } + for (auto& node : collection.getQueryNodes()) { + if (outerSchema->isExpressionInScope(*node->getInternalID())) { + result.push_back(node->getInternalID()); + } + } + return ExpressionUtil::removeDuplication(result); +} + +class SubqueryPredicatePullUpAnalyzer { +public: + SubqueryPredicatePullUpAnalyzer(const Schema& schema, + const QueryGraphCollection& queryGraphCollection) + : schema{schema}, queryGraphCollection{queryGraphCollection} {} + + bool analyze(const expression_vector& predicates) { + expression_vector correlatedPredicates; + for (auto& predicate : predicates) { + if (getDependentExprs(predicate, schema).empty()) { + nonCorrelatedPredicates.push_back(predicate); + } else { + correlatedPredicates.push_back(predicate); + } + } + for (auto predicate : correlatedPredicates) { + auto [left, right] = analyze(predicate); + if (left == nullptr) { + return false; + } + joinConditions.emplace_back(left, right); + } + for (auto& node : queryGraphCollection.getQueryNodes()) { + if (schema.isExpressionInScope(*node->getInternalID())) { + joinConditions.emplace_back(node->getInternalID(), node->getInternalID()); + } + } + return true; + } + + expression_vector getNonCorrelatedPredicates() const { return nonCorrelatedPredicates; } + std::vector getJoinConditions() const { return joinConditions; } + + expression_vector getCorrelatedInternalIDs() const { + expression_vector exprs; + for (auto& node : queryGraphCollection.getQueryNodes()) { + if (schema.isExpressionInScope(*node->getInternalID())) { + exprs.push_back(node->getInternalID()); + } + } + return exprs; + } + +private: + expression_pair analyze(std::shared_ptr predicate) { + if (predicate->expressionType != common::ExpressionType::EQUALS) { + return {nullptr, nullptr}; + } + auto left = predicate->getChild(0); + auto right = predicate->getChild(1); + if (isUnnestableJoinCondition(*left, *right)) { + return {left, right}; + } + if (isUnnestableJoinCondition(*right, *left)) { + return {right, left}; + } + return {nullptr, nullptr}; + } + + bool isUnnestableJoinCondition(const Expression& left, const Expression& right) { + return right.expressionType == ExpressionType::PROPERTY && + schema.isExpressionInScope(left) && !schema.isExpressionInScope(right); + } + +private: + const Schema& schema; + const QueryGraphCollection& queryGraphCollection; + + expression_vector nonCorrelatedPredicates; + std::vector joinConditions; +}; + +void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection, + const expression_vector& predicates, LogicalPlan& leftPlan, + std::shared_ptr hint) { + planOptionalMatch(queryGraphCollection, predicates, nullptr /* mark */, leftPlan, + std::move(hint)); +} + +void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection, + const expression_vector& predicates, std::shared_ptr mark, LogicalPlan& leftPlan, + std::shared_ptr hint) { + expression_vector correlatedExprs; + if (!leftPlan.isEmpty()) { + correlatedExprs = + getCorrelatedExprs(queryGraphCollection, predicates, leftPlan.getSchema()); + } + auto info = QueryGraphPlanningInfo(); + info.hint = hint; + if (leftPlan.isEmpty()) { + // Optional match is the first clause, e.g. OPTIONAL MATCH RETURN * + info.predicates = predicates; + auto plan = planQueryGraphCollection(queryGraphCollection, info); + leftPlan.setLastOperator(plan.getLastOperator()); + appendOptionalAccumulate(mark, leftPlan); + return; + } + if (correlatedExprs.empty()) { + // Plan uncorrelated subquery (think of this as a CTE) + info.predicates = predicates; + auto rightPlan = planQueryGraphCollection(queryGraphCollection, info); + if (leftPlan.hasUpdate()) { + appendAccOptionalCrossProduct(mark, leftPlan, rightPlan, leftPlan); + } else { + appendOptionalCrossProduct(mark, leftPlan, rightPlan, leftPlan); + } + return; + } + // Plan correlated subquery + info.corrExprsCard = leftPlan.getCardinality(); + auto analyzer = SubqueryPredicatePullUpAnalyzer(*leftPlan.getSchema(), queryGraphCollection); + std::vector joinConditions; + LogicalPlan rightPlan; + if (analyzer.analyze(predicates)) { + // Unnest as left join + info.subqueryType = SubqueryPlanningType::UNNEST_CORRELATED; + info.corrExprs = analyzer.getCorrelatedInternalIDs(); + info.predicates = analyzer.getNonCorrelatedPredicates(); + rightPlan = planQueryGraphCollectionInNewContext(queryGraphCollection, info); + joinConditions = analyzer.getJoinConditions(); + } else { + // Unnest as expression scan + distinct & inner join + info.subqueryType = SubqueryPlanningType::CORRELATED; + info.corrExprs = correlatedExprs; + info.predicates = predicates; + for (auto& expr : correlatedExprs) { + joinConditions.emplace_back(expr, expr); + } + rightPlan = planQueryGraphCollectionInNewContext(queryGraphCollection, info); + appendAccumulate(correlatedExprs, leftPlan); + } + if (leftPlan.hasUpdate()) { + appendAccHashJoin(joinConditions, JoinType::LEFT, mark, leftPlan, rightPlan, leftPlan); + } else { + appendHashJoin(joinConditions, JoinType::LEFT, mark, leftPlan, rightPlan, leftPlan); + } +} + +void Planner::planRegularMatch(const QueryGraphCollection& queryGraphCollection, + const expression_vector& predicates, LogicalPlan& leftPlan, + std::shared_ptr hint) { + expression_vector predicatesToPushDown, predicatesToPullUp; + // E.g. MATCH (a) WITH COUNT(*) AS s MATCH (b) WHERE b.age > s + // "b.age > s" should be pulled up after both MATCH clauses are joined. + for (auto& predicate : predicates) { + if (getDependentExprs(predicate, *leftPlan.getSchema()).empty()) { + predicatesToPushDown.push_back(predicate); + } else { + predicatesToPullUp.push_back(predicate); + } + } + auto correlatedExprs = + getCorrelatedExprs(queryGraphCollection, predicatesToPushDown, leftPlan.getSchema()); + auto joinNodeIDs = + ExpressionUtil::getExpressionsWithDataType(correlatedExprs, LogicalTypeID::INTERNAL_ID); + auto info = QueryGraphPlanningInfo(); + info.predicates = predicatesToPushDown; + info.hint = hint; + if (joinNodeIDs.empty()) { + info.subqueryType = SubqueryPlanningType::NONE; + auto rightPlan = planQueryGraphCollectionInNewContext(queryGraphCollection, info); + if (leftPlan.hasUpdate()) { + appendCrossProduct(rightPlan, leftPlan, leftPlan); + } else { + appendCrossProduct(leftPlan, rightPlan, leftPlan); + } + } else { + // TODO(Xiyang): there is a question regarding if we want to plan as a correlated subquery + // Multi-part query is actually CTE and CTE can be considered as a subquery but does not + // scan from outer. + info.subqueryType = SubqueryPlanningType::UNNEST_CORRELATED; + info.corrExprs = joinNodeIDs; + info.corrExprsCard = leftPlan.getCardinality(); + auto rightPlan = planQueryGraphCollectionInNewContext(queryGraphCollection, info); + if (leftPlan.hasUpdate()) { + appendHashJoin(joinNodeIDs, JoinType::INNER, rightPlan, leftPlan, leftPlan); + } else { + appendHashJoin(joinNodeIDs, JoinType::INNER, leftPlan, rightPlan, leftPlan); + } + } + for (auto& predicate : predicatesToPullUp) { + appendFilter(predicate, leftPlan); + } +} + +void Planner::planSubquery(const std::shared_ptr& expression, LogicalPlan& outerPlan) { + KU_ASSERT(expression->expressionType == ExpressionType::SUBQUERY); + auto subquery = expression->ptrCast(); + auto correlatedExprs = getDependentExprs(expression, *outerPlan.getSchema()); + auto predicates = subquery->getPredicatesSplitOnAnd(); + LogicalPlan innerPlan; + auto info = QueryGraphPlanningInfo(); + info.hint = subquery->getHint(); + if (correlatedExprs.empty()) { + // Plan uncorrelated subquery + info.subqueryType = SubqueryPlanningType::NONE; + info.predicates = predicates; + innerPlan = + planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info); + expression_vector emptyHashKeys; + auto projectExprs = expression_vector{subquery->getProjectionExpr()}; + switch (subquery->getSubqueryType()) { + case common::SubqueryType::EXISTS: { + auto aggregates = expression_vector{subquery->getCountStarExpr()}; + appendAggregate(emptyHashKeys, aggregates, innerPlan); + appendProjection(projectExprs, innerPlan); + } break; + case common::SubqueryType::COUNT: { + appendAggregate(emptyHashKeys, projectExprs, innerPlan); + } break; + default: + KU_UNREACHABLE; + } + appendCrossProduct(outerPlan, innerPlan, outerPlan); + return; + } + // Plan correlated subquery + info.corrExprsCard = outerPlan.getCardinality(); + auto analyzer = SubqueryPredicatePullUpAnalyzer(*outerPlan.getSchema(), + *subquery->getQueryGraphCollection()); + std::vector joinConditions; + if (analyzer.analyze(predicates)) { + // Unnest as inner join + info.subqueryType = SubqueryPlanningType::UNNEST_CORRELATED; + info.corrExprs = analyzer.getCorrelatedInternalIDs(); + info.predicates = analyzer.getNonCorrelatedPredicates(); + innerPlan = + planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info); + joinConditions = analyzer.getJoinConditions(); + } else { + // Unnest as expression scan + distinct & inner join + info.subqueryType = SubqueryPlanningType::CORRELATED; + info.corrExprs = correlatedExprs; + info.predicates = predicates; + for (auto& expr : correlatedExprs) { + joinConditions.emplace_back(expr, expr); + } + innerPlan = + planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info); + appendAccumulate(correlatedExprs, outerPlan); + } + switch (subquery->getSubqueryType()) { + case common::SubqueryType::EXISTS: { + appendMarkJoin(joinConditions, expression, outerPlan, innerPlan, outerPlan); + } break; + case common::SubqueryType::COUNT: { + expression_vector hashKeys; + for (auto& joinCondition : joinConditions) { + hashKeys.push_back(joinCondition.second); + } + appendAggregate(hashKeys, expression_vector{subquery->getProjectionExpr()}, innerPlan); + appendHashJoin(joinConditions, common::JoinType::COUNT, nullptr, outerPlan, innerPlan, + outerPlan); + } break; + default: + KU_UNREACHABLE; + } +} + +void Planner::planSubqueryIfNecessary(std::shared_ptr expression, LogicalPlan& plan) { + auto collector = SubqueryExprCollector(); + collector.visit(expression); + if (collector.hasSubquery()) { + for (auto& expr : collector.getSubqueryExprs()) { + if (plan.getSchema()->isExpressionInScope(*expr)) { + continue; + } + planSubquery(expr, plan); + } + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_update.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_update.cpp new file mode 100644 index 0000000000..eea378b6c6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/plan/plan_update.cpp @@ -0,0 +1,137 @@ +#include "binder/query/updating_clause/bound_delete_clause.h" +#include "binder/query/updating_clause/bound_insert_clause.h" +#include "binder/query/updating_clause/bound_merge_clause.h" +#include "binder/query/updating_clause/bound_set_clause.h" +#include "planner/operator/persistent/logical_merge.h" +#include "planner/planner.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +void Planner::planUpdatingClause(const BoundUpdatingClause& updatingClause, LogicalPlan& plan) { + switch (updatingClause.getClauseType()) { + case ClauseType::INSERT: { + planInsertClause(updatingClause, plan); + return; + } + case ClauseType::MERGE: { + planMergeClause(updatingClause, plan); + return; + } + case ClauseType::SET: { + planSetClause(updatingClause, plan); + return; + } + case ClauseType::DELETE_: { + planDeleteClause(updatingClause, plan); + return; + } + default: + KU_UNREACHABLE; + } +} + +void Planner::planInsertClause(const BoundUpdatingClause& updatingClause, LogicalPlan& plan) { + auto& insertClause = updatingClause.constCast(); + if (plan.isEmpty()) { // E.g. CREATE (a:Person {age:20}) + appendDummyScan(plan); + } else { + appendAccumulate(plan); + } + if (insertClause.hasNodeInfo()) { + appendInsertNode(insertClause.getNodeInfos(), plan); + } + if (insertClause.hasRelInfo()) { + appendInsertRel(insertClause.getRelInfos(), plan); + } +} + +void Planner::planMergeClause(const BoundUpdatingClause& updatingClause, LogicalPlan& plan) { + auto& mergeClause = updatingClause.constCast(); + expression_vector predicates; + if (mergeClause.hasPredicate()) { + predicates = mergeClause.getPredicate()->splitOnAND(); + } + // Collect merge hash keys. See LogicalMerge for details. + expression_vector keys; + for (auto& expr : mergeClause.getColumnDataExprs()) { + if (expr->expressionType == ExpressionType::LITERAL || + expr->expressionType == ExpressionType::PARAMETER) { + continue; + } + keys.push_back(expr); + } + if (!plan.isEmpty()) { + for (auto& node : mergeClause.getQueryGraphCollection()->getQueryNodes()) { + if (plan.getSchema()->isExpressionInScope(*node->getInternalID())) { + keys.push_back(node->getInternalID()); + } + } + } + auto existenceMark = mergeClause.getExistenceMark(); + planOptionalMatch(*mergeClause.getQueryGraphCollection(), predicates, existenceMark, plan, + nullptr /* hint */); + auto merge = std::make_shared(existenceMark, keys, plan.getLastOperator()); + if (mergeClause.hasInsertNodeInfo()) { + for (auto& info : mergeClause.getInsertNodeInfos()) { + merge->addInsertNodeInfo(createLogicalInsertInfo(info)->copy()); + } + } + if (mergeClause.hasInsertRelInfo()) { + for (auto& info : mergeClause.getInsertRelInfos()) { + merge->addInsertRelInfo(createLogicalInsertInfo(info)->copy()); + } + } + if (mergeClause.hasOnCreateSetNodeInfo()) { + for (auto& info : mergeClause.getOnCreateSetNodeInfos()) { + merge->addOnCreateSetNodeInfo(info.copy()); + } + } + if (mergeClause.hasOnCreateSetRelInfo()) { + for (auto& info : mergeClause.getOnCreateSetRelInfos()) { + merge->addOnCreateSetRelInfo(info.copy()); + } + } + if (mergeClause.hasOnMatchSetNodeInfo()) { + for (auto& info : mergeClause.getOnMatchSetNodeInfos()) { + merge->addOnMatchSetNodeInfo(info.copy()); + } + } + if (mergeClause.hasOnMatchSetRelInfo()) { + for (auto& info : mergeClause.getOnMatchSetRelInfos()) { + merge->addOnMatchSetRelInfo(info.copy()); + } + } + appendFlattens(merge->getGroupsPosToFlatten(), plan); + merge->setChild(0, plan.getLastOperator()); + merge->computeFactorizedSchema(); + plan.setLastOperator(merge); +} + +void Planner::planSetClause(const BoundUpdatingClause& updatingClause, LogicalPlan& plan) { + appendAccumulate(plan); + auto& setClause = updatingClause.constCast(); + if (setClause.hasNodeInfo()) { + appendSetProperty(setClause.getNodeInfos(), plan); + } + if (setClause.hasRelInfo()) { + appendSetProperty(setClause.getRelInfos(), plan); + } +} + +void Planner::planDeleteClause(const BoundUpdatingClause& updatingClause, LogicalPlan& plan) { + appendAccumulate(plan); + auto& deleteClause = updatingClause.constCast(); + if (deleteClause.hasRelInfo()) { + appendDelete(deleteClause.getRelInfos(), plan); + } + if (deleteClause.hasNodeInfo()) { + appendDelete(deleteClause.getNodeInfos(), plan); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/planner.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/planner.cpp new file mode 100644 index 0000000000..0f13e7c5f5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/planner.cpp @@ -0,0 +1,129 @@ +#include "planner/planner.h" + +#include "main/client_context.h" +#include "main/database.h" + +using namespace lbug::binder; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace planner { + +bool QueryGraphPlanningInfo::containsCorrExpr(const Expression& expr) const { + for (auto& corrExpr : corrExprs) { + if (*corrExpr == expr) { + return true; + } + } + return false; +} + +expression_vector PropertyExprCollection::getProperties(const Expression& pattern) const { + if (!patternNameToProperties.contains(pattern.getUniqueName())) { + return binder::expression_vector{}; + } + return patternNameToProperties.at(pattern.getUniqueName()); +} + +expression_vector PropertyExprCollection::getProperties() const { + expression_vector result; + for (auto& [_, exprs] : patternNameToProperties) { + for (auto& expr : exprs) { + result.push_back(expr); + } + } + return result; +} + +void PropertyExprCollection::addProperties(const std::string& patternName, + std::shared_ptr property) { + if (!patternNameToProperties.contains(patternName)) { + patternNameToProperties.insert({patternName, expression_vector{}}); + } + for (auto& p : patternNameToProperties.at(patternName)) { + if (*p == *property) { + return; + } + } + patternNameToProperties.at(patternName).push_back(property); +} + +void PropertyExprCollection::clear() { + patternNameToProperties.clear(); +} + +Planner::Planner(main::ClientContext* clientContext) + : clientContext{clientContext}, cardinalityEstimator{clientContext}, context{}, + plannerExtensions{clientContext->getDatabase()->getPlannerExtensions()} {} + +LogicalPlan Planner::planStatement(const BoundStatement& statement) { + switch (statement.getStatementType()) { + case StatementType::QUERY: { + return planQuery(statement); + } + case StatementType::CREATE_TABLE: { + return planCreateTable(statement); + } + case StatementType::CREATE_SEQUENCE: { + return planCreateSequence(statement); + } + case StatementType::CREATE_TYPE: { + return planCreateType(statement); + } + case StatementType::COPY_FROM: { + return planCopyFrom(statement); + } + case StatementType::COPY_TO: { + return planCopyTo(statement); + } + case StatementType::DROP: { + return planDrop(statement); + } + case StatementType::ALTER: { + return planAlter(statement); + } + case StatementType::STANDALONE_CALL: { + return planStandaloneCall(statement); + } + case StatementType::STANDALONE_CALL_FUNCTION: { + return planStandaloneCallFunction(statement); + } + case StatementType::EXPLAIN: { + return planExplain(statement); + } + case StatementType::CREATE_MACRO: { + return planCreateMacro(statement); + } + case StatementType::TRANSACTION: { + return planTransaction(statement); + } + case StatementType::EXTENSION: { + return planExtension(statement); + } + case StatementType::EXPORT_DATABASE: { + return planExportDatabase(statement); + } + case StatementType::IMPORT_DATABASE: { + return planImportDatabase(statement); + } + case StatementType::ATTACH_DATABASE: { + return planAttachDatabase(statement); + } + case StatementType::DETACH_DATABASE: { + return planDetachDatabase(statement); + } + case StatementType::USE_DATABASE: { + return planUseDatabase(statement); + } + case StatementType::EXTENSION_CLAUSE: { + return planExtensionClause(statement); + } + default: + KU_UNREACHABLE; + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/query_planner.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/query_planner.cpp new file mode 100644 index 0000000000..d149bdd065 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/query_planner.cpp @@ -0,0 +1,73 @@ +#include "binder/query/bound_regular_query.h" +#include "planner/operator/logical_union.h" +#include "planner/planner.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace planner { + +LogicalPlan Planner::planQuery(const BoundStatement& boundStatement) { + auto& regularQuery = boundStatement.constCast(); + if (regularQuery.getNumSingleQueries() == 1) { + return planSingleQuery(*regularQuery.getSingleQuery(0)); + } + std::vector childrenPlans; + for (auto i = 0u; i < regularQuery.getNumSingleQueries(); i++) { + childrenPlans.push_back(planSingleQuery(*regularQuery.getSingleQuery(i))); + } + auto exprs = regularQuery.getStatementResult()->getColumns(); + return createUnionPlan(childrenPlans, exprs, regularQuery.getIsUnionAll(0)); +} + +LogicalPlan Planner::createUnionPlan(std::vector& childrenPlans, + const expression_vector& expressions, bool isUnionAll) { + KU_ASSERT(!childrenPlans.empty()); + auto plan = LogicalPlan(); + std::vector> children; + children.reserve(childrenPlans.size()); + for (auto& childPlan : childrenPlans) { + children.push_back(childPlan.getLastOperator()); + } + // we compute the schema based on first child + auto union_ = std::make_shared(expressions, std::move(children)); + for (auto i = 0u; i < childrenPlans.size(); ++i) { + appendFlattens(union_->getGroupsPosToFlatten(i), childrenPlans[i]); + union_->setChild(i, childrenPlans[i].getLastOperator()); + } + union_->computeFactorizedSchema(); + plan.setLastOperator(union_); + if (!isUnionAll && !expressions.empty()) { + appendDistinct(expressions, plan); + } + return plan; +} + +expression_vector Planner::getProperties(const Expression& pattern) const { + KU_ASSERT(pattern.expressionType == ExpressionType::PATTERN); + return propertyExprCollection.getProperties(pattern); +} + +JoinOrderEnumeratorContext Planner::enterNewContext() { + auto prevContext = std::move(context); + context = JoinOrderEnumeratorContext(); + return prevContext; +} + +void Planner::exitContext(JoinOrderEnumeratorContext prevContext) { + context = std::move(prevContext); +} + +PropertyExprCollection Planner::enterNewPropertyExprCollection() { + auto prevCollection = std::move(propertyExprCollection); + propertyExprCollection = PropertyExprCollection(); + return prevCollection; +} + +void Planner::exitPropertyExprCollection(PropertyExprCollection collection) { + propertyExprCollection = std::move(collection); +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/planner/subplans_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/subplans_table.cpp new file mode 100644 index 0000000000..48e8335360 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/planner/subplans_table.cpp @@ -0,0 +1,113 @@ +#include "planner/subplans_table.h" + +using namespace lbug::binder; + +namespace lbug { +namespace planner { + +SubgraphPlans::SubgraphPlans(const SubqueryGraph& subqueryGraph) { + for (auto i = 0u; i < subqueryGraph.queryGraph.getNumQueryNodes(); ++i) { + if (subqueryGraph.queryNodesSelector[i]) { + nodeIDsToEncode.push_back(subqueryGraph.queryGraph.getQueryNode(i)->getInternalID()); + } + } + maxCost = UINT64_MAX; +} + +void SubgraphPlans::addPlan(LogicalPlan plan) { + if (plans.size() > MAX_NUM_PLANS) { + return; + } + auto planCode = encodePlan(plan); + if (!encodedPlan2PlanIdx.contains(planCode)) { + encodedPlan2PlanIdx.insert({planCode, plans.size()}); + if (maxCost == UINT64_MAX || plan.getCost() > maxCost) { // update max cost + maxCost = plan.getCost(); + } + plans.push_back(std::move(plan)); + } else { + auto planIdx = encodedPlan2PlanIdx.at(planCode); + if (plan.getCost() < plans[planIdx].getCost()) { + if (plans[planIdx].getCost() == maxCost) { // update max cost + maxCost = 0; + for (auto& plan_ : plans) { + if (plan_.getCost() > maxCost) { + maxCost = plan_.getCost(); + } + } + } + plans[planIdx] = std::move(plan); + } + } +} + +std::bitset SubgraphPlans::encodePlan(const LogicalPlan& plan) { + auto schema = plan.getSchema(); + std::bitset result; + result.reset(); + for (auto i = 0u; i < nodeIDsToEncode.size(); ++i) { + result[i] = schema->getGroup(schema->getGroupPos(*nodeIDsToEncode[i]))->isFlat(); + } + return result; +} + +std::vector DPLevel::getSubqueryGraphs() { + std::vector result; + for (auto& [subGraph, _] : subgraph2Plans) { + result.push_back(subGraph); + } + return result; +} + +void DPLevel::addPlan(const SubqueryGraph& subqueryGraph, LogicalPlan plan) { + if (subgraph2Plans.size() > MAX_NUM_SUBGRAPH) { + return; + } + if (!contains(subqueryGraph)) { + subgraph2Plans.insert({subqueryGraph, SubgraphPlans(subqueryGraph)}); + } + subgraph2Plans.at(subqueryGraph).addPlan(std::move(plan)); +} + +void SubPlansTable::resize(uint32_t newSize) { + auto prevSize = dpLevels.size(); + dpLevels.resize(newSize); + for (auto i = prevSize; i < newSize; ++i) { + dpLevels[i] = DPLevel(); + } +} + +uint64_t SubPlansTable::getMaxCost(const SubqueryGraph& subqueryGraph) const { + return containSubgraphPlans(subqueryGraph) ? + getDPLevel(subqueryGraph).getSubgraphPlans(subqueryGraph).getMaxCost() : + UINT64_MAX; +} + +bool SubPlansTable::containSubgraphPlans(const SubqueryGraph& subqueryGraph) const { + return getDPLevel(subqueryGraph).contains(subqueryGraph); +} + +const std::vector& SubPlansTable::getSubgraphPlans( + const SubqueryGraph& subqueryGraph) const { + auto& dpLevel = getDPLevel(subqueryGraph); + KU_ASSERT(dpLevel.contains(subqueryGraph)); + return dpLevel.getSubgraphPlans(subqueryGraph).getPlans(); +} + +std::vector SubPlansTable::getSubqueryGraphs(uint32_t level) { + return dpLevels[level].getSubqueryGraphs(); +} + +void SubPlansTable::addPlan(const SubqueryGraph& subqueryGraph, LogicalPlan plan) { + auto& dpLevel = getDPLevelUnsafe(subqueryGraph); + dpLevel.addPlan(subqueryGraph, std::move(plan)); +} + +void SubPlansTable::clear() { + for (auto& dpLevel : dpLevels) { + dpLevel.clear(); + } +} + +} // namespace planner +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/CMakeLists.txt new file mode 100644 index 0000000000..3b9a277895 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/CMakeLists.txt @@ -0,0 +1,13 @@ +add_subdirectory(map) +add_subdirectory(operator) +add_subdirectory(result) + +add_library(lbug_processor + OBJECT + warning_context.cpp + processor.cpp + processor_task.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/CMakeLists.txt new file mode 100644 index 0000000000..3bd69011e4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/CMakeLists.txt @@ -0,0 +1,51 @@ +add_library(lbug_processor_mapper + OBJECT + create_arrow_result_collector.cpp + create_factorized_table_scan.cpp + create_result_collector.cpp + expression_mapper.cpp + map_acc_hash_join.cpp + map_accumulate.cpp + map_aggregate.cpp + map_standalone_call.cpp + map_table_function_call.cpp + map_copy_to.cpp + map_copy_from.cpp + map_insert.cpp + map_create_macro.cpp + map_cross_product.cpp + map_ddl.cpp + map_delete.cpp + map_distinct.cpp + map_explain.cpp + map_expressions_scan.cpp + map_dummy_scan.cpp + map_dummy_sink.cpp + map_empty_result.cpp + map_extend.cpp + map_filter.cpp + map_flatten.cpp + map_hash_join.cpp + map_index_scan_node.cpp + map_intersect.cpp + map_label_filter.cpp + map_limit.cpp + map_merge.cpp + map_multiplicity_reducer.cpp + map_noop.cpp + map_order_by.cpp + map_path_property_probe.cpp + map_projection.cpp + map_recursive_extend.cpp + map_scan_node_table.cpp + map_semi_masker.cpp + map_set.cpp + map_simple.cpp + map_transaction.cpp + map_union.cpp + map_unwind.cpp + plan_mapper.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_arrow_result_collector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_arrow_result_collector.cpp new file mode 100644 index 0000000000..1a044b6a0e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_arrow_result_collector.cpp @@ -0,0 +1,29 @@ +#include "processor/operator/arrow_result_collector.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::createArrowResultCollector( + ArrowResultConfig arrowConfig, const binder::expression_vector& expressions, + planner::Schema* schema, std::unique_ptr prevOperator) { + std::vector columnDataPos; + std::vector columnTypes; + for (auto& expr : expressions) { + columnDataPos.push_back(getDataPos(*expr, *schema)); + columnTypes.push_back(expr->getDataType().copy()); + } + auto sharedState = std::make_shared(); + auto opInfo = + ArrowResultCollectorInfo(arrowConfig.chunkSize, columnDataPos, std::move(columnTypes)); + auto printInfo = OPPrintInfo::EmptyInfo(); + auto op = std::make_unique(sharedState, std::move(opInfo), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); + op->setDescriptor(std::make_unique(schema)); + return op; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_factorized_table_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_factorized_table_scan.cpp new file mode 100644 index 0000000000..702d2569e7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_factorized_table_scan.cpp @@ -0,0 +1,90 @@ +#include "processor/operator/table_function_call.h" +#include "processor/operator/table_scan/ftable_scan_function.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::binder; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::createFTableScan(const expression_vector& exprs, + std::vector colIndices, const Schema* schema, + std::shared_ptr table, uint64_t maxMorselSize, physical_op_vector_t children) { + std::vector outPosV; + if (!exprs.empty()) { + KU_ASSERT(schema); + outPosV = getDataPos(exprs, *schema); + } + auto function = FTableScan::getFunction(); + auto bindData = + std::make_unique(table, std::move(colIndices), maxMorselSize); + auto info = TableFunctionCallInfo(); + info.function = *function->copy(); + info.bindData = std::move(bindData); + info.outPosV = std::move(outPosV); + auto initInput = TableFuncInitSharedStateInput(info.bindData.get(), executionContext); + auto sharedState = info.function.initSharedStateFunc(initInput); + auto printInfo = std::make_unique(function->name, exprs); + auto result = std::make_unique(std::move(info), sharedState, getOperatorID(), + std::move(printInfo)); + for (auto& child : children) { + result->addChild(std::move(child)); + } + return result; +} + +std::unique_ptr PlanMapper::createFTableScan(const expression_vector& exprs, + const std::vector& colIndices, const Schema* schema, + std::shared_ptr table, uint64_t maxMorselSize) { + physical_op_vector_t children; + return createFTableScan(exprs, colIndices, schema, std::move(table), maxMorselSize, + std::move(children)); +} + +std::unique_ptr PlanMapper::createEmptyFTableScan( + std::shared_ptr table, uint64_t maxMorselSize, physical_op_vector_t children) { + return createFTableScan(expression_vector{}, std::vector{}, nullptr /* schema */, + std::move(table), maxMorselSize, std::move(children)); +} + +std::unique_ptr PlanMapper::createEmptyFTableScan( + std::shared_ptr table, uint64_t maxMorselSize, + std::unique_ptr child) { + physical_op_vector_t children; + children.push_back(std::move(child)); + return createFTableScan(expression_vector{}, std::vector{}, nullptr /* schema */, + std::move(table), maxMorselSize, std::move(children)); +} + +std::unique_ptr PlanMapper::createEmptyFTableScan( + std::shared_ptr table, uint64_t maxMorselSize) { + physical_op_vector_t children; + return createFTableScan(expression_vector{}, std::vector{}, nullptr /* schema */, + std::move(table), maxMorselSize, std::move(children)); +} + +std::unique_ptr PlanMapper::createFTableScanAligned( + const expression_vector& exprs, const Schema* schema, std::shared_ptr table, + uint64_t maxMorselSize, physical_op_vector_t children) { + std::vector colIndices; + colIndices.reserve(exprs.size()); + for (auto i = 0u; i < exprs.size(); ++i) { + colIndices.push_back(i); + } + return createFTableScan(exprs, std::move(colIndices), schema, std::move(table), maxMorselSize, + std::move(children)); +} + +std::unique_ptr PlanMapper::createFTableScanAligned( + const expression_vector& exprs, const Schema* schema, std::shared_ptr table, + uint64_t maxMorselSize) { + physical_op_vector_t children; + return createFTableScanAligned(exprs, schema, std::move(table), maxMorselSize, + std::move(children)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_result_collector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_result_collector.cpp new file mode 100644 index 0000000000..1111e31c9e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/create_result_collector.cpp @@ -0,0 +1,38 @@ +#include "processor/operator/result_collector.h" +#include "processor/plan_mapper.h" +#include "processor/result/factorized_table_util.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::binder; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::createResultCollector(AccumulateType accumulateType, + const expression_vector& expressions, Schema* schema, + std::unique_ptr prevOperator) { + std::vector payloadsPos; + for (auto& expr : expressions) { + payloadsPos.push_back(getDataPos(*expr, *schema)); + } + auto tableSchema = FactorizedTableUtils::createFTableSchema(expressions, *schema); + if (accumulateType == AccumulateType::OPTIONAL_) { + auto columnSchema = ColumnSchema(false /* isUnFlat */, INVALID_DATA_CHUNK_POS, + LogicalTypeUtils::getRowLayoutSize(LogicalType::BOOL())); + tableSchema.appendColumn(std::move(columnSchema)); + } + auto table = std::make_shared(storage::MemoryManager::Get(*clientContext), + tableSchema.copy()); + auto sharedState = std::make_shared(std::move(table)); + auto opInfo = ResultCollectorInfo(accumulateType, std::move(tableSchema), payloadsPos); + auto printInfo = std::make_unique(expressions, accumulateType); + auto op = std::make_unique(std::move(opInfo), std::move(sharedState), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); + op->setDescriptor(std::make_unique(schema)); + return op; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/expression_mapper.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/expression_mapper.cpp new file mode 100644 index 0000000000..08e68aadd0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/expression_mapper.cpp @@ -0,0 +1,215 @@ +#include "processor/expression_mapper.h" + +#include "binder/expression/case_expression.h" +#include "binder/expression/expression_util.h" +#include "binder/expression/lambda_expression.h" +#include "binder/expression/literal_expression.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/parameter_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression_visitor.h" // IWYU pragma: keep (used in assert) +#include "common/exception/not_implemented.h" +#include "common/string_format.h" +#include "expression_evaluator/case_evaluator.h" +#include "expression_evaluator/function_evaluator.h" +#include "expression_evaluator/lambda_evaluator.h" +#include "expression_evaluator/literal_evaluator.h" +#include "expression_evaluator/path_evaluator.h" +#include "expression_evaluator/pattern_evaluator.h" +#include "expression_evaluator/reference_evaluator.h" +#include "planner/operator/schema.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::evaluator; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +static bool canEvaluateAsFunction(ExpressionType expressionType) { + switch (expressionType) { + case ExpressionType::OR: + case ExpressionType::XOR: + case ExpressionType::AND: + case ExpressionType::NOT: + case ExpressionType::EQUALS: + case ExpressionType::NOT_EQUALS: + case ExpressionType::GREATER_THAN: + case ExpressionType::GREATER_THAN_EQUALS: + case ExpressionType::LESS_THAN: + case ExpressionType::LESS_THAN_EQUALS: + case ExpressionType::IS_NULL: + case ExpressionType::IS_NOT_NULL: + case ExpressionType::FUNCTION: + return true; + default: + return false; + } +} + +std::unique_ptr ExpressionMapper::getEvaluator( + std::shared_ptr expression) { + if (schema == nullptr) { + return getConstantEvaluator(std::move(expression)); + } + auto expressionType = expression->expressionType; + if (schema->isExpressionInScope(*expression)) { + return getReferenceEvaluator(std::move(expression)); + } else if (ExpressionType::LITERAL == expressionType) { + return getLiteralEvaluator(std::move(expression)); + } else if (ExpressionUtil::isNodePattern(*expression)) { + return getNodeEvaluator(std::move(expression)); + } else if (ExpressionUtil::isRelPattern(*expression)) { + return getRelEvaluator(std::move(expression)); + } else if (expressionType == ExpressionType::PATH) { + return getPathEvaluator(std::move(expression)); + } else if (expressionType == ExpressionType::PARAMETER) { + return getParameterEvaluator(std::move(expression)); + } else if (expressionType == ExpressionType::CASE_ELSE) { + return getCaseEvaluator(std::move(expression)); + } else if (canEvaluateAsFunction(expressionType)) { + return getFunctionEvaluator(std::move(expression)); + } else if (parentEvaluator != nullptr) { + return getLambdaParamEvaluator(std::move(expression)); + } else { + // LCOV_EXCL_START + throw NotImplementedException(stringFormat("Cannot evaluate expression with type {}.", + ExpressionTypeUtil::toString(expressionType))); + // LCOV_EXCL_STOP + } +} + +std::unique_ptr ExpressionMapper::getConstantEvaluator( + std::shared_ptr expression) { + KU_ASSERT(ConstantExpressionVisitor::isConstant(*expression)); + auto expressionType = expression->expressionType; + if (ExpressionType::LITERAL == expressionType) { + return getLiteralEvaluator(std::move(expression)); + } else if (ExpressionType::CASE_ELSE == expressionType) { + return getCaseEvaluator(std::move(expression)); + } else if (canEvaluateAsFunction(expressionType)) { + return getFunctionEvaluator(std::move(expression)); + } else { + // LCOV_EXCL_START + throw NotImplementedException(stringFormat("Cannot evaluate expression with type {}.", + ExpressionTypeUtil::toString(expressionType))); + // LCOV_EXCL_STOP + } +} + +std::unique_ptr ExpressionMapper::getLiteralEvaluator( + std::shared_ptr expression) { + auto& literalExpression = expression->constCast(); + return std::make_unique(std::move(expression), + literalExpression.getValue()); +} + +std::unique_ptr ExpressionMapper::getParameterEvaluator( + std::shared_ptr expression) { + auto& parameterExpression = expression->constCast(); + return std::make_unique(std::move(expression), + parameterExpression.getValue()); +} + +std::unique_ptr ExpressionMapper::getReferenceEvaluator( + std::shared_ptr expression) const { + KU_ASSERT(schema != nullptr); + auto vectorPos = DataPos(schema->getExpressionPos(*expression)); + auto expressionGroup = schema->getGroup(expression->getUniqueName()); + return std::make_unique(std::move(expression), + expressionGroup->isFlat(), vectorPos); +} + +std::unique_ptr ExpressionMapper::getLambdaParamEvaluator( + std::shared_ptr expression) { + return std::make_unique(std::move(expression)); +} + +std::unique_ptr ExpressionMapper::getCaseEvaluator( + std::shared_ptr expression) { + auto caseExpression = reinterpret_cast(expression.get()); + std::vector alternativeEvaluators; + for (auto i = 0u; i < caseExpression->getNumCaseAlternatives(); ++i) { + auto alternative = caseExpression->getCaseAlternative(i); + auto whenEvaluator = getEvaluator(alternative->whenExpression); + auto thenEvaluator = getEvaluator(alternative->thenExpression); + alternativeEvaluators.push_back( + CaseAlternativeEvaluator(std::move(whenEvaluator), std::move(thenEvaluator))); + } + auto elseEvaluator = getEvaluator(caseExpression->getElseExpression()); + return std::make_unique(std::move(expression), + std::move(alternativeEvaluators), std::move(elseEvaluator)); +} + +std::unique_ptr ExpressionMapper::getFunctionEvaluator( + std::shared_ptr expression) { + evaluator_vector_t childrenEvaluators; + if (expression->getNumChildren() == 2 && + expression->getChild(1)->expressionType == ExpressionType::LAMBDA) { + childrenEvaluators.push_back(getEvaluator(expression->getChild(0))); + auto result = + std::make_unique(expression, std::move(childrenEvaluators)); + auto recursiveExprMapper = ExpressionMapper(schema, result.get()); + auto& lambdaExpr = expression->getChild(1)->constCast(); + result->setLambdaRootEvaluator( + recursiveExprMapper.getEvaluator(lambdaExpr.getFunctionExpr())); + return result; + } + childrenEvaluators = getEvaluators(expression->getChildren()); + return std::make_unique(std::move(expression), + std::move(childrenEvaluators)); +} + +std::unique_ptr ExpressionMapper::getNodeEvaluator( + std::shared_ptr expression) { + auto node = expression->constPtrCast(); + expression_vector children; + children.push_back(node->getInternalID()); + children.push_back(node->getLabelExpression()); + for (auto& property : node->getPropertyExpressions()) { + children.push_back(property); + } + auto childrenEvaluators = getEvaluators(children); + return std::make_unique(std::move(expression), + std::move(childrenEvaluators)); +} + +std::unique_ptr ExpressionMapper::getRelEvaluator( + std::shared_ptr expression) { + auto rel = expression->constPtrCast(); + expression_vector children; + children.push_back(rel->getSrcNode()->getInternalID()); + children.push_back(rel->getDstNode()->getInternalID()); + children.push_back(rel->getLabelExpression()); + for (auto& property : rel->getPropertyExpressions()) { + children.push_back(property); + } + auto childrenEvaluators = getEvaluators(children); + if (rel->hasDirectionExpr()) { + auto directionEvaluator = getEvaluator(rel->getDirectionExpr()); + return std::make_unique(std::move(expression), + std::move(childrenEvaluators), std::move(directionEvaluator)); + } + return std::make_unique(std::move(expression), + std::move(childrenEvaluators)); +} + +std::unique_ptr ExpressionMapper::getPathEvaluator( + std::shared_ptr expression) { + auto children = getEvaluators(expression->getChildren()); + return std::make_unique(std::move(expression), std::move(children)); +} + +std::vector> ExpressionMapper::getEvaluators( + const expression_vector& expressions) { + std::vector> evaluators; + evaluators.reserve(expressions.size()); + for (auto& expression : expressions) { + evaluators.push_back(getEvaluator(expression)); + } + return evaluators; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_acc_hash_join.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_acc_hash_join.cpp new file mode 100644 index 0000000000..72970ad41f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_acc_hash_join.cpp @@ -0,0 +1,24 @@ +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +static PhysicalOperator* getTableScan(const PhysicalOperator* joinRoot) { + auto op = joinRoot->getChild(0); + while (op->getOperatorType() != PhysicalOperatorType::TABLE_FUNCTION_CALL) { + KU_ASSERT(op->getNumChildren() != 0); + op = op->getChild(0); + } + return op; +} + +void PlanMapper::mapSIPJoin(PhysicalOperator* joinRoot) { + auto tableScan = getTableScan(joinRoot); + auto resultCollector = tableScan->moveUnaryChild(); + joinRoot->addChild(std::move(resultCollector)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_accumulate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_accumulate.cpp new file mode 100644 index 0000000000..edc222dff8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_accumulate.cpp @@ -0,0 +1,33 @@ +#include "common/system_config.h" +#include "planner/operator/logical_accumulate.h" +#include "processor/operator/result_collector.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapAccumulate( + const LogicalOperator* logicalOperator) { + const auto& acc = logicalOperator->constCast(); + auto outSchema = acc.getSchema(); + auto inSchema = acc.getChild(0)->getSchema(); + auto prevOperator = mapOperator(acc.getChild(0).get()); + auto expressions = acc.getPayloads(); + auto resultCollector = createResultCollector(acc.getAccumulateType(), expressions, inSchema, + std::move(prevOperator)); + auto table = resultCollector->getResultFTable(); + auto maxMorselSize = table->hasUnflatCol() ? 1 : DEFAULT_VECTOR_CAPACITY; + if (acc.hasMark()) { + expressions.push_back(acc.getMark()); + } + physical_op_vector_t children; + children.push_back(std::move(resultCollector)); + return createFTableScanAligned(expressions, outSchema, table, maxMorselSize, + std::move(children)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_aggregate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_aggregate.cpp new file mode 100644 index 0000000000..5d01d37473 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_aggregate.cpp @@ -0,0 +1,227 @@ +#include "binder/expression/aggregate_function_expression.h" +#include "common/copy_constructors.h" +#include "common/types/types.h" +#include "planner/operator/logical_aggregate.h" +#include "processor/operator/aggregate/hash_aggregate.h" +#include "processor/operator/aggregate/hash_aggregate_scan.h" +#include "processor/operator/aggregate/simple_aggregate.h" +#include "processor/operator/aggregate/simple_aggregate_scan.h" +#include "processor/plan_mapper.h" +#include "processor/result/result_set_descriptor.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +static std::vector getAggregateInputInfos(const expression_vector& keys, + const expression_vector& aggregates, const Schema& schema) { + // Collect unFlat groups from + std::unordered_set groupByGroupPosSet; + for (auto& expression : keys) { + groupByGroupPosSet.insert(schema.getGroupPos(*expression)); + } + std::unordered_set unFlatAggregateGroupPosSet; + for (auto groupPos : schema.getGroupsPosInScope()) { + if (groupByGroupPosSet.contains(groupPos)) { + continue; + } + if (schema.getGroup(groupPos)->isFlat()) { + continue; + } + unFlatAggregateGroupPosSet.insert(groupPos); + } + std::vector result; + for (auto& expression : aggregates) { + auto aggregateVectorPos = DataPos::getInvalidPos(); + if (expression->getNumChildren() != 0) { // COUNT(*) has no children + auto child = expression->getChild(0); + aggregateVectorPos = DataPos{schema.getExpressionPos(*child)}; + } + std::vector multiplicityChunksPos; + for (auto& groupPos : unFlatAggregateGroupPosSet) { + if (groupPos != aggregateVectorPos.dataChunkPos) { + multiplicityChunksPos.push_back(groupPos); + } + } + auto aggExpr = expression->constPtrCast(); + auto distinctAggKeyType = aggExpr->isDistinct() ? + expression->getChild(0)->getDataType().copy() : + LogicalType::ANY(); + result.emplace_back(aggregateVectorPos, std::move(multiplicityChunksPos), + std::move(distinctAggKeyType)); + } + return result; +} + +static expression_vector getKeyExpressions(const expression_vector& expressions, + const Schema& schema, bool isFlat) { + expression_vector result; + for (auto& expression : expressions) { + if (schema.getGroup(schema.getGroupPos(*expression))->isFlat() == isFlat) { + result.emplace_back(expression); + } + } + return result; +} + +static std::vector getAggFunctions(const expression_vector& aggregates) { + std::vector aggregateFunctions; + for (auto& expression : aggregates) { + auto aggExpr = expression->constPtrCast(); + aggregateFunctions.push_back(aggExpr->getFunction().copy()); + } + return aggregateFunctions; +} + +static void writeAggResultWithNullToVector(ValueVector& vector, uint64_t pos, + AggregateState* aggregateState) { + auto isNull = aggregateState->constCast().isNull; + vector.setNull(pos, isNull); + if (!isNull) { + aggregateState->writeToVector(&vector, pos); + } +} + +static void writeAggResultWithoutNullToVector(ValueVector& vector, uint64_t pos, + AggregateState* aggregateState) { + vector.setNull(pos, false); + aggregateState->writeToVector(&vector, pos); +} + +static std::vector getMoveAggResultToVectorFuncs( + std::vector& aggregateFunctions) { + std::vector moveAggResultToVectorFuncs; + for (auto& aggregateFunction : aggregateFunctions) { + if (aggregateFunction.needToHandleNulls) { + moveAggResultToVectorFuncs.push_back(writeAggResultWithoutNullToVector); + } else { + moveAggResultToVectorFuncs.push_back(writeAggResultWithNullToVector); + } + } + return moveAggResultToVectorFuncs; +} + +std::unique_ptr PlanMapper::mapAggregate(const LogicalOperator* logicalOperator) { + auto& agg = logicalOperator->constCast(); + auto aggregates = agg.getAggregates(); + auto outSchema = agg.getSchema(); + auto child = agg.getChild(0).get(); + auto inSchema = child->getSchema(); + auto prevOperator = mapOperator(child); + if (agg.hasKeys()) { + return createHashAggregate(agg.getKeys(), agg.getDependentKeys(), aggregates, inSchema, + outSchema, std::move(prevOperator)); + } + auto aggFunctions = getAggFunctions(aggregates); + auto aggOutputPos = getDataPos(aggregates, *outSchema); + auto aggregateInputInfos = getAggregateInputInfos(agg.getAllKeys(), aggregates, *inSchema); + auto sharedState = + make_shared(clientContext, aggFunctions, aggregateInputInfos); + auto printInfo = std::make_unique(aggregates); + auto aggregate = make_unique(sharedState, std::move(aggFunctions), + copyVector(aggregateInputInfos), std::move(prevOperator), getOperatorID(), + printInfo->copy()); + aggregate->setDescriptor(std::make_unique(inSchema)); + auto finalizer = std::make_unique(sharedState, + std::move(aggregateInputInfos), getOperatorID(), printInfo->copy()); + finalizer->addChild(std::move(aggregate)); + aggFunctions = getAggFunctions(aggregates); + auto scan = std::make_unique(sharedState, + AggregateScanInfo{std::move(aggOutputPos), getMoveAggResultToVectorFuncs(aggFunctions)}, + getOperatorID(), printInfo->copy()); + scan->addChild(std::move(finalizer)); + return scan; +} + +static FactorizedTableSchema getFactorizedTableSchema(const expression_vector& flatKeys, + const expression_vector& unFlatKeys, const expression_vector& payloads, + const std::vector& aggregateFunctions) { + auto isUnFlat = false; + auto groupID = 0u; + auto tableSchema = FactorizedTableSchema(); + for (auto& flatKey : flatKeys) { + auto size = LogicalTypeUtils::getRowLayoutSize(flatKey->dataType); + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size)); + } + for (auto& unFlatKey : unFlatKeys) { + auto size = LogicalTypeUtils::getRowLayoutSize(unFlatKey->dataType); + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size)); + } + for (auto& payload : payloads) { + auto size = LogicalTypeUtils::getRowLayoutSize(payload->dataType); + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size)); + } + for (auto& aggregateFunc : aggregateFunctions) { + tableSchema.appendColumn( + ColumnSchema(isUnFlat, groupID, aggregateFunc.getAggregateStateSize())); + } + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, sizeof(hash_t))); + return tableSchema; +} + +std::unique_ptr PlanMapper::createDistinctHashAggregate( + const expression_vector& keys, const expression_vector& payloads, Schema* inSchema, + Schema* outSchema, std::unique_ptr prevOperator) { + return createHashAggregate(keys, payloads, expression_vector{} /* aggregates */, inSchema, + outSchema, std::move(prevOperator)); +} + +// Payloads are also group by keys except that they are functional dependent on keys so we don't +// need to hash or compare payloads. +std::unique_ptr PlanMapper::createHashAggregate(const expression_vector& keys, + const expression_vector& payloads, const expression_vector& aggregates, Schema* inSchema, + Schema* outSchema, std::unique_ptr prevOperator) { + // Create hash aggregate + auto aggFunctions = getAggFunctions(aggregates); + expression_vector allKeys; + allKeys.insert(allKeys.end(), keys.begin(), keys.end()); + allKeys.insert(allKeys.end(), payloads.begin(), payloads.end()); + auto aggregateInputInfos = getAggregateInputInfos(allKeys, aggregates, *inSchema); + auto flatKeys = getKeyExpressions(keys, *inSchema, true /* isFlat */); + auto unFlatKeys = getKeyExpressions(keys, *inSchema, false /* isFlat */); + std::vector keyTypes, payloadTypes; + for (auto& key : flatKeys) { + keyTypes.push_back(key->getDataType().copy()); + } + for (auto& key : unFlatKeys) { + keyTypes.push_back(key->getDataType().copy()); + } + for (auto& payload : payloads) { + payloadTypes.push_back(payload->getDataType().copy()); + } + auto tableSchema = getFactorizedTableSchema(flatKeys, unFlatKeys, payloads, aggFunctions); + HashAggregateInfo aggregateInfo{getDataPos(flatKeys, *inSchema), + getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)}; + + auto sharedState = + std::make_shared(clientContext, std::move(aggregateInfo), + aggFunctions, aggregateInputInfos, std::move(keyTypes), std::move(payloadTypes)); + auto printInfo = std::make_unique(allKeys, aggregates); + auto aggregate = make_unique(sharedState, std::move(aggFunctions), + std::move(aggregateInputInfos), std::move(prevOperator), getOperatorID(), + printInfo->copy()); + aggregate->setDescriptor(std::make_unique(inSchema)); + // Create AggScan. + expression_vector outputExpressions; + outputExpressions.insert(outputExpressions.end(), flatKeys.begin(), flatKeys.end()); + outputExpressions.insert(outputExpressions.end(), unFlatKeys.begin(), unFlatKeys.end()); + outputExpressions.insert(outputExpressions.end(), payloads.begin(), payloads.end()); + auto aggOutputPos = getDataPos(aggregates, *outSchema); + auto finalizer = + std::make_unique(sharedState, getOperatorID(), printInfo->copy()); + finalizer->addChild(std::move(aggregate)); + aggFunctions = getAggFunctions(aggregates); + auto scan = + std::make_unique(sharedState, getDataPos(outputExpressions, *outSchema), + AggregateScanInfo{std::move(aggOutputPos), getMoveAggResultToVectorFuncs(aggFunctions)}, + getOperatorID(), printInfo->copy()); + scan->addChild(std::move(finalizer)); + return scan; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_copy_from.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_copy_from.cpp new file mode 100644 index 0000000000..243b14eb27 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_copy_from.cpp @@ -0,0 +1,165 @@ +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "planner/operator/logical_partitioner.h" +#include "planner/operator/persistent/logical_copy_from.h" +#include "processor/expression_mapper.h" +#include "processor/operator/index_lookup.h" +#include "processor/operator/partitioner.h" +#include "processor/operator/persistent/copy_rel_batch_insert.h" +#include "processor/operator/persistent/node_batch_insert.h" +#include "processor/operator/persistent/rel_batch_insert.h" +#include "processor/operator/table_function_call.h" +#include "processor/plan_mapper.h" +#include "processor/result/factorized_table_util.h" +#include "processor/warning_context.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +using namespace lbug::binder; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapCopyFrom(const LogicalOperator* logicalOperator) { + const auto& copyFrom = logicalOperator->constCast(); + WarningContext::Get(*clientContext) + ->setIgnoreErrorsForCurrentQuery(copyFrom.getInfo()->getIgnoreErrorsOption()); + switch (copyFrom.getInfo()->tableType) { + case TableType::NODE: { + return mapCopyNodeFrom(logicalOperator); + } + case TableType::REL: { + return mapCopyRelFrom(logicalOperator); + } + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr PlanMapper::mapCopyNodeFrom( + const LogicalOperator* logicalOperator) { + auto& copyFrom = logicalOperator->constCast(); + const auto copyFromInfo = copyFrom.getInfo(); + const auto outFSchema = copyFrom.getSchema(); + auto prevOperator = mapOperator(copyFrom.getChild(0).get()); + auto fTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + + auto sharedState = std::make_shared(fTable); + if (prevOperator->getOperatorType() == PhysicalOperatorType::TABLE_FUNCTION_CALL) { + const auto call = prevOperator->ptrCast(); + sharedState->tableFuncSharedState = call->getSharedState().get(); + } + std::vector> columnEvaluators; + auto exprMapper = ExpressionMapper(outFSchema); + for (auto& expr : copyFromInfo->columnExprs) { + columnEvaluators.push_back(exprMapper.getEvaluator(expr)); + } + std::vector warningColumnTypes; + for (auto& column : copyFromInfo->getWarningColumns()) { + warningColumnTypes.push_back(column->getDataType().copy()); + } + auto info = std::make_unique(copyFromInfo->tableName, + std::move(warningColumnTypes), std::move(columnEvaluators), + copyFromInfo->columnEvaluateTypes); + auto printInfo = std::make_unique(copyFromInfo->tableName); + auto batchInsert = std::make_unique(std::move(info), std::move(sharedState), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); + batchInsert->setDescriptor(std::make_unique(copyFrom.getSchema())); + return batchInsert; +} + +std::unique_ptr PlanMapper::mapPartitioner( + const LogicalOperator* logicalOperator) { + auto& logicalPartitioner = logicalOperator->constCast(); + auto prevOperator = mapOperator(logicalPartitioner.getChild(0).get()); + auto outFSchema = logicalPartitioner.getSchema(); + auto& copyFromInfo = logicalPartitioner.copyFromInfo; + auto& extraInfo = copyFromInfo.extraInfo->constCast(); + PartitionerInfo partitionerInfo; + partitionerInfo.relOffsetDataPos = + getDataPos(*logicalPartitioner.getInfo().offset, *outFSchema); + partitionerInfo.infos.reserve(logicalPartitioner.getInfo().getNumInfos()); + for (auto i = 0u; i < logicalPartitioner.getInfo().getNumInfos(); i++) { + partitionerInfo.infos.emplace_back(logicalPartitioner.getInfo().getInfo(i).keyIdx, + PartitionerFunctions::partitionRelData); + } + std::vector columnTypes; + evaluator::evaluator_vector_t columnEvaluators; + auto exprMapper = ExpressionMapper(outFSchema); + for (auto& expr : copyFromInfo.columnExprs) { + columnTypes.push_back(expr->getDataType().copy()); + columnEvaluators.push_back(exprMapper.getEvaluator(expr)); + } + for (auto idx : extraInfo.internalIDColumnIndices) { + columnTypes[idx] = LogicalType::INTERNAL_ID(); + } + auto dataInfo = PartitionerDataInfo(copyFromInfo.tableName, extraInfo.fromTableName, + extraInfo.toTableName, LogicalType::copy(columnTypes), std::move(columnEvaluators), + copyFromInfo.columnEvaluateTypes); + auto sharedState = + std::make_shared(*MemoryManager::Get(*clientContext)); + expression_vector expressions; + for (auto& info : partitionerInfo.infos) { + expressions.push_back(copyFromInfo.columnExprs[info.keyIdx]); + } + auto printInfo = std::make_unique(expressions); + auto partitioner = + std::make_unique(std::move(partitionerInfo), std::move(dataInfo), + std::move(sharedState), std::move(prevOperator), getOperatorID(), std::move(printInfo)); + partitioner->setDescriptor(std::make_unique(outFSchema)); + return partitioner; +} + +std::unique_ptr PlanMapper::mapCopyRelFrom( + const LogicalOperator* logicalOperator) { + auto& copyFrom = logicalOperator->constCast(); + const auto copyFromInfo = copyFrom.getInfo(); + auto partitioner = mapOperator(copyFrom.getChild(0).get()); + KU_ASSERT(partitioner->getOperatorType() == PhysicalOperatorType::PARTITIONER); + auto partitionerSharedState = partitioner->ptrCast()->getSharedState(); + const auto catalog = Catalog::Get(*clientContext); + const auto transaction = transaction::Transaction::Get(*clientContext); + auto extraInfo = copyFromInfo->extraInfo->constCast(); + auto fromTableID = + catalog->getTableCatalogEntry(transaction, extraInfo.fromTableName)->getTableID(); + auto toTableID = + catalog->getTableCatalogEntry(transaction, extraInfo.toTableName)->getTableID(); + std::vector warningColumnTypes; + for (auto& column : copyFromInfo->getWarningColumns()) { + warningColumnTypes.push_back(column->getDataType().copy()); + } + auto fTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + auto batchInsertSharedState = std::make_shared(fTable); + // If the table entry doesn't exist, assume both directions + std::vector directions = {RelDataDirection::FWD, RelDataDirection::BWD}; + if (catalog->containsTable(transaction, copyFromInfo->tableName)) { + const auto& relGroupEntry = + catalog->getTableCatalogEntry(transaction, copyFromInfo->tableName) + ->constCast(); + directions = relGroupEntry.getRelDataDirections(); + } + + auto sink = std::make_unique(fTable, getOperatorID()); + for (auto direction : directions) { + auto insertInfo = std::make_unique(copyFromInfo->tableName, + copyVector(warningColumnTypes), fromTableID, toTableID, direction); + auto printInfo = std::make_unique(copyFromInfo->tableName); + auto progress = std::make_shared(); + auto batchInsert = std::make_unique(std::move(insertInfo), + partitionerSharedState, batchInsertSharedState, getOperatorID(), std::move(printInfo), + progress, std::make_unique()); + batchInsert->setDescriptor(std::make_unique(copyFrom.getSchema())); + sink->addChild(std::move(batchInsert)); + } + sink->addChild(std::move(partitioner)); + return sink; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_copy_to.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_copy_to.cpp new file mode 100644 index 0000000000..21619ea40b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_copy_to.cpp @@ -0,0 +1,43 @@ +#include "planner/operator/persistent/logical_copy_to.h" +#include "processor/operator/persistent/copy_to.h" +#include "processor/plan_mapper.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapCopyTo(const LogicalOperator* logicalOperator) { + auto& logicalCopyTo = logicalOperator->constCast(); + auto childSchema = logicalOperator->getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + std::vector vectorsToCopyPos; + std::vector isFlat; + std::vector types; + for (auto& expression : childSchema->getExpressionsInScope()) { + vectorsToCopyPos.emplace_back(childSchema->getExpressionPos(*expression)); + isFlat.push_back(childSchema->getGroup(expression)->isFlat()); + types.push_back(expression->dataType.copy()); + } + auto exportFunc = logicalCopyTo.getExportFunc(); + auto bindData = logicalCopyTo.getBindData(); + // TODO(Xiyang): Query: COPY (RETURN null) TO '/tmp/1.parquet', the datatype of the first + // column is ANY, should we solve the type at binder? + bindData->setDataType(std::move(types)); + auto sharedState = exportFunc.createSharedState(); + auto info = + CopyToInfo{exportFunc, std::move(bindData), std::move(vectorsToCopyPos), std::move(isFlat)}; + auto printInfo = + std::make_unique(info.bindData->columnNames, info.bindData->fileName); + auto copyTo = std::make_unique(std::move(info), std::move(sharedState), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); + copyTo->setDescriptor(std::make_unique(childSchema)); + return createEmptyFTableScan(FactorizedTable::EmptyTable(MemoryManager::Get(*clientContext)), 0, + std::move(copyTo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_create_macro.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_create_macro.cpp new file mode 100644 index 0000000000..cb7efec3b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_create_macro.cpp @@ -0,0 +1,25 @@ +#include "planner/operator/logical_create_macro.h" +#include "processor/operator/macro/create_macro.h" +#include "processor/plan_mapper.h" +#include "processor/result/factorized_table_util.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapCreateMacro( + const LogicalOperator* logicalOperator) { + auto& logicalCreateMacro = logicalOperator->constCast(); + auto createMacroInfo = + CreateMacroInfo(logicalCreateMacro.getMacroName(), logicalCreateMacro.getMacro()); + auto printInfo = std::make_unique(createMacroInfo.macroName); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable( + storage::MemoryManager::Get(*clientContext)); + return std::make_unique(std::move(createMacroInfo), std::move(messageTable), + getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_cross_product.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_cross_product.cpp new file mode 100644 index 0000000000..57142499a0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_cross_product.cpp @@ -0,0 +1,50 @@ +#include "common/system_config.h" +#include "planner/operator/logical_cross_product.h" +#include "processor/operator/cross_product.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapCrossProduct( + const LogicalOperator* logicalOperator) { + auto& logicalCrossProduct = logicalOperator->constCast(); + auto outSchema = logicalCrossProduct.getSchema(); + auto buildChild = logicalCrossProduct.getChild(1); + // map build side + auto buildSchema = buildChild->getSchema(); + auto buildSidePrevOperator = mapOperator(buildChild.get()); + auto expressions = buildSchema->getExpressionsInScope(); + auto resultCollector = createResultCollector(logicalCrossProduct.getAccumulateType(), + expressions, buildSchema, std::move(buildSidePrevOperator)); + // map probe side + auto probeSidePrevOperator = mapOperator(logicalCrossProduct.getChild(0).get()); + std::vector outVecPos; + std::vector colIndicesToScan; + if (logicalCrossProduct.hasMark()) { + expressions.push_back(logicalCrossProduct.getMark()); + } + for (auto i = 0u; i < expressions.size(); ++i) { + auto expression = expressions[i]; + outVecPos.emplace_back(outSchema->getExpressionPos(*expression)); + colIndicesToScan.push_back(i); + } + auto info = CrossProductInfo(std::move(outVecPos), std::move(colIndicesToScan)); + auto table = resultCollector->getResultFTable(); + auto maxMorselSize = table->hasUnflatCol() ? 1 : DEFAULT_VECTOR_CAPACITY; + auto localState = CrossProductLocalState(table, maxMorselSize); + auto printInfo = std::make_unique(); + auto crossProduct = std::make_unique(std::move(info), std::move(localState), + std::move(probeSidePrevOperator), getOperatorID(), std::move(printInfo)); + crossProduct->addChild(std::move(resultCollector)); + if (logicalCrossProduct.getSIPInfo().direction == SIPDirection::PROBE_TO_BUILD) { + mapSIPJoin(crossProduct.get()); + } + return crossProduct; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_ddl.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_ddl.cpp new file mode 100644 index 0000000000..a1a8d50a92 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_ddl.cpp @@ -0,0 +1,83 @@ +#include "planner/operator/ddl/logical_alter.h" +#include "planner/operator/ddl/logical_create_sequence.h" +#include "planner/operator/ddl/logical_create_table.h" +#include "planner/operator/ddl/logical_create_type.h" +#include "planner/operator/ddl/logical_drop.h" +#include "processor/expression_mapper.h" +#include "processor/operator/ddl/alter.h" +#include "processor/operator/ddl/create_sequence.h" +#include "processor/operator/ddl/create_table.h" +#include "processor/operator/ddl/create_type.h" +#include "processor/operator/ddl/drop.h" +#include "processor/plan_mapper.h" +#include "processor/result/factorized_table_util.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapCreateTable( + const LogicalOperator* logicalOperator) { + auto& createTable = logicalOperator->constCast(); + auto printInfo = std::make_unique(createTable.getInfo()->copy()); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable( + storage::MemoryManager::Get(*clientContext)); + auto sharedState = std::make_shared(); + return std::make_unique(createTable.getInfo()->copy(), messageTable, sharedState, + getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapCreateType( + const LogicalOperator* logicalOperator) { + auto& createType = logicalOperator->constCast(); + auto typeName = createType.getExpressionsForPrinting(); + auto printInfo = + std::make_unique(typeName, createType.getType().toString()); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable( + storage::MemoryManager::Get(*clientContext)); + return std::make_unique(typeName, createType.getType().copy(), + std::move(messageTable), getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapCreateSequence( + const LogicalOperator* logicalOperator) { + auto& createSequence = logicalOperator->constCast(); + auto printInfo = + std::make_unique(createSequence.getInfo().sequenceName); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable( + storage::MemoryManager::Get(*clientContext)); + return std::make_unique(createSequence.getInfo(), std::move(messageTable), + getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapDrop(const LogicalOperator* logicalOperator) { + auto& drop = logicalOperator->constCast(); + auto& dropInfo = drop.getDropInfo(); + auto printInfo = std::make_unique(drop.getDropInfo().name); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable( + storage::MemoryManager::Get(*clientContext)); + return std::make_unique(dropInfo, std::move(messageTable), getOperatorID(), + std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapAlter(const LogicalOperator* logicalOperator) { + auto& alter = logicalOperator->constCast(); + std::unique_ptr defaultValueEvaluator; + auto exprMapper = ExpressionMapper(alter.getSchema()); + if (alter.getInfo()->alterType == AlterType::ADD_PROPERTY) { + auto& addPropInfo = alter.getInfo()->extraInfo->constCast(); + defaultValueEvaluator = exprMapper.getEvaluator(addPropInfo.boundDefault); + } + auto printInfo = std::make_unique(alter.getInfo()->copy()); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable( + storage::MemoryManager::Get(*clientContext)); + return std::make_unique(alter.getInfo()->copy(), std::move(defaultValueEvaluator), + std::move(messageTable), getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_delete.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_delete.cpp new file mode 100644 index 0000000000..776e3d0bc7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_delete.cpp @@ -0,0 +1,174 @@ +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "planner/operator/persistent/logical_delete.h" +#include "processor/operator/persistent/delete.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" + +using namespace lbug::binder; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::vector getFwdRelTables(table_id_t nodeTableID, const main::ClientContext* context) { + std::vector result; + auto transaction = transaction::Transaction::Get(*context); + for (const auto entry : Catalog::Get(*context)->getRelGroupEntries(transaction, false)) { + auto& relGroupEntry = entry->constCast(); + for (auto& relEntryInfo : relGroupEntry.getRelEntryInfos()) { + const auto srcTableID = relEntryInfo.nodePair.srcTableID; + if (srcTableID == nodeTableID) { + const auto relTable = StorageManager::Get(*context)->getTable(relEntryInfo.oid); + result.push_back(relTable->ptrCast()); + } + } + } + return result; +} + +std::vector getBwdRelTables(table_id_t nodeTableID, const main::ClientContext* context) { + std::vector result; + auto transaction = transaction::Transaction::Get(*context); + for (const auto entry : Catalog::Get(*context)->getRelGroupEntries(transaction, false)) { + auto& relGroupEntry = entry->constCast(); + for (auto& relEntryInfo : relGroupEntry.getRelEntryInfos()) { + const auto dstTableID = relEntryInfo.nodePair.dstTableID; + if (dstTableID == nodeTableID) { + const auto relTable = StorageManager::Get(*context)->getTable(relEntryInfo.oid); + result.push_back(relTable->ptrCast()); + } + } + } + return result; +} + +NodeTableDeleteInfo PlanMapper::getNodeTableDeleteInfo(const TableCatalogEntry& entry, + DataPos pkPos) const { + auto storageManager = StorageManager::Get(*clientContext); + auto tableID = entry.getTableID(); + auto table = storageManager->getTable(tableID)->ptrCast(); + std::unordered_set fwdRelTables; + std::unordered_set bwdRelTables; + auto& nodeEntry = entry.constCast(); + for (auto relTable : getFwdRelTables(nodeEntry.getTableID(), clientContext)) { + fwdRelTables.insert(relTable); + } + for (auto relTable : getBwdRelTables(nodeEntry.getTableID(), clientContext)) { + bwdRelTables.insert(relTable); + } + return NodeTableDeleteInfo(table, std::move(fwdRelTables), std::move(bwdRelTables), pkPos); +} + +std::unique_ptr PlanMapper::getNodeDeleteExecutor( + const BoundDeleteInfo& boundInfo, const Schema& schema) const { + KU_ASSERT(boundInfo.tableType == TableType::NODE); + auto& node = boundInfo.pattern->constCast(); + auto nodeIDPos = getDataPos(*node.getInternalID(), schema); + auto info = NodeDeleteInfo(boundInfo.deleteType, nodeIDPos); + if (node.isEmpty()) { + return std::make_unique(std::move(info)); + } + if (node.isMultiLabeled()) { + table_id_map_t tableInfos; + for (auto entry : node.getEntries()) { + auto tableID = entry->getTableID(); + auto pkPos = getDataPos(*node.getPrimaryKey(tableID), schema); + tableInfos.insert({tableID, getNodeTableDeleteInfo(*entry, pkPos)}); + } + return std::make_unique(std::move(info), + std::move(tableInfos)); + } + KU_ASSERT(node.getNumEntries() == 1); + auto entry = node.getEntry(0); + auto pkPos = getDataPos(*node.getPrimaryKey(entry->getTableID()), schema); + auto extraInfo = getNodeTableDeleteInfo(*entry, pkPos); + return std::make_unique(std::move(info), std::move(extraInfo)); +} + +std::unique_ptr PlanMapper::mapDelete(const LogicalOperator* logicalOperator) { + auto delete_ = logicalOperator->constPtrCast(); + switch (delete_->getTableType()) { + case TableType::NODE: { + return mapDeleteNode(logicalOperator); + } + case TableType::REL: { + return mapDeleteRel(logicalOperator); + } + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr PlanMapper::mapDeleteNode( + const LogicalOperator* logicalOperator) { + auto delete_ = logicalOperator->constPtrCast(); + auto inSchema = delete_->getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + std::vector> executors; + for (auto& info : delete_->getInfos()) { + executors.push_back(getNodeDeleteExecutor(info, *inSchema)); + } + expression_vector patterns; + for (auto& info : delete_->getInfos()) { + patterns.push_back(info.pattern); + } + auto printInfo = + std::make_unique(patterns, delete_->getInfos()[0].deleteType); + return std::make_unique(std::move(executors), std::move(prevOperator), + getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::getRelDeleteExecutor( + const BoundDeleteInfo& boundInfo, const Schema& schema) const { + auto& rel = boundInfo.pattern->constCast(); + if (rel.isEmpty()) { + return std::make_unique(); + } + auto relIDPos = getDataPos(*rel.getInternalID(), schema); + auto srcNodeIDPos = getDataPos(*rel.getSrcNode()->getInternalID(), schema); + auto dstNodeIDPos = getDataPos(*rel.getDstNode()->getInternalID(), schema); + auto info = RelDeleteInfo(srcNodeIDPos, dstNodeIDPos, relIDPos); + auto storageManager = StorageManager::Get(*clientContext); + if (rel.isMultiLabeled()) { + table_id_map_t tableIDToTableMap; + for (auto entry : rel.getEntries()) { + auto& relGroupEntry = entry->constCast(); + for (auto& relEntryInfo : relGroupEntry.getRelEntryInfos()) { + auto table = storageManager->getTable(relEntryInfo.oid); + tableIDToTableMap.insert({table->getTableID(), table->ptrCast()}); + } + } + return std::make_unique(std::move(tableIDToTableMap), + std::move(info)); + } + KU_ASSERT(rel.getNumEntries() == 1); + auto& entry = rel.getEntry(0)->constCast(); + auto relEntryInfo = entry.getSingleRelEntryInfo(); + auto table = storageManager->getTable(relEntryInfo.oid)->ptrCast(); + return std::make_unique(table, std::move(info)); +} + +std::unique_ptr PlanMapper::mapDeleteRel(const LogicalOperator* logicalOperator) { + auto delete_ = logicalOperator->constPtrCast(); + auto inSchema = delete_->getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + std::vector> executors; + for (auto& info : delete_->getInfos()) { + executors.push_back(getRelDeleteExecutor(info, *inSchema)); + } + expression_vector patterns; + for (auto& info : delete_->getInfos()) { + patterns.push_back(info.pattern); + } + auto printInfo = std::make_unique(patterns); + return std::make_unique(std::move(executors), std::move(prevOperator), + getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_distinct.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_distinct.cpp new file mode 100644 index 0000000000..77a5c27974 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_distinct.cpp @@ -0,0 +1,37 @@ +#include "planner/operator/logical_distinct.h" +#include "processor/operator/aggregate/hash_aggregate.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapDistinct(const LogicalOperator* logicalOperator) { + auto distinct = logicalOperator->constPtrCast(); + auto child = distinct->getChild(0).get(); + auto outSchema = distinct->getSchema(); + auto inSchema = child->getSchema(); + auto prevOperator = mapOperator(child); + uint64_t limitNum = 0; + if (distinct->hasLimitNum()) { + limitNum += distinct->getLimitNum(); + } + if (distinct->hasSkipNum()) { + limitNum += distinct->getSkipNum(); + } + if (limitNum == 0) { + limitNum = UINT64_MAX; + } + auto op = createDistinctHashAggregate(distinct->getKeys(), distinct->getPayloads(), inSchema, + outSchema, std::move(prevOperator)); + auto hashAggregate = op->getChild(0)->getChild(0)->ptrCast(); + hashAggregate->getSharedState()->setLimitNumber(limitNum); + auto printInfo = static_cast(hashAggregate->getPrintInfo()); + const_cast(printInfo)->limitNum = limitNum; + return op; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_dummy_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_dummy_scan.cpp new file mode 100644 index 0000000000..0540466211 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_dummy_scan.cpp @@ -0,0 +1,36 @@ +#include "planner/operator/scan/logical_dummy_scan.h" +#include "processor/expression_mapper.h" +#include "processor/plan_mapper.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapDummyScan(const LogicalOperator*) { + auto inSchema = std::make_unique(); + auto expression = LogicalDummyScan::getDummyExpression(); + auto tableSchema = FactorizedTableSchema(); + // TODO(Ziyi): remove vectors when we have done the refactor of dataChunk. + std::vector> vectors; + std::vector vectorsToAppend; + auto columnSchema = ColumnSchema(false, 0 /* groupID */, + LogicalTypeUtils::getRowLayoutSize(expression->dataType)); + tableSchema.appendColumn(std::move(columnSchema)); + auto exprMapper = ExpressionMapper(inSchema.get()); + auto expressionEvaluator = exprMapper.getEvaluator(expression); + auto memoryManager = storage::MemoryManager::Get(*clientContext); + // expression can be evaluated statically and does not require an actual resultset to init + expressionEvaluator->init(ResultSet(0) /* dummy resultset */, clientContext); + expressionEvaluator->evaluate(); + vectors.push_back(expressionEvaluator->resultVector); + vectorsToAppend.push_back(expressionEvaluator->resultVector.get()); + auto table = std::make_shared(memoryManager, std::move(tableSchema)); + table->append(vectorsToAppend); + return createEmptyFTableScan(table, 1 /* maxMorselSize */); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_dummy_sink.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_dummy_sink.cpp new file mode 100644 index 0000000000..92bdaf08b0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_dummy_sink.cpp @@ -0,0 +1,17 @@ +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapDummySink(const LogicalOperator* logicalOperator) { + auto child = mapOperator(logicalOperator->getChild(0).get()); + auto descriptor = std::make_unique(logicalOperator->getSchema()); + auto sink = std::make_unique(std::move(child), getOperatorID()); + sink->setDescriptor(std::move(descriptor)); + return sink; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_empty_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_empty_result.cpp new file mode 100644 index 0000000000..6cbfd10fc1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_empty_result.cpp @@ -0,0 +1,15 @@ +#include "processor/operator/empty_result.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapEmptyResult(const LogicalOperator*) { + auto printInfo = std::make_unique(); + return std::make_unique(getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_explain.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_explain.cpp new file mode 100644 index 0000000000..13f8ab298e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_explain.cpp @@ -0,0 +1,48 @@ +#include "common/profiler.h" +#include "main/plan_printer.h" +#include "planner/operator/logical_explain.h" +#include "planner/operator/logical_plan.h" +#include "processor/operator/profile.h" +#include "processor/plan_mapper.h" +#include "processor/result/factorized_table_util.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::binder; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapExplain(const LogicalOperator* logicalOperator) { + auto& logicalExplain = logicalOperator->constCast(); + auto root = mapOperator(logicalExplain.getChild(0).get()); + if (!root->isSink()) { + auto inSchema = logicalExplain.getChild(0)->getSchema(); + root = createResultCollector(AccumulateType::REGULAR, + logicalExplain.getInnerResultColumns(), inSchema, std::move(root)); + } + auto memoryManager = storage::MemoryManager::Get(*clientContext); + auto messageTable = FactorizedTableUtils::getSingleStringColumnFTable(memoryManager); + if (logicalExplain.getExplainType() == ExplainType::PROFILE) { + auto profile = std::make_unique(ProfileInfo{}, std::move(messageTable), + getOperatorID(), OPPrintInfo::EmptyInfo()); + profile->addChild(std::move(root)); + return profile; + } + if (logicalExplain.getExplainType() == ExplainType::PHYSICAL_PLAN) { + auto plan = std::make_unique(std::move(root)); + auto profiler = std::make_unique(); + auto explainStr = main::PlanPrinter::printPlanToOstream(plan.get(), profiler.get()).str(); + FactorizedTableUtils::appendStringToTable(messageTable.get(), explainStr, memoryManager); + return std::make_unique(std::move(messageTable), getOperatorID()); + } + auto plan = LogicalPlan(); + plan.setLastOperator(logicalExplain.getChild(0)); + auto explainStr = main::PlanPrinter::printPlanToOstream(&plan).str(); + FactorizedTableUtils::appendStringToTable(messageTable.get(), explainStr, memoryManager); + return std::make_unique(std::move(messageTable), getOperatorID()); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_expressions_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_expressions_scan.cpp new file mode 100644 index 0000000000..304e34e35a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_expressions_scan.cpp @@ -0,0 +1,41 @@ +#include "common/system_config.h" +#include "planner/operator/logical_accumulate.h" +#include "planner/operator/scan/logical_expressions_scan.h" +#include "processor/operator/result_collector.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapExpressionsScan( + const LogicalOperator* logicalOperator) { + auto& expressionsScan = logicalOperator->constCast(); + auto outerAccumulate = expressionsScan.getOuterAccumulate()->ptrCast(); + expression_map materializedExpressionToColIdx; + auto materializedExpressions = outerAccumulate->getPayloads(); + for (auto i = 0u; i < materializedExpressions.size(); ++i) { + materializedExpressionToColIdx.insert({materializedExpressions[i], i}); + } + auto expressionsToScan = expressionsScan.getExpressions(); + std::vector colIndicesToScan; + for (auto& expression : expressionsToScan) { + KU_ASSERT(materializedExpressionToColIdx.contains(expression)); + colIndicesToScan.push_back(materializedExpressionToColIdx.at(expression)); + } + auto schema = expressionsScan.getSchema(); + KU_ASSERT(logicalOpToPhysicalOpMap.contains(outerAccumulate)); + auto physicalOp = logicalOpToPhysicalOpMap.at(outerAccumulate); + KU_ASSERT(physicalOp->getOperatorType() == PhysicalOperatorType::TABLE_FUNCTION_CALL); + KU_ASSERT(physicalOp->getChild(0)->getOperatorType() == PhysicalOperatorType::RESULT_COLLECTOR); + auto resultCollector = physicalOp->getChild(0)->ptrCast(); + auto table = resultCollector->getResultFTable(); + return createFTableScan(expressionsToScan, colIndicesToScan, schema, table, + DEFAULT_VECTOR_CAPACITY /* maxMorselSize */); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_extend.cpp new file mode 100644 index 0000000000..be0f8f74a9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_extend.cpp @@ -0,0 +1,187 @@ +#include "binder/binder.h" +#include "binder/expression/property_expression.h" +#include "binder/expression_binder.h" +#include "common/enums/extend_direction_util.h" +#include "main/client_context.h" +#include "planner/operator/extend/logical_extend.h" +#include "processor/operator/scan/scan_multi_rel_tables.h" +#include "processor/operator/scan/scan_rel_table.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; +using namespace lbug::catalog; + +namespace lbug { +namespace processor { + +static ScanRelTableInfo getRelTableScanInfo(const TableCatalogEntry& tableEntry, + RelDataDirection direction, RelTable* relTable, bool shouldScanNbrID, + const expression_vector& properties, const std::vector& columnPredicates, + main::ClientContext* clientContext) { + std::vector columnPredicateSets = copyVector(columnPredicates); + if (!columnPredicateSets.empty()) { + // Since we insert a nbr column. We need to pad an empty nbr column predicate set. + columnPredicateSets.insert(columnPredicateSets.begin(), ColumnPredicateSet()); + } + auto tableInfo = ScanRelTableInfo(relTable, std::move(columnPredicateSets), direction); + // We always should scan nbrID from relTable. This is not a property in the schema label, so + // cannot be bound to a column in the front-end. + auto nbrColumnID = shouldScanNbrID ? NBR_ID_COLUMN_ID : INVALID_COLUMN_ID; + tableInfo.addColumnInfo(nbrColumnID, ColumnCaster(LogicalType::INTERNAL_ID())); + auto binder = Binder(clientContext); + auto expressionBinder = ExpressionBinder(&binder, clientContext); + for (auto& expr : properties) { + auto& property = expr->constCast(); + if (property.hasProperty(tableEntry.getTableID())) { + auto propertyName = property.getPropertyName(); + auto& columnType = tableEntry.getProperty(propertyName).getType(); + auto columnCaster = ColumnCaster(columnType.copy()); + if (property.getDataType() != columnType) { + auto columnExpr = std::make_shared(property); + columnExpr->dataType = columnType.copy(); + columnCaster.setCastExpr( + expressionBinder.forceCast(columnExpr, property.getDataType())); + } + tableInfo.addColumnInfo(tableEntry.getColumnID(propertyName), std::move(columnCaster)); + } else { + tableInfo.addColumnInfo(INVALID_COLUMN_ID, ColumnCaster(LogicalType::ANY())); + } + } + return tableInfo; +} + +static bool isRelTableQualifies(ExtendDirection direction, table_id_t srcTableID, + table_id_t dstTableID, table_id_t boundNodeTableID, const table_id_set_t& nbrTableISet) { + switch (direction) { + case ExtendDirection::FWD: { + return srcTableID == boundNodeTableID && nbrTableISet.contains(dstTableID); + } + case ExtendDirection::BWD: { + return dstTableID == boundNodeTableID && nbrTableISet.contains(srcTableID); + } + default: + KU_UNREACHABLE; + } +} + +static std::vector populateRelTableCollectionScanner(table_id_t boundNodeTableID, + const table_id_set_t& nbrTableISet, const RelGroupCatalogEntry& entry, + ExtendDirection extendDirection, bool shouldScanNbrID, const expression_vector& properties, + const std::vector& columnPredicates, main::ClientContext* clientContext) { + std::vector scanInfos; + const auto storageManager = StorageManager::Get(*clientContext); + for (auto& info : entry.getRelEntryInfos()) { + auto srcTableID = info.nodePair.srcTableID; + auto dstTableID = info.nodePair.dstTableID; + auto relTable = storageManager->getTable(info.oid)->ptrCast(); + switch (extendDirection) { + case ExtendDirection::FWD: { + if (isRelTableQualifies(ExtendDirection::FWD, srcTableID, dstTableID, boundNodeTableID, + nbrTableISet)) { + scanInfos.push_back(getRelTableScanInfo(entry, RelDataDirection::FWD, relTable, + shouldScanNbrID, properties, columnPredicates, clientContext)); + } + } break; + case ExtendDirection::BWD: { + if (isRelTableQualifies(ExtendDirection::BWD, srcTableID, dstTableID, boundNodeTableID, + nbrTableISet)) { + scanInfos.push_back(getRelTableScanInfo(entry, RelDataDirection::BWD, relTable, + shouldScanNbrID, properties, columnPredicates, clientContext)); + } + } break; + case ExtendDirection::BOTH: { + if (isRelTableQualifies(ExtendDirection::FWD, srcTableID, dstTableID, boundNodeTableID, + nbrTableISet)) { + scanInfos.push_back(getRelTableScanInfo(entry, RelDataDirection::FWD, relTable, + shouldScanNbrID, properties, columnPredicates, clientContext)); + } + if (isRelTableQualifies(ExtendDirection::BWD, srcTableID, dstTableID, boundNodeTableID, + nbrTableISet)) { + scanInfos.push_back(getRelTableScanInfo(entry, RelDataDirection::BWD, relTable, + shouldScanNbrID, properties, columnPredicates, clientContext)); + } + } break; + default: + KU_UNREACHABLE; + } + } + return scanInfos; +} + +static bool scanSingleRelTable(const RelExpression& rel, const NodeExpression& boundNode, + ExtendDirection extendDirection) { + return !rel.isMultiLabeled() && !boundNode.isMultiLabeled() && + extendDirection != ExtendDirection::BOTH; +} + +std::unique_ptr PlanMapper::mapExtend(const LogicalOperator* logicalOperator) { + auto extend = logicalOperator->constPtrCast(); + auto outFSchema = extend->getSchema(); + auto inFSchema = extend->getChild(0)->getSchema(); + auto boundNode = extend->getBoundNode(); + auto nbrNode = extend->getNbrNode(); + auto rel = extend->getRel(); + auto extendDirection = extend->getDirection(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto inNodeIDPos = getDataPos(*boundNode->getInternalID(), *inFSchema); + std::vector outVectorsPos; + auto outNodeIDPos = getDataPos(*nbrNode->getInternalID(), *outFSchema); + outVectorsPos.push_back(outNodeIDPos); + for (auto& expression : extend->getProperties()) { + outVectorsPos.push_back(getDataPos(*expression, *outFSchema)); + } + auto scanInfo = ScanOpInfo(inNodeIDPos, outVectorsPos); + std::vector tableNames; + auto storageManager = StorageManager::Get(*clientContext); + for (auto entry : rel->getEntries()) { + tableNames.push_back(entry->getName()); + } + auto printInfo = std::make_unique(tableNames, extend->getProperties(), + boundNode, rel, nbrNode, extendDirection, rel->getVariableName()); + if (scanSingleRelTable(*rel, *boundNode, extendDirection)) { + KU_ASSERT(rel->getNumEntries() == 1); + auto entry = rel->getEntry(0)->ptrCast(); + auto relDataDirection = ExtendDirectionUtil::getRelDataDirection(extendDirection); + auto entryInfo = entry->getSingleRelEntryInfo(); + auto relTable = storageManager->getTable(entryInfo.oid)->ptrCast(); + auto scanRelInfo = + getRelTableScanInfo(*entry, relDataDirection, relTable, extend->shouldScanNbrID(), + extend->getProperties(), extend->getPropertyPredicates(), clientContext); + return std::make_unique(std::move(scanInfo), std::move(scanRelInfo), + std::move(prevOperator), getOperatorID(), printInfo->copy()); + } + // map to generic extend + auto directionInfo = DirectionInfo(); + directionInfo.extendFromSource = extend->extendFromSourceNode(); + if (rel->hasDirectionExpr()) { + directionInfo.directionPos = getDataPos(*rel->getDirectionExpr(), *outFSchema); + } + table_id_map_t scanners; + for (auto boundNodeTableID : boundNode->getTableIDs()) { + for (auto entry : rel->getEntries()) { + auto& relGroupEntry = entry->constCast(); + auto scanInfos = + populateRelTableCollectionScanner(boundNodeTableID, nbrNode->getTableIDsSet(), + relGroupEntry, extendDirection, extend->shouldScanNbrID(), + extend->getProperties(), extend->getPropertyPredicates(), clientContext); + if (scanInfos.empty()) { + continue; + } + if (scanners.contains(boundNodeTableID)) { + scanners.at(boundNodeTableID).addRelInfos(std::move(scanInfos)); + } else { + scanners.insert( + {boundNodeTableID, RelTableCollectionScanner(std::move(scanInfos))}); + } + } + } + return std::make_unique(std::move(scanInfo), std::move(directionInfo), + std::move(scanners), std::move(prevOperator), getOperatorID(), printInfo->copy()); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_filter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_filter.cpp new file mode 100644 index 0000000000..e37fdbed14 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_filter.cpp @@ -0,0 +1,23 @@ +#include "planner/operator/logical_filter.h" +#include "processor/expression_mapper.h" +#include "processor/operator/filter.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapFilter(const LogicalOperator* logicalOperator) { + auto& logicalFilter = logicalOperator->constCast(); + auto inSchema = logicalFilter.getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto exprMapper = ExpressionMapper(inSchema); + auto physicalRootExpr = exprMapper.getEvaluator(logicalFilter.getPredicate()); + auto printInfo = std::make_unique(logicalFilter.getPredicate()); + return make_unique(std::move(physicalRootExpr), logicalFilter.getGroupPosToSelect(), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_flatten.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_flatten.cpp new file mode 100644 index 0000000000..dba5cf83d0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_flatten.cpp @@ -0,0 +1,20 @@ +#include "planner/operator/logical_flatten.h" +#include "processor/operator/flatten.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapFlatten(const LogicalOperator* logicalOperator) { + auto& flatten = logicalOperator->constCast(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + // todo (Xiyang): add print info for flatten + auto printInfo = std::make_unique(); + return make_unique(flatten.getGroupPos(), std::move(prevOperator), getOperatorID(), + std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_hash_join.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_hash_join.cpp new file mode 100644 index 0000000000..69148c9390 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_hash_join.cpp @@ -0,0 +1,125 @@ +#include "binder/expression/expression_util.h" +#include "planner/operator/logical_hash_join.h" +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/hash_join/hash_join_probe.h" +#include "processor/plan_mapper.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::binder; +using namespace lbug::planner; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +HashJoinBuildInfo PlanMapper::createHashBuildInfo(const Schema& buildSideSchema, + const expression_vector& keys, const expression_vector& payloads) { + f_group_pos_set keyGroupPosSet; + std::vector keysPos; + std::vector fStateTypes; + std::vector payloadsPos; + auto tableSchema = FactorizedTableSchema(); + for (auto& key : keys) { + auto pos = DataPos(buildSideSchema.getExpressionPos(*key)); + keyGroupPosSet.insert(pos.dataChunkPos); + // Keys are always stored in flat column. + auto columnSchema = ColumnSchema(false /* isUnFlat */, pos.dataChunkPos, + LogicalTypeUtils::getRowLayoutSize(key->dataType)); + tableSchema.appendColumn(std::move(columnSchema)); + keysPos.push_back(pos); + fStateTypes.push_back(buildSideSchema.getGroup(pos.dataChunkPos)->isFlat() ? + FStateType::FLAT : + FStateType::UNFLAT); + } + for (auto& payload : payloads) { + auto pos = DataPos(buildSideSchema.getExpressionPos(*payload)); + if (keyGroupPosSet.contains(pos.dataChunkPos) || + buildSideSchema.getGroup(pos.dataChunkPos)->isFlat()) { + // Payloads need to be stored in flat column in 2 cases + // 1. payload is in the same chunk as a key. Since keys are always stored as flat, + // payloads must also be stored as flat. + // 2. payload is in flat chunk + auto columnSchema = ColumnSchema(false /* isUnFlat */, pos.dataChunkPos, + LogicalTypeUtils::getRowLayoutSize(payload->dataType)); + tableSchema.appendColumn(std::move(columnSchema)); + } else { + auto columnSchema = + ColumnSchema(true /* isUnFlat */, pos.dataChunkPos, sizeof(overflow_value_t)); + tableSchema.appendColumn(std::move(columnSchema)); + } + payloadsPos.push_back(pos); + } + auto hashValueColumn = ColumnSchema(false /* isUnFlat */, INVALID_DATA_CHUNK_POS, + LogicalTypeUtils::getRowLayoutSize(LogicalType::HASH())); + tableSchema.appendColumn(std::move(hashValueColumn)); + auto pointerColumn = ColumnSchema(false /* isUnFlat */, INVALID_DATA_CHUNK_POS, + LogicalTypeUtils::getRowLayoutSize(LogicalType::INT64())); + tableSchema.appendColumn(std::move(pointerColumn)); + return HashJoinBuildInfo(std::move(keysPos), std::move(fStateTypes), std::move(payloadsPos), + std::move(tableSchema)); +} + +std::unique_ptr PlanMapper::mapHashJoin(const LogicalOperator* logicalOperator) { + auto hashJoin = logicalOperator->constPtrCast(); + auto outSchema = hashJoin->getSchema(); + auto buildSchema = hashJoin->getChild(1)->getSchema(); + std::unique_ptr probeSidePrevOperator; + std::unique_ptr buildSidePrevOperator; + // Map the side into which semi mask is passed first. + if (hashJoin->getSIPInfo().dependency == SIPDependency::PROBE_DEPENDS_ON_BUILD) { + buildSidePrevOperator = mapOperator(hashJoin->getChild(1).get()); + probeSidePrevOperator = mapOperator(hashJoin->getChild(0).get()); + } else { + probeSidePrevOperator = mapOperator(hashJoin->getChild(0).get()); + buildSidePrevOperator = mapOperator(hashJoin->getChild(1).get()); + } + expression_vector probeKeys; + expression_vector buildKeys; + for (auto& [probeKey, buildKey] : hashJoin->getJoinConditions()) { + probeKeys.push_back(probeKey); + buildKeys.push_back(buildKey); + } + auto buildKeyTypes = ExpressionUtil::getDataTypes(buildKeys); + auto payloads = + ExpressionUtil::excludeExpressions(hashJoin->getExpressionsToMaterialize(), probeKeys); + // Create build + auto buildInfo = createHashBuildInfo(*buildSchema, buildKeys, payloads); + auto globalHashTable = + std::make_unique(*storage::MemoryManager::Get(*clientContext), + LogicalType::copy(buildKeyTypes), buildInfo.tableSchema.copy()); + auto sharedState = std::make_shared(std::move(globalHashTable)); + auto buildPrintInfo = std::make_unique(buildKeys, payloads); + auto hashJoinBuild = std::make_unique(PhysicalOperatorType::HASH_JOIN_BUILD, + sharedState, std::move(buildInfo), std::move(buildSidePrevOperator), getOperatorID(), + buildPrintInfo->copy()); + hashJoinBuild->setDescriptor(std::make_unique(buildSchema)); + // Create probe + std::vector probeKeysDataPos; + for (auto& probeKey : probeKeys) { + probeKeysDataPos.emplace_back(outSchema->getExpressionPos(*probeKey)); + } + std::vector probePayloadsOutPos; + for (auto& payload : payloads) { + probePayloadsOutPos.emplace_back(outSchema->getExpressionPos(*payload)); + } + ProbeDataInfo probeDataInfo(probeKeysDataPos, probePayloadsOutPos); + if (hashJoin->hasMark()) { + auto mark = hashJoin->getMark(); + auto markOutputPos = DataPos(outSchema->getExpressionPos(*mark)); + probeDataInfo.markDataPos = markOutputPos; + } else { + probeDataInfo.markDataPos = DataPos::getInvalidPos(); + } + auto probePrintInfo = std::make_unique(probeKeys); + auto hashJoinProbe = make_unique(sharedState, hashJoin->getJoinType(), + hashJoin->requireFlatProbeKeys(), probeDataInfo, std::move(probeSidePrevOperator), + getOperatorID(), probePrintInfo->copy()); + hashJoinProbe->addChild(std::move(hashJoinBuild)); + if (hashJoin->getSIPInfo().direction == SIPDirection::PROBE_TO_BUILD) { + mapSIPJoin(hashJoinProbe.get()); + } + return hashJoinProbe; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_index_scan_node.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_index_scan_node.cpp new file mode 100644 index 0000000000..b37d8bec74 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_index_scan_node.cpp @@ -0,0 +1,41 @@ +#include "main/client_context.h" +#include "planner/operator/scan/logical_index_look_up.h" +#include "processor/expression_mapper.h" +#include "processor/operator/index_lookup.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapIndexLookup( + const LogicalOperator* logicalOperator) { + auto& logicalIndexScan = logicalOperator->constCast(); + auto outSchema = logicalIndexScan.getSchema(); + auto child = logicalOperator->getChild(0).get(); + auto prevOperator = mapOperator(child); + auto storageManager = storage::StorageManager::Get(*clientContext); + auto exprMapper = ExpressionMapper(child->getSchema()); + std::vector indexLookupInfos; + for (auto i = 0u; i < logicalIndexScan.getNumInfos(); ++i) { + auto& info = logicalIndexScan.getInfo(i); + auto nodeTable = storageManager->getTable(info.nodeTableID)->ptrCast(); + auto offsetPos = DataPos(outSchema->getExpressionPos(*info.offset)); + auto keyEvaluator = exprMapper.getEvaluator(info.key); + indexLookupInfos.emplace_back(nodeTable, std::move(keyEvaluator), offsetPos); + } + auto warningDataPos = getDataPos(logicalIndexScan.getInfo(0).warningExprs, *outSchema); + binder::expression_vector expressions; + for (auto i = 0u; i < logicalIndexScan.getNumInfos(); ++i) { + expressions.push_back(logicalIndexScan.getInfo(i).offset); + } + auto printInfo = std::make_unique(expressions); + return std::make_unique(std::move(indexLookupInfos), std::move(warningDataPos), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_insert.cpp new file mode 100644 index 0000000000..4cc5350685 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_insert.cpp @@ -0,0 +1,109 @@ +#include "binder/expression/rel_expression.h" +#include "main/client_context.h" +#include "planner/operator/persistent/logical_insert.h" +#include "processor/expression_mapper.h" +#include "processor/operator/persistent/insert.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" + +using namespace lbug::evaluator; +using namespace lbug::planner; +using namespace lbug::storage; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace processor { + +static std::vector populateReturnColumnsPos(const LogicalInsertInfo& info, + const Schema& schema) { + std::vector result; + for (auto i = 0u; i < info.columnDataExprs.size(); ++i) { + if (info.isReturnColumnExprs[i]) { + result.emplace_back(schema.getExpressionPos(*info.columnExprs[i])); + } else { + result.push_back(DataPos::getInvalidPos()); + } + } + return result; +} + +NodeInsertExecutor PlanMapper::getNodeInsertExecutor(const LogicalInsertInfo* boundInfo, + const Schema& inSchema, const Schema& outSchema) const { + auto& node = boundInfo->pattern->constCast(); + auto nodeIDPos = getDataPos(*node.getInternalID(), outSchema); + auto columnsPos = populateReturnColumnsPos(*boundInfo, outSchema); + auto info = NodeInsertInfo(nodeIDPos, columnsPos, boundInfo->conflictAction); + auto storageManager = StorageManager::Get(*clientContext); + KU_ASSERT(node.getNumEntries() == 1); + ; + auto table = storageManager->getTable(node.getEntry(0)->getTableID())->ptrCast(); + evaluator_vector_t evaluators; + auto exprMapper = ExpressionMapper(&inSchema); + for (auto& expr : boundInfo->columnDataExprs) { + evaluators.push_back(exprMapper.getEvaluator(expr)); + } + auto tableInfo = NodeTableInsertInfo(table, std::move(evaluators)); + return NodeInsertExecutor(std::move(info), std::move(tableInfo)); +} + +RelInsertExecutor PlanMapper::getRelInsertExecutor(const LogicalInsertInfo* boundInfo, + const Schema& inSchema, const Schema& outSchema) const { + auto& rel = boundInfo->pattern->constCast(); + auto srcNode = rel.getSrcNode(); + auto dstNode = rel.getDstNode(); + auto srcNodeIDPos = getDataPos(*srcNode->getInternalID(), inSchema); + auto dstNodeIDPos = getDataPos(*dstNode->getInternalID(), inSchema); + auto columnsPos = populateReturnColumnsPos(*boundInfo, outSchema); + auto info = RelInsertInfo(srcNodeIDPos, dstNodeIDPos, std::move(columnsPos)); + auto storageManager = StorageManager::Get(*clientContext); + KU_ASSERT(srcNode->getNumEntries() == 1 && dstNode->getNumEntries() == 1); + auto srcTableID = srcNode->getEntry(0)->getTableID(); + auto dstTableID = dstNode->getEntry(0)->getTableID(); + KU_ASSERT(rel.getNumEntries() == 1); + auto& relGroupEntry = rel.getEntry(0)->constCast(); + auto relEntryInfo = relGroupEntry.getRelEntryInfo(srcTableID, dstTableID); + auto table = storageManager->getTable(relEntryInfo->oid)->ptrCast(); + evaluator_vector_t evaluators; + auto exprMapper = ExpressionMapper(&outSchema); + for (auto& expr : boundInfo->columnDataExprs) { + evaluators.push_back(exprMapper.getEvaluator(expr)); + } + auto tableInfo = RelTableInsertInfo(table, std::move(evaluators)); + return RelInsertExecutor(std::move(info), std::move(tableInfo)); +} + +std::unique_ptr PlanMapper::mapInsert(const LogicalOperator* logicalOperator) { + auto& logicalInsert = logicalOperator->constCast(); + auto inSchema = logicalInsert.getChild(0)->getSchema(); + auto outSchema = logicalInsert.getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + std::vector nodeExecutors; + std::vector relExecutors; + for (auto& info : logicalInsert.getInfos()) { + switch (info.tableType) { + case TableType::NODE: { + nodeExecutors.push_back(getNodeInsertExecutor(&info, *inSchema, *outSchema)); + } break; + case TableType::REL: { + relExecutors.push_back(getRelInsertExecutor(&info, *inSchema, *outSchema)); + } break; + default: + KU_UNREACHABLE; + } + } + expression_vector expressions; + for (auto& info : logicalInsert.getInfos()) { + for (auto& expr : info.columnExprs) { + expressions.push_back(expr); + } + } + auto printInfo = + std::make_unique(expressions, logicalInsert.getInfos()[0].conflictAction); + return std::make_unique(std::move(nodeExecutors), std::move(relExecutors), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_intersect.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_intersect.cpp new file mode 100644 index 0000000000..46f001b0d2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_intersect.cpp @@ -0,0 +1,69 @@ +#include "binder/expression/expression_util.h" +#include "planner/operator/logical_intersect.h" +#include "processor/operator/intersect/intersect.h" +#include "processor/operator/intersect/intersect_build.h" +#include "processor/plan_mapper.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::binder; +using namespace lbug::planner; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapIntersect(const LogicalOperator* logicalOperator) { + auto logicalIntersect = logicalOperator->constPtrCast(); + auto intersectNodeID = logicalIntersect->getIntersectNodeID(); + auto outSchema = logicalIntersect->getSchema(); + std::vector> sharedStates; + std::vector intersectDataInfos; + // Map build side children. + std::vector> buildChildren; + for (auto i = 1u; i < logicalIntersect->getNumChildren(); i++) { + auto keyNodeID = logicalIntersect->getKeyNodeID(i - 1); + auto keys = expression_vector{keyNodeID}; + auto buildSchema = logicalIntersect->getChild(i)->getSchema(); + auto buildPrevOperator = mapOperator(logicalIntersect->getChild(i).get()); + auto payloadExpressions = + ExpressionUtil::excludeExpressions(buildSchema->getExpressionsInScope(), keys); + auto buildInfo = createHashBuildInfo(*buildSchema, keys, payloadExpressions); + auto globalHashTable = + std::make_unique(*storage::MemoryManager::Get(*clientContext), + ExpressionUtil::getDataTypes(keys), buildInfo.tableSchema.copy()); + auto sharedState = std::make_shared(std::move(globalHashTable)); + sharedStates.push_back(sharedState); + auto printInfo = std::make_unique(keys, payloadExpressions); + auto build = std::make_unique(sharedState, std::move(buildInfo), + std::move(buildPrevOperator), getOperatorID(), std::move(printInfo)); + build->setDescriptor(std::make_unique(buildSchema)); + buildChildren.push_back(std::move(build)); + // Collect intersect info + std::vector vectorsToScanPos; + auto expressionsToScan = ExpressionUtil::excludeExpressions( + buildSchema->getExpressionsInScope(), {keyNodeID, intersectNodeID}); + for (auto& expression : expressionsToScan) { + vectorsToScanPos.emplace_back(outSchema->getExpressionPos(*expression)); + } + IntersectDataInfo info{DataPos(outSchema->getExpressionPos(*keyNodeID)), vectorsToScanPos}; + intersectDataInfos.push_back(info); + } + // Map probe side child. + auto probeChild = mapOperator(logicalIntersect->getChild(0).get()); + // Map intersect. + auto outputDataPos = + DataPos(outSchema->getExpressionPos(*logicalIntersect->getIntersectNodeID())); + auto printInfo = std::make_unique(intersectNodeID); + auto intersect = make_unique(outputDataPos, intersectDataInfos, sharedStates, + std::move(probeChild), getOperatorID(), std::move(printInfo)); + for (auto& child : buildChildren) { + intersect->addChild(std::move(child)); + } + if (logicalIntersect->getSIPInfo().direction == SIPDirection::PROBE_TO_BUILD) { + mapSIPJoin(intersect.get()); + } + return intersect; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_label_filter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_label_filter.cpp new file mode 100644 index 0000000000..26a3abf91c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_label_filter.cpp @@ -0,0 +1,24 @@ +#include "planner/operator/logical_node_label_filter.h" +#include "processor/operator/filter.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapNodeLabelFilter( + const LogicalOperator* logicalOperator) { + auto& logicalLabelFilter = logicalOperator->constCast(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto schema = logicalOperator->getSchema(); + auto nbrNodeVectorPos = DataPos(schema->getExpressionPos(*logicalLabelFilter.getNodeID())); + auto filterInfo = + std::make_unique(nbrNodeVectorPos, logicalLabelFilter.getTableIDSet()); + auto printInfo = std::make_unique(); + return std::make_unique(std::move(filterInfo), std::move(prevOperator), + getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_limit.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_limit.cpp new file mode 100644 index 0000000000..d52e78ae07 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_limit.cpp @@ -0,0 +1,50 @@ +#include "binder/expression/expression_util.h" +#include "common/exception/message.h" +#include "common/exception/runtime.h" +#include "planner/operator/logical_limit.h" +#include "processor/operator/limit.h" +#include "processor/operator/skip.h" +#include "processor/plan_mapper.h" + +using namespace lbug::binder; +using namespace lbug::planner; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapLimit(const LogicalOperator* logicalOperator) { + auto& logicalLimit = logicalOperator->constCast(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto dataChunkToSelectPos = logicalLimit.getGroupPosToSelect(); + auto groupsPotToLimit = logicalLimit.getGroupsPosToLimit(); + std::unique_ptr lastOperator = std::move(prevOperator); + if (logicalLimit.hasSkipNum()) { + auto skipExpr = logicalLimit.getSkipNum(); + if (!ExpressionUtil::canEvaluateAsLiteral(*skipExpr)) { + throw RuntimeException{ + ExceptionMessage::invalidSkipLimitParam(skipExpr->toString(), "skip")}; + } + auto skipNum = ExpressionUtil::evaluateAsSkipLimit(*skipExpr); + auto printInfo = std::make_unique(skipNum); + lastOperator = make_unique(skipNum, std::make_shared(0), + dataChunkToSelectPos, groupsPotToLimit, std::move(lastOperator), getOperatorID(), + printInfo->copy()); + } + if (logicalLimit.hasLimitNum()) { + auto limitExpr = logicalLimit.getLimitNum(); + if (!ExpressionUtil::canEvaluateAsLiteral(*limitExpr)) { + throw RuntimeException{ + ExceptionMessage::invalidSkipLimitParam(limitExpr->toString(), "limit")}; + } + auto limitNum = ExpressionUtil::evaluateAsSkipLimit(*limitExpr); + auto printInfo = std::make_unique(limitNum); + lastOperator = make_unique(limitNum, std::make_shared(0), + dataChunkToSelectPos, groupsPotToLimit, std::move(lastOperator), getOperatorID(), + printInfo->copy()); + } + return lastOperator; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_merge.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_merge.cpp new file mode 100644 index 0000000000..e3d5866b1a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_merge.cpp @@ -0,0 +1,117 @@ +#include "planner/operator/persistent/logical_merge.h" +#include "processor/operator/persistent/merge.h" +#include "processor/plan_mapper.h" +#include + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +static FactorizedTableSchema getFactorizedTableSchema(const binder::expression_vector& keys, + uint64_t numNodeInsertExecutors, uint64_t numRelInsertExecutors) { + auto tableSchema = FactorizedTableSchema(); + auto isUnFlat = false; + auto groupID = 0u; + for (auto& key : keys) { + auto size = common::LogicalTypeUtils::getRowLayoutSize(key->dataType); + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size)); + } + auto numNodeIDFields = numNodeInsertExecutors + numRelInsertExecutors; + for (auto i = 0u; i < numNodeIDFields; i++) { + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, sizeof(common::nodeID_t))); + } + tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, sizeof(common::hash_t))); + return tableSchema; +} + +std::unique_ptr PlanMapper::mapMerge(const LogicalOperator* logicalOperator) { + auto& logicalMerge = logicalOperator->constCast(); + auto outSchema = logicalMerge.getSchema(); + auto inSchema = logicalMerge.getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto existenceMarkPos = getDataPos(*logicalMerge.getExistenceMark(), *inSchema); + std::vector nodeInsertExecutors; + for (auto& info : logicalMerge.getInsertNodeInfos()) { + nodeInsertExecutors.push_back(getNodeInsertExecutor(&info, *inSchema, *outSchema)); + } + std::vector relInsertExecutors; + for (auto& info : logicalMerge.getInsertRelInfos()) { + relInsertExecutors.push_back(getRelInsertExecutor(&info, *inSchema, *outSchema)); + } + std::vector> onCreateNodeSetExecutors; + for (auto& info : logicalMerge.getOnCreateSetNodeInfos()) { + onCreateNodeSetExecutors.push_back(getNodeSetExecutor(info, *inSchema)); + } + std::vector> onCreateRelSetExecutors; + for (auto& info : logicalMerge.getOnCreateSetRelInfos()) { + onCreateRelSetExecutors.push_back(getRelSetExecutor(info, *inSchema)); + } + std::vector> onMatchNodeSetExecutors; + common::executor_info executorInfo; + for (auto i = 0u; i < logicalMerge.getOnMatchSetNodeInfos().size(); i++) { + auto& info = logicalMerge.getOnMatchSetNodeInfos()[i]; + for (auto j = 0u; j < logicalMerge.getInsertNodeInfos().size(); j++) { + if (*info.pattern == *logicalMerge.getInsertNodeInfos()[j].pattern) { + executorInfo.emplace(j, i); + } + } + onMatchNodeSetExecutors.push_back(getNodeSetExecutor(info, *inSchema)); + } + std::vector> onMatchRelSetExecutors; + for (auto i = 0u; i < logicalMerge.getOnMatchSetRelInfos().size(); i++) { + auto& info = logicalMerge.getOnMatchSetRelInfos()[i]; + for (auto j = 0u; j < logicalMerge.getInsertRelInfos().size(); j++) { + if (*info.pattern == *logicalMerge.getInsertRelInfos()[j].pattern) { + executorInfo.emplace(j + logicalMerge.getInsertNodeInfos().size(), + i + logicalMerge.getOnMatchSetNodeInfos().size()); + } + } + onMatchRelSetExecutors.push_back(getRelSetExecutor(info, *inSchema)); + } + binder::expression_vector expressions; + for (auto& info : logicalMerge.getInsertNodeInfos()) { + for (auto& expr : info.columnExprs) { + expressions.push_back(expr); + } + } + for (auto& info : logicalMerge.getInsertRelInfos()) { + for (auto& expr : info.columnExprs) { + expressions.push_back(expr); + } + } + std::vector onCreateOperation; + for (auto& info : logicalMerge.getOnCreateSetRelInfos()) { + onCreateOperation.emplace_back(info.column, info.columnData); + } + for (auto& info : logicalMerge.getOnCreateSetNodeInfos()) { + onCreateOperation.emplace_back(info.column, info.columnData); + } + std::vector onMatchOperation; + for (auto& info : logicalMerge.getOnMatchSetRelInfos()) { + onMatchOperation.emplace_back(info.column, info.columnData); + } + for (auto& info : logicalMerge.getOnMatchSetNodeInfos()) { + onMatchOperation.emplace_back(info.column, info.columnData); + } + auto printInfo = + std::make_unique(expressions, onCreateOperation, onMatchOperation); + std::vector> keyEvaluators; + auto expressionMapper = ExpressionMapper(inSchema); + for (auto& key : logicalMerge.getKeys()) { + keyEvaluators.push_back(expressionMapper.getEvaluator(key)); + } + + MergeInfo mergeInfo{std::move(keyEvaluators), + getFactorizedTableSchema(logicalMerge.getKeys(), + logicalMerge.getOnMatchSetNodeInfos().size(), + logicalMerge.getOnMatchSetRelInfos().size()), + std::move(executorInfo), existenceMarkPos}; + return std::make_unique(std::move(nodeInsertExecutors), std::move(relInsertExecutors), + std::move(onCreateNodeSetExecutors), std::move(onCreateRelSetExecutors), + std::move(onMatchNodeSetExecutors), std::move(onMatchRelSetExecutors), std::move(mergeInfo), + std::move(prevOperator), getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_multiplicity_reducer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_multiplicity_reducer.cpp new file mode 100644 index 0000000000..05f5b60aaf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_multiplicity_reducer.cpp @@ -0,0 +1,18 @@ +#include "processor/operator/multiplicity_reducer.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapMultiplicityReducer( + const LogicalOperator* logicalOperator) { + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto printInfo = std::make_unique(); + return std::make_unique(std::move(prevOperator), getOperatorID(), + std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_noop.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_noop.cpp new file mode 100644 index 0000000000..865d5d2c97 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_noop.cpp @@ -0,0 +1,34 @@ +#include "planner/operator/logical_noop.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapNoop(const LogicalOperator* logicalOperator) { + std::vector> children; + for (auto child : logicalOperator->getChildren()) { + children.push_back(mapOperator(child.get())); + } + auto noop = logicalOperator->constPtrCast(); + auto idx = noop->getMessageChildIdx(); + KU_ASSERT(idx < children.size()); + auto child = children[idx].get(); + // LCOV_EXCL_START + if (!child->isSink()) { + throw common::InternalException( + common::stringFormat("Trying to propagate result table from a non sink operator. This " + "should never happen.")); + } + // LCOV_EXCL_STOP + auto fTable = child->ptrCast()->getResultFTable(); + auto op = std::make_unique(fTable, getOperatorID()); + for (auto& childOp : children) { + op->addChild(std::move(childOp)); + } + return op; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_order_by.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_order_by.cpp new file mode 100644 index 0000000000..06a3550d56 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_order_by.cpp @@ -0,0 +1,107 @@ +#include "binder/expression/expression_util.h" +#include "common/exception/message.h" +#include "common/exception/runtime.h" +#include "planner/operator/logical_order_by.h" +#include "processor/operator/order_by/order_by.h" +#include "processor/operator/order_by/order_by_merge.h" +#include "processor/operator/order_by/order_by_scan.h" +#include "processor/operator/order_by/top_k.h" +#include "processor/operator/order_by/top_k_scanner.h" +#include "processor/plan_mapper.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapOrderBy(const LogicalOperator* logicalOperator) { + auto& logicalOrderBy = logicalOperator->constCast(); + auto outSchema = logicalOrderBy.getSchema(); + auto inSchema = logicalOrderBy.getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOrderBy.getChild(0).get()); + auto keyExpressions = logicalOrderBy.getExpressionsToOrderBy(); + auto payloadExpressions = inSchema->getExpressionsInScope(); + std::vector payloadsPos; + std::vector payloadTypes; + expression_map payloadToColIdx; + auto payloadSchema = FactorizedTableSchema(); + auto mayContainUnFlatKey = inSchema->getNumGroups() == 1; + for (auto i = 0u; i < payloadExpressions.size(); ++i) { + auto expression = payloadExpressions[i]; + auto [dataChunkPos, vectorPos] = inSchema->getExpressionPos(*expression); + payloadsPos.emplace_back(dataChunkPos, vectorPos); + payloadTypes.push_back(expression->dataType.copy()); + if (!inSchema->getGroup(dataChunkPos)->isFlat() && !mayContainUnFlatKey) { + // payload is unFlat and not in the same group as keys + auto columnSchema = + ColumnSchema(true /* isUnFlat */, dataChunkPos, sizeof(overflow_value_t)); + payloadSchema.appendColumn(std::move(columnSchema)); + } else { + auto columnSchema = ColumnSchema(false /* isUnFlat */, dataChunkPos, + LogicalTypeUtils::getRowLayoutSize(expression->getDataType())); + payloadSchema.appendColumn(std::move(columnSchema)); + } + payloadToColIdx.insert({expression, i}); + } + std::vector keysPos; + std::vector keyTypes; + std::vector keyInPayloadPos; + for (auto& expression : keyExpressions) { + keysPos.emplace_back(inSchema->getExpressionPos(*expression)); + keyTypes.push_back(expression->getDataType().copy()); + KU_ASSERT(payloadToColIdx.contains(expression)); + keyInPayloadPos.push_back(payloadToColIdx.at(expression)); + } + std::vector outPos; + for (auto& expression : payloadExpressions) { + outPos.emplace_back(outSchema->getExpressionPos(*expression)); + } + auto orderByDataInfo = OrderByDataInfo(keysPos, payloadsPos, LogicalType::copy(keyTypes), + LogicalType::copy(payloadTypes), logicalOrderBy.getIsAscOrders(), std::move(payloadSchema), + std::move(keyInPayloadPos)); + if (logicalOrderBy.hasLimitNum()) { + auto limitExpr = logicalOrderBy.getLimitNum(); + if (!ExpressionUtil::canEvaluateAsLiteral(*limitExpr)) { + throw RuntimeException{ + ExceptionMessage::invalidSkipLimitParam(limitExpr->toString(), "limit")}; + } + auto limitNum = ExpressionUtil::evaluateAsSkipLimit(*limitExpr); + uint64_t skipNum = 0; + if (logicalOrderBy.hasSkipNum()) { + auto skipExpr = logicalOrderBy.getSkipNum(); + if (!ExpressionUtil::canEvaluateAsLiteral(*skipExpr)) { + throw RuntimeException{ + ExceptionMessage::invalidSkipLimitParam(skipExpr->toString(), "skip")}; + } + skipNum = ExpressionUtil::evaluateAsSkipLimit(*skipExpr); + } + auto topKSharedState = std::make_shared(); + auto printInfo = + std::make_unique(keyExpressions, payloadExpressions, skipNum, limitNum); + auto topK = make_unique(std::move(orderByDataInfo), topKSharedState, skipNum, + limitNum, std::move(prevOperator), getOperatorID(), printInfo->copy()); + topK->setDescriptor(std::make_unique(inSchema)); + auto scan = + std::make_unique(outPos, topKSharedState, getOperatorID(), printInfo->copy()); + scan->addChild(std::move(topK)); + return scan; + } + auto orderBySharedState = std::make_shared(); + auto printInfo = std::make_unique(keyExpressions, payloadExpressions); + auto orderBy = make_unique(std::move(orderByDataInfo), orderBySharedState, + std::move(prevOperator), getOperatorID(), printInfo->copy()); + orderBy->setDescriptor(std::make_unique(inSchema)); + auto dispatcher = std::make_shared(); + auto orderByMerge = make_unique(orderBySharedState, std::move(dispatcher), + getOperatorID(), printInfo->copy()); + orderByMerge->addChild(std::move(orderBy)); + auto scan = std::make_unique(outPos, orderBySharedState, getOperatorID(), + printInfo->copy()); + scan->addChild(std::move(orderByMerge)); + return scan; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_path_property_probe.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_path_property_probe.cpp new file mode 100644 index 0000000000..e8f45a5a65 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_path_property_probe.cpp @@ -0,0 +1,170 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/property_expression.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/string_utils.h" +#include "planner/operator/extend/logical_recursive_extend.h" +#include "planner/operator/logical_path_property_probe.h" +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/path_property_probe.h" +#include "processor/operator/recursive_extend.h" +#include "processor/plan_mapper.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +static std::pair, std::vector> getColIdxToScan( + const expression_vector& payloads, uint32_t numKeys, const LogicalType& structType) { + std::unordered_map propertyNameToColumnIdx; + for (auto i = 0u; i < payloads.size(); ++i) { + KU_ASSERT(payloads[i]->expressionType == ExpressionType::PROPERTY); + auto propertyName = payloads[i]->ptrCast()->getPropertyName(); + StringUtils::toUpper(propertyName); + propertyNameToColumnIdx.insert({propertyName, i + numKeys}); + } + const auto& structFields = StructType::getFields(structType); + std::vector structFieldIndices; + std::vector colIndices; + for (auto i = 0u; i < structFields.size(); ++i) { + auto field = structFields[i].copy(); + auto fieldName = StringUtils::getUpper(field.getName()); + if (propertyNameToColumnIdx.contains(fieldName)) { + structFieldIndices.push_back(i); + colIndices.push_back(propertyNameToColumnIdx.at(fieldName)); + } + } + return std::make_pair(std::move(structFieldIndices), std::move(colIndices)); +} + +std::unique_ptr PlanMapper::mapPathPropertyProbe( + const LogicalOperator* logicalOperator) { + auto& logicalProbe = logicalOperator->constCast(); + if (logicalProbe.getJoinType() == RecursiveJoinType::TRACK_NONE) { + return mapOperator(logicalProbe.getChild(0).get()); + } + auto rel = logicalProbe.getRel(); + auto recursiveInfo = rel->getRecursiveInfo(); + std::vector nodeFieldIndices; + std::vector nodeTableColumnIndices; + std::shared_ptr nodeBuildSharedState = nullptr; + std::unique_ptr nodeBuild = nullptr; + // Map build node property + if (logicalProbe.getNodeChild() != nullptr) { + auto nodeBuildPrevOperator = mapOperator(logicalProbe.getNodeChild().get()); + auto nodeBuildSchema = logicalProbe.getNodeChild()->getSchema(); + auto nodeKeys = expression_vector{recursiveInfo->node->getInternalID()}; + auto nodeKeyTypes = ExpressionUtil::getDataTypes(nodeKeys); + auto nodePayloads = + ExpressionUtil::excludeExpressions(nodeBuildSchema->getExpressionsInScope(), nodeKeys); + auto nodeBuildInfo = createHashBuildInfo(*nodeBuildSchema, nodeKeys, nodePayloads); + auto nodeHashTable = + std::make_unique(*storage::MemoryManager::Get(*clientContext), + std::move(nodeKeyTypes), nodeBuildInfo.tableSchema.copy()); + nodeBuildSharedState = std::make_shared(std::move(nodeHashTable)); + nodeBuild = make_unique(PhysicalOperatorType::HASH_JOIN_BUILD, + nodeBuildSharedState, std::move(nodeBuildInfo), std::move(nodeBuildPrevOperator), + getOperatorID(), std::make_unique()); + nodeBuild->setDescriptor(std::make_unique(nodeBuildSchema)); + auto [fieldIndices, columnIndices] = getColIdxToScan(nodePayloads, nodeKeys.size(), + ListType::getChildType( + StructType::getField(rel->getDataType(), InternalKeyword::NODES).getType())); + nodeFieldIndices = std::move(fieldIndices); + nodeTableColumnIndices = std::move(columnIndices); + } + std::vector relFieldIndices; + std::vector relTableColumnIndices; + std::shared_ptr relBuildSharedState = nullptr; + std::unique_ptr relBuild = nullptr; + // Map build rel property + if (logicalProbe.getRelChild() != nullptr) { + auto relBuildPrvOperator = mapOperator(logicalProbe.getRelChild().get()); + auto relBuildSchema = logicalProbe.getRelChild()->getSchema(); + auto relKeys = expression_vector{recursiveInfo->rel->getInternalID()}; + auto relKeyTypes = ExpressionUtil::getDataTypes(relKeys); + auto relPayloads = + ExpressionUtil::excludeExpressions(relBuildSchema->getExpressionsInScope(), relKeys); + auto relBuildInfo = createHashBuildInfo(*relBuildSchema, relKeys, relPayloads); + auto relHashTable = + std::make_unique(*storage::MemoryManager::Get(*clientContext), + std::move(relKeyTypes), relBuildInfo.tableSchema.copy()); + relBuildSharedState = std::make_shared(std::move(relHashTable)); + relBuild = std::make_unique(PhysicalOperatorType::HASH_JOIN_BUILD, + relBuildSharedState, std::move(relBuildInfo), std::move(relBuildPrvOperator), + getOperatorID(), std::make_unique()); + relBuild->setDescriptor(std::make_unique(relBuildSchema)); + auto [fieldIndices, columnIndices] = getColIdxToScan(relPayloads, relKeys.size(), + ListType::getChildType( + StructType::getField(rel->getDataType(), InternalKeyword::RELS).getType())); + relFieldIndices = std::move(fieldIndices); + relTableColumnIndices = std::move(columnIndices); + } + // Map child + auto logicalChild = logicalOperator->getChild(0).get(); + auto prevOperator = mapOperator(logicalChild); + if (logicalChild->getOperatorType() == LogicalOperatorType::SEMI_MASKER) { + // Create a pipeline to populate semi mask. Pipeline source is the scan of recursive extend + // result, and pipeline sink is a dummy operator that does not materialize anything. + auto dummySink = std::make_unique(std::move(prevOperator), getOperatorID()); + dummySink->setDescriptor(std::make_unique(logicalChild->getSchema())); + auto extend = logicalChild->getChild(0)->ptrCast(); + auto columns = extend->getResultColumns(); + auto physicalCall = logicalOpToPhysicalOpMap.at(extend)->ptrCast(); + physical_op_vector_t children; + children.push_back(std::move(dummySink)); + prevOperator = createFTableScanAligned(columns, extend->getSchema(), + physicalCall->getSharedState()->factorizedTablePool.getGlobalTable(), + DEFAULT_VECTOR_CAPACITY, std::move(children)); + } + // Map probe + auto pathProbeInfo = PathPropertyProbeInfo(); + auto schema = logicalProbe.getSchema(); + pathProbeInfo.pathPos = getDataPos(*rel, *schema); + if (logicalProbe.getPathEdgeIDs() != nullptr) { + pathProbeInfo.leftNodeIDPos = getDataPos(*rel->getLeftNode()->getInternalID(), *schema); + pathProbeInfo.rightNodeIDPos = getDataPos(*rel->getRightNode()->getInternalID(), *schema); + pathProbeInfo.inputNodeIDsPos = getDataPos(*logicalProbe.getPathNodeIDs(), *schema); + pathProbeInfo.inputEdgeIDsPos = getDataPos(*logicalProbe.getPathEdgeIDs(), *schema); + pathProbeInfo.extendFromLeft = logicalProbe.extendFromLeft; + pathProbeInfo.extendDirection = logicalProbe.direction; + if (logicalProbe.direction == ExtendDirection::BOTH) { + pathProbeInfo.directionPos = + getDataPos(*recursiveInfo->bindData->directionExpr, *schema); + } + for (auto entry : recursiveInfo->node->getEntries()) { + pathProbeInfo.tableIDToName.insert({entry->getTableID(), entry->getName()}); + } + for (auto& entry : recursiveInfo->rel->getEntries()) { + auto& relGroupEntry = entry->constCast(); + for (auto& relEntryInfo : relGroupEntry.getRelEntryInfos()) { + pathProbeInfo.tableIDToName.insert({relEntryInfo.oid, entry->getName()}); + } + } + } + pathProbeInfo.nodeFieldIndices = nodeFieldIndices; + pathProbeInfo.relFieldIndices = relFieldIndices; + pathProbeInfo.nodeTableColumnIndices = nodeTableColumnIndices; + pathProbeInfo.relTableColumnIndices = relTableColumnIndices; + pathProbeInfo.extendFromLeft = logicalProbe.extendFromLeft; + auto pathProbeSharedState = + std::make_shared(nodeBuildSharedState, relBuildSharedState); + auto printInfo = std::make_unique(); + auto pathPropertyProbe = std::make_unique(std::move(pathProbeInfo), + pathProbeSharedState, std::move(prevOperator), getOperatorID(), std::move(printInfo)); + if (nodeBuild != nullptr) { + pathPropertyProbe->addChild(std::move(nodeBuild)); + } + if (relBuild != nullptr) { + pathPropertyProbe->addChild(std::move(relBuild)); + } + if (logicalProbe.getSIPInfo().direction == SIPDirection::PROBE_TO_BUILD) { + mapSIPJoin(pathPropertyProbe.get()); + } + return pathPropertyProbe; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_projection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_projection.cpp new file mode 100644 index 0000000000..f104f889f1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_projection.cpp @@ -0,0 +1,30 @@ +#include "planner/operator/logical_projection.h" +#include "processor/expression_mapper.h" +#include "processor/operator/projection.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapProjection( + const LogicalOperator* logicalOperator) { + auto& logicalProjection = logicalOperator->constCast(); + auto outSchema = logicalProjection.getSchema(); + auto inSchema = logicalProjection.getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto printInfo = + std::make_unique(logicalProjection.getExpressionsToProject()); + auto info = ProjectionInfo(); + info.discardedChunkIndices = logicalProjection.getDiscardedGroupsPos(); + auto exprMapper = ExpressionMapper(inSchema); + for (auto& expr : logicalProjection.getExpressionsToProject()) { + info.addEvaluator(exprMapper.getEvaluator(expr), getDataPos(*expr, *outSchema)); + } + return make_unique(std::move(info), std::move(prevOperator), getOperatorID(), + std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_recursive_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_recursive_extend.cpp new file mode 100644 index 0000000000..04b61911fb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_recursive_extend.cpp @@ -0,0 +1,73 @@ +#include "binder/expression/node_expression.h" +#include "graph/on_disk_graph.h" +#include "planner/operator/extend/logical_recursive_extend.h" +#include "planner/operator/sip/logical_semi_masker.h" +#include "processor/operator/recursive_extend.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; +using namespace lbug::graph; +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::unique_ptr createNodeOffsetMaskMap(const Expression& expr, + PlanMapper* mapper) { + auto& node = expr.constCast(); + auto maskMap = std::make_unique(); + for (auto tableID : node.getTableIDs()) { + maskMap->addMask(tableID, mapper->createSemiMask(tableID)); + } + return maskMap; +} + +std::unique_ptr PlanMapper::mapRecursiveExtend( + const LogicalOperator* logicalOperator) { + auto& extend = logicalOperator->constCast(); + auto& bindData = extend.getBindData(); + auto columns = extend.getResultColumns(); + auto tableSchema = createFlatFTableSchema(columns, *extend.getSchema()); + auto table = std::make_shared(storage::MemoryManager::Get(*clientContext), + tableSchema.copy()); + auto graph = std::make_unique(clientContext, bindData.graphEntry.copy()); + auto sharedState = + std::make_shared(table, std::move(graph), extend.getLimitNum()); + if (extend.hasInputNodeMask()) { + sharedState->setInputNodeMask(createNodeOffsetMaskMap(*bindData.nodeInput, this)); + } + if (extend.hasOutputNodeMask()) { + sharedState->setOutputNodeMask(createNodeOffsetMaskMap(*bindData.nodeOutput, this)); + } + auto printInfo = + std::make_unique(extend.getFunction().getFunctionName()); + auto recursiveExtend = std::make_unique(extend.getFunction().copy(), bindData, + sharedState, getOperatorID(), std::move(printInfo)); + // Map node predicate pipeline + if (extend.hasNodePredicate()) { + addOperatorMapping(logicalOperator, recursiveExtend.get()); + sharedState->setPathNodeMask(std::make_unique()); + auto maskMap = sharedState->getPathNodeMaskMap(); + KU_ASSERT(extend.getNumChildren() == 1); + auto logicalRoot = extend.getChild(0); + KU_ASSERT(logicalRoot->getNumChildren() == 1 && + logicalRoot->getChild(0)->getOperatorType() == LogicalOperatorType::SEMI_MASKER); + auto logicalSemiMasker = logicalRoot->getChild(0)->ptrCast(); + logicalSemiMasker->addTarget(logicalOperator); + for (auto tableID : logicalSemiMasker->getNodeTableIDs()) { + maskMap->addMask(tableID, createSemiMask(tableID)); + } + auto root = mapOperator(logicalRoot.get()); + recursiveExtend->addChild(std::move(root)); + eraseOperatorMapping(logicalOperator); + } + logicalOpToPhysicalOpMap.insert({logicalOperator, recursiveExtend.get()}); + physical_op_vector_t children; + children.push_back(std::move(recursiveExtend)); + return createFTableScanAligned(columns, extend.getSchema(), table, DEFAULT_VECTOR_CAPACITY, + std::move(children)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_scan_node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_scan_node_table.cpp new file mode 100644 index 0000000000..d4a680b860 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_scan_node_table.cpp @@ -0,0 +1,94 @@ +#include "binder/binder.h" +#include "binder/expression/property_expression.h" +#include "binder/expression_binder.h" +#include "common/mask.h" +#include "planner/operator/scan/logical_scan_node_table.h" +#include "processor/expression_mapper.h" +#include "processor/operator/scan/primary_key_scan_node_table.h" +#include "processor/operator/scan/scan_node_table.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapScanNodeTable( + const LogicalOperator* logicalOperator) { + auto storageManager = storage::StorageManager::Get(*clientContext); + auto catalog = catalog::Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto& scan = logicalOperator->constCast(); + const auto outSchema = scan.getSchema(); + auto nodeIDPos = getDataPos(*scan.getNodeID(), *outSchema); + std::vector outVectorsPos; + for (auto& expression : scan.getProperties()) { + outVectorsPos.emplace_back(getDataPos(*expression, *outSchema)); + } + auto scanInfo = ScanOpInfo(nodeIDPos, outVectorsPos); + const auto tableIDs = scan.getTableIDs(); + std::vector tableNames; + std::vector tableInfos; + auto binder = Binder(clientContext); + auto expressionBinder = ExpressionBinder(&binder, clientContext); + for (const auto& tableID : tableIDs) { + auto tableEntry = catalog->getTableCatalogEntry(transaction, tableID); + tableNames.push_back(tableEntry->getName()); + auto table = storageManager->getTable(tableID)->ptrCast(); + auto tableInfo = ScanNodeTableInfo(table, copyVector(scan.getPropertyPredicates())); + for (auto& expr : scan.getProperties()) { + auto& property = expr->constCast(); + if (property.hasProperty(tableEntry->getTableID())) { + auto propertyName = property.getPropertyName(); + auto& columnType = tableEntry->getProperty(propertyName).getType(); + auto columnCaster = ColumnCaster(columnType.copy()); + if (property.getDataType() != columnType) { + auto columnExpr = std::make_shared(property); + columnExpr->dataType = columnType.copy(); + columnCaster.setCastExpr( + expressionBinder.forceCast(columnExpr, property.getDataType())); + } + tableInfo.addColumnInfo(tableEntry->getColumnID(propertyName), + std::move(columnCaster)); + } else { + tableInfo.addColumnInfo(INVALID_COLUMN_ID, ColumnCaster(LogicalType::ANY())); + } + } + tableInfos.push_back(std::move(tableInfo)); + } + std::vector> sharedStates; + for (auto& tableID : tableIDs) { + auto table = storageManager->getTable(tableID)->ptrCast(); + auto semiMask = SemiMaskUtil::createMask(table->getNumTotalRows(transaction)); + sharedStates.push_back(std::make_shared(std::move(semiMask))); + } + auto alias = scan.getNodeID()->cast().getRawVariableName(); + std::unique_ptr result; + switch (scan.getScanType()) { + case LogicalScanNodeTableType::SCAN: { + auto printInfo = + std::make_unique(tableNames, alias, scan.getProperties()); + auto progressSharedState = std::make_shared(); + return std::make_unique(std::move(scanInfo), std::move(tableInfos), + std::move(sharedStates), getOperatorID(), std::move(printInfo), progressSharedState); + } + case LogicalScanNodeTableType::PRIMARY_KEY_SCAN: { + auto& primaryKeyScanInfo = scan.getExtraInfo()->constCast(); + auto exprMapper = ExpressionMapper(outSchema); + auto evaluator = exprMapper.getEvaluator(primaryKeyScanInfo.key); + auto sharedState = std::make_shared(tableInfos.size()); + auto printInfo = std::make_unique(scan.getProperties(), + primaryKeyScanInfo.key->toString(), alias); + return std::make_unique(std::move(scanInfo), std::move(tableInfos), + std::move(evaluator), std::move(sharedState), getOperatorID(), std::move(printInfo)); + } + default: + KU_UNREACHABLE; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_semi_masker.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_semi_masker.cpp new file mode 100644 index 0000000000..8d92b03411 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_semi_masker.cpp @@ -0,0 +1,131 @@ +#include "function/gds/gds.h" +#include "planner/operator/sip/logical_semi_masker.h" +#include "processor/operator/recursive_extend.h" +#include "processor/operator/scan/scan_node_table.h" +#include "processor/operator/semi_masker.h" +#include "processor/operator/table_function_call.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +// masksPerTable is collected from semiMasker. +// maskPerTable is collected from target operator, i.e. GDS, scan, ... +// Normally the two maps should have the same tableIDs. +// An exception is in GDS with filtered projected graph, multiple semiMasker will work on +// the same target GDS operator so masksPerTable may have fewer tableIDs. +static void initMask(table_id_map_t>& masksPerTable, + const table_id_map_t& maskPerTable) { + for (auto& [tableID, masks] : masksPerTable) { + KU_ASSERT(maskPerTable.contains(tableID)); + auto mask = maskPerTable.at(tableID); + mask->enable(); + masks.emplace_back(mask); + } +} + +std::unique_ptr PlanMapper::mapSemiMasker( + const LogicalOperator* logicalOperator) { + const auto& semiMasker = logicalOperator->constCast(); + const auto inSchema = semiMasker.getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + const auto tableIDs = semiMasker.getNodeTableIDs(); + table_id_map_t> masksPerTable; + for (auto tableID : tableIDs) { + masksPerTable.insert({tableID, std::vector{}}); + } + std::vector operatorNames; + for (auto& op : semiMasker.getTargetOperators()) { + const auto physicalOp = logicalOpToPhysicalOpMap.at(op); + operatorNames.push_back(PhysicalOperatorUtils::operatorToString(physicalOp)); + switch (physicalOp->getOperatorType()) { + case PhysicalOperatorType::SCAN_NODE_TABLE: { + KU_ASSERT(semiMasker.getTargetType() == SemiMaskTargetType::SCAN_NODE); + auto scan = physicalOp->ptrCast(); + initMask(masksPerTable, scan->getSemiMasks()); + } break; + case PhysicalOperatorType::TABLE_FUNCTION_CALL: { + auto sharedState = physicalOp->ptrCast()->getSharedState(); + switch (semiMasker.getTargetType()) { + case SemiMaskTargetType::GDS_GRAPH_NODE: { + auto funcSharedState = sharedState->ptrCast(); + initMask(masksPerTable, funcSharedState->getGraphNodeMaskMap()->getMasks()); + } break; + case SemiMaskTargetType::SCAN_NODE: { + auto tableFunc = physicalOp->ptrCast(); + initMask(masksPerTable, tableFunc->getSharedState()->getSemiMasks()); + } break; + default: + KU_UNREACHABLE; + } + } break; + case PhysicalOperatorType::RECURSIVE_EXTEND: { + auto sharedState = physicalOp->ptrCast()->getSharedState(); + NodeOffsetMaskMap* maskMap = nullptr; + switch (semiMasker.getTargetType()) { + case SemiMaskTargetType::RECURSIVE_EXTEND_INPUT_NODE: { + maskMap = sharedState->getInputNodeMaskMap(); + } break; + case SemiMaskTargetType::RECURSIVE_EXTEND_OUTPUT_NODE: { + maskMap = sharedState->getOutputNodeMaskMap(); + } break; + case SemiMaskTargetType::RECURSIVE_EXTEND_PATH_NODE: { + maskMap = sharedState->getPathNodeMaskMap(); + } break; + default: + KU_UNREACHABLE; + } + KU_ASSERT(maskMap != nullptr); + initMask(masksPerTable, maskMap->getMasks()); + } break; + default: + KU_UNREACHABLE; + } + } + auto keyPos = DataPos(inSchema->getExpressionPos(*semiMasker.getKey())); + auto sharedState = std::make_shared(std::move(masksPerTable)); + auto printInfo = std::make_unique(operatorNames); + switch (semiMasker.getKeyType()) { + case SemiMaskKeyType::NODE: { + if (tableIDs.size() > 1) { + return std::make_unique(keyPos, sharedState, + std::move(prevOperator), getOperatorID(), std::move(printInfo)); + } else { + return std::make_unique(keyPos, sharedState, + std::move(prevOperator), getOperatorID(), std::move(printInfo)); + } + } + case SemiMaskKeyType::PATH: { + auto& extraInfo = semiMasker.getExtraKeyInfo()->constCast(); + if (tableIDs.size() > 1) { + return std::make_unique(keyPos, sharedState, + std::move(prevOperator), getOperatorID(), std::move(printInfo), + extraInfo.direction); + } else { + return std::make_unique(keyPos, sharedState, + std::move(prevOperator), getOperatorID(), std::move(printInfo), + extraInfo.direction); + } + } + case SemiMaskKeyType::NODE_ID_LIST: { + auto& extraInfo = semiMasker.getExtraKeyInfo()->constCast(); + auto srcIDPos = getDataPos(*extraInfo.srcNodeID, *inSchema); + auto dstIDPos = getDataPos(*extraInfo.dstNodeID, *inSchema); + if (tableIDs.size() > 1) { + return std::make_unique(keyPos, srcIDPos, dstIDPos, + sharedState, std::move(prevOperator), getOperatorID(), std::move(printInfo)); + } else { + return std::make_unique(keyPos, srcIDPos, dstIDPos, + sharedState, std::move(prevOperator), getOperatorID(), std::move(printInfo)); + } + } + default: + KU_UNREACHABLE; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_set.cpp new file mode 100644 index 0000000000..ffdc8fd9ff --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_set.cpp @@ -0,0 +1,170 @@ +#include "binder/expression/property_expression.h" +#include "binder/expression/rel_expression.h" +#include "main/client_context.h" +#include "planner/operator/persistent/logical_set.h" +#include "processor/expression_mapper.h" +#include "processor/operator/persistent/set.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" +#include "storage/table/table.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::planner; +using namespace lbug::evaluator; +using namespace lbug::transaction; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +static column_id_t getColumnID(const TableCatalogEntry& entry, + const PropertyExpression& propertyExpr) { + auto columnID = INVALID_COLUMN_ID; + if (propertyExpr.hasProperty(entry.getTableID())) { + columnID = entry.getColumnID(propertyExpr.getPropertyName()); + } + return columnID; +} + +static NodeTableSetInfo getNodeTableSetInfo(const TableCatalogEntry& entry, const Expression& expr, + StorageManager* storageManager) { + auto table = storageManager->getTable(entry.getTableID())->ptrCast(); + auto columnID = getColumnID(entry, expr.constCast()); + return NodeTableSetInfo(table, columnID); +} + +static RelTableSetInfo getRelTableSetInfo(const RelGroupCatalogEntry& entry, table_id_t srcTableID, + table_id_t dstTableID, const Expression& expr, StorageManager* storageManager) { + auto relEntryInfo = entry.getRelEntryInfo(srcTableID, dstTableID); + auto table = storageManager->getTable(relEntryInfo->oid)->ptrCast(); + auto columnID = getColumnID(entry, expr.constCast()); + return RelTableSetInfo(table, columnID); +} + +std::unique_ptr PlanMapper::getNodeSetExecutor( + const BoundSetPropertyInfo& boundInfo, const Schema& schema) const { + auto& node = boundInfo.pattern->constCast(); + auto nodeIDPos = getDataPos(*node.getInternalID(), schema); + auto& property = boundInfo.column->constCast(); + auto columnVectorPos = DataPos::getInvalidPos(); + if (schema.isExpressionInScope(property)) { + columnVectorPos = getDataPos(property, schema); + } + auto exprMapper = ExpressionMapper(&schema); + auto evaluator = exprMapper.getEvaluator(boundInfo.columnData); + auto setInfo = NodeSetInfo(nodeIDPos, columnVectorPos, std::move(evaluator)); + if (node.isMultiLabeled()) { + table_id_map_t tableInfos; + for (auto entry : node.getEntries()) { + auto tableID = entry->getTableID(); + auto tableInfo = + getNodeTableSetInfo(*entry, property, StorageManager::Get(*clientContext)); + if (tableInfo.columnID == INVALID_COLUMN_ID) { + continue; + } + tableInfos.insert({tableID, std::move(tableInfo)}); + } + return std::make_unique(std::move(setInfo), + std::move(tableInfos)); + } + KU_ASSERT(node.getNumEntries() == 1); + auto tableInfo = + getNodeTableSetInfo(*node.getEntry(0), property, StorageManager::Get(*clientContext)); + return std::make_unique(std::move(setInfo), std::move(tableInfo)); +} + +std::unique_ptr PlanMapper::mapSetProperty( + const LogicalOperator* logicalOperator) { + auto set = logicalOperator->constPtrCast(); + switch (set->getTableType()) { + case TableType::NODE: { + return mapSetNodeProperty(logicalOperator); + } + case TableType::REL: { + return mapSetRelProperty(logicalOperator); + } + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr PlanMapper::mapSetNodeProperty( + const LogicalOperator* logicalOperator) { + auto set = logicalOperator->constPtrCast(); + auto inSchema = set->getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + std::vector> executors; + for (auto& info : set->getInfos()) { + executors.push_back(getNodeSetExecutor(info, *inSchema)); + } + std::vector expressions; + for (auto& info : set->getInfos()) { + expressions.emplace_back(info.column, info.columnData); + } + auto printInfo = std::make_unique(expressions); + return std::make_unique(std::move(executors), std::move(prevOperator), + getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::getRelSetExecutor(const BoundSetPropertyInfo& boundInfo, + const Schema& schema) const { + auto& rel = boundInfo.pattern->constCast(); + auto srcNodeIDPos = getDataPos(*rel.getSrcNode()->getInternalID(), schema); + auto dstNodeIDPos = getDataPos(*rel.getDstNode()->getInternalID(), schema); + auto relIDPos = getDataPos(*rel.getInternalID(), schema); + auto& property = boundInfo.column->constCast(); + auto columnVectorPos = DataPos::getInvalidPos(); + if (schema.isExpressionInScope(property)) { + columnVectorPos = getDataPos(property, schema); + } + auto exprMapper = ExpressionMapper(&schema); + auto evaluator = exprMapper.getEvaluator(boundInfo.columnData); + auto info = + RelSetInfo(srcNodeIDPos, dstNodeIDPos, relIDPos, columnVectorPos, std::move(evaluator)); + if (rel.isMultiLabeled()) { + table_id_map_t tableInfos; + for (auto entry : rel.getEntries()) { + auto& relGroupEntry = entry->constCast(); + for (auto& relEntryInfo : relGroupEntry.getRelEntryInfos()) { + auto srcTableID = relEntryInfo.nodePair.srcTableID; + auto dstTableID = relEntryInfo.nodePair.dstTableID; + auto tableInfo = getRelTableSetInfo(relGroupEntry, srcTableID, dstTableID, property, + StorageManager::Get(*clientContext)); + if (tableInfo.columnID == INVALID_COLUMN_ID) { + continue; + } + tableInfos.insert({tableInfo.table->getTableID(), std::move(tableInfo)}); + } + } + return std::make_unique(std::move(info), std::move(tableInfos)); + } + KU_ASSERT(rel.getNumEntries() == 1); + auto& relGroupEntry = rel.getEntry(0)->constCast(); + auto fromToNodePair = relGroupEntry.getSingleRelEntryInfo().nodePair; + auto tableInfo = getRelTableSetInfo(relGroupEntry, fromToNodePair.srcTableID, + fromToNodePair.dstTableID, property, StorageManager::Get(*clientContext)); + return std::make_unique(std::move(info), std::move(tableInfo)); +} + +std::unique_ptr PlanMapper::mapSetRelProperty( + const LogicalOperator* logicalOperator) { + auto set = logicalOperator->constPtrCast(); + auto inSchema = set->getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + std::vector> executors; + for (auto& info : set->getInfos()) { + executors.push_back(getRelSetExecutor(info, *inSchema)); + } + std::vector expressions; + for (auto& info : set->getInfos()) { + expressions.emplace_back(info.column, info.columnData); + } + auto printInfo = std::make_unique(expressions); + return std::make_unique(std::move(executors), std::move(prevOperator), + getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_simple.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_simple.cpp new file mode 100644 index 0000000000..b50652be12 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_simple.cpp @@ -0,0 +1,151 @@ +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "extension/mapper_extension.h" +#include "main/client_context.h" +#include "planner/operator/simple/logical_attach_database.h" +#include "planner/operator/simple/logical_detach_database.h" +#include "planner/operator/simple/logical_export_db.h" +#include "planner/operator/simple/logical_extension.h" +#include "planner/operator/simple/logical_import_db.h" +#include "planner/operator/simple/logical_use_database.h" +#include "processor/operator/persistent/copy_to.h" +#include "processor/operator/simple/attach_database.h" +#include "processor/operator/simple/detach_database.h" +#include "processor/operator/simple/export_db.h" +#include "processor/operator/simple/import_db.h" +#include "processor/operator/simple/install_extension.h" +#include "processor/operator/simple/load_extension.h" +#include "processor/operator/simple/uninstall_extension.h" +#include "processor/operator/simple/use_database.h" +#include "processor/plan_mapper.h" +#include "processor/result/factorized_table_util.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace processor { + +using namespace lbug::planner; +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::extension; + +std::unique_ptr PlanMapper::mapUseDatabase( + const LogicalOperator* logicalOperator) { + auto useDatabase = logicalOperator->constPtrCast(); + auto printInfo = std::make_unique(useDatabase->getDBName()); + auto messageTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + return std::make_unique(useDatabase->getDBName(), std::move(messageTable), + getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapAttachDatabase( + const LogicalOperator* logicalOperator) { + auto attachDatabase = logicalOperator->constPtrCast(); + auto info = attachDatabase->getAttachInfo(); + auto printInfo = std::make_unique(info.dbAlias, info.dbPath); + auto messageTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + return std::make_unique(std::move(info), std::move(messageTable), + getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapDetachDatabase( + const LogicalOperator* logicalOperator) { + auto detachDatabase = logicalOperator->constPtrCast(); + auto printInfo = std::make_unique(); + auto messageTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + return std::make_unique(detachDatabase->getDBName(), std::move(messageTable), + getOperatorID(), std::move(printInfo)); +} + +static void exportDatabaseCollectParallelFlags(const std::unique_ptr& sink) { + auto exportDB = sink->getChild(0)->ptrCast(); + for (auto i = 1u; i < sink->getNumChildren(); ++i) { + const auto& tableFuncCall = sink->getChild(i); + KU_ASSERT_UNCONDITIONAL( + tableFuncCall->getChild(0)->getOperatorType() == PhysicalOperatorType::COPY_TO); + const auto& [file, parallelFlag] = + tableFuncCall->getChild(0)->ptrCast()->getParallelFlag(); + exportDB->addToParallelReaderMap(file, parallelFlag); + } +} + +std::unique_ptr PlanMapper::mapExportDatabase( + const LogicalOperator* logicalOperator) { + auto exportDatabase = logicalOperator->constPtrCast(); + auto fs = VirtualFileSystem::GetUnsafe(*clientContext); + auto boundFileInfo = exportDatabase->getBoundFileInfo(); + KU_ASSERT(boundFileInfo->filePaths.size() == 1); + auto filePath = boundFileInfo->filePaths[0]; + if (fs->fileOrPathExists(filePath, clientContext)) { + throw RuntimeException(stringFormat("Directory {} already exists.", filePath)); + } + fs->createDir(filePath); + auto printInfo = std::make_unique(filePath, boundFileInfo->options); + auto messageTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + auto exportDB = std::make_unique(boundFileInfo->copy(), + exportDatabase->isSchemaOnly(), messageTable, getOperatorID(), std::move(printInfo)); + auto sink = std::make_unique(messageTable, getOperatorID()); + sink->addChild(std::move(exportDB)); + for (auto child : exportDatabase->getChildren()) { + sink->addChild(mapOperator(child.get())); + } + exportDatabaseCollectParallelFlags(sink); + return sink; +} + +std::unique_ptr PlanMapper::mapImportDatabase( + const LogicalOperator* logicalOperator) { + auto importDatabase = logicalOperator->constPtrCast(); + auto printInfo = std::make_unique(); + auto messageTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + return std::make_unique(importDatabase->getQuery(), importDatabase->getIndexQuery(), + std::move(messageTable), getOperatorID(), std::move(printInfo)); +} + +std::unique_ptr PlanMapper::mapExtension(const LogicalOperator* logicalOperator) { + auto logicalExtension = logicalOperator->constPtrCast(); + auto& auxInfo = logicalExtension->getAuxInfo(); + auto path = auxInfo.path; + auto messageTable = + FactorizedTableUtils::getSingleStringColumnFTable(MemoryManager::Get(*clientContext)); + switch (auxInfo.action) { + case ExtensionAction::INSTALL: { + auto installAuxInfo = auxInfo.contCast(); + InstallExtensionInfo info{path, installAuxInfo.extensionRepo, installAuxInfo.forceInstall}; + auto printInfo = std::make_unique(path); + return std::make_unique(std::move(info), std::move(messageTable), + getOperatorID(), std::move(printInfo)); + } + case ExtensionAction::UNINSTALL: { + auto printInfo = std::make_unique(path); + return std::make_unique(path, std::move(messageTable), getOperatorID(), + std::move(printInfo)); + } + case ExtensionAction::LOAD: { + auto printInfo = std::make_unique(path); + return std::make_unique(path, std::move(messageTable), getOperatorID(), + std::move(printInfo)); + } + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr PlanMapper::mapExtensionClause( + const LogicalOperator* logicalOperator) { + for (auto& mapperExtension : mapperExtensions) { + auto physicalOP = mapperExtension->map(logicalOperator, clientContext, getOperatorID()); + if (physicalOP) { + return physicalOP; + } + } + KU_UNREACHABLE; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_standalone_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_standalone_call.cpp new file mode 100644 index 0000000000..6bee7c7923 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_standalone_call.cpp @@ -0,0 +1,26 @@ +#include "binder/expression/literal_expression.h" +#include "main/db_config.h" +#include "planner/operator/logical_standalone_call.h" +#include "processor/operator/standalone_call.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapStandaloneCall( + const LogicalOperator* logicalOperator) { + auto logicalStandaloneCall = logicalOperator->constPtrCast(); + auto optionValue = + logicalStandaloneCall->getOptionValue()->constPtrCast(); + auto standaloneCallInfo = + StandaloneCallInfo(logicalStandaloneCall->getOption(), optionValue->getValue()); + auto printInfo = + std::make_unique(logicalStandaloneCall->getOption()->name); + return std::make_unique(std::move(standaloneCallInfo), getOperatorID(), + std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_table_function_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_table_function_call.cpp new file mode 100644 index 0000000000..bb943e33c1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_table_function_call.cpp @@ -0,0 +1,21 @@ +#include "planner/operator/logical_table_function_call.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapTableFunctionCall( + const LogicalOperator* logicalOperator) { + auto& call = logicalOperator->constCast(); + auto getPhysicalPlanFunc = call.getTableFunc().getPhysicalPlanFunc; + KU_ASSERT(getPhysicalPlanFunc); + auto res = getPhysicalPlanFunc(this, logicalOperator); + logicalOpToPhysicalOpMap.insert({logicalOperator, res.get()}); + return res; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_transaction.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_transaction.cpp new file mode 100644 index 0000000000..1732738f1a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_transaction.cpp @@ -0,0 +1,20 @@ +#include "planner/operator/logical_transaction.h" +#include "processor/operator/transaction.h" +#include "processor/plan_mapper.h" + +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapTransaction( + const LogicalOperator* logicalOperator) { + auto& logicalTransaction = logicalOperator->constCast(); + auto printInfo = + std::make_unique(logicalTransaction.getTransactionAction()); + return std::make_unique(logicalTransaction.getTransactionAction(), getOperatorID(), + std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_union.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_union.cpp new file mode 100644 index 0000000000..3d38c12546 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_union.cpp @@ -0,0 +1,49 @@ +#include "common/system_config.h" +#include "planner/operator/logical_union.h" +#include "processor/operator/table_scan/union_all_scan.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapUnionAll(const LogicalOperator* logicalOperator) { + auto& logicalUnionAll = logicalOperator->constCast(); + auto outSchema = logicalUnionAll.getSchema(); + // append result collectors to each child + std::vector> prevOperators; + std::vector> tables; + for (auto i = 0u; i < logicalOperator->getNumChildren(); ++i) { + auto child = logicalOperator->getChild(i); + auto childSchema = logicalUnionAll.getSchemaBeforeUnion(i); + auto prevOperator = mapOperator(child.get()); + auto resultCollector = createResultCollector(AccumulateType::REGULAR, + childSchema->getExpressionsInScope(), childSchema, std::move(prevOperator)); + tables.push_back(resultCollector->getResultFTable()); + prevOperators.push_back(std::move(resultCollector)); + } + // append union all + std::vector outputPositions; + std::vector columnIndices; + auto expressionsToUnion = logicalUnionAll.getExpressionsToUnion(); + for (auto i = 0u; i < expressionsToUnion.size(); ++i) { + auto expression = expressionsToUnion[i]; + outputPositions.emplace_back(outSchema->getExpressionPos(*expression)); + columnIndices.push_back(i); + } + auto info = UnionAllScanInfo(std::move(outputPositions), std::move(columnIndices)); + auto maxMorselSize = tables[0]->hasUnflatCol() ? 1 : DEFAULT_VECTOR_CAPACITY; + auto unionSharedState = make_shared(std::move(tables), maxMorselSize); + auto printInfo = std::make_unique(expressionsToUnion); + auto scan = make_unique(std::move(info), unionSharedState, getOperatorID(), + std::move(printInfo)); + for (auto& child : prevOperators) { + scan->addChild(std::move(child)); + } + return scan; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_unwind.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_unwind.cpp new file mode 100644 index 0000000000..d7a114e878 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/map_unwind.cpp @@ -0,0 +1,31 @@ +#include "planner/operator/logical_unwind.h" +#include "processor/expression_mapper.h" +#include "processor/operator/physical_operator.h" +#include "processor/operator/unwind.h" +#include "processor/plan_mapper.h" + +using namespace lbug::common; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapUnwind(const LogicalOperator* logicalOperator) { + auto& unwind = logicalOperator->constCast(); + auto outSchema = unwind.getSchema(); + auto inSchema = unwind.getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); + auto dataPos = DataPos(outSchema->getExpressionPos(*unwind.getOutExpr())); + auto exprMapper = ExpressionMapper(inSchema); + auto evaluator = exprMapper.getEvaluator(unwind.getInExpr()); + DataPos idPos; + if (unwind.hasIDExpr()) { + idPos = getDataPos(*unwind.getIDExpr(), *outSchema); + } + auto printInfo = std::make_unique(unwind.getInExpr(), unwind.getOutExpr()); + return std::make_unique(dataPos, idPos, std::move(evaluator), std::move(prevOperator), + getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/plan_mapper.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/plan_mapper.cpp new file mode 100644 index 0000000000..3211606d76 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/map/plan_mapper.cpp @@ -0,0 +1,234 @@ +#include "processor/plan_mapper.h" + +#include "main/client_context.h" +#include "main/database.h" +#include "planner/operator/logical_plan.h" +#include "processor/operator/profile.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +PlanMapper::PlanMapper(ExecutionContext* executionContext) + : executionContext{executionContext}, physicalOperatorID{0} { + clientContext = executionContext->clientContext; + mapperExtensions = clientContext->getDatabase()->getMapperExtensions(); +} + +std::unique_ptr PlanMapper::getPhysicalPlan(const LogicalPlan* logicalPlan, + const expression_vector& expressions, main::QueryResultType resultType, + ArrowResultConfig arrowConfig) { + auto root = mapOperator(logicalPlan->getLastOperator().get()); + if (!root->isSink()) { + if (resultType == main::QueryResultType::ARROW) { + root = createArrowResultCollector(arrowConfig, expressions, logicalPlan->getSchema(), + std::move(root)); + } else { + root = createResultCollector(AccumulateType::REGULAR, expressions, + logicalPlan->getSchema(), std::move(root)); + } + } + auto physicalPlan = std::make_unique(std::move(root)); + if (logicalPlan->isProfile()) { + physicalPlan->lastOperator->ptrCast()->setPhysicalPlan(physicalPlan.get()); + } + return physicalPlan; +} + +std::unique_ptr PlanMapper::mapOperator(const LogicalOperator* logicalOperator) { + std::unique_ptr physicalOperator; + switch (logicalOperator->getOperatorType()) { + case LogicalOperatorType::ACCUMULATE: { + physicalOperator = mapAccumulate(logicalOperator); + } break; + case LogicalOperatorType::AGGREGATE: { + physicalOperator = mapAggregate(logicalOperator); + } break; + case LogicalOperatorType::ALTER: { + physicalOperator = mapAlter(logicalOperator); + } break; + case LogicalOperatorType::ATTACH_DATABASE: { + physicalOperator = mapAttachDatabase(logicalOperator); + } break; + case LogicalOperatorType::COPY_FROM: { + physicalOperator = mapCopyFrom(logicalOperator); + } break; + case LogicalOperatorType::COPY_TO: { + physicalOperator = mapCopyTo(logicalOperator); + } break; + case LogicalOperatorType::CREATE_MACRO: { + physicalOperator = mapCreateMacro(logicalOperator); + } break; + case LogicalOperatorType::CREATE_SEQUENCE: { + physicalOperator = mapCreateSequence(logicalOperator); + } break; + case LogicalOperatorType::CREATE_TABLE: { + physicalOperator = mapCreateTable(logicalOperator); + } break; + case LogicalOperatorType::CREATE_TYPE: { + physicalOperator = mapCreateType(logicalOperator); + } break; + case LogicalOperatorType::CROSS_PRODUCT: { + physicalOperator = mapCrossProduct(logicalOperator); + } break; + case LogicalOperatorType::DELETE: { + physicalOperator = mapDelete(logicalOperator); + } break; + case LogicalOperatorType::DETACH_DATABASE: { + physicalOperator = mapDetachDatabase(logicalOperator); + } break; + case LogicalOperatorType::DISTINCT: { + physicalOperator = mapDistinct(logicalOperator); + } break; + case LogicalOperatorType::DROP: { + physicalOperator = mapDrop(logicalOperator); + } break; + case LogicalOperatorType::DUMMY_SCAN: { + physicalOperator = mapDummyScan(logicalOperator); + } break; + case LogicalOperatorType::DUMMY_SINK: { + physicalOperator = mapDummySink(logicalOperator); + } break; + case LogicalOperatorType::EMPTY_RESULT: { + physicalOperator = mapEmptyResult(logicalOperator); + } break; + case LogicalOperatorType::EXPLAIN: { + physicalOperator = mapExplain(logicalOperator); + } break; + case LogicalOperatorType::EXPRESSIONS_SCAN: { + physicalOperator = mapExpressionsScan(logicalOperator); + } break; + case LogicalOperatorType::EXTEND: { + physicalOperator = mapExtend(logicalOperator); + } break; + case LogicalOperatorType::EXTENSION: { + physicalOperator = mapExtension(logicalOperator); + } break; + case LogicalOperatorType::EXPORT_DATABASE: { + physicalOperator = mapExportDatabase(logicalOperator); + } break; + case LogicalOperatorType::FLATTEN: { + physicalOperator = mapFlatten(logicalOperator); + } break; + case LogicalOperatorType::FILTER: { + physicalOperator = mapFilter(logicalOperator); + } break; + case LogicalOperatorType::HASH_JOIN: { + physicalOperator = mapHashJoin(logicalOperator); + } break; + case LogicalOperatorType::IMPORT_DATABASE: { + physicalOperator = mapImportDatabase(logicalOperator); + } break; + case LogicalOperatorType::INDEX_LOOK_UP: { + physicalOperator = mapIndexLookup(logicalOperator); + } break; + case LogicalOperatorType::INTERSECT: { + physicalOperator = mapIntersect(logicalOperator); + } break; + case LogicalOperatorType::INSERT: { + physicalOperator = mapInsert(logicalOperator); + } break; + case LogicalOperatorType::LIMIT: { + physicalOperator = mapLimit(logicalOperator); + } break; + case LogicalOperatorType::MERGE: { + physicalOperator = mapMerge(logicalOperator); + } break; + case LogicalOperatorType::MULTIPLICITY_REDUCER: { + physicalOperator = mapMultiplicityReducer(logicalOperator); + } break; + case LogicalOperatorType::NODE_LABEL_FILTER: { + physicalOperator = mapNodeLabelFilter(logicalOperator); + } break; + case LogicalOperatorType::NOOP: { + physicalOperator = mapNoop(logicalOperator); + } break; + case LogicalOperatorType::ORDER_BY: { + physicalOperator = mapOrderBy(logicalOperator); + } break; + case LogicalOperatorType::PARTITIONER: { + physicalOperator = mapPartitioner(logicalOperator); + } break; + case LogicalOperatorType::PATH_PROPERTY_PROBE: { + physicalOperator = mapPathPropertyProbe(logicalOperator); + } break; + case LogicalOperatorType::PROJECTION: { + physicalOperator = mapProjection(logicalOperator); + } break; + case LogicalOperatorType::RECURSIVE_EXTEND: { + physicalOperator = mapRecursiveExtend(logicalOperator); + } break; + case LogicalOperatorType::SCAN_NODE_TABLE: { + physicalOperator = mapScanNodeTable(logicalOperator); + } break; + case LogicalOperatorType::SEMI_MASKER: { + physicalOperator = mapSemiMasker(logicalOperator); + } break; + case LogicalOperatorType::SET_PROPERTY: { + physicalOperator = mapSetProperty(logicalOperator); + } break; + case LogicalOperatorType::STANDALONE_CALL: { + physicalOperator = mapStandaloneCall(logicalOperator); + } break; + case LogicalOperatorType::TABLE_FUNCTION_CALL: { + physicalOperator = mapTableFunctionCall(logicalOperator); + } break; + case LogicalOperatorType::TRANSACTION: { + physicalOperator = mapTransaction(logicalOperator); + } break; + case LogicalOperatorType::UNION_ALL: { + physicalOperator = mapUnionAll(logicalOperator); + } break; + case LogicalOperatorType::UNWIND: { + physicalOperator = mapUnwind(logicalOperator); + } break; + case LogicalOperatorType::USE_DATABASE: { + physicalOperator = mapUseDatabase(logicalOperator); + } break; + case LogicalOperatorType::EXTENSION_CLAUSE: { + physicalOperator = mapExtensionClause(logicalOperator); + } break; + default: + KU_UNREACHABLE; + } + if (!logicalOpToPhysicalOpMap.contains(logicalOperator)) { + logicalOpToPhysicalOpMap.insert({logicalOperator, physicalOperator.get()}); + } + return physicalOperator; +} + +std::vector PlanMapper::getDataPos(const expression_vector& expressions, + const Schema& schema) { + std::vector result; + for (auto& expression : expressions) { + result.emplace_back(getDataPos(*expression, schema)); + } + return result; +} + +FactorizedTableSchema PlanMapper::createFlatFTableSchema(const expression_vector& expressions, + const Schema& schema) { + auto tableSchema = FactorizedTableSchema(); + for (auto& expr : expressions) { + auto dataPos = getDataPos(*expr, schema); + auto columnSchema = ColumnSchema(false /* isUnFlat */, dataPos.dataChunkPos, + LogicalTypeUtils::getRowLayoutSize(expr->getDataType())); + tableSchema.appendColumn(std::move(columnSchema)); + } + return tableSchema; +} + +std::unique_ptr PlanMapper::createSemiMask(table_id_t tableID) const { + auto table = StorageManager::Get(*clientContext)->getTable(tableID)->ptrCast(); + return SemiMaskUtil::createMask( + table->getNumTotalRows(transaction::Transaction::Get(*clientContext))); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/CMakeLists.txt new file mode 100644 index 0000000000..629659bc19 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/CMakeLists.txt @@ -0,0 +1,41 @@ +add_subdirectory(aggregate) +add_subdirectory(ddl) +add_subdirectory(hash_join) +add_subdirectory(intersect) +add_subdirectory(order_by) +add_subdirectory(persistent) +add_subdirectory(scan) +add_subdirectory(simple) +add_subdirectory(table_scan) +add_subdirectory(macro) + +add_library(lbug_processor_operator + OBJECT + arrow_result_collector.cpp + base_partitioner_shared_state.cpp + cross_product.cpp + empty_result.cpp + filter.cpp + filtering_operator.cpp + flatten.cpp + index_lookup.cpp + limit.cpp + multiplicity_reducer.cpp + partitioner.cpp + path_property_probe.cpp + physical_operator.cpp + projection.cpp + profile.cpp + recursive_extend.cpp + result_collector.cpp + semi_masker.cpp + sink.cpp + skip.cpp + standalone_call.cpp + table_function_call.cpp + transaction.cpp + unwind.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/CMakeLists.txt new file mode 100644 index 0000000000..edfe077b23 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(lbug_processor_operator_aggregate + OBJECT + aggregate_hash_table.cpp + base_aggregate.cpp + base_aggregate_scan.cpp + hash_aggregate.cpp + hash_aggregate_scan.cpp + simple_aggregate.cpp + simple_aggregate_scan.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/aggregate_hash_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/aggregate_hash_table.cpp new file mode 100644 index 0000000000..c74d3932ba --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/aggregate_hash_table.cpp @@ -0,0 +1,881 @@ +#include "processor/operator/aggregate/aggregate_hash_table.h" + +#include +#include + +#include "common/assert.h" +#include "common/constants.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/data_chunk/sel_vector.h" +#include "common/in_mem_overflow_buffer.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "common/utils.h" +#include "common/vector/value_vector.h" +#include "processor/operator/aggregate/aggregate_input.h" +#include "processor/result/factorized_table.h" +#include "processor/result/factorized_table_schema.h" + +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +AggregateHashTable::AggregateHashTable(MemoryManager& memoryManager, + std::vector keyTypes, std::vector payloadTypes, + const std::vector& aggregateFunctions, + const std::vector& distinctAggKeyTypes, uint64_t numEntriesToAllocate, + FactorizedTableSchema tableSchema) + : BaseHashTable{memoryManager, std::move(keyTypes)}, payloadTypes{std::move(payloadTypes)} { + initializeFT(aggregateFunctions, std::move(tableSchema)); + initializeHashTable(numEntriesToAllocate); + KU_ASSERT(aggregateFunctions.size() == distinctAggKeyTypes.size()); + distinctHashTables.reserve(this->aggregateFunctions.size()); + distinctHashEntriesProcessed.resize(this->aggregateFunctions.size()); + for (auto i = 0u; i < this->aggregateFunctions.size(); ++i) { + std::unique_ptr distinctHT; + if (this->aggregateFunctions[i].isDistinct) { + distinctHT = AggregateHashTableUtils::createDistinctHashTable(memoryManager, + this->keyTypes, distinctAggKeyTypes[i]); + } else { + distinctHT = nullptr; + } + distinctHashTables.push_back(std::move(distinctHT)); + } + initializeTmpVectors(); +} + +uint64_t AggregateHashTable::append(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, const DataChunkState* leadingState, + const std::vector& aggregateInputs, uint64_t resultSetMultiplicity) { + const auto numFlatTuples = leadingState->getSelVector().getSelSize(); + resizeHashTableIfNecessary(numFlatTuples); + computeVectorHashes(keyVectors); + findHashSlots(keyVectors, dependentKeyVectors, leadingState); + updateAggStates(keyVectors, aggregateInputs, resultSetMultiplicity, leadingState); + return numFlatTuples; +} + +hash_t getHash(const FactorizedTable& table, ft_tuple_idx_t tupleIdx) { + return *(hash_t*)(table.getTuple(tupleIdx) + table.getTableSchema()->getColOffset( + table.getTableSchema()->getNumColumns() - 1)); +} + +void AggregateHashTable::merge(FactorizedTable&& table) { + KU_ASSERT(*table.getTableSchema() == *getTableSchema()); + resizeHashTableIfNecessary(table.getNumTuples()); + + uint64_t startTupleIdx = 0; + while (startTupleIdx < table.getNumTuples()) { + auto numTuplesToScan = + std::min(table.getNumTuples() - startTupleIdx, DEFAULT_VECTOR_CAPACITY); + findHashSlots(table, startTupleIdx, numTuplesToScan); + auto aggregateStateOffset = aggStateColOffsetInFT; + for (auto& aggregateFunction : aggregateFunctions) { + // We'll update the distinct state at the end. + // The distinct data gets merged separately, and only after the main data so that + // we can guarantee that there is a group available in teh hash table for any given + // distinct tuple. + if (!aggregateFunction.isDistinct) { + for (auto i = 0u; i < numTuplesToScan; i++) { + aggregateFunction.combineState(hashSlotsToUpdateAggState[i]->getEntry() + + aggregateStateOffset, + table.getTuple(startTupleIdx + i) + aggregateStateOffset, + factorizedTable->getInMemOverflowBuffer()); + } + } + aggregateStateOffset += aggregateFunction.getAggregateStateSize(); + } + startTupleIdx += numTuplesToScan; + } +} + +void AggregateHashTable::mergeDistinctAggregateInfo() { + auto state = std::make_shared(); + std::vector> vectors; + std::vector vectorPtrs; + vectors.reserve(keyTypes.size() + 1); + vectorPtrs.reserve(keyTypes.size() + 1); + for (auto& keyType : keyTypes) { + vectors.emplace_back(std::make_unique(keyType.copy(), memoryManager, state)); + vectorPtrs.emplace_back(vectors.back().get()); + } + vectors.emplace_back(nullptr); + vectorPtrs.emplace_back(nullptr); + + auto aggStateColOffset = aggStateColOffsetInFT; + for (size_t distinctIdx = 0; distinctIdx < distinctHashTables.size(); distinctIdx++) { + auto& distinctHashTable = distinctHashTables[distinctIdx]; + if (distinctHashTable) { + // Distinct key type is always the last key type in the distinct table + vectors.back() = std::make_unique( + distinctHashTable->keyTypes.back().copy(), memoryManager, state); + vectorPtrs.back() = vectors.back().get(); + + // process 2048 at a time, beginning with the first unprocessed tuple + while (distinctHashEntriesProcessed[distinctIdx] < distinctHashTable->getNumEntries()) { + // Scan everything but the hash column, which isn't needed as we need hashes for + // everything but the distinct key type + std::vector colIdxToScan(vectorPtrs.size()); + std::iota(colIdxToScan.begin(), colIdxToScan.end(), 0); + + auto numTuplesToScan = std::min(DEFAULT_VECTOR_CAPACITY, + distinctHashTable->getNumEntries() - distinctHashEntriesProcessed[distinctIdx]); + state->getSelVectorUnsafe().setSelSize(numTuplesToScan); + distinctHashTable->factorizedTable->scan(vectorPtrs, + distinctHashEntriesProcessed[distinctIdx], numTuplesToScan, colIdxToScan); + // Compute hashes for the scanned keys but not the distinct key (i.e. the groups in + // the main hash table) + computeVectorHashes(std::span(const_cast(vectorPtrs.data()), + vectorPtrs.size() - 1)); + for (size_t i = 0; i < numTuplesToScan; i++) { + // Find slot in current hash table corresponding to the entry in the distinct + // hash table + // updatePosState may not check nulls, and we avoid inserting any nulls into + // distinct hash tables + if (vectors.back()->isNull(i)) { + continue; + } + auto hash = hashVector->getValue(i); + // matchFTEntries expects that the index is the same as the index in the vector + // Doesn't need to be done each time, but maybe this can be vectorized so that + // we can find the correct slots for all of them simultaneously? + mayMatchIdxes[0] = i; + auto entry = findEntry(hash, [&](auto entry) { + HashSlot slot{hash, entry}; + hashSlotsToUpdateAggState[i] = &slot; + return matchFTEntries( + std::span(const_cast(vectorPtrs.data()), + vectorPtrs.size() - 1), + 1, 0) == 0; + }); + KU_ASSERT(entry != nullptr); + aggregateFunctions[distinctIdx].updatePosState(entry + aggStateColOffset, + vectors.back().get() /*aggregateVector*/, + 1 /* Distinct aggregate should ignore multiplicity since they are known to + be non-distinct. */ + , + i, factorizedTable->getInMemOverflowBuffer()); + } + distinctHashEntriesProcessed[distinctIdx] += numTuplesToScan; + } + } + aggStateColOffset += aggregateFunctions[distinctIdx].getAggregateStateSize(); + } +} + +void AggregateHashTable::finalizeAggregateStates() { + if (!aggregateFunctions.empty()) { + for (auto i = 0u; i < getNumEntries(); ++i) { + auto entry = factorizedTable->getTuple(i); + auto aggregateStatesOffset = aggStateColOffsetInFT; + for (auto& aggregateFunction : aggregateFunctions) { + aggregateFunction.finalizeState(entry + aggregateStatesOffset); + aggregateStatesOffset += aggregateFunction.getAggregateStateSize(); + } + } + } +} + +void AggregateHashTable::initializeFT(const std::vector& aggFuncs, + FactorizedTableSchema&& tableSchema) { + aggStateColIdxInFT = keyTypes.size() + payloadTypes.size(); + for (auto& dataType : keyTypes) { + numBytesForKeys += LogicalTypeUtils::getRowLayoutSize(dataType); + } + for (auto& dataType : payloadTypes) { + numBytesForDependentKeys += LogicalTypeUtils::getRowLayoutSize(dataType); + } + aggStateColOffsetInFT = numBytesForKeys + numBytesForDependentKeys; + + aggregateFunctions.reserve(aggFuncs.size()); + for (auto i = 0u; i < aggFuncs.size(); i++) { + auto& aggFunc = aggFuncs[i]; + aggregateFunctions.push_back(aggFunc.copy()); + } + hashColIdxInFT = tableSchema.getNumColumns() - 1; + hashColOffsetInFT = tableSchema.getColOffset(hashColIdxInFT); + factorizedTable = std::make_unique(memoryManager, std::move(tableSchema)); +} + +void AggregateHashTable::initializeHashTable(uint64_t numEntriesToAllocate) { + auto numHashSlotsPerBlock = prevPowerOfTwo(HASH_BLOCK_SIZE / sizeof(HashSlot)); + setMaxNumHashSlots(nextPowerOfTwo(std::max(numHashSlotsPerBlock, numEntriesToAllocate))); + initSlotConstant(numHashSlotsPerBlock); + auto numDataBlocks = + maxNumHashSlots / numHashSlotsPerBlock + (maxNumHashSlots % numHashSlotsPerBlock != 0); + for (auto i = 0u; i < numDataBlocks; i++) { + hashSlotsBlocks.emplace_back(std::make_unique(memoryManager, HASH_BLOCK_SIZE)); + } +} + +void AggregateHashTable::initializeTmpVectors() { + hashSlotsToUpdateAggState = std::make_unique(DEFAULT_VECTOR_CAPACITY); + tmpValueIdxes = std::make_unique(DEFAULT_VECTOR_CAPACITY); + entryIdxesToInitialize = std::make_unique(DEFAULT_VECTOR_CAPACITY); + mayMatchIdxes = std::make_unique(DEFAULT_VECTOR_CAPACITY); + noMatchIdxes = std::make_unique(DEFAULT_VECTOR_CAPACITY); + tmpSlotIdxes = std::make_unique(DEFAULT_VECTOR_CAPACITY); +} + +uint8_t* AggregateHashTable::findEntryInDistinctHT( + const std::vector& groupByKeyVectors, hash_t hash) { + return findEntry(hash, + [&](auto entry) { return matchFlatVecWithEntry(groupByKeyVectors, entry); }); +} + +void AggregateHashTable::resize(uint64_t newSize) { + setMaxNumHashSlots(newSize); + addDataBlocksIfNecessary(maxNumHashSlots); + for (auto& block : hashSlotsBlocks) { + block->resetToZero(); + } + factorizedTable->forEach( + [&](auto tuple) { fillHashSlot(*(hash_t*)(tuple + hashColOffsetInFT), tuple); }); +} + +uint64_t AggregateHashTable::matchFTEntries(std::span keyVectors, + uint64_t numMayMatches, uint64_t numNoMatches) { + auto colIdx = 0u; + for (const auto& keyVector : keyVectors) { + if (keyVector->state->isFlat()) { + numMayMatches = + matchFlatVecWithFTColumn(keyVector, numMayMatches, numNoMatches, colIdx++); + } else { + numMayMatches = + matchUnFlatVecWithFTColumn(keyVector, numMayMatches, numNoMatches, colIdx++); + } + } + return numNoMatches; +} + +uint64_t AggregateHashTable::matchFTEntries(const FactorizedTable& srcTable, uint64_t startOffset, + uint64_t numMayMatches, uint64_t numNoMatches) { + for (auto colIdx = 0u; colIdx < keyTypes.size(); colIdx++) { + auto colOffset = getTableSchema()->getColOffset(colIdx); + uint64_t mayMatchIdx = 0; + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto& slot = *hashSlotsToUpdateAggState[idx]; + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + slot.getEntry() + getTableSchema()->getNullMapOffset(), colIdx); + auto isSrcEntryKeyNull = srcTable.isNonOverflowColNull(startOffset + idx, colIdx); + if (isEntryKeyNull && isSrcEntryKeyNull) { + mayMatchIdxes[mayMatchIdx++] = idx; + continue; + } else if (isEntryKeyNull != isSrcEntryKeyNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + if (ftCompareEntryFuncs[colIdx](srcTable.getTuple(startOffset + idx) + colOffset, + hashSlotsToUpdateAggState[idx]->getEntry() + colOffset, keyTypes[colIdx])) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + numMayMatches = mayMatchIdx; + } + return numNoMatches; +} + +void AggregateHashTable::initializeFTEntries(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, uint64_t numFTEntriesToInitialize) { + auto colIdx = 0u; + for (auto keyVector : keyVectors) { + if (keyVector->state->isFlat()) { + initializeFTEntryWithFlatVec(keyVector, numFTEntriesToInitialize, colIdx++); + } else { + initializeFTEntryWithUnFlatVec(keyVector, numFTEntriesToInitialize, colIdx++); + } + } + for (auto dependentKeyVector : dependentKeyVectors) { + if (dependentKeyVector->state->isFlat()) { + initializeFTEntryWithFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); + } else { + initializeFTEntryWithUnFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); + } + } + for (auto i = 0u; i < numFTEntriesToInitialize; i++) { + auto entryIdx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[entryIdx]; + fillEntryWithInitialNullAggregateState(*factorizedTable, slot.getEntry()); + // Fill the hashValue in the ftEntry. + factorizedTable->updateFlatCellNoNull(slot.getEntry(), hashColIdxInFT, + hashVector->getData() + hashVector->getNumBytesPerValue() * entryIdx); + } +} + +void AggregateHashTable::initializeFTEntries(const FactorizedTable& sourceTable, + uint64_t sourceStartOffset, uint64_t numFTEntriesToInitialize) { + // auto colIdx = 0u; + for (size_t i = 0; i < numFTEntriesToInitialize; i++) { + auto idx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[idx]; + auto sourcePos = sourceStartOffset + idx; + memcpy(slot.getEntry(), sourceTable.getTuple(sourcePos), + getTableSchema()->getNumBytesPerTuple()); + } + for (auto i = 0u; i < numFTEntriesToInitialize; i++) { + auto entryIdx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[entryIdx]; + fillEntryWithInitialNullAggregateState(*factorizedTable, slot.getEntry()); + // Fill the hashValue in the ftEntry. + factorizedTable->updateFlatCellNoNull(slot.getEntry(), hashColIdxInFT, + hashVector->getData() + hashVector->getNumBytesPerValue() * entryIdx); + } +} + +uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn(const ValueVector* vector, + uint64_t numMayMatches, uint64_t& numNoMatches, uint32_t colIdx) { + KU_ASSERT(!vector->state->isFlat()); + auto& schema = *getTableSchema(); + auto colOffset = schema.getColOffset(colIdx); + uint64_t mayMatchIdx = 0; + if (vector->hasNoNullsGuarantee()) { + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto& slot = *hashSlotsToUpdateAggState[idx]; + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + slot.getEntry() + schema.getNullMapOffset(), colIdx); + if (isEntryKeyNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + if (compareEntryFuncs[colIdx](vector, idx, slot.getEntry() + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + } else { + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto isKeyVectorNull = vector->isNull(idx); + auto& slot = *hashSlotsToUpdateAggState[idx]; + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + slot.getEntry() + schema.getNullMapOffset(), colIdx); + if (isKeyVectorNull && isEntryKeyNull) { + mayMatchIdxes[mayMatchIdx++] = idx; + continue; + } else if (isKeyVectorNull != isEntryKeyNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + + if (compareEntryFuncs[colIdx](vector, idx, slot.getEntry() + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + } + return mayMatchIdx; +} + +uint64_t AggregateHashTable::matchFlatVecWithFTColumn(const ValueVector* vector, + uint64_t numMayMatches, uint64_t& numNoMatches, uint32_t colIdx) { + KU_ASSERT(vector->state->isFlat()); + auto colOffset = getTableSchema()->getColOffset(colIdx); + uint64_t mayMatchIdx = 0; + auto pos = vector->state->getSelVector()[0]; + auto isVectorNull = vector->isNull(pos); + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto& slot = *hashSlotsToUpdateAggState[idx]; + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + slot.getEntry() + getTableSchema()->getNullMapOffset(), colIdx); + if (isEntryKeyNull && isVectorNull) { + mayMatchIdxes[mayMatchIdx++] = idx; + continue; + } else if (isEntryKeyNull != isVectorNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + if (compareEntryFuncs[colIdx](vector, pos, + hashSlotsToUpdateAggState[idx]->getEntry() + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + return mayMatchIdx; +} + +void AggregateHashTable::initializeFTEntryWithFlatVec(ValueVector* flatVector, + uint64_t numEntriesToInitialize, uint32_t colIdx) { + KU_ASSERT(flatVector->state->isFlat()); + auto colOffset = getTableSchema()->getColOffset(colIdx); + auto pos = flatVector->state->getSelVector()[0]; + if (flatVector->isNull(pos)) { + for (auto i = 0u; i < numEntriesToInitialize; i++) { + auto idx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[idx]; + factorizedTable->setNonOverflowColNull( + slot.getEntry() + getTableSchema()->getNullMapOffset(), colIdx); + } + } else { + for (auto i = 0u; i < numEntriesToInitialize; i++) { + auto idx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[idx]; + flatVector->copyToRowData(pos, slot.getEntry() + colOffset, + factorizedTable->getInMemOverflowBuffer()); + } + } +} + +void AggregateHashTable::initializeFTEntryWithUnFlatVec(ValueVector* unFlatVector, + uint64_t numEntriesToInitialize, uint32_t colIdx) { + KU_ASSERT(!unFlatVector->state->isFlat()); + auto colOffset = factorizedTable->getTableSchema()->getColOffset(colIdx); + if (unFlatVector->hasNoNullsGuarantee()) { + for (auto i = 0u; i < numEntriesToInitialize; i++) { + auto entryIdx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[entryIdx]; + unFlatVector->copyToRowData(entryIdx, slot.getEntry() + colOffset, + factorizedTable->getInMemOverflowBuffer()); + } + } else { + for (auto i = 0u; i < numEntriesToInitialize; i++) { + auto entryIdx = entryIdxesToInitialize[i]; + auto& slot = *hashSlotsToUpdateAggState[entryIdx]; + factorizedTable->updateFlatCell(slot.getEntry(), colIdx, unFlatVector, entryIdx); + } + } +} + +uint8_t* AggregateHashTable::createEntryInDistinctHT( + const std::vector& groupByHashKeyVectors, hash_t hash) { + auto entry = factorizedTable->appendEmptyTuple(); + for (auto i = 0u; i < groupByHashKeyVectors.size(); i++) { + factorizedTable->updateFlatCell(entry, i, groupByHashKeyVectors[i], + groupByHashKeyVectors[i]->state->getSelVector()[0]); + } + factorizedTable->updateFlatCellNoNull(entry, hashColIdxInFT, &hash); + fillEntryWithInitialNullAggregateState(*factorizedTable, entry); + fillHashSlot(hash, entry); + return entry; +} + +void AggregateHashTable::increaseSlotIdx(uint64_t& slotIdx) const { + slotIdx = (slotIdx + 1) % maxNumHashSlots; +} + +void AggregateHashTable::initTmpHashSlotsAndIdxes(const FactorizedTable& sourceTable, + uint64_t startOffset, uint64_t numTuples) { + for (size_t i = 0; i < numTuples; i++) { + tmpValueIdxes[i] = i; + auto hash = getHash(sourceTable, startOffset + i); + hashVector->setValue(i, hash); + tmpSlotIdxes[i] = getSlotIdxForHash(hash); + hashSlotsToUpdateAggState[i] = getHashSlot(tmpSlotIdxes[i]); + } +} + +void AggregateHashTable::initTmpHashSlotsAndIdxes() { + auto& hashSelVector = hashVector->state->getSelVector(); + if (hashSelVector.isUnfiltered()) { + for (auto i = 0u; i < hashSelVector.getSelSize(); i++) { + tmpValueIdxes[i] = i; + tmpSlotIdxes[i] = getSlotIdxForHash(hashVector->getValue(i)); + hashSlotsToUpdateAggState[i] = getHashSlot(tmpSlotIdxes[i]); + } + } else { + for (auto i = 0u; i < hashSelVector.getSelSize(); i++) { + auto pos = hashSelVector[i]; + tmpValueIdxes[i] = pos; + tmpSlotIdxes[pos] = getSlotIdxForHash(hashVector->getValue(pos)); + hashSlotsToUpdateAggState[pos] = getHashSlot(tmpSlotIdxes[pos]); + } + } +} + +void AggregateHashTable::increaseHashSlotIdxes(uint64_t numNoMatches) { + for (auto i = 0u; i < numNoMatches; i++) { + auto idx = noMatchIdxes[i]; + increaseSlotIdx(tmpSlotIdxes[idx]); + hashSlotsToUpdateAggState[idx] = getHashSlot(tmpSlotIdxes[idx]); + } +} + +void AggregateHashTable::findHashSlots(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, const DataChunkState* leadingState) { + initTmpHashSlotsAndIdxes(); + auto numEntriesToFindHashSlots = leadingState->getSelSize(); + KU_ASSERT(getNumEntries() + numEntriesToFindHashSlots < maxNumHashSlots); + while (numEntriesToFindHashSlots > 0) { + uint64_t numFTEntriesToUpdate = 0; + uint64_t numMayMatches = 0; + uint64_t numNoMatches = 0; + for (auto i = 0u; i < numEntriesToFindHashSlots; i++) { + auto idx = tmpValueIdxes[i]; + auto hash = hashVector->getValue(idx); + auto slot = hashSlotsToUpdateAggState[idx]; + if (slot->getEntry() == nullptr) { + entryIdxesToInitialize[numFTEntriesToUpdate++] = idx; + *slot = HashSlot(hash, factorizedTable->appendEmptyTuple()); + } else if (slot->checkFingerprint(hash)) { + mayMatchIdxes[numMayMatches++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + initializeFTEntries(keyVectors, dependentKeyVectors, numFTEntriesToUpdate); + numNoMatches = matchFTEntries(constSpan(keyVectors), numMayMatches, numNoMatches); + increaseHashSlotIdxes(numNoMatches); + KU_ASSERT(numNoMatches <= numEntriesToFindHashSlots); + numEntriesToFindHashSlots = numNoMatches; + memcpy(tmpValueIdxes.get(), noMatchIdxes.get(), numNoMatches * sizeof(uint64_t)); + } +} + +void AggregateHashTable::findHashSlots(const FactorizedTable& srcTable, uint64_t startOffset, + uint64_t numEntriesToFindHashSlots) { + initTmpHashSlotsAndIdxes(srcTable, startOffset, numEntriesToFindHashSlots); + KU_ASSERT(getNumEntries() + numEntriesToFindHashSlots < maxNumHashSlots); + while (numEntriesToFindHashSlots > 0) { + uint64_t numFTEntriesToUpdate = 0; + uint64_t numMayMatches = 0; + uint64_t numNoMatches = 0; + for (auto i = 0u; i < numEntriesToFindHashSlots; i++) { + auto idx = tmpValueIdxes[i]; + auto slot = hashSlotsToUpdateAggState[idx]; + auto hash = hashVector->getValue(idx); + if (slot->getEntry() == nullptr) { + entryIdxesToInitialize[numFTEntriesToUpdate++] = idx; + *slot = HashSlot(hash, factorizedTable->appendEmptyTuple()); + } else if (slot->checkFingerprint(hash)) { + mayMatchIdxes[numMayMatches++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + initializeFTEntries(srcTable, startOffset, numFTEntriesToUpdate); + numNoMatches = matchFTEntries(srcTable, startOffset, numMayMatches, numNoMatches); + increaseHashSlotIdxes(numNoMatches); + KU_ASSERT(numNoMatches <= numEntriesToFindHashSlots); + numEntriesToFindHashSlots = numNoMatches; + memcpy(tmpValueIdxes.get(), noMatchIdxes.get(), numNoMatches * sizeof(uint64_t)); + } +} + +void AggregateHashTable::appendDistinct(const std::vector& keyVectors, + ValueVector* aggregateVector, const DataChunkState* leadingState) { + std::vector distinctKeyVectors(keyVectors); + distinctKeyVectors.push_back(aggregateVector); + // The aggregateVector's state is either the same as the leading state (doesn't matter which we + // use), flat (where we must pass the original leading state), or unflat while the leadingState + // is flat (where we must pass the aggregateVector's state) + auto distinctLeadingState = + aggregateVector->state->isFlat() ? leadingState : aggregateVector->state.get(); + append(distinctKeyVectors, distinctLeadingState, std::vector{}, + 1 /*multiplicity*/); +} + +void AggregateHashTable::updateAggState(const std::vector& keyVectors, + AggregateFunction& aggregateFunction, ValueVector* aggVector, uint64_t multiplicity, + uint32_t aggStateOffset, const DataChunkState* leadingState) { + // There may be a mix of flat and unflat states, but any unflat states will be the same + bool allFlat = leadingState->isFlat(); + if (!aggVector) { + updateNullAggVectorState(*leadingState, aggregateFunction, multiplicity, aggStateOffset); + } else if (aggVector->state->isFlat() && allFlat) { + updateBothFlatAggVectorState(aggregateFunction, aggVector, multiplicity, aggStateOffset); + } else if (aggVector->state->isFlat()) { + updateFlatUnFlatKeyFlatAggVectorState(*leadingState, aggregateFunction, aggVector, + multiplicity, aggStateOffset); + } else if (allFlat) { + updateFlatKeyUnFlatAggVectorState(keyVectors, aggregateFunction, aggVector, multiplicity, + aggStateOffset); + } else if (aggVector->state.get() == leadingState) { + updateBothUnFlatSameDCAggVectorState(aggregateFunction, aggVector, multiplicity, + aggStateOffset); + } else { + updateBothUnFlatDifferentDCAggVectorState(*leadingState, aggregateFunction, aggVector, + multiplicity, aggStateOffset); + } +} + +void AggregateHashTable::updateAggStates(const std::vector& keyVectors, + const std::vector& aggregateInputs, uint64_t resultSetMultiplicity, + const DataChunkState* leadingState) { + auto aggregateStateOffset = aggStateColOffsetInFT; + for (auto i = 0u; i < aggregateFunctions.size(); i++) { + if (!aggregateFunctions[i].isDistinct) { + auto multiplicity = resultSetMultiplicity; + for (auto& dataChunk : aggregateInputs[i].multiplicityChunks) { + multiplicity *= dataChunk->state->getSelVector().getSelSize(); + } + updateAggState(keyVectors, aggregateFunctions[i], aggregateInputs[i].aggregateVector, + multiplicity, aggregateStateOffset, leadingState); + aggregateStateOffset += aggregateFunctions[i].getAggregateStateSize(); + } else { + // If a function is distinct we still need to insert the value into the distinct + // hash table + distinctHashTables[i]->appendDistinct(keyVectors, aggregateInputs[i].aggregateVector, + leadingState); + } + } +} + +void AggregateHashTable::fillEntryWithInitialNullAggregateState(FactorizedTable& table, + uint8_t* entry) { + for (auto i = 0u; i < aggregateFunctions.size(); i++) { + table.updateFlatCellNoNull(entry, aggStateColIdxInFT + i, + (void*)aggregateFunctions[i].getInitialNullAggregateState()); + } +} + +void AggregateHashTable::fillHashSlot(hash_t hash, uint8_t* groupByKeysAndAggregateStateBuffer) { + auto slotIdx = getSlotIdxForHash(hash); + auto hashSlot = getHashSlot(slotIdx); + while (true) { + if (hashSlot->getEntry()) { + increaseSlotIdx(slotIdx); + hashSlot = getHashSlot(slotIdx); + continue; + } + break; + } + *hashSlot = HashSlot(hash, groupByKeysAndAggregateStateBuffer); +} + +void AggregateHashTable::addDataBlocksIfNecessary(uint64_t maxNumHashSlots) { + auto numHashSlotsPerBlock = static_cast(1) << numSlotsPerBlockLog2; + auto numHashSlotsBlocksNeeded = + (maxNumHashSlots + numHashSlotsPerBlock - 1) / numHashSlotsPerBlock; + while (hashSlotsBlocks.size() < numHashSlotsBlocksNeeded) { + hashSlotsBlocks.emplace_back(std::make_unique(memoryManager, HASH_BLOCK_SIZE)); + } +} + +void AggregateHashTable::resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys) { + if (factorizedTable->getNumTuples() + maxNumDistinctHashKeys > maxNumHashSlots || + static_cast(factorizedTable->getNumTuples()) + maxNumDistinctHashKeys > + static_cast(maxNumHashSlots) / DEFAULT_HT_LOAD_FACTOR) { + resize(std::max(factorizedTable->getNumTuples() + maxNumDistinctHashKeys, maxNumHashSlots) * + DEFAULT_HT_LOAD_FACTOR); + } +} + +void AggregateHashTable::updateNullAggVectorState(const DataChunkState& keyState, + AggregateFunction& aggregateFunction, uint64_t multiplicity, uint32_t aggStateOffset) { + if (keyState.isFlat()) { + auto pos = keyState.getSelVector()[0]; + aggregateFunction.updatePosState(hashSlotsToUpdateAggState[pos]->getEntry() + + aggStateOffset, + nullptr, multiplicity, 0 /* dummy pos */, factorizedTable->getInMemOverflowBuffer()); + } else { + keyState.getSelVector().forEach([&](auto pos) { + aggregateFunction.updatePosState( + hashSlotsToUpdateAggState[pos]->getEntry() + aggStateOffset, nullptr, multiplicity, + 0 /* dummy pos */, factorizedTable->getInMemOverflowBuffer()); + }); + } +} + +void AggregateHashTable::updateBothFlatAggVectorState(AggregateFunction& aggregateFunction, + ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset) { + auto aggPos = aggVector->state->getSelVector()[0]; + if (!aggVector->isNull(aggPos)) { + aggregateFunction.updatePosState( + hashSlotsToUpdateAggState[hashVector->state->getSelVector()[0]]->getEntry() + + aggStateOffset, + aggVector, multiplicity, aggPos, factorizedTable->getInMemOverflowBuffer()); + } +} + +void AggregateHashTable::updateFlatUnFlatKeyFlatAggVectorState(const DataChunkState& unFlatKeyState, + AggregateFunction& aggregateFunction, ValueVector* aggVector, uint64_t multiplicity, + uint32_t aggStateOffset) { + auto aggPos = aggVector->state->getSelVector()[0]; + if (!aggVector->isNull(aggPos)) { + unFlatKeyState.getSelVector().forEach([&](auto pos) { + aggregateFunction.updatePosState(hashSlotsToUpdateAggState[pos]->getEntry() + + aggStateOffset, + aggVector, multiplicity, aggPos, factorizedTable->getInMemOverflowBuffer()); + }); + } +} + +void AggregateHashTable::updateFlatKeyUnFlatAggVectorState( + const std::vector& flatKeyVectors, AggregateFunction& aggregateFunction, + ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset) { + auto groupByKeyPos = flatKeyVectors[0]->state->getSelVector()[0]; + aggVector->forEachNonNull([&](auto pos) { + aggregateFunction.updatePosState(hashSlotsToUpdateAggState[groupByKeyPos]->getEntry() + + aggStateOffset, + aggVector, multiplicity, pos, factorizedTable->getInMemOverflowBuffer()); + }); +} + +void AggregateHashTable::updateBothUnFlatSameDCAggVectorState(AggregateFunction& aggregateFunction, + ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset) { + aggVector->forEachNonNull([&](auto pos) { + aggregateFunction.updatePosState(hashSlotsToUpdateAggState[pos]->getEntry() + + aggStateOffset, + aggVector, multiplicity, pos, factorizedTable->getInMemOverflowBuffer()); + }); +} + +void AggregateHashTable::updateBothUnFlatDifferentDCAggVectorState( + const DataChunkState& unFlatKeyState, AggregateFunction& aggregateFunction, + ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset) { + unFlatKeyState.getSelVector().forEach([&](auto pos) { + aggregateFunction.updateAllState(hashSlotsToUpdateAggState[pos]->getEntry() + + aggStateOffset, + aggVector, multiplicity, factorizedTable->getInMemOverflowBuffer()); + }); +} + +FactorizedTableSchema AggregateHashTableUtils::getTableSchemaForKeys( + const std::vector& groupByKeyTypes, + const common::LogicalType& distinctKeyType) { + auto tableSchema = FactorizedTableSchema(); + // Group by key columns + for (auto i = 0u; i < groupByKeyTypes.size(); i++) { + auto size = LogicalTypeUtils::getRowLayoutSize(groupByKeyTypes[i]); + auto columnSchema = ColumnSchema(false /* isUnFlat */, 0 /* groupID */, size); + // This isn't really necessary except in the global queues, but it's easier to just always + // set it here + columnSchema.setMayContainsNullsToTrue(); + tableSchema.appendColumn(std::move(columnSchema)); + } + // Distinct key column + auto columnSchema = ColumnSchema(false /* isUnFlat */, 0 /* groupID */, + LogicalTypeUtils::getRowLayoutSize(distinctKeyType)); + columnSchema.setMayContainsNullsToTrue(); + tableSchema.appendColumn(std::move(columnSchema)); + // Hash column + tableSchema.appendColumn(ColumnSchema(false /* isUnFlat */, 0 /* groupID */, sizeof(hash_t))); + return tableSchema; +} + +std::unique_ptr AggregateHashTableUtils::createDistinctHashTable( + MemoryManager& memoryManager, const std::vector& groupByKeyTypes, + const LogicalType& distinctKeyType) { + std::vector hashKeyTypes(groupByKeyTypes.size() + 1); + auto i = 0u; + // Group by key columns + for (; i < groupByKeyTypes.size(); i++) { + hashKeyTypes[i] = groupByKeyTypes[i].copy(); + } + // Distinct key column + hashKeyTypes[i] = distinctKeyType.copy(); + return std::make_unique(memoryManager, std::move(hashKeyTypes), + std::vector{} /* empty payload types */, 0 /* numEntriesToAllocate */, + getTableSchemaForKeys(groupByKeyTypes, distinctKeyType)); +} + +uint64_t PartitioningAggregateHashTable::append(const std::vector& keyVectors, + const std::vector& dependentKeyVectors, + const common::DataChunkState* leadingState, const std::vector& aggregateInputs, + uint64_t resultSetMultiplicity) { + const auto numFlatTuples = leadingState->getSelVector().getSelSize(); + + mergeIfFull(numFlatTuples); + + // mergeAll makes use of the hashVector, so it needs to be called before computeVectorHashes + computeVectorHashes(keyVectors); + KU_ASSERT( + hashVector->getSelVectorPtr()->getSelSize() == leadingState->getSelVector().getSelSize()); + findHashSlots(keyVectors, dependentKeyVectors, leadingState); + // Don't update distinct states since they can't be merged into the global hash tables. + // Instead we'll calculate them from scratch when merging. + updateAggStates(keyVectors, aggregateInputs, resultSetMultiplicity, leadingState); + return numFlatTuples; +} + +bool outOfSpace(const AggregateHashTable& hashTable, uint64_t newNumTuples) { + return hashTable.getNumEntries() + newNumTuples > hashTable.getCapacity() || + static_cast(hashTable.getNumEntries()) + newNumTuples > + static_cast(hashTable.getCapacity()) / DEFAULT_HT_LOAD_FACTOR; +} + +void PartitioningAggregateHashTable::mergeIfFull(uint64_t tuplesToAdd, bool mergeAll) { + if (mergeAll || outOfSpace(*this, tuplesToAdd)) { + partitioningData->appendTuples(*factorizedTable, + tableSchema.getColOffset(tableSchema.getNumColumns() - 1)); + // Move overflow data into the shared state so that it isn't obliterated when we clear + // the factorized table + partitioningData->appendOverflow(std::move(*factorizedTable->getInMemOverflowBuffer())); + clear(); + } + bool anyToMerge = mergeAll; + if (!mergeAll) { + for (const auto& hashTable : distinctHashTables) { + if (hashTable && outOfSpace(*hashTable, tuplesToAdd)) { + anyToMerge = true; + break; + } + } + } + // If no distinct hash tables need to be merged, skip the setup needed for merging + if (!anyToMerge) { + return; + } + std::vector vectors; + std::vector> keyVectors; + std::vector colIdxToScan; + if (distinctHashTables.size() > 0) { + // We need the hashes of the key columns to partition them appropriately. + // These will be the same for each of the distinct hash tables since we exclude the + // distinct aggregate key Reserve inside here so that we don't unnecessarily allocate + // memory if there are no distinct hash tables + colIdxToScan.resize(keyTypes.size()); + std::iota(colIdxToScan.begin(), colIdxToScan.end(), 0); + vectors.reserve(keyTypes.size()); + keyVectors.reserve(keyTypes.size()); + auto state = std::make_shared(); + + for (const auto& keyType : keyTypes) { + keyVectors.push_back( + std::make_unique(keyType.copy(), memoryManager, state)); + vectors.push_back(keyVectors.back().get()); + } + } + for (size_t distinctIdx = 0; distinctIdx < distinctHashTables.size(); distinctIdx++) { + auto& distinctHashTable = distinctHashTables[distinctIdx]; + if (distinctHashTable && (mergeAll || outOfSpace(*distinctHashTable, tuplesToAdd))) { + auto* distinctFactorizedTable = distinctHashTables[distinctIdx]->getFactorizedTable(); + auto distinctTableSchema = distinctFactorizedTable->getTableSchema(); + uint64_t startTupleIdx = 0; + auto numTuplesToScan = + std::min(distinctFactorizedTable->getNumTuples(), DEFAULT_VECTOR_CAPACITY); + while (startTupleIdx < distinctFactorizedTable->getNumTuples()) { + distinctFactorizedTable->scan(vectors, startTupleIdx, numTuplesToScan, + colIdxToScan); + computeVectorHashes(vectors); + for (uint64_t tupleIdx = 0; tupleIdx < numTuplesToScan; tupleIdx++) { + auto* tuple = distinctFactorizedTable->getTuple(startTupleIdx + tupleIdx); + // The distinct value needs to be partitioned according to the group that + // stores its aggregate state So we need to ignore the aggregate key when + // calculating the hash for partitioning + const auto hash = hashVector->getValue(tupleIdx); + partitioningData->appendDistinctTuple(distinctIdx, + std::span(tuple, distinctTableSchema->getNumBytesPerTuple()), hash); + } + startTupleIdx += numTuplesToScan; + numTuplesToScan = std::min(distinctFactorizedTable->getNumTuples() - startTupleIdx, + DEFAULT_VECTOR_CAPACITY); + } + + partitioningData->appendOverflow( + std::move(*distinctFactorizedTable->getInMemOverflowBuffer())); + distinctHashTable->clear(); + } + } +} + +void AggregateHashTable::clear() { + factorizedTable->clear(); + // Clear hash table + for (auto& block : hashSlotsBlocks) { + block->resetToZero(); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/base_aggregate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/base_aggregate.cpp new file mode 100644 index 0000000000..4287985935 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/base_aggregate.cpp @@ -0,0 +1,100 @@ +#include "processor/operator/aggregate/base_aggregate.h" + +#include "main/client_context.h" +#include "processor/operator/aggregate/aggregate_hash_table.h" + +using namespace lbug::function; + +namespace lbug { +namespace processor { + +size_t getNumPartitionsForParallelism(main::ClientContext* context) { + return context->getMaxNumThreadForExec(); +} + +BaseAggregateSharedState::BaseAggregateSharedState( + const std::vector& aggregateFunctions, size_t numPartitions) + : currentOffset{0}, aggregateFunctions{copyVector(aggregateFunctions)}, numThreads{0}, + // numPartitions - 1 since we want the bit width of the largest value that + // could be used to index the partitions + shiftForPartitioning{ + static_cast(sizeof(common::hash_t) * 8 - std::bit_width(numPartitions - 1))}, + readyForFinalization{false} {} + +void BaseAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + for (auto& info : aggInfos) { + auto aggregateInput = AggregateInput(); + if (info.aggVectorPos.dataChunkPos == INVALID_DATA_CHUNK_POS) { + aggregateInput.aggregateVector = nullptr; + } else { + aggregateInput.aggregateVector = resultSet->getValueVector(info.aggVectorPos).get(); + } + for (auto dataChunkPos : info.multiplicityChunksPos) { + aggregateInput.multiplicityChunks.push_back( + resultSet->getDataChunk(dataChunkPos).get()); + } + aggInputs.push_back(std::move(aggregateInput)); + } +} + +BaseAggregateSharedState::HashTableQueue::HashTableQueue(storage::MemoryManager* memoryManager, + FactorizedTableSchema tableSchema) { + headBlock = new TupleBlock(memoryManager, std::move(tableSchema)); + numTuplesPerBlock = headBlock.load()->table.getNumTuplesPerBlock(); +} + +BaseAggregateSharedState::HashTableQueue::~HashTableQueue() { + delete headBlock.load(); + TupleBlock* block = nullptr; + while (queuedTuples.pop(block)) { + delete block; + } +} + +void BaseAggregateSharedState::HashTableQueue::appendTuple(std::span tuple) { + while (true) { + auto* block = headBlock.load(); + KU_ASSERT(tuple.size() == block->table.getTableSchema()->getNumBytesPerTuple()); + auto posToWrite = block->numTuplesReserved++; + if (posToWrite < numTuplesPerBlock) { + memcpy(block->table.getTuple(posToWrite), tuple.data(), tuple.size()); + block->numTuplesWritten++; + return; + } else { + // No more space in the block, allocate and replace it + auto* newBlock = new TupleBlock(block->table.getMemoryManager(), + block->table.getTableSchema()->copy()); + if (headBlock.compare_exchange_strong(block, newBlock)) { + // TODO(bmwinger): if the queuedTuples has at least a certain size (benchmark to see + // if there's a benefit to waiting for multiple blocks) then cycle through the queue + // and flush any blocks which have been fully written + queuedTuples.push(block); + } else { + // If the block was replaced by another thread, discard the block we created and try + // again with the block allocated by the other thread + delete newBlock; + } + } + } +} + +void BaseAggregateSharedState::HashTableQueue::mergeInto(AggregateHashTable& hashTable) { + TupleBlock* partitionToMerge = nullptr; + auto headBlock = this->headBlock.load(); + KU_ASSERT(headBlock != nullptr); + while (queuedTuples.pop(partitionToMerge)) { + KU_ASSERT( + partitionToMerge->numTuplesWritten == partitionToMerge->table.getNumTuplesPerBlock()); + hashTable.merge(std::move(partitionToMerge->table)); + delete partitionToMerge; + } + if (headBlock->numTuplesWritten > 0) { + headBlock->table.resize(headBlock->numTuplesWritten); + hashTable.merge(std::move(headBlock->table)); + } + delete headBlock; + this->headBlock = nullptr; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/base_aggregate_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/base_aggregate_scan.cpp new file mode 100644 index 0000000000..6cdc06f185 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/base_aggregate_scan.cpp @@ -0,0 +1,18 @@ +#include "processor/operator/aggregate/base_aggregate_scan.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +void BaseAggregateScan::initLocalStateInternal(ResultSet* resultSet, + ExecutionContext* /*context*/) { + for (auto& dataPos : scanInfo.aggregatesPos) { + auto valueVector = resultSet->getValueVector(dataPos); + aggregateVectors.push_back(valueVector); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/hash_aggregate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/hash_aggregate.cpp new file mode 100644 index 0000000000..32b5fe2327 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/hash_aggregate.cpp @@ -0,0 +1,265 @@ +#include "processor/operator/aggregate/hash_aggregate.h" + +#include + +#include "binder/expression/expression_util.h" +#include "common/assert.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "processor/operator/aggregate/aggregate_hash_table.h" +#include "processor/operator/aggregate/aggregate_input.h" +#include "processor/operator/aggregate/base_aggregate.h" +#include "processor/result/factorized_table_schema.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::function; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string HashAggregatePrintInfo::toString() const { + std::string result = ""; + result += "Group By: "; + result += binder::ExpressionUtil::toString(keys); + if (!aggregates.empty()) { + result += ", Aggregates: "; + result += binder::ExpressionUtil::toString(aggregates); + } + if (limitNum != UINT64_MAX) { + result += ", Distinct Limit: " + std::to_string(limitNum); + } + return result; +} + +HashAggregateInfo::HashAggregateInfo(std::vector flatKeysPos, + std::vector unFlatKeysPos, std::vector dependentKeysPos, + FactorizedTableSchema tableSchema) + : flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)}, + dependentKeysPos{std::move(dependentKeysPos)}, tableSchema{std::move(tableSchema)} {} + +HashAggregateInfo::HashAggregateInfo(const HashAggregateInfo& other) + : flatKeysPos{other.flatKeysPos}, unFlatKeysPos{other.unFlatKeysPos}, + dependentKeysPos{other.dependentKeysPos}, tableSchema{other.tableSchema.copy()} {} + +HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context, + HashAggregateInfo hashAggInfo, + const std::vector& aggregateFunctions, + std::span aggregateInfos, std::vector keyTypes, + std::vector payloadTypes) + : BaseAggregateSharedState{aggregateFunctions, getNumPartitionsForParallelism(context)}, + aggInfo{std::move(hashAggInfo)}, limitNumber{common::INVALID_LIMIT}, + memoryManager{MemoryManager::Get(*context)}, + globalPartitions{getNumPartitionsForParallelism(context)} { + std::vector distinctAggregateKeyTypes; + for (auto& aggInfo : aggregateInfos) { + distinctAggregateKeyTypes.push_back(aggInfo.distinctAggKeyType.copy()); + } + + // When copying directly into factorizedTables the table's schema's internal mayContainNulls + // won't be updated and it's probably less work to just always check nulls + // Skip the last column, which is the hash column and should never contain nulls + for (size_t i = 0; i < this->aggInfo.tableSchema.getNumColumns() - 1; i++) { + this->aggInfo.tableSchema.setMayContainsNullsToTrue(i); + } + + auto& partition = globalPartitions[0]; + partition.queue = std::make_unique(MemoryManager::Get(*context), + this->aggInfo.tableSchema.copy()); + + // Always create a hash table for the first partition. Any other partitions which are non-empty + // when finalizing will create an empty copy of this table + partition.hashTable = std::make_unique(*MemoryManager::Get(*context), + std::move(keyTypes), std::move(payloadTypes), aggregateFunctions, distinctAggregateKeyTypes, + 0, this->aggInfo.tableSchema.copy()); + for (size_t functionIdx = 0; functionIdx < aggregateFunctions.size(); functionIdx++) { + auto& function = aggregateFunctions[functionIdx]; + if (function.isFunctionDistinct()) { + // Create table schema for distinct hash table + auto distinctTableSchema = FactorizedTableSchema(); + // Group by key columns + for (size_t i = 0; + i < this->aggInfo.flatKeysPos.size() + this->aggInfo.unFlatKeysPos.size(); i++) { + distinctTableSchema.appendColumn(this->aggInfo.tableSchema.getColumn(i)->copy()); + distinctTableSchema.setMayContainsNullsToTrue(i); + } + // Distinct key column + distinctTableSchema.appendColumn(ColumnSchema(false /*isUnFlat*/, 0 /*groupID*/, + LogicalTypeUtils::getRowLayoutSize( + aggregateInfos[functionIdx].distinctAggKeyType))); + distinctTableSchema.setMayContainsNullsToTrue(distinctTableSchema.getNumColumns() - 1); + // Hash column + distinctTableSchema.appendColumn( + ColumnSchema(false /* isUnFlat */, 0 /* groupID */, sizeof(hash_t))); + + partition.distinctTableQueues.emplace_back(std::make_unique( + MemoryManager::Get(*context), std::move(distinctTableSchema))); + } else { + // dummy entry so that indices line up with the aggregateFunctions + partition.distinctTableQueues.emplace_back(); + } + } + // Each partition is the same, so we create the list of distinct queues for the first partition + // and copy it to the other partitions + for (size_t i = 1; i < globalPartitions.size(); i++) { + globalPartitions[i].queue = std::make_unique(MemoryManager::Get(*context), + this->aggInfo.tableSchema.copy()); + globalPartitions[i].distinctTableQueues.resize(partition.distinctTableQueues.size()); + std::transform(partition.distinctTableQueues.begin(), partition.distinctTableQueues.end(), + globalPartitions[i].distinctTableQueues.begin(), [&](auto& q) { + if (q.get() != nullptr) { + return q->copy(); + } else { + return std::unique_ptr(); + } + }); + } +} + +std::pair HashAggregateSharedState::getNextRangeToRead() { + std::unique_lock lck{mtx}; + auto startOffset = currentOffset.load(); + auto numTuples = getNumTuples(); + if (startOffset >= numTuples) { + return std::make_pair(startOffset, startOffset); + } + // FactorizedTable::lookup resets the ValueVector and writes to the beginning, + // so we can't support scanning from multiple partitions at once + auto [table, tableStartOffset] = getPartitionForOffset(startOffset); + auto range = std::min(std::min(DEFAULT_VECTOR_CAPACITY, numTuples - startOffset), + table->getNumTuples() + tableStartOffset - startOffset); + currentOffset += range; + return std::make_pair(startOffset, startOffset + range); +} + +uint64_t HashAggregateSharedState::getNumTuples() const { + uint64_t numTuples = 0; + for (auto& partition : globalPartitions) { + numTuples += partition.hashTable->getNumEntries(); + } + return numTuples; +} + +void HashAggregateSharedState::finalizePartitions() { + BaseAggregateSharedState::finalizePartitions(globalPartitions, [&](auto& partition) { + if (!partition.hashTable) { + // We always initialize the hash table in the first partition + partition.hashTable = std::make_unique( + globalPartitions[0].hashTable->createEmptyCopy()); + } + // TODO(bmwinger): ideally these can be merged into a single function. + // The distinct tables need to be merged first so that they exist when the other table + // updates the agg states when it merges + for (size_t i = 0; i < partition.distinctTableQueues.size(); i++) { + if (partition.distinctTableQueues[i]) { + partition.distinctTableQueues[i]->mergeInto( + *partition.hashTable->getDistinctHashTable(i)); + } + } + partition.queue->mergeInto(*partition.hashTable); + partition.hashTable->mergeDistinctAggregateInfo(); + + partition.hashTable->finalizeAggregateStates(); + }); +} + +std::tuple HashAggregateSharedState::getPartitionForOffset( + offset_t offset) const { + auto factorizedTableStartOffset = 0; + auto partitionIdx = 0; + const auto* table = globalPartitions[partitionIdx].hashTable->getFactorizedTable(); + while (factorizedTableStartOffset + table->getNumTuples() <= offset) { + factorizedTableStartOffset += table->getNumTuples(); + table = globalPartitions[++partitionIdx].hashTable->getFactorizedTable(); + } + return std::make_tuple(table, factorizedTableStartOffset); +} + +void HashAggregateSharedState::scan(std::span entries, + std::vector& keyVectors, offset_t startOffset, offset_t numTuplesToScan, + std::vector& columnIndices) { + auto [table, tableStartOffset] = getPartitionForOffset(startOffset); + // Due to the way FactorizedTable::lookup works, it's necessary to read one partition + // at a time. + KU_ASSERT(startOffset - tableStartOffset + numTuplesToScan <= table->getNumTuples()); + for (size_t pos = 0; pos < numTuplesToScan; pos++) { + auto posInTable = startOffset + pos - tableStartOffset; + entries[pos] = table->getTuple(posInTable); + } + table->lookup(keyVectors, columnIndices, entries.data(), 0, numTuplesToScan); + KU_ASSERT(true); +} + +void HashAggregateSharedState::assertFinalized() const { + RUNTIME_CHECK(for (const auto& partition + : globalPartitions) { + KU_ASSERT(partition.finalized); + KU_ASSERT(partition.queue->empty()); + }); +} + +void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, ResultSet& resultSet, + main::ClientContext* context, std::vector& aggregateFunctions, + std::vector distinctKeyTypes) { + auto& info = sharedState->getAggregateInfo(); + std::vector keyDataTypes; + for (auto& pos : info.flatKeysPos) { + auto vector = resultSet.getValueVector(pos).get(); + keyVectors.push_back(vector); + keyDataTypes.push_back(vector->dataType.copy()); + } + for (auto& pos : info.unFlatKeysPos) { + auto vector = resultSet.getValueVector(pos).get(); + keyVectors.push_back(vector); + keyDataTypes.push_back(vector->dataType.copy()); + leadingState = vector->state.get(); + } + if (leadingState == nullptr) { + // All vectors are flat, so any can be the leading state + leadingState = keyVectors.front()->state.get(); + } + std::vector payloadDataTypes; + for (auto& pos : info.dependentKeysPos) { + auto vector = resultSet.getValueVector(pos).get(); + dependentKeyVectors.push_back(vector); + payloadDataTypes.push_back(vector->dataType.copy()); + } + + aggregateHashTable = std::make_unique(sharedState, + *MemoryManager::Get(*context), std::move(keyDataTypes), std::move(payloadDataTypes), + aggregateFunctions, std::move(distinctKeyTypes), info.tableSchema.copy()); +} + +uint64_t HashAggregateLocalState::append(const std::vector& aggregateInputs, + uint64_t multiplicity) const { + return aggregateHashTable->append(keyVectors, dependentKeyVectors, leadingState, + aggregateInputs, multiplicity); +} + +void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseAggregate::initLocalStateInternal(resultSet, context); + std::vector distinctAggKeyTypes; + for (auto& info : aggInfos) { + distinctAggKeyTypes.push_back(info.distinctAggKeyType.copy()); + } + localState.init(common::ku_dynamic_cast(sharedState.get()), + *resultSet, context->clientContext, aggregateFunctions, std::move(distinctAggKeyTypes)); +} + +void HashAggregate::executeInternal(ExecutionContext* context) { + while (children[0]->getNextTuple(context)) { + const auto numAppendedFlatTuples = localState.append(aggInputs, resultSet->multiplicity); + metrics->numOutputTuple.increase(numAppendedFlatTuples); + // Note: The limit count check here is only applicable to the distinct limit case. + if (localState.aggregateHashTable->getNumEntries() >= + getSharedStateReference().getLimitNumber()) { + break; + } + } + localState.aggregateHashTable->mergeIfFull(0 /*tuplesToAdd*/, true /*mergeAll*/); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/hash_aggregate_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/hash_aggregate_scan.cpp new file mode 100644 index 0000000000..e8ab4c2bd4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/hash_aggregate_scan.cpp @@ -0,0 +1,52 @@ +#include "processor/operator/aggregate/hash_aggregate_scan.h" + +using namespace lbug::function; + +namespace lbug { +namespace processor { + +void HashAggregateScan::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseAggregateScan::initLocalStateInternal(resultSet, context); + for (auto& dataPos : groupByKeyVectorsPos) { + auto valueVector = resultSet->getValueVector(dataPos); + groupByKeyVectors.push_back(valueVector.get()); + } + groupByKeyVectorsColIdxes.resize(groupByKeyVectors.size()); + iota(groupByKeyVectorsColIdxes.begin(), groupByKeyVectorsColIdxes.end(), 0); +} + +bool HashAggregateScan::getNextTuplesInternal(ExecutionContext* /*context*/) { + auto [startOffset, endOffset] = sharedState->getNextRangeToRead(); + if (startOffset >= endOffset) { + return false; + } + auto numRowsToScan = endOffset - startOffset; + entries.resize(numRowsToScan); + sharedState->scan(entries, groupByKeyVectors, startOffset, numRowsToScan, + groupByKeyVectorsColIdxes); + for (auto pos = 0u; pos < numRowsToScan; ++pos) { + auto entry = entries[pos]; + auto offset = sharedState->getTableSchema()->getColOffset(groupByKeyVectors.size()); + for (auto i = 0u; i < aggregateVectors.size(); i++) { + auto vector = aggregateVectors[i]; + auto aggState = reinterpret_cast(entry + offset); + scanInfo.moveAggResultToVectorFuncs[i](*vector, pos, aggState); + offset += aggState->getStateSize(); + } + } + metrics->numOutputTuple.increase(numRowsToScan); + return true; +} + +double HashAggregateScan::getProgress(ExecutionContext* /*context*/) const { + uint64_t totalNumTuples = sharedState->getNumTuples(); + if (totalNumTuples == 0) { + return 0.0; + } else if (sharedState->getCurrentOffset() == totalNumTuples) { + return 1.0; + } + return static_cast(sharedState->getCurrentOffset()) / totalNumTuples; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/simple_aggregate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/simple_aggregate.cpp new file mode 100644 index 0000000000..23079cf253 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/simple_aggregate.cpp @@ -0,0 +1,275 @@ +#include "processor/operator/aggregate/simple_aggregate.h" + +#include +#include +#include +#include + +#include "binder/expression/expression_util.h" +#include "common/data_chunk/data_chunk_state.h" +#include "common/in_mem_overflow_buffer.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "function/aggregate_function.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "processor/operator/aggregate/aggregate_hash_table.h" +#include "processor/operator/aggregate/aggregate_input.h" +#include "processor/operator/aggregate/base_aggregate.h" +#include "processor/result/factorized_table.h" +#include "processor/result/factorized_table_schema.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +std::string SimpleAggregatePrintInfo::toString() const { + std::string result = ""; + result += "Aggregate: "; + result += binder::ExpressionUtil::toString(aggregates); + return result; +} + +static bool isAnyFunctionDistinct(const std::vector& functions) { + return std::any_of(functions.begin(), functions.end(), + [&](auto& func) { return func.isDistinct; }); +} + +SimpleAggregateSharedState::SimpleAggregateSharedState(main::ClientContext* context, + const std::vector& aggregateFunctions, + const std::vector& aggInfos) + : BaseAggregateSharedState{aggregateFunctions, + // Only distinct functions need partitioning + getNumPartitionsForParallelism(context)}, + hasDistinct{isAnyFunctionDistinct(aggregateFunctions)}, + globalPartitions{hasDistinct ? getNumPartitionsForParallelism(context) : 0}, + aggregateOverflowBuffer{storage::MemoryManager::Get(*context)} { + auto mm = storage::MemoryManager::Get(*context); + for (size_t funcIdx = 0; funcIdx < this->aggregateFunctions.size(); funcIdx++) { + auto& aggregateFunction = this->aggregateFunctions[funcIdx]; + globalAggregateStates.push_back(aggregateFunction.createInitialNullAggregateState()); + partitioningData.emplace_back(this, funcIdx); + if (aggregateFunction.isDistinct) { + const auto& distinctKeyType = aggInfos[funcIdx].distinctAggKeyType; + auto schema = AggregateHashTableUtils::getTableSchemaForKeys(std::vector{}, + aggInfos[funcIdx].distinctAggKeyType); + for (auto& partition : globalPartitions) { + std::vector keyTypes(1); + keyTypes[0] = distinctKeyType.copy(); + auto hashTable = std::make_unique(*mm, std::move(keyTypes), + std::vector{} /*payloadTypes*/, std::vector{}, + std::vector{}, 0, schema.copy()); + auto queue = std::make_unique(mm, + AggregateHashTableUtils::getTableSchemaForKeys(std::vector{}, + aggInfos[funcIdx].distinctAggKeyType)); + partition.distinctTables.emplace_back(Partition::DistinctData{std::move(hashTable), + std::move(queue), aggregateFunction.createInitialNullAggregateState()}); + } + } else { + for (auto& partition : globalPartitions) { + partition.distinctTables.emplace_back(); + } + } + } +} + +void SimpleAggregateSharedState::combineAggregateStates( + const std::vector>& localAggregateStates, + common::InMemOverflowBuffer&& localOverflowBuffer) { + KU_ASSERT(localAggregateStates.size() == globalAggregateStates.size()); + std::unique_lock lck{mtx}; + for (auto i = 0u; i < aggregateFunctions.size(); ++i) { + // Distinct functions will be combined accross the partitions in + // finalizeAggregateStates + aggregateOverflowBuffer.merge(localOverflowBuffer); + if (!aggregateFunctions[i].isDistinct) { + aggregateFunctions[i].combineState( + reinterpret_cast(globalAggregateStates[i].get()), + reinterpret_cast(localAggregateStates[i].get()), + &aggregateOverflowBuffer); + } + } +} + +void SimpleAggregateSharedState::finalizeAggregateStates() { + std::unique_lock lck{mtx}; + for (auto i = 0u; i < aggregateFunctions.size(); ++i) { + if (aggregateFunctions[i].isDistinct) { + for (auto& partition : globalPartitions) { + aggregateFunctions[i].combineState(reinterpret_cast(getAggregateState(i)), + reinterpret_cast(partition.distinctTables[i].state.get()), + &aggregateOverflowBuffer); + } + } + aggregateFunctions[i].finalizeState( + reinterpret_cast(globalAggregateStates[i].get())); + } +} + +std::pair SimpleAggregateSharedState::getNextRangeToRead() { + std::unique_lock lck{mtx}; + if (currentOffset >= 1) { + return std::make_pair(currentOffset.load(), currentOffset.load()); + } + auto startOffset = currentOffset.load(); + currentOffset++; + return std::make_pair(startOffset, currentOffset.load()); +} + +void SimpleAggregateSharedState::SimpleAggregatePartitioningData::appendTuples( + const FactorizedTable& factorizedTable, ft_col_offset_t hashOffset) { + KU_ASSERT(sharedState->globalPartitions.size() > 0); + auto numBytesPerTuple = factorizedTable.getTableSchema()->getNumBytesPerTuple(); + for (ft_tuple_idx_t tupleIdx = 0; tupleIdx < factorizedTable.getNumTuples(); tupleIdx++) { + auto tuple = factorizedTable.getTuple(tupleIdx); + auto hash = *reinterpret_cast(tuple + hashOffset); + auto& partition = + sharedState->globalPartitions[(hash >> sharedState->shiftForPartitioning) % + sharedState->globalPartitions.size()]; + partition.distinctTables[functionIdx].queue->appendTuple( + std::span(tuple, numBytesPerTuple)); + } +} + +// LCOV_EXCL_START +void SimpleAggregateSharedState::SimpleAggregatePartitioningData::appendDistinctTuple(size_t, + std::span, common::hash_t) { + KU_UNREACHABLE; +} +// LCOV_EXCL_END + +void SimpleAggregateSharedState::SimpleAggregatePartitioningData::appendOverflow( + common::InMemOverflowBuffer&& overflowBuffer) { + sharedState->overflow.push( + std::make_unique(std::move(overflowBuffer))); +} + +void SimpleAggregateSharedState::finalizePartitions(storage::MemoryManager* memoryManager, + const std::vector& aggInfos) { + if (!hasDistinct) { + return; + } + InMemOverflowBuffer localOverflowBuffer(memoryManager); + BaseAggregateSharedState::finalizePartitions(globalPartitions, [&](auto& partition) { + for (size_t i = 0; i < partition.distinctTables.size(); i++) { + if (!aggregateFunctions[i].isDistinct) { + continue; + } + auto& [hashTable, queue, state] = partition.distinctTables[i]; + if (queue) { + KU_ASSERT(hashTable); + queue->mergeInto(*hashTable); + } + + ValueVector aggregateVector(aggInfos[i].distinctAggKeyType.copy(), memoryManager, + std::make_shared()); + const auto& ft = hashTable->getFactorizedTable(); + ft_tuple_idx_t startTupleIdx = 0; + ft_tuple_idx_t numTuplesToScan = + std::min(DEFAULT_VECTOR_CAPACITY, ft->getNumTuples() - startTupleIdx); + std::array colIdxToScan = {0}; + std::array vectors = {&aggregateVector}; + while (numTuplesToScan > 0) { + ft->scan(vectors, startTupleIdx, numTuplesToScan, colIdxToScan); + aggregateFunctions[i].updateAllState((uint8_t*)state.get(), &aggregateVector, + 1 /*multiplicity*/, &localOverflowBuffer); + startTupleIdx += numTuplesToScan; + numTuplesToScan = + std::min(DEFAULT_VECTOR_CAPACITY, ft->getNumTuples() - startTupleIdx); + } + hashTable.reset(); + queue.reset(); + } + }); + { + std::unique_lock lck{mtx}; + aggregateOverflowBuffer.merge(localOverflowBuffer); + } +} + +void SimpleAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseAggregate::initLocalStateInternal(resultSet, context); + for (auto i = 0u; i < aggregateFunctions.size(); ++i) { + auto& func = aggregateFunctions[i]; + localAggregateStates.push_back(func.createInitialNullAggregateState()); + std::unique_ptr distinctHT; + if (func.isDistinct) { + auto mm = storage::MemoryManager::Get(*context->clientContext); + std::vector keyTypes; + keyTypes.push_back(aggInfos[i].distinctAggKeyType.copy()); + distinctHT = std::make_unique( + &getSharedState().partitioningData[i], *mm, std::move(keyTypes), + std::vector{} /* empty payload*/, + std::vector{}, + std::vector{} /*empty distinct keys*/, + AggregateHashTableUtils::getTableSchemaForKeys(std::vector{}, + aggInfos[i].distinctAggKeyType)); + } else { + distinctHT = nullptr; + } + distinctHashTables.push_back(std::move(distinctHT)); + }; +} + +void SimpleAggregate::executeInternal(ExecutionContext* context) { + InMemOverflowBuffer localOverflowBuffer(storage::MemoryManager::Get(*context->clientContext)); + while (children[0]->getNextTuple(context)) { + for (auto i = 0u; i < aggregateFunctions.size(); i++) { + auto aggregateFunction = &aggregateFunctions[i]; + if (aggregateFunction->isFunctionDistinct()) { + // Just add distinct value to the hash table. We'll calculate the aggregate state + // once it's been merged into the shared state + distinctHashTables[i]->appendDistinct(std::vector{}, + aggInputs[i].aggregateVector, aggInputs[i].aggregateVector->state.get()); + } else { + computeAggregate(aggregateFunction, &aggInputs[i], localAggregateStates[i].get(), + localOverflowBuffer); + } + } + } + if (getSharedState().hasDistinct) { + for (auto& hashTable : distinctHashTables) { + if (hashTable) { + hashTable->mergeIfFull(0 /*tuplesToAdd*/, true /*mergeAll*/); + } + } + } + getSharedState().combineAggregateStates(localAggregateStates, std::move(localOverflowBuffer)); +} + +void SimpleAggregate::computeAggregate(function::AggregateFunction* function, AggregateInput* input, + function::AggregateState* state, common::InMemOverflowBuffer& overflowBuffer) { + auto multiplicity = resultSet->multiplicity; + for (auto dataChunk : input->multiplicityChunks) { + multiplicity *= dataChunk->state->getSelVector().getSelSize(); + } + if (input->aggregateVector && input->aggregateVector->state->isFlat()) { + auto pos = input->aggregateVector->state->getSelVector()[0]; + if (!input->aggregateVector->isNull(pos)) { + function->updatePosState((uint8_t*)state, input->aggregateVector, multiplicity, pos, + &overflowBuffer); + } + } else { + function->updateAllState((uint8_t*)state, input->aggregateVector, multiplicity, + &overflowBuffer); + } +} + +void SimpleAggregateFinalize::finalizeInternal(ExecutionContext* /*context*/) { + sharedState->finalizeAggregateStates(); + if (metrics) { + metrics->numOutputTuple.incrementByOne(); + } +} + +void SimpleAggregateFinalize::executeInternal(ExecutionContext* context) { + KU_ASSERT(sharedState->isReadyForFinalization()); + sharedState->finalizePartitions(storage::MemoryManager::Get(*context->clientContext), aggInfos); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/simple_aggregate_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/simple_aggregate_scan.cpp new file mode 100644 index 0000000000..925c0b73b8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/aggregate/simple_aggregate_scan.cpp @@ -0,0 +1,36 @@ +#include "processor/operator/aggregate/simple_aggregate_scan.h" + +namespace lbug { +namespace processor { + +void SimpleAggregateScan::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseAggregateScan::initLocalStateInternal(resultSet, context); + KU_ASSERT(!scanInfo.aggregatesPos.empty()); + auto outDataChunkPos = scanInfo.aggregatesPos[0].dataChunkPos; + RUNTIME_CHECK({ + for (auto& dataPos : scanInfo.aggregatesPos) { + KU_ASSERT(dataPos.dataChunkPos == outDataChunkPos); + } + }); + outDataChunk = resultSet->dataChunks[outDataChunkPos].get(); +} + +bool SimpleAggregateScan::getNextTuplesInternal(ExecutionContext* /*context*/) { + auto [startOffset, endOffset] = sharedState->getNextRangeToRead(); + if (startOffset >= endOffset) { + return false; + } + // Output of simple aggregate is guaranteed to be a single value for each aggregate. + KU_ASSERT(startOffset == 0 && endOffset == 1); + for (auto i = 0u; i < aggregateVectors.size(); i++) { + scanInfo.moveAggResultToVectorFuncs[i](*aggregateVectors[i], 0 /* position to write */, + sharedState->getAggregateState(i)); + } + KU_ASSERT(!scanInfo.aggregatesPos.empty()); + outDataChunk->state->initOriginalAndSelectedSize(1); + metrics->numOutputTuple.increase(outDataChunk->state->getSelVector().getSelSize()); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/arrow_result_collector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/arrow_result_collector.cpp new file mode 100644 index 0000000000..54142d351f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/arrow_result_collector.cpp @@ -0,0 +1,102 @@ +#include "processor/operator/arrow_result_collector.h" + +#include "common/arrow/arrow_row_batch.h" +#include "main/query_result/arrow_query_result.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +bool ArrowResultCollectorLocalState::advance() { + for (int64_t i = static_cast(chunks.size()) - 1; i >= 0; --i) { + chunkCursors[i]++; + if (chunkCursors[i] < chunks[i]->state->getSelSize()) { + return true; + } + chunkCursors[i] = 0; + } + return false; +} + +void ArrowResultCollectorLocalState::fillTuple() { + KU_ASSERT(tuple->len() == vectors.size()); + for (auto i = 0u; i < vectors.size(); ++i) { + auto vector = vectors[i]; + auto pos = vector->state->getSelVector()[vectorsSelPos[i]]; + auto data = vector->getData() + pos * vector->getNumBytesPerValue(); + tuple->getValue(i)->copyFromColLayout(data, vector); + } +} + +void ArrowResultCollectorLocalState::resetCursor() { + for (auto i = 0u; i < chunkCursors.size(); ++i) { + chunkCursors[i] = 0; + } +} + +void ArrowResultCollectorSharedState::merge(const std::vector& localArrays) { + std::unique_lock lck{mutex}; + for (auto i = 0u; i < localArrays.size(); ++i) { + arrays.push_back(localArrays[i]); + } +} + +void ArrowResultCollector::executeInternal(ExecutionContext* context) { + auto rowBatch = std::make_unique(info.columnTypes, info.chunkSize, + false /* fallbackExtensionTypes */); + while (children[0]->getNextTuple(context)) { + localState.resetCursor(); + while (true) { + if (!fillRowBatch(*rowBatch)) { + break; + } + localState.arrays.push_back(rowBatch->toArray(info.columnTypes)); + rowBatch = std::make_unique(info.columnTypes, info.chunkSize, + false /* fallbackExtensionTypes */); + } + } + // Handle the last rowBatch whose size can be smaller than chunk size. + if (rowBatch->size() > 0) { + localState.arrays.push_back(rowBatch->toArray(info.columnTypes)); + } + sharedState->merge(localState.arrays); +} + +bool ArrowResultCollector::fillRowBatch(ArrowRowBatch& rowBatch) { + while (rowBatch.size() < info.chunkSize) { + localState.fillTuple(); + rowBatch.append(*localState.tuple); + if (!localState.advance()) { + return false; + } + } + return true; +} + +void ArrowResultCollector::initLocalStateInternal(ResultSet* resultSet, ExecutionContext*) { + std::unordered_map idxMap; // Map result set chunk idx to local state idx + // Populate chunks + for (auto& pos : info.payloadPositions) { + auto idx = pos.dataChunkPos; + if (idxMap.contains(idx)) { + continue; + } + idxMap.insert({idx, localState.chunks.size()}); + localState.chunks.push_back(resultSet->getDataChunk(idx).get()); + localState.chunkCursors.push_back(0); + } + // Populate vectors + for (auto& pos : info.payloadPositions) { + localState.vectors.push_back(resultSet->getValueVector(pos).get()); + localState.vectorsSelPos.push_back(localState.chunkCursors[idxMap.at(pos.dataChunkPos)]); + } + localState.tuple = std::make_unique(info.columnTypes); +} + +std::unique_ptr ArrowResultCollector::getQueryResult() const { + return std::make_unique(std::move(sharedState->arrays), info.chunkSize); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/base_partitioner_shared_state.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/base_partitioner_shared_state.cpp new file mode 100644 index 0000000000..f52cb3aa72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/base_partitioner_shared_state.cpp @@ -0,0 +1,37 @@ +#include "processor/operator/base_partitioner_shared_state.h" + +#include "storage/table/node_table.h" +#include "transaction/transaction.h" + +namespace lbug::processor { +void PartitionerSharedState::initialize(const common::logical_type_vec_t&, + common::idx_t numPartitioners, const main::ClientContext* clientContext) { + KU_ASSERT(numPartitioners >= 1 && numPartitioners <= DIRECTIONS); + auto transaction = transaction::Transaction::Get(*clientContext); + numNodes[0] = srcNodeTable->getNumTotalRows(transaction); + if (numPartitioners > 1) { + numNodes[1] = dstNodeTable->getNumTotalRows(transaction); + } + numPartitions[0] = getNumPartitionsFromRows(numNodes[0]); + if (numPartitioners > 1) { + numPartitions[1] = getNumPartitionsFromRows(numNodes[1]); + } +} + +common::partition_idx_t PartitionerSharedState::getNextPartition(common::idx_t partitioningIdx) { + auto nextPartitionIdxToReturn = nextPartitionIdx++; + if (nextPartitionIdxToReturn >= numPartitions[partitioningIdx]) { + return common::INVALID_PARTITION_IDX; + } + return nextPartitionIdxToReturn; +} + +common::partition_idx_t PartitionerSharedState::getNumPartitionsFromRows(common::offset_t numRows) { + return (numRows + common::StorageConfig::NODE_GROUP_SIZE - 1) / + common::StorageConfig::NODE_GROUP_SIZE; +} + +void PartitionerSharedState::resetState(common::idx_t) { + nextPartitionIdx = 0; +} +} // namespace lbug::processor diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/cross_product.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/cross_product.cpp new file mode 100644 index 0000000000..47b32d0c18 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/cross_product.cpp @@ -0,0 +1,38 @@ +#include "processor/operator/cross_product.h" + +#include "common/metric.h" + +namespace lbug { +namespace processor { + +void CrossProduct::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + for (auto& pos : info.outVecPos) { + vectorsToScan.push_back(resultSet->getValueVector(pos).get()); + } + localState.init(); +} + +bool CrossProduct::getNextTuplesInternal(ExecutionContext* context) { + // Note: we should NOT morselize right table scanning (i.e. calling sharedState.getMorsel) + // because every thread should scan its own table. + auto table = localState.table.get(); + if (table->getNumTuples() == 0) { + return false; + } + if (localState.startIdx == table->getNumTuples()) { // no more to scan from right + if (!children[0]->getNextTuple(context)) { // fetch a new left tuple + return false; + } + localState.startIdx = 0; // reset right table scanning for a new left tuple + } + // scan from right table if there is tuple left + auto numTuplesToScan = + std::min(localState.maxMorselSize, table->getNumTuples() - localState.startIdx); + table->scan(vectorsToScan, localState.startIdx, numTuplesToScan, info.colIndicesToScan); + localState.startIdx += numTuplesToScan; + metrics->numOutputTuple.increase(numTuplesToScan); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/CMakeLists.txt new file mode 100644 index 0000000000..846d152fa6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_processor_operator_ddl + OBJECT + alter.cpp + create_table.cpp + create_type.cpp + drop.cpp + create_sequence.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/alter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/alter.cpp new file mode 100644 index 0000000000..7c8ee0c875 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/alter.cpp @@ -0,0 +1,356 @@ +#include "processor/operator/ddl/alter.h" + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/enums/alter_type.h" +#include "common/exception/binder.h" +#include "common/exception/runtime.h" +#include "processor/execution_context.h" +#include "storage/storage_manager.h" +#include "storage/table/table.h" +#include "transaction/transaction.h" + +using namespace lbug::binder; +using namespace lbug::common; +using namespace lbug::catalog; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +void Alter::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + if (defaultValueEvaluator) { + defaultValueEvaluator->init(*resultSet, context->clientContext); + } +} + +void Alter::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto catalog = Catalog::Get(*clientContext); + auto transaction = Transaction::Get(*clientContext); + if (catalog->containsTable(transaction, info.tableName)) { + auto entry = catalog->getTableCatalogEntry(transaction, info.tableName); + alterTable(clientContext, *entry, info); + } else { + throw BinderException("Table " + info.tableName + " does not exist."); + } +} + +using on_conflict_throw_action = std::function; + +static void validate(ConflictAction action, const on_conflict_throw_action& throwAction) { + switch (action) { + case ConflictAction::ON_CONFLICT_THROW: { + throwAction(); + } break; + case ConflictAction::ON_CONFLICT_DO_NOTHING: + break; + default: + KU_UNREACHABLE; + } +} + +static std::string propertyNotInTableMessage(const std::string& tableName, + const std::string& propertyName) { + return stringFormat("{} table does not have property {}.", tableName, propertyName); +} + +static void validatePropertyExist(ConflictAction action, const TableCatalogEntry& tableEntry, + const std::string& propertyName) { + validate(action, [&tableEntry, &propertyName]() { + if (!tableEntry.containsProperty(propertyName)) { + throw RuntimeException(propertyNotInTableMessage(tableEntry.getName(), propertyName)); + } + }); +} + +static std::string propertyInTableMessage(const std::string& tableName, + const std::string& propertyName) { + return stringFormat("{} table already has property {}.", tableName, propertyName); +} + +static void validatePropertyNotExist(ConflictAction action, const TableCatalogEntry& tableEntry, + const std::string& propertyName) { + validate(action, [&tableEntry, &propertyName] { + if (tableEntry.containsProperty(propertyName)) { + throw RuntimeException(propertyInTableMessage(tableEntry.getName(), propertyName)); + } + }); +} + +using skip_alter_on_conflict = std::function; + +static bool skipAlter(ConflictAction action, const skip_alter_on_conflict& skipAlterOnConflict) { + switch (action) { + case ConflictAction::ON_CONFLICT_THROW: + return false; + case ConflictAction::ON_CONFLICT_DO_NOTHING: + return skipAlterOnConflict(); + default: + KU_UNREACHABLE; + } +} + +static bool checkAddPropertyConflicts(const TableCatalogEntry& tableEntry, + const BoundAlterInfo& info) { + const auto& extraInfo = info.extraInfo->constCast(); + auto propertyName = extraInfo.propertyDefinition.getName(); + validatePropertyNotExist(info.onConflict, tableEntry, propertyName); + + // Eventually, we want to support non-constant default on rel tables, but it is non-trivial + // due to FWD/BWD storage + if (tableEntry.getType() == CatalogEntryType::REL_GROUP_ENTRY && + extraInfo.boundDefault->expressionType != ExpressionType::LITERAL) { + throw RuntimeException( + "Cannot set a non-constant default value when adding columns on REL tables."); + } + + return skipAlter(info.onConflict, + [&tableEntry, &propertyName]() { return tableEntry.containsProperty(propertyName); }); +} + +static bool checkDropPropertyConflicts(const TableCatalogEntry& tableEntry, + const BoundAlterInfo& info, main::ClientContext& context) { + const auto& extraInfo = info.extraInfo->constCast(); + auto propertyName = extraInfo.propertyName; + validatePropertyExist(info.onConflict, tableEntry, propertyName); + if (tableEntry.containsProperty(propertyName)) { + // Check constrains if we are going to drop a property that exists. + auto propertyID = tableEntry.getPropertyID(propertyName); + // Check primary key constraint + if (tableEntry.getTableType() == TableType::NODE && + tableEntry.constCast().getPrimaryKeyID() == propertyID) { + throw BinderException(stringFormat( + "Cannot drop property {} in table {} because it is used as primary key.", + propertyName, tableEntry.getName())); + } + // Check secondary index constraints + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + if (catalog->containsIndex(transaction, tableEntry.getTableID(), propertyID)) { + throw BinderException(stringFormat( + "Cannot drop property {} in table {} because it is used in one or more indexes. " + "Please remove the associated indexes before attempting to drop this property.", + propertyName, tableEntry.getName())); + } + } + return skipAlter(info.onConflict, + [&tableEntry, &propertyName]() { return !tableEntry.containsProperty(propertyName); }); +} + +static bool checkRenamePropertyConflicts(const TableCatalogEntry& tableEntry, + const BoundAlterInfo& info) { + const auto* extraInfo = info.extraInfo->constPtrCast(); + validatePropertyExist(ConflictAction::ON_CONFLICT_THROW, tableEntry, extraInfo->oldName); + validatePropertyNotExist(ConflictAction::ON_CONFLICT_THROW, tableEntry, extraInfo->newName); + return false; +} + +static bool checkRenameTableConflicts(const BoundAlterInfo& info, main::ClientContext& context) { + auto newName = info.extraInfo->constCast().newName; + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + if (catalog->containsTable(transaction, newName)) { + throw BinderException("Table " + newName + " already exists."); + } + return false; +} + +static std::string fromToInTableMessage(const std::string& relGroupName, + const std::string& fromTableName, const std::string& toTableName) { + return stringFormat("{}->{} already exists in {} table.", fromTableName, toTableName, + relGroupName); +} + +static bool checkAddFromToConflicts(const TableCatalogEntry& tableEntry, const BoundAlterInfo& info, + main::ClientContext& context) { + auto& extraInfo = info.extraInfo->constCast(); + auto& relGroupEntry = tableEntry.constCast(); + validate(info.onConflict, [&relGroupEntry, &extraInfo, &context]() { + if (relGroupEntry.hasRelEntryInfo(extraInfo.fromTableID, extraInfo.toTableID)) { + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + auto fromTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.fromTableID)->getName(); + auto toTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.toTableID)->getName(); + throw BinderException{ + fromToInTableMessage(relGroupEntry.getName(), fromTableName, toTableName)}; + } + }); + return skipAlter(info.onConflict, [&relGroupEntry, &extraInfo]() { + return relGroupEntry.hasRelEntryInfo(extraInfo.fromTableID, extraInfo.toTableID); + }); +} + +static std::string fromToNotInTableMessage(const std::string& relGroupName, + const std::string& fromTableName, const std::string& toTableName) { + return stringFormat("{}->{} does not exist in {} table.", fromTableName, toTableName, + relGroupName); +} + +static bool checkDropFromToConflicts(const TableCatalogEntry& tableEntry, + const BoundAlterInfo& info, main::ClientContext& context) { + auto& extraInfo = info.extraInfo->constCast(); + auto& relGroupEntry = tableEntry.constCast(); + validate(info.onConflict, [&relGroupEntry, &extraInfo, &context]() { + if (!relGroupEntry.hasRelEntryInfo(extraInfo.fromTableID, extraInfo.toTableID)) { + auto catalog = Catalog::Get(context); + auto transaction = transaction::Transaction::Get(context); + auto fromTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.fromTableID)->getName(); + auto toTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.toTableID)->getName(); + throw BinderException{ + fromToNotInTableMessage(relGroupEntry.getName(), fromTableName, toTableName)}; + } + }); + return skipAlter(info.onConflict, [&relGroupEntry, &extraInfo]() { + return !relGroupEntry.hasRelEntryInfo(extraInfo.fromTableID, extraInfo.toTableID); + }); +} + +void Alter::alterTable(main::ClientContext* clientContext, const TableCatalogEntry& entry, + const BoundAlterInfo& alterInfo) { + auto catalog = Catalog::Get(*clientContext); + auto transaction = Transaction::Get(*clientContext); + auto memoryManager = storage::MemoryManager::Get(*clientContext); + auto tableName = entry.getName(); + switch (info.alterType) { + case AlterType::ADD_PROPERTY: { + auto& extraInfo = info.extraInfo->constCast(); + auto propertyName = extraInfo.propertyDefinition.getName(); + if (checkAddPropertyConflicts(entry, info)) { + appendMessage(propertyInTableMessage(tableName, propertyName), memoryManager); + return; + } + appendMessage(stringFormat("Property {} added to table {}.", propertyName, tableName), + memoryManager); + } break; + case AlterType::DROP_PROPERTY: { + auto& extraInfo = info.extraInfo->constCast(); + auto propertyName = extraInfo.propertyName; + if (checkDropPropertyConflicts(entry, info, *clientContext)) { + appendMessage(propertyNotInTableMessage(tableName, propertyName), memoryManager); + return; + } + appendMessage( + stringFormat("Property {} has been dropped from table {}.", propertyName, tableName), + memoryManager); + } break; + case AlterType::RENAME_PROPERTY: { + // Rename property does not have IF EXISTS + checkRenamePropertyConflicts(entry, info); + auto& extraInfo = info.extraInfo->constCast(); + appendMessage( + stringFormat("Property {} renamed to {}.", extraInfo.oldName, extraInfo.newName), + memoryManager); + } break; + case AlterType::RENAME: { + // Rename table does not have IF EXISTS + checkRenameTableConflicts(info, *clientContext); + auto& extraInfo = info.extraInfo->constCast(); + appendMessage(stringFormat("Table {} renamed to {}.", tableName, extraInfo.newName), + memoryManager); + } break; + case AlterType::ADD_FROM_TO_CONNECTION: { + auto& extraInfo = info.extraInfo->constCast(); + auto fromTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.fromTableID)->getName(); + auto toTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.toTableID)->getName(); + if (checkAddFromToConflicts(entry, info, *clientContext)) { + appendMessage(fromToInTableMessage(tableName, fromTableName, toTableName), + memoryManager); + return; + } + appendMessage( + stringFormat("{}->{} added to table {}.", fromTableName, toTableName, tableName), + memoryManager); + } break; + case AlterType::DROP_FROM_TO_CONNECTION: { + auto& extraInfo = info.extraInfo->constCast(); + auto fromTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.fromTableID)->getName(); + auto toTableName = + catalog->getTableCatalogEntry(transaction, extraInfo.toTableID)->getName(); + if (checkDropFromToConflicts(entry, info, *clientContext)) { + appendMessage(fromToNotInTableMessage(tableName, fromTableName, toTableName), + memoryManager); + return; + } + appendMessage(stringFormat("{}->{} has been dropped from table {}.", fromTableName, + toTableName, tableName), + memoryManager); + } break; + case AlterType::COMMENT: { + appendMessage(stringFormat("Comment added to table {}.", tableName), memoryManager); + } break; + default: + KU_UNREACHABLE; + } + + // Handle storage changes + const auto storageManager = storage::StorageManager::Get(*clientContext); + catalog->alterTableEntry(transaction, alterInfo); + // We don't use an optimistic allocator in this case since rollback of new columns is already + // handled by checkpoint + auto& pageAllocator = *storageManager->getDataFH()->getPageManager(); + switch (info.alterType) { + case AlterType::ADD_PROPERTY: { + auto& boundAddPropInfo = info.extraInfo->constCast(); + KU_ASSERT(defaultValueEvaluator); + auto* alteredEntry = catalog->getTableCatalogEntry(transaction, alterInfo.tableName); + auto& addedProp = alteredEntry->getProperty(boundAddPropInfo.propertyDefinition.getName()); + storage::TableAddColumnState state{addedProp, *defaultValueEvaluator}; + switch (alteredEntry->getTableType()) { + case TableType::NODE: { + storageManager->getTable(alteredEntry->getTableID()) + ->addColumn(transaction, state, pageAllocator); + } break; + case TableType::REL: { + for (auto& innerRelEntry : + alteredEntry->cast().getRelEntryInfos()) { + auto* relTable = storageManager->getTable(innerRelEntry.oid); + relTable->addColumn(transaction, state, pageAllocator); + } + } break; + default: { + KU_UNREACHABLE; + } + } + } break; + case AlterType::DROP_PROPERTY: { + auto* alteredEntry = catalog->getTableCatalogEntry(transaction, alterInfo.tableName); + switch (alteredEntry->getTableType()) { + case TableType::NODE: { + storageManager->getTable(alteredEntry->getTableID())->dropColumn(); + } break; + case TableType::REL: { + for (auto& innerRelEntry : + alteredEntry->cast().getRelEntryInfos()) { + auto* relTable = storageManager->getTable(innerRelEntry.oid); + relTable->dropColumn(); + } + } break; + default: { + KU_UNREACHABLE; + } + } + } break; + case AlterType::ADD_FROM_TO_CONNECTION: { + auto relGroupEntry = catalog->getTableCatalogEntry(transaction, alterInfo.tableName) + ->ptrCast(); + auto connectionInfo = alterInfo.extraInfo->constPtrCast(); + auto relEntryInfo = + relGroupEntry->getRelEntryInfo(connectionInfo->fromTableID, connectionInfo->toTableID); + storageManager->addRelTable(relGroupEntry, *relEntryInfo); + } break; + default: + break; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_sequence.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_sequence.cpp new file mode 100644 index 0000000000..6c0236426c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_sequence.cpp @@ -0,0 +1,40 @@ +#include "processor/operator/ddl/create_sequence.h" + +#include "catalog/catalog.h" +#include "common/string_format.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string CreateSequencePrintInfo::toString() const { + return seqName; +} + +void CreateSequence::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto memoryManager = storage::MemoryManager::Get(*clientContext); + if (catalog->containsSequence(transaction, info.sequenceName)) { + switch (info.onConflict) { + case ConflictAction::ON_CONFLICT_DO_NOTHING: { + appendMessage(stringFormat("Sequence {} already exists.", info.sequenceName), + memoryManager); + return; + } + default: + break; + } + } + catalog->createSequence(transaction, info); + appendMessage(stringFormat("Sequence {} has been created.", info.sequenceName), memoryManager); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_table.cpp new file mode 100644 index 0000000000..c92af894c0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_table.cpp @@ -0,0 +1,50 @@ +#include "processor/operator/ddl/create_table.h" + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "processor/execution_context.h" +#include "storage/storage_manager.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void CreateTable::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto catalog = Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto memoryManager = storage::MemoryManager::Get(*clientContext); + // Check conflict + if (catalog->containsTable(transaction, info.tableName)) { + switch (info.onConflict) { + case ConflictAction::ON_CONFLICT_DO_NOTHING: { + appendMessage(stringFormat("Table {} already exists.", info.tableName), memoryManager); + return; + } + case ConflictAction::ON_CONFLICT_THROW: { + throw BinderException(info.tableName + " already exists in catalog."); + } + default: + KU_UNREACHABLE; + } + } + // Create the table. + CatalogEntry* entry = nullptr; + switch (info.type) { + case CatalogEntryType::NODE_TABLE_ENTRY: + case CatalogEntryType::REL_GROUP_ENTRY: { + entry = catalog->createTableEntry(transaction, info); + } break; + default: + KU_UNREACHABLE; + } + storage::StorageManager::Get(*clientContext)->createTable(entry->ptrCast()); + appendMessage(stringFormat("Table {} has been created.", info.tableName), memoryManager); + sharedState->tableCreated = true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_type.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_type.cpp new file mode 100644 index 0000000000..9b7358eaf7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/create_type.cpp @@ -0,0 +1,27 @@ +#include "processor/operator/ddl/create_type.h" + +#include "catalog/catalog.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string CreateTypePrintInfo::toString() const { + return typeName + " AS " + type; +} + +void CreateType::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto transaction = transaction::Transaction::Get(*clientContext); + Catalog::Get(*clientContext)->createType(transaction, name, type.copy()); + appendMessage(stringFormat("Type {}({}) has been created.", name, type.toString()), + storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/drop.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/drop.cpp new file mode 100644 index 0000000000..bde00103d8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/ddl/drop.cpp @@ -0,0 +1,134 @@ +#include "processor/operator/ddl/drop.h" + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/binder.h" +#include "common/string_format.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void Drop::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + switch (dropInfo.dropType) { + case DropType::SEQUENCE: { + dropSequence(clientContext); + } break; + case DropType::TABLE: { + dropTable(clientContext); + } break; + case DropType::MACRO: { + dropMacro(clientContext); + } break; + default: + KU_UNREACHABLE; + } +} + +void Drop::dropSequence(const main::ClientContext* context) { + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto memoryManager = storage::MemoryManager::Get(*context); + if (!catalog->containsSequence(transaction, dropInfo.name)) { + auto message = stringFormat("Sequence {} does not exist.", dropInfo.name); + switch (dropInfo.conflictAction) { + case ConflictAction::ON_CONFLICT_DO_NOTHING: { + appendMessage(message, memoryManager); + return; + } + case ConflictAction::ON_CONFLICT_THROW: { + throw BinderException(message); + } + default: + KU_UNREACHABLE; + } + } + catalog->dropSequence(transaction, dropInfo.name); + appendMessage(stringFormat("Sequence {} has been dropped.", dropInfo.name), memoryManager); +} + +void Drop::dropTable(const main::ClientContext* context) { + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto memoryManager = storage::MemoryManager::Get(*context); + if (!catalog->containsTable(transaction, dropInfo.name, context->useInternalCatalogEntry())) { + auto message = stringFormat("Table {} does not exist.", dropInfo.name); + switch (dropInfo.conflictAction) { + case ConflictAction::ON_CONFLICT_DO_NOTHING: { + appendMessage(message, memoryManager); + return; + } + case ConflictAction::ON_CONFLICT_THROW: { + throw BinderException(message); + } + default: + KU_UNREACHABLE; + } + } + auto entry = catalog->getTableCatalogEntry(transaction, dropInfo.name); + switch (entry->getType()) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + for (auto& indexEntry : catalog->getIndexEntries(transaction)) { + if (indexEntry->getTableID() == entry->getTableID()) { + throw BinderException(stringFormat( + "Cannot delete node table {} because it is referenced by index {}.", + entry->getName(), indexEntry->getIndexName())); + } + } + for (auto& relEntry : catalog->getRelGroupEntries(transaction)) { + if (relEntry->isParent(entry->getTableID())) { + throw BinderException(stringFormat("Cannot delete node table {} because it is " + "referenced by relationship table {}.", + entry->getName(), relEntry->getName())); + } + } + } break; + case CatalogEntryType::REL_GROUP_ENTRY: { + // Do nothing + } break; + default: + KU_UNREACHABLE; + } + catalog->dropTableEntryAndIndex(transaction, dropInfo.name); + appendMessage(stringFormat("Table {} has been dropped.", dropInfo.name), memoryManager); +} + +void Drop::dropMacro(const main::ClientContext* context) { + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto memoryManager = storage::MemoryManager::Get(*context); + handleMacroExistence(context); + catalog->dropMacro(transaction, dropInfo.name); + appendMessage(stringFormat("Macro {} has been dropped.", dropInfo.name), memoryManager); +} + +void Drop::handleMacroExistence(const main::ClientContext* context) { + auto catalog = Catalog::Get(*context); + auto transaction = transaction::Transaction::Get(*context); + auto memoryManager = storage::MemoryManager::Get(*context); + if (!catalog->containsMacro(transaction, dropInfo.name)) { + auto message = stringFormat("Macro {} does not exist.", dropInfo.name); + switch (dropInfo.conflictAction) { + case ConflictAction::ON_CONFLICT_DO_NOTHING: { + appendMessage(message, memoryManager); + return; + } + case ConflictAction::ON_CONFLICT_THROW: { + throw BinderException(message); + } + default: + KU_UNREACHABLE; + } + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/empty_result.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/empty_result.cpp new file mode 100644 index 0000000000..fb7393a64f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/empty_result.cpp @@ -0,0 +1,11 @@ +#include "processor/operator/empty_result.h" + +namespace lbug { +namespace processor { + +bool EmptyResult::getNextTuplesInternal(ExecutionContext*) { + return false; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/filter.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/filter.cpp new file mode 100644 index 0000000000..9eb882d6f4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/filter.cpp @@ -0,0 +1,70 @@ +#include "processor/operator/filter.h" + +#include "binder/expression/expression.h" // IWYU pragma: keep +#include "processor/execution_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string FilterPrintInfo::toString() const { + return expression->toString(); +} + +void Filter::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + expressionEvaluator->init(*resultSet, context->clientContext); + if (dataChunkToSelectPos == INVALID_DATA_CHUNK_POS) { + // Filter a constant expression. Ideally we should fold all such expression at compile time. + // But there are many edge cases, so we keep this code path for robustness. + state = DataChunkState::getSingleValueDataChunkState(); + } else { + state = resultSet->dataChunks[dataChunkToSelectPos]->state; + } +} + +bool Filter::getNextTuplesInternal(ExecutionContext* context) { + bool hasAtLeastOneSelectedValue = false; + do { + restoreSelVector(*state); + if (!children[0]->getNextTuple(context)) { + return false; + } + saveSelVector(*state); + hasAtLeastOneSelectedValue = + expressionEvaluator->select(state->getSelVectorUnsafe(), !state->isFlat()); + } while (!hasAtLeastOneSelectedValue); + metrics->numOutputTuple.increase(state->getSelVector().getSelSize()); + return true; +} + +void NodeLabelFiler::initLocalStateInternal(ResultSet* /*resultSet_*/, + ExecutionContext* /*context*/) { + nodeIDVector = resultSet->getValueVector(info->nodeVectorPos).get(); +} + +bool NodeLabelFiler::getNextTuplesInternal(ExecutionContext* context) { + sel_t numSelectValue = 0; + do { + restoreSelVector(*nodeIDVector->state); + if (!children[0]->getNextTuple(context)) { + return false; + } + saveSelVector(*nodeIDVector->state); + numSelectValue = 0; + auto buffer = nodeIDVector->state->getSelVectorUnsafe().getMutableBuffer(); + for (auto i = 0u; i < nodeIDVector->state->getSelVector().getSelSize(); ++i) { + auto pos = nodeIDVector->state->getSelVector()[i]; + buffer[numSelectValue] = pos; + numSelectValue += + info->nodeLabelSet.contains(nodeIDVector->getValue(pos).tableID); + } + nodeIDVector->state->getSelVectorUnsafe().setToFiltered(); + } while (numSelectValue == 0); + nodeIDVector->state->getSelVectorUnsafe().setSelSize(numSelectValue); + metrics->numOutputTuple.increase(nodeIDVector->state->getSelVector().getSelSize()); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/filtering_operator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/filtering_operator.cpp new file mode 100644 index 0000000000..cd60ea5896 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/filtering_operator.cpp @@ -0,0 +1,43 @@ +#include "processor/operator/filtering_operator.h" + +#include + +#include "common/data_chunk/data_chunk_state.h" +#include "common/system_config.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +SelVectorOverWriter::SelVectorOverWriter() { + currentSelVector = std::make_shared(common::DEFAULT_VECTOR_CAPACITY); +} + +void SelVectorOverWriter::restoreSelVector(DataChunkState& dataChunkState) const { + if (prevSelVector != nullptr) { + dataChunkState.setSelVector(prevSelVector); + } +} + +void SelVectorOverWriter::saveSelVector(DataChunkState& dataChunkState) { + if (prevSelVector == nullptr) { + prevSelVector = dataChunkState.getSelVectorShared(); + } + resetCurrentSelVector(dataChunkState.getSelVector()); + dataChunkState.setSelVector(currentSelVector); +} + +void SelVectorOverWriter::resetCurrentSelVector(const SelectionVector& selVector) { + currentSelVector->setSelSize(selVector.getSelSize()); + if (selVector.isUnfiltered()) { + currentSelVector->setToUnfiltered(); + } else { + std::memcpy(currentSelVector->getMutableBuffer().data(), + selVector.getSelectedPositions().data(), selVector.getSelSize() * sizeof(sel_t)); + currentSelVector->setToFiltered(); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/flatten.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/flatten.cpp new file mode 100644 index 0000000000..a820768109 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/flatten.cpp @@ -0,0 +1,37 @@ +#include "processor/operator/flatten.h" + +#include "common/metric.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void Flatten::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + dataChunkState = resultSet->dataChunks[dataChunkToFlattenPos]->state.get(); + currentSelVector->setToFiltered(1 /* size */); + localState = std::make_unique(); +} + +bool Flatten::getNextTuplesInternal(ExecutionContext* context) { + if (localState->currentIdx == localState->sizeToFlatten) { + dataChunkState->setToUnflat(); // TODO(Xiyang): this should be part of restore/save + restoreSelVector(*dataChunkState); + if (!children[0]->getNextTuple(context)) { + return false; + } + localState->currentIdx = 0; + localState->sizeToFlatten = dataChunkState->getSelVector().getSelSize(); + saveSelVector(*dataChunkState); + dataChunkState->setToFlat(); + } + sel_t selPos = prevSelVector->operator[](localState->currentIdx++); + currentSelVector->operator[](0) = selPos; + metrics->numOutputTuple.incrementByOne(); + return true; +} + +void Flatten::resetCurrentSelVector(const SelectionVector&) {} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/CMakeLists.txt new file mode 100644 index 0000000000..67d5ed7f11 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_processor_operator_hash_join + OBJECT + hash_join_build.cpp + hash_join_probe.cpp + join_hash_table.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/hash_join_build.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/hash_join_build.cpp new file mode 100644 index 0000000000..d0e05aeff8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/hash_join_build.cpp @@ -0,0 +1,76 @@ +#include "processor/operator/hash_join/hash_join_build.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string HashJoinBuildPrintInfo::toString() const { + std::string result = "Keys: "; + result += binder::ExpressionUtil::toString(keys); + if (!payloads.empty()) { + result += ", Payloads: "; + result += binder::ExpressionUtil::toString(payloads); + } + return result; +} + +void HashJoinSharedState::mergeLocalHashTable(JoinHashTable& localHashTable) { + std::unique_lock lck(mtx); + hashTable->merge(localHashTable); +} + +void HashJoinBuild::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + std::vector keyTypes; + for (auto i = 0u; i < info.keysPos.size(); ++i) { + auto vector = resultSet->getValueVector(info.keysPos[i]).get(); + keyTypes.push_back(vector->dataType.copy()); + if (info.fStateTypes[i] == common::FStateType::UNFLAT) { + setKeyState(vector->state.get()); + } + keyVectors.push_back(vector); + } + if (keyState == nullptr) { + setKeyState(keyVectors[0]->state.get()); + } + for (auto& pos : info.payloadsPos) { + payloadVectors.push_back(resultSet->getValueVector(pos).get()); + } + hashTable = std::make_unique(*MemoryManager::Get(*context->clientContext), + std::move(keyTypes), info.tableSchema.copy()); +} + +void HashJoinBuild::setKeyState(common::DataChunkState* state) { + if (keyState == nullptr) { + keyState = state; + } else { + KU_ASSERT(keyState == state); // two pointers should be pointing to the same state + } +} + +void HashJoinBuild::finalizeInternal(ExecutionContext* /*context*/) { + auto numTuples = sharedState->getHashTable()->getNumEntries(); + sharedState->getHashTable()->allocateHashSlots(numTuples); + sharedState->getHashTable()->buildHashSlots(); +} + +void HashJoinBuild::executeInternal(ExecutionContext* context) { + // Append thread-local tuples + while (children[0]->getNextTuple(context)) { + uint64_t numAppended = 0u; + for (auto i = 0u; i < resultSet->multiplicity; ++i) { + numAppended += appendVectors(); + } + metrics->numOutputTuple.increase(numAppended); + } + // Merge with global hash table once local tuples are all appended. + sharedState->mergeLocalHashTable(*hashTable); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/hash_join_probe.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/hash_join_probe.cpp new file mode 100644 index 0000000000..f59ade8bc0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/hash_join_probe.cpp @@ -0,0 +1,211 @@ +#include "processor/operator/hash_join/hash_join_probe.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string HashJoinProbePrintInfo::toString() const { + std::string result = "Keys: "; + result += binder::ExpressionUtil::toString(keys); + return result; +} + +void HashJoinProbe::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + probeState = std::make_unique(); + for (auto& keyDataPos : probeDataInfo.keysDataPos) { + keyVectors.push_back(resultSet->getValueVector(keyDataPos).get()); + } + if (probeDataInfo.markDataPos.isValid()) { + markVector = resultSet->getValueVector(probeDataInfo.markDataPos).get(); + } else { + markVector = nullptr; + } + for (auto& dataPos : probeDataInfo.payloadsOutPos) { + vectorsToReadInto.push_back(resultSet->getValueVector(dataPos).get()); + } + // We only need to read nonKeys from the factorizedTable. Key columns are always kept as first k + // columns in the factorizedTable, so we skip the first k columns. + KU_ASSERT(probeDataInfo.keysDataPos.size() + probeDataInfo.getNumPayloads() + 2 == + sharedState->getHashTable()->getTableSchema()->getNumColumns()); + columnIdxsToReadFrom.resize(probeDataInfo.getNumPayloads()); + iota(columnIdxsToReadFrom.begin(), columnIdxsToReadFrom.end(), + probeDataInfo.keysDataPos.size()); + auto mm = storage::MemoryManager::Get(*context->clientContext); + hashVector = std::make_unique(LogicalType::HASH(), mm); + if (keyVectors.size() > 1) { + tmpHashVector = std::make_unique(LogicalType::HASH(), mm); + } +} + +bool HashJoinProbe::getMatchedTuplesForFlatKey(ExecutionContext* context) { + if (probeState->nextMatchedTupleIdx < probeState->matchedSelVector.getSelSize()) { + // Not all matched tuples have been shipped. Continue shipping. + return true; + } + if (probeState->probedTuples[0] == nullptr) { // No more matched tuples on the chain. + // We still need to save and restore for flat input because we are discarding NULL join keys + // which changes the selected position. + // TODO(Guodong): we have potential bugs here because all keys' states should be restored. + restoreSelVector(*keyVectors[0]->state); + if (!children[0]->getNextTuple(context)) { + return false; + } + saveSelVector(*keyVectors[0]->state); + sharedState->getHashTable()->probe(keyVectors, *hashVector, hashSelVec, tmpHashVector.get(), + probeState->probedTuples.get()); + } + auto numMatchedTuples = sharedState->getHashTable()->matchFlatKeys(keyVectors, + probeState->probedTuples.get(), probeState->matchedTuples.get()); + probeState->matchedSelVector.setSelSize(numMatchedTuples); + probeState->nextMatchedTupleIdx = 0; + return true; +} + +bool HashJoinProbe::getMatchedTuplesForUnFlatKey(ExecutionContext* context) { + KU_ASSERT(keyVectors.size() == 1); + auto keyVector = keyVectors[0]; + restoreSelVector(*keyVector->state); + if (!children[0]->getNextTuple(context)) { + return false; + } + saveSelVector(*keyVector->state); + sharedState->getHashTable()->probe(keyVectors, *hashVector, hashSelVec, tmpHashVector.get(), + probeState->probedTuples.get()); + auto numMatchedTuples = + sharedState->getHashTable()->matchUnFlatKey(keyVector, probeState->probedTuples.get(), + probeState->matchedTuples.get(), probeState->matchedSelVector); + probeState->matchedSelVector.setSelSize(numMatchedTuples); + probeState->nextMatchedTupleIdx = 0; + return true; +} + +uint64_t HashJoinProbe::getInnerJoinResultForFlatKey() { + if (probeState->matchedSelVector.getSelSize() == 0) { + return 0; + } + auto numTuplesToRead = 1; + sharedState->getHashTable()->lookup(vectorsToReadInto, columnIdxsToReadFrom, + probeState->matchedTuples.get(), probeState->nextMatchedTupleIdx, numTuplesToRead); + probeState->nextMatchedTupleIdx += numTuplesToRead; + return numTuplesToRead; +} + +uint64_t HashJoinProbe::getInnerJoinResultForUnFlatKey() { + auto numTuplesToRead = probeState->matchedSelVector.getSelSize(); + if (numTuplesToRead == 0) { + return 0; + } + auto& keySelVector = keyVectors[0]->state->getSelVectorUnsafe(); + if (keySelVector.getSelSize() != numTuplesToRead) { + // Some keys have no matched tuple. So we modify selected position. + auto buffer = keySelVector.getMutableBuffer(); + for (auto i = 0u; i < numTuplesToRead; i++) { + buffer[i] = probeState->matchedSelVector[i]; + } + keySelVector.setToFiltered(numTuplesToRead); + } + sharedState->getHashTable()->lookup(vectorsToReadInto, columnIdxsToReadFrom, + probeState->matchedTuples.get(), probeState->nextMatchedTupleIdx, numTuplesToRead); + probeState->nextMatchedTupleIdx += numTuplesToRead; + return numTuplesToRead; +} + +static void writeLeftJoinMarkVector(ValueVector* markVector, bool flag) { + if (markVector == nullptr) { + return; + } + KU_ASSERT(markVector->state->getSelVector().getSelSize() == 1); + auto pos = markVector->state->getSelVector()[0]; + markVector->setValue(pos, flag); +} + +uint64_t HashJoinProbe::getLeftJoinResult() { + if (getInnerJoinResult() == 0) { + for (auto& vector : vectorsToReadInto) { + vector->setAsSingleNullEntry(); + } + // TODO(Xiyang): We have a bug in LEFT JOIN which should not discard NULL keys. To be more + // clear, NULL keys should only be discarded for probe but should not reflect on the vector. + // The following for loop is a temporary hack. + for (auto& vector : keyVectors) { + KU_ASSERT(vector->state->isFlat()); + vector->state->getSelVectorUnsafe().setSelSize(1); + } + probeState->probedTuples[0] = nullptr; + writeLeftJoinMarkVector(markVector, false); + return 1; + } + writeLeftJoinMarkVector(markVector, true); + return 1; +} + +uint64_t HashJoinProbe::getCountJoinResult() { + KU_ASSERT(vectorsToReadInto.size() == 1); + if (getInnerJoinResult() == 0) { + auto pos = vectorsToReadInto[0]->state->getSelVector()[0]; + vectorsToReadInto[0]->setValue(pos, 0); + probeState->probedTuples[0] = nullptr; + } + return 1; +} + +uint64_t HashJoinProbe::getMarkJoinResult() { + auto markValues = (bool*)markVector->getData(); + if (markVector->state->isFlat()) { + auto pos = markVector->state->getSelVector()[0]; + markValues[pos] = probeState->matchedSelVector.getSelSize() != 0; + } else { + std::fill(markValues, markValues + DEFAULT_VECTOR_CAPACITY, false); + for (auto i = 0u; i < probeState->matchedSelVector.getSelSize(); i++) { + auto pos = probeState->matchedSelVector[i]; + markValues[pos] = true; + } + } + probeState->probedTuples[0] = nullptr; + probeState->nextMatchedTupleIdx = probeState->matchedSelVector.getSelSize(); + return 1; +} + +uint64_t HashJoinProbe::getJoinResult() { + switch (joinType) { + case JoinType::LEFT: { + return getLeftJoinResult(); + } + case JoinType::COUNT: { + return getCountJoinResult(); + } + case JoinType::MARK: { + return getMarkJoinResult(); + } + case JoinType::INNER: { + return getInnerJoinResult(); + } + default: + throw InternalException("Unimplemented join type for HashJoinProbe::getJoinResult()"); + } +} + +// The general flow of a hash join probe: +// 1) find matched tuples of probe side key from ht. +// 2) populate values from matched tuples into resultKeyDataChunk , buildSideFlatResultDataChunk +// (all flat data chunks from the build side are merged into one) and buildSideVectorPtrs (each +// VectorPtr corresponds to one unFlat build side data chunk that is appended to the resultSet). +bool HashJoinProbe::getNextTuplesInternal(ExecutionContext* context) { + uint64_t numPopulatedTuples = 0; + do { + if (!getMatchedTuples(context)) { + return false; + } + numPopulatedTuples = getJoinResult(); + } while (numPopulatedTuples == 0); + metrics->numOutputTuple.increase(numPopulatedTuples); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/join_hash_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/join_hash_table.cpp new file mode 100644 index 0000000000..81925998c6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/hash_join/join_hash_table.cpp @@ -0,0 +1,212 @@ +#include "processor/operator/hash_join/join_hash_table.h" + +#include "common/utils.h" +#include "function/hash/vector_hash_functions.h" +#include "processor/result/factorized_table.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +JoinHashTable::JoinHashTable(MemoryManager& memoryManager, logical_type_vec_t keyTypes, + FactorizedTableSchema tableSchema) + : BaseHashTable{memoryManager, std::move(keyTypes)} { + auto numSlotsPerBlock = HASH_BLOCK_SIZE / sizeof(uint8_t*); + initSlotConstant(numSlotsPerBlock); + // Prev pointer is always the last column in the table. + prevPtrColOffset = tableSchema.getColOffset(tableSchema.getNumColumns() - PREV_PTR_COL_IDX); + factorizedTable = std::make_unique(&memoryManager, std::move(tableSchema)); +} + +static bool discardNullFromKeys(const std::vector& vectors) { + bool hasNonNullKeys = true; + for (auto& vector : vectors) { + if (!ValueVector::discardNull(*vector)) { + hasNonNullKeys = false; + break; + } + } + return hasNonNullKeys; +} + +uint64_t JoinHashTable::appendVectors(const std::vector& keyVectors, + const std::vector& payloadVectors, DataChunkState* keyState) { + discardNullFromKeys(keyVectors); + auto numTuplesToAppend = keyState->getSelVector().getSelSize(); + auto appendInfos = factorizedTable->allocateFlatTupleBlocks(numTuplesToAppend); + computeVectorHashes(keyVectors); + auto colIdx = 0u; + for (auto& vector : keyVectors) { + appendVector(vector, appendInfos, colIdx++); + } + for (auto& vector : payloadVectors) { + appendVector(vector, appendInfos, colIdx++); + } + appendVector(hashVector.get(), appendInfos, colIdx); + factorizedTable->numTuples += numTuplesToAppend; + return numTuplesToAppend; +} + +void JoinHashTable::appendVector(ValueVector* vector, + const std::vector& appendInfos, ft_col_idx_t colIdx) { + auto numAppendedTuples = 0ul; + for (auto& blockAppendInfo : appendInfos) { + factorizedTable->copyVectorToColumn(*vector, blockAppendInfo, numAppendedTuples, colIdx); + numAppendedTuples += blockAppendInfo.numTuplesToAppend; + } +} + +static void sortSelectedPos(ValueVector* nodeIDVector) { + auto& selVector = nodeIDVector->state->getSelVectorUnsafe(); + auto size = selVector.getSelSize(); + auto buffer = selVector.getMutableBuffer(); + if (selVector.isUnfiltered()) { + std::memcpy(buffer.data(), selVector.getSelectedPositions().data(), size * sizeof(sel_t)); + selVector.setToFiltered(); + } + std::sort(buffer.begin(), buffer.begin() + size, [nodeIDVector](sel_t left, sel_t right) { + return nodeIDVector->getValue(left) < nodeIDVector->getValue(right); + }); +} + +uint64_t JoinHashTable::appendVectorWithSorting(ValueVector* keyVector, + std::vector payloadVectors) { + auto numTuplesToAppend = 1; + KU_ASSERT(keyVector->state->getSelVector().getSelSize() == 1); + // Based on the way we are planning, we assume that the first and second vectors are both + // nodeIDs from extending, while the first one is key, and the second one is payload. + auto payloadNodeIDVector = payloadVectors[0]; + auto payloadsState = payloadNodeIDVector->state.get(); + if (!payloadsState->isFlat()) { + // Sorting is only needed when the payload is unFlat (a list of values). + sortSelectedPos(payloadNodeIDVector); + } + // A single appendInfo will return from `allocateFlatTupleBlocks` when numTuplesToAppend is 1. + auto appendInfos = factorizedTable->allocateFlatTupleBlocks(numTuplesToAppend); + KU_ASSERT(appendInfos.size() == 1); + auto colIdx = 0u; + std::vector keyVectors = {keyVector}; + computeVectorHashes(keyVectors); + factorizedTable->copyVectorToColumn(*keyVector, appendInfos[0], numTuplesToAppend, colIdx++); + for (auto& vector : payloadVectors) { + factorizedTable->copyVectorToColumn(*vector, appendInfos[0], numTuplesToAppend, colIdx++); + } + factorizedTable->copyVectorToColumn(*hashVector, appendInfos[0], numTuplesToAppend, colIdx); + if (!payloadsState->isFlat()) { + // TODO(Xiyang): I can no longer recall why I set to un-filtered but this is probably wrong. + // We should set back to the un-sorted state. + payloadsState->getSelVectorUnsafe().setToUnfiltered(); + } + factorizedTable->numTuples += numTuplesToAppend; + return numTuplesToAppend; +} + +void JoinHashTable::allocateHashSlots(uint64_t numTuples) { + setMaxNumHashSlots(nextPowerOfTwo(numTuples * 2)); + auto numSlotsPerBlock = (uint64_t)1 << numSlotsPerBlockLog2; + auto numBlocksNeeded = (maxNumHashSlots + numSlotsPerBlock - 1) / numSlotsPerBlock; + while (hashSlotsBlocks.size() < numBlocksNeeded) { + hashSlotsBlocks.emplace_back(std::make_unique(memoryManager, HASH_BLOCK_SIZE)); + } +} + +void JoinHashTable::buildHashSlots() { + for (auto& tupleBlock : factorizedTable->getTupleDataBlocks()) { + uint8_t* tuple = tupleBlock->getData(); + for (auto i = 0u; i < tupleBlock->numTuples; i++) { + auto lastSlotEntryInHT = insertEntry(tuple); + auto prevPtr = getPrevTuple(tuple); + memcpy(reinterpret_cast(prevPtr), reinterpret_cast(&lastSlotEntryInHT), + sizeof(uint8_t*)); + tuple += getTableSchema()->getNumBytesPerTuple(); + } + } +} + +void JoinHashTable::probe(const std::vector& keyVectors, ValueVector& hashVector, + SelectionVector& hashSelVec, ValueVector* tmpHashResultVector, uint8_t** probedTuples) { + KU_ASSERT(keyVectors.size() == keyTypes.size()); + if (getNumEntries() == 0) { + return; + } + if (!discardNullFromKeys(keyVectors)) { + return; + } + hashSelVec.setSelSize(keyVectors[0]->state->getSelVector().getSelSize()); + VectorHashFunction::computeHash(*keyVectors[0], keyVectors[0]->state->getSelVector(), + hashVector, hashSelVec); + for (auto i = 1u; i < keyVectors.size(); i++) { + hashSelVec.setSelSize(keyVectors[i]->state->getSelVector().getSelSize()); + VectorHashFunction::computeHash(*keyVectors[i], keyVectors[i]->state->getSelVector(), + *tmpHashResultVector, hashSelVec); + VectorHashFunction::combineHash(hashVector, hashSelVec, *tmpHashResultVector, hashSelVec, + hashVector, hashSelVec); + } + for (auto i = 0u; i < hashSelVec.getSelSize(); i++) { + KU_ASSERT(i < DEFAULT_VECTOR_CAPACITY); + probedTuples[i] = getTupleForHash(hashVector.getValue(hashSelVec[i])); + } +} + +sel_t JoinHashTable::matchFlatKeys(const std::vector& keyVectors, + uint8_t** probedTuples, uint8_t** matchedTuples) { + auto numMatchedTuples = 0; + while (probedTuples[0]) { + if (numMatchedTuples == DEFAULT_VECTOR_CAPACITY) { + break; + } + auto currentTuple = probedTuples[0]; + matchedTuples[numMatchedTuples] = currentTuple; + numMatchedTuples += matchFlatVecWithEntry(keyVectors, currentTuple); + probedTuples[0] = *getPrevTuple(currentTuple); + } + return numMatchedTuples; +} + +sel_t JoinHashTable::matchUnFlatKey(ValueVector* keyVector, uint8_t** probedTuples, + uint8_t** matchedTuples, SelectionVector& matchedTuplesSelVector) { + auto numMatchedTuples = 0; + for (auto i = 0u; i < keyVector->state->getSelVector().getSelSize(); ++i) { + auto pos = keyVector->state->getSelVector()[i]; + while (probedTuples[i]) { + auto currentTuple = probedTuples[i]; + auto entryCompareResult = compareEntryFuncs[0](keyVector, pos, currentTuple); + if (entryCompareResult) { + matchedTuples[numMatchedTuples] = currentTuple; + matchedTuplesSelVector[numMatchedTuples] = pos; + numMatchedTuples++; + break; + } + probedTuples[i] = *getPrevTuple(currentTuple); + } + } + return numMatchedTuples; +} + +uint8_t** JoinHashTable::findHashSlot(const uint8_t* tuple) const { + auto hash = *(hash_t*)(tuple + getHashValueColOffset()); + auto slotIdx = getSlotIdxForHash(hash); + return (uint8_t**)(hashSlotsBlocks[slotIdx >> numSlotsPerBlockLog2]->getData() + + (slotIdx & slotIdxInBlockMask) * sizeof(uint8_t*)); +} + +uint8_t* JoinHashTable::insertEntry(uint8_t* tuple) const { + auto slot = findHashSlot(tuple); + auto prevPtr = *slot; + *slot = tuple; + return prevPtr; +} + +void JoinHashTable::computeVectorHashes(std::vector keyVectors) { + BaseHashTable::computeVectorHashes(keyVectors); +} + +offset_t JoinHashTable::getHashValueColOffset() const { + return getTableSchema()->getColOffset(getTableSchema()->getNumColumns() - HASH_COL_IDX); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/index_lookup.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/index_lookup.cpp new file mode 100644 index 0000000000..f4b77513a9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/index_lookup.cpp @@ -0,0 +1,161 @@ +#include "processor/operator/index_lookup.h" + +#include "binder/expression/expression_util.h" +#include "common/assert.h" +#include "common/exception/message.h" +#include "common/types/types.h" +#include "common/utils.h" +#include "common/vector/value_vector.h" +#include "processor/warning_context.h" +#include "storage/index/hash_index.h" +#include "storage/table/node_table.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +namespace { + +std::optional getWarningSourceData( + const std::vector& warningDataVectors, sel_t pos) { + std::optional ret; + if (!warningDataVectors.empty()) { + ret.emplace(WarningSourceData::constructFromData(warningDataVectors, + safeIntegerConversion(pos))); + } + return ret; +} + +// TODO(Guodong): Add short path for unfiltered case. +bool checkNullKey(ValueVector* keyVector, offset_t vectorOffset, + BatchInsertErrorHandler* errorHandler, const std::vector& warningDataVectors) { + bool isNull = keyVector->isNull(vectorOffset); + if (isNull) { + errorHandler->handleError(ExceptionMessage::nullPKException(), + getWarningSourceData(warningDataVectors, vectorOffset)); + } + return !isNull; +} + +struct OffsetVectorManager { + OffsetVectorManager(ValueVector* resultVector, BatchInsertErrorHandler* errorHandler) + : ignoreErrors(errorHandler->getIgnoreErrors()), resultVector(resultVector), + insertOffset(0) { + // if we are ignoring errors we may need to filter the output sel vector + if (ignoreErrors) { + resultVector->state->getSelVectorUnsafe().setToFiltered(); + } + } + + ~OffsetVectorManager() { + if (ignoreErrors) { + resultVector->state->getSelVectorUnsafe().setSelSize(insertOffset); + } + } + + void insertEntry(offset_t entry, sel_t posInKeyVector) { + auto* offsets = reinterpret_cast(resultVector->getData()); + offsets[posInKeyVector] = entry; + if (ignoreErrors) { + // if the lookup was successful we may add the current entry to the output selection + resultVector->state->getSelVectorUnsafe()[insertOffset] = posInKeyVector; + } + ++insertOffset; + } + + bool ignoreErrors; + ValueVector* resultVector; + + offset_t insertOffset; +}; + +// TODO(Guodong): Add short path for unfiltered case. +template +void fillOffsetArraysFromVector(transaction::Transaction* transaction, const IndexLookupInfo& info, + ValueVector* keyVector, ValueVector* resultVector, + const std::vector& warningDataVectors, BatchInsertErrorHandler* errorHandler) { + KU_ASSERT(resultVector->dataType.getPhysicalType() == PhysicalTypeID::INT64); + TypeUtils::visit( + keyVector->dataType.getPhysicalType(), + [&](T) { + auto numKeys = keyVector->state->getSelVector().getSelSize(); + + // fetch all the selection pos at the start + // since we may modify the selection vector in the middle of the lookup + std::vector lookupPos(numKeys); + for (idx_t i = 0; i < numKeys; ++i) { + lookupPos[i] = (keyVector->state->getSelVector()[i]); + } + + OffsetVectorManager resultManager{resultVector, errorHandler}; + for (auto i = 0u; i < numKeys; i++) { + auto pos = lookupPos[i]; + if constexpr (!hasNoNullsGuarantee) { + if (!checkNullKey(keyVector, pos, errorHandler, warningDataVectors)) { + continue; + } + } + offset_t lookupOffset = 0; + if (!info.nodeTable->lookupPK(transaction, keyVector, pos, lookupOffset)) { + TypeUtils::visit(keyVector->dataType, [&](type) { + errorHandler->handleError( + ExceptionMessage::nonExistentPKException( + TypeUtils::toString(keyVector->getValue(pos), keyVector)), + getWarningSourceData(warningDataVectors, pos)); + }); + } else { + resultManager.insertEntry(lookupOffset, pos); + } + } + }, + [&](auto) { KU_UNREACHABLE; }); +} +} // namespace + +std::string IndexLookupPrintInfo::toString() const { + std::string result = "Indexes: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} + +bool IndexLookup::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + for (auto& info : infos) { + info.keyEvaluator->evaluate(); + lookup(transaction::Transaction::Get(*context->clientContext), info); + } + localState->errorHandler->flushStoredErrors(); + return true; +} + +void IndexLookup::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + auto errorHandler = std::make_unique(context, + WarningContext::Get(*context->clientContext)->getIgnoreErrorsOption()); + localState = std::make_unique(std::move(errorHandler)); + for (auto& pos : warningDataVectorPos) { + localState->warningDataVectors.push_back(resultSet->getValueVector(pos).get()); + } + for (auto& info : infos) { + info.keyEvaluator->init(*resultSet, context->clientContext); + } +} + +void IndexLookup::lookup(transaction::Transaction* transaction, const IndexLookupInfo& info) { + auto keyVector = info.keyEvaluator->resultVector.get(); + auto resultVector = resultSet->getValueVector(info.resultVectorPos).get(); + + if (keyVector->hasNoNullsGuarantee()) { + fillOffsetArraysFromVector(transaction, info, keyVector, resultVector, + localState->warningDataVectors, localState->errorHandler.get()); + } else { + fillOffsetArraysFromVector(transaction, info, keyVector, resultVector, + localState->warningDataVectors, localState->errorHandler.get()); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/intersect/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/intersect/CMakeLists.txt new file mode 100644 index 0000000000..8df0db7892 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/intersect/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_processor_operator_intersect + OBJECT + intersect.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/intersect/intersect.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/intersect/intersect.cpp new file mode 100644 index 0000000000..3eda85e8a7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/intersect/intersect.cpp @@ -0,0 +1,228 @@ +#include "processor/operator/intersect/intersect.h" + +#include + +#include "function/hash/hash_functions.h" +#include "processor/result/factorized_table.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string IntersectPrintInfo::toString() const { + std::string result = "Key: "; + result += key->toString(); + return result; +} + +void Intersect::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + outKeyVector = resultSet->getValueVector(outputDataPos); + for (auto& dataInfo : intersectDataInfos) { + probeKeyVectors.push_back(resultSet->getValueVector(dataInfo.keyDataPos)); + std::vector columnIdxesToScanFrom; + std::vector vectorsToReadInto; + for (auto i = 0u; i < dataInfo.payloadsDataPos.size(); i++) { + auto vector = resultSet->getValueVector(dataInfo.payloadsDataPos[i]); + // Always skip the first two columns in the fTable: build key and intersect key. + // TODO(Guodong): Remove this assumption so that keys can be stored in any order. Change + // mapping logic too so that we don't need to maintain this order explicitly. + columnIdxesToScanFrom.push_back(i + 2); + vectorsToReadInto.push_back(vector.get()); + } + payloadColumnIdxesToScanFrom.push_back(columnIdxesToScanFrom); + payloadVectorsToScanInto.push_back(std::move(vectorsToReadInto)); + } + for (auto& sharedHT : sharedHTs) { + intersectSelVectors.push_back(std::make_unique(DEFAULT_VECTOR_CAPACITY)); + isIntersectListAFlatValue.push_back( + sharedHT->getHashTable()->getTableSchema()->getColumn(1)->isFlat()); + } +} + +void Intersect::probeHTs() { + std::vector> flatTuples(probeKeyVectors.size()); + hash_t hashVal = 0; + for (auto i = 0u; i < probeKeyVectors.size(); i++) { + KU_ASSERT(probeKeyVectors[i]->state->isFlat()); + probedFlatTuples[i].clear(); + if (sharedHTs[i]->getHashTable()->getNumEntries() == 0) { + continue; + } + auto key = + probeKeyVectors[i]->getValue(probeKeyVectors[i]->state->getSelVector()[0]); + function::Hash::operation(key, false, hashVal); + auto flatTuple = sharedHTs[i]->getHashTable()->getTupleForHash(hashVal); + while (flatTuple) { + if (*(nodeID_t*)flatTuple == key) { + probedFlatTuples[i].push_back(flatTuple); + } + flatTuple = *sharedHTs[i]->getHashTable()->getPrevTuple(flatTuple); + } + } +} + +void Intersect::twoWayIntersect(nodeID_t* leftNodeIDs, SelectionVector& lSelVector, + nodeID_t* rightNodeIDs, SelectionVector& rSelVector) { + KU_ASSERT(lSelVector.getSelSize() <= rSelVector.getSelSize()); + auto leftPositionBuffer = lSelVector.getMutableBuffer(); + auto rightPositionBuffer = rSelVector.getMutableBuffer(); + sel_t leftPosition = 0, rightPosition = 0; + uint64_t outputValuePosition = 0; + while (leftPosition < lSelVector.getSelSize() && rightPosition < rSelVector.getSelSize()) { + auto leftNodeID = leftNodeIDs[leftPosition]; + auto rightNodeID = rightNodeIDs[rightPosition]; + if (leftNodeID < rightNodeID) { + leftPosition++; + } else if (leftNodeID > rightNodeID) { + rightPosition++; + } else { + leftPositionBuffer[outputValuePosition] = leftPosition; + rightPositionBuffer[outputValuePosition] = rightPosition; + leftNodeIDs[outputValuePosition] = leftNodeID; + leftPosition++; + rightPosition++; + outputValuePosition++; + } + } + lSelVector.setToFiltered(outputValuePosition); + rSelVector.setToFiltered(outputValuePosition); +} + +static std::vector fetchListsToIntersectFromTuples( + const std::vector& tuples, const std::vector& isFlatValue) { + std::vector listsToIntersect(tuples.size()); + for (auto i = 0u; i < tuples.size(); i++) { + listsToIntersect[i] = + isFlatValue[i] ? overflow_value_t{1 /* numElements */, tuples[i] + sizeof(nodeID_t)} : + *(overflow_value_t*)(tuples[i] + sizeof(nodeID_t)); + } + return listsToIntersect; +} + +static std::vector swapSmallestListToFront(std::vector& lists) { + KU_ASSERT(lists.size() >= 2); + std::vector listIdxes(lists.size()); + iota(listIdxes.begin(), listIdxes.end(), 0); + uint32_t smallestListIdx = 0; + for (auto i = 1u; i < lists.size(); i++) { + if (lists[i].numElements < lists[smallestListIdx].numElements) { + smallestListIdx = i; + } + } + if (smallestListIdx != 0) { + std::swap(lists[smallestListIdx], lists[0]); + std::swap(listIdxes[smallestListIdx], listIdxes[0]); + } + return listIdxes; +} + +static void sliceSelVectors(const std::vector& selVectorsToSlice, + SelectionVector& slicer) { + for (auto selVec : selVectorsToSlice) { + for (auto i = 0u; i < slicer.getSelSize(); i++) { + auto pos = slicer[i]; + auto buffer = selVec->getMutableBuffer(); + buffer[i] = selVec->operator[](pos); + } + selVec->setToFiltered(slicer.getSelSize()); + } +} + +void Intersect::intersectLists(const std::vector& listsToIntersect) { + if (listsToIntersect[0].numElements == 0) { + outKeyVector->state->getSelVectorUnsafe().setSelSize(0); + return; + } + KU_ASSERT(listsToIntersect[0].numElements <= DEFAULT_VECTOR_CAPACITY); + memcpy(outKeyVector->getData(), listsToIntersect[0].value, + listsToIntersect[0].numElements * sizeof(nodeID_t)); + SelectionVector lSelVector(listsToIntersect[0].numElements); + lSelVector.setSelSize(listsToIntersect[0].numElements); + std::vector selVectorsForIntersectedLists; + intersectSelVectors[0]->setToUnfiltered(listsToIntersect[0].numElements); + selVectorsForIntersectedLists.push_back(intersectSelVectors[0].get()); + for (auto i = 0u; i < listsToIntersect.size() - 1; i++) { + intersectSelVectors[i + 1]->setToUnfiltered(listsToIntersect[i + 1].numElements); + twoWayIntersect((nodeID_t*)outKeyVector->getData(), lSelVector, + (nodeID_t*)listsToIntersect[i + 1].value, *intersectSelVectors[i + 1]); + // Here we need to slice all selVectors that have been previously intersected, as all these + // lists need to be selected synchronously to read payloads correctly. + sliceSelVectors(selVectorsForIntersectedLists, lSelVector); + lSelVector.setToUnfiltered(); + selVectorsForIntersectedLists.push_back(intersectSelVectors[i + 1].get()); + } + outKeyVector->state->getSelVectorUnsafe().setSelSize(lSelVector.getSelSize()); +} + +void Intersect::populatePayloads(const std::vector& tuples, + const std::vector& listIdxes) { + for (auto i = 0u; i < listIdxes.size(); i++) { + auto listIdx = listIdxes[i]; + sharedHTs[listIdx]->getHashTable()->getFactorizedTable()->lookup( + payloadVectorsToScanInto[listIdx], intersectSelVectors[i].get(), + payloadColumnIdxesToScanFrom[listIdx], tuples[listIdx]); + } +} + +bool Intersect::hasNextTuplesToIntersect() { + tupleIdxPerBuildSide[carryBuildSideIdx]++; + if (tupleIdxPerBuildSide[carryBuildSideIdx] == probedFlatTuples[carryBuildSideIdx].size()) { + if (carryBuildSideIdx == 0) { + return false; + } + tupleIdxPerBuildSide[carryBuildSideIdx] = 0; + carryBuildSideIdx--; + if (!hasNextTuplesToIntersect()) { + return false; + } + carryBuildSideIdx++; + } + return true; +} + +bool Intersect::getNextTuplesInternal(ExecutionContext* context) { + do { + while (carryBuildSideIdx == -1u) { + if (!children[0]->getNextTuple(context)) { + return false; + } + // For each build side, probe its HT and return a vector of matched flat tuples. + probeHTs(); + auto maxNumTuplesToIntersect = 1u; + for (auto i = 0u; i < getNumBuilds(); i++) { + maxNumTuplesToIntersect *= probedFlatTuples[i].size(); + } + if (maxNumTuplesToIntersect == 0) { + // Skip if any build side has no matches. + continue; + } + carryBuildSideIdx = getNumBuilds() - 1; + std::fill(tupleIdxPerBuildSide.begin(), tupleIdxPerBuildSide.end(), 0); + } + // Cartesian product of all flat tuples probed from all build sides. + // Notice: when there are large adjacency lists in the build side, which means the list is + // too large to fit in a single ValueVector, we end up chunking the list as multiple tuples + // in FTable. Thus, when performing the intersection, we need to perform cartesian product + // between all flat tuples probed from all build sides. + std::vector flatTuplesToIntersect(getNumBuilds()); + for (auto i = 0u; i < getNumBuilds(); i++) { + flatTuplesToIntersect[i] = probedFlatTuples[i][tupleIdxPerBuildSide[i]]; + } + auto listsToIntersect = + fetchListsToIntersectFromTuples(flatTuplesToIntersect, isIntersectListAFlatValue); + auto listIdxes = swapSmallestListToFront(listsToIntersect); + intersectLists(listsToIntersect); + if (outKeyVector->state->getSelVector().getSelSize() != 0) { + populatePayloads(flatTuplesToIntersect, listIdxes); + } + if (!hasNextTuplesToIntersect()) { + carryBuildSideIdx = -1u; + } + } while (outKeyVector->state->getSelVector().getSelSize() == 0); + metrics->numOutputTuple.increase(outKeyVector->state->getSelVector().getSelSize()); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/limit.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/limit.cpp new file mode 100644 index 0000000000..4400480d2d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/limit.cpp @@ -0,0 +1,40 @@ +#include "processor/operator/limit.h" + +#include "common/metric.h" + +namespace lbug { +namespace processor { + +std::string LimitPrintInfo::toString() const { + return "Limit: " + std::to_string(limitNum); +} + +bool Limit::getNextTuplesInternal(ExecutionContext* context) { + // end of execution due to no more input + if (!children[0]->getNextTuple(context)) { + return false; + } + auto numTupleAvailable = resultSet->getNumTuples(dataChunksPosInScope); + auto numTupleProcessedBefore = counter->fetch_add(numTupleAvailable); + if (numTupleProcessedBefore + numTupleAvailable > limitNumber) { + int64_t numTupleToProcessInCurrentResultSet = limitNumber - numTupleProcessedBefore; + // end of execution due to limit has reached + if (numTupleToProcessInCurrentResultSet <= 0) { + return false; + } else { + // If all dataChunks are flat, numTupleAvailable = 1 which means numTupleProcessedBefore + // = limitNumber. So execution is terminated in above if statement. + auto& dataChunkToSelect = resultSet->dataChunks[dataChunkToSelectPos]; + KU_ASSERT(!dataChunkToSelect->state->isFlat()); + dataChunkToSelect->state->getSelVectorUnsafe().setSelSize( + numTupleToProcessInCurrentResultSet); + metrics->numOutputTuple.increase(numTupleToProcessInCurrentResultSet); + } + } else { + metrics->numOutputTuple.increase(numTupleAvailable); + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/macro/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/macro/CMakeLists.txt new file mode 100644 index 0000000000..527c46d20a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/macro/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_processor_operator_create_macro + OBJECT + create_macro.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/macro/create_macro.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/macro/create_macro.cpp new file mode 100644 index 0000000000..e92ee74c72 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/macro/create_macro.cpp @@ -0,0 +1,27 @@ +#include "processor/operator/macro/create_macro.h" + +#include "common/string_format.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string CreateMacroPrintInfo::toString() const { + return macroName; +} + +void CreateMacro::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto catalog = catalog::Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + catalog->addScalarMacroFunction(transaction, info.macroName, info.macro->copy()); + appendMessage(stringFormat("Macro: {} has been created.", info.macroName), + storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/multiplicity_reducer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/multiplicity_reducer.cpp new file mode 100644 index 0000000000..36895f78d9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/multiplicity_reducer.cpp @@ -0,0 +1,23 @@ +#include "processor/operator/multiplicity_reducer.h" + +namespace lbug { +namespace processor { + +bool MultiplicityReducer::getNextTuplesInternal(ExecutionContext* context) { + if (numRepeat == 0) { + restoreMultiplicity(); + if (!children[0]->getNextTuple(context)) { + return false; + } + saveMultiplicity(); + resultSet->multiplicity = 1; + } + numRepeat++; + if (numRepeat == prevMultiplicity) { + numRepeat = 0; + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/CMakeLists.txt new file mode 100644 index 0000000000..3b3fce6548 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/CMakeLists.txt @@ -0,0 +1,15 @@ +add_library(lbug_processor_operator_order_by + OBJECT + key_block_merger.cpp + order_by.cpp + order_by_key_encoder.cpp + order_by_merge.cpp + order_by_scan.cpp + radix_sort.cpp + sort_state.cpp + top_k.cpp + top_k_scanner.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/key_block_merger.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/key_block_merger.cpp new file mode 100644 index 0000000000..5714219cb5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/key_block_merger.cpp @@ -0,0 +1,326 @@ +#include "processor/operator/order_by/key_block_merger.h" + +#include "common/system_config.h" + +using namespace lbug::common; +using namespace lbug::processor; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +static constexpr uint64_t DATA_BLOCK_SIZE = common::TEMP_PAGE_SIZE; + +MergedKeyBlocks::MergedKeyBlocks(uint32_t numBytesPerTuple, uint64_t numTuples, + MemoryManager* memoryManager) + : numBytesPerTuple{numBytesPerTuple}, + numTuplesPerBlock{(uint32_t)(DATA_BLOCK_SIZE / numBytesPerTuple)}, numTuples{numTuples}, + endTupleOffset{numTuplesPerBlock * numBytesPerTuple} { + auto numKeyBlocks = numTuples / numTuplesPerBlock + (numTuples % numTuplesPerBlock ? 1 : 0); + for (auto i = 0u; i < numKeyBlocks; i++) { + keyBlocks.emplace_back(std::make_shared(memoryManager, DATA_BLOCK_SIZE)); + } +} + +// This constructor is used to convert a keyBlock to a MergedKeyBlocks. +MergedKeyBlocks::MergedKeyBlocks(uint32_t numBytesPerTuple, std::shared_ptr keyBlock) + : numBytesPerTuple{numBytesPerTuple}, + numTuplesPerBlock{(uint32_t)(DATA_BLOCK_SIZE / numBytesPerTuple)}, + numTuples{keyBlock->numTuples}, endTupleOffset{numTuplesPerBlock * numBytesPerTuple} { + keyBlocks.emplace_back(std::move(keyBlock)); +} + +uint8_t* MergedKeyBlocks::getBlockEndTuplePtr(uint32_t blockIdx, uint64_t endTupleIdx, + uint32_t endTupleBlockIdx) const { + KU_ASSERT(blockIdx < keyBlocks.size()); + if (endTupleIdx == 0) { + return getKeyBlockBuffer(0); + } + return blockIdx == endTupleBlockIdx ? getTuple(endTupleIdx - 1) + numBytesPerTuple : + getKeyBlockBuffer(blockIdx) + endTupleOffset; +} + +BlockPtrInfo::BlockPtrInfo(uint64_t startTupleIdx, uint64_t endTupleIdx, MergedKeyBlocks* keyBlocks) + : keyBlocks{keyBlocks}, curBlockIdx{startTupleIdx / keyBlocks->getNumTuplesPerBlock()}, + endBlockIdx{endTupleIdx == 0 ? 0 : (endTupleIdx - 1) / keyBlocks->getNumTuplesPerBlock()}, + endTupleIdx{endTupleIdx} { + if (startTupleIdx == endTupleIdx) { + curTuplePtr = nullptr; + endTuplePtr = nullptr; + curBlockEndTuplePtr = nullptr; + } else { + curTuplePtr = keyBlocks->getTuple(startTupleIdx); + endTuplePtr = keyBlocks->getBlockEndTuplePtr(endBlockIdx, endTupleIdx, endBlockIdx); + curBlockEndTuplePtr = keyBlocks->getBlockEndTuplePtr(curBlockIdx, endTupleIdx, endBlockIdx); + } +} + +void BlockPtrInfo::updateTuplePtrIfNecessary() { + if (curTuplePtr == curBlockEndTuplePtr) { + curBlockIdx++; + if (curBlockIdx <= endBlockIdx) { + curTuplePtr = keyBlocks->getKeyBlockBuffer(curBlockIdx); + curBlockEndTuplePtr = + keyBlocks->getBlockEndTuplePtr(curBlockIdx, endTupleIdx, endBlockIdx); + } + } +} + +uint64_t KeyBlockMergeTask::findRightKeyBlockIdx(uint8_t* leftEndTuplePtr) const { + // Find a tuple in the right memory block such that: + // 1. The value of the current tuple is smaller than the value in leftEndTuple. + // 2. Either the value of next tuple is larger than the value in leftEndTuple or + // the current tuple is the last tuple in the right memory block. + int64_t startIdx = rightKeyBlockNextIdx; + int64_t endIdx = rightKeyBlock->getNumTuples() - 1; + + while (startIdx <= endIdx) { + uint64_t curTupleIdx = (startIdx + endIdx) / 2; + uint8_t* curTuplePtr = rightKeyBlock->getTuple(curTupleIdx); + + if (keyBlockMerger.compareTuplePtr(leftEndTuplePtr, curTuplePtr)) { + if (curTupleIdx == rightKeyBlock->getNumTuples() - 1 || + !keyBlockMerger.compareTuplePtr(leftEndTuplePtr, + rightKeyBlock->getTuple(curTupleIdx + 1))) { + // If the current tuple is the last tuple or the value of next tuple is larger than + // the value of leftEndTuple, return the curTupleIdx. + return curTupleIdx; + } else { + startIdx = curTupleIdx + 1; + } + } else { + endIdx = curTupleIdx - 1; + } + } + // If such tuple doesn't exist, return -1. + return -1; +} + +std::unique_ptr KeyBlockMergeTask::getMorsel() { + // We grab a batch of tuples from the left memory block, then do a binary search on the + // right memory block to find the range of tuples to merge. + activeMorsels++; + if (rightKeyBlockNextIdx >= rightKeyBlock->getNumTuples()) { + // If there is no more tuples left in the right key block, + // just append all tuples in the left key block to the result key block. + auto keyBlockMergeMorsel = + std::make_unique(leftKeyBlockNextIdx, leftKeyBlock->getNumTuples(), + rightKeyBlock->getNumTuples(), rightKeyBlock->getNumTuples()); + leftKeyBlockNextIdx = leftKeyBlock->getNumTuples(); + return keyBlockMergeMorsel; + } + + auto leftKeyBlockStartIdx = leftKeyBlockNextIdx; + leftKeyBlockNextIdx += batch_size; + + if (leftKeyBlockNextIdx >= leftKeyBlock->getNumTuples()) { + // This is the last batch of tuples in the left key block to merge, so just merge it with + // remaining tuples of the right key block. + auto keyBlockMergeMorsel = std::make_unique(leftKeyBlockStartIdx, + std::min(leftKeyBlockNextIdx, leftKeyBlock->getNumTuples()), rightKeyBlockNextIdx, + rightKeyBlock->getNumTuples()); + rightKeyBlockNextIdx = rightKeyBlock->getNumTuples(); + return keyBlockMergeMorsel; + } else { + // Conduct a binary search to find the ending index in the right memory block. + auto leftEndIdxPtr = leftKeyBlock->getTuple(leftKeyBlockNextIdx - 1); + auto rightEndIdx = findRightKeyBlockIdx(leftEndIdxPtr); + + auto keyBlockMergeMorsel = std::make_unique(leftKeyBlockStartIdx, + std::min(leftKeyBlockNextIdx, leftKeyBlock->getNumTuples()), rightKeyBlockNextIdx, + rightEndIdx == (uint64_t)-1 ? rightKeyBlockNextIdx : ++rightEndIdx); + + if (rightEndIdx != (uint64_t)-1) { + rightKeyBlockNextIdx = rightEndIdx; + } + return keyBlockMergeMorsel; + } +} + +void KeyBlockMerger::mergeKeyBlocks(KeyBlockMergeMorsel& keyBlockMergeMorsel) const { + KU_ASSERT(keyBlockMergeMorsel.leftKeyBlockStartIdx < keyBlockMergeMorsel.leftKeyBlockEndIdx || + keyBlockMergeMorsel.rightKeyBlockStartIdx < keyBlockMergeMorsel.rightKeyBlockEndIdx); + + auto leftBlockPtrInfo = BlockPtrInfo(keyBlockMergeMorsel.leftKeyBlockStartIdx, + keyBlockMergeMorsel.leftKeyBlockEndIdx, + keyBlockMergeMorsel.keyBlockMergeTask->leftKeyBlock.get()); + + auto rightBlockPtrInfo = BlockPtrInfo(keyBlockMergeMorsel.rightKeyBlockStartIdx, + keyBlockMergeMorsel.rightKeyBlockEndIdx, + keyBlockMergeMorsel.keyBlockMergeTask->rightKeyBlock.get()); + + auto resultBlockPtrInfo = BlockPtrInfo(keyBlockMergeMorsel.leftKeyBlockStartIdx + + keyBlockMergeMorsel.rightKeyBlockStartIdx, + keyBlockMergeMorsel.leftKeyBlockEndIdx + keyBlockMergeMorsel.rightKeyBlockEndIdx, + keyBlockMergeMorsel.keyBlockMergeTask->resultKeyBlock.get()); + + while (leftBlockPtrInfo.hasMoreTuplesToRead() && rightBlockPtrInfo.hasMoreTuplesToRead()) { + uint64_t nextNumBytesToMerge = + std::min(std::min(leftBlockPtrInfo.getNumBytesLeftInCurBlock(), + rightBlockPtrInfo.getNumBytesLeftInCurBlock()), + resultBlockPtrInfo.getNumBytesLeftInCurBlock()); + for (auto i = 0u; i < nextNumBytesToMerge; i += numBytesPerTuple) { + if (compareTuplePtr(leftBlockPtrInfo.curTuplePtr, rightBlockPtrInfo.curTuplePtr)) { + memcpy(resultBlockPtrInfo.curTuplePtr, rightBlockPtrInfo.curTuplePtr, + numBytesPerTuple); + rightBlockPtrInfo.curTuplePtr += numBytesPerTuple; + } else { + memcpy(resultBlockPtrInfo.curTuplePtr, leftBlockPtrInfo.curTuplePtr, + numBytesPerTuple); + leftBlockPtrInfo.curTuplePtr += numBytesPerTuple; + } + resultBlockPtrInfo.curTuplePtr += numBytesPerTuple; + } + leftBlockPtrInfo.updateTuplePtrIfNecessary(); + rightBlockPtrInfo.updateTuplePtrIfNecessary(); + resultBlockPtrInfo.updateTuplePtrIfNecessary(); + } + + copyRemainingBlockDataToResult(rightBlockPtrInfo, resultBlockPtrInfo); + copyRemainingBlockDataToResult(leftBlockPtrInfo, resultBlockPtrInfo); +} + +// This function returns true if the value in the leftTuplePtr is larger than the value in the +// rightTuplePtr. +bool KeyBlockMerger::compareTuplePtrWithStringCol(uint8_t* leftTuplePtr, + uint8_t* rightTuplePtr) const { + // We can't simply use memcmp to compare tuples if there are string columns. + // We should only compare the binary strings starting from the last compared string column + // till the next string column. + uint64_t lastComparedBytes = 0; + for (auto& strKeyColInfo : strKeyColsInfo) { + auto result = memcmp(leftTuplePtr + lastComparedBytes, rightTuplePtr + lastComparedBytes, + strKeyColInfo.colOffsetInEncodedKeyBlock - lastComparedBytes + + strKeyColInfo.getEncodingSize()); + // If both sides are nulls, we can just continue to check the next string column. + auto leftStrColPtr = leftTuplePtr + strKeyColInfo.colOffsetInEncodedKeyBlock; + auto rightStrColPtr = rightTuplePtr + strKeyColInfo.colOffsetInEncodedKeyBlock; + if (OrderByKeyEncoder::isNullVal(leftStrColPtr, strKeyColInfo.isAscOrder) && + OrderByKeyEncoder::isNullVal(rightStrColPtr, strKeyColInfo.isAscOrder)) { + lastComparedBytes = + strKeyColInfo.colOffsetInEncodedKeyBlock + strKeyColInfo.getEncodingSize(); + continue; + } + + // If there is a tie, we need to compare the overflow ptr of strings values. + if (result == 0) { + // We do an optimization here to minimize the number of times that we fetch + // strings from factorizedTable. If both left and right strings are short string, + // they must equal to each other (since there are no other characters to compare for + // them). If one string is long string and the other string is short string, the + // long string must be greater than the short string. + bool isLeftStrLong = + OrderByKeyEncoder::isLongStr(leftStrColPtr, strKeyColInfo.isAscOrder); + bool isRightStrLong = + OrderByKeyEncoder::isLongStr(rightStrColPtr, strKeyColInfo.isAscOrder); + if (!isLeftStrLong && !isRightStrLong) { + continue; + } else if (isLeftStrLong && !isRightStrLong) { + return strKeyColInfo.isAscOrder; + } else if (!isLeftStrLong && isRightStrLong) { + return !strKeyColInfo.isAscOrder; + } + + auto leftTupleInfo = leftTuplePtr + numBytesToCompare; + auto rightTupleInfo = rightTuplePtr + numBytesToCompare; + auto leftBlockIdx = OrderByKeyEncoder::getEncodedFTBlockIdx(leftTupleInfo); + auto leftBlockOffset = OrderByKeyEncoder::getEncodedFTBlockOffset(leftTupleInfo); + auto rightBlockIdx = OrderByKeyEncoder::getEncodedFTBlockIdx(rightTupleInfo); + auto rightBlockOffset = OrderByKeyEncoder::getEncodedFTBlockOffset(rightTupleInfo); + + auto& leftFactorizedTable = + factorizedTables[OrderByKeyEncoder::getEncodedFTIdx(leftTupleInfo)]; + auto& rightFactorizedTable = + factorizedTables[OrderByKeyEncoder::getEncodedFTIdx(rightTupleInfo)]; + auto leftStr = leftFactorizedTable->getData(leftBlockIdx, leftBlockOffset, + strKeyColInfo.colOffsetInFT); + auto rightStr = rightFactorizedTable->getData(rightBlockIdx, + rightBlockOffset, strKeyColInfo.colOffsetInFT); + result = (leftStr == rightStr); + if (result) { + // If the tie can't be solved, we need to check the next string column. + lastComparedBytes = + strKeyColInfo.colOffsetInEncodedKeyBlock + strKeyColInfo.getEncodingSize(); + continue; + } + result = leftStr > rightStr; + return strKeyColInfo.isAscOrder == result; + } + return result > 0; + } + // The string tie can't be solved, just add the tuple in the leftMemBlock to + // resultMemBlock. + return false; +} + +void KeyBlockMerger::copyRemainingBlockDataToResult(BlockPtrInfo& blockToCopy, + BlockPtrInfo& resultBlock) const { + while (blockToCopy.curBlockIdx <= blockToCopy.endBlockIdx) { + uint64_t nextNumBytesToMerge = std::min(blockToCopy.getNumBytesLeftInCurBlock(), + resultBlock.getNumBytesLeftInCurBlock()); + for (auto i = 0u; i < nextNumBytesToMerge; i += numBytesPerTuple) { + memcpy(resultBlock.curTuplePtr, blockToCopy.curTuplePtr, numBytesPerTuple); + blockToCopy.curTuplePtr += numBytesPerTuple; + resultBlock.curTuplePtr += numBytesPerTuple; + } + blockToCopy.updateTuplePtrIfNecessary(); + resultBlock.updateTuplePtrIfNecessary(); + } +} + +std::unique_ptr KeyBlockMergeTaskDispatcher::getMorsel() { + if (isDoneMerge()) { + return nullptr; + } + std::lock_guard keyBlockMergeDispatcherLock{mtx}; + + if (!activeKeyBlockMergeTasks.empty() && activeKeyBlockMergeTasks.back()->hasMorselLeft()) { + // If there are morsels left in the lastMergeTask, just give it to the caller. + auto morsel = activeKeyBlockMergeTasks.back()->getMorsel(); + morsel->keyBlockMergeTask = activeKeyBlockMergeTasks.back(); + return morsel; + } else if (sortedKeyBlocks->size() > 1) { + // If there are no morsels left in the lastMergeTask, we just create a new merge task. + auto leftKeyBlock = sortedKeyBlocks->front(); + sortedKeyBlocks->pop(); + auto rightKeyBlock = sortedKeyBlocks->front(); + sortedKeyBlocks->pop(); + auto resultKeyBlock = std::make_shared(leftKeyBlock->getNumBytesPerTuple(), + leftKeyBlock->getNumTuples() + rightKeyBlock->getNumTuples(), memoryManager); + auto newMergeTask = std::make_shared(leftKeyBlock, rightKeyBlock, + resultKeyBlock, *keyBlockMerger); + activeKeyBlockMergeTasks.emplace_back(newMergeTask); + auto morsel = newMergeTask->getMorsel(); + morsel->keyBlockMergeTask = newMergeTask; + return morsel; + } else { + // There is no morsel can be given at this time, just wait for the ongoing merge + // task to finish. + return nullptr; + } +} + +void KeyBlockMergeTaskDispatcher::doneMorsel(std::unique_ptr morsel) { + std::lock_guard keyBlockMergeDispatcherLock{mtx}; + // If there is no active and morsels left tin the keyBlockMergeTask, just remove it from + // the active keyBlockMergeTask and add the result key block to the sortedKeyBlocks queue. + if ((--morsel->keyBlockMergeTask->activeMorsels) == 0 && + !morsel->keyBlockMergeTask->hasMorselLeft()) { + erase(activeKeyBlockMergeTasks, morsel->keyBlockMergeTask); + sortedKeyBlocks->emplace(morsel->keyBlockMergeTask->resultKeyBlock); + } +} + +void KeyBlockMergeTaskDispatcher::init(MemoryManager* memoryManager, + std::queue>* sortedKeyBlocks, + std::vector factorizedTables, std::vector& strKeyColsInfo, + uint64_t numBytesPerTuple) { + KU_ASSERT(this->keyBlockMerger == nullptr); + this->memoryManager = memoryManager; + this->sortedKeyBlocks = sortedKeyBlocks; + this->keyBlockMerger = std::make_unique(std::move(factorizedTables), + strKeyColsInfo, numBytesPerTuple); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by.cpp new file mode 100644 index 0000000000..5a97ec526c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by.cpp @@ -0,0 +1,46 @@ +#include "processor/operator/order_by/order_by.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string OrderByPrintInfo::toString() const { + std::string result = "Order By: "; + result += binder::ExpressionUtil::toString(keys); + result += ", Expressions: "; + result += binder::ExpressionUtil::toString(payloads); + return result; +} + +void OrderBy::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + localState = SortLocalState(); + localState.init(info, *sharedState, storage::MemoryManager::Get(*context->clientContext)); + for (auto& dataPos : info.payloadsPos) { + payloadVectors.push_back(resultSet->getValueVector(dataPos).get()); + } + for (auto& dataPos : info.keysPos) { + orderByVectors.push_back(resultSet->getValueVector(dataPos).get()); + } +} + +void OrderBy::initGlobalStateInternal(ExecutionContext* /*context*/) { + sharedState->init(info); +} + +void OrderBy::executeInternal(ExecutionContext* context) { + // Append thread-local tuples. + while (children[0]->getNextTuple(context)) { + for (auto i = 0u; i < resultSet->multiplicity; i++) { + localState.append(orderByVectors, payloadVectors); + } + } + localState.finalize(*sharedState); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_key_encoder.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_key_encoder.cpp new file mode 100644 index 0000000000..572a234d6d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_key_encoder.cpp @@ -0,0 +1,413 @@ +#include "processor/operator/order_by/order_by_key_encoder.h" + +#include +#include + +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "common/utils.h" +#include "storage/storage_utils.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { +static constexpr uint64_t DATA_BLOCK_SIZE = common::TEMP_PAGE_SIZE; + +OrderByKeyEncoder::OrderByKeyEncoder(const OrderByDataInfo& orderByDataInfo, + MemoryManager* memoryManager, uint8_t ftIdx, uint32_t numTuplesPerBlockInFT, + uint32_t numBytesPerTuple) + : memoryManager{memoryManager}, isAscOrder{orderByDataInfo.isAscOrder}, + numBytesPerTuple{numBytesPerTuple}, ftIdx{ftIdx}, + numTuplesPerBlockInFT{numTuplesPerBlockInFT}, swapBytes{isLittleEndian()} { + if (numTuplesPerBlockInFT > MAX_FT_BLOCK_OFFSET) { + throw RuntimeException( + "The number of tuples per block of factorizedTable exceeds the maximum blockOffset!"); + } + keyBlocks.emplace_back(std::make_unique(memoryManager, DATA_BLOCK_SIZE)); + KU_ASSERT(this->numBytesPerTuple == getNumBytesPerTuple()); + maxNumTuplesPerBlock = DATA_BLOCK_SIZE / numBytesPerTuple; + if (maxNumTuplesPerBlock <= 0) { + throw RuntimeException( + stringFormat("TupleSize({} bytes) is larger than the LARGE_PAGE_SIZE({} bytes)", + numBytesPerTuple, DATA_BLOCK_SIZE)); + } + encodeFunctions.reserve(orderByDataInfo.keysPos.size()); + for (auto& type : orderByDataInfo.keyTypes) { + encode_function_t encodeFunction; + getEncodingFunction(type.getPhysicalType(), encodeFunction); + encodeFunctions.push_back(std::move(encodeFunction)); + } +} + +void OrderByKeyEncoder::encodeKeys(const std::vector& orderByKeys) { + uint32_t numEntries = orderByKeys[0]->state->getSelVector().getSelSize(); + uint32_t encodedTuples = 0; + while (numEntries > 0) { + allocateMemoryIfFull(); + uint32_t numEntriesToEncode = + std::min(numEntries, maxNumTuplesPerBlock - getNumTuplesInCurBlock()); + auto tuplePtr = + keyBlocks.back()->getData() + keyBlocks.back()->numTuples * numBytesPerTuple; + uint32_t tuplePtrOffset = 0; + for (auto keyColIdx = 0u; keyColIdx < orderByKeys.size(); keyColIdx++) { + encodeVector(orderByKeys[keyColIdx], tuplePtr + tuplePtrOffset, encodedTuples, + numEntriesToEncode, keyColIdx); + tuplePtrOffset += getEncodingSize(orderByKeys[keyColIdx]->dataType); + } + encodeFTIdx(numEntriesToEncode, tuplePtr + tuplePtrOffset); + encodedTuples += numEntriesToEncode; + keyBlocks.back()->numTuples += numEntriesToEncode; + numEntries -= numEntriesToEncode; + } +} + +uint32_t OrderByKeyEncoder::getNumBytesPerTuple(const std::vector& keyVectors) { + uint32_t result = 0u; + for (auto& vector : keyVectors) { + result += getEncodingSize(vector->dataType); + } + result += 8; + return result; +} + +uint32_t OrderByKeyEncoder::getEncodingSize(const LogicalType& dataType) { + // Add one more byte for null flag. + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::STRING: + // 1 byte for null flag + 1 byte to indicate long/short string + 12 bytes for string prefix + return 2 + ku_string_t::SHORT_STR_LENGTH; + default: + return 1 + storage::StorageUtils::getDataTypeSize(dataType); + } +} + +void OrderByKeyEncoder::flipBytesIfNecessary(uint32_t keyColIdx, uint8_t* tuplePtr, + uint32_t numEntriesToEncode, LogicalType& type) { + if (!isAscOrder[keyColIdx]) { + auto encodingSize = getEncodingSize(type); + // If the current column is in desc order, flip all bytes. + for (auto i = 0u; i < numEntriesToEncode; i++) { + for (auto byte = 0u; byte < encodingSize; ++byte) { + *(tuplePtr + byte) = ~*(tuplePtr + byte); + } + tuplePtr += numBytesPerTuple; + } + } +} + +void OrderByKeyEncoder::encodeFlatVector(ValueVector* vector, uint8_t* tuplePtr, + uint32_t keyColIdx) { + auto pos = vector->state->getSelVector()[0]; + if (vector->isNull(pos)) { + for (auto j = 0u; j < getEncodingSize(vector->dataType); j++) { + *(tuplePtr + j) = UINT8_MAX; + } + } else { + *tuplePtr = 0; + encodeFunctions[keyColIdx](vector->getData() + pos * vector->getNumBytesPerValue(), + tuplePtr + 1, swapBytes); + } +} + +void OrderByKeyEncoder::encodeUnflatVector(ValueVector* vector, uint8_t* tuplePtr, + uint32_t encodedTuples, uint32_t numEntriesToEncode, uint32_t keyColIdx) { + if (vector->state->getSelVector().isUnfiltered()) { + auto value = vector->getData() + encodedTuples * vector->getNumBytesPerValue(); + if (vector->hasNoNullsGuarantee()) { + for (auto i = 0u; i < numEntriesToEncode; i++) { + *tuplePtr = 0; + encodeFunctions[keyColIdx](value, tuplePtr + 1, swapBytes); + tuplePtr += numBytesPerTuple; + value += vector->getNumBytesPerValue(); + } + } else { + for (auto i = 0u; i < numEntriesToEncode; i++) { + if (vector->isNull(encodedTuples + i)) { + for (auto j = 0u; j < getEncodingSize(vector->dataType); j++) { + *(tuplePtr + j) = UINT8_MAX; + } + } else { + *tuplePtr = 0; + encodeFunctions[keyColIdx](value, tuplePtr + 1, swapBytes); + } + tuplePtr += numBytesPerTuple; + value += vector->getNumBytesPerValue(); + } + } + } else { + if (vector->hasNoNullsGuarantee()) { + for (auto i = 0u; i < numEntriesToEncode; i++) { + *tuplePtr = 0; + encodeFunctions[keyColIdx](vector->getData() + + vector->state->getSelVector()[i + encodedTuples] * + vector->getNumBytesPerValue(), + tuplePtr + 1, swapBytes); + tuplePtr += numBytesPerTuple; + } + } else { + for (auto i = 0u; i < numEntriesToEncode; i++) { + auto pos = vector->state->getSelVector()[i + encodedTuples]; + if (vector->isNull(pos)) { + for (auto j = 0u; j < getEncodingSize(vector->dataType); j++) { + *(tuplePtr + j) = UINT8_MAX; + } + } else { + *tuplePtr = 0; + encodeFunctions[keyColIdx](vector->getData() + + pos * vector->getNumBytesPerValue(), + tuplePtr + 1, swapBytes); + } + tuplePtr += numBytesPerTuple; + } + } + } +} + +void OrderByKeyEncoder::encodeVector(ValueVector* vector, uint8_t* tuplePtr, uint32_t encodedTuples, + uint32_t numEntriesToEncode, uint32_t keyColIdx) { + if (vector->state->isFlat()) { + encodeFlatVector(vector, tuplePtr, keyColIdx); + } else { + encodeUnflatVector(vector, tuplePtr, encodedTuples, numEntriesToEncode, keyColIdx); + } + flipBytesIfNecessary(keyColIdx, tuplePtr, numEntriesToEncode, vector->dataType); +} + +void OrderByKeyEncoder::encodeFTIdx(uint32_t numEntriesToEncode, uint8_t* tupleInfoPtr) { + uint32_t numUpdatedFTInfoEntries = 0; + while (numUpdatedFTInfoEntries < numEntriesToEncode) { + auto nextBatchOfEntries = std::min(numEntriesToEncode - numUpdatedFTInfoEntries, + numTuplesPerBlockInFT - ftBlockOffset); + for (auto i = 0u; i < nextBatchOfEntries; i++) { + *(uint32_t*)tupleInfoPtr = ftBlockIdx; + *(uint32_t*)(tupleInfoPtr + 4) = ftBlockOffset; + *(uint8_t*)(tupleInfoPtr + 7) = ftIdx; + tupleInfoPtr += numBytesPerTuple; + ftBlockOffset++; + } + numUpdatedFTInfoEntries += nextBatchOfEntries; + if (ftBlockOffset == numTuplesPerBlockInFT) { + ftBlockIdx++; + ftBlockOffset = 0; + } + } +} + +void OrderByKeyEncoder::allocateMemoryIfFull() { + if (getNumTuplesInCurBlock() == maxNumTuplesPerBlock) { + keyBlocks.emplace_back(std::make_shared(memoryManager, DATA_BLOCK_SIZE)); + } +} + +void OrderByKeyEncoder::getEncodingFunction(PhysicalTypeID physicalType, encode_function_t& func) { + switch (physicalType) { + case PhysicalTypeID::BOOL: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::INT64: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::INT32: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::INT16: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::INT8: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::UINT64: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::UINT32: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::UINT16: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::UINT8: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::INT128: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::DOUBLE: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::FLOAT: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::STRING: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::INTERVAL: { + func = encodeTemplate; + return; + } + case PhysicalTypeID::UINT128: { + func = encodeTemplate; + return; + } + default: + KU_UNREACHABLE; + } +} + +template<> +void OrderByKeyEncoder::encodeData(int8_t data, uint8_t* resultPtr, bool /*swapBytes*/) { + memcpy(resultPtr, (void*)&data, sizeof(data)); + resultPtr[0] = flipSign(resultPtr[0]); +} + +template<> +void OrderByKeyEncoder::encodeData(int16_t data, uint8_t* resultPtr, bool swapBytes) { + if (swapBytes) { + data = BSWAP16(data); + } + memcpy(resultPtr, (void*)&data, sizeof(data)); + resultPtr[0] = flipSign(resultPtr[0]); +} + +template<> +void OrderByKeyEncoder::encodeData(int32_t data, uint8_t* resultPtr, bool swapBytes) { + if (swapBytes) { + data = BSWAP32(data); + } + memcpy(resultPtr, (void*)&data, sizeof(data)); + resultPtr[0] = flipSign(resultPtr[0]); +} + +template<> +void OrderByKeyEncoder::encodeData(int64_t data, uint8_t* resultPtr, bool swapBytes) { + if (swapBytes) { + data = BSWAP64(data); + } + memcpy(resultPtr, (void*)&data, sizeof(data)); + resultPtr[0] = flipSign(resultPtr[0]); +} + +template<> +void OrderByKeyEncoder::encodeData(uint8_t data, uint8_t* resultPtr, bool /*swapBytes*/) { + memcpy(resultPtr, (void*)&data, sizeof(data)); +} + +template<> +void OrderByKeyEncoder::encodeData(uint16_t data, uint8_t* resultPtr, bool swapBytes) { + if (swapBytes) { + data = BSWAP16(data); + } + memcpy(resultPtr, (void*)&data, sizeof(data)); +} + +template<> +void OrderByKeyEncoder::encodeData(uint32_t data, uint8_t* resultPtr, bool swapBytes) { + if (swapBytes) { + data = BSWAP32(data); + } + memcpy(resultPtr, (void*)&data, sizeof(data)); +} + +template<> +void OrderByKeyEncoder::encodeData(uint64_t data, uint8_t* resultPtr, bool swapBytes) { + if (swapBytes) { + data = BSWAP64(data); + } + memcpy(resultPtr, (void*)&data, sizeof(data)); +} + +template<> +void OrderByKeyEncoder::encodeData(common::int128_t data, uint8_t* resultPtr, bool swapBytes) { + encodeData(data.high, resultPtr, swapBytes); + encodeData(data.low, resultPtr + sizeof(data.high), swapBytes); +} + +template<> +void OrderByKeyEncoder::encodeData(common::uint128_t data, uint8_t* resultPtr, bool swapBytes) { + encodeData(data.high, resultPtr, swapBytes); + encodeData(data.low, resultPtr + sizeof(data.high), swapBytes); +} + +template<> +void OrderByKeyEncoder::encodeData(bool data, uint8_t* resultPtr, bool /*swapBytes*/) { + uint8_t val = data ? 1 : 0; + memcpy(resultPtr, (void*)&val, sizeof(data)); +} + +template<> +void OrderByKeyEncoder::encodeData(double data, uint8_t* resultPtr, bool swapBytes) { + memcpy(resultPtr, &data, sizeof(data)); + uint64_t* dataBytes = (uint64_t*)resultPtr; + if (swapBytes) { + *dataBytes = BSWAP64(*dataBytes); + } + if (data < (double)0) { + *dataBytes = ~*dataBytes; + } else { + resultPtr[0] = flipSign(resultPtr[0]); + } +} + +template<> +void OrderByKeyEncoder::encodeData(date_t data, uint8_t* resultPtr, bool swapBytes) { + encodeData(data.days, resultPtr, swapBytes); +} + +template<> +void OrderByKeyEncoder::encodeData(timestamp_t data, uint8_t* resultPtr, bool swapBytes) { + encodeData(data.value, resultPtr, swapBytes); +} + +template<> +void OrderByKeyEncoder::encodeData(interval_t data, uint8_t* resultPtr, bool swapBytes) { + int64_t months = 0, days = 0, micros = 0; + Interval::normalizeIntervalEntries(data, months, days, micros); + encodeData((int32_t)months, resultPtr, swapBytes); + resultPtr += sizeof(data.months); + encodeData((int32_t)days, resultPtr, swapBytes); + resultPtr += sizeof(data.days); + encodeData(micros, resultPtr, swapBytes); +} + +template<> +void OrderByKeyEncoder::encodeData(ku_string_t data, uint8_t* resultPtr, bool /*swapBytes*/) { + // Only encode the prefix of ku_string. + memcpy(resultPtr, (void*)data.getAsString().c_str(), + std::min((uint32_t)ku_string_t::SHORT_STR_LENGTH, data.len)); + if (ku_string_t::isShortString(data.len)) { + memset(resultPtr + data.len, '\0', ku_string_t::SHORT_STR_LENGTH + 1 - data.len); + } else { + resultPtr[12] = UINT8_MAX; + } +} + +template<> +void OrderByKeyEncoder::encodeData(float data, uint8_t* resultPtr, bool swapBytes) { + memcpy(resultPtr, &data, sizeof(data)); + uint32_t* dataBytes = (uint32_t*)resultPtr; + if (swapBytes) { + *dataBytes = BSWAP32(*dataBytes); + } + if (data < (float)0) { + *dataBytes = ~*dataBytes; + } else { + resultPtr[0] = flipSign(resultPtr[0]); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_merge.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_merge.cpp new file mode 100644 index 0000000000..592dba8e2c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_merge.cpp @@ -0,0 +1,42 @@ +#include "processor/operator/order_by/order_by_merge.h" + +#include + +#include "common/constants.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void OrderByMerge::initLocalStateInternal(ResultSet* /*resultSet*/, ExecutionContext* /*context*/) { + // OrderByMerge is the only sink operator in a pipeline and only modifies the + // sharedState by merging sortedKeyBlocks, So we don't need to initialize the resultSet. + localMerger = make_unique(sharedState->getPayloadTables(), + sharedState->getStrKeyColInfo(), sharedState->getNumBytesPerTuple()); +} + +void OrderByMerge::executeInternal(ExecutionContext* /*context*/) { + while (!sharedDispatcher->isDoneMerge()) { + auto keyBlockMergeMorsel = sharedDispatcher->getMorsel(); + if (keyBlockMergeMorsel == nullptr) { + std::this_thread::sleep_for( + std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS)); + continue; + } + localMerger->mergeKeyBlocks(*keyBlockMergeMorsel); + sharedDispatcher->doneMorsel(std::move(keyBlockMergeMorsel)); + } +} + +void OrderByMerge::initGlobalStateInternal(ExecutionContext* context) { + // TODO(Ziyi): directly feed sharedState to merger and dispatcher. + sharedDispatcher->init(storage::MemoryManager::Get(*context->clientContext), + sharedState->getSortedKeyBlocks(), sharedState->getPayloadTables(), + sharedState->getStrKeyColInfo(), sharedState->getNumBytesPerTuple()); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_scan.cpp new file mode 100644 index 0000000000..9f0cff3b97 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/order_by_scan.cpp @@ -0,0 +1,45 @@ +#include "processor/operator/order_by/order_by_scan.h" + +#include "common/metric.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void OrderByScanLocalState::init(std::vector& outVectorPos, SortSharedState& sharedState, + ResultSet& resultSet) { + for (auto& dataPos : outVectorPos) { + vectorsToRead.push_back(resultSet.getValueVector(dataPos).get()); + } + payloadScanner = std::make_unique(sharedState.getMergedKeyBlock(), + sharedState.getPayloadTables()); + numTuples = 0; + for (auto& table : sharedState.getPayloadTables()) { + numTuples += table->getNumTuples(); + } + numTuplesRead = 0; +} + +void OrderByScan::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + localState->init(outVectorPos, *sharedState, *resultSet); +} + +bool OrderByScan::getNextTuplesInternal(ExecutionContext* /*context*/) { + // If there is no more tuples to read, just return false. + auto numTuplesRead = localState->scan(); + metrics->numOutputTuple.increase(numTuplesRead); + return numTuplesRead != 0; +} + +double OrderByScan::getProgress(ExecutionContext* /*context*/) const { + if (localState->numTuples == 0) { + return 0.0; + } else if (localState->numTuplesRead == localState->numTuples) { + return 1.0; + } + return static_cast(localState->numTuplesRead) / localState->numTuples; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/radix_sort.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/radix_sort.cpp new file mode 100644 index 0000000000..d05d6e8cb3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/radix_sort.cpp @@ -0,0 +1,283 @@ +#include "processor/operator/order_by/radix_sort.h" + +#include + +#include "common/system_config.h" +#include "function/comparison/comparison_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +static constexpr uint16_t COUNTING_ARRAY_SIZE = 256; +static constexpr uint64_t DATA_BLOCK_SIZE = common::TEMP_PAGE_SIZE; + +RadixSort::RadixSort(storage::MemoryManager* memoryManager, FactorizedTable& factorizedTable, + OrderByKeyEncoder& orderByKeyEncoder, std::vector strKeyColsInfo) + : tmpSortingResultBlock{std::make_unique(memoryManager, DATA_BLOCK_SIZE)}, + tmpTuplePtrSortingBlock{std::make_unique(memoryManager, DATA_BLOCK_SIZE)}, + factorizedTable{factorizedTable}, strKeyColsInfo{std::move(strKeyColsInfo)}, + numBytesPerTuple{orderByKeyEncoder.getNumBytesPerTuple()}, + numBytesToRadixSort{numBytesPerTuple - 8} {} + +void RadixSort::sortSingleKeyBlock(const DataBlock& keyBlock) { + auto numBytesSorted = 0ul; + auto numTuplesInKeyBlock = keyBlock.numTuples; + std::queue ties; + // We need to sort the whole keyBlock for the first radix sort, so just mark all tuples as a + // tie. + ties.push(TieRange{0, numTuplesInKeyBlock - 1}); + for (auto i = 0u; i < strKeyColsInfo.size(); i++) { + const auto numBytesToSort = strKeyColsInfo[i].colOffsetInEncodedKeyBlock - numBytesSorted + + strKeyColsInfo[i].getEncodingSize(); + const auto numOfTies = ties.size(); + for (auto j = 0u; j < numOfTies; j++) { + auto keyBlockTie = ties.front(); + ties.pop(); + radixSort(keyBlock.getData() + keyBlockTie.startingTupleIdx * numBytesPerTuple, + keyBlockTie.getNumTuples(), numBytesSorted, numBytesToSort); + + auto newTiesInKeyBlock = + findTies(keyBlock.getData() + keyBlockTie.startingTupleIdx * numBytesPerTuple + + numBytesSorted, + keyBlockTie.getNumTuples(), numBytesToSort, keyBlockTie.startingTupleIdx); + for (auto& newTieInKeyBlock : newTiesInKeyBlock) { + solveStringTies(newTieInKeyBlock, + keyBlock.getData() + newTieInKeyBlock.startingTupleIdx * numBytesPerTuple, ties, + strKeyColsInfo[i]); + } + } + if (ties.empty()) { + return; + } + numBytesSorted += numBytesToSort; + } + + if (numBytesSorted < numBytesPerTuple) { + while (!ties.empty()) { + auto tie = ties.front(); + ties.pop(); + radixSort(keyBlock.getData() + tie.startingTupleIdx * numBytesPerTuple, + tie.getNumTuples(), numBytesSorted, numBytesToRadixSort - numBytesSorted); + } + } +} + +void RadixSort::radixSort(uint8_t* keyBlockPtr, uint32_t numTuplesToSort, uint32_t numBytesSorted, + uint32_t numBytesToSort) { + // We use radixSortLSD which sorts from the least significant byte to the most significant byte. + auto tmpKeyBlockPtr = tmpSortingResultBlock->getData(); + keyBlockPtr += numBytesSorted; + tmpKeyBlockPtr += numBytesSorted; + uint32_t count[COUNTING_ARRAY_SIZE]; + auto isInTmpBlock = false; + for (auto curByteIdx = 1ul; curByteIdx <= numBytesToSort; curByteIdx++) { + memset(count, 0, COUNTING_ARRAY_SIZE * sizeof(uint32_t)); + auto sourcePtr = isInTmpBlock ? tmpKeyBlockPtr : keyBlockPtr; + auto targetPtr = isInTmpBlock ? keyBlockPtr : tmpKeyBlockPtr; + auto curByteOffset = numBytesToSort - curByteIdx; + auto sortBytePtr = sourcePtr + curByteOffset; + // counting sort + for (auto j = 0ul; j < numTuplesToSort; j++) { + count[*sortBytePtr]++; + sortBytePtr += numBytesPerTuple; + } + auto maxCounter = count[0]; + for (auto val = 1ul; val < COUNTING_ARRAY_SIZE; val++) { + maxCounter = std::max(count[val], maxCounter); + count[val] = count[val] + count[val - 1]; + } + // If all bytes have the same value (tie), continue on the next byte. + if (maxCounter == numTuplesToSort) { + continue; + } + // Reorder the data based on the count array. + auto sourceTuplePtr = sourcePtr + (numTuplesToSort - 1) * numBytesPerTuple; + for (auto j = 0ul; j < numTuplesToSort; j++) { + auto targetTupleNum = --count[*(sourceTuplePtr + curByteOffset)]; + memcpy(targetPtr + targetTupleNum * numBytesPerTuple - numBytesSorted, + sourceTuplePtr - numBytesSorted, numBytesPerTuple); + sourceTuplePtr -= numBytesPerTuple; + } + isInTmpBlock = !isInTmpBlock; + } + // If the data is in the tmp block, copy the data from tmp block back. + if (isInTmpBlock) { + memcpy(keyBlockPtr - numBytesSorted, tmpKeyBlockPtr - numBytesSorted, + numTuplesToSort * numBytesPerTuple); + } +} + +std::vector RadixSort::findTies(uint8_t* keyBlockPtr, uint32_t numTuplesToFindTies, + uint32_t numBytesToSort, uint32_t baseTupleIdx) const { + std::vector newTiesInKeyBlock; + auto iTuplePtr = keyBlockPtr; + for (auto i = 0u; i < numTuplesToFindTies - 1; i++) { + auto j = i + 1; + auto jTuplePtr = iTuplePtr + numBytesPerTuple; + for (; j < numTuplesToFindTies; j++) { + if (memcmp(iTuplePtr, jTuplePtr, numBytesToSort) != 0) { + break; + } + jTuplePtr += numBytesPerTuple; + } + j--; + if (i != j) { + newTiesInKeyBlock.emplace_back(TieRange(i + baseTupleIdx, j + baseTupleIdx)); + } + iTuplePtr = jTuplePtr; + i = j; + } + return newTiesInKeyBlock; +} + +void RadixSort::fillTmpTuplePtrSortingBlock(TieRange& keyBlockTie, uint8_t* keyBlockPtr) { + auto tmpTuplePtrSortingBlockPtr = (uint8_t**)tmpTuplePtrSortingBlock->getData(); + for (auto i = 0ul; i < keyBlockTie.getNumTuples(); i++) { + tmpTuplePtrSortingBlockPtr[i] = keyBlockPtr; + keyBlockPtr += numBytesPerTuple; + } +} + +void RadixSort::reOrderKeyBlock(TieRange& keyBlockTie, uint8_t* keyBlockPtr) { + auto tmpTuplePtrSortingBlockPtr = (uint8_t**)tmpTuplePtrSortingBlock->getData(); + auto tmpKeyBlockPtr = tmpSortingResultBlock->getData(); + for (auto i = 0ul; i < keyBlockTie.getNumTuples(); i++) { + memcpy(tmpKeyBlockPtr, tmpTuplePtrSortingBlockPtr[i], numBytesPerTuple); + tmpKeyBlockPtr += numBytesPerTuple; + } + memcpy(keyBlockPtr, tmpSortingResultBlock->getData(), + keyBlockTie.getNumTuples() * numBytesPerTuple); +} + +template +void RadixSort::findStringTies(TieRange& keyBlockTie, uint8_t* keyBlockPtr, + std::queue& ties, StrKeyColInfo& keyColInfo) { + auto iTuplePtr = keyBlockPtr; + for (auto i = keyBlockTie.startingTupleIdx; i < keyBlockTie.endingTupleIdx; i++) { + bool isIValNull = OrderByKeyEncoder::isNullVal( + iTuplePtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + // This variable will only be used when the current column is a string column. Otherwise, + // we just set this variable to false. + bool isIStringLong = OrderByKeyEncoder::isLongStr( + iTuplePtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + TYPE iValue = + isIValNull ? + TYPE() : + factorizedTable.getData( + OrderByKeyEncoder::getEncodedFTBlockIdx(iTuplePtr + numBytesToRadixSort), + OrderByKeyEncoder::getEncodedFTBlockOffset(iTuplePtr + numBytesToRadixSort), + keyColInfo.colOffsetInFT); + auto j = i + 1; + auto jTuplePtr = iTuplePtr + numBytesPerTuple; + for (; j <= keyBlockTie.endingTupleIdx; j++) { + auto jTupleInfoPtr = jTuplePtr + numBytesToRadixSort; + bool isJValNull = OrderByKeyEncoder::isNullVal( + jTuplePtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + if (isIValNull && isJValNull) { + // If the left value and the right value are nulls, we can just continue on + // the next tuple. + jTupleInfoPtr += numBytesPerTuple; + continue; + } else if (isIValNull || isJValNull) { + // If only one value is null, we can just conclude that those two values are + // not equal. + break; + } + if constexpr (std::is_same::value) { + // We do an optimization here to minimize the number of times that we fetch + // tuples from factorizedTable. If both left and right string are short, they + // must equal to each other (since they have the same prefix). If one string is + // short and the other string is long, then they must not equal to each other. + bool isJStringLong = OrderByKeyEncoder::isLongStr( + jTuplePtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + if (!isIStringLong && !isJStringLong) { + jTupleInfoPtr += numBytesPerTuple; + continue; + } else if (isIStringLong != isJStringLong) { + break; + } + } + + uint8_t result = UINT8_MAX; + function::NotEquals::operation(iValue, + factorizedTable.getData( + OrderByKeyEncoder::getEncodedFTBlockIdx(jTupleInfoPtr), + OrderByKeyEncoder::getEncodedFTBlockOffset(jTupleInfoPtr), + keyColInfo.colOffsetInFT), + result, nullptr /* leftVector */, nullptr /* rightVector */); + if (result) { + break; + } + jTuplePtr += numBytesPerTuple; + } + j--; + if (i != j) { + ties.push(TieRange(i, j)); + } + i = j; + iTuplePtr = jTuplePtr; + } +} + +void RadixSort::solveStringTies(TieRange& keyBlockTie, uint8_t* keyBlockPtr, + std::queue& ties, StrKeyColInfo& keyColInfo) { + fillTmpTuplePtrSortingBlock(keyBlockTie, keyBlockPtr); + auto tmpTuplePtrSortingBlockPtr = (uint8_t**)tmpTuplePtrSortingBlock->getData(); + std::sort(tmpTuplePtrSortingBlockPtr, tmpTuplePtrSortingBlockPtr + keyBlockTie.getNumTuples(), + [this, keyColInfo](const uint8_t* leftPtr, const uint8_t* rightPtr) -> bool { + auto isLeftNull = OrderByKeyEncoder::isNullVal( + leftPtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + auto isRightNull = OrderByKeyEncoder::isNullVal( + rightPtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + // Handle null value comparison. + if (isLeftNull && isRightNull) { + // If left and right strings are both null, we can't conclude that the left string + // is smaller than the right string. + return false; + } else if (isLeftNull) { + return !keyColInfo.isAscOrder; + } else if (isRightNull) { + return keyColInfo.isAscOrder; + } + + // We only need to fetch the actual strings from the + // factorizedTable when both left and right strings are long string. + auto isLeftLongStr = OrderByKeyEncoder::isLongStr( + leftPtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + auto isRightLongStr = OrderByKeyEncoder::isLongStr( + rightPtr + keyColInfo.colOffsetInEncodedKeyBlock, keyColInfo.isAscOrder); + if (!isLeftLongStr && !isRightLongStr) { + // If left and right are both short string and have the same prefix, we can't + // conclude that the left string is smaller than the right string. + return false; + } else if (isLeftLongStr && !isRightLongStr) { + // If left string is a long string and right string is a short string, we can + // conclude that the left string must be greater than the right string. + return !keyColInfo.isAscOrder; + } else if (isRightLongStr && !isLeftLongStr) { + // If right string is a long string and left string is a short string, we can + // conclude that the right string must be greater than the left string. + return keyColInfo.isAscOrder; + } + auto leftTupleInfoPtr = leftPtr + numBytesToRadixSort; + auto rightTupleInfoPtr = rightPtr + numBytesToRadixSort; + const auto leftBlockIdx = OrderByKeyEncoder::getEncodedFTBlockIdx(leftTupleInfoPtr); + const auto leftBlockOffset = + OrderByKeyEncoder::getEncodedFTBlockOffset(leftTupleInfoPtr); + const auto rightBlockIdx = OrderByKeyEncoder::getEncodedFTBlockIdx(rightTupleInfoPtr); + const auto rightBlockOffset = + OrderByKeyEncoder::getEncodedFTBlockOffset(rightTupleInfoPtr); + auto leftStr = factorizedTable.getData(leftBlockIdx, leftBlockOffset, + keyColInfo.colOffsetInFT); + auto rightStr = factorizedTable.getData(rightBlockIdx, rightBlockOffset, + keyColInfo.colOffsetInFT); + return keyColInfo.isAscOrder ? leftStr < rightStr : leftStr > rightStr; + }); + reOrderKeyBlock(keyBlockTie, keyBlockPtr); + findStringTies(keyBlockTie, keyBlockPtr, ties, keyColInfo); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/sort_state.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/sort_state.cpp new file mode 100644 index 0000000000..099caa754e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/sort_state.cpp @@ -0,0 +1,198 @@ +#include "processor/operator/order_by/sort_state.h" + +#include + +#include "common/constants.h" +#include "common/system_config.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void SortSharedState::init(const OrderByDataInfo& orderByDataInfo) { + auto encodedKeyBlockColOffset = 0ul; + for (auto i = 0u; i < orderByDataInfo.keysPos.size(); ++i) { + const auto& dataType = orderByDataInfo.keyTypes[i]; + if (PhysicalTypeID::STRING == dataType.getPhysicalType()) { + // If this is a string column, we need to find the factorizedTable offset for this + // column. + auto ftColIdx = orderByDataInfo.keyInPayloadPos[i]; + strKeyColsInfo.emplace_back(orderByDataInfo.payloadTableSchema.getColOffset(ftColIdx), + encodedKeyBlockColOffset, orderByDataInfo.isAscOrder[i]); + } + encodedKeyBlockColOffset += OrderByKeyEncoder::getEncodingSize(dataType); + } + numBytesPerTuple = encodedKeyBlockColOffset + OrderByConstants::NUM_BYTES_FOR_PAYLOAD_IDX; +} + +std::pair SortSharedState::getLocalPayloadTable( + storage::MemoryManager& memoryManager, const FactorizedTableSchema& payloadTableSchema) { + std::unique_lock lck{mtx}; + auto payloadTable = + std::make_unique(&memoryManager, payloadTableSchema.copy()); + auto result = std::make_pair(nextTableIdx++, payloadTable.get()); + payloadTables.push_back(std::move(payloadTable)); + return result; +} + +void SortSharedState::appendLocalSortedKeyBlock( + const std::shared_ptr& mergedDataBlocks) { + std::unique_lock lck{mtx}; + sortedKeyBlocks->emplace(mergedDataBlocks); +} + +void SortSharedState::combineFTHasNoNullGuarantee() { + for (auto i = 1u; i < payloadTables.size(); i++) { + payloadTables[0]->mergeMayContainNulls(*payloadTables[i]); + } +} + +std::vector SortSharedState::getPayloadTables() const { + std::vector payloadTablesToReturn; + payloadTablesToReturn.reserve(payloadTables.size()); + for (auto& payloadTable : payloadTables) { + payloadTablesToReturn.push_back(payloadTable.get()); + } + return payloadTablesToReturn; +} + +void SortLocalState::init(const OrderByDataInfo& orderByDataInfo, SortSharedState& sharedState, + storage::MemoryManager* memoryManager) { + auto [idx, table] = + sharedState.getLocalPayloadTable(*memoryManager, orderByDataInfo.payloadTableSchema); + globalIdx = idx; + payloadTable = table; + orderByKeyEncoder = std::make_unique(orderByDataInfo, memoryManager, + globalIdx, payloadTable->getNumTuplesPerBlock(), sharedState.getNumBytesPerTuple()); + radixSorter = std::make_unique(memoryManager, *payloadTable, *orderByKeyEncoder, + sharedState.getStrKeyColInfo()); +} + +void SortLocalState::append(const std::vector& keyVectors, + const std::vector& payloadVectors) { + orderByKeyEncoder->encodeKeys(keyVectors); + payloadTable->append(payloadVectors); +} + +void SortLocalState::finalize(lbug::processor::SortSharedState& sharedState) { + for (auto& keyBlock : orderByKeyEncoder->getKeyBlocks()) { + if (keyBlock->numTuples > 0) { + radixSorter->sortSingleKeyBlock(*keyBlock); + sharedState.appendLocalSortedKeyBlock( + make_shared(orderByKeyEncoder->getNumBytesPerTuple(), keyBlock)); + } + } + orderByKeyEncoder->clear(); +} + +PayloadScanner::PayloadScanner(MergedKeyBlocks* keyBlockToScan, + std::vector payloadTables, uint64_t skipNumber, uint64_t limitNumber) + : keyBlockToScan{keyBlockToScan}, payloadTables{std::move(payloadTables)}, + limitNumber{limitNumber} { + if (this->keyBlockToScan == nullptr || this->keyBlockToScan->getNumTuples() == 0) { + nextTupleIdxToReadInMergedKeyBlock = 0; + endTuplesIdxToReadInMergedKeyBlock = 0; + return; + } + payloadIdxOffset = + this->keyBlockToScan->getNumBytesPerTuple() - OrderByConstants::NUM_BYTES_FOR_PAYLOAD_IDX; + colsToScan = std::vector(this->payloadTables[0]->getTableSchema()->getNumColumns()); + iota(colsToScan.begin(), colsToScan.end(), 0); + hasUnflatColInPayload = this->payloadTables[0]->hasUnflatCol(); + if (!hasUnflatColInPayload) { + tuplesToRead = std::make_unique(DEFAULT_VECTOR_CAPACITY); + } + nextTupleIdxToReadInMergedKeyBlock = skipNumber == UINT64_MAX ? 0 : skipNumber; + endTuplesIdxToReadInMergedKeyBlock = + limitNumber == UINT64_MAX ? this->keyBlockToScan->getNumTuples() : + std::min(nextTupleIdxToReadInMergedKeyBlock + limitNumber, + this->keyBlockToScan->getNumTuples()); + blockPtrInfo = std::make_unique(nextTupleIdxToReadInMergedKeyBlock, + endTuplesIdxToReadInMergedKeyBlock, this->keyBlockToScan); +} + +uint64_t PayloadScanner::scan(std::vector vectorsToRead) { + if (limitNumber <= 0 || + nextTupleIdxToReadInMergedKeyBlock >= endTuplesIdxToReadInMergedKeyBlock) { + return 0; + } + if (scanSingleTuple(vectorsToRead)) { + auto payloadInfo = blockPtrInfo->curTuplePtr + payloadIdxOffset; + auto blockIdx = OrderByKeyEncoder::getEncodedFTBlockIdx(payloadInfo); + auto blockOffset = OrderByKeyEncoder::getEncodedFTBlockOffset(payloadInfo); + auto payloadTable = payloadTables[OrderByKeyEncoder::getEncodedFTIdx(payloadInfo)]; + payloadTable->scan(vectorsToRead, + blockIdx * payloadTable->getNumTuplesPerBlock() + blockOffset, 1 /* numTuples */); + blockPtrInfo->curTuplePtr += keyBlockToScan->getNumBytesPerTuple(); + blockPtrInfo->updateTuplePtrIfNecessary(); + nextTupleIdxToReadInMergedKeyBlock++; + applyLimitOnResultVectors(vectorsToRead); + return 1; + } else { + auto numTuplesToRead = std::min(DEFAULT_VECTOR_CAPACITY, + endTuplesIdxToReadInMergedKeyBlock - nextTupleIdxToReadInMergedKeyBlock); + auto numTuplesRead = 0u; + while (numTuplesRead < numTuplesToRead) { + auto numTuplesToReadInCurBlock = std::min(numTuplesToRead - numTuplesRead, + blockPtrInfo->getNumTuplesLeftInCurBlock()); + for (auto i = 0u; i < numTuplesToReadInCurBlock; i++) { + auto payloadInfo = blockPtrInfo->curTuplePtr + payloadIdxOffset; + auto blockIdx = OrderByKeyEncoder::getEncodedFTBlockIdx(payloadInfo); + auto blockOffset = OrderByKeyEncoder::getEncodedFTBlockOffset(payloadInfo); + auto ft = payloadTables[OrderByKeyEncoder::getEncodedFTIdx(payloadInfo)]; + tuplesToRead[numTuplesRead + i] = + ft->getTuple(blockIdx * ft->getNumTuplesPerBlock() + blockOffset); + blockPtrInfo->curTuplePtr += keyBlockToScan->getNumBytesPerTuple(); + } + blockPtrInfo->updateTuplePtrIfNecessary(); + numTuplesRead += numTuplesToReadInCurBlock; + } + // TODO(Ziyi): This is a hacky way of using factorizedTable::lookup function, + // since the tuples in tuplesToRead may not belong to factorizedTable0. The + // lookup function doesn't perform a check on whether it holds all the tuples in + // tuplesToRead. We should optimize this lookup function in the orderByScan + // optimization PR. + payloadTables[0]->lookup(vectorsToRead, colsToScan, tuplesToRead.get(), 0, numTuplesToRead); + nextTupleIdxToReadInMergedKeyBlock += numTuplesToRead; + return numTuplesRead; + } +} + +bool PayloadScanner::scanSingleTuple(std::vector vectorsToRead) const { + // If there is an unflat col in factorizedTable or flat vector in vectorsToRead, we can only + // read one tuple at a time. Otherwise, we can read min(DEFAULT_VECTOR_CAPACITY, + // numTuplesRemainingInMemBlock) tuples. + bool hasFlatVectorToRead = false; + for (auto& vector : vectorsToRead) { + if (vector->state->isFlat()) { + hasFlatVectorToRead = true; + } + } + return hasUnflatColInPayload || hasFlatVectorToRead; +} + +void PayloadScanner::applyLimitOnResultVectors(std::vector vectorsToRead) { + // The query doesn't contain a limit clause. + if (limitNumber == UINT64_MAX) { + return; + } + // Otherwise, we have to figure out the number of tuples in current batch exceeds the limit + // number. + common::ValueVector* unflatVector = nullptr; + for (auto& vector : vectorsToRead) { + if (!vector->state->isFlat()) { + unflatVector = vector; + } + } + if (unflatVector != nullptr) { + unflatVector->state->getSelVectorUnsafe().setSelSize( + std::min(limitNumber, (uint64_t)unflatVector->state->getSelVector().getSelSize())); + limitNumber -= unflatVector->state->getSelVector().getSelSize(); + } else { + limitNumber--; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/top_k.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/top_k.cpp new file mode 100644 index 0000000000..2f019e67b4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/top_k.cpp @@ -0,0 +1,289 @@ +#include "processor/operator/order_by/top_k.h" + +#include "binder/expression/expression_util.h" +#include "common/constants.h" +#include "common/system_config.h" +#include "common/type_utils.h" +#include "function/binary_function_executor.h" +#include "function/comparison/comparison_functions.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string TopKPrintInfo::toString() const { + std::string result = "Order By: "; + result += binder::ExpressionUtil::toString(keys); + result += ", Expressions: "; + result += binder::ExpressionUtil::toString(payloads); + result += ", Skip: "; + result += std::to_string(skipNum); + result += ", Limit: "; + result += std::to_string(limitNum); + return result; +} + +TopKSortState::TopKSortState() : numTuples{0}, memoryManager{nullptr} { + orderByLocalState = std::make_unique(); + orderBySharedState = std::make_unique(); +} + +void TopKSortState::init(const OrderByDataInfo& orderByDataInfo, + storage::MemoryManager* memoryManager_) { + this->memoryManager = memoryManager_; + orderBySharedState->init(orderByDataInfo); + orderByLocalState->init(orderByDataInfo, *orderBySharedState, memoryManager_); + numTuples = 0; +} + +void TopKSortState::append(const std::vector& keyVectors, + const std::vector& payloadVectors) { + numTuples += keyVectors[0]->state->getSelVector().getSelSize(); + orderByLocalState->append(keyVectors, payloadVectors); +} + +void TopKSortState::finalize() { + orderByLocalState->finalize(*orderBySharedState); + auto merger = std::make_unique(orderBySharedState->getPayloadTables(), + orderBySharedState->getStrKeyColInfo(), orderBySharedState->getNumBytesPerTuple()); + auto dispatcher = std::make_unique(); + dispatcher->init(memoryManager, orderBySharedState->getSortedKeyBlocks(), + orderBySharedState->getPayloadTables(), orderBySharedState->getStrKeyColInfo(), + orderBySharedState->getNumBytesPerTuple()); + while (!dispatcher->isDoneMerge()) { + auto keyBlockMergeMorsel = dispatcher->getMorsel(); + merger->mergeKeyBlocks(*keyBlockMergeMorsel); + dispatcher->doneMorsel(std::move(keyBlockMergeMorsel)); + } +} + +void TopKBuffer::init(storage::MemoryManager* memoryManager_, uint64_t skipNumber, + uint64_t limitNumber) { + this->memoryManager = memoryManager_; + sortState->init(*orderByDataInfo, memoryManager_); + this->skip = skipNumber; + this->limit = limitNumber; + initVectors(); + initCompareFuncs(); +} + +void TopKBuffer::append(const std::vector& keyVectors, + const std::vector& payloadVectors) { + auto originalSelState = keyVectors[0]->state->getSelVectorShared(); + if (hasBoundaryValue && !compareBoundaryValue(keyVectors)) { + keyVectors[0]->state->setSelVector(originalSelState); + return; + } + sortState->append(keyVectors, payloadVectors); + keyVectors[0]->state->setSelVector(originalSelState); +} + +void TopKBuffer::reduce() { + auto reduceThreshold = std::max(OrderByConfig::MIN_SIZE_TO_REDUCE, + OrderByConstants::MIN_LIMIT_RATIO_TO_REDUCE * (limit + skip)); + if (sortState->getNumTuples() < reduceThreshold) { + return; + } + sortState->finalize(); + auto newSortState = std::make_unique(); + newSortState->init(*orderByDataInfo, memoryManager); + auto scanner = sortState->getScanner(0, skip + limit); + while (true) { + auto numTuplesScanned = scanner->scan(payloadVecsToScan); + if (numTuplesScanned == 0) { + setBoundaryValue(); + break; + } + newSortState->append(keyVecsToScan, payloadVecsToScan); + std::swap(payloadVecsToScan, lastPayloadVecsToScan); + std::swap(keyVecsToScan, lastKeyVecsToScan); + } + sortState = std::move(newSortState); +} + +void TopKBuffer::merge(TopKBuffer* other) { + other->finalize(); + if (other->sortState->getSharedState()->getSortedKeyBlocks()->empty()) { + return; + } + auto scanner = other->sortState->getScanner(0, skip + limit); + while (scanner->scan(payloadVecsToScan) > 0) { + sortState->append(keyVecsToScan, payloadVecsToScan); + } + reduce(); +} + +void TopKBuffer::initVectors() { + auto payloadUnflatState = std::make_shared(); + auto payloadFlatState = common::DataChunkState::getSingleValueDataChunkState(); + auto lastPayloadUnflatState = std::make_shared(); + auto lastPayloadFlatState = common::DataChunkState::getSingleValueDataChunkState(); + for (auto i = 0u; i < orderByDataInfo->payloadTypes.size(); i++) { + auto type = &orderByDataInfo->payloadTypes[i]; + auto payloadVec = std::make_unique(type->copy(), memoryManager); + auto lastPayloadVec = std::make_unique(type->copy(), memoryManager); + if (orderByDataInfo->payloadTableSchema.getColumn(i)->isFlat()) { + payloadVec->setState(payloadFlatState); + lastPayloadVec->setState(lastPayloadFlatState); + } else { + payloadVec->setState(payloadUnflatState); + lastPayloadVec->setState(lastPayloadUnflatState); + } + payloadVecsToScan.push_back(payloadVec.get()); + lastPayloadVecsToScan.push_back(lastPayloadVec.get()); + tmpVectors.push_back(std::move(payloadVec)); + tmpVectors.push_back(std::move(lastPayloadVec)); + } + auto boundaryState = common::DataChunkState::getSingleValueDataChunkState(); + for (auto i = 0u; i < orderByDataInfo->keyTypes.size(); ++i) { + auto type = &orderByDataInfo->keyTypes[i]; + auto boundaryVec = std::make_unique(type->copy(), memoryManager); + boundaryVec->setState(boundaryState); + boundaryVecs.push_back(std::move(boundaryVec)); + auto posInPayload = orderByDataInfo->keyInPayloadPos[i]; + keyVecsToScan.push_back(payloadVecsToScan[posInPayload]); + lastKeyVecsToScan.push_back(lastPayloadVecsToScan[posInPayload]); + } +} + +template +void TopKBuffer::getSelectComparisonFunction(common::PhysicalTypeID typeID, + vector_select_comparison_func& selectFunc) { + common::TypeUtils::visit( + typeID, + [&selectFunc]( + T) { selectFunc = function::BinaryFunctionExecutor::selectComparison; }, + [](auto) { KU_UNREACHABLE; }); +} + +void TopKBuffer::initCompareFuncs() { + compareFuncs.reserve(orderByDataInfo->isAscOrder.size()); + equalsFuncs.reserve(orderByDataInfo->isAscOrder.size()); + vector_select_comparison_func compareFunc; + vector_select_comparison_func equalsFunc; + for (auto i = 0u; i < orderByDataInfo->isAscOrder.size(); i++) { + auto physicalType = orderByDataInfo->keyTypes[i].getPhysicalType(); + if (orderByDataInfo->isAscOrder[i]) { + getSelectComparisonFunction(physicalType, compareFunc); + } else { + getSelectComparisonFunction(physicalType, compareFunc); + } + getSelectComparisonFunction(physicalType, equalsFunc); + compareFuncs.push_back(compareFunc); + equalsFuncs.push_back(equalsFunc); + } +} + +void TopKBuffer::setBoundaryValue() { + for (auto i = 0u; i < boundaryVecs.size(); i++) { + auto boundaryVec = boundaryVecs[i].get(); + auto dstData = boundaryVec->getData() + + boundaryVec->getNumBytesPerValue() * boundaryVec->state->getSelVector()[0]; + auto srcVector = lastKeyVecsToScan[i]; + auto srcData = + srcVector->getData() + + srcVector->getNumBytesPerValue() * + srcVector->state->getSelVector()[srcVector->state->getSelVector().getSelSize() - 1]; + boundaryVec->copyFromVectorData(dstData, srcVector, srcData); + hasBoundaryValue = true; + } +} + +bool TopKBuffer::compareBoundaryValue(const std::vector& keyVectors) { + if (keyVectors[0]->state->isFlat()) { + return compareFlatKeys(0 /* startKeyVectorIdxToCompare */, keyVectors); + } else { + compareUnflatKeys(0 /* startKeyVectorIdxToCompare */, keyVectors); + return keyVectors[0]->state->getSelVector().getSelSize() > 0; + } +} + +bool TopKBuffer::compareFlatKeys(idx_t vectorIdxToCompare, std::vector keyVectors) { + KU_ASSERT(!keyVectors.empty()); + auto selVector = std::make_shared(common::DEFAULT_VECTOR_CAPACITY); + selVector->setToFiltered(); + + if (vectorIdxToCompare < keyVectors.size() - 1 && + equalsFuncs[vectorIdxToCompare](*keyVectors[vectorIdxToCompare], + *boundaryVecs[vectorIdxToCompare], *selVector, nullptr /* dataPtr */)) { + return compareFlatKeys(vectorIdxToCompare + 1, std::move(keyVectors)); + } else { + return compareFuncs[vectorIdxToCompare](*keyVectors[vectorIdxToCompare], + *boundaryVecs[vectorIdxToCompare], *selVector, nullptr /* dataPtr */); + } +} + +void TopKBuffer::compareUnflatKeys(idx_t vectorIdxToCompare, std::vector keyVectors) { + auto compareSelVector = + std::make_shared(common::DEFAULT_VECTOR_CAPACITY); + compareSelVector->setToFiltered(); + compareFuncs[vectorIdxToCompare](*keyVectors[vectorIdxToCompare], + *boundaryVecs[vectorIdxToCompare], *compareSelVector, nullptr /* dataPtr */); + if (vectorIdxToCompare != keyVectors.size() - 1) { + auto equalsSelVector = + std::make_shared(common::DEFAULT_VECTOR_CAPACITY); + equalsSelVector->setToFiltered(); + if (equalsFuncs[vectorIdxToCompare](*keyVectors[vectorIdxToCompare], + *boundaryVecs[vectorIdxToCompare], *equalsSelVector, nullptr /* dataPtr */)) { + keyVectors[vectorIdxToCompare]->state->setSelVector(equalsSelVector); + compareUnflatKeys(vectorIdxToCompare + 1, keyVectors); + appendSelState(compareSelVector.get(), equalsSelVector.get()); + } + } + keyVectors[vectorIdxToCompare]->state->setSelVector(std::move(compareSelVector)); +} + +void TopKBuffer::appendSelState(common::SelectionVector* selVector, + common::SelectionVector* selVectorToAppend) { + for (auto i = 0u; i < selVectorToAppend->getSelSize(); i++) { + selVector->operator[](selVector->getSelSize() + i) = selVectorToAppend->operator[](i); + } + selVector->incrementSelSize(selVectorToAppend->getSelSize()); +} + +void TopKLocalState::init(const OrderByDataInfo& orderByDataInfo, + storage::MemoryManager* memoryManager, ResultSet& /*resultSet*/, uint64_t skipNumber, + uint64_t limitNumber) { + buffer = std::make_unique(orderByDataInfo); + buffer->init(memoryManager, skipNumber, limitNumber); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void TopKLocalState::append(const std::vector& keyVectors, + const std::vector& payloadVectors) { + buffer->append(keyVectors, payloadVectors); + buffer->reduce(); +} + +void TopK::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + localState = TopKLocalState(); + localState.init(info, storage::MemoryManager::Get(*context->clientContext), *resultSet, + skipNumber, limitNumber); + for (auto& dataPos : info.payloadsPos) { + payloadVectors.push_back(resultSet->getValueVector(dataPos).get()); + } + for (auto& dataPos : info.keysPos) { + orderByVectors.push_back(resultSet->getValueVector(dataPos).get()); + } +} + +void TopK::initGlobalStateInternal(ExecutionContext* context) { + sharedState->init(info, storage::MemoryManager::Get(*context->clientContext), skipNumber, + limitNumber); +} + +void TopK::executeInternal(ExecutionContext* context) { + while (children[0]->getNextTuple(context)) { + for (auto i = 0u; i < resultSet->multiplicity; i++) { + localState.append(orderByVectors, payloadVectors); + } + } + localState.finalize(); + sharedState->mergeLocalState(&localState); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/top_k_scanner.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/top_k_scanner.cpp new file mode 100644 index 0000000000..239578bdbe --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/order_by/top_k_scanner.cpp @@ -0,0 +1,26 @@ +#include "processor/operator/order_by/top_k_scanner.h" + +namespace lbug { +namespace processor { + +void TopKLocalScanState::init(std::vector& outVectorPos, TopKSharedState& sharedState, + ResultSet& resultSet) { + for (auto& pos : outVectorPos) { + vectorsToScan.push_back(resultSet.getValueVector(pos).get()); + } + payloadScanner = sharedState.buffer->getScanner(); +} + +void TopKScan::initLocalStateInternal(lbug::processor::ResultSet* resultSet, + lbug::processor::ExecutionContext* /*context*/) { + localState->init(outVectorPos, *sharedState, *resultSet); +} + +bool TopKScan::getNextTuplesInternal(ExecutionContext* /*context*/) { + auto numTuplesRead = localState->scan(); + metrics->numOutputTuple.increase(numTuplesRead); + return numTuplesRead != 0; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/partitioner.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/partitioner.cpp new file mode 100644 index 0000000000..db60d46f16 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/partitioner.cpp @@ -0,0 +1,186 @@ +#include "processor/operator/partitioner.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string PartitionerPrintInfo::toString() const { + std::string result = "Indexes: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} + +void PartitionerFunctions::partitionRelData(ValueVector* key, ValueVector* partitionIdxes) { + KU_ASSERT(key->state == partitionIdxes->state && + key->dataType.getPhysicalType() == PhysicalTypeID::INT64); + for (auto i = 0u; i < key->state->getSelVector().getSelSize(); i++) { + const auto pos = key->state->getSelVector()[i]; + const partition_idx_t partitionIdx = + key->getValue(pos) >> StorageConfig::NODE_GROUP_SIZE_LOG2; + partitionIdxes->setValue(pos, partitionIdx); + } +} + +void CopyPartitionerSharedState::initialize(const logical_type_vec_t& columnTypes, + idx_t numPartitioners, const main::ClientContext* clientContext) { + PartitionerSharedState::initialize(columnTypes, numPartitioners, clientContext); + Partitioner::initializePartitioningStates(columnTypes, partitioningBuffers, numPartitions, + numPartitioners); +} + +void CopyPartitionerSharedState::merge( + const std::vector>& localPartitioningStates) { + std::unique_lock xLck{mtx}; + KU_ASSERT(partitioningBuffers.size() == localPartitioningStates.size()); + for (auto partitioningIdx = 0u; partitioningIdx < partitioningBuffers.size(); + partitioningIdx++) { + partitioningBuffers[partitioningIdx]->merge(*localPartitioningStates[partitioningIdx]); + } +} + +void CopyPartitionerSharedState::resetState(common::idx_t partitioningIdx) { + PartitionerSharedState::resetState(partitioningIdx); + partitioningBuffers[partitioningIdx].reset(); +} + +void PartitioningBuffer::merge(const PartitioningBuffer& localPartitioningState) const { + KU_ASSERT(partitions.size() == localPartitioningState.partitions.size()); + for (auto partitionIdx = 0u; partitionIdx < partitions.size(); partitionIdx++) { + auto& sharedPartition = partitions[partitionIdx]; + auto& localPartition = localPartitioningState.partitions[partitionIdx]; + sharedPartition->merge(*localPartition); + } +} + +Partitioner::Partitioner(PartitionerInfo info, PartitionerDataInfo dataInfo, + std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, std::unique_ptr printInfo) + : Sink{type_, std::move(child), id, std::move(printInfo)}, dataInfo{std::move(dataInfo)}, + info{std::move(info)}, sharedState{std::move(sharedState)} { + partitionIdxes = std::make_unique(LogicalTypeID::INT64); +} + +void Partitioner::initGlobalStateInternal(ExecutionContext* context) { + const auto clientContext = context->clientContext; + // If initialization is required + if (!sharedState->srcNodeTable) { + auto storageManager = StorageManager::Get(*clientContext); + auto catalog = catalog::Catalog::Get(*clientContext); + auto transaction = transaction::Transaction::Get(*clientContext); + auto fromTableID = + catalog->getTableCatalogEntry(transaction, dataInfo.fromTableName)->getTableID(); + auto toTableID = + catalog->getTableCatalogEntry(transaction, dataInfo.toTableName)->getTableID(); + sharedState->srcNodeTable = storageManager->getTable(fromTableID)->ptrCast(); + sharedState->dstNodeTable = storageManager->getTable(toTableID)->ptrCast(); + auto& relGroupEntry = catalog->getTableCatalogEntry(transaction, dataInfo.tableName) + ->constCast(); + auto relEntryInfo = relGroupEntry.getRelEntryInfo(fromTableID, toTableID); + sharedState->relTable = storageManager->getTable(relEntryInfo->oid)->ptrCast(); + } + sharedState->initialize(dataInfo.columnTypes, info.infos.size(), clientContext); +} + +void Partitioner::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + localState = std::make_unique(); + initializePartitioningStates(dataInfo.columnTypes, localState->partitioningBuffers, + sharedState->numPartitions, info.infos.size()); + for (const auto& evaluator : dataInfo.columnEvaluators) { + evaluator->init(*resultSet, context->clientContext); + } +} + +DataChunk Partitioner::constructDataChunk(const std::shared_ptr& state) const { + const auto numColumns = dataInfo.columnEvaluators.size(); + DataChunk dataChunk(numColumns, state); + for (auto i = 0u; i < numColumns; ++i) { + auto& evaluator = dataInfo.columnEvaluators[i]; + dataChunk.insert(i, evaluator->resultVector); + } + return dataChunk; +} + +void Partitioner::initializePartitioningStates(const logical_type_vec_t& columnTypes, + std::vector>& partitioningBuffers, + const std::array& numPartitions, + idx_t numPartitioners) { + partitioningBuffers.resize(numPartitioners); + for (auto partitioningIdx = 0u; partitioningIdx < numPartitioners; partitioningIdx++) { + const auto numPartition = numPartitions[partitioningIdx]; + auto partitioningBuffer = std::make_unique(); + partitioningBuffer->partitions.reserve(numPartition); + for (auto i = 0u; i < numPartition; i++) { + partitioningBuffer->partitions.push_back( + std::make_unique(LogicalType::copy(columnTypes))); + } + partitioningBuffers[partitioningIdx] = std::move(partitioningBuffer); + } +} + +void Partitioner::executeInternal(ExecutionContext* context) { + const auto relOffsetVector = resultSet->getValueVector(info.relOffsetDataPos); + while (children[0]->getNextTuple(context)) { + KU_ASSERT(dataInfo.columnEvaluators.size() >= 1); + const auto numRels = relOffsetVector->state->getSelVector().getSelSize(); + evaluateExpressions(numRels); + auto currentRelOffset = sharedState->relTable->reserveRelOffsets(numRels); + for (auto i = 0u; i < numRels; i++) { + const auto pos = relOffsetVector->state->getSelVector()[i]; + relOffsetVector->setValue(pos, currentRelOffset++); + } + for (auto partitioningIdx = 0u; partitioningIdx < info.infos.size(); partitioningIdx++) { + auto& partitionInfo = info.infos[partitioningIdx]; + auto keyVector = dataInfo.columnEvaluators[partitionInfo.keyIdx]->resultVector; + partitionIdxes->state = keyVector->state; + partitionInfo.partitionerFunc(keyVector.get(), partitionIdxes.get()); + auto chunkToCopyFrom = constructDataChunk(keyVector->state); + copyDataToPartitions(*MemoryManager::Get(*context->clientContext), partitioningIdx, + chunkToCopyFrom); + } + } + sharedState->merge(localState->partitioningBuffers); +} + +void Partitioner::evaluateExpressions(uint64_t numRels) const { + for (auto i = 0u; i < dataInfo.evaluateTypes.size(); ++i) { + auto evaluator = dataInfo.columnEvaluators[i].get(); + switch (dataInfo.evaluateTypes[i]) { + case ColumnEvaluateType::DEFAULT: { + evaluator->evaluate(numRels); + } break; + default: { + evaluator->evaluate(); + } + } + } +} + +void Partitioner::copyDataToPartitions(MemoryManager& memoryManager, + partition_idx_t partitioningIdx, const DataChunk& chunkToCopyFrom) const { + std::vector vectorsToAppend; + vectorsToAppend.reserve(chunkToCopyFrom.getNumValueVectors()); + for (auto j = 0u; j < chunkToCopyFrom.getNumValueVectors(); j++) { + vectorsToAppend.push_back(&chunkToCopyFrom.getValueVectorMutable(j)); + } + for (auto i = 0u; i < chunkToCopyFrom.state->getSelVector().getSelSize(); i++) { + const auto posToCopyFrom = chunkToCopyFrom.state->getSelVector()[i]; + const auto partitionIdx = partitionIdxes->getValue(posToCopyFrom); + KU_ASSERT( + partitionIdx < localState->getPartitioningBuffer(partitioningIdx)->partitions.size()); + const auto& partition = + localState->getPartitioningBuffer(partitioningIdx)->partitions[partitionIdx]; + partition->append(memoryManager, vectorsToAppend, i, 1); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/path_property_probe.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/path_property_probe.cpp new file mode 100644 index 0000000000..4528d7d4b0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/path_property_probe.cpp @@ -0,0 +1,327 @@ +#include "processor/operator/path_property_probe.h" + +#include "common/constants.h" +#include "function/hash/hash_functions.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void PathPropertyProbe::initLocalStateInternal(ResultSet* /*resultSet_*/, + ExecutionContext* /*context*/) { + localState = PathPropertyProbeLocalState(); + auto pathVector = resultSet->getValueVector(info.pathPos); + pathNodesVector = StructVector::getFieldVectorRaw(*pathVector, InternalKeyword::NODES); + pathRelsVector = StructVector::getFieldVectorRaw(*pathVector, InternalKeyword::RELS); + auto nodesDataVector = ListVector::getDataVector(pathNodesVector); + auto relsDataVector = ListVector::getDataVector(pathRelsVector); + pathNodeIDsDataVector = StructVector::getFieldVectorRaw(*nodesDataVector, InternalKeyword::ID); + pathNodeLabelsDataVector = + StructVector::getFieldVectorRaw(*nodesDataVector, InternalKeyword::LABEL); + pathRelIDsDataVector = StructVector::getFieldVectorRaw(*relsDataVector, InternalKeyword::ID); + pathRelLabelsDataVector = + StructVector::getFieldVectorRaw(*relsDataVector, InternalKeyword::LABEL); + pathSrcNodeIDsDataVector = + StructVector::getFieldVectorRaw(*relsDataVector, InternalKeyword::SRC); + pathDstNodeIDsDataVector = + StructVector::getFieldVectorRaw(*relsDataVector, InternalKeyword::DST); + for (auto fieldIdx : info.nodeFieldIndices) { + pathNodesPropertyDataVectors.push_back( + StructVector::getFieldVector(nodesDataVector, fieldIdx).get()); + } + for (auto fieldIdx : info.relFieldIndices) { + pathRelsPropertyDataVectors.push_back( + StructVector::getFieldVector(relsDataVector, fieldIdx).get()); + } + if (info.leftNodeIDPos.isValid()) { + inputLeftNodeIDVector = resultSet->getValueVector(info.leftNodeIDPos).get(); + inputRightNodeIDVector = resultSet->getValueVector(info.rightNodeIDPos).get(); + inputNodeIDsVector = resultSet->getValueVector(info.inputNodeIDsPos).get(); + inputRelIDsVector = resultSet->getValueVector(info.inputEdgeIDsPos).get(); + if (info.directionPos.isValid()) { + inputDirectionVector = resultSet->getValueVector(info.directionPos).get(); + } + } +} + +static void copyListEntry(const ValueVector& srcVector, ValueVector* dstVector) { + auto& selVector = srcVector.state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto pos = selVector[i]; + auto entry = srcVector.getValue(pos); + ListVector::addList(dstVector, entry.size); + dstVector->setValue(pos, entry); + } +} + +static void copyInternalID(ValueVector* srcVector, ValueVector* dstIDVector, + ValueVector* dstLabelVector, + const std::unordered_map& tableIDToName) { + auto srcDataVector = ListVector::getDataVector(srcVector); + for (auto i = 0u; i < ListVector::getDataVectorSize(srcVector); ++i) { + auto id = srcDataVector->getValue(i); + dstIDVector->setValue(i, id); + StringVector::addString(dstLabelVector, i, tableIDToName.at(id.tableID)); + } +} + +// TODO(Xiyang): revisit me. Instead of printing in src->dst order. Maybe left->right order make +// more sense. +// Truth table. +// Edge Direction | ExtendFromSource | Result +// FWD | T | T +// FWD | F | F +// BWD | T | F +// BWD | F | T +static bool isCorrectOrder(ValueVector* vector, sel_t pos, bool extendFromSource) { + if (extendFromSource) { + return vector->getValue(pos); + } + return !vector->getValue(pos); +} + +static void writeSrcDstNodeIDs(nodeID_t fromID, nodeID_t toID, ValueVector* directionDataVector, + ValueVector* srcNodeIDsDataVector, ValueVector* dstNodeIDsDataVector, sel_t pos, bool flag) { + if (isCorrectOrder(directionDataVector, pos, flag)) { + srcNodeIDsDataVector->setValue(pos, fromID); + dstNodeIDsDataVector->setValue(pos, toID); + } else { + srcNodeIDsDataVector->setValue(pos, toID); + dstNodeIDsDataVector->setValue(pos, fromID); + } +} + +bool PathPropertyProbe::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto sizeProbed = 0u; + // Copy node IDs + if (inputNodeIDsVector != nullptr) { + pathNodesVector->resetAuxiliaryBuffer(); + copyListEntry(*inputNodeIDsVector, pathNodesVector); + copyInternalID(inputNodeIDsVector, pathNodeIDsDataVector, pathNodeLabelsDataVector, + info.tableIDToName); + } + // Scan node properties + if (sharedState->nodeHashTableState != nullptr) { + auto nodeHashTable = sharedState->nodeHashTableState->getHashTable(); + auto nodeDataSize = ListVector::getDataVectorSize(pathNodesVector); + while (sizeProbed < nodeDataSize) { + auto sizeToProbe = + std::min(DEFAULT_VECTOR_CAPACITY, nodeDataSize - sizeProbed); + probe(nodeHashTable, sizeProbed, sizeToProbe, pathNodeIDsDataVector, + pathNodesPropertyDataVectors, info.nodeTableColumnIndices); + sizeProbed += sizeToProbe; + } + } + // Copy rel IDs + if (inputRelIDsVector != nullptr) { + pathRelsVector->resetAuxiliaryBuffer(); + copyListEntry(*inputRelIDsVector, pathRelsVector); + copyInternalID(inputRelIDsVector, pathRelIDsDataVector, pathRelLabelsDataVector, + info.tableIDToName); + } + // Scan rel property + if (sharedState->relHashTableState != nullptr) { + auto relHashTable = sharedState->relHashTableState->getHashTable(); + auto relDataSize = ListVector::getDataVectorSize(pathRelsVector); + sizeProbed = 0u; + while (sizeProbed < relDataSize) { + auto sizeToProbe = + std::min(DEFAULT_VECTOR_CAPACITY, relDataSize - sizeProbed); + probe(relHashTable, sizeProbed, sizeToProbe, pathRelIDsDataVector, + pathRelsPropertyDataVectors, info.relTableColumnIndices); + sizeProbed += sizeToProbe; + } + } + if (inputNodeIDsVector == nullptr || inputRelIDsVector == nullptr) { + return true; + } + auto& selVector = inputNodeIDsVector->state->getSelVector(); + auto inputNodeIDsDataVector = ListVector::getDataVector(inputNodeIDsVector); + // Copy rel src&dst IDs + + switch (info.extendDirection) { + case ExtendDirection::FWD: { + if (info.extendFromLeft) { + // Example graph src->1->2->3->dst + // Input: src, dst, [1, 2, 3] + // Output: + // - srcIDs [src, 1, 2, 3] + // - dstIDs [1, 2, 3, dst] + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto leftNodeID = inputLeftNodeIDVector->getValue(i); + auto rightNodeID = inputRightNodeIDVector->getValue(i); + auto nodeListEntry = inputNodeIDsVector->getValue(i); + auto relListEntry = inputRelIDsVector->getValue(i); + if (relListEntry.size == 0) { + continue; + } + for (auto j = 0u; j < nodeListEntry.size; ++j) { + auto id = inputNodeIDsDataVector->getValue(nodeListEntry.offset + j); + pathSrcNodeIDsDataVector->setValue(relListEntry.offset + j + 1, id); + pathDstNodeIDsDataVector->setValue(relListEntry.offset + j, id); + } + pathSrcNodeIDsDataVector->setValue(relListEntry.offset, leftNodeID); + pathDstNodeIDsDataVector->setValue(relListEntry.offset + relListEntry.size - 1, + rightNodeID); + } + } else { + // Example graph src<-1<-2<-3<-dst + // Input: src, dst, [1, 2, 3] + // Output: + // - srcIDs [1, 2, 3, dst] + // - dstIDs [src, 1, 2, 3] + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto leftNodeID = inputLeftNodeIDVector->getValue(i); + auto rightNodeID = inputRightNodeIDVector->getValue(i); + auto nodeListEntry = inputNodeIDsVector->getValue(i); + auto relListEntry = inputRelIDsVector->getValue(i); + if (relListEntry.size == 0) { + continue; + } + for (auto j = 0u; j < nodeListEntry.size; ++j) { + auto id = inputNodeIDsDataVector->getValue(nodeListEntry.offset + j); + pathSrcNodeIDsDataVector->setValue(relListEntry.offset + j, id); + pathDstNodeIDsDataVector->setValue(relListEntry.offset + j + 1, id); + } + pathSrcNodeIDsDataVector->setValue(relListEntry.offset + relListEntry.size - 1, + rightNodeID); + pathDstNodeIDsDataVector->setValue(relListEntry.offset, leftNodeID); + } + } + } break; + case common::ExtendDirection::BWD: { + if (info.extendFromLeft) { + // Example graph src<-1<-2<-3<-dst + // Input: src, dst, [1, 2, 3] + // Output: + // - srcIDs [1, 2, 3, dst] + // - dstIDs [src, 1, 2, 3] + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto leftNodeID = inputLeftNodeIDVector->getValue(i); + auto rightNodeID = inputRightNodeIDVector->getValue(i); + auto nodeListEntry = inputNodeIDsVector->getValue(i); + auto relListEntry = inputRelIDsVector->getValue(i); + if (relListEntry.size == 0) { + continue; + } + for (auto j = 0u; j < nodeListEntry.size; ++j) { + auto id = inputNodeIDsDataVector->getValue(nodeListEntry.offset + j); + pathSrcNodeIDsDataVector->setValue(relListEntry.offset + j, id); + pathDstNodeIDsDataVector->setValue(relListEntry.offset + j + 1, id); + } + pathSrcNodeIDsDataVector->setValue(relListEntry.offset + relListEntry.size - 1, + rightNodeID); + pathDstNodeIDsDataVector->setValue(relListEntry.offset, leftNodeID); + } + } else { + // Example graph src->1->2->3->dst + // Input: src, dst, [1, 2, 3] + // Output: + // - srcIDs [src, 1, 2, 3] + // - dstIDs [1, 2, 3, dst] + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto leftNodeID = inputLeftNodeIDVector->getValue(i); + auto rightNodeID = inputRightNodeIDVector->getValue(i); + auto nodeListEntry = inputNodeIDsVector->getValue(i); + auto relListEntry = inputRelIDsVector->getValue(i); + if (relListEntry.size == 0) { + continue; + } + for (auto j = 0u; j < nodeListEntry.size; ++j) { + auto id = inputNodeIDsDataVector->getValue(nodeListEntry.offset + j); + pathSrcNodeIDsDataVector->setValue(relListEntry.offset + j + 1, id); + pathDstNodeIDsDataVector->setValue(relListEntry.offset + j, id); + } + pathSrcNodeIDsDataVector->setValue(relListEntry.offset, leftNodeID); + pathDstNodeIDsDataVector->setValue(relListEntry.offset + relListEntry.size - 1, + rightNodeID); + } + } + } break; + case common::ExtendDirection::BOTH: { + auto directionDataVector = ListVector::getDataVector(inputDirectionVector); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto leftNodeID = inputLeftNodeIDVector->getValue(i); + auto rightNodeID = inputRightNodeIDVector->getValue(i); + auto nodeListEntry = inputNodeIDsVector->getValue(i); + auto relListEntry = inputRelIDsVector->getValue(i); + if (relListEntry.size == 0) { + continue; + } + if (nodeListEntry.size == 0) { + KU_ASSERT(relListEntry.size == 1); + if (isCorrectOrder(directionDataVector, relListEntry.offset, info.extendFromLeft)) { + pathSrcNodeIDsDataVector->setValue(relListEntry.offset, leftNodeID); + pathDstNodeIDsDataVector->setValue(relListEntry.offset, rightNodeID); + } else { + pathSrcNodeIDsDataVector->setValue(relListEntry.offset, rightNodeID); + pathDstNodeIDsDataVector->setValue(relListEntry.offset, leftNodeID); + } + continue; + } + for (auto j = 0u; j < nodeListEntry.size - 1; ++j) { + auto from = inputNodeIDsDataVector->getValue(nodeListEntry.offset + j); + auto to = inputNodeIDsDataVector->getValue(nodeListEntry.offset + j + 1); + writeSrcDstNodeIDs(from, to, directionDataVector, pathSrcNodeIDsDataVector, + pathDstNodeIDsDataVector, relListEntry.offset + j + 1, info.extendFromLeft); + } + writeSrcDstNodeIDs(leftNodeID, + inputNodeIDsDataVector->getValue(nodeListEntry.offset), + directionDataVector, pathSrcNodeIDsDataVector, pathDstNodeIDsDataVector, + relListEntry.offset, info.extendFromLeft); + writeSrcDstNodeIDs(inputNodeIDsDataVector->getValue( + nodeListEntry.offset + nodeListEntry.size - 1), + rightNodeID, directionDataVector, pathSrcNodeIDsDataVector, + pathDstNodeIDsDataVector, relListEntry.offset + relListEntry.size - 1, + info.extendFromLeft); + } + } break; + default: + KU_UNREACHABLE; + } + return true; +} + +void PathPropertyProbe::probe(lbug::processor::JoinHashTable* hashTable, uint64_t sizeProbed, + uint64_t sizeToProbe, ValueVector* idVector, const std::vector& propertyVectors, + const std::vector& colIndicesToScan) const { + // Hash + for (auto i = 0u; i < sizeToProbe; ++i) { + function::Hash::operation(idVector->getValue(sizeProbed + i), + localState.hashes[i]); + } + // Probe hash + for (auto i = 0u; i < sizeToProbe; ++i) { + localState.probedTuples[i] = hashTable->getTupleForHash(localState.hashes[i]); + } + // Match value + for (auto i = 0u; i < sizeToProbe; ++i) { + while (localState.probedTuples[i]) { + auto currentTuple = localState.probedTuples[i]; + if (*(internalID_t*)currentTuple == idVector->getValue(sizeProbed + i)) { + localState.matchedTuples[i] = currentTuple; + break; + } + localState.probedTuples[i] = *hashTable->getPrevTuple(currentTuple); + } + KU_ASSERT(localState.matchedTuples[i] != nullptr); + } + // Scan table + auto factorizedTable = hashTable->getFactorizedTable(); + for (auto i = 0u; i < sizeToProbe; ++i) { + auto tuple = localState.matchedTuples[i]; + for (auto j = 0u; j < propertyVectors.size(); ++j) { + auto propertyVector = propertyVectors[j]; + auto colIdx = colIndicesToScan[j]; + factorizedTable->readFlatColToFlatVector(tuple, colIdx, *propertyVector, + sizeProbed + i); + } + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/CMakeLists.txt new file mode 100644 index 0000000000..969da73e37 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/CMakeLists.txt @@ -0,0 +1,24 @@ +add_subdirectory(reader) +add_subdirectory(writer/parquet) + +add_library(lbug_processor_operator_persistent + OBJECT + batch_insert_error_handler.cpp + node_batch_insert.cpp + node_batch_insert_error_handler.cpp + copy_rel_batch_insert.cpp + rel_batch_insert.cpp + copy_to.cpp + delete.cpp + delete_executor.cpp + index_builder.cpp + insert.cpp + insert_executor.cpp + merge.cpp + set.cpp + set_executor.cpp +) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/batch_insert_error_handler.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/batch_insert_error_handler.cpp new file mode 100644 index 0000000000..c6fafd8e4a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/batch_insert_error_handler.cpp @@ -0,0 +1,92 @@ +#include "processor/operator/persistent/batch_insert_error_handler.h" + +#include "common/exception/copy.h" +#include "common/uniq_lock.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "processor/warning_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +BatchInsertCachedError::BatchInsertCachedError(std::string message, + const std::optional& warningData) + : message(std::move(message)), warningData(warningData) {} + +BatchInsertErrorHandler::BatchInsertErrorHandler(ExecutionContext* context, bool ignoreErrors, + std::shared_ptr sharedErrorCounter, std::mutex* sharedErrorCounterMtx) + : ignoreErrors(ignoreErrors), + warningLimit( + std::min(context->clientContext->getClientConfig()->warningLimit, LOCAL_WARNING_LIMIT)), + context(context), currentInsertIdx(0), sharedErrorCounterMtx(sharedErrorCounterMtx), + sharedErrorCounter(std::move(sharedErrorCounter)) {} + +void BatchInsertErrorHandler::addNewVectorsIfNeeded() { + KU_ASSERT(currentInsertIdx <= cachedErrors.size()); + if (currentInsertIdx == cachedErrors.size()) { + cachedErrors.emplace_back(); + } +} + +bool BatchInsertErrorHandler::getIgnoreErrors() const { + return ignoreErrors; +} + +void BatchInsertErrorHandler::handleError(std::string message, + const std::optional& warningData) { + handleError(BatchInsertCachedError{std::move(message), warningData}); +} + +void BatchInsertErrorHandler::handleError(BatchInsertCachedError error) { + if (!ignoreErrors) { + throw common::CopyException(error.message); + } + + if (getNumErrors() >= warningLimit) { + flushStoredErrors(); + } + + addNewVectorsIfNeeded(); + cachedErrors[currentInsertIdx] = std::move(error); + ++currentInsertIdx; +} + +void BatchInsertErrorHandler::flushStoredErrors() { + std::vector unpopulatedErrors; + + for (row_idx_t i = 0; i < getNumErrors(); ++i) { + auto& error = cachedErrors[i]; + CopyFromFileError warningToAdd{std::move(error.message), {}, false}; + if (error.warningData.has_value()) { + warningToAdd.completedLine = true; + warningToAdd.warningData = error.warningData.value(); + } + unpopulatedErrors.push_back(warningToAdd); + } + + if (!unpopulatedErrors.empty()) { + KU_ASSERT(ignoreErrors); + WarningContext::Get(*context->clientContext)->appendWarningMessages(unpopulatedErrors); + } + + if (!unpopulatedErrors.empty() && sharedErrorCounter != nullptr) { + KU_ASSERT(sharedErrorCounterMtx); + common::UniqLock lockGuard{*sharedErrorCounterMtx}; + *sharedErrorCounter += unpopulatedErrors.size(); + } + + clearErrors(); +} + +void BatchInsertErrorHandler::clearErrors() { + currentInsertIdx = 0; +} + +row_idx_t BatchInsertErrorHandler::getNumErrors() const { + return currentInsertIdx; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/copy_rel_batch_insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/copy_rel_batch_insert.cpp new file mode 100644 index 0000000000..5fabcde2ac --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/copy_rel_batch_insert.cpp @@ -0,0 +1,92 @@ +#include "processor/operator/persistent/copy_rel_batch_insert.h" + +#include "storage/storage_utils.h" +#include "storage/table/csr_chunked_node_group.h" + +namespace lbug { +namespace processor { + +static void setOffsetToWithinNodeGroup(storage::ColumnChunkData& chunk, + common::offset_t startOffset) { + KU_ASSERT(chunk.getDataType().getPhysicalType() == common::PhysicalTypeID::INTERNAL_ID); + const auto offsets = reinterpret_cast(chunk.getData()); + for (auto i = 0u; i < chunk.getNumValues(); i++) { + offsets[i] -= startOffset; + } +} + +std::unique_ptr CopyRelBatchInsert::initExecutionState( + const PartitionerSharedState& partitionerSharedState, const RelBatchInsertInfo& relInfo, + common::node_group_idx_t nodeGroupIdx) { + auto executionState = std::make_unique(); + executionState->partitioningBuffer = + partitionerSharedState.constCast().getPartitionBuffer( + relInfo.partitioningIdx, nodeGroupIdx); + const auto startNodeOffset = storage::StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + for (auto& chunkedGroup : executionState->partitioningBuffer->getChunkedGroups()) { + setOffsetToWithinNodeGroup(chunkedGroup->getColumnChunk(relInfo.boundNodeOffsetColumnID), + startNodeOffset); + } + return executionState; +} + +void CopyRelBatchInsert::populateCSRLengthsInternal(const storage::InMemChunkedCSRHeader& csrHeader, + common::offset_t numNodes, storage::InMemChunkedNodeGroupCollection& partition, + common::column_id_t boundNodeOffsetColumn) { + KU_ASSERT(numNodes == csrHeader.length->getNumValues() && + numNodes == csrHeader.offset->getNumValues()); + const auto lengthData = reinterpret_cast(csrHeader.length->getData()); + std::fill(lengthData, lengthData + numNodes, 0); + for (auto& chunkedGroup : partition.getChunkedGroups()) { + auto& offsetChunk = chunkedGroup->getColumnChunk(boundNodeOffsetColumn); + for (auto i = 0u; i < offsetChunk.getNumValues(); i++) { + const auto nodeOffset = offsetChunk.getValue(i); + KU_ASSERT(nodeOffset < numNodes); + lengthData[nodeOffset]++; + } + } +} + +void CopyRelBatchInsert::populateCSRLengths(RelBatchInsertExecutionState& executionState, + storage::InMemChunkedCSRHeader& csrHeader, common::offset_t numNodes, + const RelBatchInsertInfo& relInfo) { + auto& copyRelExecutionState = executionState.cast(); + populateCSRLengthsInternal(csrHeader, numNodes, *copyRelExecutionState.partitioningBuffer, + relInfo.boundNodeOffsetColumnID); +} + +void CopyRelBatchInsert::setRowIdxFromCSROffsets(storage::ColumnChunkData& rowIdxChunk, + storage::ColumnChunkData& csrOffsetChunk) { + KU_ASSERT(rowIdxChunk.getDataType().getPhysicalType() == common::PhysicalTypeID::INTERNAL_ID); + for (auto i = 0u; i < rowIdxChunk.getNumValues(); i++) { + const auto nodeOffset = rowIdxChunk.getValue(i); + const auto csrOffset = csrOffsetChunk.getValue(nodeOffset); + rowIdxChunk.setValue(csrOffset, i); + // Increment current csr offset for nodeOffset by 1. + csrOffsetChunk.setValue(csrOffset + 1, nodeOffset); + } +} + +void CopyRelBatchInsert::finalizeStartCSROffsets(RelBatchInsertExecutionState& executionState, + storage::InMemChunkedCSRHeader& csrHeader, const RelBatchInsertInfo& relInfo) { + auto& copyRelExecutionState = executionState.cast(); + for (auto& chunkedGroup : copyRelExecutionState.partitioningBuffer->getChunkedGroups()) { + auto& offsetChunk = chunkedGroup->getColumnChunk(relInfo.boundNodeOffsetColumnID); + // We reuse bound node offset column to store row idx for each rel in the node group. + setRowIdxFromCSROffsets(offsetChunk, *csrHeader.offset); + } +} + +void CopyRelBatchInsert::writeToTable(RelBatchInsertExecutionState& executionState, + const storage::InMemChunkedCSRHeader&, const RelBatchInsertLocalState& localState, + BatchInsertSharedState& sharedState, const RelBatchInsertInfo& relInfo) { + auto& copyRelExecutionState = executionState.cast(); + for (auto& chunkedGroup : copyRelExecutionState.partitioningBuffer->getChunkedGroups()) { + sharedState.incrementNumRows(chunkedGroup->getNumRows()); + // we reused the bound node offset column to store row idx + // the row idx column determines which rows to write each entry in the chunked group to + localState.chunkedGroup->write(*chunkedGroup, relInfo.boundNodeOffsetColumnID); + } +} +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/copy_to.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/copy_to.cpp new file mode 100644 index 0000000000..f304679288 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/copy_to.cpp @@ -0,0 +1,44 @@ +#include "processor/operator/persistent/copy_to.h" + +#include "processor/execution_context.h" + +namespace lbug { +namespace processor { + +std::string CopyToPrintInfo::toString() const { + std::string result = ""; + result += "Export: "; + for (auto& name : columnNames) { + result += name + ", "; + } + result += "To: " + fileName; + return result; +} + +void CopyTo::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + localState.exportFuncLocalState = + info.exportFunc.initLocalState(*context->clientContext, *info.bindData, info.isFlatVec); + localState.inputVectors.reserve(info.inputVectorPoses.size()); + for (auto& inputVectorPos : info.inputVectorPoses) { + localState.inputVectors.push_back(resultSet->getValueVector(inputVectorPos)); + } +} + +void CopyTo::initGlobalStateInternal(lbug::processor::ExecutionContext* context) { + sharedState->init(*context->clientContext, *info.bindData); +} + +void CopyTo::finalize(ExecutionContext* /*context*/) { + info.exportFunc.finalize(*sharedState); +} + +void CopyTo::executeInternal(processor::ExecutionContext* context) { + while (children[0]->getNextTuple(context)) { + info.exportFunc.sink(*sharedState, *localState.exportFuncLocalState, *info.bindData, + localState.inputVectors); + } + info.exportFunc.combine(*sharedState, *localState.exportFuncLocalState); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/delete.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/delete.cpp new file mode 100644 index 0000000000..a943a769c4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/delete.cpp @@ -0,0 +1,64 @@ +#include "processor/operator/persistent/delete.h" + +namespace lbug { +namespace processor { + +std::string DeleteNodePrintInfo::toString() const { + std::string result = "Type: "; + switch (deleteType) { + case common::DeleteNodeType::DELETE: + result += "Delete Nodes"; + break; + case common::DeleteNodeType::DETACH_DELETE: + result += "Detach Delete Nodes"; + break; + } + result += ", From: "; + for (const auto& expression : expressions) { + result += expression->toString() + ", "; + } + return result; +} + +std::string DeleteRelPrintInfo::toString() const { + std::string result = "Type: Delete Relationships, From: "; + for (const auto& expression : expressions) { + result += expression->toString() + ", "; + } + return result; +} + +void DeleteNode::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + for (auto& executor : executors) { + executor->init(resultSet, context); + } +} + +bool DeleteNode::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + for (auto& executor : executors) { + executor->delete_(context); + } + return true; +} + +void DeleteRel::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + for (auto& executor : executors) { + executor->init(resultSet, context); + } +} + +bool DeleteRel::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + for (auto& executor : executors) { + executor->delete_(context); + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/delete_executor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/delete_executor.cpp new file mode 100644 index 0000000000..734c80173e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/delete_executor.cpp @@ -0,0 +1,157 @@ +#include "processor/operator/persistent/delete_executor.h" + +#include + +#include "common/assert.h" +#include "common/exception/message.h" +#include "common/vector/value_vector.h" +#include "processor/execution_context.h" +#include "storage/table/rel_table.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +void NodeDeleteInfo::init(const ResultSet& resultSet) { + nodeIDVector = resultSet.getValueVector(nodeIDPos).get(); +} + +void NodeTableDeleteInfo::init(const ResultSet& resultSet) { + pkVector = resultSet.getValueVector(pkPos).get(); +} + +static void throwDeleteNodeWithConnectedEdgesError(const std::string& tableName, + offset_t nodeOffset, RelDataDirection direction) { + throw RuntimeException(ExceptionMessage::violateDeleteNodeWithConnectedEdgesConstraint( + tableName, std::to_string(nodeOffset), RelDirectionUtils::relDirectionToString(direction))); +} + +void NodeTableDeleteInfo::deleteFromRelTable(Transaction* transaction, + ValueVector* nodeIDVector) const { + for (auto& relTable : fwdRelTables) { + relTable->throwIfNodeHasRels(transaction, RelDataDirection::FWD, nodeIDVector, + throwDeleteNodeWithConnectedEdgesError); + } + for (auto& relTable : bwdRelTables) { + relTable->throwIfNodeHasRels(transaction, RelDataDirection::BWD, nodeIDVector, + throwDeleteNodeWithConnectedEdgesError); + } +} + +void NodeTableDeleteInfo::detachDeleteFromRelTable(Transaction* transaction, + RelTableDeleteState* detachDeleteState) const { + for (auto& relTable : fwdRelTables) { + detachDeleteState->detachDeleteDirection = RelDataDirection::FWD; + relTable->detachDelete(transaction, detachDeleteState); + } + for (auto& relTable : bwdRelTables) { + detachDeleteState->detachDeleteDirection = RelDataDirection::BWD; + relTable->detachDelete(transaction, detachDeleteState); + } +} + +void NodeDeleteExecutor::init(ResultSet* resultSet, ExecutionContext*) { + info.init(*resultSet); + if (info.deleteType == DeleteNodeType::DETACH_DELETE) { + const auto tempSharedState = std::make_shared(); + dstNodeIDVector = std::make_unique(LogicalType::INTERNAL_ID()); + relIDVector = std::make_unique(LogicalType::INTERNAL_ID()); + dstNodeIDVector->setState(tempSharedState); + relIDVector->setState(tempSharedState); + detachDeleteState = std::make_unique(*info.nodeIDVector, + *dstNodeIDVector, *relIDVector); + } +} + +void SingleLabelNodeDeleteExecutor::init(ResultSet* resultSet, ExecutionContext* context) { + NodeDeleteExecutor::init(resultSet, context); + tableInfo.init(*resultSet); +} + +void SingleLabelNodeDeleteExecutor::delete_(ExecutionContext* context) { + KU_ASSERT(tableInfo.pkVector->state == info.nodeIDVector->state); + auto deleteState = + std::make_unique(*info.nodeIDVector, *tableInfo.pkVector); + auto transaction = Transaction::Get(*context->clientContext); + if (!tableInfo.table->delete_(transaction, *deleteState)) { + return; + } + switch (info.deleteType) { + case DeleteNodeType::DELETE: { + tableInfo.deleteFromRelTable(transaction, info.nodeIDVector); + } break; + case DeleteNodeType::DETACH_DELETE: { + tableInfo.detachDeleteFromRelTable(transaction, detachDeleteState.get()); + } break; + default: + KU_UNREACHABLE; + } +} + +void MultiLabelNodeDeleteExecutor::init(ResultSet* resultSet, ExecutionContext* context) { + NodeDeleteExecutor::init(resultSet, context); + for (auto& [_, tableInfo] : tableInfos) { + tableInfo.init(*resultSet); + } +} + +void MultiLabelNodeDeleteExecutor::delete_(ExecutionContext* context) { + auto& nodeIDSelVector = info.nodeIDVector->state->getSelVector(); + KU_ASSERT(nodeIDSelVector.getSelSize() == 1); + const auto pos = nodeIDSelVector[0]; + if (info.nodeIDVector->isNull(pos)) { + return; + } + const auto nodeID = info.nodeIDVector->getValue(pos); + const auto& tableInfo = tableInfos.at(nodeID.tableID); + auto deleteState = + std::make_unique(*info.nodeIDVector, *tableInfo.pkVector); + auto transaction = Transaction::Get(*context->clientContext); + if (!tableInfo.table->delete_(transaction, *deleteState)) { + return; + } + switch (info.deleteType) { + case DeleteNodeType::DELETE: { + tableInfo.deleteFromRelTable(transaction, info.nodeIDVector); + } break; + case DeleteNodeType::DETACH_DELETE: { + tableInfo.detachDeleteFromRelTable(transaction, detachDeleteState.get()); + } break; + default: + KU_UNREACHABLE; + } +} + +void RelDeleteInfo::init(const ResultSet& resultSet) { + srcNodeIDVector = resultSet.getValueVector(srcNodeIDPos).get(); + dstNodeIDVector = resultSet.getValueVector(dstNodeIDPos).get(); + relIDVector = resultSet.getValueVector(relIDPos).get(); +} + +void RelDeleteExecutor::init(ResultSet* resultSet, ExecutionContext*) { + info.init(*resultSet); +} + +void SingleLabelRelDeleteExecutor::delete_(ExecutionContext* context) { + auto deleteState = std::make_unique(*info.srcNodeIDVector, + *info.dstNodeIDVector, *info.relIDVector); + table->delete_(Transaction::Get(*context->clientContext), *deleteState); +} + +void MultiLabelRelDeleteExecutor::delete_(ExecutionContext* context) { + auto& idSelVector = info.relIDVector->state->getSelVector(); + KU_ASSERT(idSelVector.getSelSize() == 1); + const auto pos = idSelVector[0]; + const auto relID = info.relIDVector->getValue(pos); + KU_ASSERT(tableIDToTableMap.contains(relID.tableID)); + auto table = tableIDToTableMap.at(relID.tableID); + auto deleteState = std::make_unique(*info.srcNodeIDVector, + *info.dstNodeIDVector, *info.relIDVector); + table->delete_(Transaction::Get(*context->clientContext), *deleteState); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/index_builder.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/index_builder.cpp new file mode 100644 index 0000000000..35fed7e4e5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/index_builder.cpp @@ -0,0 +1,210 @@ +#include "processor/operator/persistent/index_builder.h" + +#include + +#include "common/assert.h" +#include "common/cast.h" +#include "common/exception/copy.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "common/types/ku_string.h" +#include "storage/index/hash_index_utils.h" +#include "storage/table/node_table.h" +#include "storage/table/string_chunk_data.h" + +namespace lbug { +namespace processor { + +using namespace lbug::common; +using namespace lbug::storage; + +template +bool IndexBufferWithWarningData::full() const { + return indexBuffer.full() || (warningDataBuffer != nullptr && warningDataBuffer->full()); +} + +template +void IndexBufferWithWarningData::append(T key, common::offset_t value, + OptionalWarningSourceData&& warningData) { + indexBuffer.push_back(std::make_pair(key, value)); + if (warningData.has_value()) { + if (warningDataBuffer == nullptr) { + warningDataBuffer = std::make_unique(); + } + warningDataBuffer->push_back(warningData.value()); + } +} + +IndexBuilderGlobalQueues::IndexBuilderGlobalQueues(transaction::Transaction* transaction, + NodeTable* nodeTable) + : nodeTable(nodeTable), transaction{transaction} { + TypeUtils::visit( + pkTypeID(), [&](ku_string_t) { queues.emplace>(); }, + [&](T) { queues.emplace>(); }, [](auto) { KU_UNREACHABLE; }); +} + +PhysicalTypeID IndexBuilderGlobalQueues::pkTypeID() const { + return nodeTable->getPKIndex()->keyTypeID(); +} + +void IndexBuilderGlobalQueues::consume(NodeBatchInsertErrorHandler& errorHandler) { + for (auto index = 0u; index < NUM_HASH_INDEXES; index++) { + maybeConsumeIndex(index, errorHandler); + } +} + +void IndexBuilderGlobalQueues::maybeConsumeIndex(size_t index, + NodeBatchInsertErrorHandler& errorHandler) { + auto& pkIndex = nodeTable->getPKIndex()->cast(); + if (!pkIndex.tryLockTypedIndex(index)) { + return; + } + + std::visit( + [&](auto&& queues) { + using T = std::decay_t; + auto lck = pkIndex.adoptLockOfTypedIndex(index); + IndexBufferWithWarningData bufferWithWarningData; + while (queues.array[index].pop(bufferWithWarningData)) { + auto& buffer = bufferWithWarningData.indexBuffer; + auto& warningDataBuffer = bufferWithWarningData.warningDataBuffer; + uint64_t insertBufferOffset = 0; + while (insertBufferOffset < buffer.size()) { + auto numValuesInserted = pkIndex.appendWithIndexPosNoLock(transaction, buffer, + insertBufferOffset, index, + [&](offset_t offset) { return nodeTable->isVisible(transaction, offset); }); + if (numValuesInserted < buffer.size() - insertBufferOffset) { + const auto& erroneousEntry = buffer[insertBufferOffset + numValuesInserted]; + OptionalWarningSourceData erroneousEntryWarningData; + if (warningDataBuffer != nullptr) { + erroneousEntryWarningData = + (*warningDataBuffer)[insertBufferOffset + numValuesInserted]; + } + errorHandler.handleError( + IndexBuilderError{.message = ExceptionMessage::duplicatePKException( + TypeUtils::toString(erroneousEntry.first)), + .key = erroneousEntry.first, + .nodeID = + nodeID_t{ + erroneousEntry.second, + nodeTable->getTableID(), + }, + .warningData = erroneousEntryWarningData}); + insertBufferOffset += 1; // skip the erroneous index then continue + } + insertBufferOffset += numValuesInserted; + } + } + return; + }, + std::move(queues)); +} + +IndexBuilderLocalBuffers::IndexBuilderLocalBuffers(IndexBuilderGlobalQueues& globalQueues) + : globalQueues(&globalQueues) { + TypeUtils::visit( + globalQueues.pkTypeID(), + [&](ku_string_t) { buffers = std::make_unique>(); }, + [&](T) { buffers = std::make_unique>(); }, + [](auto) { KU_UNREACHABLE; }); +} + +void IndexBuilderLocalBuffers::flush(NodeBatchInsertErrorHandler& errorHandler) { + std::visit( + [&](auto&& buffers) { + for (auto i = 0u; i < buffers->size(); i++) { + globalQueues->insert(i, std::move((*buffers)[i]), errorHandler); + } + }, + buffers); +} + +IndexBuilder::IndexBuilder(std::shared_ptr sharedState) + : sharedState(std::move(sharedState)), localBuffers(this->sharedState->globalQueues) {} + +void IndexBuilderSharedState::quitProducer() { + if (producers.fetch_sub(1, std::memory_order_relaxed) == 1) { + done.store(true, std::memory_order_relaxed); + } +} + +static OptionalWarningSourceData getWarningDataFromChunks( + const std::vector& chunks, common::idx_t posInChunk) { + OptionalWarningSourceData ret; + if (!chunks.empty()) { + ret = WarningSourceData::constructFromData(chunks, posInChunk); + } + return ret; +} + +void IndexBuilder::insert(const ColumnChunkData& chunk, + const std::vector& warningData, offset_t nodeOffset, offset_t numNodes, + NodeBatchInsertErrorHandler& errorHandler) { + TypeUtils::visit( + chunk.getDataType().getPhysicalType(), + [&](T) { + for (auto i = 0u; i < numNodes; i++) { + if (checkNonNullConstraint(chunk, warningData, nodeOffset, i, errorHandler)) { + auto value = chunk.getValue(i); + localBuffers.insert(value, nodeOffset + i, + getWarningDataFromChunks(warningData, i), errorHandler); + } + } + }, + [&](ku_string_t) { + auto& stringColumnChunk = ku_dynamic_cast(chunk); + for (auto i = 0u; i < numNodes; i++) { + if (checkNonNullConstraint(chunk, warningData, nodeOffset, i, errorHandler)) { + auto value = stringColumnChunk.getValue(i); + localBuffers.insert(std::move(value), nodeOffset + i, + getWarningDataFromChunks(warningData, i), errorHandler); + } + } + }, + [&](auto) { + throw CopyException(ExceptionMessage::invalidPKType(chunk.getDataType().toString())); + }); +} + +void IndexBuilder::finishedProducing(NodeBatchInsertErrorHandler& errorHandler) { + localBuffers.flush(errorHandler); + sharedState->consume(errorHandler); + while (!sharedState->isDone()) { + std::this_thread::sleep_for(std::chrono::microseconds(500)); + sharedState->consume(errorHandler); + } +} + +void IndexBuilder::finalize(ExecutionContext* /*context*/, + NodeBatchInsertErrorHandler& errorHandler) { + // Flush anything added by last node group. + localBuffers.flush(errorHandler); + + sharedState->consume(errorHandler); +} + +bool IndexBuilder::checkNonNullConstraint(const ColumnChunkData& chunk, + const std::vector& warningData, offset_t nodeOffset, + offset_t chunkOffset, NodeBatchInsertErrorHandler& errorHandler) { + const auto* nullChunk = chunk.getNullData(); + if (nullChunk->isNull(chunkOffset)) { + TypeUtils::visit( + chunk.getDataType().getPhysicalType(), + [&](struct_entry_t) { + // primary key cannot be struct + KU_UNREACHABLE; + }, + [&](T) { + errorHandler.handleError({.message = ExceptionMessage::nullPKException(), + .key = {}, + .nodeID = + nodeID_t{nodeOffset + chunkOffset, sharedState->nodeTable->getTableID()}, + .warningData = getWarningDataFromChunks(warningData, chunkOffset)}); + }); + return false; + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/insert.cpp new file mode 100644 index 0000000000..2e457efe9c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/insert.cpp @@ -0,0 +1,42 @@ +#include "processor/operator/persistent/insert.h" + +#include "binder/expression/expression_util.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string InsertPrintInfo::toString() const { + std::string result = "Expressions: "; + result += binder::ExpressionUtil::toString(expressions); + result += ", Action: "; + result += ConflictActionUtil::toString(action); + return result; +} + +void Insert::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + for (auto& executor : nodeExecutors) { + executor.init(resultSet, context); + } + for (auto& executor : relExecutors) { + executor.init(resultSet, context); + } +} + +bool Insert::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + for (auto& executor : nodeExecutors) { + executor.insert(context->clientContext); + } + for (auto& executor : relExecutors) { + executor.insert(context->clientContext); + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/insert_executor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/insert_executor.cpp new file mode 100644 index 0000000000..5caad25d9d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/insert_executor.cpp @@ -0,0 +1,193 @@ +#include "processor/operator/persistent/insert_executor.h" + +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +void NodeInsertInfo::init(const ResultSet& resultSet) { + nodeIDVector = resultSet.getValueVector(nodeIDPos).get(); + for (auto& pos : columnsPos) { + if (pos.isValid()) { + columnVectors.push_back(resultSet.getValueVector(pos).get()); + } else { + columnVectors.push_back(nullptr); + } + } +} + +void NodeInsertInfo::updateNodeID(nodeID_t nodeID) const { + KU_ASSERT(nodeIDVector->state->getSelVector().getSelSize() == 1); + auto pos = nodeIDVector->state->getSelVector()[0]; + nodeIDVector->setNull(pos, false); + nodeIDVector->setValue(pos, nodeID); +} + +nodeID_t NodeInsertInfo::getNodeID() const { + auto& nodeIDSelVector = nodeIDVector->state->getSelVector(); + KU_ASSERT(nodeIDSelVector.getSelSize() == 1); + if (nodeIDVector->isNull(nodeIDSelVector[0])) { + return {INVALID_OFFSET, INVALID_TABLE_ID}; + } + return nodeIDVector->getValue(nodeIDSelVector[0]); +} + +void NodeTableInsertInfo::init(const ResultSet& resultSet, main::ClientContext* context) { + for (auto& evaluator : columnDataEvaluators) { + evaluator->init(resultSet, context); + columnDataVectors.push_back(evaluator->resultVector.get()); + } + pkVector = columnDataVectors[table->getPKColumnID()]; +} + +void NodeInsertExecutor::init(ResultSet* resultSet, const ExecutionContext* context) { + info.init(*resultSet); + tableInfo.init(*resultSet, context->clientContext); +} + +static void writeColumnVector(ValueVector* columnVector, const ValueVector* dataVector) { + auto& columnSelVector = columnVector->state->getSelVector(); + auto& dataSelVector = dataVector->state->getSelVector(); + KU_ASSERT(columnSelVector.getSelSize() == 1 && dataSelVector.getSelSize() == 1); + auto columnPos = columnSelVector[0]; + auto dataPos = dataSelVector[0]; + if (dataVector->isNull(dataPos)) { + columnVector->setNull(columnPos, true); + } else { + columnVector->setNull(columnPos, false); + columnVector->copyFromVectorData(columnPos, dataVector, dataPos); + } +} + +// TODO(Guodong/Xiyang): think we can reference data vector instead of copy. +static void writeColumnVectors(const std::vector& columnVectors, + const std::vector& dataVectors) { + KU_ASSERT(columnVectors.size() == dataVectors.size()); + for (auto i = 0u; i < columnVectors.size(); ++i) { + if (columnVectors[i] == nullptr) { // No need to project + continue; + } + writeColumnVector(columnVectors[i], dataVectors[i]); + } +} + +static void writeColumnVectorsToNull(const std::vector& columnVectors) { + for (auto i = 0u; i < columnVectors.size(); ++i) { + auto columnVector = columnVectors[i]; + if (columnVector == nullptr) { // No need to project + continue; + } + auto& columnSelVector = columnVector->state->getSelVector(); + KU_ASSERT(columnSelVector.getSelSize() == 1); + columnVector->setNull(columnSelVector[0], true); + } +} + +void NodeInsertExecutor::setNodeIDVectorToNonNull() const { + info.nodeIDVector->setNull(info.nodeIDVector->state->getSelVector()[0], false); +} + +nodeID_t NodeInsertExecutor::insert(main::ClientContext* context) { + for (auto& evaluator : tableInfo.columnDataEvaluators) { + evaluator->evaluate(); + } + auto transaction = Transaction::Get(*context); + if (checkConflict(transaction)) { + return info.getNodeID(); + } + auto insertState = std::make_unique(*info.nodeIDVector, + *tableInfo.pkVector, tableInfo.columnDataVectors); + tableInfo.table->initInsertState(context, *insertState); + tableInfo.table->insert(transaction, *insertState); + writeColumnVectors(info.columnVectors, tableInfo.columnDataVectors); + return info.getNodeID(); +} + +void NodeInsertExecutor::skipInsert() const { + for (auto& evaluator : tableInfo.columnDataEvaluators) { + evaluator->evaluate(); + } + info.nodeIDVector->setNull(info.nodeIDVector->state->getSelVector()[0], false); + writeColumnVectors(info.columnVectors, tableInfo.columnDataVectors); +} + +bool NodeInsertExecutor::checkConflict(const Transaction* transaction) const { + if (info.conflictAction == ConflictAction::ON_CONFLICT_DO_NOTHING) { + auto offset = + tableInfo.table->validateUniquenessConstraint(transaction, tableInfo.columnDataVectors); + if (offset != INVALID_OFFSET) { + // Conflict. Skip insertion. + info.updateNodeID({offset, tableInfo.table->getTableID()}); + return true; + } + } + return false; +} + +void RelInsertInfo::init(const ResultSet& resultSet) { + srcNodeIDVector = resultSet.getValueVector(srcNodeIDPos).get(); + dstNodeIDVector = resultSet.getValueVector(dstNodeIDPos).get(); + for (auto& pos : columnsPos) { + if (pos.isValid()) { + columnVectors.push_back(resultSet.getValueVector(pos).get()); + } else { + columnVectors.push_back(nullptr); + } + } +} + +void RelTableInsertInfo::init(const ResultSet& resultSet, main::ClientContext* context) { + for (auto& evaluator : columnDataEvaluators) { + evaluator->init(resultSet, context); + columnDataVectors.push_back(evaluator->resultVector.get()); + } +} + +internalID_t RelTableInsertInfo::getRelID() const { + auto relIDVector = columnDataVectors[0]; + auto& nodeIDSelVector = relIDVector->state->getSelVector(); + KU_ASSERT(nodeIDSelVector.getSelSize() == 1); + if (relIDVector->isNull(nodeIDSelVector[0])) { + return {INVALID_OFFSET, INVALID_TABLE_ID}; + } + return relIDVector->getValue(nodeIDSelVector[0]); +} + +void RelInsertExecutor::init(ResultSet* resultSet, const ExecutionContext* context) { + info.init(*resultSet); + tableInfo.init(*resultSet, context->clientContext); +} + +internalID_t RelInsertExecutor::insert(main::ClientContext* context) { + KU_ASSERT(info.srcNodeIDVector->state->getSelVector().getSelSize() == 1); + KU_ASSERT(info.dstNodeIDVector->state->getSelVector().getSelSize() == 1); + auto srcNodeIDPos = info.srcNodeIDVector->state->getSelVector()[0]; + auto dstNodeIDPos = info.dstNodeIDVector->state->getSelVector()[0]; + if (info.srcNodeIDVector->isNull(srcNodeIDPos) || info.dstNodeIDVector->isNull(dstNodeIDPos)) { + // No need to insert. + writeColumnVectorsToNull(info.columnVectors); + return tableInfo.getRelID(); + } + for (auto i = 1u; i < tableInfo.columnDataEvaluators.size(); ++i) { + tableInfo.columnDataEvaluators[i]->evaluate(); + } + auto insertState = std::make_unique(*info.srcNodeIDVector, + *info.dstNodeIDVector, tableInfo.columnDataVectors); + tableInfo.table->initInsertState(context, *insertState); + tableInfo.table->insert(Transaction::Get(*context), *insertState); + writeColumnVectors(info.columnVectors, tableInfo.columnDataVectors); + return tableInfo.getRelID(); +} + +void RelInsertExecutor::skipInsert() const { + for (auto i = 1u; i < tableInfo.columnDataEvaluators.size(); ++i) { + tableInfo.columnDataEvaluators[i]->evaluate(); + } + writeColumnVectors(info.columnVectors, tableInfo.columnDataVectors); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/merge.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/merge.cpp new file mode 100644 index 0000000000..605b9007b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/merge.cpp @@ -0,0 +1,143 @@ +#include "processor/operator/persistent/merge.h" + +#include "binder/expression/expression_util.h" +#include "main/client_context.h" + +namespace lbug { +namespace processor { + +std::string MergePrintInfo::toString() const { + std::string result = "Pattern: "; + result += binder::ExpressionUtil::toString(pattern); + if (!onMatch.empty()) { + result += ", ON MATCH SET: " + binder::ExpressionUtil::toString(onMatch); + } + if (!onCreate.empty()) { + result += ", ON CREATE SET: " + binder::ExpressionUtil::toString(onCreate); + } + return result; +} + +void Merge::initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) { + for (auto& executor : nodeInsertExecutors) { + executor.init(resultSet, context); + } + for (auto& executor : relInsertExecutors) { + executor.init(resultSet, context); + } + for (auto& executor : onCreateNodeSetExecutors) { + executor->init(resultSet, context); + } + for (auto& executor : onCreateRelSetExecutors) { + executor->init(resultSet, context); + } + for (auto& executor : onMatchNodeSetExecutors) { + executor->init(resultSet, context); + } + for (auto& executor : onMatchRelSetExecutors) { + executor->init(resultSet, context); + } + for (auto& evaluator : info.keyEvaluators) { + evaluator->init(*resultSet_, context->clientContext); + } + localState.init(*resultSet, context->clientContext, info); +} + +void MergeLocalState::init(ResultSet& resultSet, main::ClientContext* context, MergeInfo& info) { + std::vector types; + for (auto& evaluator : info.keyEvaluators) { + auto keyVector = evaluator->resultVector.get(); + types.push_back(keyVector->dataType.copy()); + keyVectors.push_back(keyVector); + } + // TODO: remove types + hashTable = std::make_unique(*storage::MemoryManager::Get(*context), + std::move(types), std::move(info.tableSchema)); + existenceVector = resultSet.getValueVector(info.existenceMark).get(); +} + +bool MergeLocalState::patternExists() const { + KU_ASSERT(existenceVector->state->getSelVector().getSelSize() == 1); + auto pos = existenceVector->state->getSelVector()[0]; + return existenceVector->getValue(pos); +} + +void Merge::executeOnMatch(ExecutionContext* context) { + for (auto& executor : onMatchNodeSetExecutors) { + executor->set(context); + } + for (auto& executor : onMatchRelSetExecutors) { + executor->set(context); + } +} + +void Merge::executeOnCreatedPattern(PatternCreationInfo& patternCreationInfo, + ExecutionContext* context) { + for (auto& executor : nodeInsertExecutors) { + executor.skipInsert(); + } + for (auto& executor : relInsertExecutors) { + executor.skipInsert(); + } + for (auto i = 0u; i < onMatchNodeSetExecutors.size(); i++) { + auto& executor = onMatchNodeSetExecutors[i]; + auto nodeIDToSet = patternCreationInfo.getPatternID(i); + executor->setNodeID(nodeIDToSet); + executor->set(context); + } + for (auto i = 0u; i < onMatchRelSetExecutors.size(); i++) { + auto& executor = onMatchRelSetExecutors[i]; + auto relIDToSet = patternCreationInfo.getPatternID(i + onMatchNodeSetExecutors.size()); + executor->setRelID(relIDToSet); + executor->set(context); + } +} + +void Merge::executeOnNewPattern(PatternCreationInfo& patternCreationInfo, + ExecutionContext* context) { + // do insert and on create + for (auto i = 0u; i < nodeInsertExecutors.size(); i++) { + auto& executor = nodeInsertExecutors[i]; + executor.setNodeIDVectorToNonNull(); + auto nodeID = executor.insert(context->clientContext); + patternCreationInfo.updateID(i, info.executorInfo, nodeID); + } + for (auto i = 0u; i < relInsertExecutors.size(); i++) { + auto& executor = relInsertExecutors[i]; + auto relID = executor.insert(context->clientContext); + patternCreationInfo.updateID(i + nodeInsertExecutors.size(), info.executorInfo, relID); + } + for (auto& executor : onCreateNodeSetExecutors) { + executor->set(context); + } + for (auto& executor : onCreateRelSetExecutors) { + executor->set(context); + } +} + +void Merge::executeNoMatch(ExecutionContext* context) { + for (auto& evaluator : info.keyEvaluators) { + evaluator->evaluate(); + } + auto patternCreationInfo = localState.getPatternCreationInfo(); + if (patternCreationInfo.hasCreated) { + executeOnCreatedPattern(patternCreationInfo, context); + } else { + executeOnNewPattern(patternCreationInfo, context); + } +} + +bool Merge::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + if (localState.patternExists()) { + executeOnMatch(context); + } else { + executeNoMatch(context); + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/node_batch_insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/node_batch_insert.cpp new file mode 100644 index 0000000000..31fb3c9841 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/node_batch_insert.cpp @@ -0,0 +1,313 @@ +#include "processor/operator/persistent/node_batch_insert.h" + +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/cast.h" +#include "common/finally_wrapper.h" +#include "common/string_format.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/index_builder.h" +#include "processor/result/factorized_table_util.h" +#include "processor/warning_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/local_storage/local_storage.h" +#include "storage/storage_manager.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/node_table.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +std::string NodeBatchInsertPrintInfo::toString() const { + std::string result = "Table Name: "; + result += tableName; + return result; +} + +void NodeBatchInsertSharedState::initPKIndex(const ExecutionContext* context) { + uint64_t numRows = 0; + if (tableFuncSharedState != nullptr) { + numRows = tableFuncSharedState->getNumRows(); + } + auto* nodeTable = ku_dynamic_cast(table); + nodeTable->getPKIndex()->bulkReserve(numRows); + globalIndexBuilder = IndexBuilder(std::make_shared( + Transaction::Get(*context->clientContext), nodeTable)); +} + +void NodeBatchInsert::initGlobalStateInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto catalog = Catalog::Get(*clientContext); + auto transaction = Transaction::Get(*clientContext); + auto nodeTableEntry = catalog->getTableCatalogEntry(transaction, info->tableName) + ->ptrCast(); + auto nodeTable = StorageManager::Get(*clientContext)->getTable(nodeTableEntry->getTableID()); + const auto& pkDefinition = nodeTableEntry->getPrimaryKeyDefinition(); + auto pkColumnID = nodeTableEntry->getColumnID(pkDefinition.getName()); + // Init info + info->compressionEnabled = StorageManager::Get(*clientContext)->compressionEnabled(); + auto dataColumnIdx = 0u; + for (auto& property : nodeTableEntry->getProperties()) { + info->columnTypes.push_back(property.getType().copy()); + info->insertColumnIDs.push_back(nodeTableEntry->getColumnID(property.getName())); + info->outputDataColumns.push_back(dataColumnIdx++); + } + for (auto& type : info->warningColumnTypes) { + info->columnTypes.push_back(type.copy()); + info->warningDataColumns.push_back(dataColumnIdx++); + } + // Init shared state + auto nodeSharedState = sharedState->ptrCast(); + nodeSharedState->table = nodeTable; + nodeSharedState->pkColumnID = pkColumnID; + nodeSharedState->pkType = pkDefinition.getType().copy(); + nodeSharedState->initPKIndex(context); +} + +void NodeBatchInsert::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + const auto nodeInfo = info->ptrCast(); + const auto numColumns = nodeInfo->columnEvaluators.size(); + + const auto nodeSharedState = ku_dynamic_cast(sharedState.get()); + localState = std::make_unique( + std::span{nodeInfo->columnTypes.begin(), nodeInfo->outputDataColumns.size()}); + const auto nodeLocalState = localState->ptrCast(); + KU_ASSERT(nodeSharedState->globalIndexBuilder); + nodeLocalState->localIndexBuilder = nodeSharedState->globalIndexBuilder->clone(); + nodeLocalState->errorHandler = createErrorHandler(context); + nodeLocalState->optimisticAllocator = + Transaction::Get(*context->clientContext)->getLocalStorage()->addOptimisticAllocator(); + + nodeLocalState->columnVectors.resize(numColumns); + + for (auto i = 0u; i < numColumns; ++i) { + auto& evaluator = nodeInfo->columnEvaluators[i]; + evaluator->init(*resultSet, context->clientContext); + nodeLocalState->columnVectors[i] = evaluator->resultVector.get(); + } + nodeLocalState->chunkedGroup = + std::make_unique(*MemoryManager::Get(*context->clientContext), + nodeInfo->columnTypes, info->compressionEnabled, StorageConfig::NODE_GROUP_SIZE, 0); + KU_ASSERT(resultSet->dataChunks[0]); + nodeLocalState->columnState = resultSet->dataChunks[0]->state; +} + +void NodeBatchInsert::executeInternal(ExecutionContext* context) { + const auto clientContext = context->clientContext; + std::optional token; + auto nodeLocalState = localState->ptrCast(); + if (nodeLocalState->localIndexBuilder) { + token = nodeLocalState->localIndexBuilder->getProducerToken(); + } + auto transaction = Transaction::Get(*clientContext); + while (children[0]->getNextTuple(context)) { + const auto originalSelVector = nodeLocalState->columnState->getSelVectorShared(); + // Evaluate expressions if needed. + const auto numTuples = nodeLocalState->columnState->getSelVector().getSelSize(); + evaluateExpressions(numTuples); + copyToNodeGroup(transaction, MemoryManager::Get(*clientContext)), + nodeLocalState->columnState->setSelVector(originalSelVector); + } + if (nodeLocalState->chunkedGroup->getNumRows() > 0) { + appendIncompleteNodeGroup(transaction, std::move(nodeLocalState->chunkedGroup), + nodeLocalState->localIndexBuilder, MemoryManager::Get(*context->clientContext)); + } + if (nodeLocalState->localIndexBuilder) { + KU_ASSERT(token); + token->quit(); + + KU_ASSERT(nodeLocalState->errorHandler.has_value()); + nodeLocalState->localIndexBuilder->finishedProducing(nodeLocalState->errorHandler.value()); + nodeLocalState->errorHandler->flushStoredErrors(); + } + const auto nodeInfo = info->ptrCast(); + sharedState->table->cast().mergeStats(nodeInfo->insertColumnIDs, + nodeLocalState->stats); +} + +void NodeBatchInsert::evaluateExpressions(uint64_t numTuples) const { + const auto nodeInfo = info->ptrCast(); + for (auto i = 0u; i < nodeInfo->evaluateTypes.size(); ++i) { + switch (nodeInfo->evaluateTypes[i]) { + case ColumnEvaluateType::DEFAULT: { + nodeInfo->columnEvaluators[i]->evaluate(numTuples); + } break; + case ColumnEvaluateType::CAST: { + nodeInfo->columnEvaluators[i]->evaluate(); + } break; + default: + break; + } + } +} + +void NodeBatchInsert::copyToNodeGroup(transaction::Transaction* transaction, + MemoryManager* mm) const { + auto numAppendedTuples = 0ul; + const auto nodeLocalState = ku_dynamic_cast(localState.get()); + const auto numTuplesToAppend = nodeLocalState->columnState->getSelVector().getSelSize(); + while (numAppendedTuples < numTuplesToAppend) { + const auto numAppendedTuplesInNodeGroup = + nodeLocalState->chunkedGroup->append(nodeLocalState->columnVectors, numAppendedTuples, + numTuplesToAppend - numAppendedTuples); + numAppendedTuples += numAppendedTuplesInNodeGroup; + if (nodeLocalState->chunkedGroup->isFull()) { + writeAndResetNodeGroup(transaction, nodeLocalState->chunkedGroup, + nodeLocalState->localIndexBuilder, mm, *nodeLocalState->optimisticAllocator); + } + } + const auto nodeInfo = info->ptrCast(); + nodeLocalState->stats.update(nodeLocalState->columnVectors, nodeInfo->outputDataColumns.size()); + sharedState->incrementNumRows(numAppendedTuples); +} + +NodeBatchInsertErrorHandler NodeBatchInsert::createErrorHandler(ExecutionContext* context) const { + const auto nodeSharedState = ku_dynamic_cast(sharedState.get()); + auto* nodeTable = ku_dynamic_cast(sharedState->table); + return NodeBatchInsertErrorHandler{context, nodeSharedState->pkType.getLogicalTypeID(), + nodeTable, WarningContext::Get(*context->clientContext)->getIgnoreErrorsOption(), + sharedState->numErroredRows, &sharedState->erroredRowMutex}; +} + +void NodeBatchInsert::clearToIndex(MemoryManager* mm, + std::unique_ptr& nodeGroup, offset_t startIndexInGroup) const { + // Create a new chunked node group and move the unwritten values to it + // TODO(bmwinger): Can probably re-use the chunk and shift the values + const auto oldNodeGroup = std::move(nodeGroup); + const auto nodeInfo = info->ptrCast(); + nodeGroup = std::make_unique(*mm, nodeInfo->columnTypes, + nodeInfo->compressionEnabled, StorageConfig::NODE_GROUP_SIZE, 0); + nodeGroup->append(*oldNodeGroup, startIndexInGroup, + oldNodeGroup->getNumRows() - startIndexInGroup); +} + +void NodeBatchInsert::writeAndResetNodeGroup(transaction::Transaction* transaction, + std::unique_ptr& nodeGroup, std::optional& indexBuilder, + MemoryManager* mm, PageAllocator& pageAllocator) const { + const auto nodeLocalState = localState->ptrCast(); + KU_ASSERT(nodeLocalState->errorHandler.has_value()); + writeAndResetNodeGroup(transaction, nodeGroup, indexBuilder, mm, + nodeLocalState->errorHandler.value(), pageAllocator); +} + +void NodeBatchInsert::writeAndResetNodeGroup(transaction::Transaction* transaction, + std::unique_ptr& nodeGroup, std::optional& indexBuilder, + MemoryManager* mm, NodeBatchInsertErrorHandler& errorHandler, + PageAllocator& pageAllocator) const { + const auto nodeSharedState = ku_dynamic_cast(sharedState.get()); + const auto nodeTable = ku_dynamic_cast(sharedState->table); + + uint64_t nodeOffset{}; + uint64_t numRowsWritten{}; + { + // The chunked group in batch insert may contain extra data to populate error messages + // When we append to the table we only want the main data so this class is used to slice the + // original chunked group + // The slice must be restored even if an exception is thrown to prevent other threads from + // reading invalid data + InMemChunkedNodeGroup sliceToWriteToDisk{*nodeGroup, info->outputDataColumns}; + FinallyWrapper sliceRestorer{ + [&]() { nodeGroup->merge(sliceToWriteToDisk, info->outputDataColumns); }}; + std::tie(nodeOffset, numRowsWritten) = nodeTable->appendToLastNodeGroup(transaction, + info->insertColumnIDs, sliceToWriteToDisk, pageAllocator); + } + + if (indexBuilder) { + std::vector warningChunkData; + for (const auto warningDataColumn : info->warningDataColumns) { + warningChunkData.push_back(&nodeGroup->getColumnChunk(warningDataColumn)); + } + indexBuilder->insert(nodeGroup->getColumnChunk(nodeSharedState->pkColumnID), + warningChunkData, nodeOffset, numRowsWritten, errorHandler); + } + if (numRowsWritten == nodeGroup->getNumRows()) { + nodeGroup->resetToEmpty(); + } else { + clearToIndex(mm, nodeGroup, numRowsWritten); + } +} + +void NodeBatchInsert::appendIncompleteNodeGroup(transaction::Transaction* transaction, + std::unique_ptr localNodeGroup, + std::optional& indexBuilder, MemoryManager* mm) const { + std::unique_lock xLck{sharedState->mtx}; + const auto nodeLocalState = ku_dynamic_cast(localState.get()); + const auto nodeSharedState = ku_dynamic_cast(sharedState.get()); + if (!nodeSharedState->sharedNodeGroup) { + nodeSharedState->sharedNodeGroup = std::move(localNodeGroup); + return; + } + uint64_t numNodesAppended = 0; + while (numNodesAppended < localNodeGroup->getNumRows()) { + if (nodeSharedState->sharedNodeGroup->isFull()) { + writeAndResetNodeGroup(transaction, nodeSharedState->sharedNodeGroup, indexBuilder, mm, + *nodeLocalState->optimisticAllocator); + } + numNodesAppended += nodeSharedState->sharedNodeGroup->append(*localNodeGroup, + numNodesAppended /* offsetInNodeGroup */, + localNodeGroup->getNumRows() - numNodesAppended); + } + KU_ASSERT(numNodesAppended == localNodeGroup->getNumRows()); +} + +void NodeBatchInsert::finalize(ExecutionContext* context) { + KU_ASSERT(localState == nullptr); + const auto nodeSharedState = ku_dynamic_cast(sharedState.get()); + auto errorHandler = createErrorHandler(context); + auto clientContext = context->clientContext; + auto transaction = Transaction::Get(*clientContext); + auto& pageAllocator = *transaction->getLocalStorage()->addOptimisticAllocator(); + if (nodeSharedState->sharedNodeGroup) { + while (nodeSharedState->sharedNodeGroup->getNumRows() > 0) { + writeAndResetNodeGroup(transaction, nodeSharedState->sharedNodeGroup, + nodeSharedState->globalIndexBuilder, MemoryManager::Get(*clientContext), + errorHandler, pageAllocator); + } + } + if (nodeSharedState->globalIndexBuilder) { + nodeSharedState->globalIndexBuilder->finalize(context, errorHandler); + errorHandler.flushStoredErrors(); + } + + auto& nodeTable = nodeSharedState->table->cast(); + for (auto& index : nodeTable.getIndexes()) { + index.finalize(clientContext); + } + // we want to flush all index errors before children call finalize + // as the children (if they are table function calls) are responsible for populating the errors + // and sending it to the warning context + PhysicalOperator::finalize(context); + + // if the child is a subquery it will not send the errors to the warning context + // sends any remaining warnings in this case + // if the child is a table function call it will have already sent the warnings so this line + // will do nothing + WarningContext::Get(*clientContext)->defaultPopulateAllWarnings(context->queryID); +} + +void NodeBatchInsert::finalizeInternal(ExecutionContext* context) { + auto outputMsg = stringFormat("{} tuples have been copied to the {} table.", + sharedState->getNumRows() - sharedState->getNumErroredRows(), info->tableName); + auto clientContext = context->clientContext; + FactorizedTableUtils::appendStringToTable(sharedState->fTable.get(), outputMsg, + MemoryManager::Get(*clientContext)); + + const auto warningCount = + WarningContext::Get(*clientContext)->getWarningCount(context->queryID); + if (warningCount > 0) { + auto warningMsg = + stringFormat("{} warnings encountered during copy. Use 'CALL " + "show_warnings() RETURN *' to view the actual warnings. Query ID: {}", + warningCount, context->queryID); + FactorizedTableUtils::appendStringToTable(sharedState->fTable.get(), warningMsg, + MemoryManager::Get(*clientContext)); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/node_batch_insert_error_handler.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/node_batch_insert_error_handler.cpp new file mode 100644 index 0000000000..b39caf0ec4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/node_batch_insert_error_handler.cpp @@ -0,0 +1,37 @@ +#include "processor/operator/persistent/node_batch_insert_error_handler.h" + +#include "processor/execution_context.h" +#include "storage/table/node_table.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +NodeBatchInsertErrorHandler::NodeBatchInsertErrorHandler(ExecutionContext* context, + LogicalTypeID pkType, storage::NodeTable* nodeTable, bool ignoreErrors, + std::shared_ptr sharedErrorCounter, std::mutex* sharedErrorCounterMtx) + : nodeTable(nodeTable), context(context), + keyVector(std::make_shared(pkType, + storage::MemoryManager::Get(*context->clientContext))), + offsetVector(std::make_shared(LogicalTypeID::INTERNAL_ID, + storage::MemoryManager::Get(*context->clientContext))), + baseErrorHandler(context, ignoreErrors, sharedErrorCounter, sharedErrorCounterMtx) { + keyVector->state = DataChunkState::getSingleValueDataChunkState(); + offsetVector->state = DataChunkState::getSingleValueDataChunkState(); +} + +void NodeBatchInsertErrorHandler::deleteCurrentErroneousRow() { + storage::NodeTableDeleteState deleteState{ + *offsetVector, + *keyVector, + }; + nodeTable->delete_(transaction::Transaction::Get(*context->clientContext), deleteState); +} + +void NodeBatchInsertErrorHandler::flushStoredErrors() { + baseErrorHandler.flushStoredErrors(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/CMakeLists.txt new file mode 100644 index 0000000000..85bac2e6ae --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/CMakeLists.txt @@ -0,0 +1,13 @@ +add_subdirectory(csv) +add_subdirectory(npy) +add_subdirectory(parquet) + +add_library(lbug_processor_operator_persistent_reader + OBJECT + copy_from_error.cpp + file_error_handler.cpp + reader_bind_utils.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/copy_from_error.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/copy_from_error.cpp new file mode 100644 index 0000000000..a7f3981116 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/copy_from_error.cpp @@ -0,0 +1,76 @@ +#include "processor/operator/persistent/reader/copy_from_error.h" + +#include "common/vector/value_vector.h" +#include "storage/table/column_chunk_data.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +template +static PhysicalTypeID getPhysicalTypeFromDataSource(T* data) { + if constexpr (std::is_same_v) { + return data->getDataType().getPhysicalType(); + } else if constexpr (std::is_same_v) { + return data->dataType.getPhysicalType(); + } else { + KU_UNREACHABLE; + } +} + +template +WarningSourceData::DataType getValueFromData(T* data, common::idx_t pos) { + // avoid using TypeUtils::visit here to avoid the overhead from constructing a capturing lambda + switch (getPhysicalTypeFromDataSource(data)) { + case common::PhysicalTypeID::UINT64: + return data->template getValue(pos); + case common::PhysicalTypeID::UINT32: + return data->template getValue(pos); + default: + KU_UNREACHABLE; + } +} + +WarningSourceData::WarningSourceData(uint64_t numValues) : numValues(numValues) { + KU_ASSERT(numValues <= values.size()); +} + +template +WarningSourceData WarningSourceData::constructFromData(const std::vector& data, + common::idx_t pos) { + KU_ASSERT(data.size() >= CopyConstants::SHARED_WARNING_DATA_NUM_COLUMNS && + data.size() <= CopyConstants::MAX_NUM_WARNING_DATA_COLUMNS); + WarningSourceData ret{data.size()}; + for (idx_t i = 0; i < data.size(); ++i) { + ret.values[i] = getValueFromData(data[i], pos); + } + return ret; +} + +uint64_t WarningSourceData::getBlockIdx() const { + return std::get(values[BLOCK_IDX_IDX]); +} +uint32_t WarningSourceData::getOffsetInBlock() const { + return std::get(values[OFFSET_IN_BLOCK_IDX]); +} + +template WarningSourceData WarningSourceData::constructFromData( + const std::vector& data, common::idx_t pos); +template WarningSourceData WarningSourceData::constructFromData( + const std::vector& data, common::idx_t pos); + +CopyFromFileError::CopyFromFileError(std::string message, WarningSourceData warningData, + bool completedLine, bool mustThrow) + : message(std::move(message)), completedLine(completedLine), warningData(warningData), + mustThrow(mustThrow) {} + +bool CopyFromFileError::operator<(const CopyFromFileError& o) const { + if (warningData.getBlockIdx() == o.warningData.getBlockIdx()) { + return warningData.getOffsetInBlock() < o.warningData.getOffsetInBlock(); + } + return warningData.getBlockIdx() < o.warningData.getBlockIdx(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/CMakeLists.txt new file mode 100644 index 0000000000..23592dfd1d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_processor_operator_csv_reader + OBJECT + base_csv_reader.cpp + driver.cpp + parallel_csv_reader.cpp + serial_csv_reader.cpp + dialect_detection.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp new file mode 100644 index 0000000000..019d878139 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp @@ -0,0 +1,588 @@ +#include "processor/operator/persistent/reader/csv/base_csv_reader.h" + +#include + +#include "common/file_system/virtual_file_system.h" +#include "common/string_format.h" +#include "common/string_utils.h" +#include "common/system_message.h" +#include "common/utils.h" +#include "main/client_context.h" +#include "processor/operator/persistent/reader/csv/driver.h" +#include "processor/operator/persistent/reader/file_error_handler.h" +#include "utf8proc_wrapper.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +// TODO(Royi) for performance reasons we may want to reduce the number of fields here since each +// field is essentially an extra column during copy +struct CSVWarningSourceData { + CSVWarningSourceData() = default; + static CSVWarningSourceData constructFrom(const processor::WarningSourceData& warningData); + + uint64_t startByteOffset; + uint64_t endByteOffset; + uint64_t blockIdx; + uint32_t offsetInBlock; + common::idx_t fileIdx; +}; + +CSVWarningSourceData CSVWarningSourceData::constructFrom( + const processor::WarningSourceData& warningData) { + KU_ASSERT(warningData.numValues == CopyConstants::CSV_WARNING_DATA_NUM_COLUMNS); + + CSVWarningSourceData ret{}; + warningData.dumpTo(ret.blockIdx, ret.offsetInBlock, ret.startByteOffset, ret.endByteOffset, + ret.fileIdx); + return ret; +} + +BaseCSVReader::BaseCSVReader(const std::string& filePath, common::idx_t fileIdx, + common::CSVOption option, CSVColumnInfo columnInfo, main::ClientContext* context, + LocalFileErrorHandler* errorHandler) + : context{context}, option{std::move(option)}, columnInfo{std::move(columnInfo)}, + currentBlockIdx(0), numRowsInCurrentBlock(0), curRowIdx(0), numErrors(0), buffer{nullptr}, + bufferIdx(0), bufferSize{0}, position{0}, lineContext(), osFileOffset{0}, fileIdx(fileIdx), + errorHandler(errorHandler), rowEmpty{false} { + fileInfo = VirtualFileSystem::GetUnsafe(*context)->openFile(filePath, + FileOpenFlags(FileFlags::READ_ONLY +#ifdef _WIN32 + | FileFlags::BINARY +#endif + ), + context); +} + +bool BaseCSVReader::isEOF() const { + return getFileOffset() >= fileInfo->getFileSize(); +} + +uint64_t BaseCSVReader::getFileSize() { + return fileInfo->getFileSize(); +} + +template +bool BaseCSVReader::addValue(Driver& driver, uint64_t rowNum, column_id_t columnIdx, + std::string_view strVal, std::vector& escapePositions) { + std::string valueToAdd; + // insert the line number into the chunk + if (!escapePositions.empty()) { + // remove escape characters (if any) + std::string newVal = ""; + uint64_t prevPos = 0; + for (auto i = 0u; i < escapePositions.size(); i++) { + auto nextPos = escapePositions[i]; + newVal += strVal.substr(prevPos, nextPos - prevPos); + prevPos = nextPos + 1; + } + newVal += strVal.substr(prevPos, strVal.size() - prevPos); + escapePositions.clear(); + valueToAdd = newVal; + } else { + valueToAdd = strVal; + } + if (!utf8proc::Utf8Proc::isValid(valueToAdd.data(), valueToAdd.length())) { + handleCopyException("Invalid UTF8-encoded string.", true /* mustThrow */); + } + return driver.addValue(rowNum, columnIdx, valueToAdd); +} + +struct SkipRowDriver { + DriverType driverType = DriverType::SKIP_ROW; + explicit SkipRowDriver(uint64_t skipNum) : skipNum{skipNum} {} + bool done(uint64_t rowNum) const { return rowNum >= skipNum; } + bool addRow(uint64_t, column_id_t, std::optional) { return true; } + bool addValue(uint64_t, column_id_t, std::string_view) { return true; } + + uint64_t skipNum; +}; + +BaseCSVReader::parse_result_t BaseCSVReader::handleFirstBlock() { + uint64_t numRowsRead = 0; + uint64_t numErrors = 0; + readBOM(); + if (option.skipNum > 0) { + SkipRowDriver driver{option.skipNum}; + const auto parseResult = parseCSV(driver); + numRowsRead += parseResult.first; + numErrors += parseResult.second; + } + if (option.hasHeader) { + const auto parseResult = readHeader(); + numRowsRead += parseResult.first; + numErrors += parseResult.second; + } + return {numRowsRead, numErrors}; +} + +void BaseCSVReader::readBOM() { + if (!maybeReadBuffer(nullptr)) { + return; + } + if (bufferSize >= 3 && buffer[0] == '\xEF' && buffer[1] == '\xBB' && buffer[2] == '\xBF') { + position = 3; + } +} + +// Dummy driver that just skips a row. +struct HeaderDriver { + DriverType driverType = DriverType::HEADER; + bool done(uint64_t) { return true; } + bool addRow(uint64_t, column_id_t, std::optional) { return true; } + bool addValue(uint64_t, column_id_t, std::string_view) { return true; } +}; + +void BaseCSVReader::resetNumRowsInCurrentBlock() { + numRowsInCurrentBlock = 0; +} + +void BaseCSVReader::increaseNumRowsInCurrentBlock(uint64_t numRows, uint64_t numErrors) { + numRowsInCurrentBlock += numRows + numErrors; +} + +uint64_t BaseCSVReader::getNumRowsInCurrentBlock() const { + return numRowsInCurrentBlock; +} + +uint32_t BaseCSVReader::getRowOffsetInCurrentBlock() const { + return safeIntegerConversion(numRowsInCurrentBlock + curRowIdx + numErrors); +} + +BaseCSVReader::parse_result_t BaseCSVReader::readHeader() { + HeaderDriver driver; + return parseCSV(driver); +} + +bool BaseCSVReader::readBuffer(uint64_t* start) { + std::unique_ptr oldBuffer = std::move(buffer); + + // the remaining part of the last buffer + uint64_t remaining = 0; + if (start != nullptr) { + KU_ASSERT(*start <= bufferSize); + remaining = bufferSize - *start; + } + + uint64_t bufferReadSize = CopyConstants::INITIAL_BUFFER_SIZE; + while (remaining > bufferReadSize) { + bufferReadSize *= 2; + } + + buffer = std::unique_ptr(new char[bufferReadSize + remaining + 1]()); + if (remaining > 0) { + // remaining from last buffer: copy it here + KU_ASSERT(start != nullptr); + memcpy(buffer.get(), oldBuffer.get() + *start, remaining); + } + auto readCount = fileInfo->readFile(buffer.get() + remaining, bufferReadSize); + if (readCount == -1) { + // LCOV_EXCL_START + lineContext.setEndOfLine(getFileOffset()); + handleCopyException(stringFormat("Could not read from file: {}", posixErrMessage()), true); + // LCOV_EXCL_STOP + } + + // Update buffer size in a way so that the invariant osFileOffset >= bufferSize is never broken + // This is needed because in the serial CSV reader the progressFunc can call getFileOffset from + // a different thread + bufferSize = remaining; + osFileOffset += readCount; + bufferSize += readCount; + + buffer[bufferSize] = '\0'; + if (start != nullptr) { + *start = 0; + } + position = remaining; + ++bufferIdx; + return readCount > 0; +} + +std::string BaseCSVReader::reconstructLine(uint64_t startPosition, uint64_t endPosition, + bool completeLine) { + KU_ASSERT(endPosition >= startPosition); + + std::string res; + // For cases where we cannot perform a seek (e.g. compressed file system) we just return an + // empty string + if (fileInfo->canPerformSeek()) { + res.resize(endPosition - startPosition); + fileInfo->readFromFile(res.data(), res.size(), startPosition); + + const char* incompleteLineSuffix = completeLine ? "" : "..."; + res += incompleteLineSuffix; + } + + return StringUtils::ltrimNewlines(StringUtils::rtrimNewlines(res)); +} + +void BaseCSVReader::skipCurrentLine() { + do { + for (; position < bufferSize; ++position) { + if (isNewLine(buffer[position])) { + while (position < bufferSize && isNewLine(buffer[position])) { + ++position; + } + return; + } + } + } while (maybeReadBuffer(nullptr)); +} + +void BaseCSVReader::handleCopyException(const std::string& message, bool mustThrow) { + auto endByteOffset = lineContext.endByteOffset; + if (!lineContext.isCompleteLine) { + endByteOffset = getFileOffset(); + } + CopyFromFileError error{message, + WarningSourceData::constructFrom(currentBlockIdx, getRowOffsetInCurrentBlock(), + lineContext.startByteOffset, endByteOffset, fileIdx), + lineContext.isCompleteLine, mustThrow}; + errorHandler->handleError(error); + + // if we reach here it means we are ignoring the error + ++numErrors; +} + +template +static std::optional getOptionalWarningData( + const CSVColumnInfo& columnInfo, const CSVOption& option, + WarningSourceData&& warningSourceData) { + std::optional warningData; + + // we only care about populating the extra warning data when actually parsing the CSV + // and not when performing actions like sniffing + if constexpr (std::is_same_v || + std::is_same_v) { + // For now we only populate extra warning data when IGNORE_ERRORS is enabled + if (option.ignoreErrors) { + KU_ASSERT( + columnInfo.numWarningDataColumns == CopyConstants::CSV_WARNING_DATA_NUM_COLUMNS); + warningData.emplace(warningSourceData, columnInfo.numColumns); + } + } + return warningData; +} + +WarningSourceData BaseCSVReader::getWarningSourceData() const { + return WarningSourceData::constructFrom(currentBlockIdx, getRowOffsetInCurrentBlock(), + lineContext.startByteOffset, lineContext.endByteOffset, fileIdx); +} + +template +BaseCSVReader::parse_result_t BaseCSVReader::parseCSV(Driver& driver) { + KU_ASSERT(nullptr != errorHandler); + + // used for parsing algorithm + curRowIdx = 0; + numErrors = 0; + + while (true) { + column_id_t column = 0; + auto start = position.load(); + bool hasQuotes = false; + std::vector escapePositions; + lineContext.setNewLine(getFileOffset()); + + // read values into the buffer (if any) + if (!maybeReadBuffer(&start)) { + return {curRowIdx, numErrors}; + } + + // start parsing the first value + goto value_start; + value_start: + // state: value_start + // this state parses the first character of a value + if (buffer[position] == option.quoteChar) { + [[unlikely]] + // quote: actual value starts in the next position + // move to in_quotes state + start = position + 1; + hasQuotes = true; + goto in_quotes; + } else { + // no quote, move to normal parsing state + start = position; + hasQuotes = false; + goto normal; + } + normal: + // state: normal parsing state + // this state parses the remainder of a non-quoted value until we reach a delimiter or + // newline + do { + for (; position < bufferSize; position++) { + if (buffer[position] == option.delimiter) { + // delimiter: end the value and add it to the chunk + goto add_value; + } else if (isNewLine(buffer[position])) { + // newline: add row + goto add_row; + } + } + } while (readBuffer(&start)); + + [[unlikely]] + // file ends during normal scan: go to end state + goto final_state; + add_value: + // We get here after we have a delimiter. + KU_ASSERT(buffer[position] == option.delimiter || + buffer[position] == CopyConstants::DEFAULT_CSV_LIST_END_CHAR); + // Trim one character if we have quotes. + if (!addValue(driver, curRowIdx, column, + std::string_view(buffer.get() + start, position - start - hasQuotes), + escapePositions)) { + goto ignore_error; + } + column++; + + // Move past the delimiter. + ++position; + // Adjust start for MaybeReadBuffer. + start = position; + if (!maybeReadBuffer(&start)) { + [[unlikely]] + // File ends right after delimiter, go to final state + goto final_state; + } + goto value_start; + add_row: { + // We get here after we have a newline. + KU_ASSERT(isNewLine(buffer[position])); + lineContext.setEndOfLine(getFileOffset()); + bool isCarriageReturn = buffer[position] == '\r'; + if (!addValue(driver, curRowIdx, column, + std::string_view(buffer.get() + start, position - start - hasQuotes), + escapePositions)) { + goto ignore_error; + } + column++; + + curRowIdx += driver.addRow(curRowIdx, column, + getOptionalWarningData(columnInfo, option, getWarningSourceData())); + + column = 0; + position++; + // Adjust start for ReadBuffer. + start = position; + lineContext.setNewLine(getFileOffset()); + if (!maybeReadBuffer(&start)) { + // File ends right after newline, go to final state. + goto final_state; + } + if (isCarriageReturn) { + // \r newline, go to special state that parses an optional \n afterwards + goto carriage_return; + } else { + if (driver.done(curRowIdx)) { + return {curRowIdx, numErrors}; + } + goto value_start; + } + } + in_quotes: + // this state parses the remainder of a quoted value. + position++; + do { + for (; position < bufferSize; position++) { + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setEverQuoted(); + } + if (buffer[position] == option.quoteChar) { + // quote: move to unquoted state + goto unquote; + } else if (buffer[position] == option.escapeChar) { + // escape: store the escaped position and move to handle_escape state + escapePositions.push_back(position - start); + goto handle_escape; + } else if (isNewLine(buffer[position])) { + [[unlikely]] if (!handleQuotedNewline()) { goto ignore_error; } + } + } + } while (readBuffer(&start)); + [[unlikely]] + // still in quoted state at the end of the file, error: + lineContext.setEndOfLine(getFileOffset()); + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setError(); + } else { + handleCopyException("unterminated quotes."); + } + // we are ignoring this error, skip current row and restart state machine + goto ignore_error; + unquote: + KU_ASSERT(hasQuotes && buffer[position] == option.quoteChar); + // this state handles the state directly after we unquote + // in this state we expect either another quote (entering the quoted state again, and + // escaping the quote) or a delimiter/newline, ending the current value and moving on to the + // next value + position++; + if (!maybeReadBuffer(&start)) { + // file ends right after unquote, go to final state + goto final_state; + } + if (buffer[position] == option.quoteChar && + (!option.escapeChar || option.escapeChar == option.quoteChar)) { + // the escapeChar is used correctly, record this for DialectSniff + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setEverEscaped(); + } + // escaped quote, return to quoted state and store escape position + escapePositions.push_back(position - start); + goto in_quotes; + } else if (buffer[position] == option.delimiter || + buffer[position] == CopyConstants::DEFAULT_CSV_LIST_END_CHAR) { + // delimiter, add value + goto add_value; + } else if (isNewLine(buffer[position])) { + goto add_row; + } else { + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setError(); + } else { + [[unlikely]] handleCopyException("quote should be followed by " + "end of file, end of value, end of " + "row or another quote."); + } + goto ignore_error; + } + handle_escape: + // state: handle_escape + // escape should be followed by a quote or another escape character + position++; + if (!maybeReadBuffer(&start)) { + [[unlikely]] lineContext.setEndOfLine(getFileOffset()); + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setError(); + } else { + handleCopyException("escape at end of file."); + } + goto ignore_error; + } + if (buffer[position] != option.quoteChar && buffer[position] != option.escapeChar) { + ++position; // consume the invalid char + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setError(); + } else { + [[unlikely]] handleCopyException( + "neither QUOTE nor ESCAPE is proceeded by ESCAPE."); + } + goto ignore_error; + } + // the escapeChar is used correctly, record this for DialectSniff + if (driver.driverType == DriverType::SNIFF_CSV_DIALECT) { + auto& sniffDriver = reinterpret_cast(driver); + sniffDriver.setEverEscaped(); + } + // escape was followed by quote or escape, go back to quoted state + goto in_quotes; + carriage_return: + // this stage optionally skips a newline (\n) character, which allows \r\n to be interpreted + // as a single line + + // position points to the character after the carriage return. + if (buffer[position] == '\n') { + // newline after carriage return: skip + // increase position by 1 and move start to the new position + start = ++position; + if (!maybeReadBuffer(&start)) { + // file ends right after newline, go to final state + goto final_state; + } + } + if (driver.done(curRowIdx)) { + return {curRowIdx, numErrors}; + } + + goto value_start; + final_state: + // We get here when the file ends. + // If we were mid-value, add the remaining value to the chunk. + lineContext.setEndOfLine(getFileOffset()); + if (position > start) { + // Add remaining value to chunk. + if (!addValue(driver, curRowIdx, column, + std::string_view(buffer.get() + start, position - start - hasQuotes), + escapePositions)) { + return {curRowIdx, numErrors}; + } + column++; + } + if (column > 0) { + curRowIdx += driver.addRow(curRowIdx, column, + getOptionalWarningData(columnInfo, option, getWarningSourceData())); + } + return {curRowIdx, numErrors}; + ignore_error: + // we skip the current row then restart the state machine to continue parsing + skipCurrentLine(); + if (driver.done(curRowIdx)) { + return {curRowIdx, numErrors}; + } + continue; + } + KU_UNREACHABLE; +} + +column_id_t BaseCSVReader::appendWarningDataColumns(std::vector& resultColumnNames, + std::vector& resultColumnTypes, const common::FileScanInfo& fileScanInfo) { + const bool ignoreErrors = fileScanInfo.getOption(CopyConstants::IGNORE_ERRORS_OPTION_NAME, + CopyConstants::DEFAULT_IGNORE_ERRORS); + column_id_t numWarningDataColumns = 0; + if (ignoreErrors) { + numWarningDataColumns = CopyConstants::CSV_WARNING_DATA_NUM_COLUMNS; + for (idx_t i = 0; i < CopyConstants::CSV_WARNING_DATA_NUM_COLUMNS; ++i) { + resultColumnNames.emplace_back(CopyConstants::CSV_WARNING_DATA_COLUMN_NAMES[i]); + resultColumnTypes.emplace_back(CopyConstants::CSV_WARNING_DATA_COLUMN_TYPES[i]); + } + } + return numWarningDataColumns; +} + +PopulatedCopyFromError BaseCSVReader::basePopulateErrorFunc(CopyFromFileError error, + const SharedFileErrorHandler* sharedErrorHandler, BaseCSVReader* reader, std::string filePath) { + const auto warningData = CSVWarningSourceData::constructFrom(error.warningData); + const auto lineNumber = + sharedErrorHandler->getLineNumber(warningData.blockIdx, warningData.offsetInBlock); + return PopulatedCopyFromError{ + .message = std::move(error.message), + .filePath = std::move(filePath), + .skippedLineOrRecord = reader->reconstructLine(warningData.startByteOffset, + warningData.endByteOffset, error.completedLine), + .lineNumber = lineNumber, + }; +} + +common::idx_t BaseCSVReader::getFileIdxFunc(const CopyFromFileError& error) { + return CSVWarningSourceData::constructFrom(error.warningData).fileIdx; +} + +template BaseCSVReader::parse_result_t BaseCSVReader::parseCSV( + ParallelParsingDriver&); +template BaseCSVReader::parse_result_t BaseCSVReader::parseCSV( + SerialParsingDriver&); +template BaseCSVReader::parse_result_t BaseCSVReader::parseCSV( + SniffCSVNameAndTypeDriver&); +template BaseCSVReader::parse_result_t BaseCSVReader::parseCSV( + SniffCSVDialectDriver&); +template BaseCSVReader::parse_result_t BaseCSVReader::parseCSV( + SniffCSVHeaderDriver&); + +uint64_t BaseCSVReader::getFileOffset() const { + KU_ASSERT(osFileOffset >= bufferSize); + return osFileOffset - bufferSize + position; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/dialect_detection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/dialect_detection.cpp new file mode 100644 index 0000000000..9e108f3f59 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/dialect_detection.cpp @@ -0,0 +1,46 @@ +#include "processor/operator/persistent/reader/csv/dialect_detection.h" + +namespace lbug { +namespace processor { + +std::vector generateDialectOptions(const common::CSVOption& option) { + std::vector options; + std::string delimiters = ""; + std::string quoteChars = ""; + std::string escapeChars = ""; + + if (option.setDelim) { + delimiters += option.delimiter; + } else { + delimiters.assign(common::CopyConstants::DEFAULT_CSV_DELIMITER_SEARCH_SPACE.begin(), + common::CopyConstants::DEFAULT_CSV_DELIMITER_SEARCH_SPACE.end()); + } + + if (option.setQuote) { + quoteChars += option.quoteChar; + } else { + quoteChars.resize(common::CopyConstants::DEFAULT_CSV_QUOTE_SEARCH_SPACE.size()); + quoteChars.assign(common::CopyConstants::DEFAULT_CSV_QUOTE_SEARCH_SPACE.begin(), + common::CopyConstants::DEFAULT_CSV_QUOTE_SEARCH_SPACE.end()); + } + + if (option.setEscape) { + escapeChars += option.escapeChar; + } else { + escapeChars.assign(common::CopyConstants::DEFAULT_CSV_ESCAPE_SEARCH_SPACE.begin(), + common::CopyConstants::DEFAULT_CSV_ESCAPE_SEARCH_SPACE.end()); + } + + for (auto& delim : delimiters) { + for (auto& quote : quoteChars) { + for (auto& escape : escapeChars) { + DialectOption option{delim, quote, escape}; + options.push_back(option); + } + } + } + return options; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/driver.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/driver.cpp new file mode 100644 index 0000000000..158a775e62 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/driver.cpp @@ -0,0 +1,309 @@ +#include "processor/operator/persistent/reader/csv/driver.h" + +#include "common/string_format.h" +#include "common/system_config.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" +#include "utf8proc_wrapper.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +ParsingDriver::ParsingDriver(common::DataChunk& chunk, DriverType type /* = DriverType::PARSING */) + : driverType(type), chunk(chunk), rowEmpty(false) {} + +bool ParsingDriver::done(uint64_t rowNum) { + return rowNum >= DEFAULT_VECTOR_CAPACITY || doneEarly(); +} + +bool ParsingDriver::addValue(uint64_t rowNum, common::column_id_t columnIdx, + std::string_view value) { + uint64_t length = value.length(); + if (length == 0 && columnIdx == 0) { + rowEmpty = true; + } else { + rowEmpty = false; + } + BaseCSVReader* reader = getReader(); + + if (columnIdx == reader->getNumColumns() && length == 0) { + // skip a single trailing delimiter in last columnIdx + return true; + } + if (columnIdx >= reader->getNumColumns()) { + reader->handleCopyException( + stringFormat("expected {} values per row, but got more.", reader->getNumColumns())); + return false; + } + if (reader->skipColumn(columnIdx)) { + return true; + } + try { + function::CastString::copyStringToVector(&chunk.getValueVectorMutable(columnIdx), rowNum, + value, &reader->option); + } catch (ConversionException& e) { + reader->handleCopyException(e.what()); + return false; + } + + return true; +} + +bool ParsingDriver::addRow(uint64_t rowNum, common::column_id_t columnCount, + std::optional warningDataWithColumnInfo) { + BaseCSVReader* reader = getReader(); + if (rowEmpty) { + rowEmpty = false; + if (reader->getNumColumns() != 1) { + return false; + } + // Otherwise, treat it as null. + } + if (columnCount < reader->getNumColumns()) { + // Column number mismatch. + reader->handleCopyException(stringFormat("expected {} values per row, but got {}.", + reader->getNumColumns(), columnCount)); + return false; + } + + if (warningDataWithColumnInfo.has_value()) { + const auto warningDataStartColumn = warningDataWithColumnInfo->warningDataStartColumnIdx; + const auto numWarningDataColumns = warningDataWithColumnInfo->data.numValues; + KU_ASSERT(numWarningDataColumns == CopyConstants::CSV_WARNING_DATA_NUM_COLUMNS); + for (idx_t i = 0; i < numWarningDataColumns; ++i) { + const auto& warningData = warningDataWithColumnInfo->data.values[i]; + const auto columnIdx = warningDataStartColumn + i; + KU_ASSERT(columnIdx < chunk.getNumValueVectors()); + auto& vectorToSet = chunk.getValueVectorMutable(columnIdx); + std::visit( + [&vectorToSet, rowNum]( + auto warningDataField) { vectorToSet.setValue(rowNum, warningDataField); }, + warningData); + } + } + return true; +} + +ParallelParsingDriver::ParallelParsingDriver(common::DataChunk& chunk, ParallelCSVReader* reader) + : ParsingDriver(chunk, DriverType::PARALLEL), reader(reader) {} + +bool ParallelParsingDriver::doneEarly() { + return reader->finishedBlock(); +} + +BaseCSVReader* ParallelParsingDriver::getReader() { + return reader; +} + +SerialParsingDriver::SerialParsingDriver(common::DataChunk& chunk, SerialCSVReader* reader, + DriverType type /*= DriverType::SERIAL*/) + : ParsingDriver(chunk, type), reader(reader) {} + +bool SerialParsingDriver::doneEarly() { + return false; +} + +BaseCSVReader* SerialParsingDriver::getReader() { + return reader; +} + +common::DataChunk& getDummyDataChunk() { + static common::DataChunk dummyChunk = DataChunk(); // static ensures it's created only once + return dummyChunk; +} + +SniffCSVDialectDriver::SniffCSVDialectDriver(SerialCSVReader* reader) + : SerialParsingDriver(getDummyDataChunk(), reader, DriverType::SNIFF_CSV_DIALECT) { + auto& csvOption = reader->getCSVOption(); + columnCounts = std::vector(csvOption.sampleSize, 0); +} + +bool SniffCSVDialectDriver::addValue(uint64_t /*rowNum*/, common::column_id_t columnIdx, + std::string_view value) { + uint64_t length = value.length(); + if (length == 0 && columnIdx == 0) { + rowEmpty = true; + } else { + rowEmpty = false; + } + if (columnIdx == reader->getNumColumns() && length == 0) { + // skip a single trailing delimiter in last columnIdx + return true; + } + currentColumnCount++; + return true; +} + +bool SniffCSVDialectDriver::addRow(uint64_t /*rowNum*/, common::column_id_t /*columnCount*/, + std::optional /*warningData*/) { + auto& csvOption = reader->getCSVOption(); + if (rowEmpty) { + rowEmpty = false; + if (reader->getNumColumns() != 1) { + currentColumnCount = 0; + return false; + } + // Otherwise, treat it as null. + } + if (resultPosition < csvOption.sampleSize) { + columnCounts[resultPosition] = currentColumnCount; + currentColumnCount = 0; + resultPosition++; + } + return true; +} + +bool SniffCSVDialectDriver::done(uint64_t rowNum) const { + auto& csvOption = reader->getCSVOption(); + return (csvOption.hasHeader ? 1 : 0) + csvOption.sampleSize <= rowNum; +} + +void SniffCSVDialectDriver::reset() { + columnCounts = std::vector(columnCounts.size(), 0); + currentColumnCount = 0; + error = false; + resultPosition = 0; + everQuoted = false; + everEscaped = false; +} + +SniffCSVNameAndTypeDriver::SniffCSVNameAndTypeDriver(SerialCSVReader* reader, + const function::ExtraScanTableFuncBindInput* bindInput) + : SerialParsingDriver(getDummyDataChunk(), reader, DriverType::SNIFF_CSV_NAME_AND_TYPE) { + if (bindInput != nullptr) { + for (auto i = 0u; i < bindInput->expectedColumnNames.size(); i++) { + columns.push_back( + {bindInput->expectedColumnNames[i], bindInput->expectedColumnTypes[i].copy()}); + sniffType.push_back(false); + } + } +} + +bool SniffCSVNameAndTypeDriver::done(uint64_t rowNum) { + auto& csvOption = reader->getCSVOption(); + bool finished = (csvOption.hasHeader ? 1 : 0) + csvOption.sampleSize <= rowNum; + // if the csv only has one row + if (finished && rowNum <= 1 && csvOption.autoDetection && !csvOption.setHeader) { + for (auto columnIdx = 0u; columnIdx < firstRow.size(); ++columnIdx) { + auto value = firstRow[columnIdx]; + if (!utf8proc::Utf8Proc::isValid(value.data(), value.length())) { + reader->handleCopyException("Invalid UTF8-encoded string.", true /* mustThrow */); + } + std::string columnName = std::string(value); + LogicalType columnType = function::inferMinimalTypeFromString(value); + columns[columnIdx].first = columnName; + columns[columnIdx].second = std::move(columnType); + } + } + return finished; +} + +bool SniffCSVNameAndTypeDriver::addValue(uint64_t rowNum, common::column_id_t columnIdx, + std::string_view value) { + uint64_t length = value.length(); + if (length == 0 && columnIdx == 0) { + rowEmpty = true; + } else { + rowEmpty = false; + } + if (columnIdx == reader->getNumColumns() && length == 0) { + // skip a single trailing delimiter in last columnIdx + return true; + } + auto& csvOption = reader->getCSVOption(); + if (columns.size() < columnIdx + 1 && csvOption.hasHeader && rowNum > 0) { + reader->handleCopyException( + stringFormat("expected {} values per row, but got more.", reader->getNumColumns())); + } + while (columns.size() < columnIdx + 1) { + columns.emplace_back(stringFormat("column{}", columns.size()), LogicalType::ANY()); + sniffType.push_back(true); + } + if (rowNum == 0 && csvOption.hasHeader) { + // reading the header + std::string columnName(value); + LogicalType columnType(LogicalTypeID::ANY); + auto it = value.rfind(':'); + if (it != std::string_view::npos) { + try { + columnType = LogicalType::convertFromString(std::string(value.substr(it + 1)), + reader->getClientContext()); + columnName = std::string(value.substr(0, it)); + sniffType[columnIdx] = false; + } catch (const Exception&) { // NOLINT(bugprone-empty-catch): + // This is how we check for a suitable + // datatype name. + // Didn't parse, just use the whole name. + } + } + columns[columnIdx].first = columnName; + columns[columnIdx].second = std::move(columnType); + } else if (sniffType[columnIdx] && + (rowNum != 0 || !csvOption.autoDetection || csvOption.setHeader)) { + // reading the body + LogicalType combinedType; + columns[columnIdx].second = LogicalTypeUtils::combineTypes(columns[columnIdx].second, + function::inferMinimalTypeFromString(value)); + if (columns[columnIdx].second.getLogicalTypeID() == LogicalTypeID::STRING) { + sniffType[columnIdx] = false; + } + } else if (sniffType[columnIdx] && + (rowNum == 0 && csvOption.autoDetection && !csvOption.setHeader)) { + // store the first line for later use + firstRow.push_back(std::string{value}); + } + + return true; +} + +SniffCSVHeaderDriver::SniffCSVHeaderDriver(SerialCSVReader* reader, + const std::vector>& typeDetected) + : SerialParsingDriver(getDummyDataChunk(), reader, DriverType::SNIFF_CSV_HEADER) { + for (auto i = 0u; i < typeDetected.size(); i++) { + columns.push_back({typeDetected[i].first, typeDetected[i].second.copy()}); + } +} + +bool SniffCSVHeaderDriver::addValue(uint64_t /*rowNum*/, common::column_id_t columnIdx, + std::string_view value) { + uint64_t length = value.length(); + if (length == 0 && columnIdx == 0) { + rowEmpty = true; + } else { + rowEmpty = false; + } + if (columnIdx == reader->getNumColumns() && length == 0) { + // skip a single trailing delimiter in last columnIdx + return true; + } + + // reading the header + LogicalType columnType(LogicalTypeID::ANY); + + columnType = function::inferMinimalTypeFromString(value); + + // Store the value to Header vector for potential later use. + header.push_back({std::string(value), columnType.copy()}); + + // If we already determined has a header, just skip + if (detectedHeader) { + return true; + } + + // If any of the column in the first row cannot be casted to its expected type, we have a + // header. + if (columnType.getLogicalTypeID() == LogicalTypeID::STRING && + columnType.getLogicalTypeID() != columns[columnIdx].second.getLogicalTypeID() && + LogicalTypeID::BLOB != columns[columnIdx].second.getLogicalTypeID() && + LogicalTypeID::UNION != columns[columnIdx].second.getLogicalTypeID()) { + detectedHeader = true; + } + + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp new file mode 100644 index 0000000000..14039d29d1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp @@ -0,0 +1,334 @@ +#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" + +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" + +#if defined(_WIN32) +#include +#endif + +#include "common/string_format.h" +#include "common/system_message.h" +#include "function/table/table_function.h" +#include "processor/operator/persistent/reader/csv/driver.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +ParallelCSVReader::ParallelCSVReader(const std::string& filePath, idx_t fileIdx, CSVOption option, + CSVColumnInfo columnInfo, main::ClientContext* context, LocalFileErrorHandler* errorHandler) + : BaseCSVReader{filePath, fileIdx, std::move(option), std::move(columnInfo), context, + errorHandler} {} + +bool ParallelCSVReader::hasMoreToRead() const { + // If we haven't started the first block yet or are done our block, get the next block. + return buffer != nullptr && !finishedBlock(); +} + +uint64_t ParallelCSVReader::parseBlock(block_idx_t blockIdx, DataChunk& resultChunk) { + currentBlockIdx = blockIdx; + resetNumRowsInCurrentBlock(); + seekToBlockStart(); + if (blockIdx == 0) { + readBOM(); + if (option.hasHeader) { + const auto [numRowsRead, numErrors] = readHeader(); + errorHandler->setHeaderNumRows(numRowsRead + numErrors); + } + } + if (finishedBlock()) { + return 0; + } + ParallelParsingDriver driver(resultChunk, this); + const auto [numRowsRead, numErrors] = parseCSV(driver); + increaseNumRowsInCurrentBlock(numRowsRead, numErrors); + return numRowsRead; +} + +void ParallelCSVReader::reportFinishedBlock() { + errorHandler->reportFinishedBlock(currentBlockIdx, getNumRowsInCurrentBlock()); +} + +uint64_t ParallelCSVReader::continueBlock(DataChunk& resultChunk) { + KU_ASSERT(hasMoreToRead()); + ParallelParsingDriver driver(resultChunk, this); + const auto [numRowsParsed, numErrors] = parseCSV(driver); + increaseNumRowsInCurrentBlock(numRowsParsed, numErrors); + return numRowsParsed; +} + +void ParallelCSVReader::seekToBlockStart() { + // Seek to the proper location in the file. + if (fileInfo->seek(currentBlockIdx * CopyConstants::PARALLEL_BLOCK_SIZE, SEEK_SET) == -1) { + // LCOV_EXCL_START + handleCopyException( + stringFormat("Failed to seek to block {}: {}", currentBlockIdx, posixErrMessage()), + true); + // LCOV_EXCL_STOP + } + osFileOffset = currentBlockIdx * CopyConstants::PARALLEL_BLOCK_SIZE; + + if (currentBlockIdx == 0) { + // First block doesn't search for a newline. + return; + } + + // Reset the buffer. + position = 0; + bufferSize = 0; + buffer.reset(); + if (!readBuffer(nullptr)) { + return; + } + + // Find the start of the next line. + do { + for (; position < bufferSize; position++) { + if (buffer[position] == '\r') { + position++; + if (!maybeReadBuffer(nullptr)) { + return; + } + if (buffer[position] == '\n') { + position++; + } + return; + } else if (buffer[position] == '\n') { + position++; + return; + } + } + } while (readBuffer(nullptr)); +} + +bool ParallelCSVReader::handleQuotedNewline() { + lineContext.setEndOfLine(getFileOffset()); + handleCopyException("Quoted newlines are not supported in parallel CSV reader." + " Please specify PARALLEL=FALSE in the options."); + return false; +} + +bool ParallelCSVReader::finishedBlock() const { + // Only stop if we've ventured into the next block by at least a byte. + // Use `>` because `position` points to just past the newline right now. + return getFileOffset() > (currentBlockIdx + 1) * CopyConstants::PARALLEL_BLOCK_SIZE; +} + +ParallelCSVScanSharedState::ParallelCSVScanSharedState(FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context, CSVOption csvOption, CSVColumnInfo columnInfo) + : ScanFileWithProgressSharedState{std::move(fileScanInfo), numRows, context}, + csvOption{std::move(csvOption)}, columnInfo{std::move(columnInfo)}, numBlocksReadByFiles{0} { + errorHandlers.reserve(this->fileScanInfo.getNumFiles()); + for (idx_t i = 0; i < this->fileScanInfo.getNumFiles(); ++i) { + errorHandlers.emplace_back(i, &mtx); + } + populateErrorFunc = constructPopulateFunc(); + for (auto& errorHandler : errorHandlers) { + errorHandler.setPopulateErrorFunc(populateErrorFunc); + } +} + +populate_func_t ParallelCSVScanSharedState::constructPopulateFunc() { + const auto numFiles = fileScanInfo.getNumFiles(); + auto localErrorHandlers = std::vector>(numFiles); + auto readers = std::vector>(numFiles); + for (idx_t i = 0; i < numFiles; ++i) { + // If we run into errors while reconstructing lines they should be unrecoverable + localErrorHandlers[i] = + std::make_shared(&errorHandlers[i], false, context); + readers[i] = std::make_shared(fileScanInfo.filePaths[i], i, + csvOption.copy(), columnInfo.copy(), context, localErrorHandlers[i].get()); + } + return [this, movedErrorHandlers = std::move(localErrorHandlers), + movedReaders = std::move(readers)](CopyFromFileError error, + idx_t fileIdx) -> PopulatedCopyFromError { + return BaseCSVReader::basePopulateErrorFunc(std::move(error), &errorHandlers[fileIdx], + movedReaders[fileIdx].get(), fileScanInfo.getFilePath(fileIdx)); + }; +} + +void ParallelCSVScanSharedState::setFileComplete(uint64_t completedFileIdx) { + std::lock_guard guard{mtx}; + if (completedFileIdx == fileIdx) { + numBlocksReadByFiles += blockIdx; + blockIdx = 0; + fileIdx++; + } +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto& outputChunk = output.dataChunk; + + auto localState = input.localState->ptrCast(); + auto sharedState = input.sharedState->ptrCast(); + + do { + if (localState->reader != nullptr) { + if (localState->reader->hasMoreToRead()) { + auto result = localState->reader->continueBlock(outputChunk); + outputChunk.state->getSelVectorUnsafe().setSelSize(result); + if (result > 0) { + return result; + } + } + localState->reader->reportFinishedBlock(); + } + auto [fileIdx, blockIdx] = sharedState->getNext(); + if (fileIdx == UINT64_MAX) { + return 0; + } + if (fileIdx != localState->fileIdx) { + localState->fileIdx = fileIdx; + localState->errorHandler = + std::make_unique(&sharedState->errorHandlers[fileIdx], + sharedState->csvOption.ignoreErrors, sharedState->context, true); + localState->reader = + std::make_unique(sharedState->fileScanInfo.filePaths[fileIdx], + fileIdx, sharedState->csvOption.copy(), sharedState->columnInfo.copy(), + sharedState->context, localState->errorHandler.get()); + } + auto numRowsRead = localState->reader->parseBlock(blockIdx, outputChunk); + + // if there are any pending errors to throw, stop the parsing + // the actual error will be thrown during finalize + if (!sharedState->csvOption.ignoreErrors && + sharedState->errorHandlers[fileIdx].getNumCachedErrors() > 0) { + numRowsRead = 0; + } + + outputChunk.state->getSelVectorUnsafe().setSelSize(numRowsRead); + if (numRowsRead > 0) { + return numRowsRead; + } + if (localState->reader->isEOF()) { + localState->reader->reportFinishedBlock(); + localState->errorHandler->finalize(); + sharedState->setFileComplete(localState->fileIdx); + localState->reader = nullptr; + localState->errorHandler = nullptr; + } + } while (true); +} + +static std::unique_ptr bindFunc(main::ClientContext* context, + const TableFuncBindInput* input) { + auto scanInput = ku_dynamic_cast(input->extraInput.get()); + bool detectedHeader = false; + + DialectOption detectedDialect; + auto csvOption = CSVReaderConfig::construct(scanInput->fileScanInfo.options).option; + detectedDialect.doDialectDetection = csvOption.autoDetection; + + std::vector detectedColumnNames; + std::vector detectedColumnTypes; + SerialCSVScan::bindColumns(scanInput, detectedColumnNames, detectedColumnTypes, detectedDialect, + detectedHeader, context); + + std::vector resultColumnNames; + std::vector resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + + if (csvOption.autoDetection) { + std::string quote(1, detectedDialect.quoteChar); + std::string delim(1, detectedDialect.delimiter); + std::string escape(1, detectedDialect.escapeChar); + scanInput->fileScanInfo.options.insert_or_assign("ESCAPE", + Value(LogicalType::STRING(), escape)); + scanInput->fileScanInfo.options.insert_or_assign("QUOTE", + Value(LogicalType::STRING(), quote)); + scanInput->fileScanInfo.options.insert_or_assign("DELIM", + Value(LogicalType::STRING(), delim)); + } + + if (!csvOption.setHeader && csvOption.autoDetection && detectedHeader) { + scanInput->fileScanInfo.options.insert_or_assign("HEADER", Value(detectedHeader)); + } + + resultColumnNames = + TableFunction::extractYieldVariables(resultColumnNames, input->yieldVariables); + auto resultColumns = input->binder->createVariables(resultColumnNames, resultColumnTypes); + std::vector warningColumnNames; + std::vector warningColumnTypes; + const column_id_t numWarningDataColumns = BaseCSVReader::appendWarningDataColumns( + warningColumnNames, warningColumnTypes, scanInput->fileScanInfo); + auto warningColumns = + input->binder->createInvisibleVariables(warningColumnNames, warningColumnTypes); + for (auto& column : warningColumns) { + resultColumns.push_back(column); + } + return std::make_unique(std::move(resultColumns), 0 /* numRows */, + scanInput->fileScanInfo.copy(), context, numWarningDataColumns); +} + +static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input) { + auto bindData = input.bindData->constPtrCast(); + auto csvOption = CSVReaderConfig::construct(bindData->fileScanInfo.options).option; + auto columnInfo = CSVColumnInfo(bindData->getNumColumns() - bindData->numWarningDataColumns, + bindData->getColumnSkips(), bindData->numWarningDataColumns); + auto sharedState = std::make_unique(bindData->fileScanInfo.copy(), + 0 /* numRows */, bindData->context, csvOption.copy(), columnInfo.copy()); + + for (idx_t i = 0; i < sharedState->fileScanInfo.getNumFiles(); ++i) { + auto filePath = sharedState->fileScanInfo.filePaths[i]; + auto reader = std::make_unique(filePath, i, csvOption.copy(), + columnInfo.copy(), bindData->context, nullptr); + sharedState->totalSize += reader->getFileSize(); + } + + return sharedState; +} + +static std::unique_ptr initLocalState(const TableFuncInitLocalStateInput&) { + auto localState = std::make_unique(); + localState->fileIdx = std::numeric_limitsfileIdx)>::max(); + return localState; +} + +static double progressFunc(TableFuncSharedState* sharedState) { + auto state = sharedState->ptrCast(); + if (state->fileIdx >= state->fileScanInfo.getNumFiles()) { + return 1.0; + } + if (state->totalSize == 0) { + return 0.0; + } + uint64_t totalReadSize = + (state->numBlocksReadByFiles + state->blockIdx) * CopyConstants::PARALLEL_BLOCK_SIZE; + if (totalReadSize > state->totalSize) { + return 1.0; + } + return static_cast(totalReadSize) / state->totalSize; +} + +static void finalizeFunc(const ExecutionContext* ctx, TableFuncSharedState* sharedState) { + auto state = ku_dynamic_cast(sharedState); + for (idx_t i = 0; i < state->fileScanInfo.getNumFiles(); ++i) { + state->errorHandlers[i].throwCachedErrorsIfNeeded(); + } + WarningContext::Get(*ctx->clientContext) + ->populateWarnings(ctx->queryID, state->populateErrorFunc, BaseCSVReader::getFileIdxFunc); +} + +function_set ParallelCSVScan::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = tableFunc; + function->bindFunc = bindFunc; + function->initSharedStateFunc = initSharedState; + function->initLocalStateFunc = initLocalState; + function->progressFunc = progressFunc; + function->finalizeFunc = finalizeFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp new file mode 100644 index 0000000000..e0c98418be --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp @@ -0,0 +1,472 @@ +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" + +#include "binder/binder.h" +#include "function/table/bind_data.h" +#include "function/table/table_function.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/reader/csv/driver.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +SerialCSVReader::SerialCSVReader(const std::string& filePath, idx_t fileIdx, CSVOption option, + CSVColumnInfo columnInfo, main::ClientContext* context, LocalFileErrorHandler* errorHandler, + const ExtraScanTableFuncBindInput* bindInput) + : BaseCSVReader{filePath, fileIdx, std::move(option), std::move(columnInfo), context, + errorHandler}, + bindInput{bindInput} {} + +std::vector> SerialCSVReader::sniffCSV( + DialectOption& detectedDialect, bool& detectedHeader) { + auto csvOption = CSVReaderConfig::construct(bindInput->fileScanInfo.options).option; + readBOM(); + + if (detectedDialect.doDialectDetection) { + detectedDialect = detectDialect(); + } + + SniffCSVNameAndTypeDriver driver{this, bindInput}; + parseCSV(driver); + + for (auto& i : driver.columns) { + // purge null types + i.second = LogicalTypeUtils::purgeAny(i.second, LogicalType::STRING()); + } + // Do header detection IFF user didn't set header AND user didn't turn off auto detection + if (!csvOption.setHeader && csvOption.autoDetection) { + detectedHeader = detectHeader(driver.columns); + } + + // finalize the columns; rename duplicate names + std::map names; + for (auto& i : driver.columns) { + // Suppose name "col" already exists + // Let N be the number of times it exists + // rename to "col" + "_{N}" + // ideally "col_{N}" shouldn't exist, but if it already exists M times (due to user + // declaration), rename to "col_{N}" + "_{M}" repeat until no match exists + while (names.contains(i.first)) { + names[i.first]++; + i.first += "_" + std::to_string(names[i.first]); + } + names[i.first]; + // purge null types + i.second = LogicalTypeUtils::purgeAny(i.second, LogicalType::STRING()); + } + return std::move(driver.columns); +} + +uint64_t SerialCSVReader::parseBlock(block_idx_t blockIdx, DataChunk& resultChunk) { + KU_ASSERT(nullptr != errorHandler); + + if (blockIdx != currentBlockIdx) { + resetNumRowsInCurrentBlock(); + } + currentBlockIdx = blockIdx; + if (blockIdx == 0) { + const auto [numRowsRead, numErrors] = handleFirstBlock(); + errorHandler->setHeaderNumRows(numRowsRead + numErrors); + } + SerialParsingDriver driver(resultChunk, this); + const auto [numRowsRead, numErrors] = parseCSV(driver); + errorHandler->reportFinishedBlock(blockIdx, numRowsRead + numErrors); + resultChunk.state->getSelVectorUnsafe().setSelSize(numRowsRead); + increaseNumRowsInCurrentBlock(numRowsRead, numErrors); + return numRowsRead; +} + +SerialCSVScanSharedState::SerialCSVScanSharedState(FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context, CSVOption csvOption, CSVColumnInfo columnInfo, uint64_t queryID) + : ScanFileWithProgressSharedState{std::move(fileScanInfo), numRows, context}, + csvOption{std::move(csvOption)}, columnInfo{std::move(columnInfo)}, totalReadSizeByFile{0}, + queryID(queryID), populateErrorFunc(constructPopulateFunc()) { + std::lock_guard lck{mtx}; + initReader(context); +} + +populate_func_t SerialCSVScanSharedState::constructPopulateFunc() const { + return [this](CopyFromFileError error, idx_t fileIdx) -> PopulatedCopyFromError { + return BaseCSVReader::basePopulateErrorFunc(std::move(error), sharedErrorHandler.get(), + reader.get(), fileScanInfo.getFilePath(fileIdx)); + }; +} + +void SerialCSVScanSharedState::read(DataChunk& outputChunk) { + std::lock_guard lck{mtx}; + do { + if (fileIdx >= fileScanInfo.getNumFiles()) { + return; + } + uint64_t numRows = reader->parseBlock(reader->getFileOffset() == 0 ? 0 : 1, outputChunk); + if (numRows > 0) { + return; + } + totalReadSizeByFile += reader->getFileSize(); + finalizeReader(context); + fileIdx++; + initReader(context); + } while (true); +} + +void SerialCSVScanSharedState::finalizeReader(main::ClientContext* context) const { + if (localErrorHandler) { + localErrorHandler->finalize(); + } + if (sharedErrorHandler) { + sharedErrorHandler->throwCachedErrorsIfNeeded(); + WarningContext::Get(*context)->populateWarnings(queryID, populateErrorFunc, + BaseCSVReader::getFileIdxFunc); + } +} + +void SerialCSVScanSharedState::initReader(main::ClientContext* context) { + if (fileIdx < fileScanInfo.getNumFiles()) { + sharedErrorHandler = + std::make_unique(fileIdx, nullptr, populateErrorFunc); + localErrorHandler = std::make_unique(sharedErrorHandler.get(), + csvOption.ignoreErrors, context); + reader = std::make_unique(fileScanInfo.filePaths[fileIdx], fileIdx, + csvOption.copy(), columnInfo.copy(), context, localErrorHandler.get()); + } +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto serialCSVScanSharedState = ku_dynamic_cast(input.sharedState); + serialCSVScanSharedState->read(output.dataChunk); + return output.dataChunk.state->getSelVector().getSelSize(); +} + +static void bindColumnsFromFile(const ExtraScanTableFuncBindInput* bindInput, uint32_t fileIdx, + std::vector& columnNames, std::vector& columnTypes, + DialectOption& detectedDialect, bool& detectedHeader, main::ClientContext* context) { + auto csvOption = CSVReaderConfig::construct(bindInput->fileScanInfo.options).option; + auto columnInfo = CSVColumnInfo(bindInput->expectedColumnNames.size() /* numColumns */, + {} /* columnSkips */, {} /*warningDataColumns*/); + SharedFileErrorHandler sharedErrorHandler{fileIdx, nullptr}; + // We don't want to cache CSV errors encountered during sniffing, they will be re-encountered + // when actually parsing + LocalFileErrorHandler errorHandler{&sharedErrorHandler, csvOption.ignoreErrors, context, false}; + auto csvReader = SerialCSVReader(bindInput->fileScanInfo.filePaths[fileIdx], fileIdx, + csvOption.copy(), columnInfo.copy(), context, &errorHandler, bindInput); + sharedErrorHandler.setPopulateErrorFunc( + [&sharedErrorHandler, &csvReader, bindInput](CopyFromFileError error, + idx_t fileIdx) -> PopulatedCopyFromError { + return BaseCSVReader::basePopulateErrorFunc(std::move(error), &sharedErrorHandler, + &csvReader, bindInput->fileScanInfo.filePaths[fileIdx]); + }); + auto sniffedColumns = csvReader.sniffCSV(detectedDialect, detectedHeader); + sharedErrorHandler.throwCachedErrorsIfNeeded(); + for (auto& [name, type] : sniffedColumns) { + columnNames.push_back(name); + columnTypes.push_back(type.copy()); + } +} + +void SerialCSVScan::bindColumns(const ExtraScanTableFuncBindInput* bindInput, + std::vector& columnNames, std::vector& columnTypes, + DialectOption& detectedDialect, bool& detectedHeader, main::ClientContext* context) { + KU_ASSERT(bindInput->fileScanInfo.getNumFiles() > 0); + bindColumnsFromFile(bindInput, 0, columnNames, columnTypes, detectedDialect, detectedHeader, + context); + for (auto i = 1u; i < bindInput->fileScanInfo.getNumFiles(); ++i) { + std::vector tmpColumnNames; + std::vector tmpColumnTypes; + bindColumnsFromFile(bindInput, i, tmpColumnNames, tmpColumnTypes, detectedDialect, + detectedHeader, context); + ReaderBindUtils::validateNumColumns(columnTypes.size(), tmpColumnTypes.size()); + } +} + +static std::unique_ptr bindFunc(main::ClientContext* context, + const TableFuncBindInput* input) { + auto scanInput = ku_dynamic_cast(input->extraInput.get()); + if (scanInput->expectedColumnTypes.size() > 0) { + scanInput->fileScanInfo.options.insert_or_assign("SAMPLE_SIZE", + Value((int64_t)0)); // only scan headers + } + + bool detectedHeader = false; + + DialectOption detectedDialect; + auto csvOption = CSVReaderConfig::construct(scanInput->fileScanInfo.options).option; + detectedDialect.doDialectDetection = csvOption.autoDetection; + + std::vector detectedColumnNames; + std::vector detectedColumnTypes; + SerialCSVScan::bindColumns(scanInput, detectedColumnNames, detectedColumnTypes, detectedDialect, + detectedHeader, context); + + std::vector resultColumnNames; + std::vector resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + + if (detectedDialect.doDialectDetection) { + std::string quote(1, detectedDialect.quoteChar); + std::string delim(1, detectedDialect.delimiter); + std::string escape(1, detectedDialect.escapeChar); + scanInput->fileScanInfo.options.insert_or_assign("ESCAPE", + Value(LogicalType::STRING(), escape)); + scanInput->fileScanInfo.options.insert_or_assign("QUOTE", + Value(LogicalType::STRING(), quote)); + scanInput->fileScanInfo.options.insert_or_assign("DELIM", + Value(LogicalType::STRING(), delim)); + } + + if (!csvOption.setHeader && csvOption.autoDetection && detectedHeader) { + scanInput->fileScanInfo.options.insert_or_assign("HEADER", Value(detectedHeader)); + } + + resultColumnNames = + TableFunction::extractYieldVariables(resultColumnNames, input->yieldVariables); + auto resultColumns = input->binder->createVariables(resultColumnNames, resultColumnTypes); + std::vector warningColumnNames; + std::vector warningColumnTypes; + const column_id_t numWarningDataColumns = BaseCSVReader::appendWarningDataColumns( + warningColumnNames, warningColumnTypes, scanInput->fileScanInfo); + auto warningColumns = + input->binder->createInvisibleVariables(warningColumnNames, warningColumnTypes); + for (auto& column : warningColumns) { + resultColumns.push_back(column); + } + return std::make_unique(std::move(resultColumns), 0 /* numRows */, + scanInput->fileScanInfo.copy(), context, numWarningDataColumns); +} + +static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input) { + auto bindData = input.bindData->constPtrCast(); + auto csvOption = CSVReaderConfig::construct(bindData->fileScanInfo.options).option; + auto columnInfo = CSVColumnInfo(bindData->getNumColumns() - bindData->numWarningDataColumns, + bindData->getColumnSkips(), bindData->numWarningDataColumns); + auto sharedState = + std::make_unique(bindData->fileScanInfo.copy(), 0 /* numRows */, + bindData->context, csvOption.copy(), columnInfo.copy(), input.context->queryID); + for (idx_t i = 0; i < sharedState->fileScanInfo.filePaths.size(); ++i) { + const auto& filePath = sharedState->fileScanInfo.filePaths[i]; + auto reader = std::make_unique(filePath, i, csvOption.copy(), + columnInfo.copy(), sharedState->context, nullptr); + sharedState->totalSize += reader->getFileSize(); + } + return sharedState; +} + +static void finalizeFunc(const ExecutionContext* ctx, TableFuncSharedState* sharedState) { + auto state = ku_dynamic_cast(sharedState); + state->finalizeReader(ctx->clientContext); +} + +static double progressFunc(TableFuncSharedState* sharedState) { + auto state = ku_dynamic_cast(sharedState); + if (state->totalSize == 0) { + return 0.0; + } else if (state->fileIdx >= state->fileScanInfo.getNumFiles()) { + return 1.0; + } + std::lock_guard lck{state->mtx}; + uint64_t totalReadSize = state->totalReadSizeByFile + state->reader->getFileOffset(); + return static_cast(totalReadSize) / state->totalSize; +} + +function_set SerialCSVScan::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = tableFunc; + function->bindFunc = bindFunc; + function->initSharedStateFunc = initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + function->progressFunc = progressFunc; + function->finalizeFunc = finalizeFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +void SerialCSVReader::resetReaderState() { + // Reset file position to the beginning. + fileInfo->reset(); + buffer.reset(); + bufferSize = 0; + position = 0; + osFileOffset = 0; + bufferIdx = 0; + lineContext.setNewLine(getFileOffset()); + + readBOM(); +} + +DialectOption SerialCSVReader::detectDialect() { + // Extract a sample of rows from the file for dialect detection. + SniffCSVDialectDriver driver{this}; + + // Generate dialect options based on the non-user-specified options. + auto dialectSearchSpace = generateDialectOptions(option); + + // Save default for dialect not found situation. + DialectOption defaultOption{option.delimiter, option.quoteChar, option.escapeChar}; + + idx_t bestConsistentRows = 0; + idx_t maxColumnsFound = 0; + idx_t minIgnoredRows = 0; + std::vector validDialects; + std::vector finalDialects; + for (auto& dialectOption : dialectSearchSpace) { + bool notExpected = false; + // Load current dialect option. + option.delimiter = dialectOption.delimiter; + option.quoteChar = dialectOption.quoteChar; + option.escapeChar = dialectOption.escapeChar; + // reset Driver. + driver.reset(); + // Try parsing it with current dialect. + parseCSV(driver); + // Reset the file position and buffer to start reading from the beginning after detection. + resetReaderState(); + // If never unquoting quoted values or any other error during the parsing, discard this + // dialect. + if (driver.getError()) { + continue; + } + + idx_t ignoredRows = 0; + idx_t consistentRows = 0; + idx_t numCols = driver.getResultPosition() == 0 ? 1 : driver.getColumnCount(0); + dialectOption.everQuoted = driver.getEverQuoted(); + dialectOption.everEscaped = driver.getEverEscaped(); + + // If the columns didn't match the user input columns number. + if (getNumColumns() != 0 && getNumColumns() != numCols) { + continue; + } + + for (auto row = 0u; row < driver.getResultPosition(); row++) { + if (getNumColumns() != 0 && getNumColumns() != driver.getColumnCount(row)) { + notExpected = true; + break; + } + if (numCols < driver.getColumnCount(row)) { + numCols = driver.getColumnCount(row); + consistentRows = 1; + } else if (driver.getColumnCount(row) == numCols) { + consistentRows++; + } else { + ignoredRows++; + } + } + + if (notExpected) { + continue; + } + + auto moreValues = consistentRows > bestConsistentRows && numCols >= maxColumnsFound; + auto singleColumnBefore = + maxColumnsFound < 2 && numCols > maxColumnsFound * validDialects.size(); + auto moreThanOneRow = consistentRows > 1; + auto moreThanOneColumn = numCols > 1; + + if (singleColumnBefore || moreValues || moreThanOneColumn) { + if (maxColumnsFound == numCols && ignoredRows > minIgnoredRows) { + continue; + } + if (!validDialects.empty() && validDialects.front().everQuoted && + !dialectOption.everQuoted) { + // Give preference to quoted dialect. + continue; + } + + if (!validDialects.empty() && validDialects.front().everEscaped && + !dialectOption.everEscaped) { + // Give preference to Escaped dialect. + continue; + } + + if (consistentRows >= bestConsistentRows) { + bestConsistentRows = consistentRows; + maxColumnsFound = numCols; + minIgnoredRows = ignoredRows; + validDialects.clear(); + validDialects.emplace_back(dialectOption); + } + } + + if (moreThanOneRow && moreThanOneColumn && numCols == maxColumnsFound) { + bool same_quote = false; + for (auto& validDialect : validDialects) { + if (validDialect.quoteChar == dialectOption.quoteChar) { + same_quote = true; + } + } + + if (!same_quote) { + validDialects.push_back(dialectOption); + } + } + } + + // If we have multiple validDialect with quotes set, we will give the preference to ones + // that have actually quoted values. + if (!validDialects.empty()) { + for (auto& validDialect : validDialects) { + if (validDialect.everQuoted) { + finalDialects.clear(); + finalDialects.emplace_back(validDialect); + break; + } + finalDialects.emplace_back(validDialect); + } + } + + // If the Dialect we found doesn't need Quote, we use empty as QuoteChar. + if (!finalDialects.empty() && !finalDialects[0].everQuoted && !option.setQuote) { + finalDialects[0].quoteChar = '\0'; + } + // If the Dialect we found doesn't need Escape, we use empty as EscapeChar. + if (!finalDialects.empty() && !finalDialects[0].everEscaped && !option.setEscape) { + finalDialects[0].escapeChar = '\0'; + } + + // Apply the detected dialect to the CSV options. + if (!finalDialects.empty()) { + option.delimiter = finalDialects[0].delimiter; + option.quoteChar = finalDialects[0].quoteChar; + option.escapeChar = finalDialects[0].escapeChar; + } else { + option.delimiter = defaultOption.delimiter; + option.quoteChar = defaultOption.quoteChar; + option.escapeChar = defaultOption.escapeChar; + } + + DialectOption ret{option.delimiter, option.quoteChar, option.escapeChar}; + return ret; +} + +bool SerialCSVReader::detectHeader( + std::vector>& detectedTypes) { + // Reset the file position and buffer to start reading from the beginning after detection. + resetReaderState(); + SniffCSVHeaderDriver sniffHeaderDriver{this, detectedTypes}; + readBOM(); + parseCSV(sniffHeaderDriver); + resetReaderState(); + // In this case, User didn't set Header, but we detected a Header, use the detected header to + // set the name and type. + if (sniffHeaderDriver.detectedHeader) { + // If the detected header has fewer columns that expected, treat it as if no header was + // detected + if (sniffHeaderDriver.header.size() < detectedTypes.size()) { + sniffHeaderDriver.detectedHeader = false; + return false; + } + + for (auto i = 0u; i < detectedTypes.size(); i++) { + detectedTypes[i].first = sniffHeaderDriver.header[i].first; + } + } + return sniffHeaderDriver.detectedHeader; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/file_error_handler.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/file_error_handler.cpp new file mode 100644 index 0000000000..7bf0b28c2c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/file_error_handler.cpp @@ -0,0 +1,237 @@ +#include "processor/operator/persistent/reader/file_error_handler.h" + +#include + +#include "common/assert.h" +#include "common/exception/copy.h" +#include "common/string_format.h" +#include "main/client_context.h" + +namespace lbug { +using namespace common; +namespace processor { + +void LineContext::setNewLine(uint64_t start) { + startByteOffset = start; + isCompleteLine = false; +} + +void LineContext::setEndOfLine(uint64_t end) { + endByteOffset = end; + isCompleteLine = true; +} + +SharedFileErrorHandler::SharedFileErrorHandler(common::idx_t fileIdx, std::mutex* sharedMtx, + populate_func_t populateErrorFunc) + : mtx(sharedMtx), fileIdx(fileIdx), populateErrorFunc(std::move(populateErrorFunc)), + headerNumRows(0) {} + +uint64_t SharedFileErrorHandler::getNumCachedErrors() { + auto lockGuard = lock(); + return cachedErrors.size(); +} + +void SharedFileErrorHandler::tryCacheError(CopyFromFileError error, const common::UniqLock&) { + if (cachedErrors.size() < MAX_CACHED_ERROR_COUNT) { + cachedErrors.push_back(std::move(error)); + } +} + +void SharedFileErrorHandler::handleError(CopyFromFileError error) { + auto lockGuard = lock(); + if (error.mustThrow) { + throwError(error); + } + + const auto blockIdx = error.warningData.getBlockIdx(); + if (blockIdx >= linesPerBlock.size()) { + linesPerBlock.resize(blockIdx + 1); + } + + // throwing of the error is not done when in the middle of parsing blocks + // so we cache the error to be thrown later + tryCacheError(std::move(error), lockGuard); +} + +void SharedFileErrorHandler::throwCachedErrorsIfNeeded() { + auto lockGuard = lock(); + tryThrowFirstCachedError(); +} + +void SharedFileErrorHandler::tryThrowFirstCachedError() { + if (cachedErrors.empty()) { + return; + } + + // we sort the cached errors to report the one with the earliest line number + std::sort(cachedErrors.begin(), cachedErrors.end()); + + const auto error = *cachedErrors.cbegin(); + KU_ASSERT(!error.mustThrow); + + const bool errorIsThrowable = canGetLineNumber(error.warningData.getBlockIdx()); + if (errorIsThrowable) { + throwError(error); + } +} + +namespace { +std::string getFilePathMessage(std::string_view filePath) { + static constexpr std::string_view invalidFilePath = ""; + return filePath == invalidFilePath ? std::string{} : + common::stringFormat(" in file {}", filePath); +} + +std::string getLineNumberMessage(uint64_t lineNumber) { + static constexpr uint64_t invalidLineNumber = 0; + return lineNumber == invalidLineNumber ? std::string{} : + common::stringFormat(" on line {}", lineNumber); +} + +std::string getSkippedLineMessage(std::string_view skippedLineOrRecord) { + static constexpr std::string_view emptySkippedLine = ""; + return skippedLineOrRecord == emptySkippedLine ? + std::string{} : + common::stringFormat(" Line/record containing the error: '{}'", skippedLineOrRecord); +} +} // namespace + +std::string SharedFileErrorHandler::getErrorMessage(PopulatedCopyFromError populatedError) const { + return common::stringFormat("Error{}{}: {}{}", getFilePathMessage(populatedError.filePath), + getLineNumberMessage(populatedError.lineNumber), populatedError.message, + getSkippedLineMessage(populatedError.skippedLineOrRecord)); +} + +void SharedFileErrorHandler::throwError(CopyFromFileError error) const { + KU_ASSERT(populateErrorFunc); + throw CopyException(getErrorMessage(populateErrorFunc(std::move(error), fileIdx))); +} + +common::UniqLock SharedFileErrorHandler::lock() { + if (mtx) { + return common::UniqLock{*mtx}; + } + return common::UniqLock{}; +} + +bool SharedFileErrorHandler::canGetLineNumber(uint64_t blockIdx) const { + if (blockIdx > linesPerBlock.size()) { + return false; + } + for (uint64_t i = 0; i < blockIdx; ++i) { + // the line count for a block is empty if it hasn't finished being parsed + if (!linesPerBlock[i].doneParsingBlock) { + return false; + } + } + return true; +} + +void SharedFileErrorHandler::setPopulateErrorFunc(populate_func_t newPopulateErrorFunc) { + populateErrorFunc = newPopulateErrorFunc; +} + +uint64_t SharedFileErrorHandler::getLineNumber(uint64_t blockIdx, + uint64_t numRowsReadInBlock) const { + // 1-indexed + uint64_t res = numRowsReadInBlock + headerNumRows + 1; + for (uint64_t i = 0; i < blockIdx; ++i) { + KU_ASSERT(i < linesPerBlock.size()); + res += linesPerBlock[i].numLines; + } + return res; +} + +void SharedFileErrorHandler::setHeaderNumRows(uint64_t numRows) { + if (numRows == headerNumRows) { + return; + } + auto lockGuard = lock(); + headerNumRows = numRows; +} + +void SharedFileErrorHandler::updateLineNumberInfo( + const std::map& newLinesPerBlock, bool canThrowCachedError) { + const auto lockGuard = lock(); + + if (!newLinesPerBlock.empty()) { + const auto maxNewBlockIdx = newLinesPerBlock.rbegin()->first; + if (maxNewBlockIdx >= linesPerBlock.size()) { + linesPerBlock.resize(maxNewBlockIdx + 1); + } + + for (const auto& [blockIdx, linesInBlock] : newLinesPerBlock) { + auto& currentBlock = linesPerBlock[blockIdx]; + currentBlock.numLines += linesInBlock.numLines; + currentBlock.doneParsingBlock = + currentBlock.doneParsingBlock || linesInBlock.doneParsingBlock; + } + } + + if (canThrowCachedError) { + tryThrowFirstCachedError(); + } +} + +LocalFileErrorHandler::LocalFileErrorHandler(SharedFileErrorHandler* sharedErrorHandler, + bool ignoreErrors, main::ClientContext* context, bool cacheErrors) + : sharedErrorHandler(sharedErrorHandler), context(context), + maxCachedErrorCount( + std::min(this->context->getClientConfig()->warningLimit, LOCAL_WARNING_LIMIT)), + ignoreErrors(ignoreErrors), cacheIgnoredErrors(cacheErrors) {} + +void LocalFileErrorHandler::handleError(CopyFromFileError error) { + if (error.mustThrow || !ignoreErrors) { + // we delegate throwing to the shared error handler + sharedErrorHandler->handleError(std::move(error)); + return; + } + + KU_ASSERT(cachedErrors.size() <= maxCachedErrorCount); + if (cachedErrors.size() == maxCachedErrorCount) { + flushCachedErrors(); + } + + if (cacheIgnoredErrors) { + cachedErrors.push_back(std::move(error)); + } +} + +void LocalFileErrorHandler::reportFinishedBlock(uint64_t blockIdx, uint64_t numRowsRead) { + linesPerBlock[blockIdx].numLines += numRowsRead; + linesPerBlock[blockIdx].doneParsingBlock = true; + if (linesPerBlock.size() >= maxCachedErrorCount) { + flushCachedErrors(); + } +} + +void LocalFileErrorHandler::setHeaderNumRows(uint64_t numRows) { + sharedErrorHandler->setHeaderNumRows(numRows); +} + +LocalFileErrorHandler::~LocalFileErrorHandler() { + // we don't want to throw in the destructor + // so we leave throwing for later in the parsing stage or during finalize + flushCachedErrors(false); +} + +void LocalFileErrorHandler::finalize(bool canThrowCachedError) { + flushCachedErrors(canThrowCachedError); +} + +void LocalFileErrorHandler::flushCachedErrors(bool canThrowCachedError) { + if (!linesPerBlock.empty()) { + // clear linesPerBlock first so that it is empty even if updateLineNumberInfo() throws + decltype(linesPerBlock) oldLinesPerBlock; + oldLinesPerBlock.swap(linesPerBlock); + sharedErrorHandler->updateLineNumberInfo(oldLinesPerBlock, canThrowCachedError); + } + + if (!cachedErrors.empty()) { + WarningContext::Get(*context)->appendWarningMessages(cachedErrors); + cachedErrors.clear(); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/npy/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/npy/CMakeLists.txt new file mode 100644 index 0000000000..c4faf3ac3f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/npy/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_processor_operator_npy_reader + OBJECT + npy_reader.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/npy/npy_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/npy/npy_reader.cpp new file mode 100644 index 0000000000..8a5bd23290 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/npy/npy_reader.cpp @@ -0,0 +1,356 @@ +#include "processor/operator/persistent/reader/npy/npy_reader.h" + +#include +#include + +#include "binder/binder.h" +#include "common/exception/binder.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" +#include "processor/warning_context.h" + +#ifdef _WIN32 +#include "common/exception/buffer_manager.h" +#include +#include +#include +#include +#else +#include +#include +#endif +#include "common/exception/copy.h" +#include "common/string_format.h" +#include "common/utils.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/table_function.h" +#include "pyparse.h" +#include "storage/storage_utils.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +NpyReader::NpyReader(const std::string& filePath) + : filePath{filePath}, dataOffset{0}, type{LogicalTypeID::ANY} { + fd = open(filePath.c_str(), O_RDONLY); + if (fd == -1) { + throw CopyException("Failed to open NPY file."); + } + struct stat fileStatus {}; + fstat(fd, &fileStatus); + fileSize = fileStatus.st_size; + +#ifdef _WIN32 + DWORD low = (DWORD)(fileSize & 0xFFFFFFFFL); + DWORD high = (DWORD)((fileSize >> 32) & 0xFFFFFFFFL); + auto handle = + CreateFileMappingW((HANDLE)_get_osfhandle(fd), NULL, PAGE_READONLY, high, low, NULL); + if (handle == NULL) { + throw BufferManagerException( + stringFormat("CreateFileMapping for size {} failed with error code {}: {}.", fileSize, + GetLastError(), std::system_category().message(GetLastError()))); + } + + mmapRegion = MapViewOfFile(handle, FILE_MAP_READ, 0, 0, fileSize); + CloseHandle(handle); + if (mmapRegion == NULL) { + throw BufferManagerException( + stringFormat("MapViewOfFile for size {} failed with error code {}: {}.", fileSize, + GetLastError(), std::system_category().message(GetLastError()))); + } +#else + mmapRegion = mmap(nullptr, fileSize, PROT_READ, MAP_SHARED, fd, 0); + if (mmapRegion == MAP_FAILED) { + throw CopyException("Failed to mmap NPY file."); + } +#endif + parseHeader(); +} + +NpyReader::~NpyReader() { +#ifdef _WIN32 + UnmapViewOfFile(mmapRegion); +#else + munmap(mmapRegion, fileSize); +#endif + close(fd); +} + +size_t NpyReader::getNumElementsPerRow() const { + size_t numElements = 1; + for (size_t i = 1; i < shape.size(); ++i) { + numElements *= shape[i]; + } + return numElements; +} + +uint8_t* NpyReader::getPointerToRow(size_t row) const { + if (row >= getNumRows()) { + return nullptr; + } + return ( + uint8_t*)((char*)mmapRegion + dataOffset + + row * getNumElementsPerRow() * StorageUtils::getDataTypeSize(LogicalType{type})); +} + +void NpyReader::parseHeader() { + // The first 6 bytes are a magic string: exactly \x93NUMPY + char* magicString = (char*)mmapRegion; + const char* expectedMagicString = "\x93NUMPY"; + if (memcmp(magicString, expectedMagicString, 6) != 0) { + throw CopyException("Invalid NPY file"); + } + + // The next 1 byte is an unsigned byte: the major version number of the file + // format, e.g. x01. + char* majorVersion = magicString + 6; + if (*majorVersion != 1) { + throw CopyException("Unsupported NPY file version."); + } + // The next 1 byte is an unsigned byte: the minor version number of the file + // format, e.g. x00. Note: the version of the file format is not tied to the + // version of the numpy package. + char* minorVersion = majorVersion + 1; + if (*minorVersion != 0) { + throw CopyException("Unsupported NPY file version."); + } + // The next 2 bytes form a little-endian unsigned short int: the length of + // the header data HEADER_LEN. + auto headerLength = *(unsigned short int*)(minorVersion + 1); + if (!isLittleEndian()) { + headerLength = ((headerLength & 0xff00) >> 8) | ((headerLength & 0x00ff) << 8); + } + + // The next HEADER_LEN bytes form the header data describing the array's + // format. It is an ASCII string which contains a Python literal expression + // of a dictionary. It is terminated by a newline ('n') and padded with + // spaces ('x20') to make the total length of the magic string + 4 + + // HEADER_LEN be evenly divisible by 16 for alignment purposes. + auto metaInfoLength = strlen(expectedMagicString) + 4; + char* header = (char*)mmapRegion + metaInfoLength; + auto headerEnd = std::find(header, header + headerLength, '}'); + + std::string headerString(header, headerEnd + 1); + std::unordered_map headerMap = + pyparse::parse_dict(headerString, {"descr", "fortran_order", "shape"}); + auto isFortranOrder = pyparse::parse_bool(headerMap["fortran_order"]); + if (isFortranOrder) { + throw CopyException("Fortran-order NPY files are not currently supported."); + } + auto descr = pyparse::parse_str(headerMap["descr"]); + parseType(descr); + auto shapeV = pyparse::parse_tuple(headerMap["shape"]); + for (auto const& item : shapeV) { + shape.emplace_back(std::stoul(item)); + } + dataOffset = metaInfoLength + headerLength; +} + +void NpyReader::parseType(std::string descr) { + if (descr[0] == '<' || descr[0] == '>') { + // Data type endianness is specified + auto machineEndianness = isLittleEndian() ? "<" : ">"; + if (descr[0] != machineEndianness[0]) { + throw CopyException( + "The endianness of the file does not match the machine's endianness."); + } + descr = descr.substr(1); + } + if (descr[0] == '|' || descr[0] == '=') { + // Data type endianness is not applicable or native + descr = descr.substr(1); + } + if (descr == "f8") { + type = LogicalTypeID::DOUBLE; + } else if (descr == "f4") { + type = LogicalTypeID::FLOAT; + } else if (descr == "i8") { + type = LogicalTypeID::INT64; + } else if (descr == "i4") { + type = LogicalTypeID::INT32; + } else if (descr == "i2") { + type = LogicalTypeID::INT16; + } else { + throw CopyException("Unsupported data type: " + descr); + } +} + +void NpyReader::validate(const LogicalType& type_, offset_t numRows) { + auto numNodesInFile = getNumRows(); + if (numNodesInFile == 0) { + throw CopyException(stringFormat("Number of rows in npy file {} is 0.", filePath)); + } + if (numNodesInFile != numRows) { + throw CopyException("Number of rows in npy files is not equal to each other."); + } + // TODO(Guodong): Set npy reader data type to ARRAY, so we can simplify checks here. + if (type_.getLogicalTypeID() == this->type) { + if (getNumElementsPerRow() != 1) { + throw CopyException(stringFormat("Cannot copy a vector property in npy file {} to a " + "scalar property.", + filePath)); + } + return; + } else if (type_.getLogicalTypeID() == LogicalTypeID::ARRAY) { + if (this->type != ArrayType::getChildType(type_).getLogicalTypeID()) { + throw CopyException(stringFormat("The type of npy file {} does not " + "match the expected type.", + filePath)); + } + if (getNumElementsPerRow() != ArrayType::getNumElements(type_)) { + throw CopyException( + stringFormat("The shape of {} does not match {}.", filePath, type_.toString())); + } + return; + } else { + throw CopyException(stringFormat("The type of npy file {} does not " + "match the expected type.", + filePath)); + } +} + +void NpyReader::readBlock(block_idx_t blockIdx, ValueVector* vectorToRead) const { + uint64_t rowNumber = DEFAULT_VECTOR_CAPACITY * blockIdx; + auto numRows = getNumRows(); + if (rowNumber >= numRows) { + vectorToRead->state->getSelVectorUnsafe().setSelSize(0); + } else { + auto rowPointer = getPointerToRow(rowNumber); + auto numRowsToRead = std::min(DEFAULT_VECTOR_CAPACITY, getNumRows() - rowNumber); + const auto& rowType = vectorToRead->dataType; + if (rowType.getLogicalTypeID() == LogicalTypeID::ARRAY) { + auto numValuesPerRow = ArrayType::getNumElements(rowType); + for (auto i = 0u; i < numRowsToRead; i++) { + auto listEntry = ListVector::addList(vectorToRead, numValuesPerRow); + vectorToRead->setValue(i, listEntry); + } + auto dataVector = ListVector::getDataVector(vectorToRead); + memcpy(dataVector->getData(), rowPointer, + numRowsToRead * numValuesPerRow * dataVector->getNumBytesPerValue()); + vectorToRead->state->getSelVectorUnsafe().setSelSize(numRowsToRead); + } else { + memcpy(vectorToRead->getData(), rowPointer, + numRowsToRead * vectorToRead->getNumBytesPerValue()); + vectorToRead->state->getSelVectorUnsafe().setSelSize(numRowsToRead); + } + } +} + +NpyMultiFileReader::NpyMultiFileReader(const std::vector& filePaths) { + for (auto& file : filePaths) { + fileReaders.push_back(std::make_unique(file)); + } +} + +void NpyMultiFileReader::readBlock(block_idx_t blockIdx, DataChunk& dataChunkToRead) const { + for (auto i = 0u; i < fileReaders.size(); i++) { + fileReaders[i]->readBlock(blockIdx, &dataChunkToRead.getValueVectorMutable(i)); + } +} + +NpyScanSharedState::NpyScanSharedState(FileScanInfo fileScanInfo, uint64_t numRows) + : ScanFileSharedState{std::move(fileScanInfo), numRows} { + npyMultiFileReader = std::make_unique(this->fileScanInfo.filePaths); +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto sharedState = reinterpret_cast(input.sharedState); + auto [_, blockIdx] = sharedState->getNext(); + sharedState->npyMultiFileReader->readBlock(blockIdx, output.dataChunk); + return output.dataChunk.state->getSelVector().getSelSize(); +} + +static LogicalType bindColumnType(const NpyReader& reader) { + if (reader.getShape().size() == 1) { + return LogicalType(reader.getType()); + } + // For columns whose type is a multi-dimension array of size n*m, + // we flatten the row data into an 1-d array with size 1*k where k = n*m + return LogicalType::ARRAY(LogicalType(reader.getType()), reader.getNumElementsPerRow()); +} + +static void bindColumns(const FileScanInfo& fileScanInfo, uint32_t fileIdx, + std::vector& columnNames, std::vector& columnTypes) { + auto reader = NpyReader(fileScanInfo.filePaths[fileIdx]); // TODO: double check + auto columnName = std::string("column" + std::to_string(fileIdx)); + auto columnType = bindColumnType(reader); + columnNames.push_back(columnName); + columnTypes.push_back(std::move(columnType)); +} + +static void bindColumns(const FileScanInfo& fileScanInfo, std::vector& columnNames, + std::vector& columnTypes) { + KU_ASSERT(fileScanInfo.getNumFiles() > 0); + bindColumns(fileScanInfo, 0, columnNames, columnTypes); + for (auto i = 1u; i < fileScanInfo.getNumFiles(); ++i) { + std::vector tmpColumnNames; + std::vector tmpColumnTypes; + bindColumns(fileScanInfo, i, tmpColumnNames, tmpColumnTypes); + ReaderBindUtils::validateNumColumns(1, tmpColumnTypes.size()); + columnNames.push_back(tmpColumnNames[0]); + columnTypes.push_back(std::move(tmpColumnTypes[0])); + } +} + +static std::unique_ptr bindFunc(main::ClientContext* context, + const TableFuncBindInput* input) { + auto scanInput = ku_dynamic_cast(input->extraInput.get()); + if (scanInput->fileScanInfo.options.size() > 1 || + (scanInput->fileScanInfo.options.size() == 1 && + !scanInput->fileScanInfo.options.contains(CopyConstants::IGNORE_ERRORS_OPTION_NAME))) { + throw BinderException{"Copy from numpy cannot have options other than IGNORE_ERRORS."}; + } + std::vector detectedColumnNames; + std::vector detectedColumnTypes; + bindColumns(scanInput->fileScanInfo, detectedColumnNames, detectedColumnTypes); + std::vector resultColumnNames; + std::vector resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + auto config = scanInput->fileScanInfo.copy(); + KU_ASSERT(!config.filePaths.empty() && config.getNumFiles() == resultColumnNames.size()); + row_idx_t numRows = 0; + for (auto i = 0u; i < config.getNumFiles(); i++) { + auto reader = make_unique(config.filePaths[i]); + if (i == 0) { + numRows = reader->getNumRows(); + } + reader->validate(resultColumnTypes[i], numRows); + } + resultColumnNames = + TableFunction::extractYieldVariables(resultColumnNames, input->yieldVariables); + auto columns = input->binder->createVariables(resultColumnNames, resultColumnTypes); + return std::make_unique(columns, numRows, scanInput->fileScanInfo.copy(), + context); +} + +static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input) { + auto bindData = input.bindData->constPtrCast(); + auto reader = make_unique(bindData->fileScanInfo.filePaths[0]); + return std::make_unique(bindData->fileScanInfo.copy(), bindData->numRows); +} + +static void finalizeFunc(const ExecutionContext* ctx, TableFuncSharedState*) { + processor::WarningContext::Get(*ctx->clientContext)->defaultPopulateAllWarnings(ctx->queryID); +} + +function_set NpyScanFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = tableFunc; + function->bindFunc = bindFunc; + function->initSharedStateFunc = initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + function->finalizeFunc = finalizeFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/CMakeLists.txt new file mode 100644 index 0000000000..3f54037d06 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/CMakeLists.txt @@ -0,0 +1,15 @@ +add_library(lbug_processor_operator_parquet_reader + OBJECT + boolean_column_reader.cpp + column_reader.cpp + parquet_reader.cpp + interval_column_reader.cpp + struct_column_reader.cpp + string_column_reader.cpp + list_column_reader.cpp + parquet_timestamp.cpp + uuid_column_reader.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/boolean_column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/boolean_column_reader.cpp new file mode 100644 index 0000000000..9c14dce343 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/boolean_column_reader.cpp @@ -0,0 +1,27 @@ +#include "processor/operator/persistent/reader/parquet/boolean_column_reader.h" + +namespace lbug { +namespace processor { + +void BooleanColumnReader::initializeRead(uint64_t rowGroupIdx, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol) { + bytePos = 0; + TemplatedColumnReader::initializeRead(rowGroupIdx, columns, + protocol); +} + +bool BooleanParquetValueConversion::plainRead(ByteBuffer& plainData, ColumnReader& reader) { + plainData.available(1); + auto& bytePos = reinterpret_cast(reader).bytePos; + bool ret = (*plainData.ptr >> bytePos) & 1; + bytePos++; + if (bytePos == 8) { + bytePos = 0; + plainData.inc(1); + } + return ret; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/column_reader.cpp new file mode 100644 index 0000000000..b8ad9864a0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/column_reader.cpp @@ -0,0 +1,585 @@ +#include "processor/operator/persistent/reader/parquet/column_reader.h" + +#include + +#include "brotli/decode.h" +#include "common/assert.h" +#include "common/exception/not_implemented.h" +#include "common/exception/runtime.h" +#include "common/types/date_t.h" +#include "lz4.hpp" +#include "miniz_wrapper.hpp" +#include "processor/operator/persistent/reader/parquet/boolean_column_reader.h" +#include "processor/operator/persistent/reader/parquet/callback_column_reader.h" +#include "processor/operator/persistent/reader/parquet/interval_column_reader.h" +#include "processor/operator/persistent/reader/parquet/parquet_timestamp.h" +#include "processor/operator/persistent/reader/parquet/string_column_reader.h" +#include "processor/operator/persistent/reader/parquet/templated_column_reader.h" +#include "processor/operator/persistent/reader/parquet/uuid_column_reader.h" +#include "snappy.h" +#include "zstd.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +using lbug_parquet::format::CompressionCodec; +using lbug_parquet::format::ConvertedType; +using lbug_parquet::format::Encoding; +using lbug_parquet::format::PageType; +using lbug_parquet::format::Type; + +ColumnReader::ColumnReader(ParquetReader& reader, LogicalType type, + const lbug_parquet::format::SchemaElement& schema, idx_t fileIdx, uint64_t maxDefinition, + uint64_t maxRepeat) + : schema{schema}, fileIdx{fileIdx}, maxDefine{maxDefinition}, maxRepeat{maxRepeat}, + reader{reader}, type{std::move(type)}, protocol(nullptr), pageRowsAvailable{0}, + groupRowsAvailable(0), chunkReadOffset(0) {} + +void ColumnReader::initializeRead(uint64_t /*rowGroupIdx*/, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol) { + KU_ASSERT(fileIdx < columns.size()); + chunk = &columns[fileIdx]; + this->protocol = &protocol; + KU_ASSERT(chunk); + KU_ASSERT(chunk->__isset.meta_data); + + if (chunk->__isset.file_path) { + throw std::runtime_error("Only inlined data files are supported (no references)"); + } + + // ugh. sometimes there is an extra offset for the dict. sometimes it's wrong. + chunkReadOffset = chunk->meta_data.data_page_offset; + if (chunk->meta_data.__isset.dictionary_page_offset && + chunk->meta_data.dictionary_page_offset >= 4) { + // this assumes the data pages follow the dict pages directly. + chunkReadOffset = chunk->meta_data.dictionary_page_offset; + } + groupRowsAvailable = chunk->meta_data.num_values; +} + +void ColumnReader::registerPrefetch(ThriftFileTransport& transport, bool allowMerge) { + if (chunk) { + uint64_t size = chunk->meta_data.total_compressed_size; + transport.RegisterPrefetch(fileOffset(), size, allowMerge); + } +} + +uint64_t ColumnReader::fileOffset() const { + if (!chunk) { + throw std::runtime_error("fileOffset called on ColumnReader with no chunk"); + } + auto minOffset = UINT64_MAX; + if (chunk->meta_data.__isset.dictionary_page_offset) { + minOffset = std::min(minOffset, chunk->meta_data.dictionary_page_offset); + } + if (chunk->meta_data.__isset.index_page_offset) { + minOffset = std::min(minOffset, chunk->meta_data.index_page_offset); + } + minOffset = std::min(minOffset, chunk->meta_data.data_page_offset); + + return minOffset; +} + +void ColumnReader::applyPendingSkips(uint64_t numValues) { + pendingSkips -= numValues; + + dummyDefine.zero(); + dummyRepeat.zero(); + + // TODO this can be optimized, for example we dont actually have to bitunpack offsets + std::unique_ptr dummyResult = + std::make_unique(type.copy()); + + uint64_t remaining = numValues; + uint64_t numValuesRead = 0; + + while (remaining) { + auto numValuesToRead = std::min(remaining, common::DEFAULT_VECTOR_CAPACITY); + numValuesRead += + read(numValuesToRead, noneFilter, dummyDefine.ptr, dummyRepeat.ptr, dummyResult.get()); + remaining -= numValuesToRead; + } + + if (numValuesRead != numValues) { + throw std::runtime_error("Row count mismatch when skipping rows"); + } +} + +uint64_t ColumnReader::read(uint64_t numValues, parquet_filter_t& filter, uint8_t* defineOut, + uint8_t* repeatOut, common::ValueVector* resultOut) { + // we need to reset the location because multiple column readers share the same protocol + auto& trans = reinterpret_cast(*protocol->getTransport()); + trans.SetLocation(chunkReadOffset); + + // Perform any skips that were not applied yet. + if (pendingSkips > 0) { + applyPendingSkips(pendingSkips); + } + + uint64_t resultOffset = 0; + auto toRead = numValues; + + while (toRead > 0) { + while (pageRowsAvailable == 0) { + prepareRead(filter); + } + + KU_ASSERT(block); + auto readNow = std::min(toRead, pageRowsAvailable); + + KU_ASSERT(readNow <= common::DEFAULT_VECTOR_CAPACITY); + + if (hasRepeats()) { + KU_ASSERT(repeatedDecoder); + repeatedDecoder->GetBatch(repeatOut + resultOffset, readNow); + } + + if (hasDefines()) { + KU_ASSERT(defineDecoder); + defineDecoder->GetBatch(defineOut + resultOffset, readNow); + } + + uint64_t nullCount = 0; + + if ((dictDecoder || dbpDecoder || rleDecoder) && hasDefines()) { + // we need the null count because the dictionary offsets have no entries for nulls + for (auto i = 0u; i < readNow; i++) { + if (defineOut[i + resultOffset] != maxDefine) { + nullCount++; + } + } + } + + if (dictDecoder) { + offsetBuffer.resize(sizeof(uint32_t) * (readNow - nullCount)); + dictDecoder->GetBatch(offsetBuffer.ptr, readNow - nullCount); + offsets(reinterpret_cast(offsetBuffer.ptr), defineOut, readNow, filter, + resultOffset, resultOut); + } else if (dbpDecoder) { + // TODO keep this in the state + auto readBuf = std::make_shared(); + + switch (type.getPhysicalType()) { + case common::PhysicalTypeID::INT32: + readBuf->resize(sizeof(int32_t) * (readNow - nullCount)); + dbpDecoder->GetBatch(readBuf->ptr, readNow - nullCount); + break; + case common::PhysicalTypeID::INT64: + readBuf->resize(sizeof(int64_t) * (readNow - nullCount)); + dbpDecoder->GetBatch(readBuf->ptr, readNow - nullCount); + break; + default: + throw std::runtime_error("DELTA_BINARY_PACKED should only be INT32 or INT64"); + } + // Plain() will put NULLs in the right place + plain(readBuf, defineOut, readNow, filter, resultOffset, resultOut); + } else if (rleDecoder) { + // RLE encoding for boolean + KU_ASSERT(type.getLogicalTypeID() == common::LogicalTypeID::BOOL); + auto readBuf = std::make_shared(); + readBuf->resize(sizeof(bool) * (readNow - nullCount)); + rleDecoder->GetBatch(readBuf->ptr, readNow - nullCount); + plainTemplated>(readBuf, defineOut, readNow, + filter, resultOffset, resultOut); + } else { + plain(block, defineOut, readNow, filter, resultOffset, resultOut); + } + + resultOffset += readNow; + pageRowsAvailable -= readNow; + toRead -= readNow; + } + groupRowsAvailable -= numValues; + chunkReadOffset = trans.GetLocation(); + + return numValues; +} + +std::unique_ptr ColumnReader::createReader(ParquetReader& reader, + common::LogicalType type, const lbug_parquet::format::SchemaElement& schema, uint64_t fileIdx, + uint64_t maxDefine, uint64_t maxRepeat) { + switch (type.getLogicalTypeID()) { + case common::LogicalTypeID::BOOL: + return std::make_unique(reader, std::move(type), schema, fileIdx, + maxDefine, maxRepeat); + case common::LogicalTypeID::INT8: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::INT16: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::INT32: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::SERIAL: + case common::LogicalTypeID::INT64: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::UINT8: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::UINT16: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::UINT32: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::UINT64: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::FLOAT: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::DOUBLE: + return std::make_unique< + TemplatedColumnReader>>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::DATE: + return std::make_unique< + CallbackColumnReader>( + reader, std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case common::LogicalTypeID::BLOB: + case common::LogicalTypeID::STRING: + return std::make_unique(reader, std::move(type), schema, fileIdx, + maxDefine, maxRepeat); + case common::LogicalTypeID::INTERVAL: + return std::make_unique(reader, std::move(type), schema, fileIdx, + maxDefine, maxRepeat); + case common::LogicalTypeID::TIMESTAMP_TZ: + case common::LogicalTypeID::TIMESTAMP: + return createTimestampReader(reader, std::move(type), schema, fileIdx, maxDefine, + maxRepeat); + case common::LogicalTypeID::UUID: + return std::make_unique(reader, std::move(type), schema, fileIdx, + maxDefine, maxRepeat); + default: + KU_UNREACHABLE; + } +} + +void ColumnReader::prepareRead(parquet_filter_t& /*filter*/) { + dictDecoder.reset(); + defineDecoder.reset(); + block.reset(); + lbug_parquet::format::PageHeader pageHdr; + pageHdr.read(protocol); + + switch (pageHdr.type) { + case PageType::DATA_PAGE_V2: + preparePageV2(pageHdr); + prepareDataPage(pageHdr); + break; + case PageType::DATA_PAGE: + preparePage(pageHdr); + prepareDataPage(pageHdr); + break; + case PageType::DICTIONARY_PAGE: + preparePage(pageHdr); + dictionary(block, pageHdr.dictionary_page_header.num_values); + break; + default: + break; // ignore INDEX page type and any other custom extensions + } + resetPage(); +} + +void ColumnReader::allocateBlock(uint64_t size) { + if (!block) { + block = std::make_shared(size); + } else { + block->resize(size); + } +} + +void ColumnReader::allocateCompressed(uint64_t size) { + compressedBuffer.resize(size); +} + +static void brotliDecompress(uint8_t* dst, size_t dstSize, const uint8_t* src, size_t srcSize) { + auto instance = BrotliDecoderCreateInstance(nullptr /* alloc_func */, nullptr /* free_func */, + nullptr /* opaque */); + BrotliDecoderResult oneshotResult{}; + do { + oneshotResult = + BrotliDecoderDecompressStream(instance, &srcSize, &src, &dstSize, &dst, nullptr); + } while (srcSize != 0 || oneshotResult != BROTLI_DECODER_RESULT_SUCCESS); + BrotliDecoderDestroyInstance(instance); +} + +void ColumnReader::decompressInternal(lbug_parquet::format::CompressionCodec::type codec, + const uint8_t* src, uint64_t srcSize, uint8_t* dst, uint64_t dstSize) { + switch (codec) { + case CompressionCodec::UNCOMPRESSED: + throw common::CopyException("Parquet data unexpectedly uncompressed"); + case CompressionCodec::GZIP: { + MiniZStream s; + s.Decompress(reinterpret_cast(src), srcSize, reinterpret_cast(dst), + dstSize); + } break; + case CompressionCodec::SNAPPY: { + { + size_t uncompressedSize = 0; + auto res = lbug_snappy::GetUncompressedLength(reinterpret_cast(src), + srcSize, &uncompressedSize); + // LCOV_EXCL_START + if (!res) { + throw common::RuntimeException{"Failed to decompress parquet file."}; + } + if (uncompressedSize != (size_t)dstSize) { + throw common::RuntimeException{ + "Snappy decompression failure: Uncompressed data size mismatch"}; + } + // LCOV_EXCL_STOP + } + auto res = lbug_snappy::RawUncompress(reinterpret_cast(src), srcSize, + reinterpret_cast(dst)); + // LCOV_EXCL_START + if (!res) { + throw common::RuntimeException{"Snappy decompression failed."}; + } + // LCOV_EXCL_STOP + } break; + case CompressionCodec::ZSTD: { + auto res = lbug_zstd::ZSTD_decompress(dst, dstSize, src, srcSize); + // LCOV_EXCL_START + if (lbug_zstd::ZSTD_isError(res) || res != (size_t)dstSize) { + throw common::RuntimeException{"ZSTD decompression failed."}; + } + // LCOV_EXCL_STOP + } break; + case CompressionCodec::BROTLI: { + brotliDecompress(dst, dstSize, src, srcSize); + } break; + case CompressionCodec::LZ4_RAW: { + auto res = lbug_lz4::LZ4_decompress_safe(reinterpret_cast(src), + reinterpret_cast(dst), srcSize, dstSize); + // LCOV_EXCL_START + if (res != (int64_t)dstSize) { + throw common::RuntimeException{"LZ4 decompression failed."}; + } + // LCOV_EXCL_STOP + } break; + default: { + // LCOV_EXCL_START + std::stringstream codec_name; + codec_name << codec; + throw common::CopyException("Unsupported compression codec \"" + codec_name.str() + + "\". Supported options are uncompressed, gzip, snappy or zstd"); + // LCOV_EXCL_STOP + } + } +} + +void ColumnReader::preparePageV2(lbug_parquet::format::PageHeader& pageHdr) { + KU_ASSERT(pageHdr.type == PageType::DATA_PAGE_V2); + + auto& trans = reinterpret_cast(*protocol->getTransport()); + + allocateBlock(pageHdr.uncompressed_page_size + 1); + bool uncompressed = false; + if (pageHdr.data_page_header_v2.__isset.is_compressed && + !pageHdr.data_page_header_v2.is_compressed) { + uncompressed = true; + } + if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { + if (pageHdr.compressed_page_size != pageHdr.uncompressed_page_size) { + throw std::runtime_error("Page size mismatch"); + } + uncompressed = true; + } + if (uncompressed) { + trans.read(block->ptr, pageHdr.compressed_page_size); + return; + } + + // copy repeats & defines as-is because FOR SOME REASON they are uncompressed + auto uncompressedBytes = pageHdr.data_page_header_v2.repetition_levels_byte_length + + pageHdr.data_page_header_v2.definition_levels_byte_length; + trans.read(block->ptr, uncompressedBytes); + + auto compressedBytes = pageHdr.compressed_page_size - uncompressedBytes; + + allocateCompressed(compressedBytes); + trans.read(compressedBuffer.ptr, compressedBytes); + + decompressInternal(chunk->meta_data.codec, compressedBuffer.ptr, compressedBytes, + block->ptr + uncompressedBytes, pageHdr.uncompressed_page_size - uncompressedBytes); +} + +void ColumnReader::preparePage(lbug_parquet::format::PageHeader& pageHdr) { + auto& trans = reinterpret_cast(*protocol->getTransport()); + + allocateBlock(pageHdr.uncompressed_page_size + 1); + if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { + if (pageHdr.compressed_page_size != pageHdr.uncompressed_page_size) { + throw std::runtime_error("Page size mismatch"); + } + trans.read((uint8_t*)block->ptr, pageHdr.compressed_page_size); + return; + } + + allocateCompressed(pageHdr.compressed_page_size + 1); + trans.read((uint8_t*)compressedBuffer.ptr, pageHdr.compressed_page_size); + + decompressInternal(chunk->meta_data.codec, compressedBuffer.ptr, pageHdr.compressed_page_size, + block->ptr, pageHdr.uncompressed_page_size); +} + +void ColumnReader::prepareDataPage(lbug_parquet::format::PageHeader& pageHdr) { + if (pageHdr.type == PageType::DATA_PAGE && !pageHdr.__isset.data_page_header) { + throw std::runtime_error("Missing data page header from data page"); + } + if (pageHdr.type == PageType::DATA_PAGE_V2 && !pageHdr.__isset.data_page_header_v2) { + throw std::runtime_error("Missing data page header from data page v2"); + } + + bool isV1 = pageHdr.type == PageType::DATA_PAGE; + bool isV2 = pageHdr.type == PageType::DATA_PAGE_V2; + auto& v1Header = pageHdr.data_page_header; + auto& v2Header = pageHdr.data_page_header_v2; + + pageRowsAvailable = isV1 ? v1Header.num_values : v2Header.num_values; + auto pageEncoding = isV1 ? v1Header.encoding : v2Header.encoding; + + if (hasRepeats()) { + uint32_t repLength = + isV1 ? block->read() : v2Header.repetition_levels_byte_length; + block->available(repLength); + repeatedDecoder = std::make_unique(block->ptr, repLength, + RleBpDecoder::ComputeBitWidth(maxRepeat)); + block->inc(repLength); + } else if (isV2 && v2Header.repetition_levels_byte_length > 0) { + block->inc(v2Header.repetition_levels_byte_length); + } + + if (hasDefines()) { + auto defLen = isV1 ? block->read() : v2Header.definition_levels_byte_length; + block->available(defLen); + defineDecoder = std::make_unique(block->ptr, defLen, + RleBpDecoder::ComputeBitWidth(maxDefine)); + block->inc(defLen); + } else if (isV2 && v2Header.definition_levels_byte_length > 0) { + block->inc(v2Header.definition_levels_byte_length); + } + + switch (pageEncoding) { + case Encoding::RLE_DICTIONARY: + case Encoding::PLAIN_DICTIONARY: { + // where is it otherwise?? + auto dictWidth = block->read(); + // TODO somehow dict_width can be 0 ? + dictDecoder = std::make_unique(block->ptr, block->len, dictWidth); + block->inc(block->len); + break; + } + case Encoding::RLE: { + if (type.getLogicalTypeID() != common::LogicalTypeID::BOOL) { + throw std::runtime_error("RLE encoding is only supported for boolean data"); + } + block->inc(sizeof(uint32_t)); + rleDecoder = std::make_unique(block->ptr, block->len, 1); + break; + } + case Encoding::DELTA_BINARY_PACKED: { + dbpDecoder = std::make_unique(block->ptr, block->len); + block->inc(block->len); + break; + } + case Encoding::DELTA_LENGTH_BYTE_ARRAY: + case Encoding::DELTA_BYTE_ARRAY: { + KU_UNREACHABLE; + } + case Encoding::PLAIN: + // nothing to do here, will be read directly below + break; + default: + throw common::NotImplementedException("Parquet: unsupported page encoding"); + } +} + +uint64_t ColumnReader::getTotalCompressedSize() { + if (!chunk) { + return 0; + } + return chunk->meta_data.total_compressed_size; +} + +std::unique_ptr ColumnReader::createTimestampReader(ParquetReader& reader, + common::LogicalType type, const lbug_parquet::format::SchemaElement& schema, uint64_t fileIdx, + uint64_t maxDefine, uint64_t maxRepeat) { + switch (schema.type) { + case Type::INT96: { + return std::make_unique>(reader, std::move(type), schema, + fileIdx, maxDefine, maxRepeat); + } + case Type::INT64: { + if (schema.__isset.logicalType && schema.logicalType.__isset.TIMESTAMP) { + if (schema.logicalType.TIMESTAMP.unit.__isset.MILLIS) { + return std::make_unique>(reader, std::move(type), + schema, fileIdx, maxDefine, maxRepeat); + } else if (schema.logicalType.TIMESTAMP.unit.__isset.MICROS) { + return std::make_unique>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + } else if (schema.logicalType.TIMESTAMP.unit.__isset.NANOS) { + return std::make_unique>(reader, std::move(type), + schema, fileIdx, maxDefine, maxRepeat); + } + // LCOV_EXCL_START + } else if (schema.__isset.converted_type) { + // For legacy compatibility. + switch (schema.converted_type) { + case ConvertedType::TIMESTAMP_MICROS: + return std::make_unique>(reader, + std::move(type), schema, fileIdx, maxDefine, maxRepeat); + case ConvertedType::TIMESTAMP_MILLIS: + return std::make_unique>(reader, std::move(type), + schema, fileIdx, maxDefine, maxRepeat); + default: + KU_UNREACHABLE; + } + // LCOV_EXCL_STOP + } + KU_UNREACHABLE; + } + default: { + KU_UNREACHABLE; + } + } +} + +const uint64_t ParquetDecodeUtils::BITPACK_MASKS[] = {0, 1, 3, 7, 15, 31, 63, 127, 255, 511, 1023, + 2047, 4095, 8191, 16383, 32767, 65535, 131071, 262143, 524287, 1048575, 2097151, 4194303, + 8388607, 16777215, 33554431, 67108863, 134217727, 268435455, 536870911, 1073741823, 2147483647, + 4294967295, 8589934591, 17179869183, 34359738367, 68719476735, 137438953471, 274877906943, + 549755813887, 1099511627775, 2199023255551, 4398046511103, 8796093022207, 17592186044415, + 35184372088831, 70368744177663, 140737488355327, 281474976710655, 562949953421311, + 1125899906842623, 2251799813685247, 4503599627370495, 9007199254740991, 18014398509481983, + 36028797018963967, 72057594037927935, 144115188075855871, 288230376151711743, + 576460752303423487, 1152921504606846975, 2305843009213693951, 4611686018427387903, + 9223372036854775807, 18446744073709551615ULL}; + +const uint64_t ParquetDecodeUtils::BITPACK_MASKS_SIZE = + sizeof(ParquetDecodeUtils::BITPACK_MASKS) / sizeof(uint64_t); + +const uint8_t ParquetDecodeUtils::BITPACK_DLEN = 8; + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/interval_column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/interval_column_reader.cpp new file mode 100644 index 0000000000..8f8e05974a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/interval_column_reader.cpp @@ -0,0 +1,34 @@ +#include "processor/operator/persistent/reader/parquet/interval_column_reader.h" + +namespace lbug { +namespace processor { + +common::interval_t IntervalValueConversion::readParquetInterval(const char* input) { + common::interval_t result; + auto inputData = reinterpret_cast(input); + result.months = inputData[0]; + result.days = inputData[1]; + result.micros = int64_t(inputData[2]) * 1000; + return result; +} + +common::interval_t IntervalValueConversion::plainRead(ByteBuffer& plainData, + ColumnReader& /*reader*/) { + auto intervalLen = common::ParquetConstants::PARQUET_INTERVAL_SIZE; + plainData.available(intervalLen); + auto res = readParquetInterval(reinterpret_cast(plainData.ptr)); + plainData.inc(intervalLen); + return res; +} + +void IntervalColumnReader::dictionary(const std::shared_ptr& dictionaryData, + uint64_t numEntries) { + allocateDict(numEntries * sizeof(common::interval_t)); + auto dict_ptr = reinterpret_cast(this->dict->ptr); + for (auto i = 0u; i < numEntries; i++) { + dict_ptr[i] = IntervalValueConversion::plainRead(*dictionaryData, *this); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/list_column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/list_column_reader.cpp new file mode 100644 index 0000000000..1a60abbe2d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/list_column_reader.cpp @@ -0,0 +1,131 @@ +#include "processor/operator/persistent/reader/parquet/list_column_reader.h" + +namespace lbug { +namespace processor { + +ListColumnReader::ListColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat, std::unique_ptr childColumnReader, + storage::MemoryManager* memoryManager) + : ColumnReader(reader, std::move(type), schema, schemaIdx, maxDefine, maxRepeat), + childColumnReader(std::move(childColumnReader)), overflowChildCount(0) { + childDefines.resize(common::DEFAULT_VECTOR_CAPACITY); + childRepeats.resize(common::DEFAULT_VECTOR_CAPACITY); + childDefinesPtr = (uint8_t*)childDefines.ptr; + childRepeatsPtr = (uint8_t*)childRepeats.ptr; + childFilter.set(); + vectorToRead = std::make_unique( + common::ListType::getChildType(this->type).copy(), memoryManager); +} + +void ListColumnReader::applyPendingSkips(uint64_t numValues) { + pendingSkips -= numValues; + auto defineOut = std::unique_ptr(new uint8_t[numValues]); + auto repeatOut = std::unique_ptr(new uint8_t[numValues]); + uint64_t remaining = numValues; + uint64_t numValuesRead = 0; + while (remaining) { + auto result_out = std::make_unique(type.copy()); + parquet_filter_t filter; + auto to_read = std::min(remaining, common::DEFAULT_VECTOR_CAPACITY); + numValuesRead += read(to_read, filter, defineOut.get(), repeatOut.get(), result_out.get()); + remaining -= to_read; + } + + if (numValuesRead != numValues) { + throw common::CopyException("Not all skips done!"); + } +} + +uint64_t ListColumnReader::read(uint64_t numValues, parquet_filter_t& /*filter*/, + uint8_t* defineOut, uint8_t* repeatOut, common::ValueVector* resultOut) { + common::offset_t resultOffset = 0; + auto resultPtr = reinterpret_cast(resultOut->getData()); + + if (pendingSkips > 0) { + applyPendingSkips(pendingSkips); + } + + // if an individual list is longer than STANDARD_VECTOR_SIZE we actually have to loop the child + // read to fill it + bool finished = false; + while (!finished) { + uint64_t childActualNumValues = 0; + + // check if we have any overflow from a previous read + if (overflowChildCount == 0) { + // we don't: read elements from the child reader + childDefines.zero(); + childRepeats.zero(); + // we don't know in advance how many values to read because of the beautiful + // repetition/definition setup we just read (up to) a vector from the child column, and + // see if we have read enough if we have not read enough, we read another vector if we + // have read enough, we leave any unhandled elements in the overflow vector for a + // subsequent read + auto childReqNumValues = std::min(common::DEFAULT_VECTOR_CAPACITY, + childColumnReader->getGroupRowsAvailable()); + childActualNumValues = childColumnReader->read(childReqNumValues, childFilter, + childDefinesPtr, childRepeatsPtr, vectorToRead.get()); + } else { + childActualNumValues = overflowChildCount; + overflowChildCount = 0; + } + + if (childActualNumValues == 0) { + // no more elements available: we are done + break; + } + auto currentChunkOffset = common::ListVector::getDataVectorSize(resultOut); + + // hard-won piece of code this, modify at your own risk + // the intuition is that we have to only collapse values into lists that are repeated *on + // this level* the rest is pretty much handed up as-is as a single-valued list or NULL + uint64_t childIdx = 0; + for (childIdx = 0; childIdx < childActualNumValues; childIdx++) { + if (childRepeatsPtr[childIdx] == maxRepeat) { + // value repeats on this level, append + KU_ASSERT(resultOffset > 0); + resultPtr[resultOffset - 1].size++; + continue; + } + + if (resultOffset >= numValues) { + // we ran out of output space + finished = true; + break; + } + if (childDefinesPtr[childIdx] >= maxDefine) { + resultOut->setNull(resultOffset, false); + // value has been defined down the stack, hence its NOT NULL + resultPtr[resultOffset].offset = childIdx + currentChunkOffset; + resultPtr[resultOffset].size = 1; + } else if (childDefinesPtr[childIdx] == maxDefine - 1) { + resultOut->setNull(resultOffset, false); + resultPtr[resultOffset].offset = childIdx + currentChunkOffset; + resultPtr[resultOffset].size = 0; + } else { + resultOut->setNull(resultOffset, true); + resultPtr[resultOffset].offset = 0; + resultPtr[resultOffset].size = 0; + } + + repeatOut[resultOffset] = childRepeatsPtr[childIdx]; + defineOut[resultOffset] = childDefinesPtr[childIdx]; + + resultOffset++; + } + common::ListVector::appendDataVector(resultOut, vectorToRead.get(), childIdx); + if (childIdx < childActualNumValues && resultOffset == numValues) { + common::ListVector::sliceDataVector(vectorToRead.get(), childIdx, childActualNumValues); + overflowChildCount = childActualNumValues - childIdx; + for (auto repdefIdx = 0u; repdefIdx < overflowChildCount; repdefIdx++) { + childDefinesPtr[repdefIdx] = childDefinesPtr[childIdx + repdefIdx]; + childRepeatsPtr[repdefIdx] = childRepeatsPtr[childIdx + repdefIdx]; + } + } + } + return resultOffset; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp new file mode 100644 index 0000000000..0f905fd87c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp @@ -0,0 +1,756 @@ +#include "processor/operator/persistent/reader/parquet/parquet_reader.h" + +#include "binder/binder.h" +#include "common/exception/binder.h" +#include "common/exception/copy.h" +#include "common/file_system/virtual_file_system.h" +#include "common/string_format.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/table_function.h" +#include "processor/execution_context.h" +#include "processor/operator/persistent/reader/parquet/list_column_reader.h" +#include "processor/operator/persistent/reader/parquet/struct_column_reader.h" +#include "processor/operator/persistent/reader/parquet/thrift_tools.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" +#include "processor/warning_context.h" + +using namespace lbug_parquet::format; + +namespace lbug { +namespace processor { + +using namespace lbug::function; +using namespace lbug::common; + +ParquetReader::ParquetReader(std::string filePath, std::vector columnSkips, + main::ClientContext* context) + : filePath{std::move(filePath)}, columnSkips(std::move(columnSkips)), context{context} { + initMetadata(); +} + +void ParquetReader::initializeScan(ParquetReaderScanState& state, + std::vector groups_to_read, VirtualFileSystem* vfs) { + state.currentGroup = -1; + state.finished = false; + state.groupOffset = 0; + state.groupIdxList = std::move(groups_to_read); + if (!state.fileInfo || state.fileInfo->path != filePath) { + state.prefetchMode = false; + state.fileInfo = + vfs->openFile(filePath, common::FileOpenFlags(FileFlags::READ_ONLY), context); + } + + state.thriftFileProto = createThriftProtocol(state.fileInfo.get(), state.prefetchMode); + state.rootReader = createReader(); + state.defineBuf.resize(DEFAULT_VECTOR_CAPACITY); + state.repeatBuf.resize(DEFAULT_VECTOR_CAPACITY); +} + +bool ParquetReader::scanInternal(ParquetReaderScanState& state, DataChunk& result) { + if (state.finished) { + return false; + } + + // see if we have to switch to the next row group in the parquet file + if (state.currentGroup < 0 || (int64_t)state.groupOffset >= getGroup(state).num_rows) { + state.currentGroup++; + state.groupOffset = 0; + + auto& trans = ku_dynamic_cast(*state.thriftFileProto->getTransport()); + trans.ClearPrefetch(); + state.currentGroupPrefetched = false; + + if ((uint64_t)state.currentGroup == state.groupIdxList.size()) { + state.finished = true; + return false; + } + + uint64_t toScanCompressedBytes = 0; + for (auto colIdx = 0u; colIdx < result.getNumValueVectors(); colIdx++) { + prepareRowGroupBuffer(state, colIdx); + + auto fileColIdx = colIdx; + + auto rootReader = ku_dynamic_cast(state.rootReader.get()); + toScanCompressedBytes += + rootReader->getChildReader(fileColIdx)->getTotalCompressedSize(); + } + + auto& group = getGroup(state); + if (state.prefetchMode && state.groupOffset != (uint64_t)group.num_rows) { + + uint64_t totalRowGroupSpan = getGroupSpan(state); + + double scanPercentage = (double)(toScanCompressedBytes) / totalRowGroupSpan; + + // LCOV_EXCL_START + if (toScanCompressedBytes > totalRowGroupSpan) { + throw CopyException("Malformed parquet file: sum of total compressed bytes " + "of columns seems incorrect"); + } + // LCOV_EXCL_STOP + + if (scanPercentage > ParquetReaderPrefetchConfig::WHOLE_GROUP_PREFETCH_MINIMUM_SCAN) { + // Prefetch the whole row group + if (!state.currentGroupPrefetched) { + auto totalCompressedSize = getGroupCompressedSize(state); + if (totalCompressedSize > 0) { + trans.Prefetch(getGroupOffset(state), totalRowGroupSpan); + } + state.currentGroupPrefetched = true; + } + } else { + // Prefetch column-wise. + for (auto colIdx = 0u; colIdx < result.getNumValueVectors(); colIdx++) { + auto fileColIdx = colIdx; + auto rootReader = ku_dynamic_cast(state.rootReader.get()); + + rootReader->getChildReader(fileColIdx) + ->registerPrefetch(trans, true /* lazy fetch */); + } + trans.FinalizeRegistration(); + trans.PrefetchRegistered(); + } + } + return true; + } + + auto thisOutputChunkRows = + std::min(DEFAULT_VECTOR_CAPACITY, getGroup(state).num_rows - state.groupOffset); + result.state->getSelVectorUnsafe().setSelSize(thisOutputChunkRows); + + if (thisOutputChunkRows == 0) { + state.finished = true; + return false; // end of last group, we are done + } + + // we evaluate simple table filters directly in this scan, so we can skip decoding column data + // that's never going to be relevant + parquet_filter_t filterMask; + filterMask.set(); + + // mask out unused part of bitset + for (auto i = thisOutputChunkRows; i < DEFAULT_VECTOR_CAPACITY; i++) { + filterMask.set(i, false); + } + + state.defineBuf.zero(); + state.repeatBuf.zero(); + + auto definePtr = (uint8_t*)state.defineBuf.ptr; + auto repeatPtr = (uint8_t*)state.repeatBuf.ptr; + + auto rootReader = ku_dynamic_cast(state.rootReader.get()); + for (auto colIdx = 0u; colIdx < result.getNumValueVectors(); colIdx++) { + if (!columnSkips.empty() && columnSkips[colIdx]) { + continue; + } + auto fileColIdx = colIdx; + auto& resultVector = result.getValueVectorMutable(colIdx); + auto childReader = rootReader->getChildReader(fileColIdx); + auto rowsRead = childReader->read(resultVector.state->getSelVector().getSelSize(), + filterMask, definePtr, repeatPtr, &resultVector); + // LCOV_EXCL_START + if (rowsRead != result.state->getSelVector().getSelSize()) { + throw CopyException( + stringFormat("Mismatch in parquet read for column {}, expected {} rows, got {}", + fileColIdx, result.state->getSelVector().getSelSize(), rowsRead)); + } + // LCOV_EXCL_STOP + } + + state.groupOffset += thisOutputChunkRows; + return true; +} + +void ParquetReader::scan(processor::ParquetReaderScanState& state, DataChunk& result) { + while (scanInternal(state, result)) { + if (result.state->getSelVector().getSelSize() > 0) { + break; + } + } +} + +void ParquetReader::initMetadata() { + auto fileInfo = VirtualFileSystem::GetUnsafe(*context)->openFile(filePath, + FileOpenFlags(FileFlags::READ_ONLY), context); + auto proto = createThriftProtocol(fileInfo.get(), false); + auto& transport = ku_dynamic_cast(*proto->getTransport()); + auto fileSize = transport.GetSize(); + // LCOV_EXCL_START + if (fileSize < 12) { + throw CopyException{ + stringFormat("File {} is too small to be a Parquet file", filePath.c_str())}; + } + // LCOV_EXCL_STOP + + ResizeableBuffer buf; + buf.resize(8); + buf.zero(); + + transport.SetLocation(fileSize - 8); + transport.read((uint8_t*)buf.ptr, 8); + + // LCOV_EXCL_START + if (memcmp(buf.ptr + 4, "PAR1", 4) != 0) { + if (memcmp(buf.ptr + 4, "PARE", 4) == 0) { + throw CopyException{stringFormat( + "Encrypted Parquet files are not supported for file {}", fileInfo->path.c_str())}; + } + throw CopyException{ + stringFormat("No magic bytes found at the end of file {}", fileInfo->path.c_str())}; + } + // LCOV_EXCL_STOP + // Read four-byte footer length from just before the end magic bytes. + auto footerLen = *reinterpret_cast(buf.ptr); + // LCOV_EXCL_START + if (footerLen == 0 || fileSize < 12 + footerLen) { + throw CopyException{stringFormat("Footer length error in file {}", fileInfo->path.c_str())}; + } + // LCOV_EXCL_STOP + auto metadataPos = fileSize - (footerLen + 8); + transport.SetLocation(metadataPos); + transport.Prefetch(metadataPos, footerLen); + + metadata = std::make_unique(); + metadata->read(proto.get()); +} + +std::unique_ptr ParquetReader::createReaderRecursive(uint64_t depth, + uint64_t maxDefine, uint64_t maxRepeat, uint64_t& nextSchemaIdx, uint64_t& nextFileIdx) { + KU_ASSERT(nextSchemaIdx < metadata->schema.size()); + auto& sEle = metadata->schema[nextSchemaIdx]; + auto thisIdx = nextSchemaIdx; + + auto repetition_type = FieldRepetitionType::REQUIRED; + if (sEle.__isset.repetition_type && thisIdx > 0) { + repetition_type = sEle.repetition_type; + } + if (repetition_type != FieldRepetitionType::REQUIRED) { + maxDefine++; + } + if (repetition_type == FieldRepetitionType::REPEATED) { + maxRepeat++; + } + if (sEle.__isset.num_children && sEle.num_children > 0) { + std::vector structFields; + std::vector> childrenReaders; + uint64_t cIdx = 0; + while (cIdx < (uint64_t)sEle.num_children) { + nextSchemaIdx++; + auto& childEle = metadata->schema[nextSchemaIdx]; + auto childReader = + createReaderRecursive(depth + 1, maxDefine, maxRepeat, nextSchemaIdx, nextFileIdx); + structFields.emplace_back(childEle.name, childReader->getDataType().copy()); + childrenReaders.push_back(std::move(childReader)); + cIdx++; + } + KU_ASSERT(!structFields.empty()); + std::unique_ptr result; + LogicalType resultType; + + bool isRepeated = repetition_type == FieldRepetitionType::REPEATED; + bool isList = sEle.__isset.converted_type && sEle.converted_type == ConvertedType::LIST; + bool isMap = sEle.__isset.converted_type && sEle.converted_type == ConvertedType::MAP; + bool isMapKV = + sEle.__isset.converted_type && sEle.converted_type == ConvertedType::MAP_KEY_VALUE; + if (!isMapKV && thisIdx > 0) { + // check if the parent node of this is a map + auto& parentEle = metadata->schema[thisIdx - 1]; + bool parentIsMap = + parentEle.__isset.converted_type && parentEle.converted_type == ConvertedType::MAP; + bool parentHasChildren = parentEle.__isset.num_children && parentEle.num_children == 1; + isMapKV = parentIsMap && parentHasChildren; + } + + if (isMapKV) { + // LCOV_EXCL_START + if (structFields.size() != 2) { + throw CopyException{"MAP_KEY_VALUE requires two children"}; + } + if (!isRepeated) { + throw CopyException{"MAP_KEY_VALUE needs to be repeated"}; + } + // LCOV_EXCL_STOP + auto structType = LogicalType::STRUCT(std::move(structFields)); + resultType = LogicalType(LogicalTypeID::MAP, + std::make_unique(std::move(structType))); + + auto structReader = std::make_unique(*this, + ListType::getChildType(resultType).copy(), sEle, thisIdx, maxDefine - 1, + maxRepeat - 1, std::move(childrenReaders)); + return std::make_unique(*this, std::move(resultType), sEle, thisIdx, + maxDefine, maxRepeat, std::move(structReader), + storage::MemoryManager::Get(*context)); + } + + if (structFields.size() > 1 || (!isList && !isMap && !isRepeated)) { + resultType = LogicalType::STRUCT(std::move(structFields)); + result = std::make_unique(*this, resultType.copy(), sEle, thisIdx, + maxDefine, maxRepeat, std::move(childrenReaders)); + } else { + // if we have a struct with only a single type, pull up + resultType = structFields[0].getType().copy(); + result = std::move(childrenReaders[0]); + } + if (isRepeated) { + resultType = LogicalType::LIST(resultType.copy()); + return std::make_unique(*this, std::move(resultType), sEle, thisIdx, + maxDefine, maxRepeat, std::move(result), storage::MemoryManager::Get(*context)); + } + return result; + } else { + // LCOV_EXCL_START + if (!sEle.__isset.type) { + throw CopyException{"Node has neither num_children nor type set - this " + "violates the Parquet spec (corrupted file)"}; + } + // LCOV_EXCL_STOP + if (sEle.repetition_type == FieldRepetitionType::REPEATED) { + auto derivedType = deriveLogicalType(sEle); + auto listType = LogicalType::LIST(derivedType.copy()); + auto elementReader = ColumnReader::createReader(*this, std::move(derivedType), sEle, + nextFileIdx++, maxDefine, maxRepeat); + return std::make_unique(*this, std::move(listType), sEle, thisIdx, + maxDefine, maxRepeat, std::move(elementReader), + storage::MemoryManager::Get(*context)); + } + // TODO check return value of derive type or should we only do this on read() + return ColumnReader::createReader(*this, deriveLogicalType(sEle), sEle, nextFileIdx++, + maxDefine, maxRepeat); + } +} + +std::unique_ptr ParquetReader::createReader() { + uint64_t nextSchemaIdx = 0; + uint64_t nextFileIdx = 0; + + // LCOV_EXCL_START + if (metadata->schema.empty()) { + throw CopyException{"Parquet reader: no schema elements found"}; + } + if (metadata->schema[0].num_children == 0) { + throw CopyException{"Parquet reader: root schema element has no children"}; + } + // LCOV_EXCL_STOP + auto rootReader = createReaderRecursive(0, 0, 0, nextSchemaIdx, nextFileIdx); + // LCOV_EXCL_START + if (rootReader->getDataType().getPhysicalType() != PhysicalTypeID::STRUCT) { + throw CopyException{"Root element of Parquet file must be a struct"}; + } + // LCOV_EXCL_STOP + for (auto& field : StructType::getFields(rootReader->getDataType())) { + columnNames.push_back(field.getName()); + columnTypes.push_back(field.getType().copy()); + } + + KU_ASSERT(nextSchemaIdx == metadata->schema.size() - 1); + KU_ASSERT( + metadata->row_groups.empty() || nextFileIdx == metadata->row_groups[0].columns.size()); + return rootReader; +} + +void ParquetReader::prepareRowGroupBuffer(ParquetReaderScanState& state, uint64_t /*colIdx*/) { + auto& group = getGroup(state); + state.rootReader->initializeRead(state.groupIdxList[state.currentGroup], group.columns, + *state.thriftFileProto); +} + +uint64_t ParquetReader::getGroupSpan(ParquetReaderScanState& state) { + auto& group = getGroup(state); + uint64_t min_offset = UINT64_MAX; + uint64_t max_offset = 0; + for (auto& column_chunk : group.columns) { + // Set the min offset + auto current_min_offset = UINT64_MAX; + if (column_chunk.meta_data.__isset.dictionary_page_offset) { + current_min_offset = std::min(current_min_offset, + column_chunk.meta_data.dictionary_page_offset); + } + if (column_chunk.meta_data.__isset.index_page_offset) { + current_min_offset = + std::min(current_min_offset, column_chunk.meta_data.index_page_offset); + } + current_min_offset = + std::min(current_min_offset, column_chunk.meta_data.data_page_offset); + min_offset = std::min(current_min_offset, min_offset); + max_offset = std::max(max_offset, + column_chunk.meta_data.total_compressed_size + current_min_offset); + } + + return max_offset - min_offset; +} + +LogicalType ParquetReader::deriveLogicalType(const lbug_parquet::format::SchemaElement& s_ele) { + // inner node + if (s_ele.type == Type::FIXED_LEN_BYTE_ARRAY && !s_ele.__isset.type_length) { + // LCOV_EXCL_START + throw CopyException("FIXED_LEN_BYTE_ARRAY requires length to be set"); + // LCOV_EXCL_STOP + } + if (s_ele.__isset.logicalType && s_ele.logicalType.__isset.UUID && + s_ele.type == Type::FIXED_LEN_BYTE_ARRAY) { + return LogicalType::UUID(); + } + if (s_ele.__isset.converted_type) { + switch (s_ele.converted_type) { + case ConvertedType::INT_8: + if (s_ele.type == Type::INT32) { + return LogicalType::INT8(); + } else { + // LCOV_EXCL_START + throw CopyException{"INT8 converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::INT_16: + if (s_ele.type == Type::INT32) { + return LogicalType::INT16(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "INT16 converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::INT_32: + if (s_ele.type == Type::INT32) { + return LogicalType::INT32(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "INT32 converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::INT_64: + if (s_ele.type == Type::INT64) { + return LogicalType::INT64(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "INT64 converted type can only be set for value of Type::INT64"}; + // LCOV_EXCL_STOP + } + case ConvertedType::UINT_8: + if (s_ele.type == Type::INT32) { + return LogicalType::UINT8(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "UINT8 converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::UINT_16: + if (s_ele.type == Type::INT32) { + return LogicalType::UINT16(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "UINT16 converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::UINT_32: + if (s_ele.type == Type::INT32) { + return LogicalType::UINT32(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "UINT32 converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::UINT_64: + if (s_ele.type == Type::INT64) { + return LogicalType::UINT64(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "UINT64 converted type can only be set for value of Type::INT64"}; + // LCOV_EXCL_STOP + } + case ConvertedType::DATE: + if (s_ele.type == Type::INT32) { + return LogicalType::DATE(); + } else { + // LCOV_EXCL_START + throw CopyException{"DATE converted type can only be set for value of Type::INT32"}; + // LCOV_EXCL_STOP + } + case ConvertedType::TIMESTAMP_MICROS: + case ConvertedType::TIMESTAMP_MILLIS: + if (s_ele.type == Type::INT64) { + return LogicalType::TIMESTAMP(); + } else { + // LCOV_EXCL_START + throw CopyException( + "TIMESTAMP converted type can only be set for value of Type::INT64"); + // LCOV_EXCL_STOP + } + case ConvertedType::INTERVAL: { + return LogicalType::INTERVAL(); + } + case ConvertedType::UTF8: + switch (s_ele.type) { + case Type::BYTE_ARRAY: + case Type::FIXED_LEN_BYTE_ARRAY: + return LogicalType::STRING(); + // LCOV_EXCL_START + default: + throw CopyException( + "UTF8 converted type can only be set for Type::(FIXED_LEN_)BYTE_ARRAY"); + // LCOV_EXCL_STOP + } + case ConvertedType::SERIAL: + if (s_ele.type == Type::INT64) { + return LogicalType::SERIAL(); + } else { + // LCOV_EXCL_START + throw CopyException{ + "SERIAL converted type can only be set for value of Type::INT64"}; + // LCOV_EXCL_STOP + } + + default: + // LCOV_EXCL_START + throw CopyException{"Unsupported converted type"}; + // LCOV_EXCL_STOP + } + } else { + // no converted type set + // use default type for each physical type + switch (s_ele.type) { + case Type::BOOLEAN: + return LogicalType::BOOL(); + case Type::INT32: + return LogicalType::INT32(); + case Type::INT64: + return LogicalType::INT64(); + case Type::INT96: + return LogicalType::TIMESTAMP(); + case Type::FLOAT: + return LogicalType::FLOAT(); + case Type::DOUBLE: + return LogicalType::DOUBLE(); + case Type::BYTE_ARRAY: + case Type::FIXED_LEN_BYTE_ARRAY: + // TODO(Ziyi): Support parquet copy option(binary_as_string). + return LogicalType::BLOB(); + default: + return LogicalType(LogicalTypeID::ANY); + } + } +} + +uint64_t ParquetReader::getGroupCompressedSize(ParquetReaderScanState& state) { + auto& group = getGroup(state); + auto total_compressed_size = group.total_compressed_size; + + uint64_t calc_compressed_size = 0; + + // If the global total_compressed_size is not set, we can still calculate it + if (group.total_compressed_size == 0) { + for (auto& column_chunk : group.columns) { + calc_compressed_size += column_chunk.meta_data.total_compressed_size; + } + } + + // LCOV_EXCL_START + if (total_compressed_size != 0 && calc_compressed_size != 0 && + (uint64_t)total_compressed_size != calc_compressed_size) { + throw CopyException( + "mismatch between calculated compressed size and reported compressed size"); + } + // LCOV_EXCL_STOP + + return total_compressed_size ? total_compressed_size : calc_compressed_size; +} + +uint64_t ParquetReader::getGroupOffset(ParquetReaderScanState& state) { + auto& group = getGroup(state); + uint64_t minOffset = UINT64_MAX; + + for (auto& column_chunk : group.columns) { + if (column_chunk.meta_data.__isset.dictionary_page_offset) { + minOffset = + std::min(minOffset, column_chunk.meta_data.dictionary_page_offset); + } + if (column_chunk.meta_data.__isset.index_page_offset) { + minOffset = std::min(minOffset, column_chunk.meta_data.index_page_offset); + } + minOffset = std::min(minOffset, column_chunk.meta_data.data_page_offset); + } + + return minOffset; +} + +ParquetScanSharedState::ParquetScanSharedState(FileScanInfo fileScanInfo, uint64_t numRows, + main::ClientContext* context, std::vector columnSkips) + : ScanFileWithProgressSharedState{std::move(fileScanInfo), numRows, context}, + columnSkips{columnSkips} { + readers.push_back(std::make_unique(this->fileScanInfo.filePaths[fileIdx], + columnSkips, context)); + totalRowsGroups = 0; + for (auto i = fileIdx.load(); i < this->fileScanInfo.getNumFiles(); i++) { + auto reader = + std::make_unique(this->fileScanInfo.filePaths[i], columnSkips, context); + totalRowsGroups += reader->getNumRowsGroups(); + } + numBlocksReadByFiles = 0; +} + +static bool parquetSharedStateNext(ParquetScanLocalState& localState, + ParquetScanSharedState& sharedState) { + std::lock_guard mtx{sharedState.mtx}; + while (true) { + if (sharedState.fileIdx >= sharedState.fileScanInfo.getNumFiles()) { + return false; + } + if (sharedState.blockIdx < sharedState.readers[sharedState.fileIdx]->getNumRowsGroups()) { + localState.reader = sharedState.readers[sharedState.fileIdx].get(); + localState.reader->initializeScan(*localState.state, {sharedState.blockIdx}, + VirtualFileSystem::GetUnsafe(*sharedState.context)); + sharedState.blockIdx++; + return true; + } else { + sharedState.numBlocksReadByFiles += + sharedState.readers[sharedState.fileIdx]->getNumRowsGroups(); + sharedState.blockIdx = 0; + sharedState.fileIdx++; + if (sharedState.fileIdx >= sharedState.fileScanInfo.getNumFiles()) { + return false; + } + sharedState.readers.push_back(std::make_unique( + sharedState.fileScanInfo.filePaths[sharedState.fileIdx], sharedState.columnSkips, + sharedState.context)); + continue; + } + } +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto& outputChunk = output.dataChunk; + if (input.localState == nullptr) { + return 0; + } + auto parquetScanLocalState = ku_dynamic_cast(input.localState); + auto parquetScanSharedState = ku_dynamic_cast(input.sharedState); + do { + parquetScanLocalState->reader->scan(*parquetScanLocalState->state, outputChunk); + if (outputChunk.state->getSelVector().getSelSize() > 0) { + return outputChunk.state->getSelVector().getSelSize(); + } + if (!parquetSharedStateNext(*parquetScanLocalState, *parquetScanSharedState)) { + return outputChunk.state->getSelVector().getSelSize(); + } + } while (true); +} + +static void bindColumns(const ExtraScanTableFuncBindInput* bindInput, uint32_t fileIdx, + std::vector& columnNames, std::vector& columnTypes, + main::ClientContext* context) { + auto reader = ParquetReader(bindInput->fileScanInfo.filePaths[fileIdx], {}, context); + auto state = std::make_unique(); + reader.initializeScan(*state, std::vector{}, VirtualFileSystem::GetUnsafe(*context)); + for (auto i = 0u; i < reader.getNumColumns(); ++i) { + columnNames.push_back(reader.getColumnName(i)); + columnTypes.push_back(reader.getColumnType(i).copy()); + } +} + +static void bindColumns(const ExtraScanTableFuncBindInput* bindInput, + std::vector& columnNames, std::vector& columnTypes, + main::ClientContext* context) { + KU_ASSERT(bindInput->fileScanInfo.getNumFiles() > 0); + bindColumns(bindInput, 0 /* fileIdx */, columnNames, columnTypes, context); + for (auto i = 1u; i < bindInput->fileScanInfo.getNumFiles(); ++i) { + std::vector tmpColumnNames; + std::vector tmpColumnTypes; + bindColumns(bindInput, i, tmpColumnNames, tmpColumnTypes, context); + ReaderBindUtils::validateNumColumns(columnTypes.size(), tmpColumnTypes.size()); + ReaderBindUtils::validateColumnTypes(columnNames, columnTypes, tmpColumnTypes); + } +} + +static row_idx_t getNumRows(std::vector filePaths, uint64_t numColumns, + main::ClientContext* context) { + std::vector dummyColumnSkips(false, numColumns); + row_idx_t numRows = 0; + for (const auto& path : filePaths) { + auto reader = std::make_unique(path, dummyColumnSkips, context); + numRows += reader->getMetadata()->num_rows; + } + return numRows; +} + +static std::unique_ptr bindFunc(main::ClientContext* context, + const TableFuncBindInput* input) { + auto scanInput = ku_dynamic_cast(input->extraInput.get()); + const auto& options = scanInput->fileScanInfo.options; + if (options.size() > 1 || + (options.size() == 1 && !options.contains(CopyConstants::IGNORE_ERRORS_OPTION_NAME))) { + throw BinderException{"Copy from Parquet cannot have options other than IGNORE_ERRORS."}; + } + std::vector detectedColumnNames; + std::vector detectedColumnTypes; + bindColumns(scanInput, detectedColumnNames, detectedColumnTypes, context); + if (!scanInput->expectedColumnNames.empty()) { + ReaderBindUtils::validateNumColumns(scanInput->expectedColumnNames.size(), + detectedColumnNames.size()); + detectedColumnNames = scanInput->expectedColumnNames; + } + + detectedColumnNames = + TableFunction::extractYieldVariables(detectedColumnNames, input->yieldVariables); + auto resultColumns = input->binder->createVariables(detectedColumnNames, detectedColumnTypes); + auto bindData = std::make_unique(std::move(resultColumns), + getNumRows(scanInput->fileScanInfo.filePaths, detectedColumnNames.size(), context), + scanInput->fileScanInfo.copy(), context); + return bindData; +} + +static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input) { + auto bindData = input.bindData->constPtrCast(); + return std::make_unique(bindData->fileScanInfo.copy(), + bindData->numRows, bindData->context, bindData->getColumnSkips()); +} + +static std::unique_ptr initLocalState( + const TableFuncInitLocalStateInput& input) { + auto sharedState = input.sharedState.ptrCast(); + auto localState = std::make_unique(); + if (!parquetSharedStateNext(*localState, *sharedState)) { + return nullptr; + } + return localState; +} + +static double progressFunc(TableFuncSharedState* sharedState) { + auto state = sharedState->ptrCast(); + if (state->fileIdx >= state->fileScanInfo.getNumFiles()) { + return 1.0; + } + if (state->totalRowsGroups == 0) { + return 0.0; + } + uint64_t totalReadSize = state->numBlocksReadByFiles + state->blockIdx; + return static_cast(totalReadSize) / state->totalRowsGroups; +} + +static void finalizeFunc(const ExecutionContext* ctx, TableFuncSharedState*) { + WarningContext::Get(*ctx->clientContext)->defaultPopulateAllWarnings(ctx->queryID); +} + +function_set ParquetScanFunction::getFunctionSet() { + function_set functionSet; + auto function = std::make_unique(name, std::vector{LogicalTypeID::STRING}); + function->tableFunc = tableFunc; + function->bindFunc = bindFunc; + function->initSharedStateFunc = initSharedState; + function->initLocalStateFunc = initLocalState; + function->progressFunc = progressFunc; + function->finalizeFunc = finalizeFunc; + functionSet.push_back(std::move(function)); + return functionSet; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/parquet_timestamp.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/parquet_timestamp.cpp new file mode 100644 index 0000000000..f236ac9461 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/parquet_timestamp.cpp @@ -0,0 +1,38 @@ +#include "processor/operator/persistent/reader/parquet/parquet_timestamp.h" + +#include + +namespace lbug { +namespace processor { + +common::timestamp_t ParquetTimeStampUtils::impalaTimestampToTimestamp(const Int96& rawTS) { + auto impalaUS = impalaTimestampToMicroseconds(rawTS); + return common::Timestamp::fromEpochMicroSeconds(impalaUS); +} + +common::timestamp_t ParquetTimeStampUtils::parquetTimestampMicrosToTimestamp(const int64_t& rawTS) { + return common::Timestamp::fromEpochMicroSeconds(rawTS); +} + +common::timestamp_t ParquetTimeStampUtils::parquetTimestampMsToTimestamp(const int64_t& rawTS) { + return common::Timestamp::fromEpochMilliSeconds(rawTS); +} + +common::timestamp_t ParquetTimeStampUtils::parquetTimestampNsToTimestamp(const int64_t& rawTS) { + return common::Timestamp::fromEpochNanoSeconds(rawTS); +} + +int64_t ParquetTimeStampUtils::impalaTimestampToMicroseconds(const Int96& impalaTimestamp) { + int64_t daysSinceEpoch = impalaTimestamp.value[2] - JULIAN_TO_UNIX_EPOCH_DAYS; + int64_t nanoSeconds = 0; + memcpy(&nanoSeconds, &impalaTimestamp.value, sizeof(nanoSeconds)); + auto microseconds = nanoSeconds / NANOSECONDS_PER_MICRO; + return daysSinceEpoch * MICROSECONDS_PER_DAY + microseconds; +} + +common::date_t ParquetTimeStampUtils::parquetIntToDate(const int32_t& raw_date) { + return common::date_t(raw_date); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/string_column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/string_column_reader.cpp new file mode 100644 index 0000000000..5ad98d1d4a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/string_column_reader.cpp @@ -0,0 +1,93 @@ +#include "processor/operator/persistent/reader/parquet/string_column_reader.h" + +#include "common/types/blob.h" +#include "common/types/ku_string.h" +#include "parquet_types.h" +#include "utf8proc_wrapper.h" + +using lbug_parquet::format::Type; + +namespace lbug { +namespace processor { + +StringColumnReader::StringColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat) + : TemplatedColumnReader(reader, + std::move(type), schema, schemaIdx, maxDefine, maxRepeat) { + fixedWidthStringLength = 0; + if (schema.type == Type::FIXED_LEN_BYTE_ARRAY) { + KU_ASSERT(schema.__isset.type_length); + fixedWidthStringLength = schema.type_length; + } +} + +uint32_t StringColumnReader::verifyString(const char* strData, uint32_t strLen, + const bool isVarchar) { + if (!isVarchar) { + return strLen; + } + // verify if a string is actually UTF8, and if there are no null bytes in the middle of the + // string technically Parquet should guarantee this, but reality is often disappointing + auto reason = utf8proc::UnicodeInvalidReason::INVALID_UNICODE; + size_t pos = 0; + auto utf_type = utf8proc::Utf8Proc::analyze(strData, strLen, &reason, &pos); + if (utf_type == utf8proc::UnicodeType::INVALID) { + throw common::CopyException{ + "Invalid string encoding found in Parquet file: value \"" + + common::Blob::toString(reinterpret_cast(strData), strLen) + + "\" is not valid UTF8!"}; + } + return strLen; +} + +uint32_t StringColumnReader::verifyString(const char* strData, uint32_t strLen) { + return verifyString(strData, strLen, + getDataType().getLogicalTypeID() == common::LogicalTypeID::STRING); +} + +void StringColumnReader::dictionary(const std::shared_ptr& data, + uint64_t numEntries) { + dict = data; + dictStrs = std::unique_ptr(new common::ku_string_t[numEntries]); + for (auto dictIdx = 0u; dictIdx < numEntries; dictIdx++) { + auto strLen = fixedWidthStringLength == 0 ? dict->read() : fixedWidthStringLength; + dict->available(strLen); + + auto dict_str = reinterpret_cast(dict->ptr); + auto actual_str_len = verifyString(dict_str, strLen); + dictStrs[dictIdx].setFromRawStr(dict_str, actual_str_len); + dict->inc(strLen); + } +} + +common::ku_string_t StringParquetValueConversion::dictRead(ByteBuffer& /*dict*/, uint32_t& offset, + ColumnReader& reader) { + auto& dictStrings = reinterpret_cast(reader).dictStrs; + return dictStrings[offset]; +} + +common::ku_string_t StringParquetValueConversion::plainRead(ByteBuffer& plainData, + ColumnReader& reader) { + auto& scr = reinterpret_cast(reader); + uint32_t strLen = + scr.fixedWidthStringLength == 0 ? plainData.read() : scr.fixedWidthStringLength; + plainData.available(strLen); + auto plainStr = reinterpret_cast(plainData.ptr); + auto actualStrLen = + reinterpret_cast(reader).verifyString(plainStr, strLen); + auto retStr = common::ku_string_t(); + retStr.setFromRawStr(plainStr, actualStrLen); + plainData.inc(strLen); + return retStr; +} + +void StringParquetValueConversion::plainSkip(ByteBuffer& plainData, ColumnReader& reader) { + auto& scr = reinterpret_cast(reader); + uint32_t strLen = + scr.fixedWidthStringLength == 0 ? plainData.read() : scr.fixedWidthStringLength; + plainData.inc(strLen); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/struct_column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/struct_column_reader.cpp new file mode 100644 index 0000000000..975c82769b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/struct_column_reader.cpp @@ -0,0 +1,99 @@ +#include "processor/operator/persistent/reader/parquet/struct_column_reader.h" + +namespace lbug { +namespace processor { + +StructColumnReader::StructColumnReader(ParquetReader& reader, common::LogicalType type, + const lbug_parquet::format::SchemaElement& schema, uint64_t schemaIdx, uint64_t maxDefine, + uint64_t maxRepeat, std::vector> childReaders) + : ColumnReader(reader, std::move(type), schema, schemaIdx, maxDefine, maxRepeat), + childReaders(std::move(childReaders)) { + KU_ASSERT(this->type.getPhysicalType() == common::PhysicalTypeID::STRUCT); +} + +ColumnReader* StructColumnReader::getChildReader(uint64_t childIdx) { + KU_ASSERT(childIdx < childReaders.size()); + return childReaders[childIdx].get(); +} + +void StructColumnReader::initializeRead(uint64_t rowGroupIdx, + const std::vector& columns, + lbug_apache::thrift::protocol::TProtocol& protocol) { + for (auto& child : childReaders) { + child->initializeRead(rowGroupIdx, columns, protocol); + } +} + +uint64_t StructColumnReader::getTotalCompressedSize() { + uint64_t size = 0; + for (auto& child : childReaders) { + size += child->getTotalCompressedSize(); + } + return size; +} + +void StructColumnReader::registerPrefetch(ThriftFileTransport& transport, bool allow_merge) { + for (auto& child : childReaders) { + child->registerPrefetch(transport, allow_merge); + } +} + +uint64_t StructColumnReader::read(uint64_t numValuesToRead, parquet_filter_t& filter, + uint8_t* define_out, uint8_t* repeat_out, common::ValueVector* result) { + auto& fieldVectors = common::StructVector::getFieldVectors(result); + KU_ASSERT(common::StructType::getNumFields(type) == fieldVectors.size()); + if (pendingSkips > 0) { + applyPendingSkips(pendingSkips); + } + + uint64_t numValuesRead = numValuesToRead; + for (auto i = 0u; i < fieldVectors.size(); i++) { + auto numValuesChildrenRead = childReaders[i]->read(numValuesToRead, filter, define_out, + repeat_out, fieldVectors[i].get()); + if (i == 0) { + numValuesRead = numValuesChildrenRead; + } else if (numValuesRead != numValuesChildrenRead) { + throw std::runtime_error("Struct child row count mismatch"); + } + } + for (auto i = 0u; i < numValuesRead; i++) { + result->setNull(i, define_out[i] < maxDefine); + } + + return numValuesRead; +} + +void StructColumnReader::skip(uint64_t num_values) { + for (auto& child_reader : childReaders) { + child_reader->skip(num_values); + } +} + +static bool TypeHasExactRowCount(const common::LogicalType& type) { + switch (type.getLogicalTypeID()) { + case common::LogicalTypeID::LIST: + case common::LogicalTypeID::MAP: + return false; + case common::LogicalTypeID::STRUCT: + for (auto kv : common::StructType::getFieldTypes(type)) { + if (TypeHasExactRowCount(*kv)) { + return true; + } + } + return false; + default: + return true; + } +} + +uint64_t StructColumnReader::getGroupRowsAvailable() { + for (auto i = 0u; i < childReaders.size(); i++) { + if (TypeHasExactRowCount(childReaders[i]->getDataType())) { + return childReaders[i]->getGroupRowsAvailable(); + } + } + return childReaders[0]->getGroupRowsAvailable(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/uuid_column_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/uuid_column_reader.cpp new file mode 100644 index 0000000000..683dd629cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/parquet/uuid_column_reader.cpp @@ -0,0 +1,41 @@ +#include "processor/operator/persistent/reader/parquet/uuid_column_reader.h" + +namespace lbug { +namespace processor { + +common::ku_uuid_t UUIDValueConversion::ReadParquetUUID(const uint8_t* input) { + common::ku_uuid_t result{}; + result.value.low = 0; + uint64_t unsignedUpper = 0; + for (auto i = 0u; i < sizeof(uint64_t); i++) { + unsignedUpper <<= 8; + unsignedUpper += input[i]; + } + for (auto i = sizeof(uint64_t); i < sizeof(common::ku_uuid_t); i++) { + result.value.low <<= 8; + result.value.low += input[i]; + } + result.value.high = unsignedUpper; + result.value.high ^= (int64_t(1) << 63); + return result; +} + +common::ku_uuid_t UUIDValueConversion::plainRead(ByteBuffer& bufferData, ColumnReader& /*reader*/) { + auto uuidLen = sizeof(common::ku_uuid_t); + bufferData.available(uuidLen); + auto res = ReadParquetUUID(reinterpret_cast(bufferData.ptr)); + bufferData.inc(uuidLen); + return res; +} + +void UUIDColumnReader::dictionary(const std::shared_ptr& dictionaryData, + uint64_t numEntries) { + allocateDict(numEntries * sizeof(common::ku_uuid_t)); + auto dictPtr = reinterpret_cast(this->dict->ptr); + for (auto i = 0u; i < numEntries; i++) { + dictPtr[i] = UUIDValueConversion::plainRead(*dictionaryData, *this); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/reader_bind_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/reader_bind_utils.cpp new file mode 100644 index 0000000000..fe5b00467e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/reader/reader_bind_utils.cpp @@ -0,0 +1,51 @@ +#include "processor/operator/persistent/reader/reader_bind_utils.h" + +#include "common/exception/binder.h" +#include "common/string_format.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void ReaderBindUtils::validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber) { + if (detectedNumber == 0) { + return; // Empty CSV. Continue processing. + } + if (expectedNumber != detectedNumber) { + throw common::BinderException(common::stringFormat( + "Number of columns mismatch. Expected {} but got {}.", expectedNumber, detectedNumber)); + } +} + +void ReaderBindUtils::validateColumnTypes(const std::vector& columnNames, + const std::vector& expectedColumnTypes, + const std::vector& detectedColumnTypes) { + KU_ASSERT(expectedColumnTypes.size() == detectedColumnTypes.size()); + for (auto i = 0u; i < expectedColumnTypes.size(); ++i) { + if (expectedColumnTypes[i] != detectedColumnTypes[i]) { + throw common::BinderException(common::stringFormat( + "Column `{}` type mismatch. Expected {} but got {}.", columnNames[i], + expectedColumnTypes[i].toString(), detectedColumnTypes[i].toString())); + } + } +} + +void ReaderBindUtils::resolveColumns(const std::vector& expectedColumnNames, + const std::vector& detectedColumnNames, + std::vector& resultColumnNames, + const std::vector& expectedColumnTypes, + const std::vector& detectedColumnTypes, + std::vector& resultColumnTypes) { + if (expectedColumnTypes.empty()) { + resultColumnNames = detectedColumnNames; + resultColumnTypes = LogicalType::copy(detectedColumnTypes); + } else { + validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size()); + resultColumnNames = expectedColumnNames; + resultColumnTypes = LogicalType::copy(expectedColumnTypes); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/rel_batch_insert.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/rel_batch_insert.cpp new file mode 100644 index 0000000000..4893425cdf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/rel_batch_insert.cpp @@ -0,0 +1,279 @@ +#include "processor/operator/persistent/rel_batch_insert.h" + +#include "catalog/catalog.h" +#include "common/cast.h" +#include "common/exception/copy.h" +#include "common/exception/message.h" +#include "common/string_format.h" +#include "common/task_system/progress_bar.h" +#include "processor/execution_context.h" +#include "processor/result/factorized_table_util.h" +#include "processor/warning_context.h" +#include "storage/local_storage/local_storage.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/csr_chunked_node_group.h" +#include "storage/table/rel_table.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string RelBatchInsertPrintInfo::toString() const { + std::string result = "Table Name: "; + result += tableName; + return result; +} + +void RelBatchInsert::initLocalStateInternal(ResultSet*, ExecutionContext* context) { + localState = std::make_unique(); + const auto relInfo = info->ptrCast(); + localState->chunkedGroup = + std::make_unique(*MemoryManager::Get(*context->clientContext), + relInfo->columnTypes, relInfo->compressionEnabled, 0, 0); + const auto transaction = transaction::Transaction::Get(*context->clientContext); + localState->optimisticAllocator = transaction->getLocalStorage()->addOptimisticAllocator(); + const auto clientContext = context->clientContext; + const auto catalog = Catalog::Get(*clientContext); + const auto catalogEntry = catalog->getTableCatalogEntry(transaction, info->tableName); + const auto& relGroupEntry = catalogEntry->constCast(); + auto tableID = relGroupEntry.getRelEntryInfo(relInfo->fromTableID, relInfo->toTableID)->oid; + auto nbrTableID = RelDirectionUtils::getNbrTableID(relInfo->direction, relInfo->fromTableID, + relInfo->toTableID); + // TODO(Guodong): Get rid of the hard-coded nbr and rel column ID 0/1. + localState->chunkedGroup->getColumnChunk(0).cast().setTableID(nbrTableID); + localState->chunkedGroup->getColumnChunk(1).cast().setTableID(tableID); + const auto relLocalState = localState->ptrCast(); + relLocalState->dummyAllNullDataChunk = std::make_unique(relInfo->columnTypes.size()); + for (auto i = 0u; i < relInfo->columnTypes.size(); i++) { + auto valueVector = std::make_shared(relInfo->columnTypes[i].copy(), + MemoryManager::Get(*context->clientContext)); + valueVector->setAllNull(); + relLocalState->dummyAllNullDataChunk->insert(i, std::move(valueVector)); + } +} + +void RelBatchInsert::initGlobalStateInternal(ExecutionContext* context) { + const auto relBatchInsertInfo = info->ptrCast(); + const auto clientContext = context->clientContext; + const auto catalog = Catalog::Get(*clientContext); + const auto transaction = transaction::Transaction::Get(*clientContext); + const auto catalogEntry = catalog->getTableCatalogEntry(transaction, info->tableName); + const auto& relGroupEntry = catalogEntry->constCast(); + // Init info + info->compressionEnabled = StorageManager::Get(*clientContext)->compressionEnabled(); + auto dataColumnIdx = 0u; + // Handle internal id column + info->columnTypes.push_back(LogicalType::INTERNAL_ID()); + info->insertColumnIDs.push_back(0); + info->outputDataColumns.push_back(dataColumnIdx++); + for (auto& property : relGroupEntry.getProperties()) { + info->columnTypes.push_back(property.getType().copy()); + info->insertColumnIDs.push_back(relGroupEntry.getColumnID(property.getName())); + info->outputDataColumns.push_back(dataColumnIdx++); + } + for (auto& type : info->warningColumnTypes) { + info->columnTypes.push_back(type.copy()); + info->warningDataColumns.push_back(dataColumnIdx++); + } + relBatchInsertInfo->partitioningIdx = + relBatchInsertInfo->direction == RelDataDirection::FWD ? 0 : 1; + relBatchInsertInfo->boundNodeOffsetColumnID = + relBatchInsertInfo->direction == RelDataDirection::FWD ? 0 : 1; + // Init shared state + sharedState->table = partitionerSharedState->relTable; + progressSharedState = std::make_shared(); + progressSharedState->partitionsDone = 0; + progressSharedState->partitionsTotal = + partitionerSharedState->getNumPartitions(relBatchInsertInfo->partitioningIdx); +} + +void RelBatchInsert::executeInternal(ExecutionContext* context) { + const auto relInfo = info->ptrCast(); + const auto relTable = sharedState->table->ptrCast(); + const auto relLocalState = localState->ptrCast(); + const auto clientContext = context->clientContext; + const auto catalog = Catalog::Get(*clientContext); + const auto transaction = transaction::Transaction::Get(*clientContext); + const auto& relGroupEntry = catalog->getTableCatalogEntry(transaction, relInfo->tableName) + ->constCast(); + while (true) { + relLocalState->nodeGroupIdx = + partitionerSharedState->getNextPartition(relInfo->partitioningIdx); + if (relLocalState->nodeGroupIdx == INVALID_PARTITION_IDX) { + // No more partitions left in the partitioning buffer. + break; + } + ++progressSharedState->partitionsDone; + // TODO(Guodong): We need to handle the concurrency between COPY and other insertions + // into the same node group. + auto& nodeGroup = + relTable + ->getOrCreateNodeGroup(transaction, relLocalState->nodeGroupIdx, relInfo->direction) + ->cast(); + appendNodeGroup(relGroupEntry, *MemoryManager::Get(*clientContext), transaction, nodeGroup, + *relInfo, *relLocalState); + updateProgress(context); + } +} + +static void appendNewChunkedGroup(MemoryManager& mm, transaction::Transaction* transaction, + const std::vector& columnIDs, InMemChunkedCSRNodeGroup& chunkedGroup, + RelTable& relTable, CSRNodeGroup& nodeGroup, RelDataDirection direction, + PageAllocator& pageAllocator) { + const bool isNewNodeGroup = nodeGroup.isEmpty(); + const CSRNodeGroupScanSource source = isNewNodeGroup ? + CSRNodeGroupScanSource::COMMITTED_PERSISTENT : + CSRNodeGroupScanSource::COMMITTED_IN_MEMORY; + // since each thread operates on distinct node groups + // We don't need a lock here (to ensure the insert info and append agree on the number of rows + // in the node group) + relTable.pushInsertInfo(transaction, direction, nodeGroup, chunkedGroup.getNumRows(), source); + if (isNewNodeGroup) { + auto flushedChunkedGroup = chunkedGroup.flush(transaction, pageAllocator); + + // If there are deleted columns that haven't been vacuumed yet + // we need to add extra columns to the chunked group + // to ensure that the number of columns is consistent with the rest of the node group + auto persistentChunkedGroup = std::make_unique(mm, + flushedChunkedGroup->cast(), nodeGroup.getDataTypes(), columnIDs); + + nodeGroup.setPersistentChunkedGroup(std::move(persistentChunkedGroup)); + } else { + nodeGroup.appendChunkedCSRGroup(transaction, columnIDs, chunkedGroup); + } +} + +void RelBatchInsert::appendNodeGroup(const RelGroupCatalogEntry& relGroupEntry, MemoryManager& mm, + transaction::Transaction* transaction, CSRNodeGroup& nodeGroup, + const RelBatchInsertInfo& relInfo, const RelBatchInsertLocalState& localState) { + const auto nodeGroupIdx = localState.nodeGroupIdx; + const auto startNodeOffset = storage::StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + auto executionState = impl->initExecutionState(*partitionerSharedState, relInfo, nodeGroupIdx); + // Calculate num of source nodes in this node group. + // This will be used to set the num of values of the node group. + const auto numNodes = std::min(StorageConfig::NODE_GROUP_SIZE, + partitionerSharedState->getNumNodes(relInfo.partitioningIdx) - startNodeOffset); + // We optimistically flush new node group directly to disk in gapped CSR format. + // There is no benefit of leaving gaps for existing node groups, which is kept in memory. + const auto leaveGaps = nodeGroup.isEmpty(); + populateCSRHeader(relGroupEntry, *executionState, startNodeOffset, relInfo, localState, + numNodes, leaveGaps); + const auto& csrHeader = + ku_dynamic_cast(*localState.chunkedGroup).getCSRHeader(); + impl->writeToTable(*executionState, csrHeader, localState, *sharedState, relInfo); + // Reset num of rows in the chunked group to fill gaps at the end of the node group. + const auto maxSize = csrHeader.getEndCSROffset(numNodes - 1); + auto numGapsAtEnd = maxSize - localState.chunkedGroup->getNumRows(); + KU_ASSERT(localState.chunkedGroup->getCapacity() >= maxSize); + while (numGapsAtEnd > 0) { + const auto numGapsToFill = std::min(numGapsAtEnd, DEFAULT_VECTOR_CAPACITY); + localState.dummyAllNullDataChunk->state->getSelVectorUnsafe().setSelSize(numGapsToFill); + std::vector dummyVectors; + for (auto i = 0u; i < relInfo.columnTypes.size(); i++) { + dummyVectors.push_back(&localState.dummyAllNullDataChunk->getValueVectorMutable(i)); + } + const auto numGapsFilled = localState.chunkedGroup->append(dummyVectors, 0, numGapsToFill); + KU_ASSERT(numGapsFilled == numGapsToFill); + numGapsAtEnd -= numGapsFilled; + } + KU_ASSERT(localState.chunkedGroup->getNumRows() == maxSize); + + auto* relTable = sharedState->table->ptrCast(); + + InMemChunkedCSRNodeGroup sliceToWriteToDisk{ + ku_dynamic_cast(*localState.chunkedGroup), + relInfo.outputDataColumns}; + appendNewChunkedGroup(mm, transaction, relInfo.insertColumnIDs, sliceToWriteToDisk, *relTable, + nodeGroup, relInfo.direction, *localState.optimisticAllocator); + ku_dynamic_cast(*localState.chunkedGroup) + .mergeChunkedCSRGroup(sliceToWriteToDisk, relInfo.outputDataColumns); + + localState.chunkedGroup->resetToEmpty(); +} + +void RelBatchInsertImpl::finalizeStartCSROffsets(RelBatchInsertExecutionState&, + storage::InMemChunkedCSRHeader& csrHeader, const RelBatchInsertInfo&) { + csrHeader.populateEndCSROffsetFromStartAndLength(); +} + +void RelBatchInsert::populateCSRHeader(const RelGroupCatalogEntry& relGroupEntry, + RelBatchInsertExecutionState& executionState, offset_t startNodeOffset, + const RelBatchInsertInfo& relInfo, const RelBatchInsertLocalState& localState, + offset_t numNodes, bool leaveGaps) { + auto& csrNodeGroup = ku_dynamic_cast(*localState.chunkedGroup); + auto& csrHeader = csrNodeGroup.getCSRHeader(); + csrHeader.setNumValues(numNodes); + // Populate lengths for each node and check multiplicity constraint. + impl->populateCSRLengths(executionState, csrHeader, numNodes, relInfo); + checkRelMultiplicityConstraint(relGroupEntry, csrHeader, startNodeOffset, relInfo); + const auto rightCSROffsetOfRegions = csrHeader.populateStartCSROffsetsFromLength(leaveGaps); + impl->finalizeStartCSROffsets(executionState, csrHeader, relInfo); + csrHeader.finalizeCSRRegionEndOffsets(rightCSROffsetOfRegions); + // Resize csr data column chunks. + localState.chunkedGroup->resizeChunks(csrHeader.getEndCSROffset(numNodes - 1)); + localState.chunkedGroup->resetToAllNull(); + KU_ASSERT(csrHeader.sanityCheck()); +} + +void RelBatchInsert::checkRelMultiplicityConstraint(const RelGroupCatalogEntry& relGroupEntry, + const InMemChunkedCSRHeader& csrHeader, offset_t startNodeOffset, + const RelBatchInsertInfo& relInfo) { + if (!relGroupEntry.isSingleMultiplicity(relInfo.direction)) { + return; + } + for (auto i = 0u; i < csrHeader.length->getNumValues(); i++) { + if (csrHeader.length->getValue(i) > 1) { + throw CopyException(ExceptionMessage::violateRelMultiplicityConstraint( + relInfo.tableName, std::to_string(i + startNodeOffset), + RelDirectionUtils::relDirectionToString(relInfo.direction))); + } + } +} + +void RelBatchInsert::finalizeInternal(ExecutionContext* context) { + const auto relInfo = info->ptrCast(); + if (relInfo->direction == RelDataDirection::FWD) { + KU_ASSERT(relInfo->partitioningIdx == 0); + + auto outputMsg = stringFormat("{} tuples have been copied to the {} table.", + sharedState->getNumRows(), relInfo->tableName); + auto clientContext = context->clientContext; + FactorizedTableUtils::appendStringToTable(sharedState->fTable.get(), outputMsg, + MemoryManager::Get(*clientContext)); + + auto warningContext = WarningContext::Get(*context->clientContext); + const auto warningCount = warningContext->getWarningCount(context->queryID); + if (warningCount > 0) { + auto warningMsg = + stringFormat("{} warnings encountered during copy. Use 'CALL " + "show_warnings() RETURN *' to view the actual warnings. Query ID: {}", + warningCount, context->queryID); + FactorizedTableUtils::appendStringToTable(sharedState->fTable.get(), warningMsg, + MemoryManager::Get(*context->clientContext)); + warningContext->defaultPopulateAllWarnings(context->queryID); + } + } + sharedState->numRows.store(0); + sharedState->table->cast().setHasChanges(); + partitionerSharedState->resetState(relInfo->partitioningIdx); +} + +void RelBatchInsert::updateProgress(const ExecutionContext* context) const { + auto progressBar = ProgressBar::Get(*context->clientContext); + if (progressSharedState->partitionsTotal == 0) { + progressBar->updateProgress(context->queryID, 0); + } else { + double progress = static_cast(progressSharedState->partitionsDone) / + static_cast(progressSharedState->partitionsTotal); + progressBar->updateProgress(context->queryID, progress); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/set.cpp new file mode 100644 index 0000000000..fd21fda709 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/set.cpp @@ -0,0 +1,46 @@ +#include "processor/operator/persistent/set.h" + +#include "binder/expression/expression_util.h" + +namespace lbug { +namespace processor { + +void SetNodeProperty::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + for (auto& executor : executors) { + executor->init(resultSet, context); + } +} + +bool SetNodeProperty::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + for (auto& executor : executors) { + executor->set(context); + } + return true; +} + +void SetRelProperty::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + for (auto& executor : executors) { + executor->init(resultSet, context); + } +} + +bool SetRelProperty::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + for (auto& executor : executors) { + executor->set(context); + } + return true; +} + +std::string SetPropertyPrintInfo::toString() const { + std::string result = "Properties: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/set_executor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/set_executor.cpp new file mode 100644 index 0000000000..90920219f8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/set_executor.cpp @@ -0,0 +1,141 @@ +#include "processor/operator/persistent/set_executor.h" + +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void NodeSetInfo::init(const ResultSet& resultSet, main::ClientContext* context) { + nodeIDVector = resultSet.getValueVector(nodeIDPos).get(); + if (columnVectorPos.isValid()) { + columnVector = resultSet.getValueVector(columnVectorPos).get(); + } + evaluator->init(resultSet, context); + columnDataVector = evaluator->resultVector.get(); +} + +void NodeSetExecutor::init(ResultSet* resultSet, ExecutionContext* context) { + info.init(*resultSet, context->clientContext); +} + +void NodeSetExecutor::setNodeID(nodeID_t nodeID) const { + info.nodeIDVector->setValue(info.nodeIDVector->state->getSelVector()[0], nodeID); +} + +static void writeColumnUpdateResult(ValueVector* idVector, ValueVector* columnVector, + ValueVector* dataVector) { + auto& idSelVector = idVector->state->getSelVector(); + auto& columnSelVector = columnVector->state->getSelVector(); + auto& dataSelVector = dataVector->state->getSelVector(); + KU_ASSERT(idSelVector.getSelSize() == 1); + if (idVector->isNull(idSelVector[0])) { // No update happened. + return; + } + KU_ASSERT(dataSelVector.getSelSize() == 1); + if (dataVector->isNull(dataSelVector[0])) { // Update to NULL + columnVector->setNull(dataSelVector[0], true); + return; + } + columnVector->setNull(columnSelVector[0], false); + columnVector->copyFromVectorData(columnSelVector[0], dataVector, dataSelVector[0]); +} + +void SingleLabelNodeSetExecutor::set(ExecutionContext* context) { + if (tableInfo.columnID == INVALID_COLUMN_ID) { + // Not a valid column. Set projected column to null. + if (info.columnVectorPos.isValid()) { + info.columnVector->setNull(info.columnDataVector->state->getSelVector()[0], true); + } + return; + } + info.evaluator->evaluate(); + auto updateState = std::make_unique(tableInfo.columnID, + *info.nodeIDVector, *info.columnDataVector); + tableInfo.table->initUpdateState(context->clientContext, *updateState); + tableInfo.table->update(transaction::Transaction::Get(*context->clientContext), *updateState); + if (info.columnVectorPos.isValid()) { + writeColumnUpdateResult(info.nodeIDVector, info.columnVector, info.columnDataVector); + } +} + +void MultiLabelNodeSetExecutor::set(ExecutionContext* context) { + info.evaluator->evaluate(); + auto& nodeIDSelVector = info.nodeIDVector->state->getSelVector(); + KU_ASSERT(nodeIDSelVector.getSelSize() == 1); + auto nodeIDPos = nodeIDSelVector[0]; + auto& nodeID = info.nodeIDVector->getValue(nodeIDPos); + if (!tableInfos.contains(nodeID.tableID)) { + if (info.columnVectorPos.isValid()) { + info.columnVector->setNull(info.columnDataVector->state->getSelVector()[0], true); + } + return; + } + auto& tableInfo = tableInfos.at(nodeID.tableID); + auto updateState = std::make_unique(tableInfo.columnID, + *info.nodeIDVector, *info.columnDataVector); + tableInfo.table->initUpdateState(context->clientContext, *updateState); + tableInfo.table->update(transaction::Transaction::Get(*context->clientContext), *updateState); + if (info.columnVectorPos.isValid()) { + writeColumnUpdateResult(info.nodeIDVector, info.columnVector, info.columnDataVector); + } +} + +void RelSetInfo::init(const ResultSet& resultSet, main::ClientContext* context) { + srcNodeIDVector = resultSet.getValueVector(srcNodeIDPos).get(); + dstNodeIDVector = resultSet.getValueVector(dstNodeIDPos).get(); + relIDVector = resultSet.getValueVector(relIDPos).get(); + if (columnVectorPos.isValid()) { + columnVector = resultSet.getValueVector(columnVectorPos).get(); + } + evaluator->init(resultSet, context); + columnDataVector = evaluator->resultVector.get(); +} + +void RelSetExecutor::init(ResultSet* resultSet, ExecutionContext* context) { + info.init(*resultSet, context->clientContext); +} + +void RelSetExecutor::setRelID(nodeID_t relID) const { + info.relIDVector->setValue(info.relIDVector->state->getSelVector()[0], relID); +} + +void SingleLabelRelSetExecutor::set(ExecutionContext* context) { + if (tableInfo.columnID == INVALID_COLUMN_ID) { + if (info.columnVectorPos.isValid()) { + info.columnVector->setNull(info.columnDataVector->state->getSelVector()[0], true); + } + return; + } + info.evaluator->evaluate(); + auto updateState = std::make_unique(tableInfo.columnID, + *info.srcNodeIDVector, *info.dstNodeIDVector, *info.relIDVector, *info.columnDataVector); + tableInfo.table->update(transaction::Transaction::Get(*context->clientContext), *updateState); + if (info.columnVectorPos.isValid()) { + writeColumnUpdateResult(info.relIDVector, info.columnVector, info.columnDataVector); + } +} + +void MultiLabelRelSetExecutor::set(ExecutionContext* context) { + info.evaluator->evaluate(); + auto& idSelVector = info.relIDVector->state->getSelVector(); + KU_ASSERT(idSelVector.getSelSize() == 1); + auto relID = info.relIDVector->getValue(idSelVector[0]); + if (!tableInfos.contains(relID.tableID)) { + if (info.columnVectorPos.isValid()) { + info.columnVector->setNull(info.columnDataVector->state->getSelVector()[0], true); + } + return; + } + auto& tableInfo = tableInfos.at(relID.tableID); + auto updateState = std::make_unique(tableInfo.columnID, + *info.srcNodeIDVector, *info.dstNodeIDVector, *info.relIDVector, *info.columnDataVector); + tableInfo.table->update(transaction::Transaction::Get(*context->clientContext), *updateState); + if (info.columnVectorPos.isValid()) { + writeColumnUpdateResult(info.relIDVector, info.columnVector, info.columnDataVector); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/CMakeLists.txt new file mode 100644 index 0000000000..20840df2d9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/CMakeLists.txt @@ -0,0 +1,18 @@ +add_library(lbug_processor_operator_parquet_writer + OBJECT + basic_column_writer.cpp + boolean_column_writer.cpp + column_writer.cpp + interval_column_writer.cpp + struct_column_writer.cpp + string_column_writer.cpp + list_column_writer.cpp + parquet_writer.cpp + parquet_rle_bp_encoder.cpp + uuid_column_writer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) + +target_link_libraries(lbug_processor_operator_parquet_writer PUBLIC parquet) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp new file mode 100644 index 0000000000..603c98dd35 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp @@ -0,0 +1,319 @@ +#include "processor/operator/persistent/writer/parquet/basic_column_writer.h" + +#include "common/constants.h" +#include "common/exception/runtime.h" +#include "function/cast/functions/numeric_limits.h" +#include "processor/operator/persistent/reader/parquet/parquet_rle_bp_decoder.h" +#include "processor/operator/persistent/writer//parquet/parquet_rle_bp_encoder.h" +#include "processor/operator/persistent/writer/parquet/parquet_writer.h" + +namespace lbug { +namespace processor { + +using namespace lbug_parquet::format; +using namespace lbug::common; + +std::unique_ptr BasicColumnWriter::initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) { + auto result = std::make_unique(rowGroup, rowGroup.columns.size()); + registerToRowGroup(rowGroup); + return result; +} + +void BasicColumnWriter::prepare(ColumnWriterState& stateToPrepare, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count) { + auto& state = reinterpret_cast(stateToPrepare); + auto& colChunk = state.rowGroup.columns[state.colIdx]; + + uint64_t start = 0; + auto vcount = parent ? parent->definitionLevels.size() - state.definitionLevels.size() : count; + auto parentIdx = state.definitionLevels.size(); + handleRepeatLevels(state, parent); + handleDefineLevels(state, parent, vector, count, maxDefine, maxDefine - 1); + + auto vectorIdx = 0u; + for (auto i = start; i < vcount; i++) { + auto& pageInfo = state.pageInfo.back(); + pageInfo.rowCount++; + colChunk.meta_data.num_values++; + if (parent && !parent->isEmpty.empty() && parent->isEmpty[parentIdx + i]) { + pageInfo.emptyCount++; + continue; + } + if (!vector->isNull(vectorIdx)) { + pageInfo.estimatedPageSize += getRowSize(vector, vectorIdx, state); + if (pageInfo.estimatedPageSize >= ParquetConstants::MAX_UNCOMPRESSED_PAGE_SIZE) { + PageInformation newInfo; + newInfo.offset = pageInfo.offset + pageInfo.rowCount; + state.pageInfo.push_back(newInfo); + } + } + vectorIdx++; + } +} + +void BasicColumnWriter::beginWrite(ColumnWriterState& writerState) { + auto& state = reinterpret_cast(writerState); + + // Set up the page write info. + state.statsState = initializeStatsState(); + for (auto pageIdx = 0u; pageIdx < state.pageInfo.size(); pageIdx++) { + auto& pageInfo = state.pageInfo[pageIdx]; + if (pageInfo.rowCount == 0) { + KU_ASSERT(pageIdx + 1 == state.pageInfo.size()); + state.pageInfo.erase(state.pageInfo.begin() + pageIdx); + break; + } + PageWriteInformation writeInfo; + // Set up the header. + auto& hdr = writeInfo.pageHeader; + hdr.compressed_page_size = 0; + hdr.uncompressed_page_size = 0; + hdr.type = PageType::DATA_PAGE; + hdr.__isset.data_page_header = true; + + hdr.data_page_header.num_values = pageInfo.rowCount; + hdr.data_page_header.encoding = getEncoding(state); + hdr.data_page_header.definition_level_encoding = Encoding::RLE; + hdr.data_page_header.repetition_level_encoding = Encoding::RLE; + + writeInfo.bufferWriter = std::make_shared(); + writeInfo.writer = std::make_unique(writeInfo.bufferWriter); + writeInfo.writeCount = pageInfo.emptyCount; + writeInfo.maxWriteCount = pageInfo.rowCount; + writeInfo.pageState = initializePageState(state); + + writeInfo.compressedSize = 0; + writeInfo.compressedData = nullptr; + + state.writeInfo.push_back(std::move(writeInfo)); + } + + nextPage(state); +} + +void BasicColumnWriter::write(ColumnWriterState& writerState, common::ValueVector* vector, + uint64_t count) { + auto& state = reinterpret_cast(writerState); + + uint64_t remaining = count; + uint64_t offset = 0; + while (remaining > 0) { + auto& writeInfo = state.writeInfo[state.currentPage - 1]; + KU_ASSERT(writeInfo.bufferWriter != nullptr); + auto writeCount = + std::min(remaining, writeInfo.maxWriteCount - writeInfo.writeCount); + + writeVector(*writeInfo.writer, state.statsState.get(), writeInfo.pageState.get(), vector, + offset, offset + writeCount); + + writeInfo.writeCount += writeCount; + if (writeInfo.writeCount == writeInfo.maxWriteCount) { + nextPage(state); + } + offset += writeCount; + remaining -= writeCount; + } +} + +void BasicColumnWriter::finalizeWrite(ColumnWriterState& writerState) { + auto& state = reinterpret_cast(writerState); + auto& columnChunk = state.rowGroup.columns[state.colIdx]; + + // Flush the last page (if any remains). + flushPage(state); + + auto startOffset = writer.getOffset(); + auto pageOffset = startOffset; + // Flush the dictionary. + if (hasDictionary(state)) { + columnChunk.meta_data.statistics.distinct_count = dictionarySize(state); + columnChunk.meta_data.statistics.__isset.distinct_count = true; + columnChunk.meta_data.dictionary_page_offset = pageOffset; + columnChunk.meta_data.__isset.dictionary_page_offset = true; + flushDictionary(state, state.statsState.get()); + pageOffset += state.writeInfo[0].compressedSize; + } + + // Record the start position of the pages for this column. + columnChunk.meta_data.data_page_offset = pageOffset; + setParquetStatistics(state, columnChunk); + + // write the individual pages to disk + uint64_t totalUncompressedSize = 0; + for (auto& write_info : state.writeInfo) { + KU_ASSERT(write_info.pageHeader.uncompressed_page_size > 0); + auto header_start_offset = writer.getOffset(); + write_info.pageHeader.write(writer.getProtocol()); + // total uncompressed size in the column chunk includes the header size (!) + totalUncompressedSize += writer.getOffset() - header_start_offset; + totalUncompressedSize += write_info.pageHeader.uncompressed_page_size; + writer.write(write_info.compressedData, write_info.compressedSize); + } + columnChunk.meta_data.total_compressed_size = writer.getOffset() - startOffset; + columnChunk.meta_data.total_uncompressed_size = totalUncompressedSize; +} + +void BasicColumnWriter::writeLevels(Serializer& serializer, const std::vector& levels, + uint64_t maxValue, uint64_t startOffset, uint64_t count) { + if (levels.empty() || count == 0) { + return; + } + + // Write the levels using the RLE-BP encoding. + auto bitWidth = RleBpDecoder::ComputeBitWidth((maxValue)); + RleBpEncoder rleEncoder(bitWidth); + + rleEncoder.beginPrepare(levels[startOffset]); + for (auto i = startOffset + 1; i < startOffset + count; i++) { + rleEncoder.prepareValue(levels[i]); + } + rleEncoder.finishPrepare(); + + // Start off by writing the byte count as a uint32_t. + serializer.write(rleEncoder.getByteCount()); + rleEncoder.beginWrite(levels[startOffset]); + for (auto i = startOffset + 1; i < startOffset + count; i++) { + rleEncoder.writeValue(serializer, levels[i]); + } + rleEncoder.finishWrite(serializer); +} + +void BasicColumnWriter::nextPage(BasicColumnWriterState& state) { + if (state.currentPage > 0) { + // Need to flush the current page. + flushPage(state); + } + if (state.currentPage >= state.writeInfo.size()) { + state.currentPage = state.writeInfo.size() + 1; + return; + } + auto& pageInfo = state.pageInfo[state.currentPage]; + auto& writeInfo = state.writeInfo[state.currentPage]; + state.currentPage++; + + // write the repetition levels + writeLevels(*writeInfo.writer, state.repetitionLevels, maxRepeat, pageInfo.offset, + pageInfo.rowCount); + + // write the definition levels + writeLevels(*writeInfo.writer, state.definitionLevels, maxDefine, pageInfo.offset, + pageInfo.rowCount); +} + +void BasicColumnWriter::flushPage(BasicColumnWriterState& state) { + KU_ASSERT(state.currentPage > 0); + if (state.currentPage > state.writeInfo.size()) { + return; + } + + // compress the page info + auto& writeInfo = state.writeInfo[state.currentPage - 1]; + auto& bufferedWriter = *writeInfo.bufferWriter; + auto& hdr = writeInfo.pageHeader; + + flushPageState(*writeInfo.writer, writeInfo.pageState.get()); + + // now that we have finished writing the data we know the uncompressed size + if (bufferedWriter.getSize() > uint64_t(function::NumericLimits::maximum())) { + throw common::RuntimeException{common::stringFormat( + "Parquet writer: %d uncompressed page size out of range for type integer", + bufferedWriter.getSize())}; + } + hdr.uncompressed_page_size = bufferedWriter.getSize(); + + // compress the data + compressPage(bufferedWriter, writeInfo.compressedSize, writeInfo.compressedData, + writeInfo.compressedBuf); + hdr.compressed_page_size = writeInfo.compressedSize; + KU_ASSERT(hdr.uncompressed_page_size > 0); + KU_ASSERT(hdr.compressed_page_size > 0); + + if (writeInfo.compressedBuf) { + // if the data has been compressed, we no longer need the compressed data + KU_ASSERT(writeInfo.compressedBuf.get() == writeInfo.compressedData); + writeInfo.bufferWriter.reset(); + } +} + +void BasicColumnWriter::writeDictionary(BasicColumnWriterState& state, + std::unique_ptr bufferedSerializer, uint64_t rowCount) { + KU_ASSERT(bufferedSerializer); + KU_ASSERT(bufferedSerializer->getSize() > 0); + + // write the dictionary page header + PageWriteInformation writeInfo; + // set up the header + auto& hdr = writeInfo.pageHeader; + hdr.uncompressed_page_size = bufferedSerializer->getSize(); + hdr.type = PageType::DICTIONARY_PAGE; + hdr.__isset.dictionary_page_header = true; + + hdr.dictionary_page_header.encoding = Encoding::PLAIN; + hdr.dictionary_page_header.is_sorted = false; + hdr.dictionary_page_header.num_values = rowCount; + + writeInfo.bufferWriter = std::move(bufferedSerializer); + writeInfo.writer = std::make_unique(writeInfo.bufferWriter); + writeInfo.writeCount = 0; + writeInfo.maxWriteCount = 0; + + // compress the contents of the dictionary page + compressPage(*writeInfo.bufferWriter, writeInfo.compressedSize, writeInfo.compressedData, + writeInfo.compressedBuf); + hdr.compressed_page_size = writeInfo.compressedSize; + + // insert the dictionary page as the first page to write for this column + state.writeInfo.insert(state.writeInfo.begin(), std::move(writeInfo)); +} + +void BasicColumnWriter::setParquetStatistics(BasicColumnWriterState& state, + lbug_parquet::format::ColumnChunk& column) { + if (maxRepeat == 0) { + column.meta_data.statistics.null_count = nullCount; + column.meta_data.statistics.__isset.null_count = true; + column.meta_data.__isset.statistics = true; + } + // set min/max/min_value/max_value + // this code is not going to win any beauty contests, but well + auto min = state.statsState->getMin(); + if (!min.empty()) { + column.meta_data.statistics.min = std::move(min); + column.meta_data.statistics.__isset.min = true; + column.meta_data.__isset.statistics = true; + } + auto max = state.statsState->getMax(); + if (!max.empty()) { + column.meta_data.statistics.max = std::move(max); + column.meta_data.statistics.__isset.max = true; + column.meta_data.__isset.statistics = true; + } + auto min_value = state.statsState->getMinValue(); + if (!min_value.empty()) { + column.meta_data.statistics.min_value = std::move(min_value); + column.meta_data.statistics.__isset.min_value = true; + column.meta_data.__isset.statistics = true; + } + auto max_value = state.statsState->getMaxValue(); + if (!max_value.empty()) { + column.meta_data.statistics.max_value = std::move(max_value); + column.meta_data.statistics.__isset.max_value = true; + column.meta_data.__isset.statistics = true; + } + for (const auto& write_info : state.writeInfo) { + column.meta_data.encodings.push_back(write_info.pageHeader.data_page_header.encoding); + } +} + +void BasicColumnWriter::registerToRowGroup(lbug_parquet::format::RowGroup& rowGroup) { + ColumnChunk column_chunk; + column_chunk.__isset.meta_data = true; + column_chunk.meta_data.codec = writer.getCodec(); + column_chunk.meta_data.path_in_schema = schemaPath; + column_chunk.meta_data.num_values = 0; + column_chunk.meta_data.type = writer.getParquetType(schemaIdx); + rowGroup.columns.push_back(std::move(column_chunk)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/boolean_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/boolean_column_writer.cpp new file mode 100644 index 0000000000..cd6bdd4760 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/boolean_column_writer.cpp @@ -0,0 +1,45 @@ +#include "processor/operator/persistent/writer/parquet/boolean_column_writer.h" + +#include "common/serializer/serializer.h" + +namespace lbug { +namespace processor { + +void BooleanColumnWriter::writeVector(common::Serializer& temp_writer, + ColumnWriterStatistics* writerStatistics, ColumnWriterPageState* writerPageState, + common::ValueVector* vector, uint64_t chunkStart, uint64_t chunkEnd) { + auto stats = reinterpret_cast(writerStatistics); + auto state = reinterpret_cast(writerPageState); + for (auto r = chunkStart; r < chunkEnd; r++) { + auto pos = getVectorPos(vector, r); + if (!vector->isNull(pos)) { + // only encode if non-null + if (vector->getValue(pos)) { + stats->max = true; + state->byte |= 1 << state->bytePos; + } else { + stats->min = false; + } + state->bytePos++; + + if (state->bytePos == 8) { + temp_writer.write(state->byte); + state->byte = 0; + state->bytePos = 0; + } + } + } +} + +void BooleanColumnWriter::flushPageState(common::Serializer& temp_writer, + ColumnWriterPageState* writerPageState) { + auto state = reinterpret_cast(writerPageState); + if (state->bytePos > 0) { + temp_writer.write(state->byte); + state->byte = 0; + state->bytePos = 0; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/column_writer.cpp new file mode 100644 index 0000000000..b59719ff80 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/column_writer.cpp @@ -0,0 +1,378 @@ +#include "processor/operator/persistent/writer/parquet/column_writer.h" + +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "function/cast/functions/numeric_limits.h" +#include "lz4.hpp" +#include "miniz_wrapper.hpp" +#include "processor/operator/persistent/writer/parquet/boolean_column_writer.h" +#include "processor/operator/persistent/writer/parquet/interval_column_writer.h" +#include "processor/operator/persistent/writer/parquet/list_column_writer.h" +#include "processor/operator/persistent/writer/parquet/parquet_writer.h" +#include "processor/operator/persistent/writer/parquet/standard_column_writer.h" +#include "processor/operator/persistent/writer/parquet/string_column_writer.h" +#include "processor/operator/persistent/writer/parquet/struct_column_writer.h" +#include "processor/operator/persistent/writer/parquet/uuid_column_writer.h" +#include "snappy.h" +#include "zstd.h" + +namespace lbug { +namespace processor { + +using namespace lbug_parquet::format; +using namespace lbug::common; + +struct ParquetInt128Operator { + template + static inline TGT Operation(SRC input) { + return Int128_t::cast(input); + } + + template + static inline std::unique_ptr initializeStats() { + return std::make_unique(); + } + + template + static void handleStats(ColumnWriterStatistics* /*stats*/, SRC /*source*/, TGT /*target*/) {} +}; + +struct ParquetTimestampNSOperator : public BaseParquetOperator { + template + static TGT Operation(SRC input) { + return Timestamp::fromEpochNanoSeconds(input).value; + } +}; + +struct ParquetTimestampSOperator : public BaseParquetOperator { + template + static TGT Operation(SRC input) { + return Timestamp::fromEpochSeconds(input).value; + } +}; + +ColumnWriter::ColumnWriter(ParquetWriter& writer, uint64_t schemaIdx, + std::vector schemaPath, uint64_t maxRepeat, uint64_t maxDefine, bool canHaveNulls) + : writer{writer}, schemaIdx{schemaIdx}, schemaPath{std::move(schemaPath)}, maxRepeat{maxRepeat}, + maxDefine{maxDefine}, canHaveNulls{canHaveNulls}, nullCount{0} {} + +std::unique_ptr ColumnWriter::createWriterRecursive( + std::vector& schemas, ParquetWriter& writer, + const LogicalType& type, const std::string& name, std::vector schemaPathToCreate, + storage::MemoryManager* mm, uint64_t maxRepeatToCreate, uint64_t maxDefineToCreate, + bool canHaveNullsToCreate) { + auto nullType = + canHaveNullsToCreate ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; + if (!canHaveNullsToCreate) { + maxDefineToCreate--; + } + auto schemaIdx = schemas.size(); + switch (type.getLogicalTypeID()) { + case LogicalTypeID::UNION: + case LogicalTypeID::STRUCT: { + const auto& fields = StructType::getFields(type); + // set up the schema element for this struct + lbug_parquet::format::SchemaElement schema_element; + schema_element.repetition_type = nullType; + schema_element.num_children = fields.size(); + schema_element.__isset.num_children = true; + schema_element.__isset.type = false; + schema_element.__isset.repetition_type = true; + schema_element.name = name; + schemas.push_back(std::move(schema_element)); + schemaPathToCreate.push_back(name); + + // Construct the child types recursively. + std::vector> childWriters; + childWriters.reserve(fields.size()); + for (auto& field : fields) { + childWriters.push_back(createWriterRecursive(schemas, writer, field.getType(), + field.getName(), schemaPathToCreate, mm, maxRepeatToCreate, maxDefineToCreate + 1)); + } + return std::make_unique(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + std::move(childWriters), canHaveNullsToCreate); + } + case LogicalTypeID::ARRAY: + case LogicalTypeID::LIST: { + const auto& childType = ListType::getChildType(type); + // Set up the two schema elements for the list + // for some reason we only set the converted type in the OPTIONAL element + // first an OPTIONAL element. + lbug_parquet::format::SchemaElement optionalElem; + optionalElem.repetition_type = nullType; + optionalElem.num_children = 1; + optionalElem.converted_type = ConvertedType::LIST; + optionalElem.__isset.num_children = true; + optionalElem.__isset.type = false; + optionalElem.__isset.repetition_type = true; + optionalElem.__isset.converted_type = true; + optionalElem.name = name; + schemas.push_back(std::move(optionalElem)); + schemaPathToCreate.push_back(name); + + // Then a REPEATED element. + lbug_parquet::format::SchemaElement repeatedElem; + repeatedElem.repetition_type = FieldRepetitionType::REPEATED; + repeatedElem.num_children = 1; + repeatedElem.__isset.num_children = true; + repeatedElem.__isset.type = false; + repeatedElem.__isset.repetition_type = true; + repeatedElem.name = "list"; + schemas.push_back(std::move(repeatedElem)); + schemaPathToCreate.emplace_back("list"); + + auto child_writer = createWriterRecursive(schemas, writer, childType, "element", + schemaPathToCreate, mm, maxRepeatToCreate + 1, maxDefineToCreate + 2); + return std::make_unique(writer, schemaIdx, std::move(schemaPathToCreate), + maxRepeatToCreate, maxDefineToCreate, std::move(child_writer), canHaveNullsToCreate); + } + case LogicalTypeID::MAP: { + // Maps are stored as follows in parquet: + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + lbug_parquet::format::SchemaElement topElement; + topElement.repetition_type = nullType; + topElement.num_children = 1; + topElement.converted_type = ConvertedType::MAP; + topElement.__isset.repetition_type = true; + topElement.__isset.num_children = true; + topElement.__isset.converted_type = true; + topElement.__isset.type = false; + topElement.name = name; + schemas.push_back(std::move(topElement)); + schemaPathToCreate.push_back(name); + + // key_value element + lbug_parquet::format::SchemaElement kv_element; + kv_element.repetition_type = FieldRepetitionType::REPEATED; + kv_element.num_children = 2; + kv_element.__isset.repetition_type = true; + kv_element.__isset.num_children = true; + kv_element.__isset.type = false; + kv_element.name = "key_value"; + schemas.push_back(std::move(kv_element)); + schemaPathToCreate.emplace_back("key_value"); + + // Construct the child types recursively. + std::vector kvTypes; + kvTypes.push_back(MapType::getKeyType(type).copy()); + kvTypes.push_back(MapType::getValueType(type).copy()); + std::vector kvNames{"key", "value"}; + std::vector> childrenWriters; + childrenWriters.reserve(2); + for (auto i = 0u; i < 2; i++) { + auto childWriter = createWriterRecursive(schemas, writer, kvTypes[i], kvNames[i], + schemaPathToCreate, mm, maxRepeatToCreate + 1, maxDefineToCreate + 2, i != 0); + childrenWriters.push_back(std::move(childWriter)); + } + auto structWriter = std::make_unique(writer, schemaIdx, + schemaPathToCreate, maxRepeatToCreate, maxDefineToCreate, std::move(childrenWriters), + canHaveNullsToCreate); + return std::make_unique(writer, schemaIdx, schemaPathToCreate, + maxRepeatToCreate, maxDefineToCreate, std::move(structWriter), canHaveNullsToCreate); + } + default: { + SchemaElement schemaElement; + schemaElement.type = ParquetWriter::convertToParquetType(type); + schemaElement.repetition_type = nullType; + schemaElement.__isset.num_children = false; + schemaElement.__isset.type = true; + schemaElement.__isset.repetition_type = true; + schemaElement.name = name; + ParquetWriter::setSchemaProperties(type, schemaElement); + schemas.push_back(std::move(schemaElement)); + schemaPathToCreate.push_back(name); + + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + return std::make_unique(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::INT8: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::INT16: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::INT32: + case LogicalTypeID::DATE: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::TIMESTAMP_NS: + return make_unique>( + writer, schemaIdx, std::move(schemaPathToCreate), maxRepeatToCreate, + maxDefineToCreate, canHaveNullsToCreate); + case LogicalTypeID::TIMESTAMP_SEC: + return make_unique>( + writer, schemaIdx, std::move(schemaPathToCreate), maxRepeatToCreate, + maxDefineToCreate, canHaveNullsToCreate); + case LogicalTypeID::INT128: + return std::make_unique>( + writer, schemaIdx, std::move(schemaPathToCreate), maxRepeatToCreate, + maxDefineToCreate, canHaveNullsToCreate); + case LogicalTypeID::UINT8: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::UINT16: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::UINT32: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::UINT64: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::FLOAT: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::DOUBLE: + return std::make_unique>(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: + return std::make_unique(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate, mm); + case LogicalTypeID::INTERVAL: + return std::make_unique(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + case LogicalTypeID::UUID: + return std::make_unique(writer, schemaIdx, + std::move(schemaPathToCreate), maxRepeatToCreate, maxDefineToCreate, + canHaveNullsToCreate); + default: + KU_UNREACHABLE; + } + } + } +} + +void ColumnWriter::handleRepeatLevels(ColumnWriterState& stateToHandle, ColumnWriterState* parent) { + if (!parent) { + // no repeat levels without a parent node + return; + } + while (stateToHandle.repetitionLevels.size() < parent->repetitionLevels.size()) { + stateToHandle.repetitionLevels.push_back( + parent->repetitionLevels[stateToHandle.repetitionLevels.size()]); + } +} + +void ColumnWriter::handleDefineLevels(ColumnWriterState& state, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count, uint16_t defineValue, uint16_t nullValue) { + if (parent) { + // parent node: inherit definition level from the parent + uint64_t vectorIdx = 0; + while (state.definitionLevels.size() < parent->definitionLevels.size()) { + auto currentIdx = state.definitionLevels.size(); + if (parent->definitionLevels[currentIdx] != ParquetConstants::PARQUET_DEFINE_VALID) { + state.definitionLevels.push_back(parent->definitionLevels[currentIdx]); + } else if (!vector->isNull(getVectorPos(vector, vectorIdx))) { + state.definitionLevels.push_back(defineValue); + } else { + if (!canHaveNulls) { + throw RuntimeException( + "Parquet writer: map key column is not allowed to contain NULL values"); + } + nullCount++; + state.definitionLevels.push_back(nullValue); + } + if (parent->isEmpty.empty() || !parent->isEmpty[currentIdx]) { + vectorIdx++; + } + } + } else { + // no parent: set definition levels only from this validity mask + for (auto i = 0u; i < count; i++) { + if (!vector->isNull(getVectorPos(vector, i))) { + state.definitionLevels.push_back(defineValue); + } else { + if (!canHaveNulls) { + throw RuntimeException( + "Parquet writer: map key column is not allowed to contain NULL values"); + } + nullCount++; + state.definitionLevels.push_back(nullValue); + } + } + } +} + +void ColumnWriter::compressPage(common::BufferWriter& bufferedSerializer, size_t& compressedSize, + uint8_t*& compressedData, std::unique_ptr& compressedBuf) { + switch (writer.getCodec()) { + case CompressionCodec::UNCOMPRESSED: { + compressedSize = bufferedSerializer.getSize(); + compressedData = bufferedSerializer.getBlobData(); + } break; + case CompressionCodec::SNAPPY: { + compressedSize = lbug_snappy::MaxCompressedLength(bufferedSerializer.getSize()); + compressedBuf = std::unique_ptr(new uint8_t[compressedSize]); + lbug_snappy::RawCompress(reinterpret_cast(bufferedSerializer.getBlobData()), + bufferedSerializer.getSize(), reinterpret_cast(compressedBuf.get()), + &compressedSize); + compressedData = compressedBuf.get(); + KU_ASSERT(compressedSize <= lbug_snappy::MaxCompressedLength(bufferedSerializer.getSize())); + } break; + case CompressionCodec::ZSTD: { + compressedSize = lbug_zstd::ZSTD_compressBound(bufferedSerializer.getSize()); + compressedBuf = std::unique_ptr(new uint8_t[compressedSize]); + compressedSize = lbug_zstd::ZSTD_compress((void*)compressedBuf.get(), compressedSize, + reinterpret_cast(bufferedSerializer.getBlobData()), + bufferedSerializer.getSize(), ZSTD_CLEVEL_DEFAULT); + compressedData = compressedBuf.get(); + } break; + case CompressionCodec::GZIP: { + MiniZStream stream; + compressedSize = stream.MaxCompressedLength(bufferedSerializer.getSize()); + compressedBuf = std::unique_ptr(new uint8_t[compressedSize]); + stream.Compress(reinterpret_cast(bufferedSerializer.getBlobData()), + bufferedSerializer.getSize(), reinterpret_cast(compressedBuf.get()), + &compressedSize); + compressedData = compressedBuf.get(); + } break; + case CompressionCodec::LZ4_RAW: { + compressedSize = lbug_lz4::LZ4_compressBound(bufferedSerializer.getSize()); + compressedBuf = std::unique_ptr(new uint8_t[compressedSize]); + compressedSize = lbug_lz4::LZ4_compress_default( + reinterpret_cast(bufferedSerializer.getBlobData()), + reinterpret_cast(compressedBuf.get()), bufferedSerializer.getSize(), + compressedSize); + compressedData = compressedBuf.get(); + } break; + default: + KU_UNREACHABLE; + } + + if (compressedSize > uint64_t(function::NumericLimits::maximum())) { + throw RuntimeException( + stringFormat("Parquet writer: {} compressed page size out of range for type integer", + bufferedSerializer.getSize())); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/interval_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/interval_column_writer.cpp new file mode 100644 index 0000000000..6c4b141e39 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/interval_column_writer.cpp @@ -0,0 +1,39 @@ +#include "processor/operator/persistent/writer/parquet/interval_column_writer.h" + +#include "common/exception/runtime.h" +#include "common/serializer/serializer.h" + +namespace lbug { +namespace processor { + +void IntervalColumnWriter::writeParquetInterval(common::interval_t input, uint8_t* result) { + if (input.days < 0 || input.months < 0 || input.micros < 0) { + throw common::RuntimeException{"Parquet files do not support negative intervals"}; + } + uint32_t dataToStore = 0; + dataToStore = input.months; + memcpy(result, &dataToStore, sizeof(dataToStore)); + dataToStore = input.days; + result += sizeof(dataToStore); + memcpy(result, &dataToStore, sizeof(dataToStore)); + dataToStore = input.micros / 1000; + result += sizeof(dataToStore); + memcpy(result, &dataToStore, sizeof(dataToStore)); +} + +void IntervalColumnWriter::writeVector(common::Serializer& bufferedSerializer, + ColumnWriterStatistics* /*state*/, ColumnWriterPageState* /*pageState*/, + common::ValueVector* vector, uint64_t chunkStart, uint64_t chunkEnd) { + uint8_t tmpIntervalBuf[common::ParquetConstants::PARQUET_INTERVAL_SIZE]; + for (auto r = chunkStart; r < chunkEnd; r++) { + auto pos = getVectorPos(vector, r); + if (!vector->isNull(pos)) { + writeParquetInterval(vector->getValue(pos), tmpIntervalBuf); + bufferedSerializer.write(tmpIntervalBuf, + common::ParquetConstants::PARQUET_INTERVAL_SIZE); + } + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/list_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/list_column_writer.cpp new file mode 100644 index 0000000000..dd3e1718b2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/list_column_writer.cpp @@ -0,0 +1,109 @@ +#include "processor/operator/persistent/writer/parquet/list_column_writer.h" + +#include "common/constants.h" +#include "common/exception/runtime.h" + +namespace lbug { +namespace processor { + +using namespace lbug_parquet::format; + +std::unique_ptr ListColumnWriter::initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) { + auto result = std::make_unique(rowGroup, rowGroup.columns.size()); + result->childState = childWriter->initializeWriteState(rowGroup); + return result; +} + +bool ListColumnWriter::hasAnalyze() { + return childWriter->hasAnalyze(); +} + +void ListColumnWriter::analyze(ColumnWriterState& writerState, ColumnWriterState* /*parent*/, + common::ValueVector* vector, uint64_t /*count*/) { + auto& state = reinterpret_cast(writerState); + childWriter->analyze(*state.childState, &writerState, common::ListVector::getDataVector(vector), + common::ListVector::getDataVectorSize(vector)); +} + +void ListColumnWriter::finalizeAnalyze(ColumnWriterState& writerState) { + auto& state = reinterpret_cast(writerState); + childWriter->finalizeAnalyze(*state.childState); +} + +void ListColumnWriter::prepare(ColumnWriterState& writerState, ColumnWriterState* parent, + common::ValueVector* vector, uint64_t count) { + auto& state = reinterpret_cast(writerState); + + // Write definition levels and repeats. + uint64_t start = 0; + auto vcount = parent ? parent->definitionLevels.size() - state.parentIdx : count; + uint64_t vectorIdx = 0; + for (auto i = start; i < vcount; i++) { + auto parentIdx = state.parentIdx + i; + if (parent && !parent->isEmpty.empty() && parent->isEmpty[parentIdx]) { + state.definitionLevels.push_back(parent->definitionLevels[parentIdx]); + state.repetitionLevels.push_back(parent->repetitionLevels[parentIdx]); + state.isEmpty.push_back(true); + continue; + } + auto firstRepeatLevel = parent && !parent->repetitionLevels.empty() ? + parent->repetitionLevels[parentIdx] : + maxRepeat; + auto pos = getVectorPos(vector, vectorIdx); + if (parent && + parent->definitionLevels[parentIdx] != common::ParquetConstants::PARQUET_DEFINE_VALID) { + state.definitionLevels.push_back(parent->definitionLevels[parentIdx]); + state.repetitionLevels.push_back(firstRepeatLevel); + state.isEmpty.push_back(true); + } else if (!vector->isNull(pos)) { + auto listEntry = vector->getValue(pos); + // push the repetition levels + if (listEntry.size == 0) { + state.definitionLevels.push_back(maxDefine); + state.isEmpty.push_back(true); + } else { + state.definitionLevels.push_back(common::ParquetConstants::PARQUET_DEFINE_VALID); + state.isEmpty.push_back(false); + } + state.repetitionLevels.push_back(firstRepeatLevel); + for (auto k = 1u; k < listEntry.size; k++) { + state.repetitionLevels.push_back(maxRepeat + 1); + state.definitionLevels.push_back(common::ParquetConstants::PARQUET_DEFINE_VALID); + state.isEmpty.push_back(false); + } + } else { + if (!canHaveNulls) { + throw common::RuntimeException( + "Parquet writer: map key column is not allowed to contain NULL values"); + } + state.definitionLevels.push_back(maxDefine - 1); + state.repetitionLevels.push_back(firstRepeatLevel); + state.isEmpty.push_back(true); + } + vectorIdx++; + } + state.parentIdx += vcount; + childWriter->prepare(*state.childState, &writerState, common::ListVector::getDataVector(vector), + common::ListVector::getDataVectorSize(vector)); +} + +void ListColumnWriter::beginWrite(ColumnWriterState& state_p) { + auto& state = reinterpret_cast(state_p); + childWriter->beginWrite(*state.childState); +} + +void ListColumnWriter::write(ColumnWriterState& writerState, common::ValueVector* vector, + uint64_t /*count*/) { + auto& state = reinterpret_cast(writerState); + childWriter->write(*state.childState, common::ListVector::getDataVector(vector), + common::ListVector::getDataVectorSize(vector)); +} + +void ListColumnWriter::finalizeWrite(ColumnWriterState& writerState) { + auto& state = reinterpret_cast(writerState); + childWriter->finalizeWrite(*state.childState); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.cpp new file mode 100644 index 0000000000..363b5ee3b5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.cpp @@ -0,0 +1,112 @@ +#include "processor/operator/persistent/writer/parquet/parquet_rle_bp_encoder.h" + +#include "common/assert.h" + +namespace lbug { +namespace processor { + +static void varintEncode(uint32_t val, common::Serializer& ser) { + do { + uint8_t byte = val & 127; + val >>= 7; + if (val != 0) { + byte |= 128; + } + ser.write(byte); + } while (val != 0); +} + +uint8_t RleBpEncoder::getVarintSize(uint32_t val) { + uint8_t res = 0; + do { + val >>= 7; + res++; + } while (val != 0); + return res; +} + +RleBpEncoder::RleBpEncoder(uint32_t bit_width) + : byteWidth((bit_width + 7) / 8), byteCount(uint64_t(-1)), runCount(uint64_t(-1)), + currentRunCount(0), lastValue(0) {} + +// we always RLE everything (for now) +void RleBpEncoder::beginPrepare(uint32_t first_value) { + byteCount = 0; + runCount = 1; + currentRunCount = 1; + lastValue = first_value; +} + +void RleBpEncoder::finishRun() { + // last value, or value has changed + // write out the current run + byteCount += getVarintSize(currentRunCount << 1) + byteWidth; + currentRunCount = 1; + runCount++; +} + +void RleBpEncoder::prepareValue(uint32_t value) { + if (value != lastValue) { + finishRun(); + lastValue = value; + } else { + currentRunCount++; + } +} + +void RleBpEncoder::finishPrepare() { + finishRun(); +} + +uint64_t RleBpEncoder::getByteCount() const { + KU_ASSERT(byteCount != uint64_t(-1)); + return byteCount; +} + +void RleBpEncoder::beginWrite(uint32_t first_value) { + // start the RLE runs + lastValue = first_value; + currentRunCount = 1; +} + +void RleBpEncoder::writeRun(common::Serializer& writer) { + // write the header of the run + varintEncode(currentRunCount << 1, writer); + // now write the value + KU_ASSERT(lastValue >> (byteWidth * 8) == 0); + switch (byteWidth) { + case 1: + writer.write(lastValue); + break; + case 2: + writer.write(lastValue); + break; + case 3: + writer.write(lastValue & 0xFF); + writer.write((lastValue >> 8) & 0xFF); + writer.write((lastValue >> 16) & 0xFF); + break; + case 4: + writer.write(lastValue); + break; + default: + KU_UNREACHABLE; + } + currentRunCount = 1; +} + +void RleBpEncoder::writeValue(common::Serializer& writer, uint32_t value) { + if (value != lastValue) { + writeRun(writer); + lastValue = value; + } else { + currentRunCount++; + } +} + +void RleBpEncoder::finishWrite(common::Serializer& writer) { + writeRun(writer); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/parquet_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/parquet_writer.cpp new file mode 100644 index 0000000000..a8ed8de768 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/parquet_writer.cpp @@ -0,0 +1,304 @@ +#include "processor/operator/persistent/writer/parquet/parquet_writer.h" + +#include "common/constants.h" +#include "common/data_chunk/data_chunk.h" +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "common/system_config.h" +#include "main/client_context.h" +#include "protocol/TCompactProtocol.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace processor { + +using namespace lbug_parquet::format; +using namespace lbug::common; + +ParquetWriter::ParquetWriter(std::string fileName, std::vector types, + std::vector columnNames, lbug_parquet::format::CompressionCodec::type codec, + main::ClientContext* context) + : fileName{std::move(fileName)}, types{std::move(types)}, columnNames{std::move(columnNames)}, + codec{codec}, fileOffset{0}, mm{storage::MemoryManager::Get(*context)} { + fileInfo = VirtualFileSystem::GetUnsafe(*context)->openFile(this->fileName, + FileOpenFlags(FileFlags::WRITE | FileFlags::CREATE_AND_TRUNCATE_IF_EXISTS), context); + // Parquet files start with the string "PAR1". + fileInfo->writeFile(reinterpret_cast(ParquetConstants::PARQUET_MAGIC_WORDS), + strlen(ParquetConstants::PARQUET_MAGIC_WORDS), fileOffset); + fileOffset += strlen(ParquetConstants::PARQUET_MAGIC_WORDS); + lbug_apache::thrift::protocol::TCompactProtocolFactoryT tprotoFactory; + protocol = tprotoFactory.getProtocol( + std::make_shared(fileInfo.get(), fileOffset)); + + fileMetaData.num_rows = 0; + fileMetaData.version = 1; + + fileMetaData.__isset.created_by = true; + fileMetaData.created_by = "LBUG"; + + fileMetaData.schema.resize(1); + + // populate root schema object + fileMetaData.schema[0].name = "lbug_schema"; + fileMetaData.schema[0].num_children = this->types.size(); + fileMetaData.schema[0].__isset.num_children = true; + fileMetaData.schema[0].repetition_type = lbug_parquet::format::FieldRepetitionType::REQUIRED; + fileMetaData.schema[0].__isset.repetition_type = true; + + std::vector schemaPath; + for (auto i = 0u; i < this->types.size(); i++) { + columnWriters.push_back(ColumnWriter::createWriterRecursive(fileMetaData.schema, *this, + this->types[i], this->columnNames[i], schemaPath, mm)); + } +} + +Type::type ParquetWriter::convertToParquetType(const LogicalType& type) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + return Type::BOOLEAN; + case LogicalTypeID::INT8: + case LogicalTypeID::INT16: + case LogicalTypeID::INT32: + case LogicalTypeID::UINT8: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT32: + case LogicalTypeID::DATE: + return Type::INT32; + case LogicalTypeID::UINT64: + case LogicalTypeID::INT64: + case LogicalTypeID::SERIAL: + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP: + return Type::INT64; + case LogicalTypeID::FLOAT: + return Type::FLOAT; + case LogicalTypeID::INT128: + case LogicalTypeID::DOUBLE: + return Type::DOUBLE; + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: + return Type::BYTE_ARRAY; + case LogicalTypeID::UUID: + case LogicalTypeID::INTERVAL: + return Type::FIXED_LEN_BYTE_ARRAY; + default: + throw RuntimeException{ + stringFormat("Writing a column with type: {} to parquet is not supported.", + LogicalTypeUtils::toString(type.getLogicalTypeID()))}; + } +} + +void ParquetWriter::setSchemaProperties(const LogicalType& type, SchemaElement& schemaElement) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::INT8: { + schemaElement.converted_type = ConvertedType::INT_8; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::INT16: { + schemaElement.converted_type = ConvertedType::INT_16; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::INT32: { + schemaElement.converted_type = ConvertedType::INT_32; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::INT64: { + schemaElement.converted_type = ConvertedType::INT_64; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::UINT8: { + schemaElement.converted_type = ConvertedType::UINT_8; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::UINT16: { + schemaElement.converted_type = ConvertedType::UINT_16; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::UINT32: { + schemaElement.converted_type = ConvertedType::UINT_32; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::UINT64: { + schemaElement.converted_type = ConvertedType::UINT_64; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::DATE: { + schemaElement.converted_type = ConvertedType::DATE; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::STRING: { + schemaElement.converted_type = ConvertedType::UTF8; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::INTERVAL: { + schemaElement.type_length = common::ParquetConstants::PARQUET_INTERVAL_SIZE; + schemaElement.converted_type = ConvertedType::INTERVAL; + schemaElement.__isset.type_length = true; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::TIMESTAMP_TZ: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP: { + schemaElement.converted_type = ConvertedType::TIMESTAMP_MICROS; + schemaElement.__isset.converted_type = true; + schemaElement.__isset.logicalType = true; + schemaElement.logicalType.__isset.TIMESTAMP = true; + schemaElement.logicalType.TIMESTAMP.isAdjustedToUTC = + (type.getLogicalTypeID() == LogicalTypeID::TIMESTAMP_TZ); + schemaElement.logicalType.TIMESTAMP.unit.__isset.MICROS = true; + } break; + case LogicalTypeID::TIMESTAMP_MS: { + schemaElement.converted_type = ConvertedType::TIMESTAMP_MILLIS; + schemaElement.__isset.converted_type = true; + schemaElement.__isset.logicalType = true; + schemaElement.logicalType.__isset.TIMESTAMP = true; + schemaElement.logicalType.TIMESTAMP.isAdjustedToUTC = false; + schemaElement.logicalType.TIMESTAMP.unit.__isset.MILLIS = true; + } break; + case LogicalTypeID::SERIAL: { + schemaElement.converted_type = ConvertedType::SERIAL; + schemaElement.__isset.converted_type = true; + } break; + case LogicalTypeID::UUID: { + schemaElement.type_length = common::ParquetConstants::PARQUET_UUID_SIZE; + schemaElement.__isset.type_length = true; + schemaElement.__isset.logicalType = true; + schemaElement.logicalType.__isset.UUID = true; + } break; + default: + break; + } +} + +void ParquetWriter::flush(FactorizedTable& ft) { + if (ft.getNumTuples() == 0) { + return; + } + + PreparedRowGroup preparedRowGroup; + prepareRowGroup(ft, preparedRowGroup); + flushRowGroup(preparedRowGroup); + ft.clear(); +} + +void ParquetWriter::prepareRowGroup(FactorizedTable& ft, PreparedRowGroup& result) { + // set up a new row group for this chunk collection + auto& row_group = result.rowGroup; + row_group.num_rows = ft.getTotalNumFlatTuples(); + row_group.total_byte_size = row_group.num_rows * ft.getTableSchema()->getNumBytesPerTuple(); + row_group.__isset.file_offset = true; + + auto& states = result.states; + // iterate over each of the columns of the chunk collection and write them + KU_ASSERT(ft.getTableSchema()->getNumColumns() == columnWriters.size()); + std::vector> writerStates; + std::unique_ptr unflatDataChunkToRead = + std::make_unique(ft.getTableSchema()->getNumUnFlatColumns()); + std::unique_ptr flatDataChunkToRead = std::make_unique( + ft.getTableSchema()->getNumFlatColumns(), DataChunkState::getSingleValueDataChunkState()); + std::vector vectorsToRead; + vectorsToRead.reserve(columnWriters.size()); + auto numFlatVectors = 0; + for (auto i = 0u; i < columnWriters.size(); i++) { + writerStates.emplace_back(columnWriters[i]->initializeWriteState(row_group)); + auto vector = std::make_unique(types[i].copy(), mm); + vectorsToRead.push_back(vector.get()); + if (ft.getTableSchema()->getColumn(i)->isFlat()) { + flatDataChunkToRead->insert(numFlatVectors, std::move(vector)); + numFlatVectors++; + } else { + unflatDataChunkToRead->insert(i - numFlatVectors, std::move(vector)); + } + } + uint64_t numTuplesRead = 0u; + while (numTuplesRead < ft.getNumTuples()) { + readFromFT(ft, vectorsToRead, numTuplesRead); + for (auto i = 0u; i < columnWriters.size(); i++) { + if (columnWriters[i]->hasAnalyze()) { + columnWriters[i]->analyze(*writerStates[i], nullptr, vectorsToRead[i], + getNumTuples(unflatDataChunkToRead.get())); + } + } + } + + for (auto i = 0u; i < columnWriters.size(); i++) { + if (columnWriters[i]->hasAnalyze()) { + columnWriters[i]->finalizeAnalyze(*writerStates[i]); + } + } + + numTuplesRead = 0u; + while (numTuplesRead < ft.getNumTuples()) { + readFromFT(ft, vectorsToRead, numTuplesRead); + for (auto i = 0u; i < columnWriters.size(); i++) { + columnWriters[i]->prepare(*writerStates[i], nullptr, vectorsToRead[i], + getNumTuples(unflatDataChunkToRead.get())); + } + } + + for (auto i = 0u; i < columnWriters.size(); i++) { + columnWriters[i]->beginWrite(*writerStates[i]); + } + + numTuplesRead = 0u; + while (numTuplesRead < ft.getNumTuples()) { + readFromFT(ft, vectorsToRead, numTuplesRead); + for (auto i = 0u; i < columnWriters.size(); i++) { + columnWriters[i]->write(*writerStates[i], vectorsToRead[i], + getNumTuples(unflatDataChunkToRead.get())); + } + } + + for (auto& write_state : writerStates) { + states.push_back(std::move(write_state)); + } +} + +void ParquetWriter::flushRowGroup(PreparedRowGroup& rowGroup) { + std::lock_guard glock(lock); + auto& parquetRowGroup = rowGroup.rowGroup; + auto& states = rowGroup.states; + if (states.empty()) { + throw RuntimeException("Attempting to flush a row group with no rows"); + } + parquetRowGroup.file_offset = fileOffset; + for (auto i = 0u; i < states.size(); i++) { + auto write_state = std::move(states[i]); + columnWriters[i]->finalizeWrite(*write_state); + } + + // Append the row group to the file meta data. + fileMetaData.row_groups.push_back(parquetRowGroup); + fileMetaData.num_rows += parquetRowGroup.num_rows; +} + +void ParquetWriter::readFromFT(FactorizedTable& ft, std::vector vectorsToRead, + uint64_t& numTuplesRead) { + auto numTuplesToRead = + ft.getTableSchema()->getNumUnFlatColumns() != 0 ? + 1 : + std::min(ft.getNumTuples() - numTuplesRead, DEFAULT_VECTOR_CAPACITY); + ft.scan(vectorsToRead, numTuplesRead, numTuplesToRead); + numTuplesRead += numTuplesToRead; +} + +void ParquetWriter::finalize() { + auto startOffset = fileOffset; + fileMetaData.write(protocol.get()); + uint32_t metadataSize = fileOffset - startOffset; + fileInfo->writeFile(reinterpret_cast(&metadataSize), sizeof(metadataSize), + fileOffset); + fileOffset += sizeof(uint32_t); + + // Parquet files also end with the string "PAR1". + fileInfo->writeFile(reinterpret_cast(ParquetConstants::PARQUET_MAGIC_WORDS), + strlen(ParquetConstants::PARQUET_MAGIC_WORDS), fileOffset); + fileOffset += strlen(ParquetConstants::PARQUET_MAGIC_WORDS); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/string_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/string_column_writer.cpp new file mode 100644 index 0000000000..d9ac727b2b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/string_column_writer.cpp @@ -0,0 +1,210 @@ +#include "processor/operator/persistent/writer/parquet/string_column_writer.h" + +#include "common/constants.h" +#include "function/comparison/comparison_functions.h" +#include "function/hash/hash_functions.h" +#include "processor/operator/persistent/reader/parquet/parquet_rle_bp_decoder.h" + +namespace lbug { +namespace processor { + +using namespace lbug::common; +using namespace lbug_parquet::format; + +std::size_t StringHash::operator()(const ku_string_t& k) const { + hash_t result = 0; + function::Hash::operation(k, result); + return result; +} + +bool StringEquality::operator()(const ku_string_t& a, const ku_string_t& b) const { + uint8_t result = 0; + function::Equals::operation(a, b, result, nullptr /* leftVector */, nullptr /* rightVector */); + return result; +} + +void StringStatisticsState::update(const ku_string_t& val) { + if (valuesTooBig) { + return; + } + if (val.len > ParquetConstants::MAX_STRING_STATISTICS_SIZE) { + // we avoid gathering stats when individual string values are too large + // this is because the statistics are copied into the Parquet file meta data in + // uncompressed format ideally we avoid placing several mega or giga-byte long strings + // there we put a threshold of 10KB, if we see strings that exceed this threshold we + // avoid gathering stats + valuesTooBig = true; + min = std::string(); + max = std::string(); + return; + } + if (!hasStats || val.getAsString() < min) { + min = val.getAsString(); + } + if (!hasStats || val.getAsString() > max) { + max = val.getAsString(); + } + hasStats = true; +} + +std::unique_ptr StringColumnWriter::initializeWriteState(RowGroup& rowGroup) { + auto result = std::make_unique(rowGroup, rowGroup.columns.size(), mm); + registerToRowGroup(rowGroup); + return result; +} + +void StringColumnWriter::analyze(ColumnWriterState& writerState, ColumnWriterState* parent, + ValueVector* vector, uint64_t count) { + auto& state = reinterpret_cast(writerState); + uint64_t vcount = + parent ? parent->definitionLevels.size() - state.definitionLevels.size() : count; + uint64_t parentIdx = state.definitionLevels.size(); + uint64_t vectorIdx = 0; + uint32_t newValueIdx = state.dictionary.size(); + uint32_t lastValueIdx = -1; + uint64_t runLen = 0; + uint64_t runCount = 0; + for (auto i = 0u; i < vcount; i++) { + if (parent && !parent->isEmpty.empty() && parent->isEmpty[parentIdx + i]) { + continue; + } + auto pos = getVectorPos(vector, vectorIdx); + if (!vector->isNull(pos)) { + runLen++; + const auto& value = vector->getValue(pos); + // Try to insert into the dictionary. If it's already there, we get back the value + // index. + ku_string_t valueToInsert; + StringVector::copyToRowData(vector, pos, reinterpret_cast(&valueToInsert), + state.overflowBuffer.get()); + auto found = state.dictionary.insert( + string_map_t::value_type(valueToInsert, newValueIdx)); + state.estimatedPlainSize += value.len + ParquetConstants::STRING_LENGTH_SIZE; + if (found.second) { + // String didn't exist yet in the dictionary. + newValueIdx++; + state.estimatedDictPageSize += + value.len + ParquetConstants::MAX_DICTIONARY_KEY_SIZE; + } + // If the value changed, we will encode it in the page. + if (lastValueIdx != found.first->second) { + // we will add the value index size later, when we know the total number of keys + state.estimatedRlePagesSize += RleBpEncoder::getVarintSize(runLen); + runLen = 0; + runCount++; + lastValueIdx = found.first->second; + } + } + vectorIdx++; + } + // Add the costs of keys sizes. We don't know yet how many bytes the keys need as we haven't + // seen all the values. therefore we use an over-estimation of + state.estimatedRlePagesSize += ParquetConstants::MAX_DICTIONARY_KEY_SIZE * runCount; +} + +void StringColumnWriter::finalizeAnalyze(ColumnWriterState& writerState) { + auto& state = reinterpret_cast(writerState); + + // Check if a dictionary will require more space than a plain write, or if the dictionary + // page is going to be too large. + if (state.estimatedDictPageSize > ParquetConstants::MAX_UNCOMPRESSED_DICT_PAGE_SIZE || + state.estimatedRlePagesSize + state.estimatedDictPageSize > state.estimatedPlainSize) { + // Clearing the dictionary signals a plain write. + state.dictionary.clear(); + state.keyBitWidth = 0; + } else { + state.keyBitWidth = RleBpDecoder::ComputeBitWidth(state.dictionary.size()); + } +} + +void StringColumnWriter::writeVector(common::Serializer& serializer, + ColumnWriterStatistics* statsToWrite, ColumnWriterPageState* writerPageState, + ValueVector* vector, uint64_t chunkStart, uint64_t chunkEnd) { + auto pageState = reinterpret_cast(writerPageState); + auto stats = reinterpret_cast(statsToWrite); + + if (pageState->isDictionaryEncoded()) { + // Dictionary based page. + for (auto r = chunkStart; r < chunkEnd; r++) { + auto pos = getVectorPos(vector, r); + if (vector->isNull(pos)) { + continue; + } + auto value_index = pageState->dictionary.at(vector->getValue(pos)); + if (!pageState->writtenValue) { + // Write the bit-width as a one-byte entry. + serializer.write(pageState->bitWidth); + // Now begin writing the actual value. + pageState->encoder.beginWrite(value_index); + pageState->writtenValue = true; + } else { + pageState->encoder.writeValue(serializer, value_index); + } + } + } else { + for (auto r = chunkStart; r < chunkEnd; r++) { + auto pos = getVectorPos(vector, r); + if (vector->isNull(pos)) { + continue; + } + auto& str = vector->getValue(pos); + stats->update(str); + serializer.write(str.len); + serializer.write(str.getData(), str.len); + } + } +} + +void StringColumnWriter::flushPageState(common::Serializer& serializer, + ColumnWriterPageState* writerPageState) { + auto pageState = reinterpret_cast(writerPageState); + if (pageState->bitWidth != 0) { + if (!pageState->writtenValue) { + // all values are null + // just write the bit width + serializer.write(pageState->bitWidth); + return; + } + pageState->encoder.finishWrite(serializer); + } +} + +void StringColumnWriter::flushDictionary(BasicColumnWriterState& writerState, + ColumnWriterStatistics* writerStats) { + auto stats = reinterpret_cast(writerStats); + auto& state = reinterpret_cast(writerState); + if (!state.isDictionaryEncoded()) { + return; + } + // First we need to sort the values in index order. + auto values = std::vector(state.dictionary.size()); + for (const auto& entry : state.dictionary) { + KU_ASSERT(values[entry.second].len == 0); + values[entry.second] = entry.first; + } + // First write the contents of the dictionary page to a temporary buffer. + auto bufferedSerializer = std::make_unique(); + for (auto r = 0u; r < values.size(); r++) { + auto& value = values[r]; + // Update the statistics. + stats->update(value); + // Write this string value to the dictionary. + bufferedSerializer->write(value.len); + bufferedSerializer->write(value.getData(), value.len); + } + // Flush the dictionary page and add it to the to-be-written pages. + writeDictionary(state, std::move(bufferedSerializer), values.size()); +} + +uint64_t StringColumnWriter::getRowSize(ValueVector* vector, uint64_t index, + BasicColumnWriterState& writerState) { + auto& state = reinterpret_cast(writerState); + if (state.isDictionaryEncoded()) { + return (state.keyBitWidth + 7) / 8; + } else { + return vector->getValue(getVectorPos(vector, index)).len; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/struct_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/struct_column_writer.cpp new file mode 100644 index 0000000000..317233278a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/struct_column_writer.cpp @@ -0,0 +1,100 @@ +#include "processor/operator/persistent/writer/parquet/struct_column_writer.h" + +#include "common/constants.h" +#include "common/vector/value_vector.h" + +namespace lbug { +namespace processor { + +using namespace lbug_parquet::format; +using namespace lbug::common; + +std::unique_ptr StructColumnWriter::initializeWriteState( + lbug_parquet::format::RowGroup& rowGroup) { + auto result = std::make_unique(rowGroup, rowGroup.columns.size()); + + result->childStates.reserve(childWriters.size()); + for (auto& child_writer : childWriters) { + result->childStates.push_back(child_writer->initializeWriteState(rowGroup)); + } + return result; +} + +bool StructColumnWriter::hasAnalyze() { + for (auto& child_writer : childWriters) { + if (child_writer->hasAnalyze()) { + return true; + } + } + return false; +} + +void StructColumnWriter::analyze(ColumnWriterState& state_p, ColumnWriterState* /*parent*/, + ValueVector* vector, uint64_t count) { + auto& state = reinterpret_cast(state_p); + auto& childVectors = StructVector::getFieldVectors(vector); + for (auto child_idx = 0u; child_idx < childWriters.size(); child_idx++) { + // Need to check again. It might be that just one child needs it but the rest not + if (childWriters[child_idx]->hasAnalyze()) { + childWriters[child_idx]->analyze(*state.childStates[child_idx], &state_p, + childVectors[child_idx].get(), count); + } + } +} + +void StructColumnWriter::finalizeAnalyze(ColumnWriterState& state_p) { + auto& state = reinterpret_cast(state_p); + for (auto child_idx = 0u; child_idx < childWriters.size(); child_idx++) { + // Need to check again. It might be that just one child needs it but the rest not + if (childWriters[child_idx]->hasAnalyze()) { + childWriters[child_idx]->finalizeAnalyze(*state.childStates[child_idx]); + } + } +} + +void StructColumnWriter::prepare(ColumnWriterState& state_p, ColumnWriterState* parent, + ValueVector* vector, uint64_t count) { + auto& state = reinterpret_cast(state_p); + if (parent) { + // propagate empty entries from the parent + while (state.isEmpty.size() < parent->isEmpty.size()) { + state.isEmpty.push_back(parent->isEmpty[state.isEmpty.size()]); + } + } + handleRepeatLevels(state_p, parent); + handleDefineLevels(state_p, parent, vector, count, ParquetConstants::PARQUET_DEFINE_VALID, + maxDefine - 1); + auto& child_vectors = StructVector::getFieldVectors(vector); + for (auto child_idx = 0u; child_idx < childWriters.size(); child_idx++) { + childWriters[child_idx]->prepare(*state.childStates[child_idx], &state_p, + child_vectors[child_idx].get(), count); + } +} + +void StructColumnWriter::beginWrite(ColumnWriterState& state_p) { + auto& state = reinterpret_cast(state_p); + for (auto child_idx = 0u; child_idx < childWriters.size(); child_idx++) { + childWriters[child_idx]->beginWrite(*state.childStates[child_idx]); + } +} + +void StructColumnWriter::write(ColumnWriterState& state_p, ValueVector* vector, uint64_t count) { + auto& state = reinterpret_cast(state_p); + auto& child_vectors = StructVector::getFieldVectors(vector); + for (auto child_idx = 0u; child_idx < childWriters.size(); child_idx++) { + childWriters[child_idx]->write(*state.childStates[child_idx], + child_vectors[child_idx].get(), count); + } +} + +void StructColumnWriter::finalizeWrite(ColumnWriterState& state_p) { + auto& state = reinterpret_cast(state_p); + for (auto child_idx = 0u; child_idx < childWriters.size(); child_idx++) { + // we add the null count of the struct to the null count of the children + childWriters[child_idx]->nullCount += nullCount; + childWriters[child_idx]->finalizeWrite(*state.childStates[child_idx]); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/uuid_column_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/uuid_column_writer.cpp new file mode 100644 index 0000000000..1287676751 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/persistent/writer/parquet/uuid_column_writer.cpp @@ -0,0 +1,38 @@ +#include "processor/operator/persistent/writer/parquet/uuid_column_writer.h" + +#include "common/constants.h" +#include "common/serializer/serializer.h" +#include "common/types/uuid.h" + +namespace lbug { +namespace processor { + +static void writeParquetUUID(common::ku_uuid_t input, uint8_t* result) { + uint64_t high_bytes = input.value.high ^ (int64_t(1) << 63); + uint64_t low_bytes = input.value.low; + + for (auto i = 0u; i < sizeof(uint64_t); i++) { + auto shift_count = (sizeof(uint64_t) - i - 1) * 8; + result[i] = (high_bytes >> shift_count) & 0xFF; + } + for (auto i = 0u; i < sizeof(uint64_t); i++) { + auto shift_count = (sizeof(uint64_t) - i - 1) * 8; + result[sizeof(uint64_t) + i] = (low_bytes >> shift_count) & 0xFF; + } +} + +void UUIDColumnWriter::writeVector(common::Serializer& bufferedSerializer, + ColumnWriterStatistics* /*state*/, ColumnWriterPageState* /*pageState*/, + common::ValueVector* vector, uint64_t chunkStart, uint64_t chunkEnd) { + uint8_t buffer[common::ParquetConstants::PARQUET_UUID_SIZE]; + for (auto i = chunkStart; i < chunkEnd; i++) { + auto pos = getVectorPos(vector, i); + if (!vector->isNull(pos)) { + writeParquetUUID(vector->getValue(pos), buffer); + bufferedSerializer.write(buffer, common::ParquetConstants::PARQUET_UUID_SIZE); + } + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/physical_operator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/physical_operator.cpp new file mode 100644 index 0000000000..b4c165ff79 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/physical_operator.cpp @@ -0,0 +1,256 @@ +#include "processor/operator/physical_operator.h" + +#include "common/exception/interrupt.h" +#include "common/exception/runtime.h" +#include "common/task_system/progress_bar.h" +#include "main/client_context.h" +#include "processor/execution_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { +// LCOV_EXCL_START +std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType operatorType) { + switch (operatorType) { + case PhysicalOperatorType::ALTER: + return "ALTER"; + case PhysicalOperatorType::AGGREGATE: + return "AGGREGATE"; + case PhysicalOperatorType::AGGREGATE_FINALIZE: + return "AGGREGATE_FINALIZE"; + case PhysicalOperatorType::AGGREGATE_SCAN: + return "AGGREGATE_SCAN"; + case PhysicalOperatorType::ATTACH_DATABASE: + return "ATTACH_DATABASE"; + case PhysicalOperatorType::BATCH_INSERT: + return "BATCH_INSERT"; + case PhysicalOperatorType::COPY_TO: + return "COPY_TO"; + case PhysicalOperatorType::CREATE_MACRO: + return "CREATE_MACRO"; + case PhysicalOperatorType::CREATE_SEQUENCE: + return "CREATE_SEQUENCE"; + case PhysicalOperatorType::CREATE_TABLE: + return "CREATE_TABLE"; + case PhysicalOperatorType::CREATE_TYPE: + return "CREATE_TYPE"; + case PhysicalOperatorType::CROSS_PRODUCT: + return "CROSS_PRODUCT"; + case PhysicalOperatorType::DETACH_DATABASE: + return "DETACH_DATABASE"; + case PhysicalOperatorType::DELETE_: + return "DELETE"; + case PhysicalOperatorType::DROP: + return "DROP"; + case PhysicalOperatorType::DUMMY_SINK: + return "DUMMY_SINK"; + case PhysicalOperatorType::DUMMY_SIMPLE_SINK: + return "DUMMY_SIMPLE_SINK"; + case PhysicalOperatorType::EMPTY_RESULT: + return "EMPTY_RESULT"; + case PhysicalOperatorType::EXPORT_DATABASE: + return "EXPORT_DATABASE"; + case PhysicalOperatorType::FILTER: + return "FILTER"; + case PhysicalOperatorType::FLATTEN: + return "FLATTEN"; + case PhysicalOperatorType::HASH_JOIN_BUILD: + return "HASH_JOIN_BUILD"; + case PhysicalOperatorType::HASH_JOIN_PROBE: + return "HASH_JOIN_PROBE"; + case PhysicalOperatorType::IMPORT_DATABASE: + return "IMPORT_DATABASE"; + case PhysicalOperatorType::INDEX_LOOKUP: + return "INDEX_LOOKUP"; + case PhysicalOperatorType::INSERT: + return "INSERT"; + case PhysicalOperatorType::INTERSECT_BUILD: + return "INTERSECT_BUILD"; + case PhysicalOperatorType::INTERSECT: + return "INTERSECT"; + case PhysicalOperatorType::INSTALL_EXTENSION: + return "INSTALL_EXTENSION"; + case PhysicalOperatorType::LIMIT: + return "LIMIT"; + case PhysicalOperatorType::LOAD_EXTENSION: + return "LOAD_EXTENSION"; + case PhysicalOperatorType::MERGE: + return "MERGE"; + case PhysicalOperatorType::MULTIPLICITY_REDUCER: + return "MULTIPLICITY_REDUCER"; + case PhysicalOperatorType::PARTITIONER: + return "PARTITIONER"; + case PhysicalOperatorType::PATH_PROPERTY_PROBE: + return "PATH_PROPERTY_PROBE"; + case PhysicalOperatorType::PRIMARY_KEY_SCAN_NODE_TABLE: + return "PRIMARY_KEY_SCAN_NODE_TABLE"; + case PhysicalOperatorType::PROJECTION: + return "PROJECTION"; + case PhysicalOperatorType::PROFILE: + return "PROFILE"; + case PhysicalOperatorType::RECURSIVE_EXTEND: + return "RECURSIVE_EXTEND"; + case PhysicalOperatorType::RESULT_COLLECTOR: + return "RESULT_COLLECTOR"; + case PhysicalOperatorType::SCAN_NODE_TABLE: + return "SCAN_NODE_TABLE"; + case PhysicalOperatorType::SCAN_REL_TABLE: + return "SCAN_REL_TABLE"; + case PhysicalOperatorType::SEMI_MASKER: + return "SEMI_MASKER"; + case PhysicalOperatorType::SET_PROPERTY: + return "SET_PROPERTY"; + case PhysicalOperatorType::SKIP: + return "SKIP"; + case PhysicalOperatorType::STANDALONE_CALL: + return "STANDALONE_CALL"; + case PhysicalOperatorType::TABLE_FUNCTION_CALL: + return "TABLE_FUNCTION_CALL"; + case PhysicalOperatorType::TOP_K: + return "TOP_K"; + case PhysicalOperatorType::TOP_K_SCAN: + return "TOP_K_SCAN"; + case PhysicalOperatorType::TRANSACTION: + return "TRANSACTION"; + case PhysicalOperatorType::ORDER_BY: + return "ORDER_BY"; + case PhysicalOperatorType::ORDER_BY_MERGE: + return "ORDER_BY_MERGE"; + case PhysicalOperatorType::ORDER_BY_SCAN: + return "ORDER_BY_SCAN"; + case PhysicalOperatorType::UNION_ALL_SCAN: + return "UNION_ALL_SCAN"; + case PhysicalOperatorType::UNWIND: + return "UNWIND"; + case PhysicalOperatorType::USE_DATABASE: + return "USE_DATABASE"; + case PhysicalOperatorType::UNINSTALL_EXTENSION: + return "UNINSTALL_EXTENSION"; + default: + throw RuntimeException("Unknown physical operator type."); + } +} + +std::string PhysicalOperatorUtils::operatorToString(const PhysicalOperator* physicalOp) { + return operatorTypeToString(physicalOp->getOperatorType()) + "[" + + std::to_string(physicalOp->getOperatorID()) + "]"; +} +// LCOV_EXCL_STOP + +PhysicalOperator::PhysicalOperator(PhysicalOperatorType operatorType, + std::unique_ptr child, uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{operatorType, id, std::move(printInfo)} { + children.push_back(std::move(child)); +} + +PhysicalOperator::PhysicalOperator(PhysicalOperatorType operatorType, + std::unique_ptr left, std::unique_ptr right, uint32_t id, + std::unique_ptr printInfo) + : PhysicalOperator{operatorType, id, std::move(printInfo)} { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +PhysicalOperator::PhysicalOperator(PhysicalOperatorType operatorType, physical_op_vector_t children, + uint32_t id, std::unique_ptr printInfo) + : PhysicalOperator{operatorType, id, std::move(printInfo)} { + for (auto& child : children) { + this->children.push_back(std::move(child)); + } +} + +std::unique_ptr PhysicalOperator::moveUnaryChild() { + KU_ASSERT(children.size() == 1); + auto result = std::move(children[0]); + children.clear(); + return result; +} + +void PhysicalOperator::initGlobalState(ExecutionContext* context) { + if (!isSource()) { + children[0]->initGlobalState(context); + } + initGlobalStateInternal(context); +} + +void PhysicalOperator::initLocalState(ResultSet* resultSet_, ExecutionContext* context) { + if (!isSource()) { + children[0]->initLocalState(resultSet_, context); + } + resultSet = resultSet_; + registerProfilingMetrics(context->profiler); + initLocalStateInternal(resultSet_, context); +} + +bool PhysicalOperator::getNextTuple(ExecutionContext* context) { + if (context->clientContext->interrupted()) { + throw InterruptException{}; + } +#ifdef __SINGLE_THREADED__ + // In single-threaded mode, the timeout cannot be checked in the main thread + // because the main thread is blocked by the task execution. Instead, we + // check the timeout in the processor. The timeout handling may still be + // delayed, but it is better than checking it at the end of each task. + // This is the best we can do now because SIGALRM is not cross-platform. + if (context->clientContext->hasTimeout()) { + if (context->clientContext->getTimeoutRemainingInMS() == 0) { + throw InterruptException{}; + } + } +#endif + metrics->executionTime.start(); + auto result = getNextTuplesInternal(context); + ProgressBar::Get(*context->clientContext) + ->updateProgress(context->queryID, getProgress(context)); + metrics->executionTime.stop(); + return result; +} + +void PhysicalOperator::finalize(ExecutionContext* context) { + if (!isSource()) { + children[0]->finalize(context); + } + finalizeInternal(context); +} + +void PhysicalOperator::registerProfilingMetrics(Profiler* profiler) { + auto executionTime = profiler->registerTimeMetric(getTimeMetricKey()); + auto numOutputTuple = profiler->registerNumericMetric(getNumTupleMetricKey()); + metrics = std::make_unique(*executionTime, *numOutputTuple); +} + +double PhysicalOperator::getExecutionTime(Profiler& profiler) const { + auto executionTime = profiler.sumAllTimeMetricsWithKey(getTimeMetricKey()); + if (!isSource()) { + executionTime -= profiler.sumAllTimeMetricsWithKey(children[0]->getTimeMetricKey()); + } + return executionTime; +} + +uint64_t PhysicalOperator::getNumOutputTuples(Profiler& profiler) const { + return profiler.sumAllNumericMetricsWithKey(getNumTupleMetricKey()); +} + +std::unordered_map PhysicalOperator::getProfilerKeyValAttributes( + Profiler& profiler) const { + std::unordered_map result; + result.insert({"ExecutionTime", std::to_string(getExecutionTime(profiler))}); + result.insert({"NumOutputTuples", std::to_string(getNumOutputTuples(profiler))}); + return result; +} + +std::vector PhysicalOperator::getProfilerAttributes(Profiler& profiler) const { + std::vector result; + for (auto& [key, val] : getProfilerKeyValAttributes(profiler)) { + result.emplace_back(key + ": " + std::move(val)); + } + return result; +} + +double PhysicalOperator::getProgress(ExecutionContext* /*context*/) const { + return 0; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/profile.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/profile.cpp new file mode 100644 index 0000000000..140477929d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/profile.cpp @@ -0,0 +1,19 @@ +#include "processor/operator/profile.h" + +#include "main/plan_printer.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +void Profile::executeInternal(ExecutionContext* context) { + const auto planInString = + main::PlanPrinter::printPlanToOstream(info.physicalPlan, context->profiler).str(); + appendMessage(planInString, storage::MemoryManager::Get(*context->clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/projection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/projection.cpp new file mode 100644 index 0000000000..d9e24c1f56 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/projection.cpp @@ -0,0 +1,53 @@ +#include "processor/operator/projection.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" + +using namespace lbug::evaluator; + +namespace lbug { +namespace processor { + +std::string ProjectionPrintInfo::toString() const { + std::string result = "Expressions: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} + +void Projection::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + for (auto i = 0u; i < info.evaluators.size(); ++i) { + auto& expressionEvaluator = *info.evaluators[i]; + expressionEvaluator.init(*resultSet, context->clientContext); + auto [dataChunkPos, vectorPos] = info.exprsOutputPos[i]; + auto dataChunk = resultSet->dataChunks[dataChunkPos]; + dataChunk->valueVectors[vectorPos] = expressionEvaluator.resultVector; + } +} + +bool Projection::getNextTuplesInternal(ExecutionContext* context) { + restoreMultiplicity(); + if (!children[0]->getNextTuple(context)) { + return false; + } + saveMultiplicity(); + for (auto& evaluator : info.evaluators) { + evaluator->evaluate(); + } + if (!info.discardedChunkIndices.empty()) { + resultSet->multiplicity *= + resultSet->getNumTuplesWithoutMultiplicity(info.discardedChunkIndices); + } + // The if statement is added to avoid the cost of calculating numTuples when metric is disabled. + if (metrics->numOutputTuple.enabled) [[unlikely]] { + if (info.activeChunkIndices.empty()) { + // In COUNT(*) case we are projecting away everything and only track multiplicity + metrics->numOutputTuple.increase(resultSet->multiplicity); + } else { + metrics->numOutputTuple.increase(resultSet->getNumTuples(info.activeChunkIndices)); + } + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/recursive_extend.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/recursive_extend.cpp new file mode 100644 index 0000000000..d7863d71e9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/recursive_extend.cpp @@ -0,0 +1,160 @@ +#include "processor/operator/recursive_extend.h" + +#include "binder/expression/node_expression.h" +#include "binder/expression/property_expression.h" +#include "common/task_system/progress_bar.h" +#include "function/gds/compute.h" +#include "function/gds/gds_function_collection.h" +#include "function/gds/gds_utils.h" +#include "processor/execution_context.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +// All recursive join computation have the same vertex compute. This vertex compute writes +// result (could be dst, length or path) from a dst node ID to given source node ID. +class RJVertexCompute : public VertexCompute { +public: + RJVertexCompute(storage::MemoryManager* mm, RecursiveExtendSharedState* sharedState, + std::unique_ptr writer, table_id_set_t nbrTableIDSet) + : mm{mm}, sharedState{sharedState}, writer{std::move(writer)}, + nbrTableIDSet{std::move(nbrTableIDSet)} { + localFT = sharedState->factorizedTablePool.claimLocalTable(mm); + } + ~RJVertexCompute() override { sharedState->factorizedTablePool.returnLocalTable(localFT); } + + bool beginOnTable(table_id_t tableID) override { + // Nbr node table IDs might be different from graph node table IDs + // See comment below in RecursiveExtend::executeInternal. + if (!nbrTableIDSet.contains(tableID)) { + return false; + } + writer->beginWriting(tableID); + return true; + } + + void vertexCompute(offset_t startOffset, offset_t endOffset, table_id_t tableID) override { + for (auto i = startOffset; i < endOffset; ++i) { + if (sharedState->exceedLimit()) { + return; + } + auto nodeID = nodeID_t{i, tableID}; + writer->write(*localFT, nodeID, sharedState->counter.get()); + } + } + + void vertexCompute(table_id_t tableID) override { + writer->write(*localFT, tableID, sharedState->counter.get()); + } + + std::unique_ptr copy() override { + return std::make_unique(mm, sharedState, writer->copy(), nbrTableIDSet); + } + +private: + storage::MemoryManager* mm; + // Shared state storing ftables to materialize output. + RecursiveExtendSharedState* sharedState; + FactorizedTable* localFT; + std::unique_ptr writer; + table_id_set_t nbrTableIDSet; +}; + +static double getRJProgress(offset_t totalNumNodes, offset_t completedNumNodes) { + if (totalNumNodes == 0) { + return 0; + } + return (double)completedNumNodes / totalNumNodes; +} + +static bool requireRelID(const RJAlgorithm& function) { + if (function.getFunctionName() == WeightedSPPathsFunction::name || + function.getFunctionName() == SingleSPPathsFunction::name || + function.getFunctionName() == AllSPPathsFunction::name || + function.getFunctionName() == AllWeightedSPPathsFunction::name || + function.getFunctionName() == VarLenJoinsFunction::name) { + return true; + } + return false; +} + +void RecursiveExtend::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto transaction = transaction::Transaction::Get(*clientContext); + auto progressBar = ProgressBar::Get(*clientContext); + auto graph = sharedState->graph.get(); + auto inputNodeMaskMap = sharedState->getInputNodeMaskMap(); + offset_t totalNumNodes = 0; + if (inputNodeMaskMap != nullptr) { + totalNumNodes = inputNodeMaskMap->getNumMaskedNode(); + } else { + for (auto& tableID : graph->getNodeTableIDs()) { + totalNumNodes += graph->getMaxOffset(transaction, tableID); + } + } + std::vector propertyNames; + if (requireRelID(*function)) { + propertyNames.push_back(InternalKeyword::ID); + } + if (bindData.weightPropertyExpr != nullptr) { + propertyNames.push_back( + bindData.weightPropertyExpr->ptrCast()->getPropertyName()); + } + offset_t completedNumNodes = 0; + auto inputNodeTableIDSet = bindData.nodeInput->constCast().getTableIDsSet(); + for (auto& tableID : graph->getNodeTableIDs()) { + // Input node table IDs could be different from graph node table IDs, e.g. + // Given schema, student-knows->student, teacher-knows->teacher + // MATCH (a:student)-[e*knows]->(b:student) + // the graph node table IDs will include both student and teacher. + if (!inputNodeTableIDSet.contains(tableID)) { + continue; + } + auto calcFunc = [tableID, propertyNames, graph, context, this](offset_t offset) { + auto clientContext = context->clientContext; + auto computeState = function->getComputeState(context, bindData, sharedState.get()); + auto sourceNodeID = nodeID_t{offset, tableID}; + computeState->initSource(sourceNodeID); + GDSUtils::runRecursiveJoinEdgeCompute(context, *computeState, graph, + bindData.extendDirection, bindData.upperBound, sharedState->getOutputNodeMaskMap(), + propertyNames); + auto writer = function->getOutputWriter(context, bindData, *computeState, sourceNodeID, + sharedState.get()); + auto vertexCompute = std::make_unique( + storage::MemoryManager::Get(*clientContext), sharedState.get(), writer->copy(), + bindData.nodeOutput->constCast().getTableIDsSet()); + GDSUtils::runVertexCompute(context, computeState->frontierPair->getState(), graph, + *vertexCompute); + }; + auto maxOffset = graph->getMaxOffset(transaction, tableID); + if (inputNodeMaskMap && inputNodeMaskMap->getOffsetMask(tableID)->isEnabled()) { + for (const auto& offset : + inputNodeMaskMap->getOffsetMask(tableID)->range(0, maxOffset)) { + calcFunc(offset); + progressBar->updateProgress(context->queryID, + getRJProgress(totalNumNodes, completedNumNodes++)); + if (sharedState->exceedLimit()) { + break; + } + } + } else { + for (auto offset = 0u; offset < maxOffset; ++offset) { + calcFunc(offset); + progressBar->updateProgress(context->queryID, + getRJProgress(totalNumNodes, completedNumNodes++)); + if (sharedState->exceedLimit()) { + break; + } + } + } + } + sharedState->factorizedTablePool.mergeLocalTables(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/result_collector.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/result_collector.cpp new file mode 100644 index 0000000000..071ab981d6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/result_collector.cpp @@ -0,0 +1,94 @@ +#include "processor/operator/result_collector.h" + +#include "binder/expression/expression_util.h" +#include "main/query_result/materialized_query_result.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string ResultCollectorPrintInfo::toString() const { + std::string result = ""; + if (accumulateType == AccumulateType::OPTIONAL_) { + result += "Type: " + AccumulateTypeUtil::toString(accumulateType); + } + result += ",Expressions: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} + +void ResultCollector::initNecessaryLocalState(ResultSet* resultSet, ExecutionContext* context) { + payloadVectors.reserve(info.payloadPositions.size()); + for (auto& pos : info.payloadPositions) { + auto vec = resultSet->getValueVector(pos).get(); + payloadVectors.push_back(vec); + payloadAndMarkVectors.push_back(vec); + } + if (info.accumulateType == AccumulateType::OPTIONAL_) { + markVector = std::make_unique(LogicalType::BOOL(), + MemoryManager::Get(*context->clientContext)); + markVector->state = DataChunkState::getSingleValueDataChunkState(); + markVector->setValue(0, true); + payloadAndMarkVectors.push_back(markVector.get()); + } +} + +void ResultCollector::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + initNecessaryLocalState(resultSet, context); + localTable = std::make_unique(MemoryManager::Get(*context->clientContext), + info.tableSchema.copy()); +} + +void ResultCollector::executeInternal(ExecutionContext* context) { + while (children[0]->getNextTuple(context)) { + if (!payloadVectors.empty()) { + for (auto i = 0u; i < resultSet->multiplicity; i++) { + localTable->append(payloadAndMarkVectors); + } + } + } + if (!payloadVectors.empty()) { + metrics->numOutputTuple.increase(localTable->getTotalNumFlatTuples()); + sharedState->mergeLocalTable(*localTable); + } +} + +void ResultCollector::finalizeInternal(ExecutionContext* context) { + switch (info.accumulateType) { + case AccumulateType::OPTIONAL_: { + auto localResultSet = getResultSet(MemoryManager::Get(*context->clientContext)); + initNecessaryLocalState(localResultSet.get(), context); + // We should remove currIdx completely as some of the code still relies on currIdx = -1 to + // check if the state if unFlat or not. This should no longer be necessary. + // TODO(Ziyi): add an interface in factorized table + auto table = sharedState->getTable(); + auto tableSchema = table->getTableSchema(); + for (auto i = 0u; i < payloadVectors.size(); ++i) { + auto columnSchema = tableSchema->getColumn(i); + if (columnSchema->isFlat()) { + payloadVectors[i]->state->setToFlat(); + } + } + if (table->isEmpty()) { + for (auto& vector : payloadVectors) { + vector->setAsSingleNullEntry(); + } + markVector->setValue(0, false); + table->append(payloadAndMarkVectors); + } + } + default: + break; + } +} + +std::unique_ptr ResultCollector::getQueryResult() const { + return std::make_unique(sharedState->getTable()); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/CMakeLists.txt new file mode 100644 index 0000000000..0de0db6334 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_processor_operator_scan + OBJECT + primary_key_scan_node_table.cpp + scan_multi_rel_tables.cpp + scan_node_table.cpp + scan_rel_table.cpp + scan_table.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/primary_key_scan_node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/primary_key_scan_node_table.cpp new file mode 100644 index 0000000000..e2290b8ac6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/primary_key_scan_node_table.cpp @@ -0,0 +1,75 @@ +#include "processor/operator/scan/primary_key_scan_node_table.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string PrimaryKeyScanPrintInfo::toString() const { + std::string result = "Key: "; + result += key; + if (!alias.empty()) { + result += ",Alias: "; + result += alias; + } + result += ", Expressions: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} + +idx_t PrimaryKeyScanSharedState::getTableIdx() { + std::unique_lock lck{mtx}; + if (cursor < numTables) { + return cursor++; + } + return numTables; +} + +void PrimaryKeyScanNodeTable::initLocalStateInternal(ResultSet* resultSet, + ExecutionContext* context) { + ScanTable::initLocalStateInternal(resultSet, context); + auto nodeIDVector = resultSet->getValueVector(opInfo.nodeIDPos).get(); + scanState = std::make_unique(nodeIDVector, std::vector{}, + nodeIDVector->state); + indexEvaluator->init(*resultSet, context->clientContext); +} + +bool PrimaryKeyScanNodeTable::getNextTuplesInternal(ExecutionContext* context) { + auto transaction = transaction::Transaction::Get(*context->clientContext); + auto tableIdx = sharedState->getTableIdx(); + if (tableIdx >= tableInfos.size()) { + return false; + } + KU_ASSERT(tableIdx < tableInfos.size()); + auto& tableInfo = tableInfos[tableIdx]; + // Look up index + indexEvaluator->evaluate(); + auto indexVector = indexEvaluator->resultVector.get(); + auto& selVector = indexVector->state->getSelVector(); + KU_ASSERT(selVector.getSelSize() == 1); + auto pos = selVector.getSelectedPositions()[0]; + if (indexVector->isNull(pos)) { + return false; + } + offset_t nodeOffset = 0; + auto& table = tableInfo.table->cast(); + if (!table.lookupPK(transaction, indexVector, pos, nodeOffset)) { + return false; + } + auto nodeID = nodeID_t{nodeOffset, table.getTableID()}; + scanState->nodeIDVector->setValue(pos, nodeID); + // Look up properties + tableInfo.initScanState(*scanState, outVectors, context->clientContext); + table.initScanState(transaction, *scanState, nodeID.tableID, nodeOffset); + auto succeeded = table.lookup(transaction, *scanState); + tableInfo.castColumns(); + metrics->numOutputTuple.incrementByOne(); + return succeeded; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_multi_rel_tables.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_multi_rel_tables.cpp new file mode 100644 index 0000000000..d9c75aa91e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_multi_rel_tables.cpp @@ -0,0 +1,109 @@ +#include "processor/operator/scan/scan_multi_rel_tables.h" + +#include "processor/execution_context.h" +#include "storage/local_storage/local_storage.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +bool DirectionInfo::needFlip(RelDataDirection relDataDirection) const { + if (extendFromSource && relDataDirection == RelDataDirection::BWD) { + return true; + } + if (!extendFromSource && relDataDirection == RelDataDirection::FWD) { + return true; + } + return false; +} + +bool RelTableCollectionScanner::scan(main::ClientContext* context, RelTableScanState& scanState, + const std::vector& outVectors) { + auto transaction = Transaction::Get(*context); + while (true) { + auto& relInfo = relInfos[currentTableIdx]; + if (relInfo.table->scan(transaction, scanState)) { + auto& selVector = scanState.outState->getSelVector(); + if (directionVector != nullptr) { + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + directionVector->setValue(selVector[i], directionValues[currentTableIdx]); + } + } + if (selVector.getSelSize() > 0) { + relInfo.castColumns(); + return true; + } + } else { + currentTableIdx = nextTableIdx; + if (currentTableIdx == relInfos.size()) { + return false; + } + auto& currentInfo = relInfos[currentTableIdx]; + currentInfo.initScanState(scanState, outVectors, context); + currentInfo.table->initScanState(transaction, scanState, currentTableIdx == 0); + nextTableIdx++; + } + } +} + +void ScanMultiRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + ScanTable::initLocalStateInternal(resultSet, context); + auto clientContext = context->clientContext; + boundNodeIDVector = resultSet->getValueVector(opInfo.nodeIDPos).get(); + auto nbrNodeIDVector = outVectors[0]; + scanState = std::make_unique(*MemoryManager::Get(*clientContext), + boundNodeIDVector, outVectors, nbrNodeIDVector->state); + for (auto& [_, scanner] : scanners) { + for (auto& relInfo : scanner.relInfos) { + if (directionInfo.directionPos.isValid()) { + scanner.directionVector = + resultSet->getValueVector(directionInfo.directionPos).get(); + scanner.directionValues.push_back(directionInfo.needFlip(relInfo.direction)); + } + } + } + currentScanner = nullptr; +} + +bool ScanMultiRelTable::getNextTuplesInternal(ExecutionContext* context) { + while (true) { + if (currentScanner != nullptr && + currentScanner->scan(context->clientContext, *scanState, outVectors)) { + metrics->numOutputTuple.increase(scanState->outState->getSelVector().getSelSize()); + return true; + } + if (!children[0]->getNextTuple(context)) { + resetState(); + return false; + } + const auto currentIdx = boundNodeIDVector->state->getSelVector()[0]; + if (boundNodeIDVector->isNull(currentIdx)) { + currentScanner = nullptr; + continue; + } + auto nodeID = boundNodeIDVector->getValue(currentIdx); + initCurrentScanner(nodeID); + } +} + +void ScanMultiRelTable::resetState() { + currentScanner = nullptr; + for (auto& [_, scanner] : scanners) { + scanner.resetState(); + } +} + +void ScanMultiRelTable::initCurrentScanner(const nodeID_t& nodeID) { + if (scanners.contains(nodeID.tableID)) { + currentScanner = &scanners.at(nodeID.tableID); + currentScanner->resetState(); + } else { + currentScanner = nullptr; + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_node_table.cpp new file mode 100644 index 0000000000..95db1c06b0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_node_table.cpp @@ -0,0 +1,142 @@ +#include "processor/operator/scan/scan_node_table.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" +#include "storage/local_storage/local_node_table.h" +#include "storage/local_storage/local_storage.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string ScanNodeTablePrintInfo::toString() const { + std::string result = "Tables: "; + for (auto& tableName : tableNames) { + result += tableName; + if (tableName != tableNames.back()) { + result += ", "; + } + } + if (!alias.empty()) { + result += ",Alias: "; + result += alias; + } + if (!properties.empty()) { + result += ",Properties: "; + result += binder::ExpressionUtil::toString(properties); + } + return result; +} + +void ScanNodeTableSharedState::initialize(const transaction::Transaction* transaction, + NodeTable* table, ScanNodeTableProgressSharedState& progressSharedState) { + this->table = table; + this->currentCommittedGroupIdx = 0; + this->currentUnCommittedGroupIdx = 0; + this->numCommittedNodeGroups = table->getNumCommittedNodeGroups(); + if (transaction->isWriteTransaction()) { + if (const auto localTable = + transaction->getLocalStorage()->getLocalTable(this->table->getTableID())) { + auto& localNodeTable = localTable->cast(); + this->numUnCommittedNodeGroups = localNodeTable.getNumNodeGroups(); + } + } + progressSharedState.numGroups += numCommittedNodeGroups; +} + +void ScanNodeTableSharedState::nextMorsel(NodeTableScanState& scanState, + ScanNodeTableProgressSharedState& progressSharedState) { + std::unique_lock lck{mtx}; + if (currentCommittedGroupIdx < numCommittedNodeGroups) { + scanState.nodeGroupIdx = currentCommittedGroupIdx++; + progressSharedState.numGroupsScanned++; + scanState.source = TableScanSource::COMMITTED; + return; + } + if (currentUnCommittedGroupIdx < numUnCommittedNodeGroups) { + scanState.nodeGroupIdx = currentUnCommittedGroupIdx++; + scanState.source = TableScanSource::UNCOMMITTED; + return; + } + scanState.source = TableScanSource::NONE; +} + +table_id_map_t ScanNodeTable::getSemiMasks() const { + table_id_map_t result; + KU_ASSERT(tableInfos.size() == sharedStates.size()); + for (auto i = 0u; i < sharedStates.size(); ++i) { + result.insert({tableInfos[i].table->getTableID(), sharedStates[i]->getSemiMask()}); + } + return result; +} + +void ScanNodeTableInfo::initScanState(TableScanState& scanState, + const std::vector& outVectors, main::ClientContext* context) { + auto transaction = transaction::Transaction::Get(*context); + scanState.setToTable(transaction, table, columnIDs, copyVector(columnPredicates)); + initScanStateVectors(scanState, outVectors, MemoryManager::Get(*context)); +} + +void ScanNodeTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + ScanTable::initLocalStateInternal(resultSet, context); + auto nodeIDVector = resultSet->getValueVector(opInfo.nodeIDPos).get(); + scanState = std::make_unique(nodeIDVector, outVectors, nodeIDVector->state); + currentTableIdx = 0; + initCurrentTable(context); +} + +void ScanNodeTable::initCurrentTable(ExecutionContext* context) { + auto& currentInfo = tableInfos[currentTableIdx]; + currentInfo.initScanState(*scanState, outVectors, context->clientContext); + scanState->semiMask = sharedStates[currentTableIdx]->getSemiMask(); +} + +void ScanNodeTable::initGlobalStateInternal(ExecutionContext* context) { + KU_ASSERT(sharedStates.size() == tableInfos.size()); + for (auto i = 0u; i < tableInfos.size(); i++) { + sharedStates[i]->initialize(transaction::Transaction::Get(*context->clientContext), + tableInfos[i].table->ptrCast(), *progressSharedState); + } +} + +bool ScanNodeTable::getNextTuplesInternal(ExecutionContext* context) { + const auto transaction = transaction::Transaction::Get(*context->clientContext); + while (currentTableIdx < tableInfos.size()) { + auto& info = tableInfos[currentTableIdx]; + while (info.table->scan(transaction, *scanState)) { + const auto outputSize = scanState->outState->getSelVector().getSelSize(); + if (outputSize > 0) { + info.castColumns(); + scanState->outState->setToUnflat(); + metrics->numOutputTuple.increase(outputSize); + return true; + } + } + sharedStates[currentTableIdx]->nextMorsel(*scanState, *progressSharedState); + if (scanState->source == TableScanSource::NONE) { + currentTableIdx++; + if (currentTableIdx < tableInfos.size()) { + initCurrentTable(context); + } + } else { + info.table->initScanState(transaction, *scanState); + } + } + return false; +} + +double ScanNodeTable::getProgress(ExecutionContext* /*context*/) const { + if (currentTableIdx >= tableInfos.size()) { + return 1.0; + } + if (progressSharedState->numGroups == 0) { + return 0.0; + } + return static_cast(progressSharedState->numGroupsScanned) / + progressSharedState->numGroups; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_rel_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_rel_table.cpp new file mode 100644 index 0000000000..6b933167f2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_rel_table.cpp @@ -0,0 +1,93 @@ +#include "processor/operator/scan/scan_rel_table.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" +#include "storage/local_storage/local_rel_table.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::string ScanRelTablePrintInfo::toString() const { + std::string result = "Tables: "; + for (auto& tableName : tableNames) { + result += tableName; + if (tableName != tableNames.back()) { + result += ", "; + } + } + if (!alias.empty()) { + result += ",Alias: "; + result += alias; + } + result += ",Direction: ("; + result += boundNode->toString(); + result += ")"; + switch (direction) { + case ExtendDirection::FWD: { + result += "-["; + result += rel->detailsToString(); + result += "]->"; + } break; + case ExtendDirection::BWD: { + result += "<-["; + result += rel->detailsToString(); + result += "]-"; + } break; + case ExtendDirection::BOTH: { + result += "<-["; + result += rel->detailsToString(); + result += "]->"; + } break; + default: + KU_UNREACHABLE; + } + result += "("; + result += nbrNode->toString(); + result += ")"; + if (!properties.empty()) { + result += ",Properties: "; + result += binder::ExpressionUtil::toString(properties); + } + return result; +} + +void ScanRelTableInfo::initScanState(TableScanState& scanState, + const std::vector& outVectors, main::ClientContext* context) { + auto transaction = transaction::Transaction::Get(*context); + scanState.setToTable(transaction, table, columnIDs, copyVector(columnPredicates), direction); + initScanStateVectors(scanState, outVectors, MemoryManager::Get(*context)); +} + +void ScanRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + ScanTable::initLocalStateInternal(resultSet, context); + auto clientContext = context->clientContext; + auto boundNodeIDVector = resultSet->getValueVector(opInfo.nodeIDPos).get(); + auto nbrNodeIDVector = outVectors[0]; + scanState = std::make_unique(*MemoryManager::Get(*clientContext), + boundNodeIDVector, outVectors, nbrNodeIDVector->state); + tableInfo.initScanState(*scanState, outVectors, clientContext); +} + +bool ScanRelTable::getNextTuplesInternal(ExecutionContext* context) { + const auto transaction = transaction::Transaction::Get(*context->clientContext); + while (true) { + while (tableInfo.table->scan(transaction, *scanState)) { + const auto outputSize = scanState->outState->getSelVector().getSelSize(); + if (outputSize > 0) { + // No need to perform column cast because this is single table scan. + metrics->numOutputTuple.increase(outputSize); + return true; + } + } + if (!children[0]->getNextTuple(context)) { + return false; + } + tableInfo.table->initScanState(transaction, *scanState); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_table.cpp new file mode 100644 index 0000000000..6e8ca2b007 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/scan/scan_table.cpp @@ -0,0 +1,69 @@ +#include "processor/operator/scan/scan_table.h" + +#include "binder/expression/scalar_function_expression.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +void ColumnCaster::init(ValueVector* vectorAfterCasting, storage::MemoryManager* memoryManager) { + this->vectorAfterCasting = vectorAfterCasting; + vectorBeforeCasting = std::make_shared(columnType.copy(), memoryManager); + vectorBeforeCasting->setState(vectorAfterCasting->state); + funcInputVectors = {vectorBeforeCasting}; + funcInputSelVectors = {&vectorBeforeCasting->state->getSelVectorUnsafe()}; +} + +void ColumnCaster::cast() { + auto& funcExpr = castExpr->constCast(); + funcExpr.getFunction().execFunc(funcInputVectors, funcInputSelVectors, *vectorAfterCasting, + &vectorAfterCasting->state->getSelVectorUnsafe(), funcExpr.getBindData()); +} + +void ScanTableInfo::castColumns() { + for (auto& caster : columnCasters) { + if (caster.hasCast()) { + caster.cast(); + } + } +} + +void ScanTableInfo::addColumnInfo(column_id_t columnID, ColumnCaster caster) { + if (caster.hasCast()) { + hasColumnCaster = true; + } + columnIDs.push_back(columnID); + columnCasters.push_back(std::move(caster)); +} + +void ScanTableInfo::initScanStateVectors(TableScanState& scanState, + const std::vector& outVectors, MemoryManager* memoryManager) { + if (!hasColumnCaster) { + // Fast path + scanState.outputVectors = outVectors; + return; + } + scanState.outputVectors.clear(); + for (auto i = 0u; i < columnCasters.size(); ++i) { + auto& caster = columnCasters[i]; + auto vector = outVectors[i]; + if (!caster.hasCast()) { + // No need to cast + scanState.outputVectors.push_back(vector); + } else { + caster.init(vector, memoryManager); + scanState.outputVectors.push_back(caster.getVectorBeforeCasting()); + } + } +} + +void ScanTable::initLocalStateInternal(ResultSet*, ExecutionContext*) { + for (auto& pos : opInfo.outVectorsPos) { + outVectors.push_back(resultSet->getValueVector(pos).get()); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/semi_masker.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/semi_masker.cpp new file mode 100644 index 0000000000..f4e486f811 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/semi_masker.cpp @@ -0,0 +1,221 @@ +#include "processor/operator/semi_masker.h" + +#include "common/constants.h" +#include "common/roaring_mask.h" +#include "processor/execution_context.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +SemiMaskerLocalState* SemiMaskerSharedState::appendLocalState() { + auto localInfo = std::make_unique(); + bool isSingle = masksPerTable.size() == 1; + for (const auto& [tableID, vector] : masksPerTable) { + auto& mask = vector.front(); + auto newOne = SemiMaskUtil::createMask(mask->getMaxOffset()); + if (isSingle) { + localInfo->singleTableRef = newOne.get(); + } + localInfo->localMasksPerTable.insert({tableID, std::move(newOne)}); + } + std::unique_lock lock{mtx}; + localInfos.push_back(std::move(localInfo)); + return localInfos[localInfos.size() - 1].get(); +} + +void SemiMaskerSharedState::mergeToGlobal() { + for (const auto& [tableID, globalVector] : masksPerTable) { + if (globalVector.front()->getMaxOffset() > std::numeric_limits::max()) { + std::vector masks; + for (const auto& localInfo : localInfos) { + const auto& mask = localInfo->localMasksPerTable.at(tableID); + auto mask64 = static_cast(mask.get()); + if (!mask64->roaring->isEmpty()) { + masks.push_back(mask64->roaring.get()); + } + } + auto mergedMask = std::make_shared( + roaring::Roaring64Map::fastunion(masks.size(), + const_cast(masks.data()))); + for (const auto& item : globalVector) { + auto mask64 = static_cast(item); + mask64->roaring = mergedMask; + } + } else { + std::vector masks; + for (const auto& localInfo : localInfos) { + const auto& mask = localInfo->localMasksPerTable.at(tableID); + auto mask32 = static_cast(mask.get()); + if (!mask32->roaring->isEmpty()) { + masks.push_back(mask32->roaring.get()); + } + } + auto mergedMask = std::make_shared(roaring::Roaring::fastunion( + masks.size(), const_cast(masks.data()))); + for (const auto& item : globalVector) { + auto mask32 = static_cast(item); + mask32->roaring = mergedMask; + } + } + } +} + +std::string SemiMaskerPrintInfo::toString() const { + std::string result = "Operators: "; + for (const auto& op : operatorNames) { + result += op; + if (&op != &operatorNames.back()) { + result += ", "; + } + } + return result; +} + +void BaseSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext*) { + keyVector = resultSet->getValueVector(keyPos).get(); + localState = sharedState->appendLocalState(); +} + +void BaseSemiMasker::finalizeInternal(ExecutionContext* /*context*/) { + sharedState->mergeToGlobal(); +} + +bool SingleTableSemiMasker::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto& selVector = keyVector->state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); i++) { + auto pos = selVector[i]; + auto nodeID = keyVector->getValue(pos); + localState->maskSingleTable(nodeID.offset); + } + metrics->numOutputTuple.increase(selVector.getSelSize()); + return true; +} + +bool MultiTableSemiMasker::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto& selVector = keyVector->state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); i++) { + auto pos = selVector[i]; + auto nodeID = keyVector->getValue(pos); + localState->maskMultiTable(nodeID); + } + metrics->numOutputTuple.increase(selVector.getSelSize()); + return true; +} + +void NodeIDsSemiMask::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseSemiMasker::initLocalStateInternal(resultSet, context); + srcNodeIDVector = resultSet->getValueVector(srcNodeIDPos).get(); + dstNodeIDVector = resultSet->getValueVector(dstNodeIDPos).get(); +} + +bool NodeIDsSingleTableSemiMasker::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto& selVector = keyVector->state->getSelVector(); + KU_ASSERT(keyVector->state == srcNodeIDVector->state); + KU_ASSERT(keyVector->state == dstNodeIDVector->state); + auto keyDataVector = ListVector::getDataVector(keyVector); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto pos = selVector[i]; + localState->maskSingleTable(srcNodeIDVector->getValue(pos).offset); + localState->maskSingleTable(dstNodeIDVector->getValue(pos).offset); + auto [offset, size] = keyVector->getValue(pos); + for (auto j = 0u; j < size; ++j) { + localState->maskSingleTable(keyDataVector->getValue(offset + j).offset); + } + } + return true; +} + +bool NodeIDsMultipleTableSemiMasker::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto& selVector = keyVector->state->getSelVector(); + KU_ASSERT(keyVector->state == srcNodeIDVector->state); + KU_ASSERT(keyVector->state == dstNodeIDVector->state); + auto keyDataVector = ListVector::getDataVector(keyVector); + for (auto i = 0u; i < selVector.getSelSize(); ++i) { + auto pos = selVector[i]; + localState->maskMultiTable(srcNodeIDVector->getValue(pos)); + localState->maskMultiTable(dstNodeIDVector->getValue(pos)); + auto [offset, size] = keyVector->getValue(pos); + for (auto j = 0u; j < size; ++j) { + localState->maskMultiTable(keyDataVector->getValue(offset + j)); + } + } + return true; +} + +void PathSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseSemiMasker::initLocalStateInternal(resultSet, context); + auto pathRelsFieldIdx = StructType::getFieldIdx(keyVector->dataType, InternalKeyword::RELS); + pathRelsVector = StructVector::getFieldVector(keyVector, pathRelsFieldIdx).get(); + auto pathRelsDataVector = ListVector::getDataVector(pathRelsVector); + auto pathRelsSrcIDFieldIdx = + StructType::getFieldIdx(pathRelsDataVector->dataType, InternalKeyword::SRC); + pathRelsSrcIDDataVector = + StructVector::getFieldVector(pathRelsDataVector, pathRelsSrcIDFieldIdx).get(); + auto pathRelsDstIDFieldIdx = + StructType::getFieldIdx(pathRelsDataVector->dataType, InternalKeyword::DST); + pathRelsDstIDDataVector = + StructVector::getFieldVector(pathRelsDataVector, pathRelsDstIDFieldIdx).get(); +} + +bool PathSingleTableSemiMasker::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto& selVector = keyVector->state->getSelVector(); + // for both direction, we should deal with direction based on the actual direction of the edge + for (auto i = 0u; i < selVector.getSelSize(); i++) { + auto [offset, size] = pathRelsVector->getValue(selVector[i]); + for (auto j = 0u; j < size; ++j) { + auto pos = offset + j; + if (direction == ExtendDirection::FWD || direction == ExtendDirection::BOTH) { + auto srcNodeID = pathRelsSrcIDDataVector->getValue(pos); + localState->maskSingleTable(srcNodeID.offset); + } + if (direction == ExtendDirection::BWD || direction == ExtendDirection::BOTH) { + auto dstNodeID = pathRelsDstIDDataVector->getValue(pos); + localState->maskSingleTable(dstNodeID.offset); + } + } + } + return true; +} + +bool PathMultipleTableSemiMasker::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + auto& selVector = pathRelsVector->state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); i++) { + auto [offset, size] = pathRelsVector->getValue(selVector[i]); + for (auto j = 0u; j < size; ++j) { + auto pos = offset + j; + if (direction == ExtendDirection::FWD || direction == ExtendDirection::BOTH) { + auto srcNodeID = pathRelsSrcIDDataVector->getValue(pos); + localState->maskMultiTable(srcNodeID); + } + if (direction == ExtendDirection::BWD || direction == ExtendDirection::BOTH) { + auto dstNodeID = pathRelsDstIDDataVector->getValue(pos); + localState->maskMultiTable(dstNodeID); + } + } + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/CMakeLists.txt new file mode 100644 index 0000000000..728015d314 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/CMakeLists.txt @@ -0,0 +1,14 @@ +add_library(lbug_processor_operator_simple + OBJECT + attach_database.cpp + detach_database.cpp + install_extension.cpp + load_extension.cpp + import_db.cpp + export_db.cpp + use_database.cpp + uninstall_extension.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/attach_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/attach_database.cpp new file mode 100644 index 0000000000..98554dbfaf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/attach_database.cpp @@ -0,0 +1,64 @@ +#include "processor/operator/simple/attach_database.h" + +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include "main/attached_database.h" +#include "main/client_context.h" +#include "main/database.h" +#include "main/database_manager.h" +#include "processor/execution_context.h" +#include "storage/storage_extension.h" +#include "storage/storage_manager.h" + +namespace lbug { +namespace processor { + +std::string AttachDatabasePrintInfo::toString() const { + std::string result = "Database: "; + if (!dbName.empty()) { + result += dbName; + } else { + result += dbPath; + } + return result; +} + +static std::string attachMessage() { + return "Attached database successfully."; +} + +void AttachDatabase::executeInternal(ExecutionContext* context) { + auto client = context->clientContext; + auto databaseManager = main::DatabaseManager::Get(*client); + auto memoryManager = storage::MemoryManager::Get(*client); + if (common::StringUtils::getUpper(attachInfo.dbType) == common::ATTACHED_LBUG_DB_TYPE) { + auto db = std::make_unique(attachInfo.dbPath, + attachInfo.dbAlias, common::ATTACHED_LBUG_DB_TYPE, client); + client->setDefaultDatabase(db.get()); + databaseManager->registerAttachedDatabase(std::move(db)); + appendMessage(attachMessage(), memoryManager); + return; + } + for (auto& storageExtension : client->getDatabase()->getStorageExtensions()) { + if (storageExtension->canHandleDB(attachInfo.dbType)) { + auto db = storageExtension->attach(attachInfo.dbAlias, attachInfo.dbPath, client, + attachInfo.options); + databaseManager->registerAttachedDatabase(std::move(db)); + appendMessage(attachMessage(), memoryManager); + return; + } + } + auto errMsg = common::stringFormat("No loaded extension can handle database type: {}.", + attachInfo.dbType); + if (attachInfo.dbType == "duckdb") { + errMsg += "\nDid you forget to load duckdb extension?\nYou can load it by: load " + "extension duckdb;"; + } else if (attachInfo.dbType == "postgres") { + errMsg += "\nDid you forget to load postgres extension?\nYou can load it by: load " + "extension postgres;"; + } + throw common::RuntimeException{errMsg}; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/detach_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/detach_database.cpp new file mode 100644 index 0000000000..004c04a195 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/detach_database.cpp @@ -0,0 +1,28 @@ +#include "processor/operator/simple/detach_database.h" + +#include "main/client_context.h" +#include "main/database.h" +#include "main/database_manager.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace processor { + +std::string DetatchDatabasePrintInfo::toString() const { + return "Database: " + name; +} + +void DetachDatabase::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto dbManager = main::DatabaseManager::Get(*clientContext); + if (dbManager->hasAttachedDatabase(dbName) && + dbManager->getAttachedDatabase(dbName)->getDBType() == common::ATTACHED_LBUG_DB_TYPE) { + clientContext->setDefaultDatabase(nullptr /* defaultDatabase */); + } + dbManager->detachDatabase(dbName); + appendMessage("Detached database successfully.", storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/export_db.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/export_db.cpp new file mode 100644 index 0000000000..a9a24eaee6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/export_db.cpp @@ -0,0 +1,203 @@ +#include "processor/operator/simple/export_db.h" + +#include + +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "common/copier_config/csv_reader_config.h" +#include "common/file_system/virtual_file_system.h" +#include "common/string_utils.h" +#include "extension/extension_manager.h" +#include "function/scalar_macro_function.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::transaction; +using namespace lbug::catalog; +using namespace lbug::main; + +namespace lbug { +namespace processor { + +using std::stringstream; + +std::string ExportDBPrintInfo::toString() const { + std::string result = "Export To: "; + result += filePath; + if (!options.empty()) { + result += ",Options: "; + auto it = options.begin(); + for (auto i = 0u; it != options.end(); ++it, ++i) { + result += it->first + "=" + it->second.toString(); + if (i < options.size() - 1) { + result += ", "; + } + } + } + return result; +} + +static void writeStringStreamToFile(ClientContext* context, const std::string& ssString, + const std::string& path) { + const auto fileInfo = VirtualFileSystem::GetUnsafe(*context)->openFile(path, + FileOpenFlags(FileFlags::WRITE | FileFlags::CREATE_IF_NOT_EXISTS), context); + fileInfo->writeFile(reinterpret_cast(ssString.c_str()), ssString.size(), + 0 /* offset */); +} + +static std::string getTablePropertyDefinitions(const TableCatalogEntry* entry) { + std::string columns; + auto properties = entry->getProperties(); + auto propertyIdx = 0u; + for (auto& property : properties) { + propertyIdx++; + if (property.getType() == LogicalType::INTERNAL_ID()) { + continue; + } + columns += "`" + property.getName() + "`"; + columns += propertyIdx == properties.size() ? "" : ","; + } + return columns; +} + +static void writeCopyNodeStatement(stringstream& ss, const TableCatalogEntry* entry, + const FileScanInfo* info, + const std::unordered_map*>& canUseParallelReader) { + const auto csvConfig = CSVReaderConfig::construct(info->options); + // TODO(Ziyi): We should pass fileName from binder phase to here. + auto fileName = entry->getName() + "." + StringUtils::getLower(info->fileTypeInfo.fileTypeStr); + std::string columns = getTablePropertyDefinitions(entry); + bool useParallelReader = true; + if (canUseParallelReader.contains(fileName)) { + useParallelReader = canUseParallelReader.at(fileName)->load(); + } + auto copyOptionsCypher = CSVOption::toCypher(csvConfig.option.toOptionsMap(useParallelReader)); + if (columns.empty()) { + ss << stringFormat("COPY `{}` FROM \"{}\" {};\n", entry->getName(), fileName, + copyOptionsCypher); + } else { + ss << stringFormat("COPY `{}` ({}) FROM \"{}\" {};\n", entry->getName(), columns, fileName, + copyOptionsCypher); + } +} + +static void writeCopyRelStatement(stringstream& ss, const ClientContext* context, + const TableCatalogEntry* entry, const FileScanInfo* info, + const std::unordered_map*>& canUseParallelReader) { + const auto csvConfig = CSVReaderConfig::construct(info->options); + std::string columns = getTablePropertyDefinitions(entry); + auto transaction = Transaction::Get(*context); + const auto catalog = Catalog::Get(*context); + for (auto& entryInfo : entry->constCast().getRelEntryInfos()) { + auto fromTableName = + catalog->getTableCatalogEntry(transaction, entryInfo.nodePair.srcTableID)->getName(); + auto toTableName = + catalog->getTableCatalogEntry(transaction, entryInfo.nodePair.dstTableID)->getName(); + // TODO(Ziyi): We should pass fileName from binder phase to here. + auto fileName = stringFormat("{}_{}_{}.{}", entry->getName(), fromTableName, toTableName, + StringUtils::getLower(info->fileTypeInfo.fileTypeStr)); + bool useParallelReader = true; + if (canUseParallelReader.contains(fileName)) { + useParallelReader = canUseParallelReader.at(fileName)->load(); + } + auto copyOptionsMap = csvConfig.option.toOptionsMap(useParallelReader); + copyOptionsMap["from"] = stringFormat("'{}'", fromTableName); + copyOptionsMap["to"] = stringFormat("'{}'", toTableName); + auto copyOptions = CSVOption::toCypher(copyOptionsMap); + if (columns.empty()) { + ss << stringFormat("COPY `{}` FROM \"{}\" {};\n", entry->getName(), fileName, + copyOptions); + } else { + ss << stringFormat("COPY `{}` ({}) FROM \"{}\" {};\n", entry->getName(), columns, + fileName, copyOptions); + } + } +} + +static void exportLoadedExtensions(stringstream& ss, const ClientContext* clientContext) { + auto extensionCypher = extension::ExtensionManager::Get(*clientContext)->toCypher(); + if (!extensionCypher.empty()) { + ss << extensionCypher << std::endl; + } +} + +std::string getSchemaCypher(ClientContext* clientContext) { + stringstream ss; + exportLoadedExtensions(ss, clientContext); + const auto catalog = Catalog::Get(*clientContext); + auto transaction = Transaction::Get(*clientContext); + ToCypherInfo toCypherInfo; + for (const auto& nodeTableEntry : + catalog->getNodeTableEntries(transaction, false /* useInternal */)) { + ss << nodeTableEntry->toCypher(toCypherInfo) << std::endl; + } + RelGroupToCypherInfo relTableToCypherInfo{clientContext}; + for (const auto& entry : catalog->getRelGroupEntries(transaction, false /* useInternal */)) { + ss << entry->toCypher(relTableToCypherInfo) << std::endl; + } + RelGroupToCypherInfo relGroupToCypherInfo{clientContext}; + for (const auto sequenceEntry : catalog->getSequenceEntries(transaction)) { + ss << sequenceEntry->toCypher(relGroupToCypherInfo) << std::endl; + } + for (auto macroName : catalog->getMacroNames(transaction)) { + ss << catalog->getScalarMacroFunction(transaction, macroName)->toCypher(macroName) + << std::endl; + } + return ss.str(); +} + +std::string getCopyCypher(const ClientContext* context, const FileScanInfo* boundFileInfo, + const std::unordered_map*>& canUseParallelReader) { + stringstream ss; + auto transaction = Transaction::Get(*context); + const auto catalog = Catalog::Get(*context); + for (const auto& nodeTableEntry : + catalog->getNodeTableEntries(transaction, false /* useInternal */)) { + writeCopyNodeStatement(ss, nodeTableEntry, boundFileInfo, canUseParallelReader); + } + for (const auto& entry : catalog->getRelGroupEntries(transaction, false /* useInternal */)) { + writeCopyRelStatement(ss, context, entry, boundFileInfo, canUseParallelReader); + } + return ss.str(); +} + +std::string getIndexCypher(ClientContext* clientContext, const FileScanInfo& exportFileInfo) { + stringstream ss; + IndexToCypherInfo info{clientContext, exportFileInfo}; + auto transaction = Transaction::Get(*clientContext); + auto catalog = Catalog::Get(*clientContext); + for (auto entry : catalog->getIndexEntries(transaction)) { + auto indexCypher = entry->toCypher(info); + if (!indexCypher.empty()) { + ss << indexCypher << std::endl; + } + } + return ss.str(); +} + +void ExportDB::executeInternal(ExecutionContext* context) { + const auto clientContext = context->clientContext; + // write the schema.cypher file + writeStringStreamToFile(clientContext, getSchemaCypher(clientContext), + boundFileInfo.filePaths[0] + "/" + PortDBConstants::SCHEMA_FILE_NAME); + if (schemaOnly) { + return; + } + // write the copy.cypher file + // for every table, we write COPY FROM statement + writeStringStreamToFile(clientContext, + getCopyCypher(clientContext, &boundFileInfo, sharedState->canUseParallelReader), + boundFileInfo.filePaths[0] + "/" + PortDBConstants::COPY_FILE_NAME); + // write the index.cypher file + writeStringStreamToFile(clientContext, getIndexCypher(clientContext, boundFileInfo), + boundFileInfo.filePaths[0] + "/" + PortDBConstants::INDEX_FILE_NAME); + appendMessage("Exported database successfully.", storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/import_db.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/import_db.cpp new file mode 100644 index 0000000000..bb2bab9673 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/import_db.cpp @@ -0,0 +1,51 @@ +#include "processor/operator/simple/import_db.h" + +#include "common/exception/runtime.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "transaction/transaction_context.h" + +using namespace lbug::common; +using namespace lbug::transaction; +using namespace lbug::catalog; + +namespace lbug { +namespace processor { + +static void validateQueryResult(main::QueryResult* queryResult) { + auto currentResult = queryResult; + while (currentResult) { + if (!currentResult->isSuccess()) { + throw RuntimeException("Import database failed: " + currentResult->getErrorMessage()); + } + currentResult = currentResult->getNextQueryResult(); + } +} + +void ImportDB::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + if (query.empty()) { // Export empty database. + appendMessage("Imported database successfully.", + storage::MemoryManager::Get(*clientContext)); + return; + } + // TODO(Guodong): this is special for "Import database". Should refactor after we support + // multiple DDL and COPY statements in a single transaction. + // Currently, we split multiple query statements into single query and execute them one by one, + // each with an auto transaction. + auto transactionContext = transaction::TransactionContext::Get(*clientContext); + if (transactionContext->hasActiveTransaction()) { + transactionContext->commit(); + } + auto res = clientContext->queryNoLock(query); + validateQueryResult(res.get()); + if (!indexQuery.empty()) { + res = clientContext->queryNoLock(indexQuery); + validateQueryResult(res.get()); + } + appendMessage("Imported database successfully.", storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/install_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/install_extension.cpp new file mode 100644 index 0000000000..7eb634dec0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/install_extension.cpp @@ -0,0 +1,44 @@ +#include "processor/operator/simple/install_extension.h" + +#include "common/string_format.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace processor { + +using namespace lbug::common; +using namespace lbug::extension; + +void InstallExtension::setOutputMessage(bool installed, storage::MemoryManager* memoryManager) { + if (info.forceInstall) { + appendMessage( + stringFormat("Extension: {} updated from the repo: {}.", info.name, info.repo), + memoryManager); + return; + } + if (installed) { + appendMessage( + stringFormat("Extension: {} installed from the repo: {}.", info.name, info.repo), + memoryManager); + } else { + appendMessage( + stringFormat( + "Extension: {} is already installed.\nTo update it, you can run: UPDATE {}.", + info.name, info.name), + memoryManager); + } +} + +void InstallExtension::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + ExtensionInstaller installer{info, *clientContext}; + bool installResult = installer.install(); + setOutputMessage(installResult, storage::MemoryManager::Get(*clientContext)); + if (info.forceInstall) { + KU_ASSERT(installResult); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/load_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/load_extension.cpp new file mode 100644 index 0000000000..ad9a8b25b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/load_extension.cpp @@ -0,0 +1,27 @@ +#include "processor/operator/simple/load_extension.h" + +#include "extension/extension_manager.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +using namespace lbug::extension; + +std::string LoadExtensionPrintInfo::toString() const { + return "Load " + extensionName; +} + +void LoadExtension::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + ExtensionManager::Get(*clientContext)->loadExtension(path, clientContext); + appendMessage(stringFormat("Extension: {} has been loaded.", path), + storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/uninstall_extension.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/uninstall_extension.cpp new file mode 100644 index 0000000000..14e5ca15b6 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/uninstall_extension.cpp @@ -0,0 +1,39 @@ +#include "processor/operator/simple/uninstall_extension.h" + +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "common/string_format.h" +#include "extension/extension.h" +#include "main/client_context.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace processor { + +using namespace lbug::common; +using namespace lbug::extension; + +void UninstallExtension::executeInternal(ExecutionContext* context) { + auto clientContext = context->clientContext; + auto vfs = VirtualFileSystem::GetUnsafe(*clientContext); + auto localLibFilePath = ExtensionUtils::getLocalPathForExtensionLib(clientContext, path); + if (!vfs->fileOrPathExists(localLibFilePath)) { + throw RuntimeException{ + stringFormat("Can not uninstall extension: {} since it has not been installed.", path)}; + } + std::error_code errCode; + if (!std::filesystem::remove_all( + extension::ExtensionUtils::getLocalDirForExtension(clientContext, path), errCode)) { + // LCOV_EXCL_START + throw RuntimeException{ + stringFormat("An error occurred while uninstalling extension: {}. Error: {}.", path, + errCode.message())}; + // LCOV_EXCL_STOP + } + appendMessage(stringFormat("Extension: {} has been uninstalled", path), + storage::MemoryManager::Get(*clientContext)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/use_database.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/use_database.cpp new file mode 100644 index 0000000000..11f561f683 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/simple/use_database.cpp @@ -0,0 +1,25 @@ +#include "processor/operator/simple/use_database.h" + +#include "main/client_context.h" +#include "main/database_manager.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +namespace lbug { +namespace processor { + +void UseDatabase::executeInternal(ExecutionContext* context) { + auto dbManager = main::DatabaseManager::Get(*context->clientContext); + dbManager->setDefaultDatabase(dbName); + appendMessage("Used database successfully.", + storage::MemoryManager::Get(*context->clientContext)); +} + +std::string UseDatabasePrintInfo::toString() const { + std::string result = "Database: "; + result += dbName; + return result; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/sink.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/sink.cpp new file mode 100644 index 0000000000..b8fbbbe5b1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/sink.cpp @@ -0,0 +1,26 @@ +#include "processor/operator/sink.h" + +#include "main/query_result/materialized_query_result.h" +#include "processor/result/factorized_table_util.h" + +namespace lbug { +namespace processor { + +std::unique_ptr Sink::getResultSet(storage::MemoryManager* memoryManager) { + if (resultSetDescriptor == nullptr) { + // Some pipeline does not need a resultSet, e.g. OrderByMerge + return std::unique_ptr(); + } + return std::make_unique(resultSetDescriptor.get(), memoryManager); +} + +std::unique_ptr SimpleSink::getQueryResult() const { + return std::make_unique(messageTable); +} + +void SimpleSink::appendMessage(const std::string& msg, storage::MemoryManager* memoryManager) { + FactorizedTableUtils::appendStringToTable(messageTable.get(), msg, memoryManager); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/skip.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/skip.cpp new file mode 100644 index 0000000000..82ae73e2cc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/skip.cpp @@ -0,0 +1,61 @@ +#include "processor/operator/skip.h" + +#include "processor/execution_context.h" + +namespace lbug { +namespace processor { + +void Skip::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + dataChunkToSelect = resultSet->dataChunks[dataChunkToSelectPos]; +} + +std::string SkipPrintInfo::toString() const { + std::string result = "Skip: "; + result += std::to_string(number); + return result; +} + +bool Skip::getNextTuplesInternal(ExecutionContext* context) { + auto numTupleSkippedBefore = 0u; + auto numTuplesAvailable = 1u; + do { + restoreSelVector(*dataChunkToSelect->state); + // end of execution due to no more input + if (!children[0]->getNextTuple(context)) { + return false; + } + saveSelVector(*dataChunkToSelect->state); + numTuplesAvailable = resultSet->getNumTuples(dataChunksPosInScope); + numTupleSkippedBefore = counter->fetch_add(numTuplesAvailable); + } while (numTupleSkippedBefore + numTuplesAvailable <= skipNumber); + auto numTupleToSkipInCurrentResultSet = (int64_t)(skipNumber - numTupleSkippedBefore); + if (numTupleToSkipInCurrentResultSet <= 0) { + // Other thread has finished skipping. Process everything in current result set. + metrics->numOutputTuple.increase(numTuplesAvailable); + } else { + // If all dataChunks are flat, numTupleAvailable = 1 which means numTupleSkippedBefore = + // skipNumber. So execution is handled in above if statement. + KU_ASSERT(!dataChunkToSelect->state->isFlat()); + auto buffer = dataChunkToSelect->state->getSelVectorUnsafe().getMutableBuffer(); + if (dataChunkToSelect->state->getSelVector().isUnfiltered()) { + for (uint64_t i = numTupleToSkipInCurrentResultSet; + i < dataChunkToSelect->state->getSelVector().getSelSize(); ++i) { + buffer[i - numTupleToSkipInCurrentResultSet] = i; + } + dataChunkToSelect->state->getSelVectorUnsafe().setToFiltered(); + } else { + for (uint64_t i = numTupleToSkipInCurrentResultSet; + i < dataChunkToSelect->state->getSelVector().getSelSize(); ++i) { + buffer[i - numTupleToSkipInCurrentResultSet] = buffer[i]; + } + } + dataChunkToSelect->state->getSelVectorUnsafe().setSelSize( + dataChunkToSelect->state->getSelVector().getSelSize() - + numTupleToSkipInCurrentResultSet); + metrics->numOutputTuple.increase(dataChunkToSelect->state->getSelVector().getSelSize()); + } + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/standalone_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/standalone_call.cpp new file mode 100644 index 0000000000..841611ec54 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/standalone_call.cpp @@ -0,0 +1,39 @@ +#include "processor/operator/standalone_call.h" + +#include "common/cast.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "processor/execution_context.h" + +namespace lbug { +namespace processor { + +std::string StandaloneCallPrintInfo::toString() const { + std::string result = "Function: "; + result += functionName; + return result; +} + +bool StandaloneCall::getNextTuplesInternal(ExecutionContext* context) { + if (standaloneCallInfo.hasExecuted) { + return false; + } + standaloneCallInfo.hasExecuted = true; + switch (standaloneCallInfo.option->optionType) { + case main::OptionType::CONFIGURATION: { + const auto configurationOption = + common::ku_dynamic_cast(standaloneCallInfo.option); + configurationOption->setContext(context->clientContext, standaloneCallInfo.optionValue); + break; + } + case main::OptionType::EXTENSION: + context->clientContext->setExtensionOption(standaloneCallInfo.option->name, + standaloneCallInfo.optionValue); + break; + } + metrics->numOutputTuple.incrementByOne(); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_function_call.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_function_call.cpp new file mode 100644 index 0000000000..2744e0a170 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_function_call.cpp @@ -0,0 +1,57 @@ +#include "processor/operator/table_function_call.h" + +#include "binder/expression/expression_util.h" +#include "processor/execution_context.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +std::string TableFunctionCallPrintInfo::toString() const { + std::string result = "Function: "; + result += funcName; + if (!exprs.empty()) { + result += ", Expressions: "; + result += binder::ExpressionUtil::toString(exprs); + } + return result; +} + +void TableFunctionCall::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + auto initLocalStateInput = + TableFuncInitLocalStateInput(*sharedState, *info.bindData, context->clientContext); + localState = info.function.initLocalStateFunc(initLocalStateInput); + funcInput = std::make_unique(info.bindData.get(), localState.get(), + sharedState.get(), context); + auto initOutputInput = TableFuncInitOutputInput(info.outPosV, *resultSet); + // Technically we should make all table function has its own initOutputFunc. But since most + // table function is using initSingleDataChunkScanOutput. For simplicity, we assume if no + // initOutputFunc provided then we use to initSingleDataChunkScanOutput. + if (info.function.initOutputFunc == nullptr) { + funcOutput = TableFunction::initSingleDataChunkScanOutput(initOutputInput); + } else { + funcOutput = info.function.initOutputFunc(initOutputInput); + } +} + +bool TableFunctionCall::getNextTuplesInternal(ExecutionContext* context) { + funcOutput->resetState(); + funcInput->bindData->evaluateParams(context->clientContext); + auto numTuplesScanned = info.function.tableFunc(*funcInput, *funcOutput); + funcOutput->setOutputSize(numTuplesScanned); + metrics->numOutputTuple.increase(numTuplesScanned); + return numTuplesScanned != 0; +} + +void TableFunctionCall::finalizeInternal(ExecutionContext* context) { + info.function.finalizeFunc(context, sharedState.get()); +} + +double TableFunctionCall::getProgress(ExecutionContext* /*context*/) const { + return info.function.progressFunc(sharedState.get()); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/CMakeLists.txt new file mode 100644 index 0000000000..adde3e3577 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_processor_operator_table_scan + OBJECT + ftable_scan_function.cpp + union_all_scan.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/ftable_scan_function.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/ftable_scan_function.cpp new file mode 100644 index 0000000000..11780d56c4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/ftable_scan_function.cpp @@ -0,0 +1,78 @@ +#include "processor/operator/table_scan/ftable_scan_function.h" + +#include "function/table/simple_table_function.h" +#include "processor/result/factorized_table.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +struct FTableScanSharedState final : public SimpleTableFuncSharedState { + std::shared_ptr table; + uint64_t morselSize; + offset_t nextTupleIdx; + + FTableScanSharedState(std::shared_ptr table, uint64_t morselSize) + : SimpleTableFuncSharedState{table->getNumTuples()}, table{std::move(table)}, + morselSize{morselSize}, nextTupleIdx{0} {} + + TableFuncMorsel getMorsel() override { + std::unique_lock lck{mtx}; + auto numTuplesToScan = std::min(morselSize, table->getNumTuples() - nextTupleIdx); + auto morsel = TableFuncMorsel(nextTupleIdx, nextTupleIdx + numTuplesToScan); + nextTupleIdx += numTuplesToScan; + return morsel; + } +}; + +// FTableScan has an exceptional output where vectors can be in different dataChunks. So we give +// a dummy dataChunk during initialization and never use it. +struct FTableScanTableFuncOutput : TableFuncOutput { + std::vector vectors; + + explicit FTableScanTableFuncOutput(std::vector vectors) + : TableFuncOutput(common::DataChunk{} /* dummy DataChunk */), vectors{std::move(vectors)} {} +}; + +static std::unique_ptr initFTableScanOutput( + const TableFuncInitOutputInput& input) { + std::vector vectors; + for (auto i = 0u; i < input.outColumnPositions.size(); ++i) { + vectors.push_back(input.resultSet.getValueVector(input.outColumnPositions[i]).get()); + } + return std::make_unique(std::move(vectors)); +} + +static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto sharedState = ku_dynamic_cast(input.sharedState); + auto bindData = ku_dynamic_cast(input.bindData); + auto morsel = sharedState->getMorsel(); + if (morsel.endOffset <= morsel.startOffset) { + return 0; + } + auto numTuples = morsel.endOffset - morsel.startOffset; + auto& output_ = ku_dynamic_cast(output); + sharedState->table->scan(output_.vectors, morsel.startOffset, numTuples, + bindData->columnIndices); + return numTuples; +} + +static std::unique_ptr initSharedState( + const TableFuncInitSharedStateInput& input) { + auto bindData = ku_dynamic_cast(input.bindData); + return std::make_unique(bindData->table, bindData->morselSize); +} + +std::unique_ptr FTableScan::getFunction() { + auto function = std::make_unique(name, std::vector{}); + function->tableFunc = tableFunc; + function->initSharedStateFunc = initSharedState; + function->initLocalStateFunc = TableFunction::initEmptyLocalState; + function->initOutputFunc = initFTableScanOutput; + return function; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/union_all_scan.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/union_all_scan.cpp new file mode 100644 index 0000000000..292df1bf34 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/table_scan/union_all_scan.cpp @@ -0,0 +1,63 @@ +#include "processor/operator/table_scan/union_all_scan.h" + +#include + +#include "binder/expression/expression_util.h" +#include "common/metric.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string UnionAllScanPrintInfo::toString() const { + std::string result = "Expressions: "; + result += binder::ExpressionUtil::toString(expressions); + return result; +} + +std::unique_ptr UnionAllScanSharedState::getMorsel() { + std::unique_lock lck{mtx}; + if (tableIdx == tables.size()) { // No more to scan. + return std::make_unique(nullptr /* table */, 0, 0); + } + auto morsel = getMorselNoLock(tables[tableIdx].get()); + // Fetch next table if current table has nothing to scan. + while (morsel->numTuples == 0) { + tableIdx++; + nextTupleIdxToScan = 0; + if (tableIdx == tables.size()) { // No more to scan. + return std::make_unique(nullptr /* table */, 0, 0); + } + morsel = getMorselNoLock(tables[tableIdx].get()); + } + return morsel; +} + +std::unique_ptr UnionAllScanSharedState::getMorselNoLock( + FactorizedTable* table) { + auto numTuplesToScan = std::min(maxMorselSize, table->getNumTuples() - nextTupleIdxToScan); + auto morsel = std::make_unique(table, nextTupleIdxToScan, numTuplesToScan); + nextTupleIdxToScan += numTuplesToScan; + return morsel; +} + +void UnionAllScan::initLocalStateInternal(ResultSet* /*resultSet_*/, + ExecutionContext* /*context*/) { + for (auto& dataPos : info.outputPositions) { + vectors.push_back(resultSet->getValueVector(dataPos).get()); + } +} + +bool UnionAllScan::getNextTuplesInternal(ExecutionContext* /*context*/) { + auto morsel = sharedState->getMorsel(); + if (morsel->numTuples == 0) { + return false; + } + morsel->table->scan(vectors, morsel->startTupleIdx, morsel->numTuples, info.columnIndices); + metrics->numOutputTuple.increase(morsel->numTuples); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/transaction.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/transaction.cpp new file mode 100644 index 0000000000..991bf38f41 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/transaction.cpp @@ -0,0 +1,82 @@ +#include "processor/operator/transaction.h" + +#include "common/exception/transaction_manager.h" +#include "processor/execution_context.h" +#include "transaction/transaction_context.h" +#include "transaction/transaction_manager.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +std::string TransactionPrintInfo::toString() const { + std::string result = "Action: "; + result += TransactionActionUtils::toString(action); + return result; +} + +bool Transaction::getNextTuplesInternal(ExecutionContext* context) { + if (hasExecuted) { + return false; + } + hasExecuted = true; + auto clientContext = context->clientContext; + auto transactionContext = TransactionContext::Get(*clientContext); + validateActiveTransaction(*transactionContext); + switch (transactionAction) { + case TransactionAction::BEGIN_READ: { + transactionContext->beginReadTransaction(); + } break; + case TransactionAction::BEGIN_WRITE: { + transactionContext->beginWriteTransaction(); + } break; + case TransactionAction::COMMIT: { + transactionContext->commit(); + } break; + case TransactionAction::ROLLBACK: { + transactionContext->rollback(); + } break; + case TransactionAction::CHECKPOINT: { + TransactionManager::Get(*clientContext)->checkpoint(*clientContext); + } break; + default: { + KU_UNREACHABLE; + } + } + return true; +} + +void Transaction::validateActiveTransaction(const TransactionContext& context) const { + switch (transactionAction) { + case TransactionAction::BEGIN_READ: + case TransactionAction::BEGIN_WRITE: { + if (context.hasActiveTransaction()) { + throw TransactionManagerException( + "Connection already has an active transaction. Cannot start a transaction within " + "another one. For concurrent multiple transactions, please open other " + "connections."); + } + } break; + case TransactionAction::COMMIT: + case TransactionAction::ROLLBACK: { + if (!context.hasActiveTransaction()) { + throw TransactionManagerException(stringFormat("No active transaction for {}.", + TransactionActionUtils::toString(transactionAction))); + } + } break; + case TransactionAction::CHECKPOINT: { + if (context.hasActiveTransaction()) { + throw TransactionManagerException(stringFormat("Found active transaction for {}.", + TransactionActionUtils::toString(transactionAction))); + } + } break; + default: { + KU_UNREACHABLE; + } + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/unwind.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/unwind.cpp new file mode 100644 index 0000000000..d136bb3713 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/operator/unwind.cpp @@ -0,0 +1,79 @@ +#include "processor/operator/unwind.h" + +#include "binder/expression/expression.h" // IWYU pragma: keep +#include "common/system_config.h" +#include "processor/execution_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +std::string UnwindPrintInfo::toString() const { + std::string result = "Unwind: "; + result += inExpression->toString(); + result += ", As: "; + result += outExpression->toString(); + return result; +} + +void Unwind::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + expressionEvaluator->init(*resultSet, context->clientContext); + outValueVector = resultSet->getValueVector(outDataPos); + if (idPos.isValid()) { + idVector = resultSet->getValueVector(idPos).get(); + } +} + +bool Unwind::hasMoreToRead() const { + return listEntry.offset != INVALID_OFFSET && listEntry.size > startIndex; +} + +void Unwind::copyTuplesToOutVector(uint64_t startPos, uint64_t endPos) const { + auto listDataVector = ListVector::getDataVector(expressionEvaluator->resultVector.get()); + auto listPos = listEntry.offset + startPos; + for (auto i = 0u; i < endPos - startPos; i++) { + outValueVector->copyFromVectorData(i, listDataVector, listPos++); + } + if (idVector != nullptr) { + KU_ASSERT(listDataVector->dataType.getLogicalTypeID() == common::LogicalTypeID::NODE); + auto idFieldVector = StructVector::getFieldVector(listDataVector, 0); + listPos = listEntry.offset + startPos; + for (auto i = 0u; i < endPos - startPos; i++) { + idVector->copyFromVectorData(i, idFieldVector.get(), listPos++); + } + } +} + +bool Unwind::getNextTuplesInternal(ExecutionContext* context) { + if (hasMoreToRead()) { + auto totalElementsCopy = + std::min(DEFAULT_VECTOR_CAPACITY, (uint64_t)listEntry.size - startIndex); + copyTuplesToOutVector(startIndex, (totalElementsCopy + startIndex)); + startIndex += totalElementsCopy; + outValueVector->state->initOriginalAndSelectedSize(totalElementsCopy); + return true; + } + do { + if (!children[0]->getNextTuple(context)) { + return false; + } + expressionEvaluator->evaluate(); + auto pos = expressionEvaluator->resultVector->state->getSelVector()[0]; + if (expressionEvaluator->resultVector->isNull(pos)) { + outValueVector->state->getSelVectorUnsafe().setSelSize(0); + continue; + } + listEntry = expressionEvaluator->resultVector->getValue(pos); + startIndex = 0; + auto totalElementsCopy = std::min(DEFAULT_VECTOR_CAPACITY, (uint64_t)listEntry.size); + copyTuplesToOutVector(0, totalElementsCopy); + startIndex += totalElementsCopy; + outValueVector->state->initOriginalAndSelectedSize(startIndex); + } while (outValueVector->state->getSelVector().getSelSize() == 0); + metrics->numOutputTuple.increase(outValueVector->state->getSelVector().getSelSize()); + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/processor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/processor.cpp new file mode 100644 index 0000000000..a422433152 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/processor.cpp @@ -0,0 +1,82 @@ +#include "processor/processor.h" + +#include "common/task_system/progress_bar.h" +#include "main/query_result.h" +#include "processor/operator/sink.h" +#include "processor/physical_plan.h" +#include "processor/processor_task.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +#if defined(__APPLE__) +QueryProcessor::QueryProcessor(uint64_t numThreads, uint32_t threadQos) { + taskScheduler = std::make_unique(numThreads, threadQos); +} +#else +QueryProcessor::QueryProcessor(uint64_t numThreads) { + taskScheduler = std::make_unique(numThreads); +} +#endif + +std::unique_ptr QueryProcessor::execute(PhysicalPlan* physicalPlan, + ExecutionContext* context) { + auto lastOperator = physicalPlan->lastOperator.get(); + // The root pipeline(task) consists of operators and its prevOperator only, because we + // expect to have linear plans. For binary operators, e.g., HashJoin, we keep probe and its + // prevOperator in the same pipeline, and decompose build and its prevOperator into another + // one. + auto sink = lastOperator->ptrCast(); + auto task = std::make_shared(sink, context); + for (auto i = (int64_t)sink->getNumChildren() - 1; i >= 0; --i) { + decomposePlanIntoTask(sink->getChild(i), task.get(), context); + } + initTask(task.get()); + auto progressBar = ProgressBar::Get(*context->clientContext); + progressBar->startProgress(context->queryID); + taskScheduler->scheduleTaskAndWaitOrError(task, context); + progressBar->endProgress(context->queryID); + return sink->getQueryResult(); +} + +void QueryProcessor::decomposePlanIntoTask(PhysicalOperator* op, Task* task, + ExecutionContext* context) { + if (op->isSource()) { + ProgressBar::Get(*context->clientContext)->addPipeline(); + } + if (op->isSink()) { + auto childTask = std::make_unique(ku_dynamic_cast(op), context); + for (auto i = (int64_t)op->getNumChildren() - 1; i >= 0; --i) { + decomposePlanIntoTask(op->getChild(i), childTask.get(), context); + } + task->addChildTask(std::move(childTask)); + } else { + // Schedule the right most side (e.g., build side of the hash join) first. + for (auto i = (int64_t)op->getNumChildren() - 1; i >= 0; --i) { + decomposePlanIntoTask(op->getChild(i), task, context); + } + } +} + +void QueryProcessor::initTask(Task* task) { + auto processorTask = ku_dynamic_cast(task); + PhysicalOperator* op = processorTask->sink; + while (!op->isSource()) { + if (!op->isParallel()) { + task->setSingleThreadedTask(); + } + op = op->getChild(0); + } + if (!op->isParallel()) { + task->setSingleThreadedTask(); + } + for (auto& child : task->children) { + initTask(child.get()); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/processor_task.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/processor_task.cpp new file mode 100644 index 0000000000..61ec9797ca --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/processor_task.cpp @@ -0,0 +1,44 @@ +#include "processor/processor_task.h" + +#include "common/task_system/progress_bar.h" +#include "main/client_context.h" +#include "main/settings.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +ProcessorTask::ProcessorTask(Sink* sink, ExecutionContext* executionContext) + : Task{executionContext->clientContext->getCurrentSetting(main::ThreadsSetting::name) + .getValue()}, + sharedStateInitialized{false}, sink{sink}, executionContext{executionContext} {} + +void ProcessorTask::run() { + // We need the lock when cloning because multiple threads can be accessing to clone, + // which is not thread safe + lock_t lck{taskMtx}; + if (!sharedStateInitialized) { + sink->initGlobalState(executionContext); + sharedStateInitialized = true; + } + auto taskRoot = sink->copy(); + lck.unlock(); + auto resultSet = + sink->getResultSet(storage::MemoryManager::Get(*executionContext->clientContext)); + taskRoot->ptrCast()->execute(resultSet.get(), executionContext); +} + +void ProcessorTask::finalize() { + ProgressBar::Get(*executionContext->clientContext)->finishPipeline(executionContext->queryID); + sink->finalize(executionContext); +} + +bool ProcessorTask::terminate() { + return sink->terminate(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/CMakeLists.txt new file mode 100644 index 0000000000..e57cca0234 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/CMakeLists.txt @@ -0,0 +1,16 @@ +add_library(lbug_processor_result + OBJECT + base_hash_table.cpp + factorized_table.cpp + factorized_table_pool.cpp + factorized_table_schema.cpp + factorized_table_util.cpp + flat_tuple.cpp + pattern_creation_info_table.cpp + result_set.cpp + result_set_descriptor.cpp + ) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/base_hash_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/base_hash_table.cpp new file mode 100644 index 0000000000..efe5f1f95b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/base_hash_table.cpp @@ -0,0 +1,305 @@ +#include "processor/result/base_hash_table.h" + +#include "math.h" + +#include "common/constants.h" +#include "common/null_buffer.h" +#include "common/type_utils.h" +#include "common/types/ku_list.h" +#include "common/types/types.h" +#include "common/utils.h" +#include "function/comparison/comparison_functions.h" +#include "function/hash/vector_hash_functions.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace processor { + +BaseHashTable::BaseHashTable(storage::MemoryManager& memoryManager, logical_type_vec_t keyTypes) + : maxNumHashSlots{0}, numSlotsPerBlockLog2{0}, slotIdxInBlockMask{0}, + memoryManager{&memoryManager}, keyTypes{std::move(keyTypes)} { + initCompareFuncs(); + initTmpHashVector(); +} + +void BaseHashTable::setMaxNumHashSlots(uint64_t newSize) { + maxNumHashSlots = newSize; +} + +void BaseHashTable::computeVectorHashes(std::span keyVectors) { + hashVector->state = keyVectors[0]->state; + VectorHashFunction::computeHash(*keyVectors[0], keyVectors[0]->state->getSelVector(), + *hashVector.get(), hashVector->state->getSelVector()); + + for (size_t startVecIdx = 1; startVecIdx < keyVectors.size(); startVecIdx++) { + auto keyVector = keyVectors[startVecIdx]; + tmpHashResultVector->state = keyVector->state; + tmpHashCombineResultVector->state = keyVector->state; + VectorHashFunction::computeHash(*keyVector, keyVector->state->getSelVector(), + *tmpHashResultVector, tmpHashResultVector->state->getSelVector()); + tmpHashCombineResultVector->state = + !tmpHashResultVector->state->isFlat() ? tmpHashResultVector->state : hashVector->state; + VectorHashFunction::combineHash(*hashVector, hashVector->state->getSelVector(), + *tmpHashResultVector, tmpHashResultVector->state->getSelVector(), + *tmpHashCombineResultVector, tmpHashCombineResultVector->state->getSelVector()); + hashVector.swap(tmpHashCombineResultVector); + } +} + +template +static bool compareEntry(const common::ValueVector* vector, uint32_t vectorPos, + const uint8_t* entry) { + uint8_t result = 0; + auto key = vector->getData() + vectorPos * vector->getNumBytesPerValue(); + function::Equals::operation(*(T*)key, *(T*)entry, result, nullptr /* leftVector */, + nullptr /* rightVector */); + return result != 0; +} + +template +static bool factorizedTableCompareEntry(const uint8_t* entry1, const uint8_t* entry2, + const LogicalType&) { + return function::Equals::operation(*(T*)entry1, *(T*)entry2); +} + +static ft_compare_function_t getFactorizedTableCompareEntryFunc(const LogicalType& type); + +template<> +bool factorizedTableCompareEntry(const uint8_t* entry1, const uint8_t* entry2, + const LogicalType& type) { + const auto* list1 = reinterpret_cast(entry1); + const auto* list2 = reinterpret_cast(entry2); + if (list1->size != list2->size) { + return false; + } + const auto& childType = ListType::getChildType(type); + const auto childSize = LogicalTypeUtils::getRowLayoutSize(childType); + const auto nullPtr1 = reinterpret_cast(list1->overflowPtr); + const auto nullPtr2 = reinterpret_cast(list2->overflowPtr); + const auto dataPtr1 = nullPtr1 + NullBuffer::getNumBytesForNullValues(list1->size); + const auto dataPtr2 = nullPtr2 + NullBuffer::getNumBytesForNullValues(list2->size); + auto compareFunc = getFactorizedTableCompareEntryFunc(childType); + for (size_t index = 0; index < list1->size; index++) { + const bool child1IsNull = NullBuffer::isNull(nullPtr1, index); + const bool child2IsNull = NullBuffer::isNull(nullPtr2, index); + if (child1IsNull != child2IsNull) { + return false; + } + if (!child1IsNull && !child2IsNull && + !compareFunc(dataPtr1 + index * childSize, dataPtr2 + index * childSize, childType)) { + return false; + } + } + return true; +} + +const uint8_t* getFTStructFirstField(const uint8_t* structEntry, uint64_t numFields) { + return structEntry + common::NullBuffer::getNumBytesForNullValues(numFields); +} + +const uint8_t* getFTStructNodeID(const uint8_t* structEntry, const LogicalType& type) { + return getFTStructFirstField(structEntry, common::StructType::getNumFields(type)); +} + +const uint8_t* getFTStructRelID(const uint8_t* structEntry, const LogicalType& type) { + return getFTStructFirstField(structEntry, common::StructType::getNumFields(type)) + + sizeof(common::internalID_t) * 2 + sizeof(common::ku_string_t); +} + +static bool compareFTNodeEntry(const uint8_t* entry1, const uint8_t* entry2, + const LogicalType& type) { + return factorizedTableCompareEntry(getFTStructNodeID(entry1, type), + getFTStructNodeID(entry2, type), + type /*not actually used; should really be the type of the field*/); +} + +static bool compareFTRelEntry(const uint8_t* entry1, const uint8_t* entry2, + const LogicalType& type) { + return factorizedTableCompareEntry(getFTStructRelID(entry1, type), + getFTStructRelID(entry2, type), + type /*not actually used; should really be the type of the field*/); +} + +template<> +bool factorizedTableCompareEntry(const uint8_t* entry1, const uint8_t* entry2, + const LogicalType& type) { + const auto numFields = StructType::getNumFields(type); + auto entryToCompare1 = getFTStructFirstField(entry1, numFields); + auto entryToCompare2 = getFTStructFirstField(entry2, numFields); + for (auto i = 0u; i < numFields; i++) { + const auto isNullInEntry1 = NullBuffer::isNull(entry1, i); + const auto isNullInEntry2 = NullBuffer::isNull(entry2, i); + if (isNullInEntry1 != isNullInEntry2) { + return false; + } + const auto& fieldType = StructType::getFieldType(type, i); + ft_compare_function_t compareFunc = getFactorizedTableCompareEntryFunc(fieldType); + // If both not null, compare the value. + if (!isNullInEntry1 && !compareFunc(entryToCompare1, entryToCompare2, fieldType)) { + return false; + } + const auto fieldSize = LogicalTypeUtils::getRowLayoutSize(fieldType); + entryToCompare1 += fieldSize; + entryToCompare2 += fieldSize; + } + return true; +} + +static compare_function_t getCompareEntryFunc(const LogicalType& type); + +template<> +[[maybe_unused]] bool compareEntry(const common::ValueVector* vector, + uint32_t vectorPos, const uint8_t* entry) { + auto dataVector = ListVector::getDataVector(vector); + auto listToCompare = vector->getValue(vectorPos); + auto listEntry = reinterpret_cast(entry); + auto entryNullBytes = reinterpret_cast(listEntry->overflowPtr); + auto entryValues = entryNullBytes + NullBuffer::getNumBytesForNullValues(listEntry->size); + auto rowLayoutSize = LogicalTypeUtils::getRowLayoutSize(dataVector->dataType); + compare_function_t compareFunc = getCompareEntryFunc(dataVector->dataType); + if (listToCompare.size != listEntry->size) { + return false; + } + for (auto i = 0u; i < listEntry->size; i++) { + const bool entryChildIsNull = NullBuffer::isNull(entryNullBytes, i); + const bool vectorChildIsNull = dataVector->isNull(listToCompare.offset + i); + if (entryChildIsNull != vectorChildIsNull) { + return false; + } + if (!entryChildIsNull && !vectorChildIsNull && + !compareFunc(dataVector, listToCompare.offset + i, entryValues)) { + return false; + } + entryValues += rowLayoutSize; + } + return true; +} + +static bool compareNodeEntry(const common::ValueVector* vector, uint32_t vectorPos, + const uint8_t* entry) { + KU_ASSERT(0 == common::StructType::getFieldIdx(vector->dataType, common::InternalKeyword::ID)); + auto idVector = common::StructVector::getFieldVector(vector, 0).get(); + return compareEntry(idVector, vectorPos, + getFTStructNodeID(entry, vector->dataType)); +} + +static bool compareRelEntry(const common::ValueVector* vector, uint32_t vectorPos, + const uint8_t* entry) { + KU_ASSERT(3 == common::StructType::getFieldIdx(vector->dataType, common::InternalKeyword::ID)); + auto idVector = common::StructVector::getFieldVector(vector, 3).get(); + return compareEntry(idVector, vectorPos, + getFTStructRelID(entry, vector->dataType)); +} + +template<> +[[maybe_unused]] bool compareEntry(const common::ValueVector* vector, + uint32_t vectorPos, const uint8_t* entry) { + auto numFields = StructType::getNumFields(vector->dataType); + auto entryToCompare = getFTStructFirstField(entry, numFields); + for (auto i = 0u; i < numFields; i++) { + auto isNullInEntry = NullBuffer::isNull(entry, i); + auto fieldVector = StructVector::getFieldVector(vector, i); + // Firstly check null on left and right side. + if (isNullInEntry != fieldVector->isNull(vectorPos)) { + return false; + } + compare_function_t compareFunc = getCompareEntryFunc(fieldVector->dataType); + // If both not null, compare the value. + if (!isNullInEntry && !compareFunc(fieldVector.get(), vectorPos, entryToCompare)) { + return false; + } + entryToCompare += LogicalTypeUtils::getRowLayoutSize(fieldVector->dataType); + } + return true; +} + +static compare_function_t getCompareEntryFunc(const LogicalType& type) { + compare_function_t func; + switch (type.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + func = compareNodeEntry; + } break; + case LogicalTypeID::REL: { + func = compareRelEntry; + } break; + default: { + TypeUtils::visit( + type.getPhysicalType(), [&](T) { func = compareEntry; }, + [](auto) { KU_UNREACHABLE; }); + } + } + return func; +} + +static ft_compare_function_t getFactorizedTableCompareEntryFunc(const LogicalType& type) { + ft_compare_function_t func; + switch (type.getLogicalTypeID()) { + case LogicalTypeID::NODE: { + func = compareFTNodeEntry; + } break; + case LogicalTypeID::REL: { + func = compareFTRelEntry; + } break; + default: { + TypeUtils::visit( + type.getPhysicalType(), + [&](T) { func = factorizedTableCompareEntry; }, + [](auto) { KU_UNREACHABLE; }); + } + } + return func; +} + +void BaseHashTable::initSlotConstant(uint64_t numSlotsPerBlock) { + numSlotsPerBlockLog2 = std::log2(numSlotsPerBlock); + slotIdxInBlockMask = + common::BitmaskUtils::all1sMaskForLeastSignificantBits(numSlotsPerBlockLog2); +} + +// ! This function will only be used by distinct aggregate and hashJoin, which assumes that all +// keyVectors are flat. +bool BaseHashTable::matchFlatVecWithEntry(const std::vector& keyVectors, + const uint8_t* entry) { + for (auto i = 0u; i < keyVectors.size(); i++) { + auto keyVector = keyVectors[i]; + KU_ASSERT(keyVector->state->isFlat()); + KU_ASSERT(keyVector->state->getSelVector().getSelSize() == 1); + auto pos = keyVector->state->getSelVector()[0]; + auto isKeyVectorNull = keyVector->isNull(pos); + auto isEntryKeyNull = + factorizedTable->isNonOverflowColNull(entry + getTableSchema()->getNullMapOffset(), i); + // If either key or entry is null, we shouldn't compare the value of keyVector and + // entry. + if (isKeyVectorNull && isEntryKeyNull) { + continue; + } else if (isKeyVectorNull != isEntryKeyNull) { + return false; + } + if (!compareEntryFuncs[i](keyVector, pos, entry + getTableSchema()->getColOffset(i))) { + return false; + } + } + return true; +} + +void BaseHashTable::initCompareFuncs() { + compareEntryFuncs.reserve(keyTypes.size()); + for (auto i = 0u; i < keyTypes.size(); ++i) { + compareEntryFuncs.push_back(getCompareEntryFunc(keyTypes[i])); + ftCompareEntryFuncs.push_back(getFactorizedTableCompareEntryFunc(keyTypes[i])); + } +} + +void BaseHashTable::initTmpHashVector() { + hashState = std::make_shared(); + hashState->setToFlat(); + hashVector = std::make_unique(LogicalType::HASH(), memoryManager); + hashVector->state = hashState; + tmpHashResultVector = std::make_unique(LogicalType::HASH(), memoryManager); + tmpHashCombineResultVector = std::make_unique(LogicalType::HASH(), memoryManager); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table.cpp new file mode 100644 index 0000000000..e3d466baee --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table.cpp @@ -0,0 +1,770 @@ +#include "processor/result/factorized_table.h" + +#include + +#include "common/assert.h" +#include "common/exception/runtime.h" +#include "common/null_buffer.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +DataBlock::DataBlock(storage::MemoryManager* mm, uint64_t size) : numTuples{0}, freeSize{size} { + block = mm->allocateBuffer(true /* initializeToZero */, size); +} + +DataBlock::~DataBlock() = default; + +uint8_t* DataBlock::getData() const { + return block->getBuffer().data(); +} +std::span DataBlock::getSizedData() const { + return block->getBuffer(); +} +uint8_t* DataBlock::getWritableData() const { + return block->getBuffer().last(freeSize).data(); +} +void DataBlock::resetNumTuplesAndFreeSize() { + freeSize = block->getBuffer().size(); + numTuples = 0; +} +void DataBlock::resetToZero() { + memset(block->getBuffer().data(), 0, block->getBuffer().size()); +} + +void DataBlock::preventDestruction() { + block->preventDestruction(); +} + +void DataBlock::copyTuples(DataBlock* blockToCopyFrom, ft_tuple_idx_t tupleIdxToCopyFrom, + DataBlock* blockToCopyInto, ft_tuple_idx_t tupleIdxToCopyTo, uint32_t numTuplesToCopy, + uint32_t numBytesPerTuple) { + for (auto i = 0u; i < numTuplesToCopy; i++) { + memcpy(blockToCopyInto->getData() + (tupleIdxToCopyTo * numBytesPerTuple), + blockToCopyFrom->getData() + (tupleIdxToCopyFrom * numBytesPerTuple), numBytesPerTuple); + tupleIdxToCopyFrom++; + tupleIdxToCopyTo++; + } + blockToCopyInto->numTuples += numTuplesToCopy; + blockToCopyInto->freeSize -= (numTuplesToCopy * numBytesPerTuple); +} + +void DataBlockCollection::merge(DataBlockCollection& other) { + if (blocks.empty()) { + append(std::move(other.blocks)); + return; + } + // Pop up the old last block first, and then push back blocks from other into the vector. + auto oldLastBlock = std::move(blocks.back()); + blocks.pop_back(); + append(std::move(other.blocks)); + // Insert back tuples in the old last block to the new last block. + auto newLastBlock = blocks.back().get(); + auto numTuplesToAppendIntoNewLastBlock = + std::min(numTuplesPerBlock - newLastBlock->numTuples, oldLastBlock->numTuples); + DataBlock::copyTuples(oldLastBlock.get(), 0, newLastBlock, newLastBlock->numTuples, + numTuplesToAppendIntoNewLastBlock, numBytesPerTuple); + // If any tuples left in the old last block, shift them to the beginning, and push the old last + // block back. + auto numTuplesLeftForNewBlock = oldLastBlock->numTuples - numTuplesToAppendIntoNewLastBlock; + if (numTuplesLeftForNewBlock > 0) { + auto tupleIdxInOldLastBlock = numTuplesToAppendIntoNewLastBlock; + oldLastBlock->resetNumTuplesAndFreeSize(); + DataBlock::copyTuples(oldLastBlock.get(), tupleIdxInOldLastBlock, oldLastBlock.get(), 0, + numTuplesLeftForNewBlock, numBytesPerTuple); + blocks.push_back(std::move(oldLastBlock)); + } +} + +FactorizedTable::FactorizedTable(MemoryManager* memoryManager, FactorizedTableSchema tableSchema) + : memoryManager{memoryManager}, tableSchema{std::move(tableSchema)}, numTuples{0} { + if (!this->tableSchema.isEmpty()) { + inMemOverflowBuffer = std::make_unique(memoryManager); + auto numBytesPerTuple = this->tableSchema.getNumBytesPerTuple(); + if (numBytesPerTuple > TEMP_PAGE_SIZE) { + // I realize it's unlikely to trigger this case because the fixed size part for + // a column is always small. A quick calculation, assume average column size is 16 bytes + // then we need more than 16K column to test this. I choose to throw exception until + // we encounter a use case. + throw RuntimeException( + "Trying to allocate for a large tuple of size greater than 256KB. " + "Allocation is disabled for performance reason."); + } + flatTupleBlockSize = TEMP_PAGE_SIZE; + numFlatTuplesPerBlock = flatTupleBlockSize / numBytesPerTuple; + flatTupleBlockCollection = + std::make_unique(numBytesPerTuple, numFlatTuplesPerBlock); + unFlatTupleBlockCollection = std::make_unique(); + } +} + +FactorizedTable::~FactorizedTable() { + if (!preventDestruction) { + return; + } + flatTupleBlockCollection->preventDestruction(); + unFlatTupleBlockCollection->preventDestruction(); + inMemOverflowBuffer->preventDestruction(); +} + +void FactorizedTable::append(const std::vector& vectors) { + auto numTuplesToAppend = computeNumTuplesToAppend(vectors); + auto appendInfos = allocateFlatTupleBlocks(numTuplesToAppend); + for (auto i = 0u; i < vectors.size(); i++) { + auto numAppendedTuples = 0ul; + for (auto& blockAppendInfo : appendInfos) { + copyVectorToColumn(*vectors[i], blockAppendInfo, numAppendedTuples, i); + numAppendedTuples += blockAppendInfo.numTuplesToAppend; + } + KU_ASSERT(numAppendedTuples == numTuplesToAppend); + } + numTuples += numTuplesToAppend; +} + +void FactorizedTable::resize(uint64_t numTuples) { + if (numTuples > this->numTuples) { + auto numTuplesToAdd = numTuples - this->numTuples; + auto numBytesPerTuple = tableSchema.getNumBytesPerTuple(); + while (flatTupleBlockCollection->needAllocation(numTuplesToAdd * numBytesPerTuple)) { + auto newBlock = std::make_unique(memoryManager, flatTupleBlockSize); + flatTupleBlockCollection->append(std::move(newBlock)); + auto numTuplesToAddInBlock = + std::min(static_cast(numTuplesToAdd), numFlatTuplesPerBlock); + auto block = flatTupleBlockCollection->getLastBlock(); + block->freeSize -= numBytesPerTuple * numTuplesToAddInBlock; + block->numTuples += numTuplesToAddInBlock; + numTuplesToAdd -= numTuplesToAddInBlock; + } + KU_ASSERT(numTuplesToAdd < numFlatTuplesPerBlock); + auto block = flatTupleBlockCollection->getLastBlock(); + block->freeSize -= numBytesPerTuple * numTuplesToAdd; + block->numTuples += numTuplesToAdd; + } else { + auto numTuplesRemaining = numTuples; + KU_ASSERT(flatTupleBlockCollection->getBlocks().size() == 1); + // TODO: It always adds to the end, so this will leave empty blocks in the middle if it's + // reused + for (auto& block : flatTupleBlockCollection->getBlocks()) { + block->numTuples = + std::min(static_cast(numTuplesRemaining), numFlatTuplesPerBlock); + block->freeSize = + block->getSizedData().size() - block->numTuples * tableSchema.getNumBytesPerTuple(); + numTuplesRemaining -= block->numTuples; + } + KU_ASSERT(numTuplesRemaining == 0); + } + this->numTuples = numTuples; +} +uint8_t* FactorizedTable::appendEmptyTuple() { + auto numBytesPerTuple = tableSchema.getNumBytesPerTuple(); + if (flatTupleBlockCollection->needAllocation(numBytesPerTuple)) { + auto newBlock = std::make_unique(memoryManager, flatTupleBlockSize); + flatTupleBlockCollection->append(std::move(newBlock)); + } + auto block = flatTupleBlockCollection->getLastBlock(); + uint8_t* tuplePtr = block->getWritableData(); + block->freeSize -= numBytesPerTuple; + block->numTuples++; + numTuples++; + return tuplePtr; +} + +void FactorizedTable::scan(std::span vectors, ft_tuple_idx_t tupleIdx, + uint64_t numTuplesToScan, std::span colIdxesToScan) const { + KU_ASSERT(tupleIdx + numTuplesToScan <= numTuples); + KU_ASSERT(vectors.size() == colIdxesToScan.size()); + std::unique_ptr tuplesToRead = std::make_unique(numTuplesToScan); + for (auto i = 0u; i < numTuplesToScan; i++) { + tuplesToRead[i] = getTuple(tupleIdx + i); + } + lookup(vectors, colIdxesToScan, tuplesToRead.get(), 0 /* startPos */, numTuplesToScan); +} + +void FactorizedTable::lookup(std::span vectors, + std::span colIdxesToScan, uint8_t** tuplesToRead, uint64_t startPos, + uint64_t numTuplesToRead) const { + KU_ASSERT(vectors.size() == colIdxesToScan.size()); + for (auto i = 0u; i < colIdxesToScan.size(); i++) { + auto vector = vectors[i]; + // TODO(Xiyang/Ziyi): we should set up a rule about when to reset. Should it be in operator? + vector->resetAuxiliaryBuffer(); + ft_col_idx_t colIdx = colIdxesToScan[i]; + if (tableSchema.getColumn(colIdx)->isFlat()) { + KU_ASSERT(!(vector->state->isFlat() && numTuplesToRead > 1)); + readFlatCol(tuplesToRead + startPos, colIdx, *vector, numTuplesToRead); + } else { + // If the caller wants to read an unflat column from factorizedTable, the vector + // must be unflat and the numTuplesToScan should be 1. + KU_ASSERT(!vector->state->isFlat() && numTuplesToRead == 1); + readUnflatCol(tuplesToRead + startPos, colIdx, *vector); + } + } +} + +void FactorizedTable::lookup(std::vector& vectors, const SelectionVector* selVector, + std::vector& colIdxesToScan, uint8_t* tupleToRead) const { + KU_ASSERT(vectors.size() == colIdxesToScan.size()); + for (auto i = 0u; i < colIdxesToScan.size(); i++) { + ft_col_idx_t colIdx = colIdxesToScan[i]; + if (tableSchema.getColumn(colIdx)->isFlat()) { + readFlatCol(&tupleToRead, colIdx, *vectors[i], 1); + } else { + readUnflatCol(tupleToRead, *selVector, colIdx, *vectors[i]); + } + } +} + +void FactorizedTable::lookup(std::vector& vectors, + std::vector& colIdxesToScan, std::vector& tupleIdxesToRead, + uint64_t startPos, uint64_t numTuplesToRead) const { + KU_ASSERT(vectors.size() == colIdxesToScan.size()); + auto tuplesToRead = std::make_unique(tupleIdxesToRead.size()); + KU_ASSERT(numTuplesToRead > 0); + for (auto i = 0u; i < numTuplesToRead; i++) { + tuplesToRead[i] = getTuple(tupleIdxesToRead[i + startPos]); + } + lookup(vectors, colIdxesToScan, tuplesToRead.get(), 0 /* startPos */, numTuplesToRead); +} + +void FactorizedTable::mergeMayContainNulls(FactorizedTable& other) { + for (auto i = 0u; i < other.tableSchema.getNumColumns(); i++) { + if (!other.hasNoNullGuarantee(i)) { + tableSchema.setMayContainsNullsToTrue(i); + } + } +} + +void FactorizedTable::merge(FactorizedTable& other) { + KU_ASSERT(tableSchema == other.tableSchema); + if (other.numTuples == 0) { + return; + } + mergeMayContainNulls(other); + unFlatTupleBlockCollection->append(std::move(other.unFlatTupleBlockCollection)); + flatTupleBlockCollection->merge(*other.flatTupleBlockCollection); + inMemOverflowBuffer->merge(*other.inMemOverflowBuffer); + numTuples += other.numTuples; +} + +bool FactorizedTable::hasUnflatCol() const { + std::vector colIdxes(tableSchema.getNumColumns()); + iota(colIdxes.begin(), colIdxes.end(), 0); + return hasUnflatCol(colIdxes); +} + +uint64_t FactorizedTable::getTotalNumFlatTuples() const { + auto totalNumFlatTuples = 0ul; + for (auto i = 0u; i < getNumTuples(); i++) { + totalNumFlatTuples += getNumFlatTuples(i); + } + return totalNumFlatTuples; +} + +uint64_t FactorizedTable::getNumFlatTuples(ft_tuple_idx_t tupleIdx) const { + std::unordered_map calculatedGroups; + uint64_t numFlatTuples = 1; + auto tupleBuffer = getTuple(tupleIdx); + for (auto i = 0u; i < tableSchema.getNumColumns(); i++) { + auto column = tableSchema.getColumn(i); + auto groupID = column->getGroupID(); + if (!calculatedGroups.contains(groupID)) { + calculatedGroups[groupID] = true; + numFlatTuples *= column->isFlat() ? 1 : ((overflow_value_t*)tupleBuffer)->numElements; + } + tupleBuffer += column->getNumBytes(); + } + return numFlatTuples; +} + +uint8_t* FactorizedTable::getTuple(ft_tuple_idx_t tupleIdx) const { + KU_ASSERT(tupleIdx < numTuples); + auto [blockIdx, tupleIdxInBlock] = getBlockIdxAndTupleIdxInBlock(tupleIdx); + auto buffer = flatTupleBlockCollection->getBlock(blockIdx)->getSizedData(); + // Check that the end of the block doesn't overflow the buffer + KU_ASSERT((tupleIdxInBlock + 1) * tableSchema.getNumBytesPerTuple() <= buffer.size()); + return buffer.data() + tupleIdxInBlock * tableSchema.getNumBytesPerTuple(); +} + +void FactorizedTable::updateFlatCell(uint8_t* tuplePtr, ft_col_idx_t colIdx, + ValueVector* valueVector, uint32_t pos) { + auto nullBuffer = tuplePtr + tableSchema.getNullMapOffset(); + if (valueVector->isNull(pos)) { + setNonOverflowColNull(nullBuffer, colIdx); + } else { + valueVector->copyToRowData(pos, tuplePtr + tableSchema.getColOffset(colIdx), + inMemOverflowBuffer.get()); + NullBuffer::setNoNull(nullBuffer, colIdx); + } +} + +bool FactorizedTable::isOverflowColNull(const uint8_t* nullBuffer, ft_tuple_idx_t tupleIdx, + ft_col_idx_t colIdx) const { + KU_ASSERT(colIdx < tableSchema.getNumColumns()); + if (tableSchema.getColumn(colIdx)->hasNoNullGuarantee()) { + return false; + } + return NullBuffer::isNull(nullBuffer, tupleIdx); +} + +bool FactorizedTable::isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx_t colIdx) const { + KU_ASSERT(colIdx < tableSchema.getNumColumns()); + if (tableSchema.getColumn(colIdx)->hasNoNullGuarantee()) { + return false; + } + return NullBuffer::isNull(nullBuffer, colIdx); +} + +bool FactorizedTable::isNonOverflowColNull(ft_tuple_idx_t tupleIdx, ft_col_idx_t colIdx) const { + KU_ASSERT(colIdx < tableSchema.getNumColumns()); + if (tableSchema.getColumn(colIdx)->hasNoNullGuarantee()) { + return false; + } + return NullBuffer::isNull(getTuple(tupleIdx) + tableSchema.getNullMapOffset(), colIdx); +} + +void FactorizedTable::setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx) { + NullBuffer::setNull(nullBuffer, colIdx); + tableSchema.setMayContainsNullsToTrue(colIdx); +} + +void FactorizedTable::clear() { + numTuples = 0; + flatTupleBlockCollection = std::make_unique( + tableSchema.getNumBytesPerTuple(), numFlatTuplesPerBlock); + unFlatTupleBlockCollection = std::make_unique(); + inMemOverflowBuffer->resetBuffer(); +} + +void FactorizedTable::setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, + ft_tuple_idx_t tupleIdx) { + NullBuffer::setNull(nullBuffer, tupleIdx); + tableSchema.setMayContainsNullsToTrue(colIdx); +} + +// TODO(Guodong): change this function to not use dataChunkPos in ColumnSchema. +uint64_t FactorizedTable::computeNumTuplesToAppend( + const std::vector& vectorsToAppend) const { + KU_ASSERT(!vectorsToAppend.empty()); + auto numTuplesToAppend = 1ul; + for (auto i = 0u; i < vectorsToAppend.size(); i++) { + // If the caller tries to append an unflat vector to a flat column in the + // factorizedTable, the factorizedTable needs to flatten that vector. + if (tableSchema.getColumn(i)->isFlat() && !vectorsToAppend[i]->state->isFlat()) { + // The caller is not allowed to append multiple unFlat columns from different + // datachunks to multiple flat columns in the factorizedTable. + numTuplesToAppend = vectorsToAppend[i]->state->getSelVector().getSelSize(); + } + } + return numTuplesToAppend; +} + +std::vector FactorizedTable::allocateFlatTupleBlocks( + uint64_t numTuplesToAppend) { + auto numBytesPerTuple = tableSchema.getNumBytesPerTuple(); + std::vector appendingInfos; + while (numTuplesToAppend > 0) { + if (flatTupleBlockCollection->needAllocation(numBytesPerTuple)) { + auto newBlock = std::make_unique(memoryManager, flatTupleBlockSize); + flatTupleBlockCollection->append(std::move(newBlock)); + } + auto block = flatTupleBlockCollection->getLastBlock(); + auto numTuplesToAppendInCurBlock = + std::min(numTuplesToAppend, block->freeSize / numBytesPerTuple); + appendingInfos.emplace_back(block->getWritableData(), numTuplesToAppendInCurBlock); + block->freeSize -= numTuplesToAppendInCurBlock * numBytesPerTuple; + block->numTuples += numTuplesToAppendInCurBlock; + numTuplesToAppend -= numTuplesToAppendInCurBlock; + } + return appendingInfos; +} + +uint64_t getDataBlockSize(uint32_t numBytes) { + if (numBytes < TEMP_PAGE_SIZE) { + return TEMP_PAGE_SIZE; + } + return numBytes + 1; +} + +uint8_t* FactorizedTable::allocateUnflatTupleBlock(uint32_t numBytes) { + if (unFlatTupleBlockCollection->isEmpty()) { + auto newBlock = std::make_unique(memoryManager, getDataBlockSize(numBytes)); + unFlatTupleBlockCollection->append(std::move(newBlock)); + } + auto lastBlock = unFlatTupleBlockCollection->getLastBlock(); + if (lastBlock->freeSize > numBytes) { + auto writableData = lastBlock->getWritableData(); + lastBlock->freeSize -= numBytes; + return writableData; + } + auto newBlock = std::make_unique(memoryManager, getDataBlockSize(numBytes)); + unFlatTupleBlockCollection->append(std::move(newBlock)); + lastBlock = unFlatTupleBlockCollection->getLastBlock(); + lastBlock->freeSize -= numBytes; + return lastBlock->getData(); +} + +void FactorizedTable::copyFlatVectorToFlatColumn(const ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, ft_col_idx_t colIdx) { + auto valuePositionInVectorToAppend = vector.state->getSelVector()[0]; + auto colOffsetInDataBlock = tableSchema.getColOffset(colIdx); + auto dstDataPtr = blockAppendInfo.data; + for (auto i = 0u; i < blockAppendInfo.numTuplesToAppend; i++) { + if (vector.isNull(valuePositionInVectorToAppend)) { + setNonOverflowColNull(dstDataPtr + tableSchema.getNullMapOffset(), colIdx); + } else { + vector.copyToRowData(valuePositionInVectorToAppend, dstDataPtr + colOffsetInDataBlock, + inMemOverflowBuffer.get()); + } + dstDataPtr += tableSchema.getNumBytesPerTuple(); + } +} + +void FactorizedTable::copyUnflatVectorToFlatColumn(const ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, uint64_t numAppendedTuples, ft_col_idx_t colIdx) { + auto byteOffsetOfColumnInTuple = tableSchema.getColOffset(colIdx); + auto dstTuple = blockAppendInfo.data; + if (vector.state->getSelVector().isUnfiltered()) { + if (vector.hasNoNullsGuarantee()) { + for (auto i = 0u; i < blockAppendInfo.numTuplesToAppend; i++) { + vector.copyToRowData(numAppendedTuples + i, dstTuple + byteOffsetOfColumnInTuple, + inMemOverflowBuffer.get()); + dstTuple += tableSchema.getNumBytesPerTuple(); + } + } else { + for (auto i = 0u; i < blockAppendInfo.numTuplesToAppend; i++) { + if (vector.isNull(numAppendedTuples + i)) { + setNonOverflowColNull(dstTuple + tableSchema.getNullMapOffset(), colIdx); + } else { + vector.copyToRowData(numAppendedTuples + i, + dstTuple + byteOffsetOfColumnInTuple, inMemOverflowBuffer.get()); + } + dstTuple += tableSchema.getNumBytesPerTuple(); + } + } + } else { + if (vector.hasNoNullsGuarantee()) { + for (auto i = 0u; i < blockAppendInfo.numTuplesToAppend; i++) { + vector.copyToRowData(vector.state->getSelVector()[numAppendedTuples + i], + dstTuple + byteOffsetOfColumnInTuple, inMemOverflowBuffer.get()); + dstTuple += tableSchema.getNumBytesPerTuple(); + } + } else { + for (auto i = 0u; i < blockAppendInfo.numTuplesToAppend; i++) { + auto pos = vector.state->getSelVector()[numAppendedTuples + i]; + if (vector.isNull(pos)) { + setNonOverflowColNull(dstTuple + tableSchema.getNullMapOffset(), colIdx); + } else { + vector.copyToRowData(pos, dstTuple + byteOffsetOfColumnInTuple, + inMemOverflowBuffer.get()); + } + dstTuple += tableSchema.getNumBytesPerTuple(); + } + } + } +} + +// For an unflat column, only an unflat vector is allowed to copy from, for the column, we only +// store an overflow_value_t, which contains a pointer to the overflow dataBlock in the +// factorizedTable. NullMasks are stored inside the overflow buffer. +void FactorizedTable::copyVectorToUnflatColumn(const ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, ft_col_idx_t colIdx) { + KU_ASSERT(!vector.state->isFlat()); + auto unflatTupleValue = appendVectorToUnflatTupleBlocks(vector, colIdx); + auto blockPtr = blockAppendInfo.data + tableSchema.getColOffset(colIdx); + for (auto i = 0u; i < blockAppendInfo.numTuplesToAppend; i++) { + memcpy(blockPtr, (uint8_t*)&unflatTupleValue, sizeof(overflow_value_t)); + blockPtr += tableSchema.getNumBytesPerTuple(); + } +} + +void FactorizedTable::copyVectorToColumn(const ValueVector& vector, + const BlockAppendingInfo& blockAppendInfo, uint64_t numAppendedTuples, ft_col_idx_t colIdx) { + if (tableSchema.getColumn(colIdx)->isFlat()) { + copyVectorToFlatColumn(vector, blockAppendInfo, numAppendedTuples, colIdx); + } else { + copyVectorToUnflatColumn(vector, blockAppendInfo, colIdx); + } +} + +overflow_value_t FactorizedTable::appendVectorToUnflatTupleBlocks(const ValueVector& vector, + ft_col_idx_t colIdx) { + KU_ASSERT(!vector.state->isFlat()); + auto numFlatTuplesInVector = vector.state->getSelVector().getSelSize(); + auto numBytesPerValue = LogicalTypeUtils::getRowLayoutSize(vector.dataType); + auto numBytesForData = numBytesPerValue * numFlatTuplesInVector; + auto overflowBlockBuffer = allocateUnflatTupleBlock( + numBytesForData + NullBuffer::getNumBytesForNullValues(numFlatTuplesInVector)); + if (vector.state->getSelVector().isUnfiltered()) { + if (vector.hasNoNullsGuarantee()) { + auto dstDataBuffer = overflowBlockBuffer; + for (auto i = 0u; i < numFlatTuplesInVector; i++) { + vector.copyToRowData(i, dstDataBuffer, inMemOverflowBuffer.get()); + dstDataBuffer += numBytesPerValue; + } + } else { + auto dstDataBuffer = overflowBlockBuffer; + for (auto i = 0u; i < numFlatTuplesInVector; i++) { + if (vector.isNull(i)) { + setOverflowColNull(overflowBlockBuffer + numBytesForData, colIdx, i); + } else { + vector.copyToRowData(i, dstDataBuffer, inMemOverflowBuffer.get()); + } + dstDataBuffer += numBytesPerValue; + } + } + } else { + if (vector.hasNoNullsGuarantee()) { + auto dstDataBuffer = overflowBlockBuffer; + for (auto i = 0u; i < numFlatTuplesInVector; i++) { + vector.copyToRowData(vector.state->getSelVector()[i], dstDataBuffer, + inMemOverflowBuffer.get()); + dstDataBuffer += numBytesPerValue; + } + } else { + auto dstDataBuffer = overflowBlockBuffer; + for (auto i = 0u; i < numFlatTuplesInVector; i++) { + auto pos = vector.state->getSelVector()[i]; + if (vector.isNull(pos)) { + setOverflowColNull(overflowBlockBuffer + numBytesForData, colIdx, i); + } else { + vector.copyToRowData(pos, dstDataBuffer, inMemOverflowBuffer.get()); + } + dstDataBuffer += numBytesPerValue; + } + } + } + return overflow_value_t{numFlatTuplesInVector, overflowBlockBuffer}; +} + +void FactorizedTable::readUnflatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, + ValueVector& vector) const { + auto overflowColValue = + *(overflow_value_t*)(tuplesToRead[0] + tableSchema.getColOffset(colIdx)); + KU_ASSERT(vector.state->getSelVector().isUnfiltered()); + auto numBytesPerValue = LogicalTypeUtils::getRowLayoutSize(vector.dataType); + if (hasNoNullGuarantee(colIdx)) { + vector.setAllNonNull(); + auto val = overflowColValue.value; + for (auto i = 0u; i < overflowColValue.numElements; i++) { + vector.copyFromRowData(i, val); + val += numBytesPerValue; + } + } else { + auto overflowColNullData = + overflowColValue.value + overflowColValue.numElements * numBytesPerValue; + auto overflowColData = overflowColValue.value; + for (auto i = 0u; i < overflowColValue.numElements; i++) { + if (isOverflowColNull(overflowColNullData, i, colIdx)) { + vector.setNull(i, true); + } else { + vector.setNull(i, false); + vector.copyFromRowData(i, overflowColData); + } + overflowColData += numBytesPerValue; + } + } + vector.state->getSelVectorUnsafe().setSelSize(overflowColValue.numElements); +} + +void FactorizedTable::readUnflatCol(const uint8_t* tupleToRead, const SelectionVector& selVector, + ft_col_idx_t colIdx, ValueVector& vector) const { + auto vectorOverflowValue = *(overflow_value_t*)(tupleToRead + tableSchema.getColOffset(colIdx)); + KU_ASSERT(vector.state->getSelVector().isUnfiltered()); + if (hasNoNullGuarantee(colIdx)) { + vector.setAllNonNull(); + auto val = vectorOverflowValue.value; + for (auto i = 0u; i < vectorOverflowValue.numElements; i++) { + auto pos = selVector[i]; + vector.copyFromRowData(i, val + (pos * vector.getNumBytesPerValue())); + } + } else { + for (auto i = 0u; i < vectorOverflowValue.numElements; i++) { + auto pos = selVector[i]; + if (isOverflowColNull(vectorOverflowValue.value + vectorOverflowValue.numElements * + vector.getNumBytesPerValue(), + pos, colIdx)) { + vector.setNull(i, true); + } else { + vector.setNull(i, false); + vector.copyFromRowData(i, + vectorOverflowValue.value + pos * vector.getNumBytesPerValue()); + } + } + } + vector.state->getSelVectorUnsafe().setSelSize(selVector.getSelSize()); +} + +void FactorizedTable::readFlatColToFlatVector(uint8_t* tupleToRead, ft_col_idx_t colIdx, + ValueVector& vector, sel_t pos) const { + if (isNonOverflowColNull(tupleToRead + tableSchema.getNullMapOffset(), colIdx)) { + vector.setNull(pos, true); + } else { + vector.setNull(pos, false); + vector.copyFromRowData(pos, tupleToRead + tableSchema.getColOffset(colIdx)); + } +} + +void FactorizedTable::readFlatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, ValueVector& vector, + uint64_t numTuplesToRead) const { + if (vector.state->isFlat()) { + auto pos = vector.state->getSelVector()[0]; + readFlatColToFlatVector(tuplesToRead[0], colIdx, vector, pos); + } else { + readFlatColToUnflatVector(tuplesToRead, colIdx, vector, numTuplesToRead); + } +} + +void FactorizedTable::readFlatColToUnflatVector(uint8_t** tuplesToRead, ft_col_idx_t colIdx, + ValueVector& vector, uint64_t numTuplesToRead) const { + vector.state->getSelVectorUnsafe().setSelSize(numTuplesToRead); + if (hasNoNullGuarantee(colIdx)) { + vector.setAllNonNull(); + for (auto i = 0u; i < numTuplesToRead; i++) { + auto positionInVectorToWrite = vector.state->getSelVector()[i]; + auto srcData = tuplesToRead[i] + tableSchema.getColOffset(colIdx); + vector.copyFromRowData(positionInVectorToWrite, srcData); + } + } else { + for (auto i = 0u; i < numTuplesToRead; i++) { + auto positionInVectorToWrite = vector.state->getSelVector()[i]; + auto dataBuffer = tuplesToRead[i]; + if (isNonOverflowColNull(dataBuffer + tableSchema.getNullMapOffset(), colIdx)) { + vector.setNull(positionInVectorToWrite, true); + } else { + vector.setNull(positionInVectorToWrite, false); + vector.copyFromRowData(positionInVectorToWrite, + dataBuffer + tableSchema.getColOffset(colIdx)); + } + } + } +} + +FactorizedTableIterator::FactorizedTableIterator(FactorizedTable& factorizedTable) + : factorizedTable{factorizedTable}, currentTupleBuffer{nullptr}, numFlatTuples{0}, + nextFlatTupleIdx{0}, nextTupleIdx{1} { + resetState(); +} + +void FactorizedTableIterator::getNext(FlatTuple& tuple) { + // Go to the next tuple if we have iterated all the flat tuples of the current tuple. + if (nextFlatTupleIdx >= numFlatTuples) { + currentTupleBuffer = factorizedTable.getTuple(nextTupleIdx); + numFlatTuples = factorizedTable.getNumFlatTuples(nextTupleIdx); + nextFlatTupleIdx = 0; + updateNumElementsInDataChunk(); + nextTupleIdx++; + } + for (auto i = 0ul; i < factorizedTable.getTableSchema()->getNumColumns(); i++) { + auto column = factorizedTable.getTableSchema()->getColumn(i); + if (column->isFlat()) { + readFlatColToFlatTuple(i, currentTupleBuffer, tuple); + } else { + readUnflatColToFlatTuple(i, currentTupleBuffer, tuple); + } + } + updateFlatTuplePositionsInDataChunk(); + nextFlatTupleIdx++; +} + +void FactorizedTableIterator::resetState() { + numFlatTuples = 0; + nextFlatTupleIdx = 0; + nextTupleIdx = 1; + if (factorizedTable.getNumTuples()) { + currentTupleBuffer = factorizedTable.getTuple(0); + numFlatTuples = factorizedTable.getNumFlatTuples(0); + updateNumElementsInDataChunk(); + updateInvalidEntriesInFlatTuplePositionsInDataChunk(); + } +} + +void FactorizedTableIterator::readUnflatColToFlatTuple(ft_col_idx_t colIdx, uint8_t* valueBuffer, + FlatTuple& tuple) { + auto overflowValue = + (overflow_value_t*)(valueBuffer + factorizedTable.getTableSchema()->getColOffset(colIdx)); + auto groupID = factorizedTable.getTableSchema()->getColumn(colIdx)->getGroupID(); + auto tupleSizeInOverflowBuffer = + LogicalTypeUtils::getRowLayoutSize(tuple[colIdx].getDataType()); + valueBuffer = overflowValue->value + + tupleSizeInOverflowBuffer * flatTuplePositionsInDataChunk[groupID].first; + auto isNull = factorizedTable.isOverflowColNull( + overflowValue->value + tupleSizeInOverflowBuffer * overflowValue->numElements, + flatTuplePositionsInDataChunk[groupID].first, colIdx); + tuple[colIdx].setNull(isNull); + if (!isNull) { + tuple[colIdx].copyFromRowLayout(valueBuffer); + } +} + +void FactorizedTableIterator::readFlatColToFlatTuple(ft_col_idx_t colIdx, uint8_t* valueBuffer, + FlatTuple& tuple) { + auto isNull = factorizedTable.isNonOverflowColNull( + valueBuffer + factorizedTable.getTableSchema()->getNullMapOffset(), colIdx); + tuple[colIdx].setNull(isNull); + if (!isNull) { + tuple[colIdx].copyFromRowLayout( + valueBuffer + factorizedTable.getTableSchema()->getColOffset(colIdx)); + } +} + +void FactorizedTableIterator::updateInvalidEntriesInFlatTuplePositionsInDataChunk() { + for (auto i = 0u; i < flatTuplePositionsInDataChunk.size(); i++) { + bool isValidEntry = false; + for (auto j = 0u; j < factorizedTable.getTableSchema()->getNumColumns(); j++) { + if (factorizedTable.getTableSchema()->getColumn(j)->getGroupID() == i) { + isValidEntry = true; + break; + } + } + if (!isValidEntry) { + flatTuplePositionsInDataChunk[i] = std::make_pair(UINT64_MAX, UINT64_MAX); + } + } +} + +void FactorizedTableIterator::updateNumElementsInDataChunk() { + auto colOffsetInTupleBuffer = 0ul; + for (auto i = 0u; i < factorizedTable.getTableSchema()->getNumColumns(); i++) { + auto column = factorizedTable.getTableSchema()->getColumn(i); + auto groupID = column->getGroupID(); + // If this is an unflat column, the number of elements is stored in the + // overflow_value_t struct. Otherwise, the number of elements is 1. + auto numElementsInDataChunk = + column->isFlat() ? + 1 : + ((overflow_value_t*)(currentTupleBuffer + colOffsetInTupleBuffer))->numElements; + if (groupID >= flatTuplePositionsInDataChunk.size()) { + flatTuplePositionsInDataChunk.resize(groupID + 1); + } + flatTuplePositionsInDataChunk[groupID] = + std::make_pair(0 /* nextIdxToReadInDataChunk */, numElementsInDataChunk); + colOffsetInTupleBuffer += column->getNumBytes(); + } +} + +void FactorizedTableIterator::updateFlatTuplePositionsInDataChunk() { + for (auto i = 0u; i < flatTuplePositionsInDataChunk.size(); i++) { + if (!isValidDataChunkPos(i)) { + continue; + } + flatTuplePositionsInDataChunk.at(i).first++; + // If we have output all elements in the current column, we reset the + // nextIdxToReadInDataChunk in the current column to 0. + if (flatTuplePositionsInDataChunk.at(i).first >= + flatTuplePositionsInDataChunk.at(i).second) { + flatTuplePositionsInDataChunk.at(i).first = 0; + } else { + // If the current dataChunk is not full, then we don't need to update the next + // dataChunk. + break; + } + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_pool.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_pool.cpp new file mode 100644 index 0000000000..afa313815b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_pool.cpp @@ -0,0 +1,31 @@ +#include "processor/result/factorized_table_pool.h" + +namespace lbug { +namespace processor { + +FactorizedTable* FactorizedTablePool::claimLocalTable(storage::MemoryManager* mm) { + std::unique_lock lck{mtx}; + if (availableLocalTables.empty()) { + auto table = std::make_shared(mm, globalTable->getTableSchema()->copy()); + localTables.push_back(table); + availableLocalTables.push(table.get()); + } + auto result = availableLocalTables.top(); + availableLocalTables.pop(); + return result; +} + +void FactorizedTablePool::returnLocalTable(FactorizedTable* table) { + std::unique_lock lck{mtx}; + availableLocalTables.push(table); +} + +void FactorizedTablePool::mergeLocalTables() { + std::unique_lock lck{mtx}; + for (auto& localTable : localTables) { + globalTable->merge(*localTable); + } +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_schema.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_schema.cpp new file mode 100644 index 0000000000..c92cc2475e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_schema.cpp @@ -0,0 +1,66 @@ +#include "processor/result/factorized_table_schema.h" + +#include "common/null_buffer.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +ColumnSchema::ColumnSchema(const ColumnSchema& other) { + isUnFlat = other.isUnFlat; + groupID = other.groupID; + numBytes = other.numBytes; + mayContainNulls = other.mayContainNulls; +} + +FactorizedTableSchema::FactorizedTableSchema(const FactorizedTableSchema& other) { + for (auto i = 0u; i < other.columns.size(); ++i) { + appendColumn(other.columns[i].copy()); + } +} + +void FactorizedTableSchema::appendColumn(ColumnSchema column) { + numBytesForDataPerTuple += column.getNumBytes(); + columns.push_back(std::move(column)); + colOffsets.push_back( + colOffsets.empty() ? 0 : colOffsets.back() + getColumn(columns.size() - 2)->getNumBytes()); + numBytesForNullMapPerTuple = NullBuffer::getNumBytesForNullValues(getNumColumns()); + numBytesPerTuple = numBytesForDataPerTuple + numBytesForNullMapPerTuple; +} + +bool FactorizedTableSchema::operator==(const FactorizedTableSchema& other) const { + if (columns.size() != other.columns.size()) { + return false; + } + for (auto i = 0u; i < columns.size(); i++) { + if (columns[i] != other.columns[i]) { + return false; + } + } + return numBytesForDataPerTuple == other.numBytesForDataPerTuple && numBytesForNullMapPerTuple && + other.numBytesForNullMapPerTuple; +} + +uint64_t FactorizedTableSchema::getNumFlatColumns() const { + auto numFlatColumns = 0u; + for (auto& column : columns) { + if (column.isFlat()) { + numFlatColumns++; + } + } + return numFlatColumns; +} + +uint64_t FactorizedTableSchema::getNumUnFlatColumns() const { + auto numUnflatColumns = 0u; + for (auto& column : columns) { + if (!column.isFlat()) { + numUnflatColumns++; + } + } + return numUnflatColumns; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_util.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_util.cpp new file mode 100644 index 0000000000..da7b7ac3af --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/factorized_table_util.cpp @@ -0,0 +1,70 @@ +#include "processor/result/factorized_table_util.h" + +using namespace lbug::storage; +using namespace lbug::common; +using namespace lbug::binder; +using namespace lbug::planner; + +namespace lbug { +namespace processor { + +FactorizedTableSchema FactorizedTableUtils::createFTableSchema(const expression_vector& exprs, + const Schema& schema) { + auto tableSchema = FactorizedTableSchema(); + std::unordered_set groupIDSet; + for (auto& e : exprs) { + auto groupPos = schema.getExpressionPos(*e).first; + auto group = schema.getGroup(groupPos); + if (group->isFlat()) { + auto column = + ColumnSchema(false, groupPos, LogicalTypeUtils::getRowLayoutSize(e->getDataType())); + tableSchema.appendColumn(std::move(column)); + } else { + auto column = ColumnSchema(true, groupPos, (uint32_t)sizeof(overflow_value_t)); + tableSchema.appendColumn(std::move(column)); + } + } + return tableSchema; +} + +FactorizedTableSchema FactorizedTableUtils::createFlatTableSchema( + std::vector columnTypes) { + auto tableSchema = FactorizedTableSchema(); + for (auto& type : columnTypes) { + auto column = ColumnSchema(false /* isUnFlat */, 0 /* groupID */, + LogicalTypeUtils::getRowLayoutSize(type)); + tableSchema.appendColumn(std::move(column)); + } + return tableSchema; +} + +void FactorizedTableUtils::appendStringToTable(FactorizedTable* factorizedTable, + const std::string& outputMsg, MemoryManager* memoryManager) { + auto outputMsgVector = std::make_shared(LogicalTypeID::STRING, memoryManager); + outputMsgVector->state = DataChunkState::getSingleValueDataChunkState(); + auto outputKUStr = ku_string_t(); + outputKUStr.overflowPtr = + reinterpret_cast(StringVector::getInMemOverflowBuffer(outputMsgVector.get()) + ->allocateSpace(outputMsg.length())); + outputKUStr.set(outputMsg); + outputMsgVector->setValue(0, outputKUStr); + factorizedTable->append(std::vector{outputMsgVector.get()}); +} + +std::shared_ptr FactorizedTableUtils::getFactorizedTableForOutputMsg( + const std::string& outputMsg, MemoryManager* memoryManager) { + auto table = getSingleStringColumnFTable(memoryManager); + appendStringToTable(table.get(), outputMsg, memoryManager); + return table; +} + +std::shared_ptr FactorizedTableUtils::getSingleStringColumnFTable( + MemoryManager* mm) { + std::vector typeVec; + typeVec.push_back(LogicalType::STRING()); + auto fTableSchema = createFlatTableSchema(std::move(typeVec)); + return std::make_shared(mm, std::move(fTableSchema)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/flat_tuple.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/flat_tuple.cpp new file mode 100644 index 0000000000..04d7d7d9e1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/flat_tuple.cpp @@ -0,0 +1,99 @@ +#include "processor/result/flat_tuple.h" + +#include + +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "common/types/value/value.h" +#include "utf8proc.h" +#include "utf8proc_wrapper.h" + +using namespace lbug::utf8proc; +using namespace lbug::common; + +namespace lbug { +namespace processor { + +FlatTuple::FlatTuple(const std::vector& types) { + for (auto& type : types) { + values.emplace_back(Value::createDefaultValue(type)); + } +} + +uint32_t FlatTuple::len() const { + return values.size(); +} + +static void checkOutOfRange(idx_t idx, idx_t size) { + if (idx >= size) { + throw RuntimeException(stringFormat( + "ValIdx is out of range. Number of values in flatTuple: {}, valIdx: {}.", size, idx)); + } +} + +Value* FlatTuple::getValue(uint32_t idx) { + checkOutOfRange(idx, len()); + return &values[idx]; +} + +Value& FlatTuple::operator[](idx_t idx) { + checkOutOfRange(idx, len()); + return values[idx]; +} + +const Value& FlatTuple::operator[](idx_t idx) const { + checkOutOfRange(idx, len()); + return values[idx]; +} + +std::string FlatTuple::toString() const { + std::string result; + for (auto i = 0ul; i < values.size(); i++) { + if (i != 0) { + result += "|"; + } + result += values[i].toString(); + } + result += "\n"; + return result; +} + +std::string FlatTuple::toString(const std::vector& colsWidth, + const std::string& delimiter, const uint32_t maxWidth) { + std::ostringstream result; + for (auto i = 0ul; i < values.size(); i++) { + auto value = values[i].toString(); + auto fieldLen = 0u; + auto cutoff = 0u, cutoffLen = 0u; + for (auto iter = 0u; iter < value.length();) { + auto width = Utf8Proc::renderWidth(value.c_str(), iter); + if (fieldLen + 3 > maxWidth && cutoff == 0) { + cutoff = iter; + cutoffLen = fieldLen; + } + fieldLen += width; + iter = utf8proc_next_grapheme(value.c_str(), value.length(), iter); + } + if (fieldLen > maxWidth) { + value = value.substr(0, cutoff) + "..."; + fieldLen = cutoffLen + 3; + } + if (colsWidth[i] != 0) { + value = " " + std::move(value) + " "; + fieldLen += 2; + } + fieldLen = std::min(fieldLen, maxWidth + 2); + if (colsWidth[i] != 0) { + result << value << std::string(colsWidth[i] - fieldLen, ' '); + } else { + result << value; + } + if (i != values.size() - 1) { + result << delimiter; + } + } + return result.str(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/pattern_creation_info_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/pattern_creation_info_table.cpp new file mode 100644 index 0000000000..376ae3ffaf --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/pattern_creation_info_table.cpp @@ -0,0 +1,57 @@ +#include "processor/result/pattern_creation_info_table.h" + +namespace lbug { +namespace processor { + +void PatternCreationInfo::updateID(common::executor_id_t executorID, + common::executor_info executorInfo, common::nodeID_t nodeID) const { + if (!executorInfo.contains(executorID)) { + return; + } + auto ftColIndex = executorInfo.at(executorID); + *(common::nodeID_t*)(tuple + ftColIndex * sizeof(common::nodeID_t)) = nodeID; +} + +PatternCreationInfoTable::PatternCreationInfoTable(storage::MemoryManager& memoryManager, + std::vector keyTypes, FactorizedTableSchema tableSchema) + : AggregateHashTable{memoryManager, copyVector(keyTypes), std::vector{}, + std::vector{} /* empty aggregates */, + std::vector{} /* empty distinct agg key*/, + 0 /* numEntriesToAllocate */, tableSchema.copy()}, + tuple{nullptr}, idColOffset{tableSchema.getColOffset(keyTypes.size())} {} + +PatternCreationInfo PatternCreationInfoTable::getPatternCreationInfo( + const std::vector& keyVectors) { + auto hasCreated = true; + if (keyVectors.size() == 0) { + // Constant keys, we can simply use one tuple to store all information + if (factorizedTable->getNumTuples() == 0) { + tuple = factorizedTable->appendEmptyTuple(); + hasCreated = false; + } + KU_ASSERT(factorizedTable->getNumTuples() == 1); + return PatternCreationInfo{tuple, hasCreated}; + } else { + resizeHashTableIfNecessary(1); + computeVectorHashes(keyVectors); + findHashSlots(keyVectors, std::vector{}, keyVectors[0]->state.get()); + hasCreated = tuple != nullptr; + auto idTuple = tuple == nullptr ? + factorizedTable->getTuple(factorizedTable->getNumTuples() - 1) : + tuple; + return PatternCreationInfo{idTuple + idColOffset, hasCreated}; + } +} + +uint64_t PatternCreationInfoTable::matchFTEntries(std::span keyVectors, + uint64_t numMayMatches, uint64_t numNoMatches) { + numNoMatches = AggregateHashTable::matchFTEntries(keyVectors, numMayMatches, numNoMatches); + KU_ASSERT(numMayMatches <= 1); + // If we found the entry for the target key, we set tuple to the key tuple. Otherwise, simply + // set tuple to nullptr. + tuple = numMayMatches != 0 ? hashSlotsToUpdateAggState[mayMatchIdxes[0]]->getEntry() : nullptr; + return numNoMatches; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/result_set.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/result_set.cpp new file mode 100644 index 0000000000..56daf68f67 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/result_set.cpp @@ -0,0 +1,40 @@ +#include "processor/result/result_set.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +ResultSet::ResultSet(ResultSetDescriptor* resultSetDescriptor, + storage::MemoryManager* memoryManager) + : multiplicity{1} { + auto numDataChunks = resultSetDescriptor->dataChunkDescriptors.size(); + dataChunks.resize(numDataChunks); + for (auto i = 0u; i < numDataChunks; ++i) { + auto dataChunkDescriptor = resultSetDescriptor->dataChunkDescriptors[i].get(); + auto numValueVectors = dataChunkDescriptor->logicalTypes.size(); + auto dataChunk = std::make_unique(numValueVectors); + if (dataChunkDescriptor->isSingleState) { + dataChunk->state = DataChunkState::getSingleValueDataChunkState(); + } + for (auto j = 0u; j < numValueVectors; ++j) { + auto vector = std::make_shared(dataChunkDescriptor->logicalTypes[j].copy(), + memoryManager); + dataChunk->insert(j, std::move(vector)); + } + insert(i, std::move(dataChunk)); + } +} + +uint64_t ResultSet::getNumTuplesWithoutMultiplicity( + const std::unordered_set& dataChunksPosInScope) { + KU_ASSERT(!dataChunksPosInScope.empty()); + uint64_t numTuples = 1; + for (auto& dataChunkPos : dataChunksPosInScope) { + numTuples *= dataChunks[dataChunkPos]->state->getSelVector().getSelSize(); + } + return numTuples; +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/result_set_descriptor.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/result_set_descriptor.cpp new file mode 100644 index 0000000000..ab65971cf5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/result/result_set_descriptor.cpp @@ -0,0 +1,30 @@ +#include "processor/result/result_set_descriptor.h" + +#include "planner/operator/schema.h" + +namespace lbug { +namespace processor { + +ResultSetDescriptor::ResultSetDescriptor(planner::Schema* schema) { + for (auto i = 0u; i < schema->getNumGroups(); ++i) { + auto group = schema->getGroup(i); + auto dataChunkDescriptor = std::make_unique(group->isSingleState()); + for (auto& expression : group->getExpressions()) { + dataChunkDescriptor->logicalTypes.push_back(expression->getDataType().copy()); + } + dataChunkDescriptors.push_back(std::move(dataChunkDescriptor)); + } +} + +std::unique_ptr ResultSetDescriptor::copy() const { + std::vector> dataChunkDescriptorsCopy; + dataChunkDescriptorsCopy.reserve(dataChunkDescriptors.size()); + for (auto& dataChunkDescriptor : dataChunkDescriptors) { + dataChunkDescriptorsCopy.push_back( + std::make_unique(*dataChunkDescriptor)); + } + return std::make_unique(std::move(dataChunkDescriptorsCopy)); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/processor/warning_context.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/warning_context.cpp new file mode 100644 index 0000000000..54ced3d65f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/processor/warning_context.cpp @@ -0,0 +1,95 @@ +#include "processor/warning_context.h" + +#include "common/assert.h" +#include "common/uniq_lock.h" +#include "main/client_context.h" + +using namespace lbug::common; + +namespace lbug { +namespace processor { + +static PopulatedCopyFromError defaultPopulateFunc(CopyFromFileError error, common::idx_t) { + return PopulatedCopyFromError{ + .message = std::move(error.message), + .filePath = "", + .skippedLineOrRecord = "", + .lineNumber = 0, + }; +} + +static idx_t defaultGetFileIdxFunc(const CopyFromFileError&) { + return 0; +} + +WarningContext::WarningContext(main::ClientConfig* clientConfig) + : clientConfig{clientConfig}, queryWarningCount{0}, numStoredWarnings{0}, + ignoreErrorsOption{false} {} + +void WarningContext::appendWarningMessages(const std::vector& messages) { + UniqLock lock{mtx}; + + queryWarningCount += messages.size(); + + for (const auto& message : messages) { + if (numStoredWarnings >= clientConfig->warningLimit) { + break; + } + unpopulatedWarnings.push_back(message); + ++numStoredWarnings; + } +} + +const std::vector& WarningContext::getPopulatedWarnings() const { + // if there are still unpopulated warnings when we try to get populated warnings something is + // probably wrong + KU_ASSERT(unpopulatedWarnings.empty()); + return populatedWarnings; +} + +void WarningContext::defaultPopulateAllWarnings(uint64_t queryID) { + populateWarnings(queryID); +} + +void WarningContext::populateWarnings(uint64_t queryID, populate_func_t populateFunc, + get_file_idx_func_t getFileIdxFunc) { + if (!populateFunc) { + // if no populate functor is provided we default to just copying the message over + // and leaving the CSV fields unpopulated + populateFunc = defaultPopulateFunc; + } + if (!getFileIdxFunc) { + getFileIdxFunc = defaultGetFileIdxFunc; + } + for (auto& warning : unpopulatedWarnings) { + const auto fileIdx = getFileIdxFunc(warning); + populatedWarnings.emplace_back(populateFunc(std::move(warning), fileIdx), queryID); + } + unpopulatedWarnings.clear(); +} + +void WarningContext::clearPopulatedWarnings() { + populatedWarnings.clear(); + numStoredWarnings = 0; +} + +uint64_t WarningContext::getWarningCount(uint64_t) { + auto ret = queryWarningCount; + queryWarningCount = 0; + return ret; +} + +void WarningContext::setIgnoreErrorsForCurrentQuery(bool ignoreErrors) { + ignoreErrorsOption = ignoreErrors; +} + +bool WarningContext::getIgnoreErrorsOption() const { + return ignoreErrorsOption; +} + +WarningContext* WarningContext::Get(const main::ClientContext& context) { + return context.warningContext.get(); +} + +} // namespace processor +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/CMakeLists.txt new file mode 100644 index 0000000000..0ce8c3ff18 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/CMakeLists.txt @@ -0,0 +1,31 @@ +add_subdirectory(buffer_manager) +add_subdirectory(compression) +add_subdirectory(local_storage) +add_subdirectory(predicate) +add_subdirectory(index) +add_subdirectory(stats) +add_subdirectory(table) +add_subdirectory(wal) + +add_library(lbug_storage + OBJECT + checkpointer.cpp + database_header.cpp + disk_array.cpp + disk_array_collection.cpp + file_db_id_utils.cpp + file_handle.cpp + free_space_manager.cpp + optimistic_allocator.cpp + overflow_file.cpp + page_manager.cpp + shadow_file.cpp + shadow_utils.cpp + storage_manager.cpp + storage_utils.cpp + storage_version_info.cpp + undo_buffer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/CMakeLists.txt new file mode 100644 index 0000000000..d51720ce5e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(lbug_storage_buffer_manager + OBJECT + vm_region.cpp + buffer_manager.cpp + memory_manager.cpp + spiller.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/buffer_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/buffer_manager.cpp new file mode 100644 index 0000000000..52a32b599e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/buffer_manager.cpp @@ -0,0 +1,516 @@ +#include "storage/buffer_manager/buffer_manager.h" + +#include +#include +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/constants.h" +#include "common/exception/buffer_manager.h" +#include "common/file_system/local_file_system.h" +#include "common/file_system/virtual_file_system.h" +#include "common/types/types.h" +#include "main/db_config.h" +#include "storage/buffer_manager/spiller.h" +#include "storage/file_handle.h" +#include "storage/table/column_chunk_data.h" +#include + +#if defined(_WIN32) +#include + +#include +#include +#include +#include +#include +#endif + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +bool EvictionQueue::insert(uint32_t fileIndex, page_idx_t pageIndex) { + EvictionCandidate candidate{fileIndex, pageIndex}; + while (size < capacity) { + // Weak is fine since spurious failure is acceptable. + // The slot can always be filled later. + auto emptyCandidate = EMPTY; + if (data[insertCursor.fetch_add(1, std::memory_order_relaxed) % capacity] + .compare_exchange_weak(emptyCandidate, candidate)) { + size++; + return true; + } + } + return false; +} + +std::span, EvictionQueue::BATCH_SIZE> EvictionQueue::next() { + return std::span, BATCH_SIZE>( + data.get() + ((evictionCursor += BATCH_SIZE) % capacity), BATCH_SIZE); +} + +void EvictionQueue::clear(std::atomic& candidate) { + auto nonEmpty = candidate.load(); + if (nonEmpty != EMPTY && candidate.compare_exchange_strong(nonEmpty, EMPTY)) { + size--; + return; + } + KU_UNREACHABLE; +} + +BufferManager::BufferManager(const std::string& databasePath, const std::string& spillToDiskPath, + uint64_t bufferPoolSize, uint64_t maxDBSize, VirtualFileSystem* vfs, bool readOnly) + : bufferPoolSize{bufferPoolSize}, evictionQueue{bufferPoolSize / LBUG_PAGE_SIZE}, + usedMemory{evictionQueue.getCapacity() * sizeof(EvictionCandidate)}, vfs{vfs} { + verifySizeParams(bufferPoolSize, maxDBSize); +#if !BM_MALLOC + vmRegions[0] = std::make_unique(REGULAR_PAGE, maxDBSize); + vmRegions[1] = std::make_unique(TEMP_PAGE, bufferPoolSize); +#endif + + // TODO(bmwinger): It may be better to spill to disk in a different location for remote file + // systems, or even in general. + // Ideally we want to spill to disk in some temporary location such + // as /var/tmp (not /tmp since that may be backed by memory). However we also need to be able to + // support multiple databases spilling at once (can't be the same file), and handle different + // platforms. + if (!readOnly && !main::DBConfig::isDBPathInMemory(databasePath) && + dynamic_cast(vfs->findFileSystem(spillToDiskPath))) { + spiller = std::make_unique(spillToDiskPath, *this, vfs); + } +} + +void BufferManager::verifySizeParams(uint64_t bufferPoolSize, uint64_t maxDBSize) { + if (bufferPoolSize < LBUG_PAGE_SIZE) { + throw BufferManagerException(stringFormat( + "The given buffer pool size should be at least {} bytes.", LBUG_PAGE_SIZE)); + } + // We require at least two page groups, one for the main data file, and one for the shadow file. + if (maxDBSize < 2 * LBUG_PAGE_SIZE * StorageConstants::PAGE_GROUP_SIZE) { + throw BufferManagerException( + "The given max db size should be at least " + + std::to_string(2 * LBUG_PAGE_SIZE * StorageConstants::PAGE_GROUP_SIZE) + " bytes."); + } + if ((maxDBSize & (maxDBSize - 1)) != 0) { + throw BufferManagerException("The given max db size should be a power of 2."); + } +} + +// Important Note: Pin returns a raw pointer to the frame. This is potentially very dangerous and +// trusts the caller is going to protect this memory space. +// Important responsibilities for the caller are: +// (1) The caller should know the page size and not read/write beyond these boundaries. +// (2) If the given FileHandle is not a (temporary) in-memory file and the caller writes to the +// frame, caller should make sure to call setFrameDirty to let the BufferManager know that the page +// should be flushed to disk if it is evicted. +// (3) If multiple threads are writing to the page, they should coordinate separately because they +// both get access to the same piece of memory. +uint8_t* BufferManager::pin(FileHandle& fileHandle, page_idx_t pageIdx, + PageReadPolicy pageReadPolicy) { + auto pageState = fileHandle.getPageState(pageIdx); + while (true) { + auto currStateAndVersion = pageState->getStateAndVersion(); + switch (PageState::getState(currStateAndVersion)) { + case PageState::EVICTED: { + if (pageState->tryLock(currStateAndVersion)) { + if (!claimAFrame(fileHandle, pageIdx, pageReadPolicy)) { + pageState->resetToEvicted(); + throw BufferManagerException("Unable to allocate memory! The buffer pool is " + "full and no memory could be freed!"); + } + if (!evictionQueue.insert(fileHandle.getFileIndex(), pageIdx)) { + throw BufferManagerException( + "Eviction queue is full! This should be impossible."); + } +#if BM_MALLOC + KU_ASSERT(pageState->getPage()); + return pageState->getPage(); +#else + return getFrame(fileHandle, pageIdx); +#endif + } + } break; + case PageState::UNLOCKED: + case PageState::MARKED: { + if (pageState->tryLock(currStateAndVersion)) { + return getFrame(fileHandle, pageIdx); + } + } break; + case PageState::LOCKED: { + continue; + } + default: { + KU_UNREACHABLE; + } + } + } +} + +#if defined(WIN32) +class AccessViolation : public std::exception { +public: + AccessViolation(const uint8_t* location) : location{location} {} + + const uint8_t* location; +}; + +class ScopedTranslator { + const _se_translator_function old; + +public: + ScopedTranslator(_se_translator_function newTranslator) + : old{_set_se_translator(newTranslator)} {} + ~ScopedTranslator() { _set_se_translator(old); } +}; + +void handleAccessViolation(unsigned int exceptionCode, PEXCEPTION_POINTERS exceptionRecord) { + if (exceptionCode == EXCEPTION_ACCESS_VIOLATION + // exception was from a read + && exceptionRecord->ExceptionRecord->ExceptionInformation[0] == 0) [[likely]] { + throw AccessViolation( + (const uint8_t*)exceptionRecord->ExceptionRecord->ExceptionInformation[1]); + } + // Needs to not be an Exception so that it can't be caught by regular exception handling + // And is seems like throwing integer error codes is treated similarly to hardware + // exceptions with /EHa + throw exceptionCode; +} +#endif + +// Returns true if the function completes successfully +inline bool try_func(const std::function& func, uint8_t* frame, + const std::array, 2>& vmRegions [[maybe_unused]], + PageSizeClass pageSizeClass [[maybe_unused]], [[maybe_unused]] PageState* pageState) { +#if BM_MALLOC + if (frame == nullptr) { + return false; + } + pageState->addReader(); +#endif + +#if defined(_WIN32) && !BM_MALLOC + try { +#endif + func(frame); +#if defined(_WIN32) && !BM_MALLOC + } catch (AccessViolation& exc) { + // If we encounter an acess violation within the VM region, + // the page was decomitted by another thread + // and is no longer valid memory + if (vmRegions[pageSizeClass]->contains(exc.location)) { + return false; + } else { + throw EXCEPTION_ACCESS_VIOLATION; + } + } +#endif +#if BM_MALLOC + pageState->removeReader(); +#endif + return true; +} + +void BufferManager::optimisticRead(FileHandle& fileHandle, page_idx_t pageIdx, + const std::function& func) { + auto pageState = fileHandle.getPageState(pageIdx); +#if defined(_WIN32) + // Change the Structured Exception handling just for the scope of this function + auto translator = ScopedTranslator(handleAccessViolation); +#endif + while (true) { + auto currStateAndVersion = pageState->getStateAndVersion(); + switch (PageState::getState(currStateAndVersion)) { + case PageState::UNLOCKED: { + if (!try_func(func, getFrame(fileHandle, pageIdx), vmRegions, + fileHandle.getPageSizeClass(), pageState)) { + continue; + } + if (pageState->getStateAndVersion() == currStateAndVersion) { + return; + } + } break; + case PageState::MARKED: { + // If the page is marked, we try to switch to unlocked. + pageState->tryClearMark(currStateAndVersion); + continue; + } + case PageState::EVICTED: { + pin(fileHandle, pageIdx, PageReadPolicy::READ_PAGE); + unpin(fileHandle, pageIdx); + } break; + default: { + // When locked, continue the spinning. + continue; + } + } + } +} + +void BufferManager::unpin(FileHandle& fileHandle, page_idx_t pageIdx) { + auto pageState = fileHandle.getPageState(pageIdx); + pageState->unlock(); +} + +// evicts up to 64 pages and returns the space reclaimed +uint64_t BufferManager::evictPages() { + std::array*, EvictionQueue::BATCH_SIZE> evictionCandidates{}; + size_t evictablePages = 0; + uint64_t claimedMemory = 0; + + // Try each page at least twice. + // E.g. if the vast majority of pages are unmarked and unlocked, + // the first pass will mark them and the second pass, if insufficient marked pages + // are found, will evict the first batch. + // Using the eviction queue's cursor means that we fail after the same number of total attempts, + // regardless of how many threads are trying to evict. + auto startCursor = evictionQueue.getEvictionCursor(); + auto failureLimit = evictionQueue.getCapacity() * 2; + while (evictablePages == 0 && evictionQueue.getEvictionCursor() - startCursor < failureLimit) { + for (auto& candidate : evictionQueue.next()) { + auto evictionCandidate = candidate.load(); + if (evictionCandidate == EvictionQueue::EMPTY) { + continue; + } + KU_ASSERT(evictionCandidate.fileIdx < fileHandles.size()); + auto* pageState = + fileHandles[evictionCandidate.fileIdx]->getPageState(evictionCandidate.pageIdx); + auto pageStateAndVersion = pageState->getStateAndVersion(); + if (!evictionCandidate.isEvictable(pageStateAndVersion)) { + if (evictionCandidate.isSecondChanceEvictable(pageStateAndVersion)) { + pageState->tryMark(pageStateAndVersion); + } + continue; + } + evictionCandidates[evictablePages++] = &candidate; + } + } + + for (size_t i = 0; i < evictablePages; i++) { + claimedMemory += tryEvictPage(*evictionCandidates[i]); + } + return claimedMemory; +} + +void BufferManager::removeEvictedCandidates() { + auto startCursor = evictionQueue.getEvictionCursor(); + while (evictionQueue.getEvictionCursor() - startCursor < evictionQueue.getCapacity()) { + for (auto& candidate : evictionQueue.next()) { + auto evictionCandidate = candidate.load(); + if (evictionCandidate == EvictionQueue::EMPTY) { + continue; + } + KU_ASSERT(evictionCandidate.fileIdx < fileHandles.size()); + auto* pageState = + fileHandles[evictionCandidate.fileIdx]->getPageState(evictionCandidate.pageIdx); + auto pageStateAndVersion = pageState->getStateAndVersion(); + if (PageState::getState(pageStateAndVersion) == PageState::EVICTED) { + evictionQueue.clear(candidate); + } + } + } +} + +// This function tries to load the given page into a frame. Due to our design of mmap, each page is +// uniquely mapped to a frame. Thus, claiming a frame is equivalent to ensuring enough physical +// memory is available. +// First, we reserve the memory for the page, which increments the atomic counter `usedMemory`. +// Then, we check if there is enough memory available. If not, we evict pages until we have enough +// or we can find no more pages to be evicted. +// Lastly, we double check if the needed memory is available. If not, we free the memory we reserved +// and return false, otherwise, we load the page to its corresponding frame and return true. +bool BufferManager::claimAFrame(FileHandle& fileHandle, page_idx_t pageIdx, + PageReadPolicy pageReadPolicy) { + page_offset_t pageSizeToClaim = fileHandle.getPageSize(); + if (!reserve(pageSizeToClaim)) { + return false; + } +#if _WIN32 && !BM_MALLOC + // Committing in this context means reserving physical memory/page file space for a segment of + // virtual memory. On Linux/Unix this is automatic when you write to the memory address. + auto result = + VirtualAlloc(getFrame(fileHandle, pageIdx), pageSizeToClaim, MEM_COMMIT, PAGE_READWRITE); + if (result == NULL) { + throw BufferManagerException( + stringFormat("VirtualAlloc MEM_COMMIT failed with error code {}: {}.", GetLastError(), + std::system_category().message(GetLastError()))); + } +#endif + cachePageIntoFrame(fileHandle, pageIdx, pageReadPolicy); + return true; +} + +bool BufferManager::reserve(uint64_t sizeToReserve) { + // Reserve the memory for the page. + usedMemory += sizeToReserve; + uint64_t totalClaimedMemory = 0; + uint64_t nonEvictableClaimedMemory = 0; + const auto needMoreMemory = [&]() { + // The only time we should exceed the buffer pool size should be when threads are currently + // attempting to reserve space and have pre-allocated space. So if we've claimed enough + // space for what we're trying to reserve, then we can continue even if the current total is + // higher than the buffer pool size as we should never actually exceed the buffer pool size. + return sizeToReserve > totalClaimedMemory && + // usedMemory - totalClaimedMemory could underflow + usedMemory > bufferPoolSize.load() - totalClaimedMemory; + }; + uint8_t failedCount = 0; + // Evict pages if necessary until we have enough memory. + while (needMoreMemory()) { + uint64_t memoryClaimed = 0; + // Avoid reducing the evictable memory below 1/2 at first to reduce thrashing if most of the + // memory is non-evictable + if (!spiller || usedMemory - nonEvictableMemory > bufferPoolSize / 2) { + memoryClaimed = evictPages(); + } else { + auto [_memoryClaimed, nowEvictableMemory] = spiller->claimNextGroup(); + memoryClaimed = _memoryClaimed; + nonEvictableClaimedMemory += _memoryClaimed; + nonEvictableMemory -= nowEvictableMemory; + // If we're unable to claim anything from the spiller, fall back to evicting pages + // We may also need to evict pages if the spiller just unpins BM pages + if (memoryClaimed == 0 || nowEvictableMemory > 0) { + memoryClaimed = evictPages(); + } + } + if (memoryClaimed == 0 && needMoreMemory()) { + if (failedCount++ < 2) { + // If we failed to find any memory to free, try waiting briefly for other threads to + // stop using memory + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } else { + // Cannot find more pages to be evicted. Free the memory we reserved and return + // false. + freeUsedMemory(sizeToReserve + totalClaimedMemory); + nonEvictableMemory -= nonEvictableClaimedMemory; + return false; + } + } + totalClaimedMemory += memoryClaimed; + } + // Have enough memory available now + if (totalClaimedMemory > 0) { + freeUsedMemory(totalClaimedMemory); + nonEvictableMemory -= nonEvictableClaimedMemory; + } + return true; +} + +uint64_t BufferManager::tryEvictPage(std::atomic& _candidate) { + auto candidate = _candidate.load(); + // Page must have been evicted by another thread already + if (candidate.pageIdx == INVALID_PAGE_IDX) { + return 0; + } + auto& pageState = *fileHandles[candidate.fileIdx]->getPageState(candidate.pageIdx); + auto currStateAndVersion = pageState.getStateAndVersion(); + // We check if the page is evictable again. Note that if the page's state or version has + // changed after the check, `tryLock` will fail, and we will abort the eviction of this page. + if (!candidate.isEvictable(currStateAndVersion) || !pageState.tryLock(currStateAndVersion)) { + return 0; + } + // The pageState was locked, but another thread already evicted this candidate and unlocked it + // before the lock occurred + if (_candidate.load() != candidate +#if BM_MALLOC + // When the pageState is locked, optimisticReads will wait, so at this point no new + // optimistic reads will begin and thus it is safe to free the buffer at this point + || pageState.getReaderCount() > 0 +#endif + ) { + pageState.unlockUnchanged(); + return 0; + } + if (fileHandles[candidate.fileIdx]->isInMemoryMode()) { + // Cannot flush pages under in memory mode. + return 0; + } + // At this point, the page is LOCKED, and we have exclusive access to the eviction candidate. + // Next, flush out the frame into the file page if the frame + // is dirty. Finally remove the page from the frame and reset the page to EVICTED. + auto& fileHandle = *fileHandles[candidate.fileIdx]; + fileHandle.flushPageIfDirtyWithoutLock(candidate.pageIdx); + auto numBytesFreed = fileHandle.getPageSize(); + releaseFrameForPage(fileHandle, candidate.pageIdx); + pageState.resetToEvicted(); + evictionQueue.clear(_candidate); + return numBytesFreed; +} + +void BufferManager::cachePageIntoFrame(FileHandle& fileHandle, page_idx_t pageIdx, + PageReadPolicy pageReadPolicy) { + auto pageState = fileHandle.getPageState(pageIdx); + pageState->clearDirty(); +#if BM_MALLOC + pageState->allocatePage(fileHandle.getPageSize()); + if (pageReadPolicy == PageReadPolicy::READ_PAGE) { + fileHandle.readPageFromDisk(pageState->getPage(), pageIdx); + } +#else + if (pageReadPolicy == PageReadPolicy::READ_PAGE) { + fileHandle.readPageFromDisk(getFrame(fileHandle, pageIdx), pageIdx); + } +#endif +} + +void BufferManager::removeFilePagesFromFrames(FileHandle& fileHandle) { + for (auto pageIdx = 0u; pageIdx < fileHandle.getNumPages(); ++pageIdx) { + removePageFromFrame(fileHandle, pageIdx, false /* do not flush */); + } +} + +void BufferManager::updateFrameIfPageIsInFrameWithoutLock(file_idx_t fileIdx, + const uint8_t* newPage, page_idx_t pageIdx) { + KU_ASSERT(fileIdx < fileHandles.size()); + auto& fileHandle = *fileHandles[fileIdx]; + auto state = fileHandle.getPageState(pageIdx); + if (state && state->getState() != PageState::EVICTED) { + memcpy(getFrame(fileHandle, pageIdx), newPage, LBUG_PAGE_SIZE); + } +} + +void BufferManager::removePageFromFrameIfNecessary(FileHandle& fileHandle, page_idx_t pageIdx) { + if (pageIdx >= fileHandle.getNumPages()) { + return; + } + removePageFromFrame(fileHandle, pageIdx, false /* do not flush */); +} + +// NOTE: We assume the page is not pinned (locked) here. +void BufferManager::removePageFromFrame(FileHandle& fileHandle, page_idx_t pageIdx, + bool shouldFlush) { + auto pageState = fileHandle.getPageState(pageIdx); + if (PageState::getState(pageState->getStateAndVersion()) == PageState::EVICTED) { + return; + } + pageState->spinLock(pageState->getStateAndVersion()); + if (shouldFlush) { + fileHandle.flushPageIfDirtyWithoutLock(pageIdx); + } + releaseFrameForPage(fileHandle, pageIdx); + freeUsedMemory(fileHandle.getPageSize()); + pageState->resetToEvicted(); +} + +uint64_t BufferManager::freeUsedMemory(uint64_t size) { + KU_ASSERT(usedMemory.load() >= size); + return usedMemory.fetch_sub(size); +} + +void BufferManager::resetSpiller(std::string spillPath) { + if (spillPath.empty()) { + // Disable spilling to disk; + spiller = nullptr; + } else { + spiller = std::make_unique(spillPath, *this, vfs); + } +} + +BufferManager::~BufferManager() = default; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/memory_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/memory_manager.cpp new file mode 100644 index 0000000000..b05f8a7d4a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/memory_manager.cpp @@ -0,0 +1,115 @@ +#include "storage/buffer_manager/memory_manager.h" + +#include + +#include "common/exception/buffer_manager.h" +#include "common/file_system/virtual_file_system.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "main/database.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/file_handle.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +MemoryBuffer::MemoryBuffer(MemoryManager* mm, page_idx_t pageIdx, uint8_t* buffer, uint64_t size) + : buffer{buffer, static_cast(size)}, mm{mm}, pageIdx{pageIdx}, evicted{false} {} + +MemoryBuffer::~MemoryBuffer() { + if (buffer.data() != nullptr && !evicted) { + mm->freeBlock(pageIdx, buffer); + mm->updateUsedMemoryForFreedBlock(pageIdx, buffer); + buffer = std::span(); + } +} + +SpillResult MemoryBuffer::setSpilledToDisk(uint64_t filePosition) { + mm->freeBlock(pageIdx, buffer); + // reinterpret_cast isn't allowed here, but we shouldn't leave the invalid pointer and + // still want to store the size + buffer = std::span(static_cast(nullptr), buffer.size()); + evicted = true; + this->filePosition = filePosition; + if (pageIdx == INVALID_PAGE_IDX) { + return SpillResult{buffer.size(), 0}; + } else { + return SpillResult{0, buffer.size()}; + } +} + +void MemoryBuffer::prepareLoadFromDisk() { + KU_ASSERT(buffer.data() == nullptr && evicted); + buffer = mm->mallocBuffer(false, buffer.size()); + evicted = false; +} + +MemoryManager::MemoryManager(BufferManager* bm, VirtualFileSystem* vfs) : bm{bm} { + pageSize = TEMP_PAGE_SIZE; + fh = bm->getFileHandle("mm-256KB", FileHandle::O_IN_MEM_TEMP_FILE, vfs, nullptr); +} + +std::span MemoryManager::mallocBuffer(bool initializeToZero, uint64_t size) { + if (!bm->reserve(size)) { + throw BufferManagerException( + "Unable to allocate memory! The buffer pool is full and no memory could be freed!"); + } + void* buffer = nullptr; + bm->nonEvictableMemory += size; + if (initializeToZero) { + buffer = calloc(size, 1); + } else { + buffer = malloc(size); + } + return std::span(static_cast(buffer), size); +} + +std::unique_ptr MemoryManager::allocateBuffer(bool initializeToZero, uint64_t size) { + if (size != TEMP_PAGE_SIZE) [[unlikely]] { + auto buffer = mallocBuffer(initializeToZero, size); + return std::make_unique(this, INVALID_PAGE_IDX, buffer.data(), size); + } + page_idx_t pageIdx = INVALID_PAGE_IDX; + { + std::scoped_lock lock(allocatorLock); + if (freePages.empty()) { + pageIdx = fh->addNewPage(); + } else { + pageIdx = freePages.top(); + freePages.pop(); + } + } + auto buffer = bm->pin(*fh, pageIdx, PageReadPolicy::DONT_READ_PAGE); + auto memoryBuffer = std::make_unique(this, pageIdx, buffer); + if (initializeToZero) { + memset(memoryBuffer->getBuffer().data(), 0, pageSize); + } + return memoryBuffer; +} + +void MemoryManager::freeBlock(page_idx_t pageIdx, std::span buffer) { + if (pageIdx == INVALID_PAGE_IDX) { + std::free(buffer.data()); + } else { + bm->unpin(*fh, pageIdx); + } +} + +void MemoryManager::updateUsedMemoryForFreedBlock(page_idx_t pageIdx, std::span buffer) { + if (pageIdx == INVALID_PAGE_IDX) { + bm->freeUsedMemory(buffer.size()); + bm->nonEvictableMemory -= buffer.size(); + } else { + std::unique_lock lock(allocatorLock); + freePages.push(pageIdx); + } +} + +MemoryManager* MemoryManager::Get(const main::ClientContext& context) { + return context.getDatabase()->getMemoryManager(); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/spiller.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/spiller.cpp new file mode 100644 index 0000000000..14ac3a01c5 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/spiller.cpp @@ -0,0 +1,115 @@ +#include "storage/buffer_manager/spiller.h" + +#include + +#include "common/assert.h" +#include "common/exception/io.h" +#include "common/file_system/virtual_file_system.h" +#include "common/types/types.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/file_handle.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/column_chunk_data.h" + +namespace lbug { +namespace storage { + +Spiller::Spiller(std::string tmpFilePath, BufferManager& bufferManager, + common::VirtualFileSystem* vfs) + : tmpFilePath{std::move(tmpFilePath)}, bufferManager{bufferManager}, vfs{vfs}, dataFH{nullptr} { + // Clear the file if it already existed (e.g. from a previous run which + // failed to clean up). + vfs->removeFileIfExists(this->tmpFilePath); +} + +FileHandle* Spiller::getOrCreateDataFH() const { + if (dataFH.load()) { + return dataFH; + } + std::unique_lock lock(fileCreationMutex); + // Another thread may have created the file while the lock was being acquired + if (dataFH.load()) { + return dataFH; + } + const_cast(this)->dataFH = bufferManager.getFileHandle(tmpFilePath, + FileHandle::O_PERSISTENT_FILE_CREATE_NOT_EXISTS, vfs, nullptr); + return dataFH; +} + +FileHandle* Spiller::getDataFH() const { + if (dataFH.load()) { + return dataFH; + } + return nullptr; +} + +void Spiller::addUnusedChunk(InMemChunkedNodeGroup* nodeGroup) { + std::unique_lock lock(partitionerGroupsMtx); + fullPartitionerGroups.insert(nodeGroup); +} + +void Spiller::clearUnusedChunk(InMemChunkedNodeGroup* nodeGroup) { + std::unique_lock lock(partitionerGroupsMtx); + auto entry = fullPartitionerGroups.find(nodeGroup); + if (entry != fullPartitionerGroups.end()) { + fullPartitionerGroups.erase(entry); + } +} + +Spiller::~Spiller() { + // This should be safe as long as the VFS is always using a local file system and the VFS is + // destroyed after the buffer manager + try { + vfs->removeFileIfExists(this->tmpFilePath); + } catch (common::IOException&) {} // NOLINT +} + +SpillResult Spiller::spillToDisk(ColumnChunkData& chunk) const { + auto& buffer = *chunk.buffer; + KU_ASSERT(!buffer.evicted); + auto dataFH = getOrCreateDataFH(); + auto pageSize = dataFH->getPageSize(); + auto numPages = (buffer.buffer.size_bytes() + pageSize - 1) / pageSize; + auto startPage = dataFH->addNewPages(numPages); + dataFH->writePagesToFile(buffer.buffer.data(), buffer.buffer.size_bytes(), startPage); + return buffer.setSpilledToDisk(startPage * pageSize); +} + +void Spiller::loadFromDisk(ColumnChunkData& chunk) const { + auto& buffer = *chunk.buffer; + if (buffer.evicted) { + buffer.prepareLoadFromDisk(); + auto dataFH = getDataFH(); + KU_ASSERT(dataFH); + dataFH->getFileInfo()->readFromFile(buffer.buffer.data(), buffer.buffer.size(), + buffer.filePosition); + } +} + +SpillResult Spiller::claimNextGroup() { + InMemChunkedNodeGroup* groupToFlush = nullptr; + { + std::unique_lock lock(partitionerGroupsMtx); + if (!fullPartitionerGroups.empty()) { + auto groupToFlushEntry = fullPartitionerGroups.begin(); + groupToFlush = *groupToFlushEntry; + fullPartitionerGroups.erase(groupToFlushEntry); + } + } + if (groupToFlush == nullptr) { + return SpillResult{}; + } + return groupToFlush->spillToDisk(); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Function shouldn't be re-ordered +void Spiller::clearFile() { + auto curDataFH = getDataFH(); + if (curDataFH) { + curDataFH->getFileInfo()->truncate(0); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/vm_region.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/vm_region.cpp new file mode 100644 index 0000000000..44769c3bae --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/buffer_manager/vm_region.cpp @@ -0,0 +1,91 @@ +#include "storage/buffer_manager/vm_region.h" + +#include "common/string_format.h" +#include "common/system_config.h" +#include "common/system_message.h" + +#ifdef _WIN32 +#include +#include +#include +#else +#include +#endif + +#include "common/exception/buffer_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +VMRegion::VMRegion(PageSizeClass pageSizeClass, uint64_t maxRegionSize) : numFrameGroups{0} { + if (maxRegionSize > static_cast(-1)) { + throw BufferManagerException("maxRegionSize is beyond the max available mmap region size."); + } + frameSize = pageSizeClass == REGULAR_PAGE ? LBUG_PAGE_SIZE : TEMP_PAGE_SIZE; + const auto numBytesForFrameGroup = frameSize * StorageConstants::PAGE_GROUP_SIZE; + maxNumFrameGroups = (maxRegionSize + numBytesForFrameGroup - 1) / numBytesForFrameGroup; +#ifdef _WIN32 + region = (uint8_t*)VirtualAlloc(NULL, getMaxRegionSize(), MEM_RESERVE, PAGE_READWRITE); + if (region == NULL) { + throw BufferManagerException(stringFormat( + "VirtualAlloc for size {} failed with error code {}: {}.", getMaxRegionSize(), + GetLastError(), std::system_category().message(GetLastError()))); + } +#else + // Create a private anonymous mapping. The mapping is not shared with other processes and not + // backed by any file, and its content are initialized to zero. + region = static_cast(mmap(NULL, getMaxRegionSize(), PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE, -1 /* fd */, 0 /* offset */)); + if (region == MAP_FAILED) { + throw BufferManagerException( + "Mmap for size " + std::to_string(getMaxRegionSize()) + " failed."); + } +#endif +} + +VMRegion::~VMRegion() { +#ifdef _WIN32 + VirtualFree(region, 0, MEM_RELEASE); +#else + munmap(region, getMaxRegionSize()); +#endif +} + +void VMRegion::releaseFrame(frame_idx_t frameIdx) const { +#ifdef _WIN32 + // TODO: VirtualAlloc(..., MEM_RESET, ...) may be faster + // See https://arvid.io/2018/04/02/memory-mapping-on-windows/#1 + // Not sure what the differences are + if (!VirtualFree(getFrame(frameIdx), frameSize, MEM_DECOMMIT)) { + auto code = GetLastError(); + throw BufferManagerException(stringFormat( + "Releasing physical memory associated with a frame failed with error code {}: {}.", + code, systemErrMessage(code))); + } + +#else + int error = madvise(getFrame(frameIdx), frameSize, MADV_DONTNEED); + if (error != 0) { + // LCOV_EXCL_START + throw BufferManagerException(stringFormat( + "Releasing physical memory associated with a frame failed with error code {}: {}.", + error, posixErrMessage())); + // LCOV_EXCL_STOP + } +#endif +} + +frame_group_idx_t VMRegion::addNewFrameGroup() { + std::unique_lock xLck{mtx}; + if (numFrameGroups >= maxNumFrameGroups) { + // LCOV_EXCL_START + throw BufferManagerException("No more frame groups can be added to the allocator."); + // LCOV_EXCL_STOP + } + return numFrameGroups++; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/checkpointer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/checkpointer.cpp new file mode 100644 index 0000000000..b9d8dcad81 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/checkpointer.cpp @@ -0,0 +1,224 @@ +#include "storage/checkpointer.h" + +#include "catalog/catalog.h" +#include "common/file_system/file_system.h" +#include "common/file_system/virtual_file_system.h" +#include "common/serializer/buffered_file.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/in_mem_file_writer.h" +#include "extension/extension_manager.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/database_header.h" +#include "storage/shadow_utils.h" +#include "storage/storage_manager.h" +#include "storage/wal/local_wal.h" + +namespace lbug { +namespace storage { + +Checkpointer::Checkpointer(main::ClientContext& clientContext) + : clientContext{clientContext}, + isInMemory{main::DBConfig::isDBPathInMemory(clientContext.getDatabasePath())} {} + +Checkpointer::~Checkpointer() = default; + +PageRange Checkpointer::serializeCatalog(const catalog::Catalog& catalog, + StorageManager& storageManager) { + auto catalogWriter = + std::make_shared(*MemoryManager::Get(clientContext)); + common::Serializer catalogSerializer(catalogWriter); + catalog.serialize(catalogSerializer); + auto pageAllocator = storageManager.getDataFH()->getPageManager(); + return catalogWriter->flush(*pageAllocator, storageManager.getShadowFile()); +} + +PageRange Checkpointer::serializeMetadata(const catalog::Catalog& catalog, + StorageManager& storageManager) { + auto metadataWriter = + std::make_shared(*MemoryManager::Get(clientContext)); + common::Serializer metadataSerializer(metadataWriter); + storageManager.serialize(catalog, metadataSerializer); + + // We need to preallocate the pages for the page manager before we actually serialize it, + // this is because the page manager needs to track the pages used for itself. + // The number of pages needed for the page manager should only decrease after making an + // additional allocation, so we just calculate the number of pages needed to serialize the + // current state of the page manager. + // Thus, it is possible that we allocate an extra page that we won't end up writing to when we + // flush the metadata writer. This may cause a discrepancy between the number of tracked pages + // and the number of physical pages in the file but shouldn't cause any actual incorrect + // behavior in the database. + auto& pageManager = *storageManager.getDataFH()->getPageManager(); + const auto pagesForPageManager = pageManager.estimatePagesNeededForSerialize(); + auto pageAllocator = storageManager.getDataFH()->getPageManager(); + const auto allocatedPages = pageAllocator->allocatePageRange( + metadataWriter->getNumPagesToFlush() + pagesForPageManager); + pageManager.serialize(metadataSerializer); + + metadataWriter->flush(allocatedPages, pageAllocator->getDataFH(), + storageManager.getShadowFile()); + return allocatedPages; +} + +void Checkpointer::writeCheckpoint() { + if (isInMemory) { + return; + } + + auto databaseHeader = + *StorageManager::Get(clientContext)->getOrInitDatabaseHeader(clientContext); + // Checkpoint storage. Note that we first checkpoint storage before serializing the catalog, as + // checkpointing storage may overwrite columnIDs in the catalog. + bool hasStorageChanges = checkpointStorage(); + serializeCatalogAndMetadata(databaseHeader, hasStorageChanges); + writeDatabaseHeader(databaseHeader); + logCheckpointAndApplyShadowPages(); + + // This function will evict all pages that were freed during this checkpoint + // It must be called before we remove all evicted candidates from the BM + // Or else the evicted pages may end up appearing multiple times in the eviction queue + auto storageManager = StorageManager::Get(clientContext); + storageManager->finalizeCheckpoint(); + // When a page is freed by the FSM, it evicts it from the BM. However, if the page is freed, + // then reused over and over, it can be appended to the eviction queue multiple times. To + // prevent multiple entries of the same page from existing in the eviction queue, at the end of + // each checkpoint we remove any already-evicted pages. + auto bufferManager = MemoryManager::Get(clientContext)->getBufferManager(); + bufferManager->removeEvictedCandidates(); + + catalog::Catalog::Get(clientContext)->resetVersion(); + auto* dataFH = storageManager->getDataFH(); + dataFH->getPageManager()->resetVersion(); + storageManager->getWAL().reset(); + storageManager->getShadowFile().reset(); +} + +bool Checkpointer::checkpointStorage() { + const auto storageManager = StorageManager::Get(clientContext); + auto pageAllocator = storageManager->getDataFH()->getPageManager(); + return storageManager->checkpoint(&clientContext, *pageAllocator); +} + +void Checkpointer::serializeCatalogAndMetadata(DatabaseHeader& databaseHeader, + bool hasStorageChanges) { + const auto storageManager = StorageManager::Get(clientContext); + const auto catalog = catalog::Catalog::Get(clientContext); + auto* dataFH = storageManager->getDataFH(); + + // Serialize the catalog if there are changes + if (databaseHeader.catalogPageRange.startPageIdx == common::INVALID_PAGE_IDX || + catalog->changedSinceLastCheckpoint()) { + databaseHeader.updateCatalogPageRange(*dataFH->getPageManager(), + serializeCatalog(*catalog, *storageManager)); + } + // Serialize the storage metadata if there are changes + if (databaseHeader.metadataPageRange.startPageIdx == common::INVALID_PAGE_IDX || + hasStorageChanges || catalog->changedSinceLastCheckpoint() || + dataFH->getPageManager()->changedSinceLastCheckpoint()) { + // We must free the existing metadata page range before serializing + // So that the freed pages are serialized by the FSM + databaseHeader.freeMetadataPageRange(*dataFH->getPageManager()); + databaseHeader.metadataPageRange = serializeMetadata(*catalog, *storageManager); + } +} + +void Checkpointer::writeDatabaseHeader(const DatabaseHeader& header) { + auto headerWriter = + std::make_shared(*MemoryManager::Get(clientContext)); + common::Serializer headerSerializer(headerWriter); + header.serialize(headerSerializer); + auto headerPage = headerWriter->getPage(0); + + const auto storageManager = StorageManager::Get(clientContext); + auto dataFH = storageManager->getDataFH(); + auto& shadowFile = storageManager->getShadowFile(); + auto shadowHeader = ShadowUtils::createShadowVersionIfNecessaryAndPinPage( + common::StorageConstants::DB_HEADER_PAGE_IDX, true /* skipReadingOriginalPage */, *dataFH, + shadowFile); + memcpy(shadowHeader.frame, headerPage.data(), common::LBUG_PAGE_SIZE); + shadowFile.getShadowingFH().unpinPage(shadowHeader.shadowPage); + + // Update the in-memory database header with the new version + StorageManager::Get(clientContext)->setDatabaseHeader(std::make_unique(header)); +} + +void Checkpointer::logCheckpointAndApplyShadowPages() { + const auto storageManager = StorageManager::Get(clientContext); + auto& shadowFile = storageManager->getShadowFile(); + // Flush the shadow file. + shadowFile.flushAll(clientContext); + auto wal = WAL::Get(clientContext); + // Log the checkpoint to the WAL and flush WAL. This indicates that all shadow pages and + // files (snapshots of catalog and metadata) have been written to disk. The part that is not + // done is to replace them with the original pages or catalog and metadata files. If the + // system crashes before this point, the WAL can still be used to recover the system to a + // state where the checkpoint can be redone. + wal->logAndFlushCheckpoint(&clientContext); + shadowFile.applyShadowPages(clientContext); + // Clear the wal and also shadowing files. + auto bufferManager = MemoryManager::Get(clientContext)->getBufferManager(); + wal->clear(); + shadowFile.clear(*bufferManager); +} + +void Checkpointer::rollback() { + if (isInMemory) { + return; + } + const auto storageManager = StorageManager::Get(clientContext); + auto catalog = catalog::Catalog::Get(clientContext); + // Any pages freed during the checkpoint are no longer freed + storageManager->rollbackCheckpoint(*catalog); +} + +bool Checkpointer::canAutoCheckpoint(const main::ClientContext& clientContext, + const transaction::Transaction& transaction) { + if (clientContext.isInMemory()) { + return false; + } + if (!clientContext.getDBConfig()->autoCheckpoint) { + return false; + } + if (transaction.isRecovery()) { + // Recovery transactions are not allowed to trigger auto checkpoint. + return false; + } + auto wal = WAL::Get(clientContext); + const auto expectedSize = transaction.getLocalWAL().getSize() + wal->getFileSize(); + return expectedSize > clientContext.getDBConfig()->checkpointThreshold; +} + +void Checkpointer::readCheckpoint() { + auto storageManager = StorageManager::Get(clientContext); + storageManager->initDataFileHandle(common::VirtualFileSystem::GetUnsafe(clientContext), + &clientContext); + if (!isInMemory && storageManager->getDataFH()->getNumPages() > 0) { + readCheckpoint(&clientContext, catalog::Catalog::Get(clientContext), storageManager); + } + extension::ExtensionManager::Get(clientContext)->autoLoadLinkedExtensions(&clientContext); +} + +void Checkpointer::readCheckpoint(main::ClientContext* context, catalog::Catalog* catalog, + StorageManager* storageManager) { + auto fileInfo = storageManager->getDataFH()->getFileInfo(); + auto reader = std::make_unique(*fileInfo); + common::Deserializer deSer(std::move(reader)); + auto currentHeader = std::make_unique(DatabaseHeader::deserialize(deSer)); + // If the catalog page range is invalid, it means there is no catalog to read; thus, the + // database is empty. + if (currentHeader->catalogPageRange.startPageIdx != common::INVALID_PAGE_IDX) { + deSer.getReader()->cast()->resetReadOffset( + currentHeader->catalogPageRange.startPageIdx * common::LBUG_PAGE_SIZE); + catalog->deserialize(deSer); + deSer.getReader()->cast()->resetReadOffset( + currentHeader->metadataPageRange.startPageIdx * common::LBUG_PAGE_SIZE); + storageManager->deserialize(context, catalog, deSer); + storageManager->getDataFH()->getPageManager()->deserialize(deSer); + } + storageManager->setDatabaseHeader(std::move(currentHeader)); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/CMakeLists.txt new file mode 100644 index 0000000000..86c36f06ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(lbug_storage_compression + OBJECT + compression.cpp + float_compression.cpp + bitpacking_int128.cpp + bitpacking_utils.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) + +target_link_libraries(lbug_storage_compression PRIVATE fastpfor libalp) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/bitpacking_int128.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/bitpacking_int128.cpp new file mode 100644 index 0000000000..f630c4c0d4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/bitpacking_int128.cpp @@ -0,0 +1,171 @@ +// Adapted from +// https://github.com/duckdb/duckdb/blob/main/src/storage/compression/bitpacking_hugeint.cpp + +#include "storage/compression/bitpacking_int128.h" + +#include "storage/compression/bitpacking_utils.h" +#include "storage/compression/compression.h" + +namespace lbug::storage { + +//===--------------------------------------------------------------------===// +// Unpacking +//===--------------------------------------------------------------------===// + +static void unpackLast(const uint32_t* __restrict in, common::int128_t* __restrict out, + uint16_t delta) { + const uint8_t LAST_IDX = 31; + const uint16_t SHIFT = (delta * 31) % 32; + out[LAST_IDX] = in[0] >> SHIFT; + if (delta > 32) { + out[LAST_IDX] |= static_cast(in[1]) << (32 - SHIFT); + } + if (delta > 64) { + out[LAST_IDX] |= static_cast(in[2]) << (64 - SHIFT); + } + if (delta > 96) { + out[LAST_IDX] |= static_cast(in[3]) << (96 - SHIFT); + } +} + +// Unpacks for specific deltas +static void unpackDelta0(common::int128_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + out[i] = 0; + } +} + +static void unpackDelta32(const uint32_t* __restrict in, common::int128_t* __restrict out) { + for (uint8_t k = 0; k < 32; ++k) { + out[k] = static_cast(in[k]); + } +} + +static void unpackDelta64(const uint32_t* __restrict in, common::int128_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = i * 2; + out[i] = in[OFFSET]; + out[i] |= static_cast(in[OFFSET + 1]) << 32; + } +} + +static void unpackDelta96(const uint32_t* __restrict in, common::int128_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = i * 3; + out[i] = in[OFFSET]; + out[i] |= static_cast(in[OFFSET + 1]) << 32; + out[i] |= static_cast(in[OFFSET + 2]) << 64; + } +} + +static void unpackDelta128(const uint32_t* __restrict in, common::int128_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = i * 4; + out[i] = in[OFFSET]; + out[i] |= static_cast(in[OFFSET + 1]) << 32; + out[i] |= static_cast(in[OFFSET + 2]) << 64; + out[i] |= static_cast(in[OFFSET + 3]) << 96; + } +} + +//===--------------------------------------------------------------------===// +// Packing +//===--------------------------------------------------------------------===// + +// Packs for specific deltas +static void packDelta32(const common::int128_t* __restrict in, uint32_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + out[i] = static_cast(in[i]); + } +} + +static void packDelta64(const common::int128_t* __restrict in, uint32_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = 2 * i; + out[OFFSET] = static_cast(in[i]); + out[OFFSET + 1] = static_cast(in[i] >> 32); + } +} + +static void packDelta96(const common::int128_t* __restrict in, uint32_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = 3 * i; + out[OFFSET] = static_cast(in[i]); + out[OFFSET + 1] = static_cast(in[i] >> 32); + out[OFFSET + 2] = static_cast(in[i] >> 64); + } +} + +static void packDelta128(const common::int128_t* __restrict in, uint32_t* __restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = 4 * i; + out[OFFSET] = static_cast(in[i]); + out[OFFSET + 1] = static_cast(in[i] >> 32); + out[OFFSET + 2] = static_cast(in[i] >> 64); + out[OFFSET + 3] = static_cast(in[i] >> 96); + } +} + +//===--------------------------------------------------------------------===// +// HugeIntPacker +//===--------------------------------------------------------------------===// + +void Int128Packer::pack(const common::int128_t* __restrict in, uint32_t* __restrict out, + uint8_t width) { + KU_ASSERT(width <= 128); + switch (width) { + case 0: + break; + case 32: + packDelta32(in, out); + break; + case 64: + packDelta64(in, out); + break; + case 96: + packDelta96(in, out); + break; + case 128: + packDelta128(in, out); + break; + default: + for (common::idx_t oindex = 0; oindex < IntegerBitpacking::CHUNK_SIZE; + ++oindex) { + BitpackingUtils::packSingle(in[oindex], + reinterpret_cast(out), width, oindex); + } + } +} + +void Int128Packer::unpack(const uint32_t* __restrict in, common::int128_t* __restrict out, + uint8_t width) { + KU_ASSERT(width <= 128); + switch (width) { + case 0: + unpackDelta0(out); + break; + case 32: + unpackDelta32(in, out); + break; + case 64: + unpackDelta64(in, out); + break; + case 96: + unpackDelta96(in, out); + break; + case 128: + unpackDelta128(in, out); + break; + default: + for (common::idx_t oindex = 0; oindex < IntegerBitpacking::CHUNK_SIZE - 1; + ++oindex) { + BitpackingUtils::unpackSingle(reinterpret_cast(in), + out + oindex, width, oindex); + } + unpackLast(in + +(IntegerBitpacking::CHUNK_SIZE - 1) * width / + (sizeof(uint32_t) * 8), + out, width); + } +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/bitpacking_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/bitpacking_utils.cpp new file mode 100644 index 0000000000..6c4177952a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/bitpacking_utils.cpp @@ -0,0 +1,140 @@ +// packSingle/unpackSingle are adapted from +// https://github.com/duckdb/duckdb/blob/main/src/storage/compression/bitpacking_hugeint.cpp + +#include "storage/compression/bitpacking_utils.h" + +#include "common/utils.h" + +namespace lbug::storage { +namespace { +template +void unpackSingleField(const CompressedType* __restrict in, UncompressedType* __restrict out, + uint16_t delta, uint16_t shiftRight) { + static constexpr size_t compressedFieldSizeBits = sizeof(CompressedType) * 8; + + if constexpr (compressed_field == 0) { + *out = static_cast(in[0]) >> shiftRight; + } else { + unpackSingleField(in, out, delta, shiftRight); + KU_ASSERT( + sizeof(UncompressedType) * 8 > compressed_field * compressedFieldSizeBits - shiftRight); + *out |= static_cast(in[compressed_field]) + << (compressed_field * compressedFieldSizeBits - shiftRight); + } +} + +template +void unpackSingleValueInPlace(const CompressedType* __restrict in, UncompressedType* __restrict out, + uint16_t delta, uint16_t shiftRight) { + static_assert(sizeof(UncompressedType) <= 4 * sizeof(CompressedType)); + + static constexpr size_t compressedFieldSizeBits = sizeof(CompressedType) * 8; + + if (delta + shiftRight <= compressedFieldSizeBits) { + unpackSingleField<0>(in, out, delta, shiftRight); + } else if (delta + shiftRight > compressedFieldSizeBits && + delta + shiftRight <= 2 * compressedFieldSizeBits) { + unpackSingleField<1>(in, out, delta, shiftRight); + } else if (delta + shiftRight > 2 * compressedFieldSizeBits && + delta + shiftRight <= 3 * compressedFieldSizeBits) { + unpackSingleField<2>(in, out, delta, shiftRight); + } else if (delta + shiftRight > 3 * compressedFieldSizeBits && + delta + shiftRight <= 4 * compressedFieldSizeBits) { + unpackSingleField<3>(in, out, delta, shiftRight); + } else if (delta + shiftRight > 4 * compressedFieldSizeBits) { + unpackSingleField<4>(in, out, delta, shiftRight); + } + + // we previously copy over the entire most significant field + // zero out the bits that are not actually part of the compressed value + *out &= common::BitmaskUtils::all1sMaskForLeastSignificantBits(delta); +} + +template +void setValueForBitsMatchingMask(CompressedType& out, UncompressedType unshiftedValue, + UncompressedType unshiftedMask, size_t shift) { + CompressedType valueToSet = 0; + CompressedType mask = 0; + if constexpr (shiftRight) { + valueToSet = static_cast((unshiftedValue & unshiftedMask) >> shift); + mask = static_cast(unshiftedMask >> shift); + } else { + valueToSet = static_cast((unshiftedValue & unshiftedMask) << shift); + mask = static_cast(unshiftedMask << shift); + } + const CompressedType bitsToSet = valueToSet & mask; + const CompressedType bitsToClear = ~mask | valueToSet; + out = (out | bitsToSet) & bitsToClear; +} + +template +void packSingleField(const UncompressedType in, CompressedType* __restrict out, uint16_t delta, + uint16_t shiftLeft, UncompressedType mask) { + static constexpr size_t compressedFieldSizeBits = sizeof(CompressedType) * 8; + + if constexpr (compressed_field == 0) { + setValueForBitsMatchingMask(out[0], in, mask, shiftLeft); + } else { + packSingleField(in, out, delta, shiftLeft, mask); + KU_ASSERT( + sizeof(UncompressedType) * 8 > compressed_field * compressedFieldSizeBits - shiftLeft); + + setValueForBitsMatchingMask(out[compressed_field], in, mask, + (compressed_field * compressedFieldSizeBits - shiftLeft)); + } +} + +template +void packSingleImpl(const UncompressedType in, CompressedType* __restrict out, uint16_t delta, + uint16_t shiftLeft, UncompressedType mask) { + static_assert(sizeof(UncompressedType) <= 4 * sizeof(CompressedType)); + + static constexpr size_t compressedFieldSizeBits = sizeof(CompressedType) * 8; + + if (delta + shiftLeft <= compressedFieldSizeBits) { + packSingleField<0>(in, out, delta, shiftLeft, mask); + } else if (delta + shiftLeft > compressedFieldSizeBits && + delta + shiftLeft <= 2 * compressedFieldSizeBits) { + packSingleField<1>(in, out, delta, shiftLeft, mask); + } else if (delta + shiftLeft > 2 * compressedFieldSizeBits && + delta + shiftLeft <= 3 * compressedFieldSizeBits) { + packSingleField<2>(in, out, delta, shiftLeft, mask); + } else if (delta + shiftLeft > 3 * compressedFieldSizeBits && + delta + shiftLeft <= 4 * compressedFieldSizeBits) { + packSingleField<3>(in, out, delta, shiftLeft, mask); + } else if (delta + shiftLeft > 4 * compressedFieldSizeBits) { + packSingleField<4>(in, out, delta, shiftLeft, mask); + } +} +} // namespace + +template +void BitpackingUtils::unpackSingle(const uint8_t* __restrict srcCursor, + UncompressedType* __restrict dst, uint16_t bitWidth, size_t srcOffset) { + const size_t srcBufferOffset = srcOffset * bitWidth / sizeOfCompressedTypeBits; + const size_t shiftRight = srcOffset * bitWidth % sizeOfCompressedTypeBits; + + const auto* castedSrcCursor = + reinterpret_cast(srcCursor) + srcBufferOffset; + unpackSingleValueInPlace(castedSrcCursor, dst, bitWidth, shiftRight); +} + +template +void BitpackingUtils::packSingle(const UncompressedType src, + uint8_t* __restrict dstBuffer, uint16_t bitWidth, size_t dstOffset) { + const size_t dstBufferOffset = dstOffset * bitWidth / sizeOfCompressedTypeBits; + const size_t shiftLeft = dstOffset * bitWidth % sizeOfCompressedTypeBits; + + auto* castedDstBuffer = reinterpret_cast(dstBuffer) + dstBufferOffset; + packSingleImpl(src, castedDstBuffer, bitWidth, shiftLeft, + common::BitmaskUtils::all1sMaskForLeastSignificantBits(bitWidth)); +} + +template struct BitpackingUtils; +template struct BitpackingUtils; +template struct BitpackingUtils; +template struct BitpackingUtils; +template struct BitpackingUtils; +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/compression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/compression.cpp new file mode 100644 index 0000000000..ffaf2ff363 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/compression.cpp @@ -0,0 +1,1251 @@ +#include "storage/compression/compression.h" + +#include +#include +#include +#include + +#include "common/assert.h" +#include "common/exception/not_implemented.h" +#include "common/exception/storage.h" +#include "common/null_mask.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/type_utils.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "fastpfor/bitpackinghelpers.h" +#include "storage/compression/bitpacking_int128.h" +#include "storage/compression/bitpacking_utils.h" +#include "storage/compression/float_compression.h" +#include "storage/compression/sign_extend.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk_data.h" +#include + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +template +auto getTypedMinMax(std::span data, const NullMask* nullMask, uint64_t nullMaskOffset) { + std::optional min, max; + KU_ASSERT(data.size() > 0); + if (!nullMask || nullMask->hasNoNullsGuarantee()) { + auto [minRaw, maxRaw] = std::minmax_element(data.begin(), data.end()); + min = StorageValue(*minRaw); + max = StorageValue(*maxRaw); + } else { + for (uint64_t i = 0; i < data.size(); i++) { + if (!nullMask->isNull(nullMaskOffset + i)) { + if (!min || data[i] < min->get()) { + min = StorageValue(data[i]); + } + if (!max || data[i] > max->get()) { + max = StorageValue(data[i]); + } + } + } + } + return std::make_pair(min, max); +} + +uint32_t getDataTypeSizeInChunk(const common::LogicalType& dataType) { + return getDataTypeSizeInChunk(dataType.getPhysicalType()); +} + +uint32_t getDataTypeSizeInChunk(const common::PhysicalTypeID& dataType) { + switch (dataType) { + case PhysicalTypeID::STRING: + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::STRUCT: { + return 0; + } + case PhysicalTypeID::INTERNAL_ID: { + return sizeof(offset_t); + } + default: { + auto size = PhysicalTypeUtils::getFixedTypeSize(dataType); + KU_ASSERT(size <= LBUG_PAGE_SIZE); + return size; + } + } +} + +ALPMetadata::ALPMetadata(const alp::state& alpState, common::PhysicalTypeID physicalType) + : exp(alpState.exp), fac(alpState.fac), exceptionCount(alpState.exceptions_count) { + const size_t physicalTypeSize = PhysicalTypeUtils::getFixedTypeSize(physicalType); + + // to get the exception capacity we find the number of bytes needed to store the current + // exception count, take the smallest power of 2 greater than or equal to that value + // or the size of one page (whichever is larger) + // then find how many exceptions fit in that size + exceptionCapacity = + static_cast(std::bit_ceil(alpState.exceptions_count * physicalTypeSize)) / + physicalTypeSize; +} + +void ALPMetadata::serialize(common::Serializer& serializer) const { + serializer.write(exp); + serializer.write(fac); + serializer.write(exceptionCount); + serializer.write(exceptionCapacity); +} + +ALPMetadata ALPMetadata::deserialize(common::Deserializer& deserializer) { + ALPMetadata ret; + deserializer.deserializeValue(ret.exp); + deserializer.deserializeValue(ret.fac); + deserializer.deserializeValue(ret.exceptionCount); + deserializer.deserializeValue(ret.exceptionCapacity); + return ret; +} + +std::unique_ptr ALPMetadata::copy() { + return std::make_unique(*this); +} + +CompressionMetadata::CompressionMetadata(StorageValue min, StorageValue max, + CompressionType compression, const alp::state& state, StorageValue minEncoded, + StorageValue maxEncoded, common::PhysicalTypeID physicalType) + : min(min), max(max), compression(compression), + extraMetadata(std::make_unique(state, physicalType)) { + if (compression == CompressionType::ALP) { + children.emplace_back(minEncoded, maxEncoded, + minEncoded == maxEncoded ? CompressionType::CONSTANT : + CompressionType::INTEGER_BITPACKING); + } +} + +const CompressionMetadata& CompressionMetadata::getChild(offset_t idx) const { + KU_ASSERT(idx < getChildCount(compression)); + return children[idx]; +} + +CompressionMetadata::CompressionMetadata(const CompressionMetadata& o) + : min{o.min}, max{o.max}, compression{o.compression}, children{o.children} { + if (o.extraMetadata.has_value()) { + this->extraMetadata = o.extraMetadata.value()->copy(); + } +} + +CompressionMetadata& CompressionMetadata::operator=(const CompressionMetadata& o) { + if (this != &o) { + min = o.min; + max = o.max; + compression = o.compression; + if (o.extraMetadata.has_value()) { + extraMetadata = o.extraMetadata.value()->copy(); + } else { + extraMetadata = {}; + } + children = o.children; + } + return *this; +} + +void CompressionMetadata::serialize(Serializer& serializer) const { + serializer.write(min); + serializer.write(max); + serializer.write(compression); + + if (compression == CompressionType::ALP) { + floatMetadata()->serialize(serializer); + } + + KU_ASSERT(children.size() == getChildCount(compression)); + for (size_t i = 0; i < children.size(); ++i) { + children[i].serialize(serializer); + } +} + +CompressionMetadata CompressionMetadata::deserialize(common::Deserializer& deserializer) { + StorageValue min{}; + StorageValue max{}; + CompressionType compressionType{}; + deserializer.deserializeValue(min); + deserializer.deserializeValue(max); + deserializer.deserializeValue(compressionType); + + CompressionMetadata ret(min, max, compressionType); + + if (compressionType == CompressionType::ALP) { + auto alpMetadata = std::make_unique(ALPMetadata::deserialize(deserializer)); + ret.extraMetadata = std::move(alpMetadata); + } + + for (size_t i = 0; i < getChildCount(compressionType); ++i) { + ret.children.push_back(deserialize(deserializer)); + } + + return ret; +} + +bool CompressionMetadata::canAlwaysUpdateInPlace() const { + switch (compression) { + case CompressionType::BOOLEAN_BITPACKING: + case CompressionType::UNCOMPRESSED: { + return true; + } + case CompressionType::CONSTANT: + case CompressionType::ALP: + case CompressionType::INTEGER_BITPACKING: { + return false; + } + default: { + throw common::StorageException( + "Unknown compression type with ID " + std::to_string((uint8_t)compression)); + } + } +} + +bool CompressionMetadata::canUpdateInPlace(const uint8_t* data, uint32_t pos, uint64_t numValues, + PhysicalTypeID physicalType, InPlaceUpdateLocalState& localUpdateState, + const std::optional& nullMask) const { + if (canAlwaysUpdateInPlace()) { + return true; + } + switch (compression) { + case CompressionType::CONSTANT: { + // Value can be updated in place only if it is identical to the value already stored. + switch (physicalType) { + case PhysicalTypeID::BOOL: { + for (uint64_t i = pos; i < pos + numValues; i++) { + if (nullMask && nullMask->isNull(i)) { + continue; + } + if (NullMask::isNull(reinterpret_cast(data), i) != + static_cast(min.unsignedInt)) { + return false; + } + } + return true; + } + default: { + for (uint64_t i = pos; i < pos + numValues; i++) { + if (nullMask && nullMask->isNull(i)) { + continue; + } + auto size = getDataTypeSizeInChunk(physicalType); + if (memcmp(data + i * size, &min.unsignedInt, size) != 0) { + return false; + } + } + return true; + } + } + } + case CompressionType::BOOLEAN_BITPACKING: + case CompressionType::UNCOMPRESSED: { + return true; + } + case CompressionType::ALP: { + return TypeUtils::visit( + physicalType, + [&](T) { + auto values = std::span(reinterpret_cast(data) + pos, numValues); + return FloatCompression::canUpdateInPlace(values, *this, localUpdateState, + std::move(nullMask), pos); + }, + [&](auto) { + throw common::StorageException("Attempted to read from a column chunk which " + "uses float compression but does " + "not have a supported physical type: " + + PhysicalTypeUtils::toString(physicalType)); + return false; + }); + } + case CompressionType::INTEGER_BITPACKING: { + auto cdata = const_cast(data); + return TypeUtils::visit( + physicalType, + [&](T) { + auto values = std::span(reinterpret_cast(cdata) + pos, numValues); + return IntegerBitpacking::canUpdateInPlace(values, *this, std::move(nullMask), + pos); + }, + [&](internalID_t) { + auto values = + std::span(reinterpret_cast(cdata) + pos, numValues); + return IntegerBitpacking::canUpdateInPlace(values, *this, + std::move(nullMask), pos); + }, + [&](auto) { + throw common::StorageException("Attempted to read from a column chunk which " + "uses integer bitpacking but does " + "not have a supported integer physical type: " + + PhysicalTypeUtils::toString(physicalType)); + return false; + }); + } + default: { + throw common::StorageException( + "Unknown compression type with ID " + std::to_string((uint8_t)compression)); + } + } +} + +uint64_t CompressionMetadata::numValues(uint64_t pageSize, const LogicalType& dataType) const { + return numValues(pageSize, dataType.getPhysicalType()); +} + +uint64_t CompressionMetadata::numValues(uint64_t pageSize, common::PhysicalTypeID dataType) const { + switch (compression) { + case CompressionType::CONSTANT: { + return std::numeric_limits::max(); + } + case CompressionType::UNCOMPRESSED: { + return Uncompressed::numValues(pageSize, dataType); + } + case CompressionType::INTEGER_BITPACKING: { + switch (dataType) { + case PhysicalTypeID::INT128: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::INT64: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::INT32: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::INT16: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::INT8: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::UINT64: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::UINT32: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::UINT16: + return IntegerBitpacking::numValues(pageSize, *this); + case PhysicalTypeID::UINT8: + return IntegerBitpacking::numValues(pageSize, *this); + default: { + throw common::StorageException( + "Attempted to read from a column chunk which uses integer bitpacking but does " + "not " + "have a supported integer physical type: " + + PhysicalTypeUtils::toString(dataType)); + } + } + } + case CompressionType::ALP: { + switch (dataType) { + case PhysicalTypeID::DOUBLE: { + return FloatCompression::numValues(pageSize, *this); + } + case PhysicalTypeID::FLOAT: { + return FloatCompression::numValues(pageSize, *this); + } + default: { + throw common::StorageException( + "Attempted to read from a column chunk which uses float compression but does " + "not " + "have a supported physical type: " + + PhysicalTypeUtils::toString(dataType)); + } + } + } + case CompressionType::BOOLEAN_BITPACKING: { + return BooleanBitpacking::numValues(pageSize); + } + default: { + throw common::StorageException( + "Unknown compression type with ID " + std::to_string((uint8_t)compression)); + } + } +} + +size_t CompressionMetadata::getChildCount(CompressionType compressionType) { + switch (compressionType) { + case CompressionType::ALP: { + return 1; + } + default: { + return 0; + } + } +} + +std::optional ConstantCompression::analyze(const ColumnChunkData& chunk) { + switch (chunk.getDataType().getPhysicalType()) { + // Only values that can fit in the CompressionMetadata's data field can use constant + // compression + case PhysicalTypeID::BOOL: { + if (chunk.getCapacity() == 0) { + return std::optional( + CompressionMetadata(StorageValue(0), StorageValue(0), CompressionType::CONSTANT)); + } + auto firstValue = chunk.getValue(0); + + // TODO(bmwinger): This could be optimized. We could do bytewise comparison with memcmp, + // but we need to make sure to stop at the end of the values to avoid false positives + for (auto i = 1u; i < chunk.getNumValues(); i++) { + // If any value is different from the first one, we can't use constant compression + if (firstValue != chunk.getValue(i)) { + return std::nullopt; + } + } + auto value = StorageValue(firstValue); + return std::optional(CompressionMetadata(value, value, CompressionType::CONSTANT)); + } + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::DOUBLE: + case PhysicalTypeID::FLOAT: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::INT8: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT128: { + uint8_t size = chunk.getNumBytesPerValue(); + StorageValue value{}; + KU_ASSERT(size <= sizeof(value.unsignedInt)); + // If there are no values, or only one value, we will always use constant compression + // since the loop won't execute + for (auto i = 1u; i < chunk.getNumValues(); i++) { + // If any value is different from the first one, we can't use constant compression + if (std::memcmp(chunk.getData(), chunk.getData() + i * size, size) != 0) { + return std::nullopt; + } + } + if (chunk.getNumValues() > 0) { + std::memcpy(&value.unsignedInt, chunk.getData(), size); + } + return std::optional(CompressionMetadata(value, value, CompressionType::CONSTANT)); + } + default: { + return std::optional(); + } + } +} + +uint64_t Uncompressed::numValues(uint64_t dataSize, common::PhysicalTypeID physicalType) { + uint32_t numBytesPerValue = getDataTypeSizeInChunk(physicalType); + return numBytesPerValue == 0 ? UINT64_MAX : dataSize / numBytesPerValue; +} + +uint64_t Uncompressed::numValues(uint64_t dataSize, const common::LogicalType& logicalType) { + return numValues(dataSize, logicalType.getPhysicalType()); +} + +std::string CompressionMetadata::toString(const PhysicalTypeID physicalType) const { + switch (compression) { + case CompressionType::UNCOMPRESSED: { + return "UNCOMPRESSED"; + } + case CompressionType::ALP: { + uint8_t bitWidth = TypeUtils::visit( + physicalType, + [&](T) { + static constexpr common::idx_t BITPACKING_CHILD_IDX = 0; + return IntegerBitpacking::EncodedType>::getPackingInfo( + getChild(BITPACKING_CHILD_IDX)) + .bitWidth; + }, + [](auto) -> uint8_t { KU_UNREACHABLE; }); + return stringFormat("FLOAT_COMPRESSION[{}], {} Exceptions", bitWidth, + floatMetadata()->exceptionCount); + } + case CompressionType::INTEGER_BITPACKING: { + uint8_t bitWidth = TypeUtils::visit( + physicalType, + [&](common::internalID_t) { + return IntegerBitpacking::getPackingInfo(*this).bitWidth; + }, + [](bool) -> uint8_t { KU_UNREACHABLE; }, + [&]( + T) { return IntegerBitpacking::getPackingInfo(*this).bitWidth; }, + [](auto) -> uint8_t { KU_UNREACHABLE; }); + return stringFormat("INTEGER_BITPACKING[{}]", bitWidth); + } + case CompressionType::BOOLEAN_BITPACKING: { + return "BOOLEAN_BITPACKING"; + } + case CompressionType::CONSTANT: { + return "CONSTANT"; + } + default: { + KU_UNREACHABLE; + } + } +} + +void ConstantCompression::decompressValues(uint8_t* dstBuffer, uint64_t dstOffset, + uint64_t numValues, common::PhysicalTypeID physicalType, uint32_t numBytesPerValue, + const CompressionMetadata& metadata) { + auto start = dstBuffer + dstOffset * numBytesPerValue; + auto end = dstBuffer + (dstOffset + numValues) * numBytesPerValue; + + TypeUtils::visit( + physicalType, + [&](common::internalID_t) { + std::fill(reinterpret_cast(start), reinterpret_cast(end), + metadata.min.get()); + }, + [&] + requires(numeric_utils::IsIntegral || std::floating_point) + (T) { + std::fill(reinterpret_cast(start), reinterpret_cast(end), + metadata.min.get()); + }, + [&](auto) { + throw NotImplementedException("CONSTANT compression is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + }); +} + +void ConstantCompression::decompressFromPage(const uint8_t* /*srcBuffer*/, uint64_t /*srcOffset*/, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& metadata) const { + return decompressValues(dstBuffer, dstOffset, numValues, dataType, numBytesPerValue, metadata); +} + +void ConstantCompression::copyFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& metadata) const { + if (dataType == common::PhysicalTypeID::BOOL) { + common::NullMask::setNullRange(reinterpret_cast(dstBuffer), dstOffset, numValues, + metadata.min.unsignedInt); + } else { + decompressFromPage(srcBuffer, srcOffset, dstBuffer, dstOffset, numValues, metadata); + } +} + +template +inline T abs(T value); + +template + requires std::is_unsigned_v +inline T abs(T value) { + return value; +} + +template + requires std::is_signed_v +inline T abs(T value) { + return std::abs(value); +} + +template<> +inline int128_t abs(int128_t value) { + return value >= 0 ? value : -value; +} + +template +BitpackInfo IntegerBitpacking::getPackingInfo(const CompressionMetadata& metadata) { + auto max = metadata.max.get(); + auto min = metadata.min.get(); + bool hasNegative = false; + T offset = 0; + uint8_t bitWidth = 0; + // Frame of reference encoding is only used when values are either all positive or all + // negative, and when we will save at least 1 bit per value. when the chunk was first + // compressed + if (min > 0 && max > 0 && + numeric_utils::bitWidth((U)(max - min)) < numeric_utils::bitWidth((U)max)) { + offset = min; + bitWidth = static_cast(numeric_utils::bitWidth((U)(max - min))); + hasNegative = false; + } else if (min < 0 && max < 0 && + numeric_utils::bitWidth((U)(min - max)) < numeric_utils::bitWidth((U)max)) { + offset = (U)max; + bitWidth = static_cast(numeric_utils::bitWidth((U)(min - max))) + 1; + // This is somewhat suboptimal since we know that the values are all negative + // We could use an offset equal to the minimum, but values which are all negative are + // probably going to grow in the negative direction, leading to many re-compressions + // when inserting + hasNegative = true; + } else if (min < 0) { + bitWidth = + static_cast(numeric_utils::bitWidth((U)std::max(abs(min), abs(max)))) + + 1; + hasNegative = true; + } else { + bitWidth = + static_cast(numeric_utils::bitWidth((U)std::max(abs(min), abs(max)))); + hasNegative = false; + } + return BitpackInfo{bitWidth, hasNegative, offset}; +} + +template +bool IntegerBitpacking::canUpdateInPlace(std::span values, + const CompressionMetadata& metadata, const std::optional& nullMask, + uint64_t nullMaskOffset) { + auto info = getPackingInfo(metadata); + auto [min, max] = getTypedMinMax(values, nullMask ? &*nullMask : nullptr, nullMaskOffset); + KU_ASSERT((min && max) || (!min && !max)); + // If all values are null update can trivially be done in-place + if (!min) { + return true; + } + auto newMetadata = + CompressionMetadata(StorageValue(std::min(metadata.min.get(), min->template get())), + StorageValue(std::max(metadata.max.get(), max->template get())), + metadata.compression); + auto newInfo = getPackingInfo(newMetadata); + + if (info.bitWidth != newInfo.bitWidth || info.hasNegative != newInfo.hasNegative || + info.offset != newInfo.offset) { + return false; + } + return true; +} + +template +void fastunpack(const uint8_t* in, T* out, uint32_t bitWidth) { + if constexpr (std::is_same_v, int32_t> || + std::is_same_v, int64_t>) { + FastPForLib::fastunpack((const uint32_t*)in, out, bitWidth); + } else if constexpr (std::is_same_v, int16_t>) { + FastPForLib::fastunpack((const uint16_t*)in, out, bitWidth); + } else if constexpr (std::is_same_v, int8_t>) { + FastPForLib::fastunpack((const uint8_t*)in, out, bitWidth); + } else { + static_assert(std::is_same_v, int128_t>); + Int128Packer::unpack(reinterpret_cast(in), out, bitWidth); + } +} + +template +void fastpack(const T* in, uint8_t* out, uint8_t bitWidth) { + if constexpr (std::is_same_v, int32_t> || + std::is_same_v, int64_t>) { + FastPForLib::fastpack(in, (uint32_t*)out, bitWidth); + } else if constexpr (std::is_same_v, int16_t>) { + FastPForLib::fastpack(in, (uint16_t*)out, bitWidth); + } else if constexpr (std::is_same_v, int8_t>) { + FastPForLib::fastpack(in, (uint8_t*)out, bitWidth); + } else { + static_assert(std::is_same_v, int128_t>); + Int128Packer::pack(in, reinterpret_cast(out), bitWidth); + } +} + +template +void IntegerBitpacking::setPartialChunkInPlace(const uint8_t* srcBuffer, offset_t posInSrc, + uint8_t* dstBuffer, offset_t posInDst, offset_t numValues, const BitpackInfo& header) const { + U tmpChunk[CHUNK_SIZE]; + copyValuesToTempChunkWithOffset(reinterpret_cast(srcBuffer) + posInSrc, tmpChunk, + header, numValues); + packPartialChunk(tmpChunk, dstBuffer, posInDst, header, numValues); +} + +template +void IntegerBitpacking::setValuesFromUncompressed(const uint8_t* srcBuffer, offset_t posInSrc, + uint8_t* dstBuffer, offset_t posInDst, offset_t numValues, const CompressionMetadata& metadata, + const NullMask* nullMask) const { + KU_UNUSED(nullMask); + + auto header = getPackingInfo(metadata); + + // Null values will usually be 0, which will not be able to be stored if there is a + // non-zero offset However we don't care about the value stored for null values + // Currently they will be mangled by storage+recovery (underflow in the subtraction + // below) + KU_ASSERT(numValues == static_cast(std::ranges::count_if( + std::ranges::iota_view{posInSrc, posInSrc + numValues}, + [srcBuffer, &metadata, nullMask](offset_t i) { + auto value = reinterpret_cast(srcBuffer)[i]; + return (nullMask && nullMask->isNull(i)) || + canUpdateInPlace(std::span(&value, 1), metadata); + }))); + + // Data can be considered to be stored in aligned chunks of 32 values + // with a size of 32 * bitWidth bits, + // or bitWidth 32-bit values (we cast the buffer to a uint32_t* later). + + // update unaligned values in the first chunk + auto valuesInFirstChunk = std::min(CHUNK_SIZE - (posInDst % CHUNK_SIZE), numValues); + offset_t dstIndex = posInDst; + if (valuesInFirstChunk < CHUNK_SIZE) { + // update unaligned values in the last chunk + setPartialChunkInPlace(srcBuffer, posInSrc, dstBuffer, posInDst, valuesInFirstChunk, + header); + dstIndex += valuesInFirstChunk; + } + + // update chunk-aligned values using fastpack/unpack + for (; dstIndex + CHUNK_SIZE <= posInDst + numValues; dstIndex += CHUNK_SIZE) { + U chunk[CHUNK_SIZE]; + + const size_t chunkIndexOffsetInSrc = posInSrc + dstIndex - posInDst; + copyValuesToTempChunkWithOffset(reinterpret_cast(srcBuffer) + + chunkIndexOffsetInSrc, + chunk, header, CHUNK_SIZE); + + const offset_t dstOffsetBytes = dstIndex * header.bitWidth / 8; + fastpack(chunk, dstBuffer + dstOffsetBytes, header.bitWidth); + } + + // update unaligned values in the last chunk + const auto lastChunkIndexOffset = dstIndex - posInDst; + const size_t unalignedValuesToPack = numValues - lastChunkIndexOffset; + if (unalignedValuesToPack > 0) { + setPartialChunkInPlace(srcBuffer, posInSrc + lastChunkIndexOffset, dstBuffer, + posInDst + lastChunkIndexOffset, unalignedValuesToPack, header); + } +} + +template +void IntegerBitpacking::getValues(const uint8_t* chunkStart, uint8_t pos, uint8_t* dst, + uint8_t numValuesToRead, const BitpackInfo& header) const { + const size_t maxReadIndex = pos + numValuesToRead; + KU_ASSERT(maxReadIndex <= CHUNK_SIZE); + + for (size_t i = pos; i < maxReadIndex; i++) { + // Always use unsigned version of unpacker to prevent sign-bit filling when right + // shifting + U& out = reinterpret_cast(dst)[i - pos]; + BitpackingUtils::unpackSingle(chunkStart, &out, header.bitWidth, i); + + if (header.hasNegative && header.bitWidth > 0) { + SignExtend((uint8_t*)&out, header.bitWidth); + } + + if (header.offset != 0) { + reinterpret_cast(out) += header.offset; + } + } +} + +template +void IntegerBitpacking::packPartialChunk(const U* srcBuffer, uint8_t* dstBuffer, size_t posInDst, + BitpackInfo info, size_t numValuesToPack) const { + for (size_t i = 0; i < numValuesToPack; ++i) { + BitpackingUtils::packSingle(srcBuffer[i], dstBuffer, info.bitWidth, i + posInDst); + } +} + +template +void IntegerBitpacking::copyValuesToTempChunkWithOffset(const U* srcBuffer, U* tmpBuffer, + BitpackInfo info, size_t numValuesToCopy) const { + for (auto j = 0u; j < numValuesToCopy; j++) { + tmpBuffer[j] = static_cast((T)(srcBuffer[j]) - info.offset); + } +} + +template +uint64_t IntegerBitpacking::compressNextPage(const uint8_t*& srcBuffer, + uint64_t numValuesRemaining, uint8_t* dstBuffer, uint64_t dstBufferSize, + const CompressionMetadata& metadata) const { + // TODO(bmwinger): this is hacky; we need a better system for dynamically choosing between + // algorithms when compressing + if (metadata.compression == CompressionType::UNCOMPRESSED) { + return Uncompressed(sizeof(T)).compressNextPage(srcBuffer, numValuesRemaining, dstBuffer, + dstBufferSize, metadata); + } + KU_ASSERT(metadata.compression == CompressionType::INTEGER_BITPACKING); + auto info = getPackingInfo(metadata); + auto bitWidth = info.bitWidth; + + if (bitWidth == 0) { + return 0; + } + auto numValuesToCompress = std::min(numValuesRemaining, numValues(dstBufferSize, info)); + // Round up to nearest byte + auto sizeToCompress = + numValuesToCompress * bitWidth / 8 + (numValuesToCompress * bitWidth % 8 != 0); + KU_ASSERT(dstBufferSize >= CHUNK_SIZE); + KU_ASSERT(dstBufferSize >= sizeToCompress); + // This might overflow the source buffer if there are fewer values remaining than the chunk + // size so we stop at the end of the last full chunk and use a temporary array to avoid + // overflow. + if (info.offset == 0) { + auto lastFullChunkEnd = numValuesToCompress - numValuesToCompress % CHUNK_SIZE; + for (auto i = 0ull; i < lastFullChunkEnd; i += CHUNK_SIZE) { + fastpack(reinterpret_cast(srcBuffer) + i, dstBuffer + i * bitWidth / 8, + bitWidth); + } + // Pack last partial chunk, avoiding overflows + const size_t remainingNumValues = numValuesToCompress % CHUNK_SIZE; + if (remainingNumValues > 0) { + packPartialChunk(reinterpret_cast(srcBuffer) + lastFullChunkEnd, + dstBuffer + lastFullChunkEnd * bitWidth / 8, 0, info, remainingNumValues); + } + } else { + U tmp[CHUNK_SIZE]; + auto lastFullChunkEnd = numValuesToCompress - numValuesToCompress % CHUNK_SIZE; + for (auto i = 0ull; i < lastFullChunkEnd; i += CHUNK_SIZE) { + copyValuesToTempChunkWithOffset(reinterpret_cast(srcBuffer) + i, tmp, info, + CHUNK_SIZE); + fastpack(tmp, dstBuffer + i * bitWidth / 8, bitWidth); + } + // Pack last partial chunk, avoiding overflows + auto remainingValues = numValuesToCompress % CHUNK_SIZE; + if (remainingValues > 0) { + copyValuesToTempChunkWithOffset(reinterpret_cast(srcBuffer) + + lastFullChunkEnd, + tmp, info, remainingValues); + packPartialChunk(tmp, dstBuffer + lastFullChunkEnd * bitWidth / 8, 0, info, + remainingValues); + } + } + srcBuffer += numValuesToCompress * sizeof(U); + return sizeToCompress; +} + +template +void IntegerBitpacking::decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& metadata) const { + auto info = getPackingInfo(metadata); + + auto srcCursor = getChunkStart(srcBuffer, srcOffset, info.bitWidth); + auto valuesInFirstChunk = std::min(CHUNK_SIZE - (srcOffset % CHUNK_SIZE), numValues); + auto bytesPerChunk = CHUNK_SIZE / 8 * info.bitWidth; + auto dstIndex = dstOffset; + + // Copy values which aren't aligned to the start of the chunk + if (valuesInFirstChunk < CHUNK_SIZE) { + getValues(srcCursor, srcOffset % CHUNK_SIZE, dstBuffer + dstIndex * sizeof(U), + valuesInFirstChunk, info); + if (numValues == valuesInFirstChunk) { + return; + } + // Start at the end of the first partial chunk + srcCursor += bytesPerChunk; + dstIndex += valuesInFirstChunk; + } + + // Use fastunpack to directly unpack the full-sized chunks + for (; dstIndex + CHUNK_SIZE <= dstOffset + numValues; dstIndex += CHUNK_SIZE) { + fastunpack(srcCursor, (U*)dstBuffer + dstIndex, info.bitWidth); + if (info.hasNegative && info.bitWidth > 0) { + SignExtend(dstBuffer + dstIndex * sizeof(U), info.bitWidth); + } + if (info.offset != 0) { + for (auto i = 0u; i < CHUNK_SIZE; i++) { + ((T*)dstBuffer)[dstIndex + i] += info.offset; + } + } + srcCursor += bytesPerChunk; + } + // Copy remaining values from within the last chunk. + if (dstIndex < dstOffset + numValues) { + getValues(srcCursor, 0, dstBuffer + dstIndex * sizeof(U), dstOffset + numValues - dstIndex, + info); + } +} + +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; +template class IntegerBitpacking; + +void BooleanBitpacking::setValuesFromUncompressed(const uint8_t* srcBuffer, offset_t srcOffset, + uint8_t* dstBuffer, offset_t dstOffset, offset_t numValues, + const CompressionMetadata& /*metadata*/, const NullMask* /*nullMask*/) const { + for (auto i = 0u; i < numValues; i++) { + NullMask::setNull((uint64_t*)dstBuffer, dstOffset + i, ((bool*)srcBuffer)[srcOffset + i]); + } +} + +uint64_t BooleanBitpacking::compressNextPage(const uint8_t*& srcBuffer, uint64_t numValuesRemaining, + uint8_t* dstBuffer, uint64_t dstBufferSize, const CompressionMetadata& /*metadata*/) const { + // TODO(bmwinger): Optimize, e.g. using an integer bitpacking function + auto numValuesToCompress = std::min(numValuesRemaining, numValues(dstBufferSize)); + for (auto i = 0ull; i < numValuesToCompress; i++) { + NullMask::setNull((uint64_t*)dstBuffer, i, srcBuffer[i]); + } + srcBuffer += numValuesToCompress / 8; + // Will be a multiple of 8 except for the last iteration + return numValuesToCompress / 8 + (bool)(numValuesToCompress % 8); +} + +void BooleanBitpacking::decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& /*metadata*/) const { + // TODO(bmwinger): Optimize, e.g. using an integer bitpacking function + for (auto i = 0ull; i < numValues; i++) { + ((bool*)dstBuffer)[dstOffset + i] = NullMask::isNull((uint64_t*)srcBuffer, srcOffset + i); + } +} + +void BooleanBitpacking::copyFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const CompressionMetadata& /*metadata*/) const { + NullMask::copyNullMask(reinterpret_cast(srcBuffer), srcOffset, + reinterpret_cast(dstBuffer), dstOffset, numValues); +} + +void ReadCompressedValuesFromPageToVector::operator()(const uint8_t* frame, PageCursor& pageCursor, + common::ValueVector* resultVector, uint32_t posInVector, uint64_t numValuesToRead, + const CompressionMetadata& metadata) { + switch (metadata.compression) { + case CompressionType::CONSTANT: + return constant.decompressFromPage(frame, pageCursor.elemPosInPage, resultVector->getData(), + posInVector, numValuesToRead, metadata); + case CompressionType::UNCOMPRESSED: + return uncompressed.decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + case CompressionType::ALP: { + switch (physicalType) { + case PhysicalTypeID::DOUBLE: { + return FloatCompression().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::FLOAT: { + return FloatCompression().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + default: { + throw NotImplementedException("Float Compression is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + } + } + } + case CompressionType::INTEGER_BITPACKING: { + switch (physicalType) { + case PhysicalTypeID::INT128: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::INT64: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::INT32: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::INT16: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::INT8: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::UINT64: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::UINT32: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::UINT16: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + case PhysicalTypeID::UINT8: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + } + default: { + throw NotImplementedException("INTEGER_BITPACKING is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + } + } + } + case CompressionType::BOOLEAN_BITPACKING: + return booleanBitpacking.decompressFromPage(frame, pageCursor.elemPosInPage, + resultVector->getData(), posInVector, numValuesToRead, metadata); + default: + KU_UNREACHABLE; + } +} + +void ReadCompressedValuesFromPage::operator()(const uint8_t* frame, PageCursor& pageCursor, + uint8_t* result, uint32_t startPosInResult, uint64_t numValuesToRead, + const CompressionMetadata& metadata) { + switch (metadata.compression) { + case CompressionType::CONSTANT: + return constant.copyFromPage(frame, pageCursor.elemPosInPage, result, startPosInResult, + numValuesToRead, metadata); + case CompressionType::UNCOMPRESSED: + return uncompressed.decompressFromPage(frame, pageCursor.elemPosInPage, result, + startPosInResult, numValuesToRead, metadata); + case CompressionType::ALP: { + switch (physicalType) { + case PhysicalTypeID::DOUBLE: { + return FloatCompression().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::FLOAT: { + return FloatCompression().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + default: { + throw NotImplementedException("Float Compression is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + } + } + } + case CompressionType::INTEGER_BITPACKING: { + switch (physicalType) { + case PhysicalTypeID::INT128: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::INT64: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::INT32: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::INT16: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::INT8: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::UINT64: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::UINT32: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::UINT16: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + case PhysicalTypeID::UINT8: { + return IntegerBitpacking().decompressFromPage(frame, pageCursor.elemPosInPage, + result, startPosInResult, numValuesToRead, metadata); + } + default: { + throw NotImplementedException("INTEGER_BITPACKING is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + } + } + } + case CompressionType::BOOLEAN_BITPACKING: + // Reading into ColumnChunks should be done without decompressing for booleans + return booleanBitpacking.copyFromPage(frame, pageCursor.elemPosInPage, result, + startPosInResult, numValuesToRead, metadata); + default: + KU_UNREACHABLE; + } +} + +void WriteCompressedValuesToPage::operator()(uint8_t* frame, uint16_t posInFrame, + const uint8_t* data, offset_t dataOffset, offset_t numValues, + const CompressionMetadata& metadata, const NullMask* nullMask) { + switch (metadata.compression) { + case CompressionType::CONSTANT: + return constant.setValuesFromUncompressed(data, dataOffset, frame, posInFrame, numValues, + metadata, nullMask); + case CompressionType::UNCOMPRESSED: + return uncompressed.setValuesFromUncompressed(data, dataOffset, frame, posInFrame, + numValues, metadata, nullMask); + case CompressionType::INTEGER_BITPACKING: { + return TypeUtils::visit(physicalType, [&](T) { + if constexpr (std::same_as) { + throw NotImplementedException( + "INTEGER_BITPACKING is not implemented for type bool"); + } else if constexpr (std::same_as) { + IntegerBitpacking().setValuesFromUncompressed(data, dataOffset, frame, + posInFrame, numValues, metadata, nullMask); + } else if constexpr (numeric_utils::IsIntegral) { + return IntegerBitpacking().setValuesFromUncompressed(data, dataOffset, frame, + posInFrame, numValues, metadata, nullMask); + } else { + throw NotImplementedException("INTEGER_BITPACKING is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + } + }); + } + case CompressionType::ALP: { + return TypeUtils::visit(physicalType, [&](T) { + if constexpr (std::is_floating_point_v) { + FloatCompression().setValuesFromUncompressed(data, dataOffset, frame, posInFrame, + numValues, metadata, nullMask); + } else { + throw NotImplementedException("FLOAT_COMPRESSION is not implemented for type " + + PhysicalTypeUtils::toString(physicalType)); + } + }); + } + case CompressionType::BOOLEAN_BITPACKING: + return booleanBitpacking.copyFromPage(data, dataOffset, frame, posInFrame, numValues, + metadata); + + default: + KU_UNREACHABLE; + } +} + +void WriteCompressedValuesToPage::operator()(uint8_t* frame, uint16_t posInFrame, + common::ValueVector* vector, uint32_t posInVector, offset_t numValues, + const CompressionMetadata& metadata) { + if (metadata.compression == CompressionType::BOOLEAN_BITPACKING) { + booleanBitpacking.setValuesFromUncompressed(vector->getData(), posInVector, frame, + posInFrame, numValues, metadata, &vector->getNullMask()); + } else { + (*this)(frame, posInFrame, vector->getData(), posInVector, 1, metadata, + &vector->getNullMask()); + } +} + +std::optional StorageValue::readFromVector(const common::ValueVector& vector, + common::offset_t posInVector) { + return TypeUtils::visit( + vector.dataType.getPhysicalType(), + // TODO(bmwinger): concept for supported storagevalue types + [&]( + T) { return std::make_optional(StorageValue(vector.getValue(posInVector))); }, + [](auto) { return std::optional(); }); +} + +bool StorageValue::gt(const StorageValue& other, common::PhysicalTypeID type) const { + switch (type) { + case common::PhysicalTypeID::BOOL: + case common::PhysicalTypeID::LIST: + case common::PhysicalTypeID::ARRAY: + case common::PhysicalTypeID::INTERNAL_ID: + case common::PhysicalTypeID::STRING: + case common::PhysicalTypeID::UINT64: + case common::PhysicalTypeID::UINT32: + case common::PhysicalTypeID::UINT16: + case common::PhysicalTypeID::UINT8: + return this->unsignedInt > other.unsignedInt; + case common::PhysicalTypeID::INT128: + return this->signedInt128 > other.signedInt128; + case common::PhysicalTypeID::INT64: + case common::PhysicalTypeID::INT32: + case common::PhysicalTypeID::INT16: + case common::PhysicalTypeID::INT8: + return this->signedInt > other.signedInt; + case common::PhysicalTypeID::FLOAT: + case common::PhysicalTypeID::DOUBLE: + return this->floatVal > other.floatVal; + default: + KU_UNREACHABLE; + } +} + +std::pair, std::optional> getMinMaxStorageValue( + const uint8_t* data, uint64_t offset, uint64_t numValues, PhysicalTypeID physicalType, + const NullMask* nullMask, bool valueRequiredIfUnsupported) { + std::pair, std::optional> returnValue; + + TypeUtils::visit( + physicalType, + [&](bool) { + if (numValues > 0) { + const auto boolData = reinterpret_cast(data); + if (!nullMask || nullMask->hasNoNullsGuarantee()) { + auto [minRaw, maxRaw] = NullMask::getMinMax(boolData, offset, numValues); + returnValue = std::make_pair(std::optional(StorageValue(minRaw)), + std::optional(StorageValue(maxRaw))); + } else { + std::optional min, max; + for (size_t i = offset; i < offset + numValues; i++) { + if (!nullMask || !nullMask->isNull(i)) { + auto boolValue = NullMask::isNull(boolData, i); + if (!max || boolValue > max->get()) { + max = boolValue; + } + if (!min || boolValue < min->get()) { + min = boolValue; + } + } + } + returnValue = std::make_pair(min, max); + } + } + }, + [&](T) + requires(numeric_utils::IsIntegral || std::floating_point) + { + if (numValues > 0) { + auto typedData = std::span(reinterpret_cast(data) + offset, numValues); + returnValue = getTypedMinMax(typedData, nullMask ? &*nullMask : nullptr, offset); + } + }, + [&](T) + requires(std::same_as) + { + if (numValues > 0) { + const auto typedData = + std::span(reinterpret_cast(data) + offset, numValues); + returnValue = getTypedMinMax(typedData, nullMask ? &*nullMask : nullptr, offset); + } + }, + [&](T) + requires(std::same_as || std::same_as || + std::same_as || std::same_as || + std::same_as) + { + if (valueRequiredIfUnsupported) { + // For unsupported types on the first copy, + // they need a non-optional value to distinguish them + // from supported types where every value is null + returnValue.first = std::numeric_limits::min(); + returnValue.second = std::numeric_limits::max(); + } + }); + return returnValue; +} +std::pair, std::optional> getMinMaxStorageValue( + const ColumnChunkData& data, uint64_t offset, uint64_t numValues, PhysicalTypeID physicalType, + bool valueRequiredIfUnsupported) { + auto nullMask = data.getNullMask(); + return getMinMaxStorageValue(data.getData(), offset, numValues, physicalType, + nullMask ? &*nullMask : nullptr, valueRequiredIfUnsupported); +} + +std::pair, std::optional> getMinMaxStorageValue( + const ValueVector& data, uint64_t offset, uint64_t numValues, PhysicalTypeID physicalType, + bool valueRequiredIfUnsupported) { + std::pair, std::optional> returnValue; + auto& nullMask = data.getNullMask(); + + TypeUtils::visit( + physicalType, + [&](T) + requires(numeric_utils::IsIntegral || std::floating_point) + { + if (numValues > 0) { + auto typedData = + std::span(reinterpret_cast(data.getData()) + offset, numValues); + returnValue = getTypedMinMax(typedData, &nullMask, offset); + } + }, + [&](T) + requires(std::same_as) + { + if (numValues > 0) { + const auto typedData = std::span( + reinterpret_cast(data.getData()) + offset, numValues); + returnValue = getTypedMinMax(typedData, &nullMask, offset); + } + }, + [&](T) + requires(std::same_as || std::same_as || + std::same_as || std::same_as || + std::same_as) + { + if (valueRequiredIfUnsupported) { + // For unsupported types on the first copy, + // they need a non-optional value to distinguish them + // from supported types where every value is null + returnValue.first = std::numeric_limits::min(); + returnValue.second = std::numeric_limits::max(); + } + }); + return returnValue; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/float_compression.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/float_compression.cpp new file mode 100644 index 0000000000..4bfc605604 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/compression/float_compression.cpp @@ -0,0 +1,248 @@ +#include "storage/compression/float_compression.h" + +#include "alp/encode.hpp" +#include "common/system_config.h" +#include "common/utils.h" +#include + +namespace lbug { +namespace storage { + +namespace { +static constexpr common::idx_t BITPACKING_CHILD_IDX = 0; + +template +common::LogicalType getBitpackingLogicalType() { + if constexpr (std::is_same_v) { + return common::LogicalType::INT32(); + } else { + return common::LogicalType::INT64(); + } +} +} // namespace + +template +size_t EncodeException::numPagesFromExceptions(size_t exceptionCount) { + return common::ceilDiv(static_cast(exceptionCount), + common::LBUG_PAGE_SIZE / sizeInBytes()); +} + +template +size_t EncodeException::exceptionBytesPerPage() { + return common::LBUG_PAGE_SIZE / sizeInBytes() * sizeInBytes(); +} + +template +bool EncodeException::operator<(const EncodeException& o) const { + return posInChunk < o.posInChunk; +} + +template +EncodeException EncodeExceptionView::getValue(common::offset_t elementOffset) const { + EncodeException ret{}; + const auto* const elementAddress = bytes + elementOffset * decltype(ret)::sizeInBytes(); + std::memcpy(&ret.value, elementAddress, sizeof(ret.value)); + std::memcpy(&ret.posInChunk, elementAddress + sizeof(ret.value), sizeof(ret.posInChunk)); + return ret; +} + +template +void EncodeExceptionView::setValue(EncodeException exception, + common::offset_t elementOffset) { + auto* const elementAddress = bytes + elementOffset * decltype(exception)::sizeInBytes(); + std::memcpy(elementAddress, &exception.value, sizeof(exception.value)); + std::memcpy(elementAddress + sizeof(exception.value), &exception.posInChunk, + sizeof(exception.posInChunk)); +} + +template +FloatCompression::FloatCompression() + : constantEncodedFloatBitpacker(getBitpackingLogicalType()), encodedFloatBitpacker() {} + +template +uint64_t FloatCompression::compressNextPage(const uint8_t*&, uint64_t, uint8_t*, uint64_t, + const struct CompressionMetadata&) const { + KU_UNREACHABLE; +} + +template +uint64_t FloatCompression::compressNextPageWithExceptions(const uint8_t*& srcBuffer, + uint64_t srcOffset, uint64_t numValuesRemaining, uint8_t* dstBuffer, uint64_t dstBufferSize, + EncodeExceptionView exceptionBuffer, [[maybe_unused]] uint64_t exceptionBufferSize, + uint64_t& exceptionCount, const struct CompressionMetadata& metadata) const { + KU_ASSERT(metadata.compression == CompressionType::ALP); + + const size_t numValuesToCompress = + std::min(numValuesRemaining, numValues(dstBufferSize, metadata)); + + std::vector integerEncodedValues(numValuesToCompress); + for (size_t posInPage = 0; posInPage < numValuesToCompress; ++posInPage) { + const auto floatValue = reinterpret_cast(srcBuffer)[posInPage]; + const auto* floatMetadata = metadata.floatMetadata(); + const EncodedType encodedValue = + alp::AlpEncode::encode_value(floatValue, floatMetadata->fac, floatMetadata->exp); + const double decodedValue = + alp::AlpDecode::decode_value(encodedValue, floatMetadata->fac, floatMetadata->exp); + + if (floatValue != decodedValue) { + KU_ASSERT( + (exceptionCount + 1) * EncodeException::sizeInBytes() <= exceptionBufferSize); + exceptionBuffer.setValue( + {.value = floatValue, + .posInChunk = common::safeIntegerConversion(srcOffset + posInPage)}, + exceptionCount); + + // We don't need to replace with 1st successful encode as the integer bitpacking + // metadata is already populated + ++exceptionCount; + } else { + integerEncodedValues[posInPage] = encodedValue; + } + } + srcBuffer += numValuesToCompress * sizeof(T); + + const auto* castedIntegerEncodedBuffer = + reinterpret_cast(integerEncodedValues.data()); + const auto compressedIntegerSize = + getEncodedFloatBitpacker(metadata).compressNextPage(castedIntegerEncodedBuffer, + numValuesToCompress, dstBuffer, dstBufferSize, metadata.getChild(BITPACKING_CHILD_IDX)); + + // zero out unused parts of the page + memset(dstBuffer + compressedIntegerSize, 0, dstBufferSize - compressedIntegerSize); + + // since we already do the zeroing we return the size of the whole page + return dstBufferSize; +} + +template +uint64_t FloatCompression::numValues(uint64_t dataSize, const CompressionMetadata& metadata) { + return metadata.getChild(BITPACKING_CHILD_IDX) + .numValues(dataSize, getBitpackingLogicalType()); +} + +template +void FloatCompression::decompressFromPage(const uint8_t* srcBuffer, uint64_t srcOffset, + uint8_t* dstBuffer, uint64_t dstOffset, uint64_t numValues, + const struct CompressionMetadata& metadata) const { + + // use dstBuffer for unpacking the ALP encoded values then decode them in place + getEncodedFloatBitpacker(metadata).decompressFromPage(srcBuffer, srcOffset, dstBuffer, + dstOffset, numValues, metadata.getChild(BITPACKING_CHILD_IDX)); + + static_assert(sizeof(EncodedType) == sizeof(T)); + auto* integerEncodedValues = reinterpret_cast(dstBuffer); + for (size_t i = 0; i < numValues; ++i) { + reinterpret_cast(dstBuffer)[dstOffset + i] = + alp::AlpDecode::decode_value(integerEncodedValues[dstOffset + i], + metadata.floatMetadata()->fac, metadata.floatMetadata()->exp); + } +} + +template +void FloatCompression::setValuesFromUncompressed(const uint8_t* srcBuffer, + common::offset_t srcOffset, uint8_t* dstBuffer, common::offset_t dstOffset, + common::offset_t numValues, const CompressionMetadata& metadata, + const common::NullMask* nullMask) const { + // each individual value that is being updated should be able to be updated in place + RUNTIME_CHECK(InPlaceUpdateLocalState localUpdateState{}); + KU_ASSERT(numValues == + static_cast( + std::ranges::count_if(std::ranges::iota_view{srcOffset, srcOffset + numValues}, + [&localUpdateState, srcBuffer, &metadata, nullMask](common::offset_t i) { + auto value = reinterpret_cast(srcBuffer)[i]; + return (nullMask && nullMask->isNull(i)) || + canUpdateInPlace(std::span(&value, 1), metadata, localUpdateState); + }))); + + std::vector integerEncodedValues(numValues); + for (size_t i = 0; i < numValues; ++i) { + const size_t posInSrc = i + srcOffset; + + const auto floatValue = reinterpret_cast(srcBuffer)[posInSrc]; + const EncodedType encodedValue = alp::AlpEncode::encode_value(floatValue, + metadata.floatMetadata()->fac, metadata.floatMetadata()->exp); + integerEncodedValues[i] = encodedValue; + } + + getEncodedFloatBitpacker(metadata).setValuesFromUncompressed( + reinterpret_cast(integerEncodedValues.data()), 0, dstBuffer, dstOffset, + numValues, metadata.getChild(BITPACKING_CHILD_IDX), nullMask); +} + +template +const CompressionAlg& FloatCompression::getEncodedFloatBitpacker( + const CompressionMetadata& metadata) const { + if (metadata.getChild(BITPACKING_CHILD_IDX).isConstant()) { + return constantEncodedFloatBitpacker; + } else { + return encodedFloatBitpacker; + } +} + +template +BitpackInfo::EncodedType> FloatCompression::getBitpackInfo( + const CompressionMetadata& metadata) { + const auto& bitpackMetadata = metadata.getChild(BITPACKING_CHILD_IDX); + if (bitpackMetadata.isConstant()) { + const auto constValue = bitpackMetadata.min.get(); + return {.bitWidth = 0, .hasNegative = (constValue < 0), .offset = constValue}; + } else { + return IntegerBitpacking::getPackingInfo(bitpackMetadata); + } +} + +template +bool FloatCompression::canUpdateInPlace(std::span value, + const CompressionMetadata& metadata, InPlaceUpdateLocalState& localUpdateState, + const std::optional& nullMask, uint64_t nullMaskOffset) { + size_t newExceptionCount = 0; + std::vector encodedValues(value.size()); + const auto bitpackingInfo = getBitpackInfo(metadata); + const auto* floatMetadata = metadata.floatMetadata(); + for (size_t i = 0; i < value.size(); ++i) { + if (nullMask && nullMask->isNull(nullMaskOffset + i)) { + continue; + } + + const auto floatValue = value[i]; + const EncodedType encodedValue = + alp::AlpEncode::encode_value(floatValue, floatMetadata->fac, floatMetadata->exp); + const double decodedValue = + alp::AlpDecode::decode_value(encodedValue, floatMetadata->fac, floatMetadata->exp); + if (floatValue != decodedValue) { + ++newExceptionCount; + encodedValues[i] = bitpackingInfo.offset; + } else { + encodedValues[i] = encodedValue; + } + } + localUpdateState.floatState.newExceptionCount += newExceptionCount; + const size_t totalExceptionCount = + floatMetadata->exceptionCount + localUpdateState.floatState.newExceptionCount; + const bool exceptionsOK = totalExceptionCount <= floatMetadata->exceptionCapacity; + + return exceptionsOK && + metadata.getChild(BITPACKING_CHILD_IDX) + .canUpdateInPlace(reinterpret_cast(encodedValues.data()), 0, + encodedValues.size(), getBitpackingLogicalType().getPhysicalType(), + localUpdateState); +} + +template +common::page_idx_t FloatCompression::getNumDataPages(common::page_idx_t numTotalPages, + const CompressionMetadata& compMeta) { + return numTotalPages - + EncodeException::numPagesFromExceptions(compMeta.floatMetadata()->exceptionCapacity); +} + +template class FloatCompression; +template class FloatCompression; + +template struct EncodeException; +template struct EncodeException; + +template struct EncodeExceptionView; +template struct EncodeExceptionView; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/database_header.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/database_header.cpp new file mode 100644 index 0000000000..444f12253b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/database_header.cpp @@ -0,0 +1,117 @@ +#include "storage/database_header.h" + +#include + +#include "common/exception/runtime.h" +#include "common/file_system/file_info.h" +#include "common/serializer/buffered_file.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/system_config.h" +#include "main/client_context.h" +#include "storage/page_manager.h" +#include "storage/storage_version_info.h" + +namespace lbug::storage { +static void validateStorageVersion(common::Deserializer& deSer) { + std::string key; + deSer.validateDebuggingInfo(key, "storage_version"); + storage_version_t savedStorageVersion = 0; + deSer.deserializeValue(savedStorageVersion); + const auto storageVersion = StorageVersionInfo::getStorageVersion(); + if (savedStorageVersion != storageVersion) { + // TODO(Guodong): Add a test case for this. + throw common::RuntimeException( + common::stringFormat("Trying to read a database file with a different version. " + "Database file version: {}, Current build storage version: {}", + savedStorageVersion, storageVersion)); + } +} + +static void validateMagicBytes(common::Deserializer& deSer) { + std::string key; + deSer.validateDebuggingInfo(key, "magic"); + const auto numMagicBytes = strlen(StorageVersionInfo::MAGIC_BYTES); + uint8_t magicBytes[4]; + for (auto i = 0u; i < numMagicBytes; i++) { + deSer.deserializeValue(magicBytes[i]); + } + if (memcmp(magicBytes, StorageVersionInfo::MAGIC_BYTES, numMagicBytes) != 0) { + throw common::RuntimeException( + "Unable to open database. The file is not a valid Lbug database file!"); + } +} + +void DatabaseHeader::updateCatalogPageRange(PageManager& pageManager, PageRange newPageRange) { + if (catalogPageRange.startPageIdx != common::INVALID_PAGE_IDX) { + pageManager.freePageRange(catalogPageRange); + } + catalogPageRange = newPageRange; +} + +void DatabaseHeader::freeMetadataPageRange(PageManager& pageManager) const { + if (metadataPageRange.startPageIdx != common::INVALID_PAGE_IDX) { + pageManager.freePageRange(metadataPageRange); + } +} + +static void writeMagicBytes(common::Serializer& serializer) { + serializer.writeDebuggingInfo("magic"); + const auto numMagicBytes = strlen(StorageVersionInfo::MAGIC_BYTES); + for (auto i = 0u; i < numMagicBytes; i++) { + serializer.serializeValue(StorageVersionInfo::MAGIC_BYTES[i]); + } +} + +void DatabaseHeader::serialize(common::Serializer& ser) const { + writeMagicBytes(ser); + ser.writeDebuggingInfo("storage_version"); + ser.serializeValue(StorageVersionInfo::getStorageVersion()); + ser.writeDebuggingInfo("catalog"); + ser.serializeValue(catalogPageRange.startPageIdx); + ser.serializeValue(catalogPageRange.numPages); + ser.writeDebuggingInfo("metadata"); + ser.serializeValue(metadataPageRange.startPageIdx); + ser.serializeValue(metadataPageRange.numPages); + ser.writeDebuggingInfo("databaseID"); + ser.serializeValue(databaseID.value); +} + +DatabaseHeader DatabaseHeader::deserialize(common::Deserializer& deSer) { + validateMagicBytes(deSer); + validateStorageVersion(deSer); + PageRange catalogPageRange{}, metaPageRange{}; + common::ku_uuid_t databaseID{}; + std::string key; + deSer.validateDebuggingInfo(key, "catalog"); + deSer.deserializeValue(catalogPageRange.startPageIdx); + deSer.deserializeValue(catalogPageRange.numPages); + deSer.validateDebuggingInfo(key, "metadata"); + deSer.deserializeValue(metaPageRange.startPageIdx); + deSer.deserializeValue(metaPageRange.numPages); + deSer.validateDebuggingInfo(key, "databaseID"); + deSer.deserializeValue(databaseID.value); + return {catalogPageRange, metaPageRange, databaseID}; +} + +DatabaseHeader DatabaseHeader::createInitialHeader(common::RandomEngine* randomEngine) { + // We generate a random UUID to act as the database ID + return DatabaseHeader{{}, {}, common::UUID::generateRandomUUID(randomEngine)}; +} + +std::optional DatabaseHeader::readDatabaseHeader(common::FileInfo& dataFileInfo) { + if (dataFileInfo.getFileSize() < common::LBUG_PAGE_SIZE) { + // If the data file hasn't been written to there is no existing database header + return std::nullopt; + } + auto reader = std::make_unique(dataFileInfo); + common::Deserializer deSer(std::move(reader)); + try { + return DatabaseHeader::deserialize(deSer); + } catch (const common::RuntimeException&) { + // It is possible we optimistically write to the database file before the first checkpoint + // In this case the magic bytes check will fail and we assume there is no existing header + return std::nullopt; + } +} +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/disk_array.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/disk_array.cpp new file mode 100644 index 0000000000..e31ad8f5df --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/disk_array.cpp @@ -0,0 +1,401 @@ +#include "storage/disk_array.h" + +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "common/types/types.h" +#include "storage/file_handle.h" +#include "storage/shadow_file.h" +#include "storage/shadow_utils.h" +#include "storage/storage_utils.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +// Header can be read or write since it just needs the sizes +static PageCursor getAPIdxAndOffsetInAP(const PageStorageInfo& info, uint64_t idx) { + // We assume that `numElementsPerPageLog2`, `elementPageOffsetMask`, + // `alignedElementSizeLog2` are never modified throughout transactional updates, thus, we + // directly use them from header here. + const page_idx_t apIdx = idx / info.numElementsPerPage; + const uint32_t byteOffsetInAP = (idx % info.numElementsPerPage) * info.alignedElementSize; + return PageCursor{apIdx, byteOffsetInAP}; +} + +PageStorageInfo::PageStorageInfo(uint64_t elementSize) + : alignedElementSize{std::bit_ceil(elementSize)}, + numElementsPerPage{LBUG_PAGE_SIZE / alignedElementSize} { + KU_ASSERT(elementSize <= LBUG_PAGE_SIZE); +} + +PIPWrapper::PIPWrapper(const FileHandle& fileHandle, page_idx_t pipPageIdx) + : pipPageIdx(pipPageIdx) { + fileHandle.readPageFromDisk(reinterpret_cast(&pipContents), pipPageIdx); +} + +DiskArrayInternal::DiskArrayInternal(FileHandle& fileHandle, + const DiskArrayHeader& headerForReadTrx, DiskArrayHeader& headerForWriteTrx, + ShadowFile* shadowFile, uint64_t elementSize, bool bypassShadowing) + : storageInfo{elementSize}, fileHandle(fileHandle), header{headerForReadTrx}, + headerForWriteTrx{headerForWriteTrx}, hasTransactionalUpdates{false}, shadowFile{shadowFile}, + lastAPPageIdx{INVALID_PAGE_IDX}, lastPageOnDisk{INVALID_PAGE_IDX} { + if (this->header.firstPIPPageIdx != ShadowUtils::NULL_PAGE_IDX) { + pips.emplace_back(fileHandle, header.firstPIPPageIdx); + while (pips[pips.size() - 1].pipContents.nextPipPageIdx != ShadowUtils::NULL_PAGE_IDX) { + pips.emplace_back(fileHandle, pips[pips.size() - 1].pipContents.nextPipPageIdx); + } + } + // If bypassing the WAL is disabled, just leave the lastPageOnDisk as invalid, as then all pages + // will be treated as updates to existing ones + if (bypassShadowing) { + updateLastPageOnDisk(); + } +} + +void DiskArrayInternal::updateLastPageOnDisk() { + auto numElements = getNumElementsNoLock(TransactionType::READ_ONLY); + if (numElements > 0) { + auto apCursor = getAPIdxAndOffsetInAP(storageInfo, numElements - 1); + lastPageOnDisk = getAPPageIdxNoLock(apCursor.pageIdx, TransactionType::READ_ONLY); + } else { + lastPageOnDisk = 0; + } +} + +uint64_t DiskArrayInternal::getNumElements(TransactionType trxType) { + std::shared_lock sLck{diskArraySharedMtx}; + return getNumElementsNoLock(trxType); +} + +bool DiskArrayInternal::checkOutOfBoundAccess(TransactionType trxType, uint64_t idx) const { + auto currentNumElements = getNumElementsNoLock(trxType); + if (idx >= currentNumElements) { + // LCOV_EXCL_START + throw RuntimeException(stringFormat( + "idx: {} of the DiskArray to be accessed is >= numElements in DiskArray{}.", idx, + currentNumElements)); + // LCOV_EXCL_STOP + } + return true; +} + +void DiskArrayInternal::get(uint64_t idx, const Transaction* transaction, + std::span val) { + std::shared_lock sLck{diskArraySharedMtx}; + KU_ASSERT(checkOutOfBoundAccess(transaction->getType(), idx)); + auto apCursor = getAPIdxAndOffsetInAP(storageInfo, idx); + page_idx_t apPageIdx = getAPPageIdxNoLock(apCursor.pageIdx, transaction->getType()); + if (transaction->getType() != TransactionType::CHECKPOINT || !hasTransactionalUpdates || + apPageIdx > lastPageOnDisk || + !shadowFile->hasShadowPage(fileHandle.getFileIndex(), apPageIdx)) { + fileHandle.optimisticReadPage(apPageIdx, [&](const uint8_t* frame) -> void { + memcpy(val.data(), frame + apCursor.elemPosInPage, val.size()); + }); + } else { + ShadowUtils::readShadowVersionOfPage(fileHandle, apPageIdx, *shadowFile, + [&val, &apCursor](const uint8_t* frame) -> void { + memcpy(val.data(), frame + apCursor.elemPosInPage, val.size()); + }); + } +} + +void DiskArrayInternal::updatePage(uint64_t pageIdx, bool isNewPage, + std::function updateOp) { + // Pages which are new to this transaction are written directly to the file + // Pages which previously existed are written to the WAL file + if (pageIdx <= lastPageOnDisk) { + // This may still be used to create new pages since bypassing the WAL is currently optional + // and if disabled lastPageOnDisk will be INVALID_PAGE_IDX (and the above comparison will + // always be true) + ShadowUtils::updatePage(fileHandle, pageIdx, isNewPage, *shadowFile, updateOp); + } else { + const auto frame = fileHandle.pinPage(pageIdx, + isNewPage ? PageReadPolicy::DONT_READ_PAGE : PageReadPolicy::READ_PAGE); + updateOp(frame); + fileHandle.setLockedPageDirty(pageIdx); + fileHandle.unpinPage(pageIdx); + } +} + +void DiskArrayInternal::update(const Transaction* transaction, uint64_t idx, + std::span val) { + std::unique_lock xLck{diskArraySharedMtx}; + hasTransactionalUpdates = true; + KU_ASSERT(checkOutOfBoundAccess(transaction->getType(), idx)); + auto apCursor = getAPIdxAndOffsetInAP(storageInfo, idx); + // TODO: We are currently supporting only DiskArrays that can grow in size and not + // those that can shrink in size. That is why we can use + // getAPPageIdxNoLock(apIdx, Transaction::WRITE) directly to compute the physical page Idx + // because any apIdx is guaranteed to be either in an existing PIP or a new PIP we added, which + // getAPPageIdxNoLock will correctly locate: this function simply searches an existing PIP if + // apIdx < numAPs stored in "previous" PIP; otherwise one of the newly inserted PIPs stored in + // pipPageIdxsOfInsertedPIPs. If within a single transaction we could grow or shrink, then + // getAPPageIdxNoLock logic needs to change to give the same guarantee (e.g., an apIdx = 0, may + // no longer to be guaranteed to be in pips[0].) + page_idx_t apPageIdx = getAPPageIdxNoLock(apCursor.pageIdx, transaction->getType()); + updatePage(apPageIdx, false /*isNewPage=*/, [&apCursor, &val](uint8_t* frame) -> void { + memcpy(frame + apCursor.elemPosInPage, val.data(), val.size()); + }); +} + +uint64_t DiskArrayInternal::resize(PageAllocator& pageAllocator, const Transaction* transaction, + uint64_t newNumElements, std::span defaultVal) { + std::unique_lock xLck{diskArraySharedMtx}; + auto it = iter_mut(defaultVal.size()); + auto originalNumElements = getNumElementsNoLock(transaction->getType()); + while (it.size() < newNumElements) { + it.pushBack(pageAllocator, transaction, defaultVal); + } + return originalNumElements; +} + +void DiskArrayInternal::setNextPIPPageIDxOfPIPNoLock(uint64_t pipIdxOfPreviousPIP, + page_idx_t nextPIPPageIdx) { + // This happens if the first pip is being inserted, in which case we need to change the header. + if (pipIdxOfPreviousPIP == UINT64_MAX) { + headerForWriteTrx.firstPIPPageIdx = nextPIPPageIdx; + } else if (pips.empty()) { + pipUpdates.newPIPs[pipIdxOfPreviousPIP].pipContents.nextPipPageIdx = nextPIPPageIdx; + } else { + if (!pipUpdates.updatedLastPIP.has_value()) { + pipUpdates.updatedLastPIP = std::make_optional(pips[pipIdxOfPreviousPIP]); + } + if (pipIdxOfPreviousPIP == pips.size() - 1) { + pipUpdates.updatedLastPIP->pipContents.nextPipPageIdx = nextPIPPageIdx; + } else { + KU_ASSERT(pipIdxOfPreviousPIP >= pips.size() && + pipUpdates.newPIPs.size() > pipIdxOfPreviousPIP - pips.size()); + pipUpdates.newPIPs[pipIdxOfPreviousPIP - pips.size()].pipContents.nextPipPageIdx = + nextPIPPageIdx; + } + } +} + +page_idx_t DiskArrayInternal::getAPPageIdxNoLock(page_idx_t apIdx, TransactionType trxType) { + auto [pipIdx, offsetInPIP] = StorageUtils::getQuotientRemainder(apIdx, NUM_PAGE_IDXS_PER_PIP); + if ((trxType != TransactionType::CHECKPOINT) || !hasPIPUpdatesNoLock(pipIdx)) { + return pips[pipIdx].pipContents.pageIdxs[offsetInPIP]; + } else if (pipIdx == pips.size() - 1 && pipUpdates.updatedLastPIP) { + return pipUpdates.updatedLastPIP->pipContents.pageIdxs[offsetInPIP]; + } else { + KU_ASSERT(pipIdx >= pips.size() && pipIdx - pips.size() < pipUpdates.newPIPs.size()); + return pipUpdates.newPIPs[pipIdx - pips.size()].pipContents.pageIdxs[offsetInPIP]; + } +} + +page_idx_t DiskArrayInternal::getUpdatedPageIdxOfPipNoLock(uint64_t pipIdx) { + if (pipIdx < pips.size()) { + return pips[pipIdx].pipPageIdx; + } + return pipUpdates.newPIPs[pipIdx - pips.size()].pipPageIdx; +} + +void DiskArrayInternal::clearWALPageVersionAndRemovePageFromFrameIfNecessary(page_idx_t pageIdx) { + shadowFile->clearShadowPage(fileHandle.getFileIndex(), pageIdx); + fileHandle.removePageFromFrameIfNecessary(pageIdx); +} + +void DiskArrayInternal::checkpointOrRollbackInMemoryIfNecessaryNoLock(bool isCheckpoint) { + if (!hasTransactionalUpdates) { + return; + } + if (pipUpdates.updatedLastPIP.has_value()) { + // Note: This should not cause a memory leak because PIPWrapper is a struct. So we + // should overwrite the previous PIPWrapper's memory. + if (isCheckpoint) { + pips.back() = *pipUpdates.updatedLastPIP; + } + clearWALPageVersionAndRemovePageFromFrameIfNecessary(pips.back().pipPageIdx); + } + + for (auto& newPIP : pipUpdates.newPIPs) { + clearWALPageVersionAndRemovePageFromFrameIfNecessary(newPIP.pipPageIdx); + if (isCheckpoint) { + pips.emplace_back(newPIP); + } + } + // Note that we already updated the header to its correct state above. + pipUpdates.clear(); + hasTransactionalUpdates = false; + if (isCheckpoint && lastPageOnDisk != INVALID_PAGE_IDX) { + updateLastPageOnDisk(); + } +} + +void DiskArrayInternal::checkpoint() { + if (pipUpdates.updatedLastPIP.has_value()) { + ShadowUtils::updatePage(fileHandle, pipUpdates.updatedLastPIP->pipPageIdx, true, + *shadowFile, [&](auto* frame) { + memcpy(frame, &pipUpdates.updatedLastPIP->pipContents, sizeof(PIP)); + }); + } + for (auto& newPIP : pipUpdates.newPIPs) { + ShadowUtils::updatePage(fileHandle, newPIP.pipPageIdx, true, *shadowFile, + [&](auto* frame) { memcpy(frame, &newPIP.pipContents, sizeof(PIP)); }); + } +} + +void DiskArrayInternal::reclaimStorage(PageAllocator& pageAllocator) const { + for (auto& pip : pips) { + for (auto pageIdx : pip.pipContents.pageIdxs) { + if (pageIdx != ShadowUtils::NULL_PAGE_IDX) { + pageAllocator.freePage(pageIdx); + } + } + if (pip.pipPageIdx != ShadowUtils::NULL_PAGE_IDX) { + pageAllocator.freePage(pip.pipPageIdx); + } + } +} + +bool DiskArrayInternal::hasPIPUpdatesNoLock(uint64_t pipIdx) const { + // This is a request to a pipIdx > pips.size(). Since pips.size() is the original number of pips + // we started with before the write transaction is updated, we return true, i.e., this PIP is + // an "updated" PIP and should be read from the WAL version. + if (pipIdx >= pips.size()) { + return true; + } + return (pipIdx == pips.size() - 1) && pipUpdates.updatedLastPIP; +} + +std::pair +DiskArrayInternal::getAPPageIdxAndAddAPToPIPIfNecessaryForWriteTrxNoLock( + PageAllocator& pageAllocator, const Transaction* transaction, page_idx_t apIdx) { + if (apIdx == getNumAPs(headerForWriteTrx) - 1 && lastAPPageIdx != INVALID_PAGE_IDX) { + return std::make_pair(lastAPPageIdx, false /*not a new page*/); + } else if (apIdx < getNumAPs(headerForWriteTrx)) { + // If the apIdx of the array page is < numAPs, we do not have to + // add a new array page, so directly return the pageIdx of the apIdx. + return std::make_pair(getAPPageIdxNoLock(apIdx, transaction->getType()), + false /* is not inserting a new ap page */); + } else { + // apIdx even if it's being inserted should never be > updatedDiskArrayHeader->numAPs. + KU_ASSERT(apIdx == getNumAPs(headerForWriteTrx)); + // We need to add a new AP. This may further cause a new pip to be inserted, which is + // handled by the if/else-if/else branch below. + page_idx_t newAPPageIdx = pageAllocator.allocatePage(); + // We need to create a new array page and then add its apPageIdx (newAPPageIdx variable) to + // an appropriate PIP. + auto pipIdxAndOffsetOfNewAP = + StorageUtils::getQuotientRemainder(apIdx, NUM_PAGE_IDXS_PER_PIP); + uint64_t pipIdx = pipIdxAndOffsetOfNewAP.first; + uint64_t offsetOfNewAPInPIP = pipIdxAndOffsetOfNewAP.second; + if (pipIdx < pips.size()) { + KU_ASSERT(pipIdx == pips.size() - 1); + // We do not need to insert a new pip and we need to add newAPPageIdx to a PIP that + // existed before this transaction started. + if (!pipUpdates.updatedLastPIP.has_value()) { + pipUpdates.updatedLastPIP = std::make_optional(pips[pipIdx]); + } + pipUpdates.updatedLastPIP->pipContents.pageIdxs[offsetOfNewAPInPIP] = newAPPageIdx; + } else if ((pipIdx - pips.size()) < pipUpdates.newPIPs.size()) { + // We do not need to insert a new PIP and we need to add newAPPageIdx to a new PIP that + // already got created after this transaction started. + auto& pip = pipUpdates.newPIPs[pipIdx - pips.size()]; + pip.pipContents.pageIdxs[offsetOfNewAPInPIP] = newAPPageIdx; + } else { + // We need to create a new PIP and make the previous PIP (or the header) point to it. + page_idx_t pipPageIdx = pageAllocator.allocatePage(); + pipUpdates.newPIPs.emplace_back(pipPageIdx); + uint64_t pipIdxOfPreviousPIP = pipIdx - 1; + setNextPIPPageIDxOfPIPNoLock(pipIdxOfPreviousPIP, pipPageIdx); + pipUpdates.newPIPs.back().pipContents.pageIdxs[offsetOfNewAPInPIP] = newAPPageIdx; + } + return std::make_pair(newAPPageIdx, true /* inserting a new ap page */); + } +} + +DiskArrayInternal::WriteIterator& DiskArrayInternal::WriteIterator::seek(size_t newIdx) { + KU_ASSERT(newIdx < diskArray.headerForWriteTrx.numElements); + auto oldPageIdx = apCursor.pageIdx; + idx = newIdx; + apCursor = getAPIdxAndOffsetInAP(diskArray.storageInfo, idx); + if (oldPageIdx != apCursor.pageIdx) { + page_idx_t apPageIdx = diskArray.getAPPageIdxNoLock(apCursor.pageIdx, TRX_TYPE); + getPage(apPageIdx, false /*isNewlyAdded*/); + } + return *this; +} + +void DiskArrayInternal::WriteIterator::pushBack(PageAllocator& pageAllocator, + const Transaction* transaction, std::span val) { + idx = diskArray.headerForWriteTrx.numElements; + auto oldPageIdx = apCursor.pageIdx; + apCursor = getAPIdxAndOffsetInAP(diskArray.storageInfo, idx); + // If this would add a new page, pin new page and update PIP + auto [apPageIdx, isNewlyAdded] = + diskArray.getAPPageIdxAndAddAPToPIPIfNecessaryForWriteTrxNoLock(pageAllocator, transaction, + apCursor.pageIdx); + diskArray.lastAPPageIdx = apPageIdx; + // Used to calculate the number of APs, so it must be updated after the PIPs are. + diskArray.headerForWriteTrx.numElements++; + if (isNewlyAdded || shadowPageAndFrame.originalPage == INVALID_PAGE_IDX || + apCursor.pageIdx != oldPageIdx) { + getPage(apPageIdx, isNewlyAdded); + } + memcpy(operator*().data(), val.data(), val.size()); +} + +void DiskArrayInternal::WriteIterator::unpin() { + if (shadowPageAndFrame.shadowPage != INVALID_PAGE_IDX) { + // unpin current page + diskArray.shadowFile->getShadowingFH().unpinPage(shadowPageAndFrame.shadowPage); + shadowPageAndFrame.shadowPage = INVALID_PAGE_IDX; + } else if (shadowPageAndFrame.originalPage != INVALID_PAGE_IDX) { + diskArray.fileHandle.setLockedPageDirty(shadowPageAndFrame.originalPage); + diskArray.fileHandle.unpinPage(shadowPageAndFrame.originalPage); + shadowPageAndFrame.originalPage = INVALID_PAGE_IDX; + } +} + +void DiskArrayInternal::WriteIterator::getPage(page_idx_t newPageIdx, bool isNewlyAdded) { + unpin(); + if (newPageIdx <= diskArray.lastPageOnDisk) { + // Pin new page + shadowPageAndFrame = ShadowUtils::createShadowVersionIfNecessaryAndPinPage(newPageIdx, + isNewlyAdded, diskArray.fileHandle, *diskArray.shadowFile); + } else { + shadowPageAndFrame.frame = diskArray.fileHandle.pinPage(newPageIdx, + isNewlyAdded ? PageReadPolicy::DONT_READ_PAGE : PageReadPolicy::READ_PAGE); + shadowPageAndFrame.originalPage = newPageIdx; + shadowPageAndFrame.shadowPage = INVALID_PAGE_IDX; + } +} + +DiskArrayInternal::WriteIterator DiskArrayInternal::iter_mut(uint64_t valueSize) { + return WriteIterator(valueSize, *this); +} + +page_idx_t DiskArrayInternal::getAPIdx(uint64_t idx) const { + return getAPIdxAndOffsetInAP(storageInfo, idx).pageIdx; +} + +// [] operator to be used when building an InMemDiskArrayBuilder without transactional updates. +// This changes the contents directly in memory and not on disk (nor on the wal) +uint8_t* BlockVectorInternal::operator[](uint64_t idx) const { + auto apCursor = getAPIdxAndOffsetInAP(storageInfo, idx); + KU_ASSERT(apCursor.pageIdx < inMemArrayPages.size()); + return inMemArrayPages[apCursor.pageIdx]->getData() + apCursor.elemPosInPage; +} + +void BlockVectorInternal::resize(uint64_t newNumElements, + const element_construct_func_t& defaultConstructor) { + auto oldNumElements = numElements; + KU_ASSERT(newNumElements >= oldNumElements); + uint64_t oldNumArrayPages = inMemArrayPages.size(); + uint64_t newNumArrayPages = getNumArrayPagesNeededForElements(newNumElements); + for (auto i = oldNumArrayPages; i < newNumArrayPages; ++i) { + inMemArrayPages.emplace_back( + memoryManager.allocateBuffer(true /*initializeToZero*/, LBUG_PAGE_SIZE)); + } + for (uint64_t i = 0; i < newNumElements - oldNumElements; i++) { + auto* dest = operator[](oldNumElements + i); + defaultConstructor(dest); + } + numElements = newNumElements; +} +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/disk_array_collection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/disk_array_collection.cpp new file mode 100644 index 0000000000..80a67ffc96 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/disk_array_collection.cpp @@ -0,0 +1,115 @@ +#include "storage/disk_array_collection.h" + +#include "common/system_config.h" +#include "common/types/types.h" +#include "storage/file_handle.h" +#include "storage/shadow_utils.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +DiskArrayCollection::DiskArrayCollection(FileHandle& fileHandle, ShadowFile& shadowFile, + bool bypassShadowing) + : fileHandle(fileHandle), shadowFile{shadowFile}, bypassShadowing{bypassShadowing}, + numHeaders{0} { + headersForReadTrx.push_back(std::make_unique()); + headersForWriteTrx.push_back(std::make_unique()); + headerPagesOnDisk = 0; +} + +DiskArrayCollection::DiskArrayCollection(FileHandle& fileHandle, ShadowFile& shadowFile, + page_idx_t firstHeaderPage, bool bypassShadowing) + : fileHandle(fileHandle), shadowFile{shadowFile}, bypassShadowing{bypassShadowing}, + numHeaders{0} { + // Read headers from disk + page_idx_t headerPageIdx = firstHeaderPage; + do { + fileHandle.optimisticReadPage(headerPageIdx, [&](auto* frame) { + const auto page = reinterpret_cast(frame); + headersForReadTrx.push_back(std::make_unique(*page)); + headersForWriteTrx.push_back(std::make_unique(*page)); + headerPageIdx = page->nextHeaderPage; + numHeaders += page->numHeaders; + }); + } while (headerPageIdx != INVALID_PAGE_IDX); + headerPagesOnDisk = headersForReadTrx.size(); +} + +void DiskArrayCollection::checkpoint(page_idx_t firstHeaderPage, PageAllocator& pageAllocator) { + // Write headers to disk + page_idx_t headerPage = firstHeaderPage; + for (page_idx_t indexInMemory = 0; indexInMemory < headersForWriteTrx.size(); indexInMemory++) { + if (headersForWriteTrx[indexInMemory]->nextHeaderPage == INVALID_PAGE_IDX && + indexInMemory < headersForWriteTrx.size() - 1) { + // This is the first time checkpointing the next disk array, allocate a page for its + // header + populateNextHeaderPage(pageAllocator, indexInMemory); + } + + // Only update if the headers for the given page have changed + // Or if the page has not yet been written + if (indexInMemory >= headerPagesOnDisk || + *headersForWriteTrx[indexInMemory] != *headersForReadTrx[indexInMemory]) { + ShadowUtils::updatePage(*pageAllocator.getDataFH(), headerPage, + true /*writing full page*/, shadowFile, [&](auto* frame) { + memcpy(frame, headersForWriteTrx[indexInMemory].get(), sizeof(HeaderPage)); + if constexpr (sizeof(HeaderPage) < LBUG_PAGE_SIZE) { + // Zero remaining data in the page + std::fill(frame + sizeof(HeaderPage), frame + LBUG_PAGE_SIZE, 0); + } + }); + } + headerPage = headersForWriteTrx[indexInMemory]->nextHeaderPage; + } + headerPagesOnDisk = headersForWriteTrx.size(); +} + +void DiskArrayCollection::populateNextHeaderPage(PageAllocator& pageAllocator, + common::page_idx_t indexInMemory) { + auto nextHeaderPage = pageAllocator.allocatePage(); + headersForWriteTrx[indexInMemory]->nextHeaderPage = nextHeaderPage; + // We can't really roll back the structural changes in the PKIndex (the disk arrays are + // created in the destructor and there are a fixed number which does not change after that + // point), so we apply those to the version that would otherwise be identical to the one on + // disk + headersForReadTrx[indexInMemory]->nextHeaderPage = nextHeaderPage; +} + +size_t DiskArrayCollection::addDiskArray() { + auto oldSize = numHeaders++; + // This may not be the last header page. If we rollback there may be header pages which are + // empty + auto pageIdx = numHeaders % HeaderPage::NUM_HEADERS_PER_PAGE; + if (pageIdx >= headersForWriteTrx.size()) { + + headersForWriteTrx.emplace_back(std::make_unique()); + // Also add a new read header page as we need to pass read headers to the disk arrays + // Newly added read headers will be empty until checkpointing + headersForReadTrx.emplace_back(std::make_unique()); + } + + auto& headerPage = *headersForWriteTrx[pageIdx]; + KU_ASSERT(headerPage.numHeaders < HeaderPage::NUM_HEADERS_PER_PAGE); + auto indexInPage = headerPage.numHeaders; + headerPage.headers[indexInPage] = DiskArrayHeader(); + headerPage.numHeaders++; + headersForReadTrx[pageIdx]->numHeaders++; + return oldSize; +} + +void DiskArrayCollection::reclaimStorage(PageAllocator& pageAllocator, + common::page_idx_t firstHeaderPage) const { + auto headerPage = firstHeaderPage; + for (page_idx_t indexInMemory = 0; indexInMemory < headersForReadTrx.size(); indexInMemory++) { + if (headerPage == INVALID_PAGE_IDX) { + break; + } + pageAllocator.freePage(headerPage); + headerPage = headersForReadTrx[indexInMemory]->nextHeaderPage; + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/file_db_id_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/file_db_id_utils.cpp new file mode 100644 index 0000000000..83fe9ea181 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/file_db_id_utils.cpp @@ -0,0 +1,20 @@ +#include "storage/file_db_id_utils.h" + +#include "common/exception/runtime.h" + +namespace lbug::storage { +void FileDBIDUtils::verifyDatabaseID(const common::FileInfo& fileInfo, + common::ku_uuid_t expectedDatabaseID, common::ku_uuid_t databaseID) { + if (expectedDatabaseID.value != databaseID.value) { + throw common::RuntimeException(common::stringFormat( + "Database ID for temporary file '{}' does not match the current database. This file " + "may have been left behind from a previous database with the same name. If it is safe " + "to do so, please delete this file and restart the database.", + fileInfo.path)); + } +} + +void FileDBIDUtils::writeDatabaseID(common::Serializer& ser, common::ku_uuid_t databaseID) { + ser.write(databaseID); +} +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/file_handle.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/file_handle.cpp new file mode 100644 index 0000000000..679ac7a8d7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/file_handle.cpp @@ -0,0 +1,177 @@ +#include "storage/file_handle.h" + +#include + +#include "common/file_system/virtual_file_system.h" +#include "storage/buffer_manager/buffer_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +FileHandle::FileHandle(const std::string& path, uint8_t fhFlags, BufferManager* bm, + uint32_t fileIndex, VirtualFileSystem* vfs, main::ClientContext* context) + : fhFlags{fhFlags}, fileIndex{fileIndex}, numPages{0}, pageCapacity{0}, bm{bm}, + pageSizeClass{isNewTmpFile() && isLargePaged() ? TEMP_PAGE : REGULAR_PAGE}, pageStates{0, 0}, + frameGroupIdxes{0, 0}, pageManager(std::make_unique(this)) { + if (isNewTmpFile()) { + constructTmpFileHandle(path); + } else { + constructPersistentFileHandle(path, vfs, context); + } + pageStates = ConcurrentVector{numPages, pageCapacity}; + frameGroupIdxes = ConcurrentVector{getNumPageGroups(), getNumPageGroups()}; + for (auto i = 0u; i < frameGroupIdxes.size(); i++) { + frameGroupIdxes[i] = bm->addNewFrameGroup(pageSizeClass); + } +} + +void FileHandle::constructPersistentFileHandle(const std::string& path, VirtualFileSystem* vfs, + main::ClientContext* context) { + FileOpenFlags openFlags{0}; + if (isReadOnlyFile()) { + openFlags.flags = FileFlags::READ_ONLY; + openFlags.lockType = isLockRequired() ? FileLockType::READ_LOCK : FileLockType::NO_LOCK; + } else { + openFlags.flags = + FileFlags::WRITE | FileFlags::READ_ONLY | + ((createFileIfNotExists()) ? FileFlags::CREATE_IF_NOT_EXISTS : 0x00000000); + openFlags.lockType = isLockRequired() ? FileLockType::WRITE_LOCK : FileLockType::NO_LOCK; + } + fileInfo = vfs->openFile(path, openFlags, context); + const auto fileLength = fileInfo->getFileSize(); + numPages = ceil(static_cast(fileLength) / static_cast(getPageSize())); + pageCapacity = 0; + while (pageCapacity < numPages) { + pageCapacity += StorageConstants::PAGE_GROUP_SIZE; + } +} + +void FileHandle::constructTmpFileHandle(const std::string& path) { + fileInfo = std::make_unique(path, nullptr); + numPages = 0; + pageCapacity = 0; +} + +page_idx_t FileHandle::addNewPage() { + return addNewPages(1 /* numNewPages */); +} + +page_idx_t FileHandle::addNewPages(page_idx_t numNewPages) { + std::unique_lock lck{fhSharedMutex, std::defer_lock_t{}}; + while (!lck.try_lock()) {} + const auto numPagesBeforeChange = numPages.load(); + for (auto i = 0u; i < numNewPages; i++) { + addNewPageWithoutLock(); + } + return numPagesBeforeChange; +} + +page_idx_t FileHandle::addNewPageWithoutLock() { + if (numPages == pageCapacity) { + addNewPageGroupWithoutLock(); + } + pageStates[numPages].resetToEvicted(); + const auto pageIdx = numPages++; + if (isInMemoryMode()) { + bm->pin(*this, pageIdx, PageReadPolicy::DONT_READ_PAGE); + } + return pageIdx; +} + +void FileHandle::addNewPageGroupWithoutLock() { + pageCapacity += StorageConstants::PAGE_GROUP_SIZE; + pageStates.resize(pageCapacity); + frameGroupIdxes.push_back(bm->addNewFrameGroup(pageSizeClass)); +} + +uint8_t* FileHandle::pinPage(page_idx_t pageIdx, PageReadPolicy readPolicy) { + if (isInMemoryMode()) { + // Already pinned. + return bm->getFrame(*this, pageIdx); + } + return bm->pin(*this, pageIdx, readPolicy); +} + +void FileHandle::optimisticReadPage(page_idx_t pageIdx, + const std::function& readOp) { + if (isInMemoryMode()) { + KU_ASSERT( + PageState::getState(getPageState(pageIdx)->getStateAndVersion()) == PageState::LOCKED); + const auto frame = bm->getFrame(*this, pageIdx); + readOp(frame); + } else { + bm->optimisticRead(*this, pageIdx, readOp); + } +} + +void FileHandle::unpinPage(page_idx_t pageIdx) { + bm->unpin(*this, pageIdx); +} + +void FileHandle::resetToZeroPagesAndPageCapacity() { + removePageIdxAndTruncateIfNecessary(0 /* pageIdx */); + if (isInMemoryMode()) { + for (auto i = 0u; i < numPages; i++) { + bm->unpin(*this, i); + } + } else { + fileInfo->truncate(0 /* size */); + } +} + +uint8_t* FileHandle::getFrame(page_idx_t pageIdx) { + KU_ASSERT(pageIdx < numPages); + return bm->getFrame(*this, pageIdx); +} + +void FileHandle::removePageIdxAndTruncateIfNecessary(page_idx_t pageIdx) { + std::unique_lock xLck{fhSharedMutex}; + if (numPages <= pageIdx) { + return; + } + numPages = pageIdx; + pageStates.resize(numPages); + const auto numPageGroups = getNumPageGroups(); + if (numPageGroups == frameGroupIdxes.size()) { + return; + } + KU_ASSERT(numPageGroups < frameGroupIdxes.size()); + frameGroupIdxes.resize(numPageGroups); + pageCapacity = numPageGroups * StorageConstants::PAGE_GROUP_SIZE; +} + +void FileHandle::removePageFromFrameIfNecessary(page_idx_t pageIdx) { + bm->removePageFromFrameIfNecessary(*this, pageIdx); +} + +void FileHandle::flushAllDirtyPagesInFrames() { + for (auto pageIdx = 0u; pageIdx < numPages; ++pageIdx) { + flushPageIfDirtyWithoutLock(pageIdx); + } +} + +void FileHandle::flushPageIfDirtyWithoutLock(page_idx_t pageIdx) { + auto pageState = getPageState(pageIdx); + if (!isInMemoryMode() && pageState->isDirty()) { + fileInfo->writeFile(getFrame(pageIdx), getPageSize(), pageIdx * getPageSize()); + pageState->clearDirtyWithoutLock(); + } +} + +void FileHandle::writePagesToFile(const uint8_t* buffer, uint64_t size, page_idx_t startPageIdx) { + if (isInMemoryMode()) { + auto pageSize = getPageSize(); + for (uint64_t i = 0; i < size; i += pageSize) { + const auto frame = getFrame(startPageIdx + i / pageSize); + memcpy(frame, buffer + i, std::min(pageSize, size - i)); + } + } else { + fileInfo->writeFile(buffer, size, startPageIdx * getPageSize()); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/free_space_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/free_space_manager.cpp new file mode 100644 index 0000000000..963f4d953d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/free_space_manager.cpp @@ -0,0 +1,303 @@ +#include "storage/free_space_manager.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/in_mem_file_writer.h" +#include "common/serializer/serializer.h" +#include "common/utils.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/file_handle.h" +#include "storage/page_range.h" + +namespace lbug::storage { +static FreeSpaceManager::sorted_free_list_t& getFreeList( + std::vector& freeLists, common::idx_t level) { + if (level >= freeLists.size()) { + freeLists.resize(level + 1, + FreeSpaceManager::sorted_free_list_t{&FreeSpaceManager::entryCmp}); + } + return freeLists[level]; +} + +FreeSpaceManager::FreeSpaceManager() : freeLists{}, numEntries(0), needClearEvictedEntries(false){}; + +common::idx_t FreeSpaceManager::getLevel(common::page_idx_t numPages) { + // level is exponent of largest power of 2 that is <= numPages + // e.g. 2 -> level 1, 5 -> level 2 + KU_ASSERT(numPages > 0); + return common::CountZeros::Trailing(std::bit_floor(numPages)); +} + +bool FreeSpaceManager::entryCmp(const PageRange& a, const PageRange& b) { + return a.numPages == b.numPages ? a.startPageIdx < b.startPageIdx : a.numPages < b.numPages; +} + +void FreeSpaceManager::addFreePages(PageRange entry) { + KU_ASSERT(entry.numPages > 0); + const auto entryLevel = getLevel(entry.numPages); + KU_ASSERT(!getFreeList(freeLists, entryLevel).contains(entry)); + getFreeList(freeLists, entryLevel).insert(entry); + ++numEntries; +} + +void FreeSpaceManager::evictAndAddFreePages(FileHandle* fileHandle, PageRange entry) { + evictPages(fileHandle, entry); + addFreePages(entry); +} + +void FreeSpaceManager::addUncheckpointedFreePages(PageRange entry) { + uncheckpointedFreePageRanges.push_back(entry); +} + +void FreeSpaceManager::rollbackCheckpoint() { + uncheckpointedFreePageRanges.clear(); +} + +// This also removes the chunk from the free space manager +std::optional FreeSpaceManager::popFreePages(common::page_idx_t numPages) { + if (numPages > 0) { + auto levelToSearch = getLevel(numPages); + for (; levelToSearch < freeLists.size(); ++levelToSearch) { + auto& curList = freeLists[levelToSearch]; + auto entryIt = curList.lower_bound(PageRange{0, numPages}); + if (entryIt != curList.end()) { + auto entry = *entryIt; + curList.erase(entryIt); + --numEntries; + return splitPageRange(entry, numPages); + } + } + } + return std::nullopt; +} + +PageRange FreeSpaceManager::splitPageRange(PageRange chunk, common::page_idx_t numRequiredPages) { + KU_ASSERT(chunk.numPages >= numRequiredPages); + PageRange ret{chunk.startPageIdx, numRequiredPages}; + if (numRequiredPages < chunk.numPages) { + PageRange remainingEntry{chunk.startPageIdx + numRequiredPages, + chunk.numPages - numRequiredPages}; + addFreePages(remainingEntry); + } + return ret; +} + +struct SerializePagesUsedTracker { + common::page_idx_t numPagesUsed; + uint64_t numBytesUsedInPage; + + void updatePagesUsed(uint64_t numBytesToAdd) { + if (numBytesUsedInPage + numBytesToAdd > common::InMemFileWriter::getPageSize()) { + ++numPagesUsed; + numBytesUsedInPage = 0; + } + numBytesUsedInPage += numBytesToAdd; + } + + template + void processValue(T) { + updatePagesUsed(sizeof(T)); + } + + void processDebuggingInfo(const std::string& value) { + updatePagesUsed(sizeof(uint64_t) + value.size()); + } +}; + +struct ValueSerializer { + common::Serializer& ser; + + template + void processValue(T value) { + ser.write(value); + } + + void processDebuggingInfo(const std::string& value) { ser.writeDebuggingInfo(value); } +}; + +template +static common::row_idx_t serializeCheckpointedEntries( + const std::vector& freeLists, ValueProcessor& ser) { + auto entryIt = FreeEntryIterator{freeLists}; + common::row_idx_t numWrittenEntries = 0; + while (!entryIt.done()) { + const auto entry = *entryIt; + ser.processValue(entry.startPageIdx); + ser.processValue(entry.numPages); + ++entryIt; + ++numWrittenEntries; + } + return numWrittenEntries; +} + +template +static common::row_idx_t serializeUncheckpointedEntries( + const FreeSpaceManager::free_list_t& uncheckpointedEntries, ValueProcessor& ser) { + for (const auto& entry : uncheckpointedEntries) { + ser.processValue(entry.startPageIdx); + ser.processValue(entry.numPages); + } + return uncheckpointedEntries.size(); +} + +template +void FreeSpaceManager::serializeInternal(ValueProcessor& ser) const { + // we also serialize uncheckpointed entries as serialize() may be called before + // finalizeCheckpoint() + ser.processDebuggingInfo("page_manager"); + const auto numEntries = getNumEntries() + uncheckpointedFreePageRanges.size(); + ser.processDebuggingInfo("numEntries"); + ser.processValue(numEntries); + ser.processDebuggingInfo("entries"); + [[maybe_unused]] const auto numCheckpointedEntries = + serializeCheckpointedEntries(freeLists, ser); + [[maybe_unused]] const auto numUncheckpointedEntries = + serializeUncheckpointedEntries(uncheckpointedFreePageRanges, ser); + KU_ASSERT(numCheckpointedEntries + numUncheckpointedEntries == numEntries); +} + +common::page_idx_t FreeSpaceManager::getMaxNumPagesForSerialization() const { + SerializePagesUsedTracker ser{}; + serializeInternal(ser); + return ser.numPagesUsed + (ser.numBytesUsedInPage > 0); +} + +void FreeSpaceManager::serialize(common::Serializer& ser) const { + ValueSerializer serWrapper{.ser = ser}; + serializeInternal(serWrapper); +} + +void FreeSpaceManager::deserialize(common::Deserializer& deSer) { + std::string key; + + deSer.validateDebuggingInfo(key, "page_manager"); + deSer.validateDebuggingInfo(key, "numEntries"); + common::row_idx_t numEntries{}; + deSer.deserializeValue(numEntries); + + deSer.validateDebuggingInfo(key, "entries"); + for (common::row_idx_t i = 0; i < numEntries; ++i) { + PageRange entry{}; + deSer.deserializeValue(entry.startPageIdx); + deSer.deserializeValue(entry.numPages); + addFreePages(entry); + } +} + +void FreeSpaceManager::evictPages(FileHandle* fileHandle, const PageRange& entry) { + needClearEvictedEntries = true; + for (uint64_t i = 0; i < entry.numPages; ++i) { + const auto pageIdx = entry.startPageIdx + i; + fileHandle->removePageFromFrameIfNecessary(pageIdx); + } +} + +void FreeSpaceManager::finalizeCheckpoint(FileHandle* fileHandle) { + // evict pages before they're added to the free list + for (const auto& entry : uncheckpointedFreePageRanges) { + evictPages(fileHandle, entry); + } + + mergePageRanges(std::move(uncheckpointedFreePageRanges), fileHandle); + uncheckpointedFreePageRanges.clear(); +} + +void FreeSpaceManager::resetFreeLists() { + freeLists.clear(); + numEntries = 0; +} + +void FreeSpaceManager::mergePageRanges(free_list_t newInitialEntries, FileHandle* fileHandle) { + free_list_t allEntries = std::move(newInitialEntries); + for (const auto& freeList : freeLists) { + allEntries.insert(allEntries.end(), freeList.begin(), freeList.end()); + } + if (allEntries.empty()) { + return; + } + + resetFreeLists(); + std::sort(allEntries.begin(), allEntries.end(), [](const auto& entryA, const auto& entryB) { + return entryA.startPageIdx < entryB.startPageIdx; + }); + + PageRange prevEntry = allEntries[0]; + for (common::row_idx_t i = 1; i < allEntries.size(); ++i) { + const auto& entry = allEntries[i]; + KU_ASSERT(prevEntry.startPageIdx + prevEntry.numPages <= entry.startPageIdx); + if (prevEntry.startPageIdx + prevEntry.numPages == entry.startPageIdx) { + prevEntry.numPages += entry.numPages; + } else { + addFreePages(prevEntry); + prevEntry = entry; + } + } + handleLastPageRange(prevEntry, fileHandle); +} + +void FreeSpaceManager::handleLastPageRange(PageRange pageRange, FileHandle* fileHandle) { + if (pageRange.startPageIdx + pageRange.numPages == fileHandle->getNumPages()) { + fileHandle->removePageIdxAndTruncateIfNecessary(pageRange.startPageIdx); + } else { + addFreePages(pageRange); + } +} + +common::row_idx_t FreeSpaceManager::getNumEntries() const { + return numEntries; +} + +std::vector FreeSpaceManager::getEntries(common::row_idx_t startOffset, + common::row_idx_t endOffset) const { + KU_ASSERT(endOffset >= startOffset); + std::vector ret; + FreeEntryIterator it{freeLists}; + it.advance(startOffset); + while (ret.size() < endOffset - startOffset) { + KU_ASSERT(!it.done()); + ret.push_back(*it); + ++it; + } + return ret; +} + +void FreeSpaceManager::clearEvictedBufferManagerEntriesIfNeeded(BufferManager* bufferManager) { + if (needClearEvictedEntries) { + bufferManager->removeEvictedCandidates(); + needClearEvictedEntries = false; + } +} + +void FreeEntryIterator::advance(common::row_idx_t numEntries) { + for (common::row_idx_t i = 0; i < numEntries; ++i) { + ++(*this); + } +} + +void FreeEntryIterator::operator++() { + KU_ASSERT(freeListIdx < freeLists.size()); + ++freeListIt; + if (freeListIt == freeLists[freeListIdx].end()) { + ++freeListIdx; + advanceFreeListIdx(); + } +} + +bool FreeEntryIterator::done() const { + return freeListIdx >= freeLists.size(); +} + +void FreeEntryIterator::advanceFreeListIdx() { + for (; freeListIdx < freeLists.size(); ++freeListIdx) { + if (!freeLists[freeListIdx].empty()) { + freeListIt = freeLists[freeListIdx].begin(); + break; + } + } +} + +PageRange FreeEntryIterator::operator*() const { + KU_ASSERT(freeListIdx < freeLists.size() && freeListIt != freeLists[freeListIdx].end()); + return *freeListIt; +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/CMakeLists.txt new file mode 100644 index 0000000000..494c6421a7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_storage_index + OBJECT + hash_index.cpp + in_mem_hash_index.cpp + index.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/hash_index.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/hash_index.cpp new file mode 100644 index 0000000000..cc76445a69 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/hash_index.cpp @@ -0,0 +1,735 @@ +#include "storage/index/hash_index.h" + +#include + +#include "common/assert.h" +#include "common/exception/message.h" +#include "common/serializer/deserializer.h" +#include "common/types/int128_t.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "common/types/uint128_t.h" +#include "main/client_context.h" +#include "storage/disk_array.h" +#include "storage/disk_array_collection.h" +#include "storage/file_handle.h" +#include "storage/index/hash_index_header.h" +#include "storage/index/hash_index_slot.h" +#include "storage/index/hash_index_utils.h" +#include "storage/index/in_mem_hash_index.h" +#include "storage/local_storage/local_hash_index.h" +#include "storage/overflow_file.h" +#include "storage/shadow_utils.h" +#include "storage/storage_manager.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +template +HashIndex::HashIndex(MemoryManager& memoryManager, OverflowFileHandle* overflowFileHandle, + DiskArrayCollection& diskArrays, uint64_t indexPos, ShadowFile* shadowFile, + const HashIndexHeader& indexHeaderForReadTrx, HashIndexHeader& indexHeaderForWriteTrx) + : shadowFile{shadowFile}, headerPageIdx{0}, overflowFileHandle{overflowFileHandle}, + localStorage{std::make_unique>(memoryManager, overflowFileHandle)}, + indexHeaderForReadTrx{indexHeaderForReadTrx}, indexHeaderForWriteTrx{indexHeaderForWriteTrx}, + memoryManager{memoryManager} { + pSlots = diskArrays.getDiskArray(indexPos); + oSlots = diskArrays.getDiskArray(NUM_HASH_INDEXES + indexPos); +} + +template +void HashIndex::deleteFromPersistentIndex(const Transaction* transaction, Key key, + visible_func isVisible) { + auto& header = this->indexHeaderForWriteTrx; + if (header.numEntries == 0) { + return; + } + auto hashValue = HashIndexUtils::hash(key); + auto fingerprint = HashIndexUtils::getFingerprintForHash(hashValue); + auto iter = + getSlotIterator(HashIndexUtils::getPrimarySlotIdForHash(header, hashValue), transaction); + do { + auto entryPos = findMatchedEntryInSlot(transaction, iter.slot, key, fingerprint, isVisible); + if (entryPos != SlotHeader::INVALID_ENTRY_POS) { + iter.slot.header.setEntryInvalid(entryPos); + updateSlot(transaction, iter.slotInfo, iter.slot); + header.numEntries--; + } + } while (nextChainedSlot(transaction, iter)); +} + +template<> +inline hash_t HashIndex::hashStored(const Transaction* transaction, + const ku_string_t& key) const { + hash_t hash = 0; + const auto str = overflowFileHandle->readString(transaction->getType(), key); + function::Hash::operation(str, hash); + return hash; +} + +template +bool HashIndex::checkpoint(PageAllocator& pageAllocator) { + if (localStorage->hasUpdates()) { + auto transaction = &DUMMY_CHECKPOINT_TRANSACTION; + auto netInserts = localStorage->getNetInserts(); + if (netInserts > 0) { + reserve(pageAllocator, transaction, netInserts); + } + localStorage->applyLocalChanges( + [&](Key) { + // TODO(Guodong/Ben): FIX-ME. We should vacuum the index during checkpoint. + // DO NOTHING. + }, + [&](const auto& insertions) { + mergeBulkInserts(pageAllocator, transaction, insertions); + }); + pSlots->checkpoint(); + oSlots->checkpoint(); + return true; + } + pSlots->checkpoint(); + oSlots->checkpoint(); + return false; +} + +template +bool HashIndex::checkpointInMemory() { + if (!localStorage->hasUpdates()) { + return false; + } + pSlots->checkpointInMemoryIfNecessary(); + oSlots->checkpointInMemoryIfNecessary(); + localStorage->clear(); + if constexpr (std::same_as) { + overflowFileHandle->checkpointInMemory(); + } + return true; +} + +template +bool HashIndex::rollbackInMemory() { + if (!localStorage->hasUpdates()) { + return false; + } + pSlots->rollbackInMemoryIfNecessary(); + oSlots->rollbackInMemoryIfNecessary(); + localStorage->clear(); + return true; +} + +template +void HashIndex::rollbackCheckpoint() { + pSlots->rollbackInMemoryIfNecessary(); + oSlots->rollbackInMemoryIfNecessary(); +} + +template +void HashIndex::reclaimStorage(PageAllocator& pageAllocator) { + pSlots->reclaimStorage(pageAllocator); + oSlots->reclaimStorage(pageAllocator); +} + +template +void HashIndex::splitSlots(PageAllocator& pageAllocator, const Transaction* transaction, + HashIndexHeader& header, slot_id_t numSlotsToSplit) { + auto originalSlotIterator = pSlots->iter_mut(); + auto newSlotIterator = pSlots->iter_mut(); + auto overflowSlotIterator = oSlots->iter_mut(); + // The overflow slot iterators will hang if they access the same page + // So instead buffer new overflow slots here and append them at the end + std::vector newOverflowSlots; + + auto getNextOvfSlot = [&](slot_id_t nextOvfSlotId) { + if (nextOvfSlotId >= oSlots->getNumElements()) { + return &newOverflowSlots[nextOvfSlotId - oSlots->getNumElements()]; + } else { + return &*overflowSlotIterator.seek(nextOvfSlotId); + } + }; + + for (slot_id_t i = 0; i < numSlotsToSplit; i++) { + auto* newSlot = &*newSlotIterator.pushBack(pageAllocator, transaction, OnDiskSlotType()); + entry_pos_t newEntryPos = 0; + OnDiskSlotType* originalSlot = &*originalSlotIterator.seek(header.nextSplitSlotId); + do { + for (entry_pos_t originalEntryPos = 0; originalEntryPos < PERSISTENT_SLOT_CAPACITY; + originalEntryPos++) { + if (!originalSlot->header.isEntryValid(originalEntryPos)) { + continue; // Skip invalid entries. + } + if (newEntryPos >= PERSISTENT_SLOT_CAPACITY) { + newSlot->header.nextOvfSlotId = + newOverflowSlots.size() + oSlots->getNumElements(); + newOverflowSlots.emplace_back(); + newSlot = &newOverflowSlots.back(); + newEntryPos = 0; + } + // Copy entry from old slot to new slot + const auto& key = originalSlot->entries[originalEntryPos].key; + const hash_t hash = this->hashStored(transaction, key); + const auto newSlotId = hash & header.higherLevelHashMask; + if (newSlotId != header.nextSplitSlotId) { + KU_ASSERT(newSlotId == newSlotIterator.idx()); + newSlot->entries[newEntryPos] = originalSlot->entries[originalEntryPos]; + newSlot->header.setEntryValid(newEntryPos, + originalSlot->header.fingerprints[originalEntryPos]); + originalSlot->header.setEntryInvalid(originalEntryPos); + newEntryPos++; + } + } + } while (originalSlot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID && + (originalSlot = getNextOvfSlot(originalSlot->header.nextOvfSlotId))); + header.incrementNextSplitSlotId(); + } + for (auto&& slot : newOverflowSlots) { + overflowSlotIterator.pushBack(pageAllocator, transaction, std::move(slot)); + } +} + +template +std::vector::OnDiskSlotType>> +HashIndex::getChainedSlots(const Transaction* transaction, slot_id_t pSlotId) { + std::vector> slots; + SlotInfo slotInfo{pSlotId, SlotType::PRIMARY}; + while (slotInfo.slotType == SlotType::PRIMARY || + slotInfo.slotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + auto slot = getSlot(transaction, slotInfo); + slots.emplace_back(slotInfo, slot); + slotInfo.slotId = slot.header.nextOvfSlotId; + slotInfo.slotType = SlotType::OVF; + } + return slots; +} + +template +void HashIndex::reserve(PageAllocator& pageAllocator, const Transaction* transaction, + uint64_t newEntries) { + slot_id_t numRequiredEntries = + HashIndexUtils::getNumRequiredEntries(this->indexHeaderForWriteTrx.numEntries + newEntries); + // Can be no fewer slots than the current level requires + auto numRequiredSlots = + std::max((numRequiredEntries + PERSISTENT_SLOT_CAPACITY - 1) / PERSISTENT_SLOT_CAPACITY, + static_cast(1ul << this->indexHeaderForWriteTrx.currentLevel)); + // Always start with at least one page worth of slots. + // This guarantees that when splitting the source and destination slot are never on the same + // page, which allows safe use of multiple disk array iterators. + numRequiredSlots = std::max(numRequiredSlots, LBUG_PAGE_SIZE / pSlots->getAlignedElementSize()); + // If there are no entries, we can just re-size the number of primary slots and re-calculate the + // levels + if (this->indexHeaderForWriteTrx.numEntries == 0) { + pSlots->resize(pageAllocator, transaction, numRequiredSlots); + + auto numSlotsOfCurrentLevel = 1u << this->indexHeaderForWriteTrx.currentLevel; + while ((numSlotsOfCurrentLevel << 1) <= numRequiredSlots) { + this->indexHeaderForWriteTrx.incrementLevel(); + numSlotsOfCurrentLevel <<= 1; + } + if (numRequiredSlots >= numSlotsOfCurrentLevel) { + this->indexHeaderForWriteTrx.nextSplitSlotId = + numRequiredSlots - numSlotsOfCurrentLevel; + } + } else { + splitSlots(pageAllocator, transaction, this->indexHeaderForWriteTrx, + numRequiredSlots - pSlots->getNumElements(transaction->getType())); + } +} + +template +void HashIndex::sortEntries(const Transaction* transaction, + const InMemHashIndex& insertLocalStorage, + typename InMemHashIndex::SlotIterator& slotToMerge, + std::vector& entries) { + do { + auto numEntries = slotToMerge.slot->header.numEntries(); + for (auto entryPos = 0u; entryPos < numEntries; entryPos++) { + const auto* entry = &slotToMerge.slot->entries[entryPos]; + const auto hash = hashStored(transaction, entry->key); + const auto primarySlot = + HashIndexUtils::getPrimarySlotIdForHash(indexHeaderForWriteTrx, hash); + entries.push_back(HashIndexEntryView{primarySlot, + slotToMerge.slot->header.fingerprints[entryPos], entry}); + } + } while (insertLocalStorage.nextChainedSlot(slotToMerge)); + std::sort(entries.begin(), entries.end(), [&](auto entry1, auto entry2) -> bool { + // Sort based on the entry's disk slot ID so that the first slot is at the end. + // Sorting is done reversed so that we can process from the back of the list, + // using the size to track the remaining entries + return entry1.diskSlotId > entry2.diskSlotId; + }); +} + +template +void HashIndex::mergeBulkInserts(PageAllocator& pageAllocator, const Transaction* transaction, + const InMemHashIndex& insertLocalStorage) { + // TODO: Ideally we can split slots at the same time that we insert new ones + // Compute the new number of primary slots, and iterate over each slot, determining if it + // needs to be split (and how many times, which is complicated) and insert/rehash each element + // one by one. Rehashed entries should be copied into a new slot in-memory, and then that new + // slot (with the entries from the respective slot in the local storage) should be processed + // immediately to avoid increasing memory usage (caching one page of slots at a time since split + // slots usually get rehashed to a new page). + // + // On the other hand, two passes may not be significantly slower than one + // TODO: one pass would also reduce locking when frames are unpinned, + // which is useful if this can be parallelized + reserve(pageAllocator, transaction, insertLocalStorage.size()); + // RUNTIME_CHECK(auto originalNumEntries = this->indexHeaderForWriteTrx.numEntries); + + // Storing as many slots in-memory as on-disk shouldn't be necessary (for one, it makes memory + // usage an issue as we may need significantly more memory to store the slots that we would + // otherwise). Instead, when merging here, we can re-hash and split each in-memory slot (into + // temporary vector buffers instead of slots for improved performance) and then merge each of + // those one at a time into the disk slots. That will keep the low memory requirements and still + // let us update each on-disk slot one at a time. + + auto diskSlotIterator = pSlots->iter_mut(); + // TODO: Use a separate random access iterator and one that's sequential for adding new overflow + // slots All new slots will be sequential and benefit from caching, but for existing randomly + // accessed slots we just benefit from the interface. However, the two iterators would not be + // able to pin the same page simultaneously + // Alternatively, cache new slots in memory and pushBack them at the end like in splitSlots + auto diskOverflowSlotIterator = oSlots->iter_mut(); + + // Store sorted slot positions. Re-use to avoid re-allocating memory + // TODO: Unify implementations to make sure this matches the size used by the disk array + constexpr size_t NUM_SLOTS_PER_PAGE = + LBUG_PAGE_SIZE / DiskArray::getAlignedElementSize(); + std::array, NUM_SLOTS_PER_PAGE> partitionedEntries; + // Sort entries for a page of slots at a time, then move vertically and process all entries + // which map to a given page on disk, then horizontally to the next page in the set. These pages + // may not be consecutive, but we reduce the memory overhead for storing the information about + // the sorted data and still just process each page once. + for (uint64_t localSlotId = 0; localSlotId < insertLocalStorage.numPrimarySlots(); + localSlotId += NUM_SLOTS_PER_PAGE) { + for (size_t i = 0; + i < NUM_SLOTS_PER_PAGE && localSlotId + i < insertLocalStorage.numPrimarySlots(); + i++) { + auto localSlot = + typename InMemHashIndex::SlotIterator(localSlotId + i, &insertLocalStorage); + partitionedEntries[i].clear(); + // Results are sorted in reverse, so we can process the end first and pop_back to remove + // them from the vector + sortEntries(transaction, insertLocalStorage, localSlot, partitionedEntries[i]); + } + // Repeat until there are no unprocessed partitions in partitionedEntries + // This will run at most NUM_SLOTS_PER_PAGE times the number of entries + std::bitset done; + while (!done.all()) { + std::optional diskSlotPage; + for (size_t i = 0; i < NUM_SLOTS_PER_PAGE; i++) { + if (!done[i] && !partitionedEntries[i].empty()) { + auto diskSlotId = partitionedEntries[i].back().diskSlotId; + if (!diskSlotPage) { + diskSlotPage = diskSlotId / NUM_SLOTS_PER_PAGE; + } + if (diskSlotId / NUM_SLOTS_PER_PAGE == diskSlotPage) { + auto merged = mergeSlot(pageAllocator, transaction, partitionedEntries[i], + diskSlotIterator, diskOverflowSlotIterator, diskSlotId); + KU_ASSERT(merged <= partitionedEntries[i].size()); + partitionedEntries[i].resize(partitionedEntries[i].size() - merged); + if (partitionedEntries[i].empty()) { + done[i] = true; + } + } + } else { + done[i] = true; + } + } + } + } + // TODO(Guodong): Fix this assertion statement which doesn't count the entries in + // deleteLocalStorage. + // KU_ASSERT(originalNumEntries + insertLocalStorage.getIndexHeader().numEntries == + // indexHeaderForWriteTrx.numEntries); +} + +template +size_t HashIndex::mergeSlot(PageAllocator& pageAllocator, const Transaction* transaction, + const std::vector& slotToMerge, + typename DiskArray::WriteIterator& diskSlotIterator, + typename DiskArray::WriteIterator& diskOverflowSlotIterator, + slot_id_t diskSlotId) { + slot_id_t diskEntryPos = 0u; + // mergeSlot should only be called when there is at least one entry for the given disk slot id + // in the slot to merge + OnDiskSlotType* diskSlot = &*diskSlotIterator.seek(diskSlotId); + KU_ASSERT(diskSlot->header.nextOvfSlotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID || + diskOverflowSlotIterator.size() > diskSlot->header.nextOvfSlotId); + // Merge slot from local storage to an existing slot. + size_t merged = 0; + for (auto it = std::rbegin(slotToMerge); it != std::rend(slotToMerge); ++it) { + if (it->diskSlotId != diskSlotId) { + return merged; + } + // Find the next empty entry or add a new slot if there are no more entries + while (diskSlot->header.isEntryValid(diskEntryPos) || + diskEntryPos >= PERSISTENT_SLOT_CAPACITY) { + diskEntryPos++; + if (diskEntryPos >= PERSISTENT_SLOT_CAPACITY) { + if (diskSlot->header.nextOvfSlotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + // If there are no more disk slots in this chain, we need to add one + diskSlot->header.nextOvfSlotId = diskOverflowSlotIterator.size(); + // This may invalidate diskSlot + diskOverflowSlotIterator.pushBack(pageAllocator, transaction, OnDiskSlotType()); + KU_ASSERT( + diskSlot->header.nextOvfSlotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID || + diskOverflowSlotIterator.size() > diskSlot->header.nextOvfSlotId); + } else { + diskOverflowSlotIterator.seek(diskSlot->header.nextOvfSlotId); + KU_ASSERT( + diskSlot->header.nextOvfSlotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID || + diskOverflowSlotIterator.size() > diskSlot->header.nextOvfSlotId); + } + diskSlot = &*diskOverflowSlotIterator; + // Check to make sure we're not looping + KU_ASSERT(diskOverflowSlotIterator.idx() != diskSlot->header.nextOvfSlotId); + diskEntryPos = 0; + } + } + KU_ASSERT(diskEntryPos < PERSISTENT_SLOT_CAPACITY); + if constexpr (std::is_same_v) { + auto* inMemEntry = it->entry; + auto kuString = overflowFileHandle->writeString(&pageAllocator, inMemEntry->key); + diskSlot->entries[diskEntryPos] = SlotEntry{kuString, inMemEntry->value}; + } else { + diskSlot->entries[diskEntryPos] = *it->entry; + } + diskSlot->header.setEntryValid(diskEntryPos, it->fingerprint); + KU_ASSERT([&]() { + const auto& key = it->entry->key; + const auto hash = hashStored(transaction, key); + const auto primarySlot = + HashIndexUtils::getPrimarySlotIdForHash(indexHeaderForWriteTrx, hash); + KU_ASSERT(it->fingerprint == HashIndexUtils::getFingerprintForHash(hash)); + KU_ASSERT(primarySlot == diskSlotId); + return true; + }()); + indexHeaderForWriteTrx.numEntries++; + diskEntryPos++; + merged++; + } + return merged; +} + +template +void HashIndex::bulkReserve(uint64_t newEntries) { + return localStorage->reserveInserts(newEntries); +} + +template +HashIndex::~HashIndex() = default; + +template<> +bool HashIndex::equals(const transaction::Transaction* transaction, + std::string_view keyToLookup, const common::ku_string_t& keyInEntry) const { + if (!HashIndexUtils::areStringPrefixAndLenEqual(keyToLookup, keyInEntry)) { + return false; + } + if (keyInEntry.len <= common::ku_string_t::PREFIX_LENGTH) { + // For strings shorter than PREFIX_LENGTH, the result must be true. + return true; + } else if (keyInEntry.len <= common::ku_string_t::SHORT_STR_LENGTH) { + // For short strings, whose lengths are larger than PREFIX_LENGTH, check if their + // actual values are equal. + return memcmp(keyToLookup.data(), keyInEntry.prefix, keyInEntry.len) == 0; + } else { + // For long strings, compare with overflow data + return overflowFileHandle->equals(transaction->getType(), keyToLookup, keyInEntry); + } +} + +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; +template class HashIndex; + +std::unique_ptr PrimaryKeyIndexStorageInfo::deserialize( + std::unique_ptr reader) { + page_idx_t firstHeaderPage = INVALID_PAGE_IDX; + page_idx_t overflowHeaderPage = INVALID_PAGE_IDX; + Deserializer deSer(std::move(reader)); + deSer.deserializeValue(firstHeaderPage); + deSer.deserializeValue(overflowHeaderPage); + return std::make_unique(firstHeaderPage, overflowHeaderPage); +} + +std::unique_ptr PrimaryKeyIndex::createNewIndex(IndexInfo indexInfo, + bool inMemMode, MemoryManager& memoryManager, PageAllocator& pageAllocator, + ShadowFile* shadowFile) { + return std::make_unique(std::move(indexInfo), + std::make_unique(), inMemMode, memoryManager, pageAllocator, + shadowFile); +} + +PrimaryKeyIndex::PrimaryKeyIndex(IndexInfo indexInfo, std::unique_ptr storageInfo, + bool inMemMode, MemoryManager& memoryManager, PageAllocator& pageAllocator, + ShadowFile* shadowFile) + : Index{std::move(indexInfo), std::move(storageInfo)}, shadowFile{*shadowFile} { + auto& hashIndexStorageInfo = this->storageInfo->cast(); + if (hashIndexStorageInfo.firstHeaderPage == INVALID_PAGE_IDX) { + KU_ASSERT(hashIndexStorageInfo.overflowHeaderPage == INVALID_PAGE_IDX); + hashIndexHeadersForReadTrx.resize(NUM_HASH_INDEXES); + hashIndexHeadersForWriteTrx.resize(NUM_HASH_INDEXES); + hashIndexDiskArrays = std::make_unique(*pageAllocator.getDataFH(), + *shadowFile, true /*bypassShadowing*/); + // Each index has a primary slot array and an overflow slot array + for (size_t i = 0; i < NUM_HASH_INDEXES * 2; i++) { + hashIndexDiskArrays->addDiskArray(); + } + } else { + size_t headerIdx = 0; + for (size_t headerPageIdx = 0; headerPageIdx < INDEX_HEADER_PAGES; headerPageIdx++) { + pageAllocator.getDataFH()->optimisticReadPage( + hashIndexStorageInfo.firstHeaderPage + headerPageIdx, [&](auto* frame) { + const auto onDiskHeaders = reinterpret_cast(frame); + for (size_t i = 0; i < INDEX_HEADERS_PER_PAGE && headerIdx < NUM_HASH_INDEXES; + i++) { + hashIndexHeadersForReadTrx.emplace_back(onDiskHeaders[i]); + headerIdx++; + } + }); + } + hashIndexHeadersForWriteTrx.assign(hashIndexHeadersForReadTrx.begin(), + hashIndexHeadersForReadTrx.end()); + KU_ASSERT(headerIdx == NUM_HASH_INDEXES); + hashIndexDiskArrays = std::make_unique(*pageAllocator.getDataFH(), + *shadowFile, + hashIndexStorageInfo.firstHeaderPage + + INDEX_HEADER_PAGES /*firstHeaderPage for the DAC follows the index header pages*/, + true /*bypassShadowing*/); + } + initOverflowAndSubIndices(inMemMode, memoryManager, pageAllocator, hashIndexStorageInfo); +} + +void PrimaryKeyIndex::initOverflowAndSubIndices(bool inMemMode, MemoryManager& mm, + PageAllocator& pageAllocator, PrimaryKeyIndexStorageInfo& storageInfo) { + KU_ASSERT(indexInfo.keyDataTypes.size() == 1); + if (indexInfo.keyDataTypes[0] == PhysicalTypeID::STRING) { + if (inMemMode) { + overflowFile = std::make_unique(mm); + } else { + overflowFile = std::make_unique(pageAllocator.getDataFH(), mm, + &shadowFile, storageInfo.overflowHeaderPage); + } + } + hashIndices.reserve(NUM_HASH_INDEXES); + TypeUtils::visit( + indexInfo.keyDataTypes[0], + [&](ku_string_t) { + for (auto i = 0u; i < NUM_HASH_INDEXES; i++) { + hashIndices.push_back(std::make_unique>(mm, + overflowFile->addHandle(), *hashIndexDiskArrays, i, &shadowFile, + hashIndexHeadersForReadTrx[i], hashIndexHeadersForWriteTrx[i])); + } + }, + [&](T) { + for (auto i = 0u; i < NUM_HASH_INDEXES; i++) { + hashIndices.push_back(std::make_unique>(mm, nullptr, + *hashIndexDiskArrays, i, &shadowFile, hashIndexHeadersForReadTrx[i], + hashIndexHeadersForWriteTrx[i])); + } + }, + [&](auto) { KU_UNREACHABLE; }); +} + +bool PrimaryKeyIndex::lookup(const Transaction* trx, ValueVector* keyVector, uint64_t vectorPos, + offset_t& result, visible_func isVisible) { + bool retVal = false; + KU_ASSERT(indexInfo.keyDataTypes.size() == 1); + TypeUtils::visit( + indexInfo.keyDataTypes[0], + [&](T) { + T key = keyVector->getValue(vectorPos); + retVal = lookup(trx, key, result, isVisible); + }, + [](auto) { KU_UNREACHABLE; }); + return retVal; +} + +void PrimaryKeyIndex::commitInsert(Transaction* transaction, const ValueVector& nodeIDVector, + const std::vector& indexVectors, Index::InsertState& insertState) { + KU_ASSERT(indexVectors.size() == 1); + const auto& pkVector = *indexVectors[0]; + const auto& pkInsertState = insertState.cast(); + for (auto i = 0u; i < nodeIDVector.state->getSelSize(); i++) { + const auto nodeIDPos = nodeIDVector.state->getSelVector()[i]; + const auto offset = nodeIDVector.readNodeOffset(nodeIDPos); + const auto pkPos = pkVector.state->getSelVector()[i]; + if (pkVector.isNull(pkPos)) { + throw RuntimeException(ExceptionMessage::nullPKException()); + } + if (!insert(transaction, &pkVector, pkPos, offset, pkInsertState.isVisible)) { + throw RuntimeException( + ExceptionMessage::duplicatePKException(pkVector.getAsValue(pkPos)->toString())); + } + } +} + +bool PrimaryKeyIndex::insert(const Transaction* transaction, const ValueVector* keyVector, + uint64_t vectorPos, offset_t value, visible_func isVisible) { + bool result = false; + KU_ASSERT(indexInfo.keyDataTypes.size() == 1); + TypeUtils::visit( + indexInfo.keyDataTypes[0], + [&](T) { + T key = keyVector->getValue(vectorPos); + result = insert(transaction, key, value, isVisible); + }, + [](auto) { KU_UNREACHABLE; }); + return result; +} + +void PrimaryKeyIndex::delete_(ValueVector* keyVector) { + KU_ASSERT(indexInfo.keyDataTypes.size() == 1); + TypeUtils::visit( + indexInfo.keyDataTypes[0], + [&](T) { + for (auto i = 0u; i < keyVector->state->getSelVector().getSelSize(); i++) { + auto pos = keyVector->state->getSelVector()[i]; + if (keyVector->isNull(pos)) { + continue; + } + auto key = keyVector->getValue(pos); + delete_(key); + } + }, + [](auto) { KU_UNREACHABLE; }); +} + +void PrimaryKeyIndex::checkpointInMemory() { + bool indexChanged = false; + for (auto i = 0u; i < NUM_HASH_INDEXES; i++) { + if (hashIndices[i]->checkpointInMemory()) { + indexChanged = true; + } + } + if (indexChanged) { + for (size_t i = 0; i < NUM_HASH_INDEXES; i++) { + hashIndexHeadersForReadTrx[i] = hashIndexHeadersForWriteTrx[i]; + } + hashIndexDiskArrays->checkpointInMemory(); + } + if (overflowFile) { + overflowFile->checkpointInMemory(); + } +} + +void PrimaryKeyIndex::writeHeaders(PageAllocator& pageAllocator) const { + size_t headerIdx = 0; + auto& hashIndexStorageInfo = storageInfo->cast(); + if (hashIndexStorageInfo.firstHeaderPage == INVALID_PAGE_IDX) { + const auto allocatedPages = pageAllocator.allocatePageRange( + NUM_HEADER_PAGES + 1 /*first DiskArrayCollection header page*/); + hashIndexStorageInfo.firstHeaderPage = allocatedPages.startPageIdx; + } + for (size_t headerPageIdx = 0; headerPageIdx < INDEX_HEADER_PAGES; headerPageIdx++) { + ShadowUtils::updatePage(*pageAllocator.getDataFH(), + hashIndexStorageInfo.firstHeaderPage + headerPageIdx, + true /*writing all the data to the page; no need to read original*/, shadowFile, + [&](auto* frame) { + const auto onDiskFrame = reinterpret_cast(frame); + for (size_t i = 0; i < INDEX_HEADERS_PER_PAGE && headerIdx < NUM_HASH_INDEXES; + i++) { + hashIndexHeadersForWriteTrx[headerIdx++].write(onDiskFrame[i]); + } + }); + } + KU_ASSERT(headerIdx == NUM_HASH_INDEXES); +} + +void PrimaryKeyIndex::rollbackCheckpoint() { + for (idx_t i = 0; i < NUM_HASH_INDEXES; ++i) { + hashIndices[i]->rollbackCheckpoint(); + } + hashIndexDiskArrays->rollbackCheckpoint(); + hashIndexHeadersForWriteTrx.assign(hashIndexHeadersForReadTrx.begin(), + hashIndexHeadersForReadTrx.end()); + if (overflowFile) { + overflowFile->rollbackInMemory(); + } +} + +static void updateOverflowHeaderPageIfNeeded(IndexStorageInfo* storageInfo, + OverflowFile* overflowFile) { + auto& hashIndexStorageInfo = storageInfo->cast(); + if (hashIndexStorageInfo.overflowHeaderPage == INVALID_PAGE_IDX) { + hashIndexStorageInfo.overflowHeaderPage = overflowFile->getHeaderPageIdx(); + } +} + +void PrimaryKeyIndex::checkpoint(main::ClientContext*, storage::PageAllocator& pageAllocator) { + bool indexChanged = false; + for (auto i = 0u; i < NUM_HASH_INDEXES; i++) { + if (hashIndices[i]->checkpoint(pageAllocator)) { + indexChanged = true; + } + } + if (indexChanged) { + writeHeaders(pageAllocator); + hashIndexDiskArrays->checkpoint(getDiskArrayFirstHeaderPage(), pageAllocator); + } + if (overflowFile) { + overflowFile->checkpoint(pageAllocator); + updateOverflowHeaderPageIfNeeded(storageInfo.get(), overflowFile.get()); + } + // Make sure that changes which bypassed the WAL are written. + // There is no other mechanism for enforcing that they are flushed + // and they will be dropped when the file handle is destroyed. + // TODO: Should eventually be moved into the disk array when the disk array can + // generally handle bypassing the WAL, but should only be run once per file, not once per + // disk array + pageAllocator.getDataFH()->flushAllDirtyPagesInFrames(); + checkpointInMemory(); +} + +PrimaryKeyIndex::~PrimaryKeyIndex() = default; + +std::unique_ptr PrimaryKeyIndex::load(main::ClientContext* context, + StorageManager* storageManager, IndexInfo indexInfo, std::span storageInfoBuffer) { + auto storageInfoBufferReader = + std::make_unique(storageInfoBuffer.data(), storageInfoBuffer.size()); + auto storageInfo = PrimaryKeyIndexStorageInfo::deserialize(std::move(storageInfoBufferReader)); + return std::make_unique(indexInfo, std::move(storageInfo), + storageManager->isInMemory(), *MemoryManager::Get(*context), + *storageManager->getDataFH()->getPageManager(), &storageManager->getShadowFile()); +} + +void PrimaryKeyIndex::reclaimStorage(PageAllocator& pageAllocator) const { + for (auto& hashIndex : hashIndices) { + hashIndex->reclaimStorage(pageAllocator); + } + hashIndexDiskArrays->reclaimStorage(pageAllocator, getDiskArrayFirstHeaderPage()); + if (overflowFile) { + overflowFile->reclaimStorage(pageAllocator); + } + const auto firstHeaderPage = getFirstHeaderPage(); + if (firstHeaderPage != INVALID_PAGE_IDX) { + pageAllocator.freePageRange({getFirstHeaderPage(), NUM_HEADER_PAGES}); + } +} + +page_idx_t PrimaryKeyIndex::getDiskArrayFirstHeaderPage() const { + const auto firstHeaderPage = getFirstHeaderPage(); + return firstHeaderPage == INVALID_PAGE_IDX ? INVALID_PAGE_IDX : + firstHeaderPage + NUM_HEADER_PAGES; +} + +page_idx_t PrimaryKeyIndex::getFirstHeaderPage() const { + return storageInfo->cast().firstHeaderPage; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/in_mem_hash_index.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/in_mem_hash_index.cpp new file mode 100644 index 0000000000..d532d9be43 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/in_mem_hash_index.cpp @@ -0,0 +1,289 @@ +#include "storage/index/in_mem_hash_index.h" + +#include +#include + +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/disk_array.h" +#include "storage/index/hash_index_header.h" +#include "storage/index/hash_index_slot.h" +#include "storage/index/hash_index_utils.h" +#include "storage/overflow_file.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +template +InMemHashIndex::InMemHashIndex(MemoryManager& memoryManager, + OverflowFileHandle* overflowFileHandle) + : overflowFileHandle(overflowFileHandle), + pSlots{std::make_unique>(memoryManager)}, + oSlots{std::make_unique>(memoryManager)}, indexHeader{}, + memoryManager{memoryManager}, numFreeSlots{0} { + // Match HashIndex in allocating at least one page of slots so that we don't split within the + // same page + allocateSlots(LBUG_PAGE_SIZE / pSlots->getAlignedElementSize()); +} + +template +void InMemHashIndex::clear() { + indexHeader = HashIndexHeader(); + pSlots = std::make_unique>(memoryManager); + oSlots = std::make_unique>(memoryManager); + allocateSlots(LBUG_PAGE_SIZE / pSlots->getAlignedElementSize()); +} + +template +void InMemHashIndex::allocateSlots(uint32_t newNumSlots) { + // Allocate memory before updating the header in case the memory allocation fails + auto existingSlots = pSlots->size(); + if (newNumSlots > existingSlots) { + allocatePSlots(newNumSlots - existingSlots); + } + auto numSlotsOfCurrentLevel = 1u << this->indexHeader.currentLevel; + while ((numSlotsOfCurrentLevel << 1) <= newNumSlots) { + this->indexHeader.incrementLevel(); + numSlotsOfCurrentLevel <<= 1; + } + if (newNumSlots >= numSlotsOfCurrentLevel) { + this->indexHeader.nextSplitSlotId = newNumSlots - numSlotsOfCurrentLevel; + } +} + +template +void InMemHashIndex::reserve(uint32_t numEntries_) { + slot_id_t numRequiredEntries = HashIndexUtils::getNumRequiredEntries(numEntries_); + auto numRequiredSlots = (numRequiredEntries + SLOT_CAPACITY - 1) / SLOT_CAPACITY; + if (numRequiredSlots <= pSlots->size()) { + return; + } + if (indexHeader.numEntries == 0) { + allocateSlots(numRequiredSlots); + } else { + while (pSlots->size() < numRequiredSlots) { + splitSlot(); + } + } +} + +template +uint64_t InMemHashIndex::countSlots(SlotIterator iter) const { + if (iter.slotInfo.slotType == SlotType::OVF && + iter.slotInfo.slotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + return 0; + } + uint64_t count = 1; + while (nextChainedSlot(iter)) { + count++; + } + return count; +} + +template +void InMemHashIndex::reserveOverflowSlots(uint64_t totalSlotsRequired) { + // Make sure we have enough free slots to do the split without having to allocate more + // Any unused ones will just stay in the free slot chain and will be reused later + SlotInfo newSlot{oSlots->size(), SlotType::OVF}; + if (totalSlotsRequired > numFreeSlots) { + oSlots->resize(oSlots->size() + totalSlotsRequired - numFreeSlots); + for (uint64_t i = 0; i < totalSlotsRequired - numFreeSlots; i++) { + auto slot = getSlot(newSlot); + addFreeOverflowSlot(*slot, newSlot); + newSlot.slotId++; + } + } +} + +template +void InMemHashIndex::splitSlot() { + // Add new slot + allocatePSlots(1); + + // Rehash the entries in the slot to split + SlotIterator originalSlot(indexHeader.nextSplitSlotId, this); + // Reserve enough overflow slots to be able to finish splitting without allocating any new + // memory Otherwise we run the risk of leaving the hash index in an invalid state if we fail to + // allocate a new overflow slot + // TODO(bmwinger): If we split slots backwards instead of forwards we would need to reserve just + // one slot Since we could then reclaim slots that have been emptied. That would require making + // the slots doubly-linked + reserveOverflowSlots(countSlots(originalSlot)); + + // Use a separate iterator to track the first empty position so that the gapless entries can + // be maintained + SlotIterator originalSlotForInsert(indexHeader.nextSplitSlotId, this); + auto entryPosToInsert = 0u; + SlotIterator newSlot(pSlots->size() - 1, this); + entry_pos_t newSlotPos = 0; + bool gaps = false; + do { + for (auto entryPos = 0u; entryPos < SLOT_CAPACITY; entryPos++) { + if (!originalSlot.slot->header.isEntryValid(entryPos)) { + // Check that this function leaves no gaps + KU_ASSERT(originalSlot.slot->header.numEntries() == + std::countr_one(originalSlot.slot->header.validityMask)); + // There should be no gaps, so when we encounter an invalid entry we can return + // early + reclaimOverflowSlots(originalSlotForInsert); + indexHeader.incrementNextSplitSlotId(); + return; + } + const auto& entry = originalSlot.slot->entries[entryPos]; + const auto& hash = this->hashStored(originalSlot.slot->entries[entryPos].key); + const auto fingerprint = HashIndexUtils::getFingerprintForHash(hash); + const auto newSlotId = hash & indexHeader.higherLevelHashMask; + if (newSlotId != indexHeader.nextSplitSlotId) { + if (newSlotPos >= SLOT_CAPACITY) { + auto newOvfSlotId = allocateAOSlot(); + newSlot.slot->header.nextOvfSlotId = newOvfSlotId; + [[maybe_unused]] bool hadNextSlot = nextChainedSlot(newSlot); + KU_ASSERT(hadNextSlot); + newSlotPos = 0; + } + newSlot.slot->entries[newSlotPos] = entry; + newSlot.slot->header.setEntryValid(newSlotPos, fingerprint); + originalSlot.slot->header.setEntryInvalid(entryPos); + newSlotPos++; + gaps = true; + } else if (gaps) { + // If we have created a gap previously, move the entry to the first gap to avoid + // leaving gaps + while (originalSlotForInsert.slot->header.isEntryValid(entryPosToInsert)) { + entryPosToInsert++; + if (entryPosToInsert >= SLOT_CAPACITY) { + entryPosToInsert = 0; + // There should always be another slot since we can't split more entries + // than there were to begin with + [[maybe_unused]] bool hadNextSlot = nextChainedSlot(originalSlotForInsert); + KU_ASSERT(hadNextSlot); + } + } + originalSlotForInsert.slot->entries[entryPosToInsert] = entry; + originalSlotForInsert.slot->header.setEntryValid(entryPosToInsert, fingerprint); + originalSlot.slot->header.setEntryInvalid(entryPos); + } + } + KU_ASSERT(originalSlot.slot->header.numEntries() == + std::countr_one(originalSlot.slot->header.validityMask)); + } while (nextChainedSlot(originalSlot)); + + reclaimOverflowSlots(originalSlotForInsert); + indexHeader.incrementNextSplitSlotId(); +} + +template +void InMemHashIndex::addFreeOverflowSlot(InMemSlotType& overflowSlot, SlotInfo slotInfo) { + // This function should only be called on slots that can be directly inserted into the free slot + // list + KU_ASSERT(slotInfo.slotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID); + KU_ASSERT(overflowSlot.header.nextOvfSlotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID); + KU_ASSERT(slotInfo.slotType == SlotType::OVF); + overflowSlot.header.nextOvfSlotId = indexHeader.firstFreeOverflowSlotId; + indexHeader.firstFreeOverflowSlotId = slotInfo.slotId; + numFreeSlots++; +} + +template +void InMemHashIndex::reclaimOverflowSlots(SlotIterator iter) { + // Reclaim empty overflow slots at the end of the chain. + // This saves the cost of having to iterate over them, and reduces memory usage by letting them + // be used instead of allocating new slots + if (iter.slot->header.nextOvfSlotId != SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + // Skip past the last non-empty entry + InMemSlotType* lastNonEmptySlot = iter.slot; + while (iter.slot->header.numEntries() > 0 || iter.slotInfo.slotType == SlotType::PRIMARY) { + lastNonEmptySlot = iter.slot; + if (!nextChainedSlot(iter)) { + iter.slotInfo = HashIndexUtils::INVALID_OVF_INFO; + break; + } + } + lastNonEmptySlot->header.nextOvfSlotId = SlotHeader::INVALID_OVERFLOW_SLOT_ID; + while (iter.slotInfo != HashIndexUtils::INVALID_OVF_INFO) { + // Remove empty overflow slots from slot chain + KU_ASSERT(iter.slot->header.numEntries() == 0); + auto slotInfo = iter.slotInfo; + auto slot = clearNextOverflowAndAdvanceIter(iter); + if (slotInfo.slotType == SlotType::OVF) { + // Insert empty slot into free slot chain + addFreeOverflowSlot(*slot, slotInfo); + } + } + } +} + +template +InMemHashIndex::InMemSlotType* InMemHashIndex::clearNextOverflowAndAdvanceIter( + SlotIterator& iter) { + auto originalSlot = iter.slot; + auto nextOverflowSlot = iter.slot->header.nextOvfSlotId; + iter.slot->header.nextOvfSlotId = SlotHeader::INVALID_OVERFLOW_SLOT_ID; + iter.slotInfo = SlotInfo{nextOverflowSlot, SlotType::OVF}; + if (nextOverflowSlot != SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + iter.slot = getSlot(iter.slotInfo); + } + return originalSlot; +} + +template +uint32_t InMemHashIndex::allocatePSlots(uint32_t numSlotsToAllocate) { + auto oldNumSlots = pSlots->size(); + auto newNumSlots = oldNumSlots + numSlotsToAllocate; + pSlots->resize(newNumSlots); + return oldNumSlots; +} + +template +uint32_t InMemHashIndex::allocateAOSlot() { + if (indexHeader.firstFreeOverflowSlotId == SlotHeader::INVALID_OVERFLOW_SLOT_ID) { + auto oldNumSlots = oSlots->size(); + auto newNumSlots = oldNumSlots + 1; + oSlots->resize(newNumSlots); + return oldNumSlots; + } else { + auto freeOSlotId = indexHeader.firstFreeOverflowSlotId; + auto& slot = (*oSlots)[freeOSlotId]; + // Remove slot from the free slot chain + indexHeader.firstFreeOverflowSlotId = slot.header.nextOvfSlotId; + KU_ASSERT(slot.header.numEntries() == 0); + slot.header.nextOvfSlotId = SlotHeader::INVALID_OVERFLOW_SLOT_ID; + KU_ASSERT(numFreeSlots > 0); + numFreeSlots--; + return freeOSlotId; + } +} + +template +InMemHashIndex::InMemSlotType* InMemHashIndex::getSlot(const SlotInfo& slotInfo) const { + if (slotInfo.slotType == SlotType::PRIMARY) { + return &pSlots->operator[](slotInfo.slotId); + } else { + return &oSlots->operator[](slotInfo.slotId); + } +} + +template +common::hash_t InMemHashIndex::hashStored(const InMemHashIndex::OwnedType& key) const { + return HashIndexUtils::hash(key); +} + +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; +template class InMemHashIndex; + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/index.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/index.cpp new file mode 100644 index 0000000000..e242b9b3ae --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/index/index.cpp @@ -0,0 +1,106 @@ +#include "storage/index/index.h" + +#include "common/exception/runtime.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "main/client_context.h" +#include "storage/storage_manager.h" + +namespace lbug { +namespace storage { + +IndexStorageInfo::~IndexStorageInfo() = default; + +Index::InsertState::~InsertState() = default; + +Index::UpdateState::~UpdateState() = default; + +Index::DeleteState::~DeleteState() = default; + +Index::~Index() = default; + +bool Index::isBuiltOnColumn(common::column_id_t columnID) const { + auto it = std::find(indexInfo.columnIDs.begin(), indexInfo.columnIDs.end(), columnID); + return it != indexInfo.columnIDs.end(); +} + +void IndexInfo::serialize(common::Serializer& ser) const { + ser.write(name); + ser.write(indexType); + ser.write(tableID); + ser.serializeVector(columnIDs); + ser.serializeVector(keyDataTypes); + ser.write(isPrimary); + ser.write(isBuiltin); +} + +IndexInfo IndexInfo::deserialize(common::Deserializer& deSer) { + std::string name; + std::string indexType; + common::table_id_t tableID = common::INVALID_TABLE_ID; + std::vector columnIDs; + std::vector keyDataTypes; + bool isPrimary = false; + bool isBuiltin = false; + deSer.deserializeValue(name); + deSer.deserializeValue(indexType); + deSer.deserializeValue(tableID); + deSer.deserializeVector(columnIDs); + deSer.deserializeVector(keyDataTypes); + deSer.deserializeValue(isPrimary); + deSer.deserializeValue(isBuiltin); + return IndexInfo{std::move(name), std::move(indexType), tableID, std::move(columnIDs), + std::move(keyDataTypes), isPrimary, isBuiltin}; +} + +std::shared_ptr IndexStorageInfo::serialize() const { + return std::make_shared(0 /*maximumSize*/); +} + +void Index::serialize(common::Serializer& ser) const { + indexInfo.serialize(ser); + auto bufferedWriter = storageInfo->serialize(); + ser.write(bufferedWriter->getSize()); + ser.write(bufferedWriter->getData().data.get(), bufferedWriter->getSize()); +} + +IndexHolder::IndexHolder(std::unique_ptr loadedIndex) + : indexInfo{loadedIndex->getIndexInfo()}, storageInfoBuffer{nullptr}, storageInfoBufferSize{0}, + loaded{true}, index{std::move(loadedIndex)} {} + +IndexHolder::IndexHolder(IndexInfo indexInfo, std::unique_ptr storageInfoBuffer, + uint32_t storageInfoBufferSize) + : indexInfo{std::move(indexInfo)}, storageInfoBuffer{std::move(storageInfoBuffer)}, + storageInfoBufferSize{storageInfoBufferSize}, loaded{false}, index{nullptr} {} + +void IndexHolder::serialize(common::Serializer& ser) const { + if (loaded) { + KU_ASSERT(index); + index->serialize(ser); + } else { + indexInfo.serialize(ser); + ser.write(storageInfoBufferSize); + if (storageInfoBufferSize > 0) { + KU_ASSERT(storageInfoBuffer); + ser.write(storageInfoBuffer.get(), storageInfoBufferSize); + } + } +} + +void IndexHolder::load(main::ClientContext* context, StorageManager* storageManager) { + if (loaded) { + return; + } + KU_ASSERT(!index); + KU_ASSERT(storageInfoBuffer); + auto indexTypeOptional = StorageManager::Get(*context)->getIndexType(indexInfo.indexType); + if (!indexTypeOptional.has_value()) { + throw common::RuntimeException("No index type with name: " + indexInfo.indexType); + } + index = indexTypeOptional.value().get().loadFunc(context, storageManager, indexInfo, + std::span(storageInfoBuffer.get(), storageInfoBufferSize)); + loaded = true; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/CMakeLists.txt new file mode 100644 index 0000000000..a30accd400 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_storage_local_storage + OBJECT + local_node_table.cpp + local_rel_table.cpp + local_storage.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_node_table.cpp new file mode 100644 index 0000000000..5d71d5526d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_node_table.cpp @@ -0,0 +1,129 @@ +#include "storage/local_storage/local_node_table.h" + +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/cast.h" +#include "common/exception/message.h" +#include "common/types/types.h" +#include "common/types/value/value.h" +#include "storage/index/hash_index.h" +#include "storage/storage_utils.h" +#include "storage/table/node_table.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +std::vector LocalNodeTable::getNodeTableColumnTypes( + const catalog::TableCatalogEntry& table) { + std::vector types; + for (auto& property : table.getProperties()) { + types.push_back(property.getType().copy()); + } + return types; +} + +LocalNodeTable::LocalNodeTable(const catalog::TableCatalogEntry* tableEntry, Table& table, + MemoryManager& mm) + : LocalTable{table}, overflowFileHandle(nullptr), + nodeGroups{mm, getNodeTableColumnTypes(*tableEntry), false /*enableCompression*/} { + initLocalHashIndex(mm); + startOffset = table.getNumTotalRows(nullptr /* transaction */); +} + +void LocalNodeTable::initLocalHashIndex(MemoryManager& mm) { + auto& nodeTable = ku_dynamic_cast(table); + overflowFile = std::make_unique(mm); + overflowFileHandle = overflowFile->addHandle(); + hashIndex = std::make_unique(mm, + nodeTable.getColumn(nodeTable.getPKColumnID()).getDataType().getPhysicalType(), + overflowFileHandle); +} + +bool LocalNodeTable::isVisible(const Transaction* transaction, offset_t offset) const { + auto [nodeGroupIdx, offsetInGroup] = + StorageUtils::getNodeGroupIdxAndOffsetInChunk(offset - startOffset); + auto* nodeGroup = nodeGroups.getNodeGroup(nodeGroupIdx); + if (nodeGroup->isDeleted(transaction, offsetInGroup)) { + return false; + } + return nodeGroup->isInserted(transaction, offsetInGroup); +} + +offset_t LocalNodeTable::validateUniquenessConstraint(const Transaction* transaction, + const ValueVector& pkVector) const { + KU_ASSERT(pkVector.state->getSelVector().getSelSize() == 1); + return hashIndex->lookup(pkVector, + [&](offset_t offset_) { return isVisible(transaction, offset_); }); +} + +bool LocalNodeTable::insert(Transaction* transaction, TableInsertState& insertState) { + auto& nodeInsertState = insertState.constCast(); + const auto nodeOffset = startOffset + nodeGroups.getNumTotalRows(); + KU_ASSERT(nodeInsertState.pkVector.state->getSelVector().getSelSize() == 1); + if (!hashIndex->insert(nodeInsertState.pkVector, nodeOffset, + [&](offset_t offset) { return isVisible(transaction, offset); })) { + const auto val = + nodeInsertState.pkVector.getAsValue(nodeInsertState.pkVector.state->getSelVector()[0]); + throw RuntimeException(ExceptionMessage::duplicatePKException(val->toString())); + } + const auto nodeIDPos = + nodeInsertState.nodeIDVector.state->getSelVector().getSelectedPositions()[0]; + nodeInsertState.nodeIDVector.setValue(nodeIDPos, internalID_t{nodeOffset, table.getTableID()}); + nodeGroups.append(&DUMMY_TRANSACTION, insertState.propertyVectors); + return true; +} + +bool LocalNodeTable::update(Transaction* transaction, TableUpdateState& updateState) { + KU_ASSERT(transaction->isDummy()); + const auto& nodeUpdateState = updateState.cast(); + KU_ASSERT(nodeUpdateState.nodeIDVector.state->getSelVector().getSelSize() == 1); + const auto pos = nodeUpdateState.nodeIDVector.state->getSelVector()[0]; + const auto offset = nodeUpdateState.nodeIDVector.readNodeOffset(pos); + KU_ASSERT(nodeUpdateState.columnID != table.cast().getPKColumnID()); + KU_ASSERT(offset >= startOffset); + const auto [nodeGroupIdx, rowIdxInGroup] = + StorageUtils::getQuotientRemainder(offset - startOffset, StorageConfig::NODE_GROUP_SIZE); + const auto nodeGroup = nodeGroups.getNodeGroup(nodeGroupIdx); + nodeGroup->update(transaction, rowIdxInGroup, nodeUpdateState.columnID, + nodeUpdateState.propertyVector); + return true; +} + +bool LocalNodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) { + KU_ASSERT(transaction->isDummy()); + const auto& nodeDeleteState = deleteState.cast(); + KU_ASSERT(nodeDeleteState.nodeIDVector.state->getSelVector().getSelSize() == 1); + const auto pos = nodeDeleteState.nodeIDVector.state->getSelVector()[0]; + const auto offset = nodeDeleteState.nodeIDVector.readNodeOffset(pos); + KU_ASSERT(offset >= startOffset); + hashIndex->delete_(nodeDeleteState.pkVector); + const auto [nodeGroupIdx, rowIdxInGroup] = + StorageUtils::getQuotientRemainder(offset - startOffset, StorageConfig::NODE_GROUP_SIZE); + const auto nodeGroup = nodeGroups.getNodeGroup(nodeGroupIdx); + return nodeGroup->delete_(transaction, rowIdxInGroup); +} + +bool LocalNodeTable::addColumn(TableAddColumnState& addColumnState) { + nodeGroups.addColumn(addColumnState); + return true; +} + +void LocalNodeTable::clear(MemoryManager& mm) { + auto& nodeTable = ku_dynamic_cast(table); + hashIndex = std::make_unique(mm, + nodeTable.getColumn(nodeTable.getPKColumnID()).getDataType().getPhysicalType(), + overflowFileHandle); + nodeGroups.clear(); +} + +bool LocalNodeTable::lookupPK(const Transaction* transaction, const ValueVector* keyVector, + sel_t pos, offset_t& result) const { + result = hashIndex->lookup(*keyVector, pos, + [&](offset_t offset) { return isVisible(transaction, offset); }); + return result != INVALID_OFFSET; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_rel_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_rel_table.cpp new file mode 100644 index 0000000000..9952f2618e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_rel_table.cpp @@ -0,0 +1,298 @@ +#include "storage/local_storage/local_rel_table.h" + +#include +#include + +#include "common/enums/rel_direction.h" +#include "storage/table/rel_table.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +static std::vector getTypesForLocalRelTable(const catalog::TableCatalogEntry& table) { + std::vector types; + types.reserve(table.getNumProperties() + 2); + // Pre-append src and dst node ID columns. + types.push_back(LogicalType::INTERNAL_ID()); + types.push_back(LogicalType::INTERNAL_ID()); + for (auto& property : table.getProperties()) { + types.push_back(property.getType().copy()); + } + return types; +} + +LocalRelTable::LocalRelTable(const catalog::TableCatalogEntry* tableEntry, const Table& table, + MemoryManager& mm) + : LocalTable{table} { + localNodeGroup = std::make_unique(mm, 0, false, + getTypesForLocalRelTable(*tableEntry), INVALID_ROW_IDX); + const auto& relTable = table.cast(); + for (auto relDirection : relTable.getStorageDirections()) { + directedIndices.emplace_back(relDirection); + } +} + +bool LocalRelTable::insert(Transaction*, TableInsertState& state) { + const auto& insertState = state.cast(); + + std::vector rowIndicesToInsertTo; + for (auto& directedIndex : directedIndices) { + const auto& nodeIDVector = insertState.getBoundNodeIDVector(directedIndex.direction); + KU_ASSERT(nodeIDVector.state->getSelVector().getSelSize() == 1); + auto nodePos = nodeIDVector.state->getSelVector()[0]; + if (nodeIDVector.isNull(nodePos)) { + return false; + } + auto nodeOffset = nodeIDVector.readNodeOffset(nodePos); + rowIndicesToInsertTo.push_back(&directedIndex.index[nodeOffset]); + } + + const auto numRowsInLocalTable = localNodeGroup->getNumRows(); + const auto relOffset = StorageConstants::MAX_NUM_ROWS_IN_TABLE + numRowsInLocalTable; + const auto relIDVector = insertState.propertyVectors[0]; + KU_ASSERT(relIDVector->dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID); + const auto relIDPos = relIDVector->state->getSelVector()[0]; + relIDVector->setValue(relIDPos, internalID_t{relOffset, table.getTableID()}); + relIDVector->setNull(relIDPos, false); + std::vector insertVectors; + insertVectors.push_back(&insertState.srcNodeIDVector); + insertVectors.push_back(&insertState.dstNodeIDVector); + for (auto i = 0u; i < insertState.propertyVectors.size(); i++) { + insertVectors.push_back(insertState.propertyVectors[i]); + } + const auto numRowsToAppend = insertState.srcNodeIDVector.state->getSelVector().getSelSize(); + localNodeGroup->append(&DUMMY_TRANSACTION, insertVectors, 0, numRowsToAppend); + + for (auto* rowIndexToInsertTo : rowIndicesToInsertTo) { + rowIndexToInsertTo->push_back(numRowsInLocalTable); + } + + return true; +} + +bool LocalRelTable::update(Transaction* transaction, TableUpdateState& state) { + KU_ASSERT(transaction->isDummy()); + const auto& updateState = state.cast(); + + std::vector rowIndicesToUpdate; + for (auto& directedIndex : directedIndices) { + const auto& nodeIDVector = updateState.getBoundNodeIDVector(directedIndex.direction); + KU_ASSERT(nodeIDVector.state->getSelVector().getSelSize() == 1); + auto nodePos = nodeIDVector.state->getSelVector()[0]; + if (nodeIDVector.isNull(nodePos)) { + return false; + } + auto nodeOffset = nodeIDVector.readNodeOffset(nodePos); + rowIndicesToUpdate.push_back(&directedIndex.index[nodeOffset]); + } + + const auto relIDPos = updateState.relIDVector.state->getSelVector()[0]; + if (updateState.relIDVector.isNull(relIDPos)) { + return false; + } + const auto relOffset = updateState.relIDVector.readNodeOffset(relIDPos); + const auto matchedRow = findMatchingRow(transaction, rowIndicesToUpdate, relOffset); + if (matchedRow == INVALID_ROW_IDX) { + return false; + } + KU_ASSERT(updateState.columnID != NBR_ID_COLUMN_ID); + localNodeGroup->update(transaction, matchedRow, + rewriteLocalColumnID(RelDataDirection::FWD /* This is a dummy direction */, + updateState.columnID), + updateState.propertyVector); + return true; +} + +bool LocalRelTable::delete_(Transaction* transaction, TableDeleteState& state) { + const auto& deleteState = state.cast(); + + std::vector rowIndicesToDeleteFrom; + auto& directedIndex = + directedIndices[RelDirectionUtils::relDirectionToKeyIdx(deleteState.detachDeleteDirection)]; + auto& reverseDirectedIndex = directedIndices[RelDirectionUtils::relDirectionToKeyIdx( + RelDirectionUtils::getOppositeDirection(deleteState.detachDeleteDirection))]; + std::vector> directedIndicesAndNodeIDVectors; + auto directedIndexPos = + RelDirectionUtils::relDirectionToKeyIdx(deleteState.detachDeleteDirection); + if (directedIndexPos < directedIndices.size()) { + directedIndicesAndNodeIDVectors.emplace_back(directedIndex, deleteState.srcNodeIDVector); + } + auto reverseDirectedIndexPos = RelDirectionUtils::relDirectionToKeyIdx( + RelDirectionUtils::getOppositeDirection(deleteState.detachDeleteDirection)); + if (reverseDirectedIndexPos < directedIndices.size()) { + directedIndicesAndNodeIDVectors.emplace_back(reverseDirectedIndex, + deleteState.dstNodeIDVector); + } + for (auto& [csrIndex, nodeIDVector] : directedIndicesAndNodeIDVectors) { + KU_ASSERT(nodeIDVector.state->getSelVector().getSelSize() == 1); + auto nodePos = nodeIDVector.state->getSelVector()[0]; + if (nodeIDVector.isNull(nodePos)) { + return false; + } + auto nodeOffset = nodeIDVector.readNodeOffset(nodePos); + KU_ASSERT(csrIndex.index.contains(nodeOffset)); + rowIndicesToDeleteFrom.push_back(&csrIndex.index[nodeOffset]); + } + + const auto relIDPos = deleteState.relIDVector.state->getSelVector()[0]; + if (deleteState.relIDVector.isNull(relIDPos)) { + return false; + } + const auto relOffset = deleteState.relIDVector.readNodeOffset(relIDPos); + const auto matchedRow = findMatchingRow(transaction, rowIndicesToDeleteFrom, relOffset); + if (matchedRow == INVALID_ROW_IDX) { + return false; + } + + for (auto* rowIndexToDeleteFrom : rowIndicesToDeleteFrom) { + std::erase(*rowIndexToDeleteFrom, matchedRow); + } + return true; +} + +bool LocalRelTable::addColumn(TableAddColumnState& addColumnState) { + localNodeGroup->addColumn(addColumnState, nullptr /* FileHandle */, + nullptr /* newColumnStats */); + return true; +} + +bool LocalRelTable::checkIfNodeHasRels(ValueVector* srcNodeIDVector, + RelDataDirection direction) const { + KU_ASSERT(srcNodeIDVector->state->isFlat()); + const auto nodeIDPos = srcNodeIDVector->state->getSelVector()[0]; + const auto nodeOffset = srcNodeIDVector->getValue(nodeIDPos).offset; + const auto& directedIndex = + directedIndices[RelDirectionUtils::relDirectionToKeyIdx(direction)].index; + return (directedIndex.contains(nodeOffset) && !directedIndex.at(nodeOffset).empty()); +} + +void LocalRelTable::initializeScan(TableScanState& state) { + auto& relScanState = state.cast(); + KU_ASSERT(relScanState.source == TableScanSource::UNCOMMITTED); + KU_ASSERT(relScanState.localTableScanState); + auto& localScanState = *relScanState.localTableScanState; + localScanState.rowIndices.clear(); + localScanState.nextRowToScan = 0; +} + +std::vector LocalRelTable::rewriteLocalColumnIDs(RelDataDirection direction, + const std::vector& columnIDs) { + std::vector localColumnIDs; + localColumnIDs.reserve(columnIDs.size()); + for (auto i = 0u; i < columnIDs.size(); i++) { + const auto columnID = columnIDs[i]; + localColumnIDs.push_back(rewriteLocalColumnID(direction, columnID)); + } + return localColumnIDs; +} + +column_id_t LocalRelTable::rewriteLocalColumnID(RelDataDirection direction, column_id_t columnID) { + return columnID == NBR_ID_COLUMN_ID ? direction == RelDataDirection::FWD ? + LOCAL_NBR_NODE_ID_COLUMN_ID : + LOCAL_BOUND_NODE_ID_COLUMN_ID : + columnID + 1; +} + +bool LocalRelTable::scan(const Transaction* transaction, TableScanState& state) const { + auto& relScanState = state.cast(); + KU_ASSERT(relScanState.localTableScanState); + auto& localScanState = *relScanState.localTableScanState; + while (true) { + if (relScanState.currBoundNodeIdx >= relScanState.cachedBoundNodeSelVector.getSelSize()) { + return false; + } + const auto boundNodePos = + relScanState.cachedBoundNodeSelVector[relScanState.currBoundNodeIdx]; + const auto boundNodeOffset = relScanState.nodeIDVector->readNodeOffset(boundNodePos); + auto& localCSRIndex = + directedIndices[RelDirectionUtils::relDirectionToKeyIdx(relScanState.direction)].index; + if (localScanState.rowIndices.empty() && localCSRIndex.contains(boundNodeOffset)) { + localScanState.rowIndices = localCSRIndex.at(boundNodeOffset); + localScanState.nextRowToScan = 0; + KU_ASSERT( + std::is_sorted(localScanState.rowIndices.begin(), localScanState.rowIndices.end())); + } + KU_ASSERT(localScanState.rowIndices.size() >= localScanState.nextRowToScan); + const auto numToScan = + std::min(localScanState.rowIndices.size() - localScanState.nextRowToScan, + DEFAULT_VECTOR_CAPACITY); + if (numToScan == 0) { + relScanState.currBoundNodeIdx++; + localScanState.nextRowToScan = 0; + localScanState.rowIndices.clear(); + continue; + } + for (auto i = 0u; i < numToScan; i++) { + localScanState.rowIdxVector->setValue(i, + localScanState.rowIndices[localScanState.nextRowToScan + i]); + } + localScanState.rowIdxVector->state->getSelVectorUnsafe().setSelSize(numToScan); + [[maybe_unused]] auto lookupRes = + localNodeGroup->lookupMultiple(transaction, localScanState); + localScanState.nextRowToScan += numToScan; + relScanState.setNodeIDVectorToFlat( + relScanState.cachedBoundNodeSelVector[relScanState.currBoundNodeIdx]); + return true; + } +} + +static std::unique_ptr setupLocalTableScanState(DataChunk& scanChunk, + std::span intersectRows) { + const std::vector columnIDs{LOCAL_REL_ID_COLUMN_ID}; + auto scanState = std::make_unique(nullptr, + std::vector{&scanChunk.getValueVectorMutable(0)}, scanChunk.state); + scanState->columnIDs = columnIDs; + scanState->nodeGroupScanState->chunkStates.resize(columnIDs.size()); + scanChunk.state->getSelVectorUnsafe().setSelSize(intersectRows.size()); + for (uint64_t i = 0; i < intersectRows.size(); i++) { + scanState->rowIdxVector->setValue(i, intersectRows[i]); + } + return scanState; +} + +row_idx_t LocalRelTable::findMatchingRow(const Transaction* transaction, + const std::vector& rowIndicesToCheck, offset_t relOffset) const { + for (auto* rowIndex : rowIndicesToCheck) { + std::sort(rowIndex->begin(), rowIndex->end()); + } + std::vector intersectRows = + std::accumulate(rowIndicesToCheck.begin(), rowIndicesToCheck.end(), *rowIndicesToCheck[0], + [](row_idx_vec_t curIntersection, row_idx_vec_t* rowIndex) -> row_idx_vec_t { + row_idx_vec_t ret; + std::set_intersection(curIntersection.begin(), curIntersection.end(), + rowIndex->begin(), rowIndex->end(), std::back_inserter(ret)); + return ret; + }); + // Loop over relID column chunks to find the relID. + const auto numVectorsToScan = + ceilDiv(static_cast(intersectRows.size()), DEFAULT_VECTOR_CAPACITY); + for (uint64_t vectorIdx = 0; vectorIdx < numVectorsToScan; ++vectorIdx) { + DataChunk scanChunk(1); + scanChunk.insert(0, std::make_shared(LogicalType::INTERNAL_ID())); + + const uint64_t startRowToScan = vectorIdx * DEFAULT_VECTOR_CAPACITY; + const auto endRowToScan = std::min(startRowToScan + DEFAULT_VECTOR_CAPACITY, + static_cast(intersectRows.size())); + std::span currentRowsToCheck{intersectRows.begin() + startRowToScan, + intersectRows.begin() + endRowToScan}; + const auto scanState = setupLocalTableScanState(scanChunk, currentRowsToCheck); + + [[maybe_unused]] auto lookupRes = localNodeGroup->lookupMultiple(transaction, *scanState); + const auto scannedRelIDVector = scanState->outputVectors[0]; + KU_ASSERT( + scannedRelIDVector->state->getSelVector().getSelSize() == currentRowsToCheck.size()); + for (auto i = 0u; i < currentRowsToCheck.size(); i++) { + if (scannedRelIDVector->getValue(i).offset == relOffset) { + return currentRowsToCheck[i]; + } + } + } + return INVALID_ROW_IDX; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_storage.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_storage.cpp new file mode 100644 index 0000000000..e808192117 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/local_storage/local_storage.cpp @@ -0,0 +1,95 @@ +#include "storage/local_storage/local_storage.h" + +#include "storage/local_storage/local_node_table.h" +#include "storage/local_storage/local_rel_table.h" +#include "storage/local_storage/local_table.h" +#include "storage/storage_manager.h" +#include "storage/table/rel_table.h" +#include "storage/table/table.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +LocalTable* LocalStorage::getOrCreateLocalTable(Table& table) { + const auto tableID = table.getTableID(); + auto catalog = catalog::Catalog::Get(clientContext); + auto transaction = transaction::Transaction::Get(clientContext); + auto& mm = *MemoryManager::Get(clientContext); + if (!tables.contains(tableID)) { + switch (table.getTableType()) { + case TableType::NODE: { + auto tableEntry = catalog->getTableCatalogEntry(transaction, table.getTableID()); + tables[tableID] = std::make_unique(tableEntry, table, mm); + } break; + case TableType::REL: { + // We have to fetch the rel group entry from the catalog to based on the relGroupID. + auto tableEntry = + catalog->getTableCatalogEntry(transaction, table.cast().getRelGroupID()); + tables[tableID] = std::make_unique(tableEntry, table, mm); + } break; + default: + KU_UNREACHABLE; + } + } + return tables.at(tableID).get(); +} + +LocalTable* LocalStorage::getLocalTable(table_id_t tableID) const { + if (tables.contains(tableID)) { + return tables.at(tableID).get(); + } + return nullptr; +} + +PageAllocator* LocalStorage::addOptimisticAllocator() { + auto* dataFH = StorageManager::Get(clientContext)->getDataFH(); + if (dataFH->isInMemoryMode()) { + return dataFH->getPageManager(); + } + UniqLock lck{mtx}; + optimisticAllocators.emplace_back( + std::make_unique(*dataFH->getPageManager())); + return optimisticAllocators.back().get(); +} + +void LocalStorage::commit() { + auto catalog = catalog::Catalog::Get(clientContext); + auto transaction = transaction::Transaction::Get(clientContext); + auto storageManager = StorageManager::Get(clientContext); + for (auto& [tableID, localTable] : tables) { + if (localTable->getTableType() == TableType::NODE) { + const auto tableEntry = catalog->getTableCatalogEntry(transaction, tableID); + const auto table = storageManager->getTable(tableID); + table->commit(&clientContext, tableEntry, localTable.get()); + } + } + for (auto& [tableID, localTable] : tables) { + if (localTable->getTableType() == TableType::REL) { + const auto table = storageManager->getTable(tableID); + const auto tableEntry = + catalog->getTableCatalogEntry(transaction, table->cast().getRelGroupID()); + table->commit(&clientContext, tableEntry, localTable.get()); + } + } + for (auto& optimisticAllocator : optimisticAllocators) { + optimisticAllocator->commit(); + } +} + +void LocalStorage::rollback() { + auto mm = MemoryManager::Get(clientContext); + for (auto& [_, localTable] : tables) { + localTable->clear(*mm); + } + for (auto& optimisticAllocator : optimisticAllocators) { + optimisticAllocator->rollback(); + } + auto* bufferManager = mm->getBufferManager(); + PageManager::Get(clientContext)->clearEvictedBMEntriesIfNeeded(bufferManager); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/optimistic_allocator.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/optimistic_allocator.cpp new file mode 100644 index 0000000000..0060d8fb5a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/optimistic_allocator.cpp @@ -0,0 +1,31 @@ +#include "storage/optimistic_allocator.h" + +#include "storage/page_manager.h" + +namespace lbug::storage { +OptimisticAllocator::OptimisticAllocator(PageManager& pageManager) + : PageAllocator(pageManager.getDataFH()), pageManager(pageManager) {} + +PageRange OptimisticAllocator::allocatePageRange(common::page_idx_t numPages) { + auto pageRange = pageManager.allocatePageRange(numPages); + if (numPages > 0) { + optimisticallyAllocatedPages.push_back(pageRange); + } + return pageRange; +} + +void OptimisticAllocator::freePageRange(PageRange block) { + pageManager.freePageRange(block); +} + +void OptimisticAllocator::rollback() { + for (const auto& entry : optimisticallyAllocatedPages) { + pageManager.freeImmediatelyRewritablePageRange(pageManager.getDataFH(), entry); + } + optimisticallyAllocatedPages.clear(); +} + +void OptimisticAllocator::commit() { + optimisticallyAllocatedPages.clear(); +} +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/overflow_file.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/overflow_file.cpp new file mode 100644 index 0000000000..a9ac2d51ab --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/overflow_file.cpp @@ -0,0 +1,282 @@ +#include "storage/overflow_file.h" + +#include + +#include "common/type_utils.h" +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/file_handle.h" +#include "storage/shadow_utils.h" +#include "storage/storage_utils.h" +#include "transaction/transaction.h" + +using namespace lbug::transaction; +using namespace lbug::common; + +namespace lbug { +namespace storage { + +std::string OverflowFileHandle::readString(TransactionType trxType, const ku_string_t& str) const { + if (ku_string_t::isShortString(str.len)) { + return str.getAsShortString(); + } + PageCursor cursor; + TypeUtils::decodeOverflowPtr(str.overflowPtr, cursor.pageIdx, cursor.elemPosInPage); + std::string retVal; + retVal.reserve(str.len); + int32_t remainingLength = str.len; + while (remainingLength > 0) { + auto numBytesToReadInPage = + std::min(static_cast(remainingLength), END_OF_PAGE - cursor.elemPosInPage); + auto startPosInSrc = retVal.size(); + read(trxType, cursor.pageIdx, [&](uint8_t* frame) { + // Replace rather than append, since optimistic read may call the function multiple + // times + retVal.replace(startPosInSrc, numBytesToReadInPage, + std::string_view(reinterpret_cast(frame) + cursor.elemPosInPage, + numBytesToReadInPage)); + cursor.pageIdx = *reinterpret_cast(frame + END_OF_PAGE); + }); + remainingLength -= numBytesToReadInPage; + // After the first page we always start reading from the beginning of the page. + cursor.elemPosInPage = 0; + } + return retVal; +} + +bool OverflowFileHandle::equals(TransactionType trxType, std::string_view keyToLookup, + const ku_string_t& keyInEntry) const { + PageCursor cursor; + TypeUtils::decodeOverflowPtr(keyInEntry.overflowPtr, cursor.pageIdx, cursor.elemPosInPage); + auto lengthRead = 0u; + while (lengthRead < keyInEntry.len) { + auto numBytesToCheckInPage = std::min(static_cast(keyInEntry.len) - lengthRead, + END_OF_PAGE - cursor.elemPosInPage); + bool equal = true; + read(trxType, cursor.pageIdx, [&](auto* frame) { + equal = memcmp(keyToLookup.data() + lengthRead, frame + cursor.elemPosInPage, + numBytesToCheckInPage) == 0; + // Update the next page index + cursor.pageIdx = *reinterpret_cast(frame + END_OF_PAGE); + }); + if (!equal) { + return false; + } + cursor.elemPosInPage = 0; + lengthRead += numBytesToCheckInPage; + } + return true; +} + +uint8_t* OverflowFileHandle::addANewPage(PageAllocator* pageAllocator) { + page_idx_t newPageIdx = overflowFile.getNewPageIdx(pageAllocator); + if (pageWriteCache.size() > 0) { + memcpy(pageWriteCache[nextPosToWriteTo.pageIdx].buffer->getData() + END_OF_PAGE, + &newPageIdx, sizeof(page_idx_t)); + } + if (startPageIdx == INVALID_PAGE_IDX) { + startPageIdx = newPageIdx; + } + pageWriteCache.emplace(newPageIdx, + CachedPage{.buffer = overflowFile.memoryManager.allocateBuffer(true /*initializeToZero*/, + LBUG_PAGE_SIZE), + .newPage = true}); + nextPosToWriteTo.elemPosInPage = 0; + nextPosToWriteTo.pageIdx = newPageIdx; + return pageWriteCache[newPageIdx].buffer->getData(); +} + +void OverflowFileHandle::setStringOverflow(PageAllocator* pageAllocator, const char* srcRawString, + uint64_t len, ku_string_t& diskDstString) { + if (len <= ku_string_t::SHORT_STR_LENGTH) { + return; + } + overflowFile.headerChanged = true; + uint8_t* pageToWrite = nullptr; + if (nextPosToWriteTo.pageIdx == INVALID_PAGE_IDX) { + pageToWrite = addANewPage(pageAllocator); + } else { + auto cached = pageWriteCache.find(nextPosToWriteTo.pageIdx); + if (cached != pageWriteCache.end()) { + pageToWrite = cached->second.buffer->getData(); + } else { + overflowFile.readFromDisk(TransactionType::CHECKPOINT, nextPosToWriteTo.pageIdx, + [&](auto* frame) { + auto page = overflowFile.memoryManager.allocateBuffer( + false /*initializeToZero*/, LBUG_PAGE_SIZE); + memcpy(page->getData(), frame, LBUG_PAGE_SIZE); + pageToWrite = page->getData(); + pageWriteCache.emplace(nextPosToWriteTo.pageIdx, + CachedPage{.buffer = std::move(page), .newPage = false}); + }); + } + } + int32_t remainingLength = len; + TypeUtils::encodeOverflowPtr(diskDstString.overflowPtr, nextPosToWriteTo.pageIdx, + nextPosToWriteTo.elemPosInPage); + while (remainingLength > 0) { + auto bytesWritten = len - remainingLength; + auto numBytesToWriteInPage = std::min(static_cast(remainingLength), + END_OF_PAGE - nextPosToWriteTo.elemPosInPage); + memcpy(pageToWrite + nextPosToWriteTo.elemPosInPage, srcRawString + bytesWritten, + numBytesToWriteInPage); + remainingLength -= numBytesToWriteInPage; + nextPosToWriteTo.elemPosInPage += numBytesToWriteInPage; + if (nextPosToWriteTo.elemPosInPage >= END_OF_PAGE) { + pageToWrite = addANewPage(pageAllocator); + } + } +} + +ku_string_t OverflowFileHandle::writeString(PageAllocator* pageAllocator, + std::string_view rawString) { + ku_string_t result; + result.len = rawString.length(); + auto shortStrLen = ku_string_t::SHORT_STR_LENGTH; + auto inlineLen = std::min(shortStrLen, static_cast(result.len)); + memcpy(result.prefix, rawString.data(), inlineLen); + setStringOverflow(pageAllocator, rawString.data(), rawString.length(), result); + return result; +} + +void OverflowFileHandle::checkpoint() { + for (auto& [pageIndex, page] : pageWriteCache) { + overflowFile.writePageToDisk(pageIndex, page.buffer->getData(), page.newPage); + } +} + +void OverflowFileHandle::reclaimStorage(PageAllocator& pageAllocator) { + if (startPageIdx == INVALID_PAGE_IDX) { + return; + } + + auto pageIdx = startPageIdx; + while (true) { + if (pageIdx == 0 || pageIdx == INVALID_PAGE_IDX) [[unlikely]] { + throw RuntimeException( + "The overflow file has been corrupted, this should never happen."); + } + pageAllocator.freePage(pageIdx); + + if (pageIdx == nextPosToWriteTo.pageIdx) { + break; + } + + // reclaimStorage() is only called after the hash index is checkpointed + // so the page write cache should always be cleared + KU_ASSERT(!pageWriteCache.contains(pageIdx)); + overflowFile.readFromDisk(TransactionType::CHECKPOINT, pageIdx, [&pageIdx](auto* frame) { + pageIdx = *reinterpret_cast(frame + END_OF_PAGE); + }); + } +} + +void OverflowFileHandle::read(TransactionType trxType, page_idx_t pageIdx, + const std::function& func) const { + auto cachedPage = pageWriteCache.find(pageIdx); + if (cachedPage != pageWriteCache.end()) { + return func(cachedPage->second.buffer->getData()); + } + overflowFile.readFromDisk(trxType, pageIdx, func); +} + +OverflowFile::OverflowFile(FileHandle* fileHandle, MemoryManager& memoryManager, + ShadowFile* shadowFile, page_idx_t headerPageIdx) + : fileHandle{fileHandle}, shadowFile{shadowFile}, memoryManager{memoryManager}, + headerChanged{false}, headerPageIdx{headerPageIdx} { + KU_ASSERT(shadowFile); + if (headerPageIdx != INVALID_PAGE_IDX) { + readFromDisk(TransactionType::READ_ONLY, headerPageIdx, + [&](auto* frame) { memcpy(&header, frame, sizeof(header)); }); + } else { + header = StringOverflowFileHeader(); + } +} + +OverflowFile::OverflowFile(MemoryManager& memoryManager) + : fileHandle{nullptr}, shadowFile{nullptr}, memoryManager{memoryManager}, headerChanged{false}, + headerPageIdx{INVALID_PAGE_IDX} { + // Reserve a page for the header + this->headerPageIdx = getNewPageIdx(nullptr); + header = StringOverflowFileHeader(); +} + +common::page_idx_t OverflowFile::getNewPageIdx(PageAllocator* pageAllocator) { + // If this isn't the first call reserving the page header, then the header flag must be set + // prior to this + if (pageAllocator) { + return pageAllocator->allocatePage(); + } else { + return pageCounter.fetch_add(1); + } +} + +void OverflowFile::readFromDisk(TransactionType trxType, page_idx_t pageIdx, + const std::function& func) const { + KU_ASSERT(shadowFile); + auto [fileHandleToPin, pageIdxToPin] = ShadowUtils::getFileHandleAndPhysicalPageIdxToPin( + *fileHandle, pageIdx, *shadowFile, trxType); + fileHandleToPin->optimisticReadPage(pageIdxToPin, func); +} + +void OverflowFile::writePageToDisk(page_idx_t pageIdx, uint8_t* data, bool newPage) const { + if (newPage) { + KU_ASSERT(fileHandle); + KU_ASSERT(!fileHandle->isInMemoryMode()); + fileHandle->writePageToFile(data, pageIdx); + } else { + KU_ASSERT(shadowFile); + ShadowUtils::updatePage(*fileHandle, pageIdx, true /* overwriting entire page*/, + *shadowFile, [&](auto* frame) { memcpy(frame, data, LBUG_PAGE_SIZE); }); + } +} + +void OverflowFile::checkpoint(PageAllocator& pageAllocator) { + KU_ASSERT(fileHandle); + if (headerPageIdx == INVALID_PAGE_IDX) { + // Reserve a page for the header + this->headerPageIdx = getNewPageIdx(&pageAllocator); + headerChanged = true; + } + // TODO(bmwinger): Ideally this could be done separately and in parallel by each HashIndex + // However fileHandle->addNewPages needs to be called beforehand, + // but after each HashIndex::prepareCommit has written to the in-memory pages + for (auto& handle : handles) { + handle->checkpoint(); + } + if (headerChanged) { + uint8_t page[LBUG_PAGE_SIZE]; + memcpy(page, &header, sizeof(header)); + // Zero free space at the end of the header page + std::fill(page + sizeof(header), page + LBUG_PAGE_SIZE, 0); + writePageToDisk(headerPageIdx + HEADER_PAGE_IDX, page, false /*newPage*/); + } +} + +void OverflowFile::checkpointInMemory() { + headerChanged = false; +} + +void OverflowFile::rollbackInMemory() { + KU_ASSERT(getFileHandle()->getNumPages() <= INVALID_PAGE_IDX); + if (getFileHandle()->getNumPages() > headerPageIdx) { + readFromDisk(TransactionType::READ_ONLY, headerPageIdx, + [&](auto* frame) { memcpy(&header, frame, sizeof(header)); }); + } + for (auto i = 0u; i < handles.size(); i++) { + auto& handle = handles[i]; + handle->rollbackInMemory(header.entries[i].cursor); + } +} + +void OverflowFile::reclaimStorage(PageAllocator& pageAllocator) const { + for (auto& handle : handles) { + handle->reclaimStorage(pageAllocator); + } + if (headerPageIdx != INVALID_PAGE_IDX) { + pageAllocator.freePage(headerPageIdx); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/page_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/page_manager.cpp new file mode 100644 index 0000000000..89f764a629 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/page_manager.cpp @@ -0,0 +1,66 @@ +#include "storage/page_manager.h" + +#include "common/uniq_lock.h" +#include "storage/file_handle.h" +#include "storage/storage_manager.h" + +namespace lbug::storage { +static constexpr bool ENABLE_FSM = true; + +PageRange PageManager::allocatePageRange(common::page_idx_t numPages) { + if constexpr (ENABLE_FSM) { + common::UniqLock lck{mtx}; + auto allocatedFreeChunk = freeSpaceManager->popFreePages(numPages); + if (allocatedFreeChunk.has_value()) { + ++version; + return {*allocatedFreeChunk}; + } + } + auto startPageIdx = fileHandle->addNewPages(numPages); + KU_ASSERT(fileHandle->getNumPages() >= startPageIdx + numPages); + return PageRange(startPageIdx, numPages); +} + +void PageManager::freePageRange(PageRange entry) { + if constexpr (ENABLE_FSM) { + common::UniqLock lck{mtx}; + // Freed pages cannot be immediately reused to ensure checkpoint recovery works + // Instead they are reusable after the end of the next checkpoint + freeSpaceManager->addUncheckpointedFreePages(entry); + ++version; + } +} + +common::page_idx_t PageManager::estimatePagesNeededForSerialize() { + return freeSpaceManager->getMaxNumPagesForSerialization(); +} + +void PageManager::freeImmediatelyRewritablePageRange(FileHandle* fileHandle, PageRange entry) { + if constexpr (ENABLE_FSM) { + common::UniqLock lck{mtx}; + freeSpaceManager->evictAndAddFreePages(fileHandle, entry); + ++version; + } +} + +void PageManager::serialize(common::Serializer& serializer) { + freeSpaceManager->serialize(serializer); +} + +void PageManager::deserialize(common::Deserializer& deSer) { + freeSpaceManager->deserialize(deSer); +} + +void PageManager::finalizeCheckpoint() { + freeSpaceManager->finalizeCheckpoint(fileHandle); +} + +void PageManager::clearEvictedBMEntriesIfNeeded(BufferManager* bufferManager) { + freeSpaceManager->clearEvictedBufferManagerEntriesIfNeeded(bufferManager); +} + +PageManager* PageManager::Get(const main::ClientContext& context) { + return StorageManager::Get(context)->getDataFH()->getPageManager(); +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/CMakeLists.txt new file mode 100644 index 0000000000..a23fedfd36 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_storage_predicate + OBJECT + null_predicate.cpp + column_predicate.cpp + constant_predicate.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/column_predicate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/column_predicate.cpp new file mode 100644 index 0000000000..212b91e9cd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/column_predicate.cpp @@ -0,0 +1,123 @@ +#include "storage/predicate/column_predicate.h" + +#include "binder/expression/literal_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "storage/predicate/constant_predicate.h" +#include "storage/predicate/null_predicate.h" + +using namespace lbug::binder; +using namespace lbug::common; + +namespace lbug { +namespace storage { + +ZoneMapCheckResult ColumnPredicateSet::checkZoneMap(const MergedColumnChunkStats& stats) const { + for (auto& predicate : predicates) { + if (predicate->checkZoneMap(stats) == ZoneMapCheckResult::SKIP_SCAN) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } + return ZoneMapCheckResult::ALWAYS_SCAN; +} + +std::string ColumnPredicateSet::toString() const { + if (predicates.empty()) { + return {}; + } + auto result = predicates[0]->toString(); + for (auto i = 1u; i < predicates.size(); ++i) { + result += stringFormat(" AND {}", predicates[i]->toString()); + } + return result; +} + +static bool isColumnRef(ExpressionType type) { + return type == ExpressionType::PROPERTY || type == ExpressionType::VARIABLE; +} + +static bool isCastedColumnRef(const Expression& expr) { + if (expr.expressionType == ExpressionType::FUNCTION) { + const auto& funcExpr = expr.constCast(); + if (funcExpr.getFunction().name.starts_with("CAST")) { + KU_ASSERT(funcExpr.getNumChildren() > 0); + return isColumnRef(funcExpr.getChild(0)->expressionType); + } + } + return false; +} + +static bool isColumnOrCastedColumnRef(const Expression& expr) { + return isColumnRef(expr.expressionType) || isCastedColumnRef(expr); +} + +static bool isColumnRefConstantPair(const Expression& left, const Expression& right) { + return isColumnOrCastedColumnRef(left) && right.expressionType == ExpressionType::LITERAL; +} + +static bool columnMatchesExprChild(const Expression& column, const Expression& expr) { + return (expr.getNumChildren() > 0 && column == *expr.getChild(0)); +} + +static std::unique_ptr tryConvertToConstColumnPredicate(const Expression& column, + const Expression& predicate) { + if (isColumnRefConstantPair(*predicate.getChild(0), *predicate.getChild(1))) { + if (column != *predicate.getChild(0) && + !columnMatchesExprChild(column, *predicate.getChild(0))) { + return nullptr; + } + auto value = predicate.getChild(1)->constCast().getValue(); + return std::make_unique(column.toString(), + predicate.expressionType, value); + } else if (isColumnRefConstantPair(*predicate.getChild(1), *predicate.getChild(0))) { + if (column != *predicate.getChild(1) && + !columnMatchesExprChild(column, *predicate.getChild(1))) { + return nullptr; + } + auto value = predicate.getChild(0)->constCast().getValue(); + auto expressionType = + ExpressionTypeUtil::reverseComparisonDirection(predicate.expressionType); + return std::make_unique(column.toString(), expressionType, value); + } + // Not a predicate that runs on this property. + return nullptr; +} + +static std::unique_ptr tryConvertToIsNull(const Expression& column, + const Expression& predicate) { + // we only convert simple predicates + if (isColumnOrCastedColumnRef(*predicate.getChild(0)) && column == *predicate.getChild(0)) { + return std::make_unique(column.toString(), ExpressionType::IS_NULL); + } + return nullptr; +} + +static std::unique_ptr tryConvertToIsNotNull(const Expression& column, + const Expression& predicate) { + if (isColumnOrCastedColumnRef(*predicate.getChild(0)) && column == *predicate.getChild(0)) { + return std::make_unique(column.toString(), + ExpressionType::IS_NOT_NULL); + } + return nullptr; +} + +std::unique_ptr ColumnPredicateUtil::tryConvert(const Expression& property, + const Expression& predicate) { + if (ExpressionTypeUtil::isComparison(predicate.expressionType)) { + return tryConvertToConstColumnPredicate(property, predicate); + } + switch (predicate.expressionType) { + case common::ExpressionType::IS_NULL: + return tryConvertToIsNull(property, predicate); + case common::ExpressionType::IS_NOT_NULL: + return tryConvertToIsNotNull(property, predicate); + default: + return nullptr; + } +} + +std::string ColumnPredicate::toString() { + return stringFormat("{} {}", columnName, ExpressionTypeUtil::toParsableString(expressionType)); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/constant_predicate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/constant_predicate.cpp new file mode 100644 index 0000000000..700284b8c9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/constant_predicate.cpp @@ -0,0 +1,95 @@ +#include "storage/predicate/constant_predicate.h" + +#include "common/type_utils.h" +#include "function/comparison/comparison_functions.h" +#include "storage/compression/compression.h" +#include "storage/table/column_chunk_stats.h" + +using namespace lbug::common; +using namespace lbug::function; + +namespace lbug { +namespace storage { + +template +bool inRange(T min, T max, T val) { + auto a = GreaterThanEquals::operation(val, min); + auto b = LessThanEquals::operation(val, max); + return a && b; +} + +template +ZoneMapCheckResult checkZoneMapSwitch(const MergedColumnChunkStats& mergedStats, + ExpressionType expressionType, const Value& value) { + // If the chunk is casted from a non-storage value type + // The stats will be empty, skip the zone map check in this case + if (mergedStats.stats.min.has_value() && mergedStats.stats.max.has_value()) { + auto max = mergedStats.stats.max->get(); + auto min = mergedStats.stats.min->get(); + auto constant = value.getValue(); + switch (expressionType) { + case ExpressionType::EQUALS: { + if (!inRange(min, max, constant)) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } break; + case ExpressionType::NOT_EQUALS: { + if (Equals::operation(constant, min) && Equals::operation(constant, max)) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } break; + case ExpressionType::GREATER_THAN: { + if (GreaterThanEquals::operation(constant, max)) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } break; + case ExpressionType::GREATER_THAN_EQUALS: { + if (GreaterThan::operation(constant, max)) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } break; + case ExpressionType::LESS_THAN: { + if (LessThanEquals::operation(constant, min)) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } break; + case ExpressionType::LESS_THAN_EQUALS: { + if (LessThan::operation(constant, min)) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } break; + default: + KU_UNREACHABLE; + } + } + return ZoneMapCheckResult::ALWAYS_SCAN; +} + +ZoneMapCheckResult ColumnConstantPredicate::checkZoneMap( + const MergedColumnChunkStats& stats) const { + auto physicalType = value.getDataType().getPhysicalType(); + return TypeUtils::visit( + physicalType, + [&](T) { return checkZoneMapSwitch(stats, expressionType, value); }, + [&](auto) { return ZoneMapCheckResult::ALWAYS_SCAN; }); +} + +std::string ColumnConstantPredicate::toString() { + std::string valStr; + if (value.getDataType().getPhysicalType() == PhysicalTypeID::STRING || + value.getDataType().getPhysicalType() == PhysicalTypeID::LIST || + value.getDataType().getPhysicalType() == PhysicalTypeID::ARRAY || + value.getDataType().getPhysicalType() == PhysicalTypeID::STRUCT || + value.getDataType().getLogicalTypeID() == LogicalTypeID::UUID || + value.getDataType().getLogicalTypeID() == LogicalTypeID::TIMESTAMP || + value.getDataType().getLogicalTypeID() == LogicalTypeID::DATE || + value.getDataType().getLogicalTypeID() == LogicalTypeID::INTERVAL) { + valStr = stringFormat("'{}'", value.toString()); + } else { + valStr = value.toString(); + } + return stringFormat("{} {}", ColumnPredicate::toString(), valStr); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/null_predicate.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/null_predicate.cpp new file mode 100644 index 0000000000..dfafd18c87 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/predicate/null_predicate.cpp @@ -0,0 +1,15 @@ +#include "storage/predicate/null_predicate.h" + +#include "storage/table/column_chunk_stats.h" + +namespace lbug::storage { +common::ZoneMapCheckResult ColumnNullPredicate::checkZoneMap( + const MergedColumnChunkStats& mergedStats) const { + const bool statToCheck = (expressionType == common::ExpressionType::IS_NULL) ? + mergedStats.guaranteedNoNulls : + mergedStats.guaranteedAllNulls; + return statToCheck ? common::ZoneMapCheckResult::SKIP_SCAN : + common::ZoneMapCheckResult::ALWAYS_SCAN; +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/shadow_file.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/shadow_file.cpp new file mode 100644 index 0000000000..8189581831 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/shadow_file.cpp @@ -0,0 +1,195 @@ +#include "storage/shadow_file.h" + +#include "common/exception/io.h" +#include "common/file_system/virtual_file_system.h" +#include "common/serializer/buffered_file.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/database_header.h" +#include "storage/file_db_id_utils.h" +#include "storage/file_handle.h" +#include "storage/storage_manager.h" + +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace storage { + +void ShadowPageRecord::serialize(Serializer& serializer) const { + serializer.write(originalFileIdx); + serializer.write(originalPageIdx); +} + +ShadowPageRecord ShadowPageRecord::deserialize(Deserializer& deserializer) { + file_idx_t originalFileIdx = INVALID_FILE_IDX; + page_idx_t originalPageIdx = INVALID_PAGE_IDX; + deserializer.deserializeValue(originalFileIdx); + deserializer.deserializeValue(originalPageIdx); + return ShadowPageRecord{originalFileIdx, originalPageIdx}; +} + +ShadowFile::ShadowFile(BufferManager& bm, VirtualFileSystem* vfs, const std::string& databasePath) + : bm{bm}, shadowFilePath{StorageUtils::getShadowFilePath(databasePath)}, vfs{vfs}, + shadowingFH{nullptr} { + KU_ASSERT(vfs); +} + +void ShadowFile::clearShadowPage(file_idx_t originalFile, page_idx_t originalPage) { + if (hasShadowPage(originalFile, originalPage)) { + shadowPagesMap.at(originalFile).erase(originalPage); + if (shadowPagesMap.at(originalFile).empty()) { + shadowPagesMap.erase(originalFile); + } + } +} + +page_idx_t ShadowFile::getOrCreateShadowPage(file_idx_t originalFile, page_idx_t originalPage) { + if (hasShadowPage(originalFile, originalPage)) { + return shadowPagesMap[originalFile][originalPage]; + } + const auto shadowPageIdx = getOrCreateShadowingFH()->addNewPage(); + shadowPagesMap[originalFile][originalPage] = shadowPageIdx; + shadowPageRecords.push_back({originalFile, originalPage}); + return shadowPageIdx; +} + +page_idx_t ShadowFile::getShadowPage(file_idx_t originalFile, page_idx_t originalPage) const { + KU_ASSERT(hasShadowPage(originalFile, originalPage)); + return shadowPagesMap.at(originalFile).at(originalPage); +} + +void ShadowFile::applyShadowPages(ClientContext& context) const { + const auto pageBuffer = std::make_unique(LBUG_PAGE_SIZE); + page_idx_t shadowPageIdx = 1; // Skip header page. + auto dataFileInfo = StorageManager::Get(context)->getDataFH()->getFileInfo(); + KU_ASSERT(shadowingFH); + for (const auto& record : shadowPageRecords) { + shadowingFH->readPageFromDisk(pageBuffer.get(), shadowPageIdx++); + dataFileInfo->writeFile(pageBuffer.get(), LBUG_PAGE_SIZE, + record.originalPageIdx * LBUG_PAGE_SIZE); + // NOTE: We're not taking lock here, as we assume this is only called with a single thread. + MemoryManager::Get(context)->getBufferManager()->updateFrameIfPageIsInFrameWithoutLock( + record.originalFileIdx, pageBuffer.get(), record.originalPageIdx); + } + dataFileInfo->syncFile(); +} + +static ku_uuid_t getOldDatabaseID(FileInfo& dataFileInfo) { + auto oldHeader = DatabaseHeader::readDatabaseHeader(dataFileInfo); + if (!oldHeader.has_value()) { + throw InternalException("Found a shadow file for database {} but no valid database header. " + "The database is corrupted, please recreate it."); + } + return oldHeader->databaseID; +} + +void ShadowFile::replayShadowPageRecords(ClientContext& context) { + if (context.getDBConfig()->readOnly) { + throw RuntimeException("Couldn't replay shadow pages under read-only mode. Please re-open " + "the database with read-write mode to replay shadow pages."); + } + auto vfs = VirtualFileSystem::GetUnsafe(context); + auto shadowFilePath = StorageUtils::getShadowFilePath(context.getDatabasePath()); + auto shadowFileInfo = vfs->openFile(shadowFilePath, FileOpenFlags(FileFlags::READ_ONLY)); + + std::unique_ptr dataFileInfo; + try { + dataFileInfo = vfs->openFile(context.getDatabasePath(), + FileOpenFlags{FileFlags::WRITE | FileFlags::READ_ONLY, FileLockType::WRITE_LOCK}); + } catch (IOException& e) { + throw RuntimeException(stringFormat( + "Found shadow file {} but no corresponding database file. This file " + "may have been left behind from a previous database with the same name. If it is safe " + "to do so, please delete this file and restart the database.", + shadowFilePath)); + } + + ShadowFileHeader header; + const auto headerBuffer = std::make_unique(LBUG_PAGE_SIZE); + shadowFileInfo->readFromFile(headerBuffer.get(), LBUG_PAGE_SIZE, 0); + memcpy(&header, headerBuffer.get(), sizeof(ShadowFileHeader)); + + // When replaying the shadow file we haven't read the database ID from the database + // header yet + // So we need to do it separately here to verify the shadow file matches the database + auto oldDatabaseID = getOldDatabaseID(*dataFileInfo); + FileDBIDUtils::verifyDatabaseID(*shadowFileInfo, oldDatabaseID, header.databaseID); + + std::vector shadowPageRecords; + shadowPageRecords.reserve(header.numShadowPages); + auto reader = std::make_unique(*shadowFileInfo); + reader->resetReadOffset((header.numShadowPages + 1) * LBUG_PAGE_SIZE); + Deserializer deSer(std::move(reader)); + deSer.deserializeVector(shadowPageRecords); + + const auto pageBuffer = std::make_unique(LBUG_PAGE_SIZE); + page_idx_t shadowPageIdx = 1; + for (const auto& record : shadowPageRecords) { + shadowFileInfo->readFromFile(pageBuffer.get(), LBUG_PAGE_SIZE, + shadowPageIdx * LBUG_PAGE_SIZE); + dataFileInfo->writeFile(pageBuffer.get(), LBUG_PAGE_SIZE, + record.originalPageIdx * LBUG_PAGE_SIZE); + shadowPageIdx++; + } +} + +void ShadowFile::flushAll(main::ClientContext& context) const { + // Write header page to file. + ShadowFileHeader header; + header.numShadowPages = shadowPageRecords.size(); + header.databaseID = StorageManager::Get(context)->getOrInitDatabaseID(context); + const auto headerBuffer = std::make_unique(LBUG_PAGE_SIZE); + memcpy(headerBuffer.get(), &header, sizeof(ShadowFileHeader)); + KU_ASSERT(shadowingFH && !shadowingFH->isInMemoryMode()); + shadowingFH->writePageToFile(headerBuffer.get(), 0); + // Flush shadow pages to file. + shadowingFH->flushAllDirtyPagesInFrames(); + // Append shadow page records to the end of the file. + const auto writer = std::make_shared(*shadowingFH->getFileInfo()); + writer->setFileOffset(shadowingFH->getNumPages() * LBUG_PAGE_SIZE); + Serializer ser(writer); + KU_ASSERT(shadowPageRecords.size() + 1 == shadowingFH->getNumPages()); + ser.serializeVector(shadowPageRecords); + writer->flush(); + // Sync the file to disk. + writer->sync(); +} + +void ShadowFile::clear(BufferManager& bm) { + KU_ASSERT(shadowingFH); + // TODO(Guodong): We should remove shadow file here. This requires changes: + // 1. We need to make shadow file not going through BM. + // 2. We need to remove fileHandles held in BM, so that BM only keeps FH for the data file. + bm.removeFilePagesFromFrames(*shadowingFH); + shadowingFH->resetToZeroPagesAndPageCapacity(); + shadowPagesMap.clear(); + shadowPageRecords.clear(); + // Reserve header page. + shadowingFH->addNewPage(); +} + +void ShadowFile::reset() { + shadowingFH->resetFileInfo(); + shadowingFH = nullptr; + vfs->removeFileIfExists(shadowFilePath); +} + +FileHandle* ShadowFile::getOrCreateShadowingFH() { + if (!shadowingFH) { + shadowingFH = bm.getFileHandle(shadowFilePath, + FileHandle::O_PERSISTENT_FILE_CREATE_NOT_EXISTS, vfs, nullptr); + if (shadowingFH->getNumPages() == 0) { + // Reserve the first page for the header. + shadowingFH->addNewPage(); + } + } + return shadowingFH; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/shadow_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/shadow_utils.cpp new file mode 100644 index 0000000000..0a3d61b56d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/shadow_utils.cpp @@ -0,0 +1,88 @@ +#include "storage/shadow_utils.h" + +#include "storage/file_handle.h" +#include "storage/shadow_file.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +ShadowPageAndFrame ShadowUtils::createShadowVersionIfNecessaryAndPinPage(page_idx_t originalPage, + bool skipReadingOriginalPage, FileHandle& fileHandle, ShadowFile& shadowFile) { + KU_ASSERT(!fileHandle.isInMemoryMode()); + const auto hasShadowPage = shadowFile.hasShadowPage(fileHandle.getFileIndex(), originalPage); + auto shadowPage = shadowFile.getOrCreateShadowPage(fileHandle.getFileIndex(), originalPage); + uint8_t* shadowFrame = nullptr; + try { + if (hasShadowPage) { + shadowFrame = + shadowFile.getShadowingFH().pinPage(shadowPage, PageReadPolicy::READ_PAGE); + } else { + shadowFrame = + shadowFile.getShadowingFH().pinPage(shadowPage, PageReadPolicy::DONT_READ_PAGE); + if (!skipReadingOriginalPage) { + fileHandle.optimisticReadPage(originalPage, [&](const uint8_t* frame) -> void { + memcpy(shadowFrame, frame, LBUG_PAGE_SIZE); + }); + } + } + // The shadow page existing already does not mean that it's already dirty + // It may have been flushed to disk to free memory and then read again + shadowFile.getShadowingFH().setLockedPageDirty(shadowPage); + } catch (Exception&) { + throw; + } + return {originalPage, shadowPage, shadowFrame}; +} + +std::pair ShadowUtils::getFileHandleAndPhysicalPageIdxToPin( + FileHandle& fileHandle, page_idx_t pageIdx, const ShadowFile& shadowFile, + transaction::TransactionType trxType) { + if (trxType == transaction::TransactionType::CHECKPOINT && + shadowFile.hasShadowPage(fileHandle.getFileIndex(), pageIdx)) { + return std::make_pair(&shadowFile.getShadowingFH(), + shadowFile.getShadowPage(fileHandle.getFileIndex(), pageIdx)); + } + return std::make_pair(&fileHandle, pageIdx); +} + +void unpinShadowPage(page_idx_t originalPageIdx, page_idx_t shadowPageIdx, + const ShadowFile& shadowFile) { + KU_ASSERT(originalPageIdx != INVALID_PAGE_IDX && shadowPageIdx != INVALID_PAGE_IDX); + KU_UNUSED(originalPageIdx); + shadowFile.getShadowingFH().unpinPage(shadowPageIdx); +} + +void ShadowUtils::updatePage(FileHandle& fileHandle, page_idx_t originalPageIdx, + bool skipReadingOriginalPage, ShadowFile& shadowFile, + const std::function& updateOp) { + KU_ASSERT(!fileHandle.isInMemoryMode()); + const auto shadowPageIdxAndFrame = createShadowVersionIfNecessaryAndPinPage(originalPageIdx, + skipReadingOriginalPage, fileHandle, shadowFile); + try { + updateOp(shadowPageIdxAndFrame.frame); + } catch (Exception&) { + unpinShadowPage(shadowPageIdxAndFrame.originalPage, shadowPageIdxAndFrame.shadowPage, + shadowFile); + throw; + } + unpinShadowPage(shadowPageIdxAndFrame.originalPage, shadowPageIdxAndFrame.shadowPage, + shadowFile); +} + +void ShadowUtils::readShadowVersionOfPage(const FileHandle& fileHandle, page_idx_t originalPageIdx, + const ShadowFile& shadowFile, const std::function& readOp) { + KU_ASSERT(!fileHandle.isInMemoryMode()); + KU_ASSERT(shadowFile.hasShadowPage(fileHandle.getFileIndex(), originalPageIdx)); + const page_idx_t shadowPageIdx = + shadowFile.getShadowPage(fileHandle.getFileIndex(), originalPageIdx); + const auto frame = + shadowFile.getShadowingFH().pinPage(shadowPageIdx, PageReadPolicy::READ_PAGE); + readOp(frame); + unpinShadowPage(originalPageIdx, shadowPageIdx, shadowFile); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/CMakeLists.txt new file mode 100644 index 0000000000..353a920ba1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_storage_stats + OBJECT + column_stats.cpp + hyperloglog.cpp + table_stats.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/column_stats.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/column_stats.cpp new file mode 100644 index 0000000000..ce1b4992ba --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/column_stats.cpp @@ -0,0 +1,32 @@ +#include "storage/stats/column_stats.h" + +#include "function/hash/vector_hash_functions.h" + +namespace lbug { +namespace storage { + +ColumnStats::ColumnStats(const common::LogicalType& dataType) : hashes{nullptr} { + if (!common::LogicalTypeUtils::isNested(dataType)) { + hll.emplace(); + } +} + +void ColumnStats::update(const common::ValueVector* vector) { + if (hll) { + if (!hashes) { + hashes = std::make_unique(common::LogicalTypeID::UINT64); + } + hashes->state = vector->state; + function::VectorHashFunction::computeHash(*vector, vector->state->getSelVector(), *hashes, + hashes->state->getSelVector()); + KU_ASSERT(hashes->hasNoNullsGuarantee()); + for (auto i = 0u; i < hashes->state->getSelVector().getSelSize(); i++) { + hll->insertElement(hashes->getValue(i)); + } + hashes->state = nullptr; + hashes->setAllNonNull(); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/hyperloglog.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/hyperloglog.cpp new file mode 100644 index 0000000000..5916275939 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/hyperloglog.cpp @@ -0,0 +1,90 @@ +#include "storage/stats/hyperloglog.h" + +#include + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" + +namespace lbug { +namespace storage { + +common::cardinality_t HyperLogLog::count() const { + uint32_t c[Q + 2] = {0}; + extractCounts(c); + return static_cast(estimateCardinality(c)); +} + +void HyperLogLog::merge(const HyperLogLog& other) { + for (auto i = 0u; i < M; ++i) { + update(i, other.k[i]); + } +} + +void HyperLogLog::extractCounts(uint32_t* c) const { + for (auto i = 0u; i < M; ++i) { + c[k[i]]++; + } +} + +//! Taken from redis code +static double HLLSigma(double x) { + if (x == 1.) { + return std::numeric_limits::infinity(); + } + double z_prime = NAN; + double y = 1; + double z = x; + do { + x *= x; + z_prime = z; + z += x * y; + y += y; + } while (z_prime != z); + return z; +} + +//! Taken from redis code +static double HLLTau(double x) { + if (x == 0. || x == 1.) { + return 0.; + } + double z_prime = NAN; + double y = 1.0; + double z = 1 - x; + do { + x = sqrt(x); + z_prime = z; + y *= 0.5; + z -= pow(1 - x, 2) * y; + } while (z_prime != z); + return z / 3; +} + +int64_t HyperLogLog::estimateCardinality(const uint32_t* c) { + auto z = M * HLLTau((static_cast(M) - c[Q]) / static_cast(M)); + + for (auto k = Q; k >= 1; --k) { + z += c[k]; + z *= 0.5; + } + + z += M * HLLSigma(c[0] / static_cast(M)); + + return llroundl(ALPHA * M * M / z); +} + +void HyperLogLog::serialize(common::Serializer& serializer) const { + serializer.writeDebuggingInfo("hll_data"); + serializer.serializeArray(k); +} + +HyperLogLog HyperLogLog::deserialize(common::Deserializer& deserializer) { + HyperLogLog result; + std::string info; + deserializer.validateDebuggingInfo(info, "hll_data"); + deserializer.deserializeArray(result.k); + return result; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/table_stats.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/table_stats.cpp new file mode 100644 index 0000000000..9091d39f1d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/stats/table_stats.cpp @@ -0,0 +1,64 @@ +#include "storage/stats/table_stats.h" + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" + +namespace lbug { +namespace storage { + +TableStats::TableStats(std::span dataTypes) : cardinality{0} { + for (const auto& dataType : dataTypes) { + columnStats.emplace_back(dataType); + } +} + +TableStats::TableStats(const TableStats& other) : cardinality{other.cardinality} { + columnStats.reserve(other.columnStats.size()); + for (auto i = 0u; i < other.columnStats.size(); ++i) { + columnStats.emplace_back(other.columnStats[i].copy()); + } +} + +void TableStats::update(const std::vector& vectors, size_t numColumns) { + std::vector dummyColumnIDs; + for (auto i = 0u; i < vectors.size(); ++i) { + dummyColumnIDs.push_back(i); + } + update(dummyColumnIDs, vectors, numColumns); +} + +void TableStats::update(const std::vector& columnIDs, + const std::vector& vectors, size_t numColumns) { + KU_ASSERT(columnIDs.size() == vectors.size()); + size_t numColumnsToUpdate = std::min(numColumns, vectors.size()); + + for (auto i = 0u; i < numColumnsToUpdate; ++i) { + auto columnID = columnIDs[i]; + KU_ASSERT(columnID < columnStats.size()); + columnStats[columnID].update(vectors[i]); + } + const auto numValues = vectors[0]->state->getSelVector().getSelSize(); + for (auto i = 1u; i < numColumnsToUpdate; ++i) { + KU_ASSERT(vectors[i]->state->getSelVector().getSelSize() == numValues); + } + incrementCardinality(numValues); +} + +void TableStats::serialize(common::Serializer& serializer) const { + serializer.writeDebuggingInfo("cardinality"); + serializer.write(cardinality); + serializer.writeDebuggingInfo("column_stats"); + serializer.serializeVector(columnStats); +} + +TableStats TableStats::deserialize(common::Deserializer& deserializer) { + std::string info; + deserializer.validateDebuggingInfo(info, "cardinality"); + deserializer.deserializeValue(cardinality); + deserializer.validateDebuggingInfo(info, "column_stats"); + deserializer.deserializeVector(columnStats); + return *this; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_manager.cpp new file mode 100644 index 0000000000..16e4c1161f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_manager.cpp @@ -0,0 +1,318 @@ +#include "storage/storage_manager.h" + +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/file_system/virtual_file_system.h" +#include "common/random_engine.h" +#include "common/serializer/in_mem_file_writer.h" +#include "main/attached_database.h" +#include "main/client_context.h" +#include "main/database.h" +#include "main/db_config.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/checkpointer.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" +#include "storage/wal/wal_replayer.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +StorageManager::StorageManager(const std::string& databasePath, bool readOnly, bool enableChecksums, + MemoryManager& memoryManager, bool enableCompression, VirtualFileSystem* vfs) + : databasePath{databasePath}, readOnly{readOnly}, dataFH{nullptr}, memoryManager{memoryManager}, + enableCompression{enableCompression} { + wal = std::make_unique(databasePath, readOnly, enableChecksums, vfs); + shadowFile = + std::make_unique(*memoryManager.getBufferManager(), vfs, this->databasePath); + inMemory = main::DBConfig::isDBPathInMemory(databasePath); + registerIndexType(PrimaryKeyIndex::getIndexType()); +} + +StorageManager::~StorageManager() = default; + +void StorageManager::initDataFileHandle(VirtualFileSystem* vfs, main::ClientContext* context) { + if (inMemory) { + dataFH = memoryManager.getBufferManager()->getFileHandle(databasePath, + FileHandle::O_PERSISTENT_FILE_IN_MEM, vfs, context); + } else { + auto flag = readOnly ? FileHandle::O_PERSISTENT_FILE_READ_ONLY : + FileHandle::O_PERSISTENT_FILE_CREATE_NOT_EXISTS; + flag |= FileHandle::O_LOCKED_PERSISTENT_FILE; + dataFH = memoryManager.getBufferManager()->getFileHandle(databasePath, flag, vfs, context); + if (dataFH->getNumPages() == 0) { + if (!readOnly) { + // Reserve the first page for the database header. + dataFH->getPageManager()->allocatePage(); + // Write a dummy database header page. + const auto* initialHeader = getOrInitDatabaseHeader(*context); + auto headerWriter = + std::make_shared(*MemoryManager::Get(*context)); + Serializer headerSerializer(headerWriter); + initialHeader->serialize(headerSerializer); + dataFH->getFileInfo()->writeFile(headerWriter->getPage(0).data(), LBUG_PAGE_SIZE, + StorageConstants::DB_HEADER_PAGE_IDX); + dataFH->getFileInfo()->syncFile(); + } + } + } +} + +Table* StorageManager::getTable(table_id_t tableID) { + std::lock_guard lck{mtx}; + KU_ASSERT(tables.contains(tableID)); + return tables.at(tableID).get(); +} + +void StorageManager::recover(main::ClientContext& clientContext, bool throwOnWalReplayFailure, + bool enableChecksums) { + const auto walReplayer = std::make_unique(clientContext); + walReplayer->replay(throwOnWalReplayFailure, enableChecksums); +} + +void StorageManager::createNodeTable(NodeTableCatalogEntry* entry) { + tables[entry->getTableID()] = std::make_unique(this, entry, &memoryManager); +} + +// TODO(Guodong): This API is added since storageManager doesn't provide an API to add a single +// rel table. We may have to refactor the existing StorageManager::createTable(TableCatalogEntry* +// entry). +void StorageManager::addRelTable(RelGroupCatalogEntry* entry, const RelTableCatalogInfo& info) { + tables[info.oid] = std::make_unique(entry, info.nodePair.srcTableID, + info.nodePair.dstTableID, this, &memoryManager); +} + +void StorageManager::createRelTableGroup(RelGroupCatalogEntry* entry) { + for (auto& info : entry->getRelEntryInfos()) { + addRelTable(entry, info); + } +} + +void StorageManager::createTable(TableCatalogEntry* entry) { + std::lock_guard lck{mtx}; + switch (entry->getType()) { + case CatalogEntryType::NODE_TABLE_ENTRY: { + createNodeTable(entry->ptrCast()); + } break; + case CatalogEntryType::REL_GROUP_ENTRY: { + createRelTableGroup(entry->ptrCast()); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +WAL& StorageManager::getWAL() const { + KU_ASSERT(wal); + return *wal; +} + +ShadowFile& StorageManager::getShadowFile() const { + KU_ASSERT(shadowFile); + return *shadowFile; +} + +void StorageManager::reclaimDroppedTables(const Catalog& catalog) { + std::vector droppedTables; + for (const auto& [tableID, table] : tables) { + switch (table->getTableType()) { + case TableType::NODE: { + if (!catalog.containsTable(&DUMMY_CHECKPOINT_TRANSACTION, tableID, true)) { + table->reclaimStorage(*dataFH->getPageManager()); + droppedTables.push_back(tableID); + } + } break; + case TableType::REL: { + auto& relTable = table->cast(); + auto relGroupID = relTable.getRelGroupID(); + if (!catalog.containsTable(&DUMMY_CHECKPOINT_TRANSACTION, relGroupID, true)) { + table->reclaimStorage(*dataFH->getPageManager()); + droppedTables.push_back(tableID); + } else { + auto relGroupEntry = + catalog.getTableCatalogEntry(&DUMMY_CHECKPOINT_TRANSACTION, relGroupID); + if (!relGroupEntry->cast().getRelEntryInfo( + relTable.getFromNodeTableID(), relTable.getToNodeTableID())) { + table->reclaimStorage(*dataFH->getPageManager()); + droppedTables.push_back(tableID); + } + } + } + default: { + // DO NOTHING. + } + } + } + for (auto tableID : droppedTables) { + tables.erase(tableID); + } +} + +bool StorageManager::checkpoint(main::ClientContext* context, PageAllocator& pageAllocator) { + bool hasChanges = false; + const auto catalog = Catalog::Get(*context); + const auto nodeTableEntries = catalog->getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION); + const auto relGroupEntries = catalog->getRelGroupEntries(&DUMMY_CHECKPOINT_TRANSACTION); + + for (const auto entry : nodeTableEntries) { + if (!tables.contains(entry->getTableID())) { + throw RuntimeException(stringFormat( + "Checkpoint failed: table {} not found in storage manager.", entry->getName())); + } + hasChanges = + tables.at(entry->getTableID())->checkpoint(context, entry, pageAllocator) || hasChanges; + } + for (const auto entry : relGroupEntries) { + for (auto& info : entry->getRelEntryInfos()) { + if (!tables.contains(info.oid)) { + throw RuntimeException(stringFormat( + "Checkpoint failed: table {} not found in storage manager.", entry->getName())); + } + hasChanges = + tables.at(info.oid)->checkpoint(context, entry, pageAllocator) || hasChanges; + } + entry->vacuumColumnIDs(1); + } + reclaimDroppedTables(*catalog); + return hasChanges; +} + +void StorageManager::finalizeCheckpoint() { + dataFH->getPageManager()->finalizeCheckpoint(); +} + +void StorageManager::rollbackCheckpoint(const Catalog& catalog) { + std::lock_guard lck{mtx}; + const auto nodeTableEntries = catalog.getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION); + for (const auto tableEntry : nodeTableEntries) { + KU_ASSERT(tables.contains(tableEntry->getTableID())); + tables.at(tableEntry->getTableID())->rollbackCheckpoint(); + } + dataFH->getPageManager()->rollbackCheckpoint(); +} + +std::optional> StorageManager::getIndexType( + const std::string& typeName) const { + for (auto& indexType : registeredIndexTypes) { + if (StringUtils::caseInsensitiveEquals(indexType.typeName, typeName)) { + return indexType; + } + } + return std::nullopt; +} + +void StorageManager::serialize(const Catalog& catalog, Serializer& ser) { + std::lock_guard lck{mtx}; + auto nodeTableEntries = catalog.getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION); + auto relGroupEntries = catalog.getRelGroupEntries(&DUMMY_CHECKPOINT_TRANSACTION); + std::sort(nodeTableEntries.begin(), nodeTableEntries.end(), + [](const auto& a, const auto& b) { return a->getTableID() < b->getTableID(); }); + std::sort(relGroupEntries.begin(), relGroupEntries.end(), + [](const auto& a, const auto& b) { return a->getTableID() < b->getTableID(); }); + ser.writeDebuggingInfo("num_node_tables"); + ser.write(nodeTableEntries.size()); + for (const auto tableEntry : nodeTableEntries) { + KU_ASSERT(tables.contains(tableEntry->getTableID())); + ser.writeDebuggingInfo("table_id"); + ser.write(tableEntry->getTableID()); + tables.at(tableEntry->getTableID())->serialize(ser); + } + ser.writeDebuggingInfo("num_rel_groups"); + ser.write(relGroupEntries.size()); + for (const auto entry : relGroupEntries) { + const auto& relGroupEntry = entry->cast(); + ser.writeDebuggingInfo("rel_group_id"); + ser.write(relGroupEntry.getTableID()); + ser.writeDebuggingInfo("num_inner_rel_tables"); + ser.write(relGroupEntry.getNumRelTables()); + for (auto& info : relGroupEntry.getRelEntryInfos()) { + KU_ASSERT(tables.contains(info.oid)); + info.serialize(ser); + tables.at(info.oid)->serialize(ser); + } + } +} + +void StorageManager::deserialize(main::ClientContext* context, const Catalog* catalog, + Deserializer& deSer) { + std::string key; + deSer.validateDebuggingInfo(key, "num_node_tables"); + uint64_t numNodeTables = 0; + deSer.deserializeValue(numNodeTables); + for (auto i = 0u; i < numNodeTables; i++) { + deSer.validateDebuggingInfo(key, "table_id"); + table_id_t tableID = INVALID_TABLE_ID; + deSer.deserializeValue(tableID); + if (!catalog->containsTable(&DUMMY_TRANSACTION, tableID)) { + throw RuntimeException( + stringFormat("Load table failed: table {} doesn't exist in catalog.", tableID)); + } + KU_ASSERT(!tables.contains(tableID)); + auto tableEntry = catalog->getTableCatalogEntry(&DUMMY_TRANSACTION, tableID) + ->ptrCast(); + tables[tableID] = std::make_unique(this, tableEntry, &memoryManager); + tables[tableID]->deserialize(context, this, deSer); + } + deSer.validateDebuggingInfo(key, "num_rel_groups"); + uint64_t numRelGroups = 0; + deSer.deserializeValue(numRelGroups); + for (auto i = 0u; i < numRelGroups; i++) { + deSer.validateDebuggingInfo(key, "rel_group_id"); + table_id_t relGroupID = INVALID_TABLE_ID; + deSer.deserializeValue(relGroupID); + if (!catalog->containsTable(&DUMMY_TRANSACTION, relGroupID)) { + throw RuntimeException( + stringFormat("Load table failed: table {} doesn't exist in catalog.", relGroupID)); + } + deSer.validateDebuggingInfo(key, "num_inner_rel_tables"); + uint64_t numInnerRelTables = 0; + deSer.deserializeValue(numInnerRelTables); + auto relGroupEntry = catalog->getTableCatalogEntry(&DUMMY_TRANSACTION, relGroupID) + ->ptrCast(); + for (auto k = 0u; k < numInnerRelTables; k++) { + RelTableCatalogInfo info = RelTableCatalogInfo::deserialize(deSer); + KU_ASSERT(!tables.contains(info.oid)); + tables[info.oid] = std::make_unique(relGroupEntry, info.nodePair.srcTableID, + info.nodePair.dstTableID, this, &memoryManager); + tables.at(info.oid)->deserialize(context, this, deSer); + } + } +} + +common::ku_uuid_t StorageManager::getOrInitDatabaseID(const main::ClientContext& clientContext) { + return getOrInitDatabaseHeader(clientContext)->databaseID; +} + +const storage::DatabaseHeader* StorageManager::getOrInitDatabaseHeader( + const main::ClientContext& clientContext) { + if (databaseHeader == nullptr) { + // We should only create the database header if a persistent one doesn't exist + KU_ASSERT(std::nullopt == DatabaseHeader::readDatabaseHeader(*dataFH->getFileInfo())); + databaseHeader = std::make_unique( + DatabaseHeader::createInitialHeader(RandomEngine::Get(clientContext))); + } + return databaseHeader.get(); +} + +void StorageManager::setDatabaseHeader(std::unique_ptr header) { + KU_ASSERT(!databaseHeader || header->databaseID.value == databaseHeader->databaseID.value); + databaseHeader = std::move(header); +} + +StorageManager* StorageManager::Get(const main::ClientContext& context) { + if (context.getAttachedDatabase()) { + return context.getAttachedDatabase()->getStorageManager(); + } else { + return context.getDatabase()->getStorageManager(); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_utils.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_utils.cpp new file mode 100644 index 0000000000..d9e585d886 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_utils.cpp @@ -0,0 +1,93 @@ +#include "storage/storage_utils.h" + +#include + +#include "common/null_buffer.h" +#include "common/string_format.h" +#include "common/types/ku_list.h" +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "main/settings.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +std::string StorageUtils::getColumnName(const std::string& propertyName, ColumnType type, + const std::string& prefix) { + switch (type) { + case ColumnType::DATA: { + return stringFormat("{}_data", propertyName); + } + case ColumnType::NULL_MASK: { + return stringFormat("{}_null", propertyName); + } + case ColumnType::INDEX: { + return stringFormat("{}_index", propertyName); + } + case ColumnType::OFFSET: { + return stringFormat("{}_offset", propertyName); + } + case ColumnType::CSR_OFFSET: { + return stringFormat("{}_csr_offset", prefix); + } + case ColumnType::CSR_LENGTH: { + return stringFormat("{}_csr_length", prefix); + } + case ColumnType::STRUCT_CHILD: { + return stringFormat("{}_{}_child", propertyName, prefix); + } + default: { + if (prefix.empty()) { + return propertyName; + } + return stringFormat("{}_{}", prefix, propertyName); + } + } +} + +std::string StorageUtils::expandPath(const main::ClientContext* context, const std::string& path) { + if (main::DBConfig::isDBPathInMemory(path)) { + return path; + } + auto fullPath = path; + // Handle '~' for home directory expansion + if (path.starts_with('~')) { + fullPath = + context->getCurrentSetting(main::HomeDirectorySetting::name).getValue() + + fullPath.substr(1); + } + // Normalize the path to resolve '.' and '..' + std::filesystem::path normalizedPath = std::filesystem::absolute(fullPath).lexically_normal(); + return normalizedPath.string(); +} + +uint32_t StorageUtils::getDataTypeSize(const LogicalType& type) { + switch (type.getPhysicalType()) { + case PhysicalTypeID::STRING: { + return sizeof(ku_string_t); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return sizeof(ku_list_t); + } + case PhysicalTypeID::STRUCT: { + uint32_t size = 0; + const auto fieldsTypes = StructType::getFieldTypes(type); + for (const auto& fieldType : fieldsTypes) { + size += getDataTypeSize(*fieldType); + } + size += NullBuffer::getNumBytesForNullValues(fieldsTypes.size()); + return size; + } + default: { + return PhysicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); + } + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_version_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_version_info.cpp new file mode 100644 index 0000000000..3ec3f9c863 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/storage_version_info.cpp @@ -0,0 +1,23 @@ +#include "storage/storage_version_info.h" + +namespace lbug { +namespace storage { + +storage_version_t StorageVersionInfo::getStorageVersion() { + auto storageVersionInfo = getStorageVersionInfo(); + if (!storageVersionInfo.contains(LBUG_CMAKE_VERSION)) { + // If the current LBUG_CMAKE_VERSION is not in the map, + // then we must run the newest version of lbug + // LCOV_EXCL_START + storage_version_t maxVersion = 0; + for (auto& [_, versionNumber] : storageVersionInfo) { + maxVersion = std::max(maxVersion, versionNumber); + } + return maxVersion; + // LCOV_EXCL_STOP + } + return storageVersionInfo.at(LBUG_CMAKE_VERSION); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/CMakeLists.txt new file mode 100644 index 0000000000..4d7e1938e4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/CMakeLists.txt @@ -0,0 +1,39 @@ +add_library(lbug_storage_store + OBJECT + chunked_node_group.cpp + column.cpp + column_chunk.cpp + column_chunk_data.cpp + column_chunk_stats.cpp + csr_chunked_node_group.cpp + csr_node_group.cpp + column_reader_writer.cpp + column_chunk_data.cpp + column_chunk_metadata.cpp + compression_flush_buffer.cpp + dictionary_chunk.cpp + dictionary_column.cpp + in_mem_chunked_node_group_collection.cpp + in_memory_exception_chunk.cpp + lazy_segment_scanner.cpp + list_chunk_data.cpp + list_column.cpp + node_group.cpp + node_group_collection.cpp + node_table.cpp + null_column.cpp + rel_table.cpp + rel_table_data.cpp + string_chunk_data.cpp + string_column.cpp + struct_chunk_data.cpp + struct_column.cpp + table.cpp + update_info.cpp + version_info.cpp + version_record_handler.cpp +) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/chunked_node_group.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/chunked_node_group.cpp new file mode 100644 index 0000000000..c7647d3349 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/chunked_node_group.cpp @@ -0,0 +1,723 @@ +#include "storage/table/chunked_node_group.h" + +#include + +#include "common/assert.h" +#include "common/types/types.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/buffer_manager/spiller.h" +#include "storage/enums/residency_state.h" +#include "storage/page_allocator.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/column_chunk_scanner.h" +#include "storage/table/node_table.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +template +static void handleAppendException(std::vector>& chunks, uint64_t numRows) { + // After an exception is thrown other threads may continue to work on this chunked group for a + // while before they are interrupted + // Although the changes will eventually be rolled back + // We reset the state of the chunk so later changes won't corrupt any data + // Due to the numValues in column chunks not matching the number of rows + for (const auto& chunk : chunks) { + chunk->truncate(numRows); + } + std::rethrow_exception(std::current_exception()); +} + +ChunkedNodeGroup::ChunkedNodeGroup(std::vector> chunks, + row_idx_t startRowIdx, NodeGroupDataFormat format) + : format{format}, startRowIdx{startRowIdx}, chunks{std::move(chunks)} { + KU_ASSERT(!this->chunks.empty()); + residencyState = this->chunks[0]->getResidencyState(); + numRows = this->chunks[0]->getNumValues(); + capacity = numRows; + for (auto columnID = 1u; columnID < this->chunks.size(); columnID++) { + KU_ASSERT(this->chunks[columnID]->getNumValues() == numRows); + KU_ASSERT(this->chunks[columnID]->getResidencyState() == residencyState); + } +} + +ChunkedNodeGroup::ChunkedNodeGroup(ChunkedNodeGroup& base, + const std::vector& selectedColumns) + : format{base.format}, residencyState{base.residencyState}, startRowIdx{base.startRowIdx}, + capacity{base.capacity}, numRows{base.numRows.load()} { + chunks.resize(selectedColumns.size()); + for (auto i = 0u; i < selectedColumns.size(); i++) { + auto columnID = selectedColumns[i]; + KU_ASSERT(columnID < base.getNumColumns()); + chunks[i] = base.moveColumnChunk(columnID); + } +} + +ChunkedNodeGroup::ChunkedNodeGroup(InMemChunkedNodeGroup& base, + const std::vector& selectedColumns, NodeGroupDataFormat format) + : format{format}, residencyState{ResidencyState::IN_MEMORY}, startRowIdx{base.getStartRowIdx()}, + capacity{base.getCapacity()}, numRows{base.getNumRows()} { + chunks.resize(selectedColumns.size()); + for (auto i = 0u; i < selectedColumns.size(); i++) { + auto columnID = selectedColumns[i]; + KU_ASSERT(columnID < base.getNumColumns()); + chunks[i] = std::make_unique(true /*enableCompression*/, + base.moveColumnChunk(columnID)); + } +} + +ChunkedNodeGroup::ChunkedNodeGroup(MemoryManager& mm, const std::vector& columnTypes, + bool enableCompression, uint64_t capacity, row_idx_t startRowIdx, ResidencyState residencyState, + NodeGroupDataFormat format) + : format{format}, residencyState{residencyState}, startRowIdx{startRowIdx}, capacity{capacity}, + numRows{0} { + chunks.reserve(columnTypes.size()); + for (auto& type : columnTypes) { + chunks.push_back(std::make_unique(mm, type.copy(), capacity, enableCompression, + residencyState)); + } +} + +ChunkedNodeGroup::ChunkedNodeGroup(MemoryManager& mm, ChunkedNodeGroup& base, + std::span columnTypes, std::span baseColumnIDs) + : format{base.format}, residencyState{base.residencyState}, startRowIdx{base.startRowIdx}, + capacity{base.capacity}, numRows{base.numRows.load()}, + versionInfo(std::move(base.versionInfo)) { + bool enableCompression = false; + KU_ASSERT(!baseColumnIDs.empty()); + + chunks.resize(columnTypes.size()); + + KU_ASSERT(base.getNumColumns() == baseColumnIDs.size()); + for (column_id_t i = 0; i < baseColumnIDs.size(); ++i) { + auto baseColumnID = baseColumnIDs[i]; + KU_ASSERT(baseColumnID < chunks.size()); + chunks[baseColumnID] = base.moveColumnChunk(i); + enableCompression = chunks[baseColumnID]->isCompressionEnabled(); + KU_ASSERT(chunks[baseColumnID]->getDataType().getPhysicalType() == + columnTypes[baseColumnID].getPhysicalType()); + } + + for (column_id_t i = 0; i < columnTypes.size(); ++i) { + if (chunks[i] == nullptr) { + chunks[i] = std::make_unique(mm, columnTypes[i].copy(), 0, + enableCompression, ResidencyState::IN_MEMORY); + } + } +} + +void ChunkedNodeGroup::resetNumRowsFromChunks() { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + KU_ASSERT(!chunks.empty()); + numRows = getColumnChunk(0).getNumValues(); + capacity = numRows; + for (auto i = 1u; i < getNumColumns(); i++) { + KU_ASSERT(numRows == getColumnChunk(i).getNumValues()); + } +} + +void ChunkedNodeGroup::resetVersionAndUpdateInfo() { + if (versionInfo) { + versionInfo.reset(); + } + for (const auto& chunk : chunks) { + chunk->resetUpdateInfo(); + } +} + +void ChunkedNodeGroup::truncate(const offset_t numRows_) { + KU_ASSERT(numRows >= numRows_); + for (const auto& chunk : chunks) { + chunk->truncate(numRows_); + } + numRows = numRows_; +} + +void InMemChunkedNodeGroup::setNumRows(const offset_t numRows_) { + for (const auto& chunk : chunks) { + chunk->setNumValues(numRows_); + } + numRows = numRows_; +} + +uint64_t ChunkedNodeGroup::append(const Transaction* transaction, + const std::vector& columnVectors, row_idx_t startRowInVectors, + uint64_t numValuesToAppend) { + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + KU_ASSERT(columnVectors.size() == chunks.size()); + const auto numRowsToAppendInChunk = std::min(numValuesToAppend, capacity - numRows); + try { + for (auto i = 0u; i < columnVectors.size(); i++) { + const auto columnVector = columnVectors[i]; + chunks[i]->append(columnVector, columnVector->state->getSelVector().slice( + startRowInVectors, numRowsToAppendInChunk)); + } + } catch ([[maybe_unused]] std::exception& e) { + handleAppendException(chunks, numRows); + } + if (transaction->shouldAppendToUndoBuffer()) { + if (!versionInfo) { + versionInfo = std::make_unique(); + } + versionInfo->append(transaction->getID(), numRows, numRowsToAppendInChunk); + } + numRows += numRowsToAppendInChunk; + return numRowsToAppendInChunk; +} + +offset_t ChunkedNodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, const ChunkedNodeGroup& other, + offset_t offsetInOtherNodeGroup, offset_t numRowsToAppend) { + KU_ASSERT(residencyState == ResidencyState::IN_MEMORY); + KU_ASSERT(other.chunks.size() == chunks.size()); + std::vector chunksToAppend(other.chunks.size()); + for (auto i = 0u; i < chunks.size(); i++) { + chunksToAppend[i] = other.chunks[i].get(); + } + return append(transaction, columnIDs, chunksToAppend, offsetInOtherNodeGroup, numRowsToAppend); +} + +offset_t ChunkedNodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, const InMemChunkedNodeGroup& other, + offset_t offsetInOtherNodeGroup, offset_t numRowsToAppend) { + KU_ASSERT(residencyState == ResidencyState::IN_MEMORY); + KU_ASSERT(other.chunks.size() == chunks.size()); + std::vector chunksToAppend(other.chunks.size()); + for (auto i = 0u; i < chunks.size(); i++) { + chunksToAppend[i] = other.chunks[i].get(); + } + return append(transaction, columnIDs, chunksToAppend, offsetInOtherNodeGroup, numRowsToAppend); +} + +offset_t ChunkedNodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, std::span other, + offset_t offsetInOtherNodeGroup, offset_t numRowsToAppend) { + KU_ASSERT(residencyState == ResidencyState::IN_MEMORY); + KU_ASSERT(other.size() == columnIDs.size()); + const auto numToAppendInChunkedGroup = std::min(numRowsToAppend, capacity - numRows); + try { + for (auto i = 0u; i < columnIDs.size(); i++) { + auto columnID = columnIDs[i]; + KU_ASSERT(columnID < chunks.size()); + chunks[columnID]->append(other[i], offsetInOtherNodeGroup, numToAppendInChunkedGroup); + } + } catch ([[maybe_unused]] std::exception& e) { + handleAppendException(chunks, numRows); + } + if (transaction->getID() != Transaction::DUMMY_TRANSACTION_ID) { + if (!versionInfo) { + versionInfo = std::make_unique(); + } + versionInfo->append(transaction->getID(), numRows, numToAppendInChunkedGroup); + } + numRows += numToAppendInChunkedGroup; + return numToAppendInChunkedGroup; +} + +offset_t ChunkedNodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, std::span other, + offset_t offsetInOtherNodeGroup, offset_t numRowsToAppend) { + KU_ASSERT(residencyState == ResidencyState::IN_MEMORY); + KU_ASSERT(other.size() == columnIDs.size()); + const auto numToAppendInChunkedGroup = std::min(numRowsToAppend, capacity - numRows); + try { + for (auto i = 0u; i < columnIDs.size(); i++) { + auto columnID = columnIDs[i]; + KU_ASSERT(columnID < chunks.size()); + chunks[columnID]->append(other[i], offsetInOtherNodeGroup, numToAppendInChunkedGroup); + } + } catch ([[maybe_unused]] std::exception& e) { + handleAppendException(chunks, numRows); + } + if (transaction->shouldAppendToUndoBuffer()) { + if (!versionInfo) { + versionInfo = std::make_unique(); + } + versionInfo->append(transaction->getID(), numRows, numToAppendInChunkedGroup); + } + numRows += numToAppendInChunkedGroup; + return numToAppendInChunkedGroup; +} + +void InMemChunkedNodeGroup::write(const InMemChunkedNodeGroup& data, column_id_t offsetColumnID) { + KU_ASSERT(data.chunks.size() == chunks.size() + 1); + auto& offsetChunk = data.chunks[offsetColumnID]; + column_id_t columnID = 0, chunkIdx = 0; + for (auto i = 0u; i < data.chunks.size(); i++) { + if (i == offsetColumnID) { + columnID++; + continue; + } + KU_ASSERT(columnID < data.chunks.size()); + writeToColumnChunk(chunkIdx, columnID, data.chunks, *offsetChunk); + chunkIdx++; + columnID++; + } + numRows = chunks[0]->getNumValues(); + for (auto i = 1u; i < chunks.size(); i++) { + KU_ASSERT(numRows == chunks[i]->getNumValues()); + } +} + +static ZoneMapCheckResult getZoneMapResult(const TableScanState& scanState, + const std::vector>& chunks) { + if (!scanState.columnPredicateSets.empty()) { + for (auto i = 0u; i < scanState.columnIDs.size(); i++) { + const auto columnID = scanState.columnIDs[i]; + if (columnID == INVALID_COLUMN_ID || columnID == ROW_IDX_COLUMN_ID) { + continue; + } + + KU_ASSERT(i < scanState.columnPredicateSets.size()); + if (chunks[columnID]->hasUpdates()) { + // With updates, we need to merge with update data for the correct stats, which can + // be slow if there are lots of updates. We defer this for now. + return ZoneMapCheckResult::ALWAYS_SCAN; + } + const auto columnZoneMapResult = scanState.columnPredicateSets[i].checkZoneMap( + chunks[columnID]->getMergedColumnChunkStats()); + if (columnZoneMapResult == ZoneMapCheckResult::SKIP_SCAN) { + return ZoneMapCheckResult::SKIP_SCAN; + } + } + } + return ZoneMapCheckResult::ALWAYS_SCAN; +} + +void ChunkedNodeGroup::scan(const Transaction* transaction, const TableScanState& scanState, + const NodeGroupScanState& nodeGroupScanState, offset_t rowIdxInGroup, + length_t numRowsToScan) const { + KU_ASSERT(rowIdxInGroup + numRowsToScan <= numRows); + auto& anchorSelVector = scanState.outState->getSelVectorUnsafe(); + if (getZoneMapResult(scanState, chunks) == ZoneMapCheckResult::SKIP_SCAN) { + anchorSelVector.setToFiltered(0); + return; + } + + if (versionInfo) { + versionInfo->getSelVectorToScan(transaction->getStartTS(), transaction->getID(), + anchorSelVector, rowIdxInGroup, numRowsToScan); + } else { + anchorSelVector.setToUnfiltered(numRowsToScan); + } + + if (anchorSelVector.getSelSize() > 0) { + for (auto i = 0u; i < scanState.columnIDs.size(); i++) { + const auto columnID = scanState.columnIDs[i]; + if (columnID == INVALID_COLUMN_ID) { + scanState.outputVectors[i]->setAllNull(); + continue; + } + if (columnID == ROW_IDX_COLUMN_ID) { + for (auto rowIdx = 0u; rowIdx < numRowsToScan; rowIdx++) { + scanState.rowIdxVector->setValue(rowIdx, + rowIdx + rowIdxInGroup + startRowIdx); + } + continue; + } + KU_ASSERT(columnID < chunks.size()); + chunks[columnID]->scan(transaction, nodeGroupScanState.chunkStates[i], + *scanState.outputVectors[i], rowIdxInGroup, numRowsToScan); + } + } +} + +template +void ChunkedNodeGroup::scanCommitted(Transaction* transaction, TableScanState& scanState, + InMemChunkedNodeGroup& output) const { + if (residencyState != SCAN_RESIDENCY_STATE) { + return; + } + for (auto i = 0u; i < scanState.columnIDs.size(); i++) { + const auto columnID = scanState.columnIDs[i]; + chunks[columnID]->scanCommitted(transaction, + scanState.nodeGroupScanState->chunkStates[i], output.getColumnChunk(i)); + } +} + +template void ChunkedNodeGroup::scanCommitted(Transaction* transaction, + TableScanState& scanState, InMemChunkedNodeGroup& output) const; +template void ChunkedNodeGroup::scanCommitted(Transaction* transaction, + TableScanState& scanState, InMemChunkedNodeGroup& output) const; + +bool ChunkedNodeGroup::hasDeletions(const Transaction* transaction) const { + return versionInfo && versionInfo->hasDeletions(transaction); +} + +row_idx_t ChunkedNodeGroup::getNumUpdatedRows(const Transaction* transaction, + column_id_t columnID) { + return getColumnChunk(columnID).getNumUpdatedRows(transaction); +} + +bool ChunkedNodeGroup::lookup(const Transaction* transaction, const TableScanState& state, + const NodeGroupScanState& nodeGroupScanState, offset_t rowIdxInChunk, sel_t posInOutput) const { + KU_ASSERT(rowIdxInChunk + 1 <= numRows); + const bool hasValuesToRead = versionInfo ? versionInfo->isSelected(transaction->getStartTS(), + transaction->getID(), rowIdxInChunk) : + true; + if (!hasValuesToRead) { + return false; + } + for (auto i = 0u; i < state.columnIDs.size(); i++) { + const auto columnID = state.columnIDs[i]; + if (columnID == INVALID_COLUMN_ID) { + state.outputVectors[i]->setAllNull(); + continue; + } + if (columnID == ROW_IDX_COLUMN_ID) { + state.rowIdxVector->setValue( + state.rowIdxVector->state->getSelVector()[posInOutput], + rowIdxInChunk + startRowIdx); + continue; + } + KU_ASSERT(columnID < chunks.size()); + KU_ASSERT(i < nodeGroupScanState.chunkStates.size()); + chunks[columnID]->lookup(transaction, nodeGroupScanState.chunkStates[i], rowIdxInChunk, + *state.outputVectors[i], state.outputVectors[i]->state->getSelVector()[posInOutput]); + } + return true; +} + +void ChunkedNodeGroup::update(const Transaction* transaction, row_idx_t rowIdxInChunk, + column_id_t columnID, const ValueVector& propertyVector) { + getColumnChunk(columnID).update(transaction, rowIdxInChunk, propertyVector); +} + +bool ChunkedNodeGroup::delete_(const Transaction* transaction, row_idx_t rowIdxInChunk) { + if (!versionInfo) { + versionInfo = std::make_unique(); + } + return versionInfo->delete_(transaction->getID(), rowIdxInChunk); +} + +void ChunkedNodeGroup::addColumn(MemoryManager& mm, const TableAddColumnState& addColumnState, + bool enableCompression, PageAllocator* pageAllocator, ColumnStats* newColumnStats) { + auto& dataType = addColumnState.propertyDefinition.getType(); + chunks.push_back(std::make_unique(mm, dataType.copy(), capacity, enableCompression, + ResidencyState::IN_MEMORY)); + auto numExistingRows = getNumRows(); + chunks.back()->populateWithDefaultVal(addColumnState.defaultEvaluator, numExistingRows, + newColumnStats); + if (residencyState == ResidencyState::ON_DISK) { + KU_ASSERT(pageAllocator); + chunks.back()->flush(*pageAllocator); + } +} + +bool ChunkedNodeGroup::isDeleted(const Transaction* transaction, row_idx_t rowInChunk) const { + if (!versionInfo) { + return false; + } + return versionInfo->isDeleted(transaction, rowInChunk); +} + +bool ChunkedNodeGroup::isInserted(const Transaction* transaction, row_idx_t rowInChunk) const { + if (!versionInfo) { + return rowInChunk < getNumRows(); + } + return versionInfo->isInserted(transaction, rowInChunk); +} + +bool ChunkedNodeGroup::hasAnyUpdates(const Transaction* transaction, column_id_t columnID, + row_idx_t startRow, length_t numRowsToCheck) const { + return getColumnChunk(columnID).hasUpdates(transaction, startRow, numRowsToCheck); +} + +row_idx_t ChunkedNodeGroup::getNumDeletions(const Transaction* transaction, row_idx_t startRow, + length_t numRowsToCheck) const { + if (versionInfo) { + return versionInfo->getNumDeletions(transaction, startRow, numRowsToCheck); + } + return 0; +} + +std::unique_ptr InMemChunkedNodeGroup::flushInternal(ColumnChunkData& chunk, + PageAllocator& pageAllocator) { + // Finalize is necessary prior to splitting for strings and lists so that pruned values + // don't have an impact on the number/size of segments It should not be necessary after + // splitting since the function is used to prune unused values (or duplicated dictionary + // entries in the case of strings) and those will never be introduced when splitting. + chunk.finalize(); + if (chunk.shouldSplit()) { + auto splitSegments = chunk.split(true /*new segments are always the max size if possible*/); + std::vector> flushedSegments; + flushedSegments.reserve(splitSegments.size()); + for (auto& segment : splitSegments) { + // TODO(bmwinger): This should be removed when splitting works predictively instead of + // backtracking if we copy too many values + // It's only needed to prune values from string/list chunks which were truncated + segment->finalize(); + flushedSegments.push_back(Column::flushChunkData(*segment, pageAllocator)); + } + return std::make_unique(chunk.isCompressionEnabled(), + std::move(flushedSegments)); + } else { + return std::make_unique(chunk.isCompressionEnabled(), + Column::flushChunkData(chunk, pageAllocator)); + } +} + +std::unique_ptr InMemChunkedNodeGroup::flush(Transaction* transaction, + PageAllocator& pageAllocator) { + std::vector> flushedChunks(getNumColumns()); + for (auto i = 0u; i < getNumColumns(); i++) { + flushedChunks[i] = flushInternal(getColumnChunk(i), pageAllocator); + } + auto flushedChunkedGroup = + std::make_unique(std::move(flushedChunks), 0 /*startRowIdx*/); + flushedChunkedGroup->versionInfo = std::make_unique(); + KU_ASSERT(flushedChunkedGroup->getNumRows() == numRows); + flushedChunkedGroup->versionInfo->append(transaction->getID(), 0, numRows); + return flushedChunkedGroup; +} + +std::unique_ptr ChunkedNodeGroup::flushEmpty(MemoryManager& mm, + const std::vector& columnTypes, bool enableCompression, uint64_t capacity, + common::row_idx_t startRowIdx, PageAllocator& pageAllocator) { + auto emptyGroup = std::make_unique(mm, columnTypes, enableCompression, + capacity, startRowIdx, ResidencyState::IN_MEMORY); + for (auto i = 0u; i < columnTypes.size(); i++) { + emptyGroup->getColumnChunk(i).flush(pageAllocator); + } + // Reset residencyState and numRows after flushing. + emptyGroup->residencyState = ResidencyState::ON_DISK; + return emptyGroup; +} + +uint64_t ChunkedNodeGroup::getEstimatedMemoryUsage() const { + if (residencyState == ResidencyState::ON_DISK) { + return 0; + } + uint64_t memoryUsage = 0; + for (const auto& chunk : chunks) { + memoryUsage += chunk->getEstimatedMemoryUsage(); + } + return memoryUsage; +} + +bool ChunkedNodeGroup::hasUpdates() const { + for (const auto& chunk : chunks) { + if (chunk->hasUpdates()) { + return true; + } + } + return false; +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void ChunkedNodeGroup::commitInsert(row_idx_t startRow, row_idx_t numRowsToCommit, + transaction_t commitTS) { + versionInfo->commitInsert(startRow, numRowsToCommit, commitTS); +} + +void ChunkedNodeGroup::rollbackInsert(row_idx_t startRow, row_idx_t numRows_, transaction_t) { + if (startRow == 0) { + truncate(0); + versionInfo.reset(); + return; + } + if (startRow >= numRows) { + // Nothing to rollback. + return; + } + versionInfo->rollbackInsert(startRow, numRows_); + numRows = startRow; +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void ChunkedNodeGroup::commitDelete(row_idx_t startRow, row_idx_t numRows_, + transaction_t commitTS) { + versionInfo->commitDelete(startRow, numRows_, commitTS); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void ChunkedNodeGroup::rollbackDelete(row_idx_t startRow, row_idx_t numRows_, transaction_t) { + versionInfo->rollbackDelete(startRow, numRows_); +} + +void ChunkedNodeGroup::reclaimStorage(PageAllocator& pageAllocator) const { + for (auto& columnChunk : chunks) { + if (columnChunk) { + columnChunk->reclaimStorage(pageAllocator); + } + } +} + +void ChunkedNodeGroup::serialize(Serializer& serializer) const { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + serializer.writeDebuggingInfo("chunks"); + serializer.serializeVectorOfPtrs(chunks); + serializer.writeDebuggingInfo("startRowIdx"); + serializer.write(startRowIdx); + serializer.writeDebuggingInfo("has_version_info"); + serializer.write(versionInfo != nullptr); + if (versionInfo) { + serializer.writeDebuggingInfo("version_info"); + versionInfo->serialize(serializer); + } +} + +std::unique_ptr ChunkedNodeGroup::deserialize(MemoryManager& memoryManager, + Deserializer& deSer) { + std::string key; + std::vector> chunks; + bool hasVersions = false; + row_idx_t startRowIdx = 0; + deSer.validateDebuggingInfo(key, "chunks"); + deSer.deserializeVectorOfPtrs(chunks, + [&](Deserializer& deser) { return ColumnChunk::deserialize(memoryManager, deser); }); + deSer.validateDebuggingInfo(key, "startRowIdx"); + deSer.deserializeValue(startRowIdx); + auto chunkedGroup = std::make_unique(std::move(chunks), startRowIdx); + deSer.validateDebuggingInfo(key, "has_version_info"); + deSer.deserializeValue(hasVersions); + if (hasVersions) { + deSer.validateDebuggingInfo(key, "version_info"); + chunkedGroup->versionInfo = VersionInfo::deserialize(deSer); + } + return chunkedGroup; +} + +InMemChunkedNodeGroup::InMemChunkedNodeGroup(MemoryManager& mm, + const std::vector& columnTypes, bool enableCompression, uint64_t capacity, + common::row_idx_t startRowIdx) + : startRowIdx{startRowIdx}, numRows{0}, capacity{capacity}, dataInUse{true} { + chunks.reserve(columnTypes.size()); + for (auto& type : columnTypes) { + chunks.push_back(ColumnChunkFactory::createColumnChunkData(mm, type.copy(), + enableCompression, capacity, ResidencyState::IN_MEMORY)); + } +} + +InMemChunkedNodeGroup::InMemChunkedNodeGroup(std::vector>&& chunks, + row_idx_t startRowIdx) + : startRowIdx{startRowIdx}, numRows{chunks[0]->getNumValues()}, capacity{numRows}, + chunks{std::move(chunks)}, dataInUse{true} { + KU_ASSERT(!this->chunks.empty()); + for (auto columnID = 1u; columnID < this->chunks.size(); columnID++) { + KU_ASSERT(this->chunks[columnID]->getNumValues() == numRows); + } +} + +void InMemChunkedNodeGroup::setUnused(const MemoryManager& mm) { + dataInUse = false; + mm.getBufferManager()->getSpillerOrSkip([&](auto& spiller) { spiller.addUnusedChunk(this); }); +} + +void InMemChunkedNodeGroup::loadFromDisk(const MemoryManager& mm) { + mm.getBufferManager()->getSpillerOrSkip([&](auto& spiller) { + std::unique_lock lock{spillToDiskMutex}; + // Prevent buffer manager from being able to spill this chunk to disk + spiller.clearUnusedChunk(this); + for (auto& chunk : chunks) { + chunk->loadFromDisk(); + } + dataInUse = true; + }); +} + +SpillResult InMemChunkedNodeGroup::spillToDisk() { + uint64_t reclaimedSpace = 0; + uint64_t nowEvictableMemory = 0; + std::unique_lock lock{spillToDiskMutex}; + // Its possible that the chunk may be loaded and marked as in-use between when it is selected to + // be spilled to disk and actually spilled + if (!dataInUse) { + // These are groups from the partitioner which specifically are internalID columns and thus + // don't have a null column or any other sort of child column. That being said, it may be a + // good idea to make the interface more generic, which would open up the possibility of + // spilling to disk during node table copies too. + for (size_t i = 0; i < getNumColumns(); i++) { + auto [reclaimed, nowEvictable] = getColumnChunk(i).spillToDisk(); + reclaimedSpace += reclaimed; + nowEvictableMemory += nowEvictable; + } + } + return SpillResult{reclaimedSpace, nowEvictableMemory}; +} + +void InMemChunkedNodeGroup::resetToEmpty() { + numRows = 0; + for (const auto& chunk : chunks) { + chunk->resetToEmpty(); + } +} + +void InMemChunkedNodeGroup::resetToAllNull() const { + for (const auto& chunk : chunks) { + chunk->resetToAllNull(); + } +} + +void InMemChunkedNodeGroup::resizeChunks(const uint64_t newSize) { + if (newSize <= capacity) { + return; + } + for (auto& chunk : chunks) { + chunk->resize(newSize); + } + capacity = newSize; +} + +uint64_t InMemChunkedNodeGroup::append(const std::vector& columnVectors, + row_idx_t startRowInVectors, uint64_t numValuesToAppend) { + KU_ASSERT(columnVectors.size() == chunks.size()); + const auto numRowsToAppendInChunk = std::min(numValuesToAppend, capacity - numRows); + try { + for (auto i = 0u; i < columnVectors.size(); i++) { + const auto columnVector = columnVectors[i]; + chunks[i]->append(columnVector, columnVector->state->getSelVector().slice( + startRowInVectors, numRowsToAppendInChunk)); + } + } catch ([[maybe_unused]] std::exception& e) { + handleAppendException(chunks, numRows); + } + numRows += numRowsToAppendInChunk; + return numRowsToAppendInChunk; +} + +offset_t InMemChunkedNodeGroup::append(const InMemChunkedNodeGroup& other, + offset_t offsetInOtherNodeGroup, offset_t numRowsToAppend) { + KU_ASSERT(other.chunks.size() == chunks.size()); + const auto numToAppendInChunkedGroup = std::min(numRowsToAppend, capacity - numRows); + try { + for (auto i = 0u; i < other.getNumColumns(); i++) { + chunks[i]->append(other.chunks[i].get(), offsetInOtherNodeGroup, + numToAppendInChunkedGroup); + } + } catch ([[maybe_unused]] std::exception& e) { + handleAppendException(chunks, numRows); + } + numRows += numToAppendInChunkedGroup; + return numToAppendInChunkedGroup; +} + +void InMemChunkedNodeGroup::merge(InMemChunkedNodeGroup& base, + const std::vector& columnsToMergeInto) { + KU_ASSERT(base.getNumColumns() == columnsToMergeInto.size()); + for (idx_t i = 0; i < base.getNumColumns(); ++i) { + KU_ASSERT(columnsToMergeInto[i] < chunks.size()); + chunks[columnsToMergeInto[i]] = base.moveColumnChunk(i); + } +} + +InMemChunkedNodeGroup::InMemChunkedNodeGroup(InMemChunkedNodeGroup& base, + const std::vector& selectedColumns) + : startRowIdx{base.getStartRowIdx()}, numRows{base.getNumRows()}, capacity{base.getCapacity()}, + dataInUse{true} { + chunks.resize(selectedColumns.size()); + for (auto i = 0u; i < selectedColumns.size(); i++) { + auto columnID = selectedColumns[i]; + KU_ASSERT(columnID < base.getNumColumns()); + chunks[i] = base.moveColumnChunk(columnID); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column.cpp new file mode 100644 index 0000000000..96496598bc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column.cpp @@ -0,0 +1,580 @@ +#include "storage/table/column.h" + +#include +#include +#include + +#include "common/assert.h" +#include "common/data_chunk/sel_vector.h" +#include "common/null_mask.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/compression/compression.h" +#include "storage/file_handle.h" +#include "storage/page_allocator.h" +#include "storage/page_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/list_column.h" +#include "storage/table/null_column.h" +#include "storage/table/string_column.h" +#include "storage/table/struct_column.h" +#include + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::evaluator; + +namespace lbug { +namespace storage { + +struct ReadInternalIDValuesToVector { + ReadInternalIDValuesToVector() : compressedReader{LogicalType(LogicalTypeID::INTERNAL_ID)} {} + void operator()(const uint8_t* frame, PageCursor& pageCursor, ValueVector* resultVector, + uint32_t posInVector, uint32_t numValuesToRead, const CompressionMetadata& metadata) { + KU_ASSERT(resultVector->dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID); + + KU_ASSERT(numValuesToRead <= DEFAULT_VECTOR_CAPACITY); + offset_t offsetBuffer[DEFAULT_VECTOR_CAPACITY]; + + compressedReader(frame, pageCursor, reinterpret_cast(offsetBuffer), 0, + numValuesToRead, metadata); + auto resultData = reinterpret_cast(resultVector->getData()); + for (auto i = 0u; i < numValuesToRead; i++) { + resultData[posInVector + i].offset = offsetBuffer[i]; + } + } + +private: + ReadCompressedValuesFromPage compressedReader; +}; + +struct WriteInternalIDValuesToPage { + WriteInternalIDValuesToPage() : compressedWriter{LogicalType(LogicalTypeID::INTERNAL_ID)} {} + void operator()(uint8_t* frame, uint16_t posInFrame, const uint8_t* data, uint32_t dataOffset, + offset_t numValues, const CompressionMetadata& metadata, const NullMask* nullMask) { + compressedWriter(frame, posInFrame, data, dataOffset, numValues, metadata, nullMask); + } + void operator()(uint8_t* frame, uint16_t posInFrame, ValueVector* vector, + uint32_t offsetInVector, offset_t numValues, const CompressionMetadata& metadata) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID); + compressedWriter(frame, posInFrame, + reinterpret_cast( + &vector->getValue(offsetInVector).offset), + 0 /*dataOffset*/, numValues, metadata); + } + +private: + WriteCompressedValuesToPage compressedWriter; +}; + +static read_values_to_vector_func_t getReadValuesToVectorFunc(const LogicalType& logicalType) { + switch (logicalType.getLogicalTypeID()) { + case LogicalTypeID::INTERNAL_ID: + return ReadInternalIDValuesToVector(); + default: + return ReadCompressedValuesFromPageToVector(logicalType); + } +} + +static write_values_func_t getWriteValuesFunc(const LogicalType& logicalType) { + switch (logicalType.getLogicalTypeID()) { + case LogicalTypeID::INTERNAL_ID: + return WriteInternalIDValuesToPage(); + default: + return WriteCompressedValuesToPage(logicalType); + } +} + +InternalIDColumn::InternalIDColumn(std::string name, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression) + : Column{std::move(name), LogicalType::INTERNAL_ID(), dataFH, mm, shadowFile, enableCompression, + false /*requireNullColumn*/}, + commonTableID{INVALID_TABLE_ID} {} + +void InternalIDColumn::populateCommonTableID(const ValueVector* resultVector) const { + auto nodeIDs = reinterpret_cast(resultVector->getData()); + auto& selVector = resultVector->state->getSelVector(); + for (auto i = 0u; i < selVector.getSelSize(); i++) { + const auto pos = selVector[i]; + nodeIDs[pos].tableID = commonTableID; + } +} + +Column::Column(std::string name, LogicalType dataType, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression, bool requireNullColumn) + : name{std::move(name)}, dataType{std::move(dataType)}, mm{mm}, dataFH(dataFH), + shadowFile(shadowFile), enableCompression{enableCompression}, + columnReadWriter(ColumnReadWriterFactory::createColumnReadWriter( + this->dataType.getPhysicalType(), dataFH, shadowFile)) { + readToVectorFunc = getReadValuesToVectorFunc(this->dataType); + readToPageFunc = ReadCompressedValuesFromPage(this->dataType); + writeFunc = getWriteValuesFunc(this->dataType); + if (requireNullColumn) { + auto columnName = + StorageUtils::getColumnName(this->name, StorageUtils::ColumnType::NULL_MASK, ""); + nullColumn = + std::make_unique(columnName, dataFH, mm, shadowFile, enableCompression); + } +} + +Column::Column(std::string name, PhysicalTypeID physicalType, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression, bool requireNullColumn) + : Column(std::move(name), LogicalType::ANY(physicalType), dataFH, mm, shadowFile, + enableCompression, requireNullColumn) {} + +Column::~Column() = default; + +Column* Column::getNullColumn() const { + return nullColumn.get(); +} + +void Column::populateExtraChunkState(SegmentState& state) const { + if (state.metadata.compMeta.compression == CompressionType::ALP) { + if (dataType.getPhysicalType() == PhysicalTypeID::DOUBLE) { + state.alpExceptionChunk = + std::make_unique>(state, dataFH, mm, shadowFile); + } else if (dataType.getPhysicalType() == PhysicalTypeID::FLOAT) { + state.alpExceptionChunk = + std::make_unique>(state, dataFH, mm, shadowFile); + } + } +} + +std::unique_ptr Column::flushChunkData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator) { + switch (chunkData.getDataType().getPhysicalType()) { + case PhysicalTypeID::STRUCT: { + return StructColumn::flushChunkData(chunkData, pageAllocator); + } + case PhysicalTypeID::STRING: { + return StringColumn::flushChunkData(chunkData, pageAllocator); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return ListColumn::flushChunkData(chunkData, pageAllocator); + } + default: { + return flushNonNestedChunkData(chunkData, pageAllocator); + } + } +} + +std::unique_ptr Column::flushNonNestedChunkData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator) { + auto chunkMeta = flushData(chunkData, pageAllocator); + auto flushedChunk = ColumnChunkFactory::createColumnChunkData(chunkData.getMemoryManager(), + chunkData.getDataType().copy(), chunkData.isCompressionEnabled(), chunkMeta, + chunkData.hasNullData(), true); + if (chunkData.hasNullData()) { + auto nullChunkMeta = flushData(*chunkData.getNullData(), pageAllocator); + auto nullData = std::make_unique(chunkData.getMemoryManager(), + chunkData.isCompressionEnabled(), nullChunkMeta); + flushedChunk->setNullData(std::move(nullData)); + } + return flushedChunk; +} + +ColumnChunkMetadata Column::flushData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator) { + KU_ASSERT(chunkData.sanityCheck()); + const auto preScanMetadata = chunkData.getMetadataToFlush(); + auto allocatedBlock = pageAllocator.allocatePageRange(preScanMetadata.getNumPages()); + return chunkData.flushBuffer(pageAllocator, allocatedBlock, preScanMetadata); +} + +void Column::scan(const ChunkState& state, offset_t startOffsetInChunk, offset_t length, + ValueVector* resultVector, uint64_t offsetInVector) const { + if (length == 0) { + return; + } + // Selection vector must be ordered, and values must be within the range of [0, length) + RUNTIME_CHECK(if (resultVector->state) { + sel_t prevValue = 0; + resultVector->state->getSelVector().forEach([&](auto i) { + KU_ASSERT(prevValue <= i); + KU_ASSERT(i < length); + prevValue = i; + }); + }); + + state.rangeSegments(startOffsetInChunk, length, + [&](auto& segmentState, auto startOffsetInSegment, auto lengthInSegment, auto dstOffset) { + scanSegment(segmentState, startOffsetInSegment, lengthInSegment, resultVector, + offsetInVector + dstOffset); + }); +} + +void Column::scanSegment(const SegmentState& state, offset_t startOffsetInSegment, + row_idx_t numValuesToScan, ValueVector* resultVector, offset_t offsetInVector) const { + if (numValuesToScan == 0) { + return; + } + KU_ASSERT(startOffsetInSegment + numValuesToScan <= state.metadata.numValues); + if (nullColumn) { + KU_ASSERT(state.nullState); + nullColumn->scanSegment(*state.nullState, startOffsetInSegment, numValuesToScan, + resultVector, offsetInVector); + } + if (getDataTypeSizeInChunk(dataType) == 0) { + return; + } + if (!resultVector->state || resultVector->state->getSelVector().isUnfiltered()) { + columnReadWriter->readCompressedValuesToVector(state, resultVector, offsetInVector, + startOffsetInSegment, numValuesToScan, readToVectorFunc); + } else { + struct Filterer { + explicit Filterer(const SelectionVector& selVector, offset_t offsetInVector) + : selVector(selVector), posInSelVector(0), offsetInVector{offsetInVector} {} + bool operator()(offset_t startIdx, offset_t endIdx) { + while (posInSelVector < selVector.getSelSize() && + (selVector[posInSelVector] < offsetInVector || + selVector[posInSelVector] - offsetInVector < startIdx)) { + posInSelVector++; + } + return (posInSelVector < selVector.getSelSize() && + isInRange(selVector[posInSelVector] - offsetInVector, startIdx, endIdx)); + } + + const SelectionVector& selVector; + offset_t posInSelVector; + offset_t offsetInVector; + }; + + columnReadWriter->readCompressedValuesToVector(state, resultVector, offsetInVector, + startOffsetInSegment, numValuesToScan, readToVectorFunc, + Filterer{resultVector->state->getSelVector(), offsetInVector}); + } +} + +void Column::scanSegment(const SegmentState& state, ColumnChunkData* outputChunk, + offset_t offsetInSegment, offset_t numValues) const { + if (numValues == 0) { + return; + } + KU_ASSERT(offsetInSegment + numValues <= state.metadata.numValues); + auto startLength = outputChunk->getNumValues(); + if (nullColumn) { + nullColumn->scanSegment(*state.nullState, outputChunk->getNullData(), offsetInSegment, + numValues); + } + + if (startLength + numValues > outputChunk->getCapacity()) { + outputChunk->resize(std::bit_ceil(startLength + numValues)); + } + + if (getDataTypeSizeInChunk(dataType) > 0) { + columnReadWriter->readCompressedValuesToPage(state, outputChunk->getData(), + outputChunk->getNumValues(), offsetInSegment, numValues, readToPageFunc); + } + outputChunk->setNumValues(startLength + numValues); +} + +void Column::scan(const ChunkState& state, ColumnChunkData* outputChunk, offset_t offsetInChunk, + offset_t numValues) const { + outputChunk->setNumValues(0); + [[maybe_unused]] uint64_t numValuesScanned = state.rangeSegments(offsetInChunk, numValues, + [&](auto& segmentState, auto startOffsetInSegment, auto lengthInSegment, auto) { + scanSegment(segmentState, outputChunk, startOffsetInSegment, lengthInSegment); + }); + KU_ASSERT(outputChunk->getNumValues() == numValuesScanned); +} + +void Column::scanSegment(const SegmentState& state, offset_t startOffsetInSegment, offset_t length, + uint8_t* result) const { + KU_ASSERT(startOffsetInSegment + length <= state.metadata.numValues); + columnReadWriter->readCompressedValuesToPage(state, result, 0, startOffsetInSegment, length, + readToPageFunc); +} + +void Column::lookupValue(const ChunkState& state, offset_t nodeOffset, ValueVector* resultVector, + uint32_t posInVector) const { + auto [segmentState, offsetInSegment] = state.findSegment(nodeOffset); + if (nullColumn) { + nullColumn->lookupInternal(*segmentState->nullState, offsetInSegment, resultVector, + posInVector); + } + if (!resultVector->isNull(posInVector)) { + lookupInternal(*segmentState, offsetInSegment, resultVector, posInVector); + } +} + +void Column::lookupInternal(const SegmentState& state, offset_t offsetInSegment, + ValueVector* resultVector, uint32_t posInVector) const { + columnReadWriter->readCompressedValueToVector(state, offsetInSegment, resultVector, posInVector, + readToVectorFunc); +} + +[[maybe_unused]] static bool sanityCheckForWrites(const ColumnChunkMetadata& metadata, + const LogicalType& dataType) { + if (metadata.compMeta.compression == CompressionType::ALP) { + return metadata.compMeta.children.size() != 0; + } + if (metadata.compMeta.compression == CompressionType::CONSTANT) { + return metadata.getNumDataPages(dataType.getPhysicalType()) == 0; + } + const auto numValuesPerPage = metadata.compMeta.numValues(LBUG_PAGE_SIZE, dataType); + if (numValuesPerPage == UINT64_MAX) { + return metadata.getNumDataPages(dataType.getPhysicalType()) == 0; + } + return std::ceil( + static_cast(metadata.numValues) / static_cast(numValuesPerPage)) <= + metadata.getNumDataPages(dataType.getPhysicalType()); +} + +void Column::updateStatistics(ColumnChunkMetadata& metadata, offset_t maxIndex, + const std::optional& min, const std::optional& max) const { + if (maxIndex >= metadata.numValues) { + metadata.numValues = maxIndex + 1; + KU_ASSERT(sanityCheckForWrites(metadata, dataType)); + } + // Either both or neither should be provided + KU_ASSERT((!min && !max) || (min && max)); + if (min && max) { + // If new values are outside of the existing min/max, update them + if (max->gt(metadata.compMeta.max, dataType.getPhysicalType())) { + metadata.compMeta.max = *max; + } else if (metadata.compMeta.min.gt(*min, dataType.getPhysicalType())) { + metadata.compMeta.min = *min; + } + } +} + +void Column::write(ColumnChunkData& persistentChunk, ChunkState& state, offset_t initialDstOffset, + const ColumnChunkData& data, offset_t srcOffset, length_t numValues) const { + state.rangeSegments(srcOffset, numValues, + [&](auto& segmentState, auto offsetInSegment, auto lengthInSegment, auto dstOffset) { + writeSegment(persistentChunk, segmentState, initialDstOffset + dstOffset, data, + offsetInSegment, lengthInSegment); + }); +} + +void Column::writeSegment(ColumnChunkData& persistentChunk, SegmentState& state, + offset_t dstOffsetInSegment, const ColumnChunkData& data, offset_t srcOffset, + offset_t numValues) const { + auto nullMask = data.getNullMask(); + columnReadWriter->writeValuesToPageFromBuffer(state, dstOffsetInSegment, data.getData(), + nullMask ? &*nullMask : nullptr, srcOffset, numValues, writeFunc); + + if (dataType.getPhysicalType() != common::PhysicalTypeID::ALP_EXCEPTION_DOUBLE && + dataType.getPhysicalType() != common::PhysicalTypeID::ALP_EXCEPTION_FLOAT) { + auto [minWritten, maxWritten] = + getMinMaxStorageValue(data, srcOffset, numValues, dataType.getPhysicalType()); + updateStatistics(persistentChunk.getMetadata(), dstOffsetInSegment + numValues - 1, + minWritten, maxWritten); + } +} + +// TODO: Do we need to adapt the offsets to this current node group? +void Column::writeValues(ChunkState& state, offset_t initialDstOffset, const uint8_t* data, + const NullMask* nullChunkData, offset_t srcOffset, offset_t numValues) const { + state.rangeSegments(srcOffset, numValues, + [&](auto& segmentState, auto offsetInSegment, auto lengthInSegment, auto dstOffset) { + writeValuesInternal(segmentState, initialDstOffset + dstOffset, data, nullChunkData, + offsetInSegment, lengthInSegment); + }); +} + +void Column::writeValuesInternal(SegmentState& state, common::offset_t dstOffsetInSegment, + const uint8_t* data, const common::NullMask* nullChunkData, common::offset_t srcOffset, + common::offset_t numValues) const { + columnReadWriter->writeValuesToPageFromBuffer(state, dstOffsetInSegment, data, nullChunkData, + srcOffset, numValues, writeFunc); +} + +// Append to the end of the chunk. +offset_t Column::appendValues(ColumnChunkData& persistentChunk, SegmentState& state, + const uint8_t* data, const NullMask* nullChunkData, offset_t numValues) const { + auto& metadata = persistentChunk.getMetadata(); + const auto startOffset = metadata.numValues; + writeValuesInternal(state, metadata.numValues, data, nullChunkData, 0 /*dataOffset*/, + numValues); + + auto [minWritten, maxWritten] = getMinMaxStorageValue(data, 0 /*offset*/, numValues, + dataType.getPhysicalType(), nullChunkData); + updateStatistics(metadata, startOffset + numValues - 1, minWritten, maxWritten); + return startOffset; +} + +bool Column::isEndOffsetOutOfPagesCapacity(const ColumnChunkMetadata& metadata, + offset_t endOffset) const { + if (metadata.compMeta.compression != CompressionType::CONSTANT && + (metadata.compMeta.numValues(LBUG_PAGE_SIZE, dataType) * + metadata.getNumDataPages(dataType.getPhysicalType())) <= endOffset) { + // Note that for constant compression, `metadata.numPages` will be equal to 0. + // Thus, this function will always return true. + return true; + } + return false; +} + +void Column::checkpointColumnChunkInPlace(SegmentState& state, + const ColumnCheckpointState& checkpointState, PageAllocator& pageAllocator) const { + for (auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + KU_ASSERT(segmentCheckpointState.numRows > 0); + state.column->writeSegment(checkpointState.persistentData, state, + segmentCheckpointState.offsetInSegment, segmentCheckpointState.chunkData, + segmentCheckpointState.startRowInData, segmentCheckpointState.numRows); + } + // writeSegment doesn't update numValues, just the metadata + // TODO(bmwinger): either have all writes update numValues, or have writeSegment update it + // directly + checkpointState.persistentData.resetNumValuesFromMetadata(); + if (nullColumn) { + checkpointNullData(checkpointState, pageAllocator); + } +} + +void Column::checkpointNullData(const ColumnCheckpointState& checkpointState, + PageAllocator& pageAllocator) const { + std::vector nullSegmentCheckpointStates; + for (const auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + KU_ASSERT(segmentCheckpointState.chunkData.hasNullData()); + nullSegmentCheckpointStates.emplace_back(*segmentCheckpointState.chunkData.getNullData(), + segmentCheckpointState.startRowInData, segmentCheckpointState.offsetInSegment, + segmentCheckpointState.numRows); + } + KU_ASSERT(checkpointState.persistentData.hasNullData()); + nullColumn->checkpointSegment( + ColumnCheckpointState(*checkpointState.persistentData.getNullData(), + std::move(nullSegmentCheckpointStates)), + pageAllocator, false); +} + +std::vector> Column::checkpointColumnChunkOutOfPlace( + const SegmentState& state, const ColumnCheckpointState& checkpointState, + PageAllocator& pageAllocator, bool canSplitSegment) const { + const auto numRows = std::max(checkpointState.endRowIdxToWrite, state.metadata.numValues); + checkpointState.persistentData.setToInMemory(); + checkpointState.persistentData.resize(numRows); + KU_ASSERT(checkpointState.persistentData.getNumValues() == 0); + scanSegment(state, &checkpointState.persistentData, 0, state.metadata.numValues); + state.reclaimAllocatedPages(pageAllocator); + // TODO(bmwinger): for simple compression types, we can predict whether or not we will need to + // split the segment and avoid having to re-write it multiple times + for (auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + checkpointState.persistentData.write(&segmentCheckpointState.chunkData, + segmentCheckpointState.startRowInData, segmentCheckpointState.offsetInSegment, + segmentCheckpointState.numRows); + } + // Finalize is necessary prior to splitting for strings and lists so that pruned values don't + // have an impact on the number/size of segments It should not be necessary after splitting + // since the function is used to prune unused values (or duplicated dictionary entries in the + // case of strings) and those will never be introduced when splitting. + checkpointState.persistentData.finalize(); + if (canSplitSegment && checkpointState.persistentData.shouldSplit()) { + auto newSegments = checkpointState.persistentData.split(); + for (auto& segment : newSegments) { + segment->flush(pageAllocator); + } + return newSegments; + } + checkpointState.persistentData.flush(pageAllocator); + return {}; +} + +bool Column::canCheckpointInPlace(const SegmentState& state, + const ColumnCheckpointState& checkpointState) const { + if (isEndOffsetOutOfPagesCapacity(checkpointState.persistentData.getMetadata(), + checkpointState.endRowIdxToWrite)) { + return false; + } + if (checkpointState.persistentData.getMetadata().compMeta.canAlwaysUpdateInPlace()) { + return true; + } + + InPlaceUpdateLocalState localUpdateState{}; + for (auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + auto& chunkData = segmentCheckpointState.chunkData; + if (chunkData.getNumValues() != 0 && + !state.metadata.compMeta.canUpdateInPlace(chunkData.getData(), + segmentCheckpointState.startRowInData, segmentCheckpointState.numRows, + dataType.getPhysicalType(), localUpdateState, chunkData.getNullMask())) { + return false; + } + } + return true; +} + +std::vector> Column::checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment) const { + if (checkpointState.segmentCheckpointStates.empty()) { + return {}; + } + SegmentState chunkState; + checkpointState.persistentData.initializeScanState(chunkState, this); + if (canCheckpointInPlace(chunkState, checkpointState)) { + checkpointColumnChunkInPlace(chunkState, checkpointState, pageAllocator); + + if (chunkState.metadata.compMeta.compression == CompressionType::ALP) { + if (dataType.getPhysicalType() == PhysicalTypeID::DOUBLE) { + chunkState.getExceptionChunk()->finalizeAndFlushToDisk(chunkState); + } else if (dataType.getPhysicalType() == PhysicalTypeID::FLOAT) { + chunkState.getExceptionChunk()->finalizeAndFlushToDisk(chunkState); + } else { + KU_UNREACHABLE; + } + checkpointState.persistentData.getMetadata().compMeta.floatMetadata()->exceptionCount = + chunkState.metadata.compMeta.floatMetadata()->exceptionCount; + } + return {}; + } else { + return checkpointColumnChunkOutOfPlace(chunkState, checkpointState, pageAllocator, + canSplitSegment); + } +} + +std::unique_ptr ColumnFactory::createColumn(std::string name, PhysicalTypeID physicalType, + FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression) { + return std::make_unique(name, LogicalType::ANY(physicalType), dataFH, mm, shadowFile, + enableCompression); +} + +std::unique_ptr ColumnFactory::createColumn(std::string name, LogicalType dataType, + FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression) { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT8: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::INT128: + case PhysicalTypeID::UINT128: + case PhysicalTypeID::DOUBLE: + case PhysicalTypeID::FLOAT: + case PhysicalTypeID::INTERVAL: { + return std::make_unique(name, std::move(dataType), dataFH, mm, shadowFile, + enableCompression); + } + case PhysicalTypeID::INTERNAL_ID: { + return std::make_unique(name, dataFH, mm, shadowFile, enableCompression); + } + case PhysicalTypeID::STRING: { + return std::make_unique(name, std::move(dataType), dataFH, mm, shadowFile, + enableCompression); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return std::make_unique(name, std::move(dataType), dataFH, mm, shadowFile, + enableCompression); + } + case PhysicalTypeID::STRUCT: { + return std::make_unique(name, std::move(dataType), dataFH, mm, shadowFile, + enableCompression); + } + default: { + KU_UNREACHABLE; + } + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk.cpp new file mode 100644 index 0000000000..06917c90dc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk.cpp @@ -0,0 +1,334 @@ +#include "storage/table/column_chunk.h" + +#include +#include + +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/vector/value_vector.h" +#include "main/client_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/enums/residency_state.h" +#include "storage/page_allocator.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/column_chunk_scanner.h" +#include "storage/table/combined_chunk_scanner.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +void ChunkState::reclaimAllocatedPages(PageAllocator& pageAllocator) const { + for (auto& state : segmentStates) { + state.reclaimAllocatedPages(pageAllocator); + } +} + +std::pair ChunkState::findSegment( + common::offset_t offsetInChunk) const { + auto [iter, offsetInSegment] = genericFindSegment(std::span(segmentStates), offsetInChunk); + if (iter == std::span(segmentStates).end()) { + return std::make_pair(nullptr, 0); + } + return std::make_pair(&*iter, offsetInSegment); +} + +ColumnChunk::ColumnChunk(MemoryManager& mm, LogicalType&& dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState, bool initializeToZero) + : enableCompression{enableCompression} { + data.push_back(ColumnChunkFactory::createColumnChunkData(mm, std::move(dataType), + enableCompression, capacity, residencyState, true, initializeToZero)); + KU_ASSERT(residencyState != ResidencyState::ON_DISK); +} + +ColumnChunk::ColumnChunk(MemoryManager& mm, LogicalType&& dataType, bool enableCompression, + ColumnChunkMetadata metadata) + : enableCompression{enableCompression} { + data.push_back(ColumnChunkFactory::createColumnChunkData(mm, std::move(dataType), + enableCompression, metadata, true, true)); +} + +ColumnChunk::ColumnChunk(bool enableCompression, std::unique_ptr data) + : enableCompression{enableCompression}, data{} { + this->data.push_back(std::move(data)); +} +ColumnChunk::ColumnChunk(bool enableCompression, + std::vector> segments) + : enableCompression{enableCompression}, data{std::move(segments)} {} + +void ColumnChunk::initializeScanState(ChunkState& state, const Column* column) const { + state.column = column; + state.segmentStates.resize(data.size()); + for (size_t i = 0; i < data.size(); i++) { + data[i]->initializeScanState(state.segmentStates[i], column); + } +} + +void ColumnChunk::scan(const Transaction* transaction, const ChunkState& state, ValueVector& output, + offset_t offsetInChunk, length_t length) const { + // Check if there is deletions or insertions. If so, update selVector based on transaction. + switch (getResidencyState()) { + case ResidencyState::IN_MEMORY: { + rangeSegments(offsetInChunk, length, + [&](auto& segment, auto offsetInSegment, auto lengthInSegment, auto dstOffset) { + segment->scan(output, offsetInSegment, lengthInSegment, dstOffset); + }); + } break; + case ResidencyState::ON_DISK: { + state.column->scan(state, offsetInChunk, length, &output, 0); + } break; + default: { + KU_UNREACHABLE; + } + } + updateInfo.scan(transaction, output, offsetInChunk, length); +} + +static void scanPersistentSegments(ChunkState& chunkState, ColumnChunkScanner& output, + common::offset_t startRow, common::offset_t numRows) { + KU_ASSERT(output.getNumValues() == 0); + [[maybe_unused]] uint64_t numValuesScanned = chunkState.rangeSegments(startRow, numRows, + [&](auto& segmentState, auto offsetInSegment, auto lengthInSegment, auto) { + output.scanSegment(offsetInSegment, lengthInSegment, + [&chunkState, &segmentState](ColumnChunkData& outputChunk, offset_t offsetInSegment, + offset_t lengthInSegment) { + chunkState.column->scanSegment(segmentState, &outputChunk, offsetInSegment, + lengthInSegment); + }); + }); + KU_ASSERT(output.getNumValues() == numValuesScanned); +} + +void ColumnChunk::scanInMemSegments(ColumnChunkScanner& output, common::offset_t startRow, + common::offset_t numRows) const { + rangeSegments(startRow, numRows, + [&](auto& segment, auto offsetInSegment, auto lengthInSegment, auto) { + output.scanSegment(offsetInSegment, lengthInSegment, + [&segment](ColumnChunkData& outputChunk, offset_t offsetInSegment, + offset_t lengthInSegment) { + outputChunk.append(segment.get(), offsetInSegment, lengthInSegment); + }); + }); +} + +template +void ColumnChunk::scanCommitted(const Transaction* transaction, ChunkState& chunkState, + ColumnChunkScanner& output, row_idx_t startRow, row_idx_t numRows) const { + auto numValuesInChunk = getNumValues(); + if (numRows == INVALID_ROW_IDX || startRow + numRows > numValuesInChunk) { + numRows = numValuesInChunk - startRow; + } + if (numRows == 0 || startRow >= numValuesInChunk) { + return; + } + const auto residencyState = getResidencyState(); + if (SCAN_RESIDENCY_STATE == residencyState) { + if constexpr (SCAN_RESIDENCY_STATE == ResidencyState::ON_DISK) { + scanPersistentSegments(chunkState, output, startRow, numRows); + } else { + static_assert(SCAN_RESIDENCY_STATE == ResidencyState::IN_MEMORY); + scanInMemSegments(output, startRow, numRows); + } + output.applyCommittedUpdates(updateInfo, transaction, startRow, numRows); + } +} + +template void ColumnChunk::scanCommitted(const Transaction* transaction, + ChunkState& chunkState, ColumnChunkScanner& output, row_idx_t startRow, + row_idx_t numRows) const; +template void ColumnChunk::scanCommitted(const Transaction* transaction, + ChunkState& chunkState, ColumnChunkScanner& output, row_idx_t startRow, + row_idx_t numRows) const; + +template +void ColumnChunk::scanCommitted(const Transaction* transaction, ChunkState& chunkState, + ColumnChunkData& output, row_idx_t startRow, row_idx_t numRows) const { + CombinedChunkScanner scanner{output}; + scanCommitted(transaction, chunkState, scanner, startRow, numRows); +} + +template void ColumnChunk::scanCommitted(const Transaction* transaction, + ChunkState& chunkState, ColumnChunkData& output, row_idx_t startRow, row_idx_t numRows) const; +template void ColumnChunk::scanCommitted(const Transaction* transaction, + ChunkState& chunkState, ColumnChunkData& output, row_idx_t startRow, row_idx_t numRows) const; + +bool ColumnChunk::hasUpdates(const Transaction* transaction, row_idx_t startRow, + length_t numRows) const { + return updateInfo.hasUpdates(transaction, startRow, numRows); +} + +void ColumnChunk::lookup(const Transaction* transaction, const ChunkState& state, + offset_t rowInChunk, ValueVector& output, sel_t posInOutputVector) const { + switch (getResidencyState()) { + case ResidencyState::IN_MEMORY: { + rangeSegments(rowInChunk, 1, [&](auto& segment, auto offsetInSegment, auto, auto) { + segment->lookup(offsetInSegment, output, posInOutputVector); + }); + } break; + case ResidencyState::ON_DISK: { + state.column->lookupValue(state, rowInChunk, &output, posInOutputVector); + } break; + } + updateInfo.lookup(transaction, rowInChunk, output, posInOutputVector); +} + +void ColumnChunk::update(const Transaction* transaction, offset_t offsetInChunk, + const ValueVector& values) { + if (transaction->getType() == TransactionType::DUMMY) { + rangeSegments(offsetInChunk, 1, [&](auto& segment, auto offsetInSegment, auto, auto) { + segment->write(&values, values.state->getSelVector().getSelectedPositions()[0], + offsetInSegment); + }); + return; + } + + const auto vectorIdx = offsetInChunk / DEFAULT_VECTOR_CAPACITY; + const auto rowIdxInVector = offsetInChunk % DEFAULT_VECTOR_CAPACITY; + auto& vectorUpdateInfo = updateInfo.update(data.front()->getMemoryManager(), transaction, + vectorIdx, rowIdxInVector, values); + transaction->pushVectorUpdateInfo(updateInfo, vectorIdx, vectorUpdateInfo, + transaction->getID()); +} + +MergedColumnChunkStats ColumnChunk::getMergedColumnChunkStats() const { + KU_ASSERT(!updateInfo.isSet()); + auto baseStats = MergedColumnChunkStats{ColumnChunkStats{}, true, true}; + for (auto& segment : data) { + // TODO: Replace with a function that modifies the existing stats in-place? + auto segmentStats = segment->getMergedColumnChunkStats(); + baseStats.merge(segmentStats, segment->getDataType().getPhysicalType()); + } + return baseStats; +} + +void ColumnChunk::serialize(Serializer& serializer) const { + serializer.writeDebuggingInfo("enable_compression"); + serializer.write(enableCompression); + serializer.write(data.size()); + for (auto& segment : data) { + segment->serialize(serializer); + } +} + +std::unique_ptr ColumnChunk::deserialize(MemoryManager& mm, Deserializer& deSer) { + std::string key; + bool enableCompression = false; + deSer.validateDebuggingInfo(key, "enable_compression"); + deSer.deserializeValue(enableCompression); + uint64_t numSegments = 0; + deSer.deserializeValue(numSegments); + std::vector> segments; + for (uint64_t i = 0; i < numSegments; i++) { + segments.push_back(ColumnChunkData::deserialize(mm, deSer)); + } + return std::make_unique(enableCompression, std::move(segments)); +} + +row_idx_t ColumnChunk::getNumUpdatedRows(const Transaction* transaction) const { + return updateInfo.getNumUpdatedRows(transaction); +} + +void ColumnChunk::reclaimStorage(PageAllocator& pageAllocator) const { + for (const auto& segment : data) { + segment->reclaimStorage(pageAllocator); + } +} + +void ColumnChunk::append(common::ValueVector* vector, const common::SelectionView& selView) { + data.back()->append(vector, selView); +} + +void ColumnChunk::append(const ColumnChunk* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + for (auto& otherSegment : other->data) { + if (numValuesToAppend == 0) { + return; + } + if (otherSegment->getNumValues() < startPosInOtherChunk) { + startPosInOtherChunk -= otherSegment->getNumValues(); + } else { + auto numValuesToAppendInSegment = + std::min(otherSegment->getNumValues(), uint64_t{numValuesToAppend}); + append(otherSegment.get(), startPosInOtherChunk, numValuesToAppendInSegment); + numValuesToAppend -= numValuesToAppendInSegment; + startPosInOtherChunk = 0; + } + } +} + +void ColumnChunk::append(const ColumnChunkData* other, common::offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + data.back()->append(other, startPosInOtherChunk, numValuesToAppend); +} + +void ColumnChunk::write(Column& column, ChunkState& state, offset_t dstOffset, + const ColumnChunkData& dataToWrite, offset_t srcOffset, common::length_t numValues) { + auto segment = data.begin(); + auto offsetInSegment = dstOffset; + while (segment->get()->getNumValues() < offsetInSegment) { + offsetInSegment -= segment->get()->getNumValues(); + segment++; + } + while (numValues > 0) { + auto numValuesToWriteInSegment = + std::min(numValues, segment->get()->getNumValues()) - offsetInSegment; + column.write(*segment->get(), state, offsetInSegment, dataToWrite, srcOffset, + numValuesToWriteInSegment); + offsetInSegment = 0; + numValues -= numValuesToWriteInSegment; + srcOffset += numValuesToWriteInSegment; + } +} + +void ColumnChunk::checkpoint(Column& column, + std::vector&& chunkCheckpointStates, PageAllocator& pageAllocator) { + offset_t segmentStart = 0; + for (size_t i = 0; i < data.size(); i++) { + std::vector segmentCheckpointStates; + auto& segment = data[i]; + KU_ASSERT(segment->getResidencyState() == ResidencyState::ON_DISK); + for (auto& state : chunkCheckpointStates) { + const bool isLastSegment = (i == data.size() - 1); + if (state.startRow + state.numRows > segmentStart && + (isLastSegment || state.startRow < segmentStart + segment->getNumValues())) { + const auto startOffset = std::max(state.startRow, segmentStart); + // Generally, we only want to checkpoint the overlapping parts of the old segment + // and the new chunk. This is to prevent having duplicate segments. However, for the + // last old segment we allow extending it to account for any insertions we have made + // in the current checkpoint. + const auto endOffset = isLastSegment ? state.startRow + state.numRows : + std::min(state.startRow + state.numRows, + segmentStart + segment->getNumValues()); + + const auto startOffsetInSegment = startOffset - segmentStart; + const auto startRowInChunk = startOffset - state.startRow; + segmentCheckpointStates.push_back({*state.chunkData, startRowInChunk, + startOffsetInSegment, endOffset - startOffset}); + } + } + auto segmentEnd = segmentStart + segment->getNumValues(); + // If the segment was split during checkpointing we need to insert the new segments into the + // ColumnChunk + auto newSegments = column.checkpointSegment( + ColumnCheckpointState(*segment, std::move(segmentCheckpointStates)), pageAllocator); + if (!newSegments.empty()) { + auto oldSize = data.size(); + data.resize(data.size() - 1 + newSegments.size()); + std::move_backward(data.begin() + i, data.begin() + oldSize, data.end()); + for (size_t j = 0; j < newSegments.size(); j++) { + data[i + j] = std::move(newSegments[j]); + } + // We want to increment by a total of newSegments.size() but we increment i at the end + // of each loop body + i += newSegments.size() - 1; + } + segmentStart = segmentEnd; + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_data.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_data.cpp new file mode 100644 index 0000000000..ece8f3b970 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_data.cpp @@ -0,0 +1,1094 @@ +#include "storage/table/column_chunk_data.h" + +#include + +#include "common/data_chunk/sel_vector.h" +#include "common/exception/copy.h" +#include "common/null_mask.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/system_config.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "expression_evaluator/expression_evaluator.h" +#include "storage/buffer_manager/buffer_manager.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/buffer_manager/spill_result.h" +#include "storage/buffer_manager/spiller.h" +#include "storage/compression/compression.h" +#include "storage/compression/float_compression.h" +#include "storage/enums/residency_state.h" +#include "storage/stats/column_stats.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk_metadata.h" +#include "storage/table/compression_flush_buffer.h" +#include "storage/table/list_chunk_data.h" +#include "storage/table/string_chunk_data.h" +#include "storage/table/struct_chunk_data.h" + +using namespace lbug::common; +using namespace lbug::evaluator; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +void SegmentState::reclaimAllocatedPages(PageAllocator& pageAllocator) const { + const auto& entry = metadata.pageRange; + if (entry.startPageIdx != INVALID_PAGE_IDX) { + pageAllocator.freePageRange(entry); + } + if (nullState) { + nullState->reclaimAllocatedPages(pageAllocator); + } + for (const auto& child : childrenStates) { + child.reclaimAllocatedPages(pageAllocator); + } +} + +static std::shared_ptr getCompression(const LogicalType& dataType, + bool enableCompression) { + if (!enableCompression) { + return std::make_shared(dataType); + } + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT128: { + return std::make_shared>(); + } + case PhysicalTypeID::INT64: { + return std::make_shared>(); + } + case PhysicalTypeID::INT32: { + return std::make_shared>(); + } + case PhysicalTypeID::INT16: { + return std::make_shared>(); + } + case PhysicalTypeID::INT8: { + return std::make_shared>(); + } + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::UINT64: { + return std::make_shared>(); + } + case PhysicalTypeID::UINT32: { + return std::make_shared>(); + } + case PhysicalTypeID::UINT16: { + return std::make_shared>(); + } + case PhysicalTypeID::UINT8: { + return std::make_shared>(); + } + case PhysicalTypeID::FLOAT: { + return std::make_shared>(); + } + case PhysicalTypeID::DOUBLE: { + return std::make_shared>(); + } + default: { + return std::make_shared(dataType); + } + } +} + +ColumnChunkData::ColumnChunkData(MemoryManager& mm, LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState, bool hasNullData, bool initializeToZero) + : residencyState{residencyState}, dataType{std::move(dataType)}, + enableCompression{enableCompression}, + numBytesPerValue{getDataTypeSizeInChunk(this->dataType)}, capacity{capacity}, numValues{0}, + inMemoryStats() { + if (hasNullData) { + nullData = std::make_unique(mm, capacity, enableCompression, residencyState); + } + initializeBuffer(this->dataType.getPhysicalType(), mm, initializeToZero); + initializeFunction(); +} + +ColumnChunkData::ColumnChunkData(MemoryManager& mm, LogicalType dataType, bool enableCompression, + const ColumnChunkMetadata& metadata, bool hasNullData, bool initializeToZero) + : residencyState(ResidencyState::ON_DISK), dataType{std::move(dataType)}, + enableCompression{enableCompression}, + numBytesPerValue{getDataTypeSizeInChunk(this->dataType)}, capacity{0}, + numValues{metadata.numValues}, metadata{metadata} { + if (hasNullData) { + nullData = std::make_unique(mm, enableCompression, metadata); + } + initializeBuffer(this->dataType.getPhysicalType(), mm, initializeToZero); + initializeFunction(); +} + +ColumnChunkData::ColumnChunkData(MemoryManager& mm, PhysicalTypeID dataType, bool enableCompression, + const ColumnChunkMetadata& metadata, bool hasNullData, bool initializeToZero) + : ColumnChunkData(mm, LogicalType::ANY(dataType), enableCompression, metadata, hasNullData, + initializeToZero) {} + +void ColumnChunkData::initializeBuffer(PhysicalTypeID physicalType, MemoryManager& mm, + bool initializeToZero) { + numBytesPerValue = getDataTypeSizeInChunk(physicalType); + + // Some columnChunks are much smaller than the 256KB minimum size used by allocateBuffer + // Which would lead to excessive memory use, particularly in the partitioner + buffer = mm.allocateBuffer(initializeToZero, getBufferSize(capacity)); +} + +void ColumnChunkData::initializeFunction() { + const auto compression = getCompression(dataType, enableCompression); + getMetadataFunction = GetCompressionMetadata(compression, dataType); + flushBufferFunction = initializeFlushBufferFunction(compression); +} + +ColumnChunkData::flush_buffer_func_t ColumnChunkData::initializeFlushBufferFunction( + std::shared_ptr compression) const { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + // Since we compress into memory, storage is the same as fixed-sized + // values, but we need to mark it as being boolean compressed. + return uncompressedFlushBuffer; + } + case PhysicalTypeID::STRING: + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT8: + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::INT128: { + return CompressedFlushBuffer(compression, dataType); + } + case PhysicalTypeID::DOUBLE: { + return CompressedFloatFlushBuffer(compression, dataType); + } + case PhysicalTypeID::FLOAT: { + return CompressedFloatFlushBuffer(compression, dataType); + } + default: { + return uncompressedFlushBuffer; + } + } +} + +void ColumnChunkData::resetToAllNull() { + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + if (nullData) { + nullData->resetToAllNull(); + } + resetInMemoryStats(); +} + +void ColumnChunkData::resetToEmpty() { + KU_ASSERT(residencyState != ResidencyState::ON_DISK); + if (nullData) { + nullData->resetToEmpty(); + } + KU_ASSERT(getBufferSize() == getBufferSize(capacity)); + memset(getData(), 0x00, getBufferSize()); + numValues = 0; + resetInMemoryStats(); +} + +static void updateInMemoryStats(ColumnChunkStats& stats, const ValueVector& values, + uint64_t offset = 0, uint64_t numValues = std::numeric_limits::max()) { + const auto physicalType = values.dataType.getPhysicalType(); + const auto numValuesToCheck = std::min(numValues, values.state->getSelSize()); + stats.update(values, offset, numValuesToCheck, physicalType); +} + +static void updateInMemoryStats(ColumnChunkStats& stats, const ColumnChunkData* values, + uint64_t offset = 0, uint64_t numValues = std::numeric_limits::max()) { + const auto physicalType = values->getDataType().getPhysicalType(); + const auto numValuesToCheck = std::min(values->getNumValues(), numValues); + const auto nullMask = values->getNullMask(); + stats.update(*values, offset, numValuesToCheck, physicalType); +} + +MergedColumnChunkStats ColumnChunkData::getMergedColumnChunkStats() const { + const CompressionMetadata& onDiskMetadata = metadata.compMeta; + ColumnChunkStats stats = inMemoryStats; + const auto physicalType = getDataType().getPhysicalType(); + const bool isStorageValueType = + TypeUtils::visit(physicalType, [](T) { return StorageValueType; }); + if (isStorageValueType) { + stats.update(onDiskMetadata.min, onDiskMetadata.max, physicalType); + } + return MergedColumnChunkStats{stats, !nullData || nullData->haveNoNullsGuaranteed(), + nullData && nullData->haveAllNullsGuaranteed()}; +} + +void ColumnChunkData::updateStats(const ValueVector* vector, const SelectionView& selView) { + if (selView.isUnfiltered()) { + updateInMemoryStats(inMemoryStats, *vector); + } else { + TypeUtils::visit( + getDataType().getPhysicalType(), + [&](T) { + std::optional firstValue; + // ValueVector::firstNonNull uses the vector's builtin selection vector, not the one + // passed as an argument + selView.forEachBreakWhenFalse([&](auto i) { + if (vector->isNull(i)) { + return true; + } else { + firstValue = vector->getValue(i); + return false; + } + }); + if (!firstValue) { + return; + } + T min = *firstValue, max = *firstValue; + auto update = [&](sel_t pos) { + const auto val = vector->getValue(pos); + if (val < min) { + min = val; + } else if (val > max) { + max = val; + } + }; + if (vector->hasNoNullsGuarantee()) { + selView.forEach(update); + } else { + selView.forEach([&](auto pos) { + if (!vector->isNull(pos)) { + update(pos); + } + }); + } + inMemoryStats.update(StorageValue(min), StorageValue(max), + getDataType().getPhysicalType()); + }, + [](T) { static_assert(!StorageValueType); }); + } +} + +void ColumnChunkData::resetInMemoryStats() { + inMemoryStats.reset(); +} + +ColumnChunkMetadata ColumnChunkData::getMetadataToFlush() const { + KU_ASSERT(numValues <= capacity); + StorageValue minValue = {}, maxValue = {}; + if (capacity > 0) { + std::optional nullMask; + if (nullData) { + nullMask = nullData->getNullMask(); + } + auto [min, max] = + getMinMaxStorageValue(getData(), 0 /*offset*/, numValues, dataType.getPhysicalType(), + nullMask.has_value() ? &*nullMask : nullptr, true /*valueRequiredIfUnsupported*/); + minValue = min.value_or(StorageValue()); + maxValue = max.value_or(StorageValue()); + } + KU_ASSERT(getBufferSize() == getBufferSize(capacity)); + return getMetadataFunction(buffer->getBuffer(), numValues, minValue, maxValue); +} + +void ColumnChunkData::append(ValueVector* vector, const SelectionView& selView) { + KU_ASSERT(vector->dataType.getPhysicalType() == dataType.getPhysicalType()); + copyVectorToBuffer(vector, numValues, selView); + numValues += selView.getSelSize(); + updateStats(vector, selView); +} + +void ColumnChunkData::append(const ColumnChunkData* other, offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + KU_ASSERT(other->dataType.getPhysicalType() == dataType.getPhysicalType()); + if (nullData) { + KU_ASSERT(nullData->getNumValues() == getNumValues()); + nullData->append(other->nullData.get(), startPosInOtherChunk, numValuesToAppend); + } + KU_ASSERT(numValues + numValuesToAppend <= capacity); + memcpy(getData() + numValues * numBytesPerValue, + other->getData() + startPosInOtherChunk * numBytesPerValue, + numValuesToAppend * numBytesPerValue); + numValues += numValuesToAppend; + updateInMemoryStats(inMemoryStats, other, startPosInOtherChunk, numValuesToAppend); +} + +void ColumnChunkData::flush(PageAllocator& pageAllocator) { + const auto preScanMetadata = getMetadataToFlush(); + auto allocatedEntry = pageAllocator.allocatePageRange(preScanMetadata.getNumPages()); + const auto flushedMetadata = flushBuffer(pageAllocator, allocatedEntry, preScanMetadata); + setToOnDisk(flushedMetadata); + if (nullData) { + nullData->flush(pageAllocator); + } +} + +// Note: This function is not setting child/null chunk data recursively. +void ColumnChunkData::setToOnDisk(const ColumnChunkMetadata& otherMetadata) { + residencyState = ResidencyState::ON_DISK; + capacity = 0; + // Note: We don't need to set the buffer to nullptr, as it allows ColumnChunkData to be resized. + buffer = buffer->getMemoryManager()->allocateBuffer(true, 0 /*size*/); + this->metadata = otherMetadata; + this->numValues = otherMetadata.numValues; + resetInMemoryStats(); +} + +ColumnChunkMetadata ColumnChunkData::flushBuffer(PageAllocator& pageAllocator, + const PageRange& entry, const ColumnChunkMetadata& otherMetadata) const { + const auto bufferSizeToFlush = getBufferSize(numValues); + if (!otherMetadata.compMeta.isConstant() && bufferSizeToFlush != 0) { + KU_ASSERT(bufferSizeToFlush <= buffer->getBuffer().size_bytes()); + const auto bufferToFlush = buffer->getBuffer().subspan(0, bufferSizeToFlush); + return flushBufferFunction(bufferToFlush, pageAllocator.getDataFH(), entry, otherMetadata); + } + KU_ASSERT(otherMetadata.getNumPages() == 0); + return otherMetadata; +} + +uint64_t ColumnChunkData::getBufferSize(uint64_t capacity_) const { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + // 8 values per byte, and we need a buffer size which is a + // multiple of 8 bytes. + return ceil(capacity_ / 8.0 / 8.0) * 8; + } + default: { + return numBytesPerValue * capacity_; + } + } +} + +void ColumnChunkData::initializeScanState(SegmentState& state, const Column* column) const { + if (nullData) { + KU_ASSERT(state.nullState); + nullData->initializeScanState(*state.nullState, column->getNullColumn()); + } + state.column = column; + if (residencyState == ResidencyState::ON_DISK) { + state.metadata = metadata; + state.numValuesPerPage = state.metadata.compMeta.numValues(LBUG_PAGE_SIZE, dataType); + + state.column->populateExtraChunkState(state); + } +} + +void ColumnChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + KU_ASSERT(offset + length <= numValues); + if (nullData) { + nullData->scan(output, offset, length, posInOutputVector); + } + memcpy(output.getData() + posInOutputVector * numBytesPerValue, + getData() + offset * numBytesPerValue, numBytesPerValue * length); +} + +void ColumnChunkData::lookup(offset_t offsetInChunk, ValueVector& output, + sel_t posInOutputVector) const { + KU_ASSERT(offsetInChunk < capacity); + output.setNull(posInOutputVector, isNull(offsetInChunk)); + if (!output.isNull(posInOutputVector)) { + memcpy(output.getData() + posInOutputVector * numBytesPerValue, + getData() + offsetInChunk * numBytesPerValue, numBytesPerValue); + } +} + +void ColumnChunkData::write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, + RelMultiplicity multiplicity) { + KU_ASSERT(chunk->dataType.getPhysicalType() == dataType.getPhysicalType() && + dstOffsets->getDataType().getPhysicalType() == PhysicalTypeID::INTERNAL_ID && + chunk->getNumValues() == dstOffsets->getNumValues()); + for (auto i = 0u; i < dstOffsets->getNumValues(); i++) { + const auto dstOffset = dstOffsets->getValue(i); + KU_ASSERT(dstOffset < capacity); + memcpy(getData() + dstOffset * numBytesPerValue, chunk->getData() + i * numBytesPerValue, + numBytesPerValue); + numValues = dstOffset >= numValues ? dstOffset + 1 : numValues; + } + if (nullData || multiplicity == RelMultiplicity::ONE) { + for (auto i = 0u; i < dstOffsets->getNumValues(); i++) { + const auto dstOffset = dstOffsets->getValue(i); + if (multiplicity == RelMultiplicity::ONE && isNull(dstOffset)) { + throw CopyException( + stringFormat("Node with offset: {} can only have one neighbour due " + "to the MANY-ONE/ONE-ONE relationship constraint.", + dstOffset)); + } + if (nullData) { + nullData->setNull(dstOffset, chunk->isNull(i)); + } + } + } + updateInMemoryStats(inMemoryStats, chunk); +} + +// NOTE: This function is only called in LocalTable right now when +// performing out-of-place committing. LIST has a different logic for +// handling out-of-place committing as it has to be slided. However, +// this is unsafe, as this function can also be used for other purposes +// later. Thus, an assertion is added at the first line. +void ColumnChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + KU_ASSERT(dataType.getPhysicalType() != PhysicalTypeID::BOOL && + dataType.getPhysicalType() != PhysicalTypeID::LIST && + dataType.getPhysicalType() != PhysicalTypeID::ARRAY); + if (nullData) { + nullData->setNull(offsetInChunk, vector->isNull(offsetInVector)); + } + if (offsetInChunk >= numValues) { + numValues = offsetInChunk + 1; + } + if (!vector->isNull(offsetInVector)) { + memcpy(getData() + offsetInChunk * numBytesPerValue, + vector->getData() + offsetInVector * numBytesPerValue, numBytesPerValue); + } + static constexpr uint64_t numValuesToWrite = 1; + updateInMemoryStats(inMemoryStats, *vector, offsetInVector, numValuesToWrite); +} + +void ColumnChunkData::write(const ColumnChunkData* srcChunk, offset_t srcOffsetInChunk, + offset_t dstOffsetInChunk, offset_t numValuesToCopy) { + KU_ASSERT(srcChunk->dataType.getPhysicalType() == dataType.getPhysicalType()); + if ((dstOffsetInChunk + numValuesToCopy) >= numValues) { + numValues = dstOffsetInChunk + numValuesToCopy; + } + memcpy(getData() + dstOffsetInChunk * numBytesPerValue, + srcChunk->getData() + srcOffsetInChunk * numBytesPerValue, + numValuesToCopy * numBytesPerValue); + if (nullData) { + KU_ASSERT(srcChunk->getNullData()); + nullData->write(srcChunk->getNullData(), srcOffsetInChunk, dstOffsetInChunk, + numValuesToCopy); + } + updateInMemoryStats(inMemoryStats, srcChunk, srcOffsetInChunk, numValuesToCopy); +} + +void ColumnChunkData::resetNumValuesFromMetadata() { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + numValues = metadata.numValues; + if (nullData) { + nullData->resetNumValuesFromMetadata(); + // FIXME(bmwinger): not always working + // KU_ASSERT(numValues == nullData->numValues); + } +} + +void ColumnChunkData::setToInMemory() { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + KU_ASSERT(capacity == 0 && getBufferSize() == 0); + residencyState = ResidencyState::IN_MEMORY; + numValues = 0; + if (nullData) { + nullData->setToInMemory(); + } +} + +void ColumnChunkData::resize(uint64_t newCapacity) { + const auto numBytesAfterResize = getBufferSize(newCapacity); + if (numBytesAfterResize > getBufferSize()) { + auto resizedBuffer = buffer->getMemoryManager()->allocateBuffer(false, numBytesAfterResize); + auto bufferSize = getBufferSize(); + auto resizedBufferData = resizedBuffer->getBuffer().data(); + memcpy(resizedBufferData, buffer->getBuffer().data(), bufferSize); + memset(resizedBufferData + bufferSize, 0, numBytesAfterResize - bufferSize); + buffer = std::move(resizedBuffer); + } + if (nullData) { + nullData->resize(newCapacity); + } + if (newCapacity > capacity) { + capacity = newCapacity; + } +} + +void ColumnChunkData::resizeWithoutPreserve(uint64_t newCapacity) { + const auto numBytesAfterResize = getBufferSize(newCapacity); + if (numBytesAfterResize > getBufferSize()) { + auto resizedBuffer = buffer->getMemoryManager()->allocateBuffer(false, numBytesAfterResize); + buffer = std::move(resizedBuffer); + } + if (nullData) { + nullData->resize(newCapacity); + } + if (newCapacity > capacity) { + capacity = newCapacity; + } +} + +void ColumnChunkData::populateWithDefaultVal(ExpressionEvaluator& defaultEvaluator, + uint64_t& numValues_, ColumnStats* newColumnStats) { + auto numValuesAppended = 0u; + const auto numValuesToPopulate = numValues_; + while (numValuesAppended < numValuesToPopulate) { + const auto numValuesToAppend = + std::min(DEFAULT_VECTOR_CAPACITY, numValuesToPopulate - numValuesAppended); + defaultEvaluator.evaluate(numValuesToAppend); + auto resultVector = defaultEvaluator.resultVector.get(); + KU_ASSERT(resultVector->state->getSelVector().getSelSize() == numValuesToAppend); + append(resultVector, resultVector->state->getSelVector()); + if (newColumnStats) { + newColumnStats->update(resultVector); + } + numValuesAppended += numValuesToAppend; + } +} + +void ColumnChunkData::copyVectorToBuffer(ValueVector* vector, offset_t startPosInChunk, + const SelectionView& selView) { + auto bufferToWrite = buffer->getBuffer().data() + startPosInChunk * numBytesPerValue; + KU_ASSERT(startPosInChunk + selView.getSelSize() <= capacity); + const auto vectorDataToWriteFrom = vector->getData(); + if (nullData) { + nullData->appendNulls(vector, selView, startPosInChunk); + } + if (selView.isUnfiltered()) { + memcpy(bufferToWrite, vectorDataToWriteFrom, selView.getSelSize() * numBytesPerValue); + } else { + selView.forEach([&](auto pos) { + memcpy(bufferToWrite, vectorDataToWriteFrom + pos * numBytesPerValue, numBytesPerValue); + bufferToWrite += numBytesPerValue; + }); + } +} + +void ColumnChunkData::setNumValues(uint64_t numValues_) { + KU_ASSERT(numValues_ <= capacity); + numValues = numValues_; + if (nullData) { + nullData->setNumValues(numValues_); + } +} + +bool ColumnChunkData::numValuesSanityCheck() const { + if (nullData) { + return numValues == nullData->getNumValues(); + } + return numValues <= capacity; +} + +bool ColumnChunkData::sanityCheck() const { + if (nullData) { + return nullData->sanityCheck() && numValuesSanityCheck(); + } + return numValues <= capacity; +} + +uint64_t ColumnChunkData::getEstimatedMemoryUsage() const { + return buffer->getBuffer().size() + (nullData ? nullData->getEstimatedMemoryUsage() : 0); +} + +void ColumnChunkData::serialize(Serializer& serializer) const { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + serializer.writeDebuggingInfo("data_type"); + dataType.serialize(serializer); + serializer.writeDebuggingInfo("metadata"); + metadata.serialize(serializer); + serializer.writeDebuggingInfo("enable_compression"); + serializer.write(enableCompression); + serializer.writeDebuggingInfo("has_null"); + serializer.write(nullData != nullptr); + if (nullData) { + serializer.writeDebuggingInfo("null_data"); + nullData->serialize(serializer); + } +} + +std::unique_ptr ColumnChunkData::deserialize(MemoryManager& memoryManager, + Deserializer& deSer) { + std::string key; + ColumnChunkMetadata metadata; + bool enableCompression = false; + bool hasNull = false; + bool initializeToZero = true; + deSer.validateDebuggingInfo(key, "data_type"); + const auto dataType = LogicalType::deserialize(deSer); + deSer.validateDebuggingInfo(key, "metadata"); + metadata = decltype(metadata)::deserialize(deSer); + deSer.validateDebuggingInfo(key, "enable_compression"); + deSer.deserializeValue(enableCompression); + deSer.validateDebuggingInfo(key, "has_null"); + deSer.deserializeValue(hasNull); + auto chunkData = ColumnChunkFactory::createColumnChunkData(memoryManager, dataType.copy(), + enableCompression, metadata, hasNull, initializeToZero); + if (hasNull) { + deSer.validateDebuggingInfo(key, "null_data"); + chunkData->nullData = NullChunkData::deserialize(memoryManager, deSer); + } + + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::STRUCT: { + StructChunkData::deserialize(deSer, *chunkData); + } break; + case PhysicalTypeID::STRING: { + StringChunkData::deserialize(deSer, *chunkData); + } break; + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + ListChunkData::deserialize(deSer, *chunkData); + } break; + default: { + // DO NOTHING. + } + } + + return chunkData; +} + +void BoolChunkData::append(ValueVector* vector, const SelectionView& selView) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::BOOL); + for (auto i = 0u; i < selView.getSelSize(); i++) { + const auto pos = selView[i]; + NullMask::setNull(getData(), numValues + i, vector->getValue(pos)); + } + if (nullData) { + nullData->appendNulls(vector, selView, numValues); + } + numValues += selView.getSelSize(); + updateStats(vector, selView); +} + +void BoolChunkData::append(const ColumnChunkData* other, offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + NullMask::copyNullMask(other->getData(), startPosInOtherChunk, getData(), + numValues, numValuesToAppend); + if (nullData) { + nullData->append(other->getNullData(), startPosInOtherChunk, numValuesToAppend); + } + numValues += numValuesToAppend; + updateInMemoryStats(inMemoryStats, other, startPosInOtherChunk, numValuesToAppend); +} + +void BoolChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + KU_ASSERT(offset + length <= numValues); + if (nullData) { + nullData->scan(output, offset, length, posInOutputVector); + } + for (auto i = 0u; i < length; i++) { + output.setValue(posInOutputVector + i, + NullMask::isNull(getData(), offset + i)); + } +} + +void BoolChunkData::lookup(offset_t offsetInChunk, ValueVector& output, + sel_t posInOutputVector) const { + KU_ASSERT(offsetInChunk < capacity); + output.setNull(posInOutputVector, nullData->isNull(offsetInChunk)); + if (!output.isNull(posInOutputVector)) { + output.setValue(posInOutputVector, + NullMask::isNull(getData(), offsetInChunk)); + } +} + +void BoolChunkData::write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, RelMultiplicity) { + KU_ASSERT(chunk->getDataType().getPhysicalType() == PhysicalTypeID::BOOL && + dstOffsets->getDataType().getPhysicalType() == PhysicalTypeID::INTERNAL_ID && + chunk->getNumValues() == dstOffsets->getNumValues()); + for (auto i = 0u; i < dstOffsets->getNumValues(); i++) { + const auto dstOffset = dstOffsets->getValue(i); + KU_ASSERT(dstOffset < capacity); + NullMask::setNull(getData(), dstOffset, chunk->getValue(i)); + if (nullData) { + nullData->setNull(dstOffset, chunk->getNullData()->isNull(i)); + } + numValues = dstOffset >= numValues ? dstOffset + 1 : numValues; + } + updateInMemoryStats(inMemoryStats, chunk); +} + +void BoolChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::BOOL); + KU_ASSERT(offsetInChunk < capacity); + const auto valueToSet = vector->getValue(offsetInVector); + setValue(valueToSet, offsetInChunk); + if (nullData) { + nullData->write(vector, offsetInVector, offsetInChunk); + } + numValues = offsetInChunk >= numValues ? offsetInChunk + 1 : numValues; + if (!vector->isNull(offsetInVector)) { + inMemoryStats.update(StorageValue{valueToSet}, dataType.getPhysicalType()); + } +} + +void BoolChunkData::write(const ColumnChunkData* srcChunk, offset_t srcOffsetInChunk, + offset_t dstOffsetInChunk, offset_t numValuesToCopy) { + if (nullData) { + nullData->write(srcChunk->getNullData(), srcOffsetInChunk, dstOffsetInChunk, + numValuesToCopy); + } + if ((dstOffsetInChunk + numValuesToCopy) >= numValues) { + numValues = dstOffsetInChunk + numValuesToCopy; + } + NullMask::copyNullMask(srcChunk->getData(), srcOffsetInChunk, getData(), + dstOffsetInChunk, numValuesToCopy); + updateInMemoryStats(inMemoryStats, srcChunk, srcOffsetInChunk, numValuesToCopy); +} + +NullMask NullChunkData::getNullMask() const { + return NullMask( + std::span(getData(), ceilDiv(capacity, NullMask::NUM_BITS_PER_NULL_ENTRY)), + !noNullsGuaranteedInMem()); +} + +void NullChunkData::setNull(offset_t pos, bool isNull) { + setValue(isNull, pos); + // TODO(Guodong): Better let NullChunkData also support `append` a + // vector. +} + +void NullChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + const bool isNull = vector->isNull(offsetInVector); + setValue(isNull, offsetInChunk); +} + +void NullChunkData::write(const ColumnChunkData* srcChunk, offset_t srcOffsetInChunk, + offset_t dstOffsetInChunk, offset_t numValuesToCopy) { + if (numValuesToCopy == 0) { + return; + } + KU_ASSERT(srcChunk->getBufferSize() >= sizeof(uint64_t)); + copyFromBuffer(srcChunk->getData(), srcOffsetInChunk, dstOffsetInChunk, + numValuesToCopy); +} + +void NullChunkData::append(const ColumnChunkData* other, offset_t startOffsetInOtherChunk, + uint32_t numValuesToAppend) { + write(other, startOffsetInOtherChunk, numValues, numValuesToAppend); +} + +bool NullChunkData::haveNoNullsGuaranteed() const { + return noNullsGuaranteedInMem() && !metadata.compMeta.max.get(); +} + +bool NullChunkData::haveAllNullsGuaranteed() const { + return allNullsGuaranteedInMem() && metadata.compMeta.min.get(); +} + +void NullChunkData::serialize(Serializer& serializer) const { + KU_ASSERT(residencyState == ResidencyState::ON_DISK); + serializer.writeDebuggingInfo("null_chunk_metadata"); + metadata.serialize(serializer); +} + +std::unique_ptr NullChunkData::deserialize(MemoryManager& memoryManager, + Deserializer& deSer) { + std::string key; + ColumnChunkMetadata metadata; + deSer.validateDebuggingInfo(key, "null_chunk_metadata"); + metadata = decltype(metadata)::deserialize(deSer); + // TODO: FIX-ME. enableCompression. + return std::make_unique(memoryManager, true, metadata); +} + +void NullChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + output.setNullFromBits(getNullMask().getData(), offset, posInOutputVector, length); +} + +void NullChunkData::appendNulls(const ValueVector* vector, const SelectionView& selView, + offset_t startPosInChunk) { + if (selView.isUnfiltered()) { + copyFromBuffer(vector->getNullMask().getData(), 0, startPosInChunk, selView.getSelSize()); + } else { + for (auto i = 0u; i < selView.getSelSize(); i++) { + const auto pos = selView[i]; + setNull(startPosInChunk + i, vector->isNull(pos)); + } + } +} + +void InternalIDChunkData::append(ValueVector* vector, const SelectionView& selView) { + switch (vector->dataType.getPhysicalType()) { + case PhysicalTypeID::INTERNAL_ID: { + copyVectorToBuffer(vector, numValues, selView); + } break; + case PhysicalTypeID::INT64: { + copyInt64VectorToBuffer(vector, numValues, selView); + } break; + default: { + KU_UNREACHABLE; + } + } + numValues += selView.getSelSize(); +} + +void InternalIDChunkData::copyVectorToBuffer(ValueVector* vector, offset_t startPosInChunk, + const SelectionView& selView) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID); + const auto relIDsInVector = reinterpret_cast(vector->getData()); + if (commonTableID == INVALID_TABLE_ID) { + commonTableID = relIDsInVector[selView[0]].tableID; + } + for (auto i = 0u; i < selView.getSelSize(); i++) { + const auto pos = selView[i]; + if (vector->isNull(pos)) { + continue; + } + KU_ASSERT(relIDsInVector[pos].tableID == commonTableID); + memcpy(getData() + (startPosInChunk + i) * numBytesPerValue, &relIDsInVector[pos].offset, + numBytesPerValue); + } +} + +void InternalIDChunkData::copyInt64VectorToBuffer(ValueVector* vector, offset_t startPosInChunk, + const SelectionView& selView) const { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::INT64); + for (auto i = 0u; i < selView.getSelSize(); i++) { + const auto pos = selView[i]; + if (vector->isNull(pos)) { + continue; + } + memcpy(getData() + (startPosInChunk + i) * numBytesPerValue, + &vector->getValue(pos), numBytesPerValue); + } +} + +void InternalIDChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + KU_ASSERT(offset + length <= numValues); + KU_ASSERT(commonTableID != INVALID_TABLE_ID); + internalID_t relID; + relID.tableID = commonTableID; + for (auto i = 0u; i < length; i++) { + relID.offset = getValue(offset + i); + output.setValue(posInOutputVector + i, relID); + } +} + +void InternalIDChunkData::lookup(offset_t offsetInChunk, ValueVector& output, + sel_t posInOutputVector) const { + KU_ASSERT(offsetInChunk < capacity); + internalID_t relID; + relID.offset = getValue(offsetInChunk); + KU_ASSERT(commonTableID != INVALID_TABLE_ID); + relID.tableID = commonTableID; + output.setValue(posInOutputVector, relID); +} + +void InternalIDChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID); + const auto relIDsInVector = reinterpret_cast(vector->getData()); + if (commonTableID == INVALID_TABLE_ID) { + commonTableID = relIDsInVector[offsetInVector].tableID; + } + KU_ASSERT(commonTableID == relIDsInVector[offsetInVector].tableID); + if (!vector->isNull(offsetInVector)) { + memcpy(getData() + offsetInChunk * numBytesPerValue, &relIDsInVector[offsetInVector].offset, + numBytesPerValue); + } + if (offsetInChunk >= numValues) { + numValues = offsetInChunk + 1; + } +} + +void InternalIDChunkData::append(const ColumnChunkData* other, offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + ColumnChunkData::append(other, startPosInOtherChunk, numValuesToAppend); + commonTableID = other->cast().commonTableID; +} + +std::optional ColumnChunkData::getNullMask() const { + return nullData ? std::optional(nullData->getNullMask()) : std::nullopt; +} + +std::unique_ptr ColumnChunkFactory::createColumnChunkData(MemoryManager& mm, + LogicalType dataType, bool enableCompression, uint64_t capacity, ResidencyState residencyState, + bool hasNullData, bool initializeToZero) { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + return std::make_unique(mm, capacity, enableCompression, residencyState, + hasNullData); + } + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT8: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::INT128: + case PhysicalTypeID::UINT128: + case PhysicalTypeID::DOUBLE: + case PhysicalTypeID::FLOAT: + case PhysicalTypeID::INTERVAL: { + return std::make_unique(mm, std::move(dataType), capacity, + enableCompression, residencyState, hasNullData, initializeToZero); + } + case PhysicalTypeID::INTERNAL_ID: { + return std::make_unique(mm, capacity, enableCompression, + residencyState); + } + case PhysicalTypeID::STRING: { + return std::make_unique(mm, std::move(dataType), capacity, + enableCompression, residencyState); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return std::make_unique(mm, std::move(dataType), capacity, enableCompression, + residencyState); + } + case PhysicalTypeID::STRUCT: { + return std::make_unique(mm, std::move(dataType), capacity, + enableCompression, residencyState); + } + default: + KU_UNREACHABLE; + } +} + +std::unique_ptr ColumnChunkFactory::createColumnChunkData(MemoryManager& mm, + LogicalType dataType, bool enableCompression, ColumnChunkMetadata& metadata, bool hasNullData, + bool initializeToZero) { + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + return std::make_unique(mm, enableCompression, metadata, hasNullData); + } + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT8: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::INT128: + case PhysicalTypeID::UINT128: + case PhysicalTypeID::DOUBLE: + case PhysicalTypeID::FLOAT: + case PhysicalTypeID::INTERVAL: { + return std::make_unique(mm, std::move(dataType), enableCompression, + metadata, hasNullData, initializeToZero); + } + // Physically, we only materialize offset of INTERNAL_ID, which is same as INT64, + case PhysicalTypeID::INTERNAL_ID: { + // INTERNAL_ID should never have nulls. + return std::make_unique(mm, enableCompression, metadata); + } + case PhysicalTypeID::STRING: { + return std::make_unique(mm, enableCompression, metadata); + } + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: { + return std::make_unique(mm, std::move(dataType), enableCompression, + metadata); + } + case PhysicalTypeID::STRUCT: { + return std::make_unique(mm, std::move(dataType), enableCompression, + metadata); + } + default: + KU_UNREACHABLE; + } +} + +bool ColumnChunkData::isNull(offset_t pos) const { + return nullData && nullData->isNull(pos); +} + +MemoryManager& ColumnChunkData::getMemoryManager() const { + return *buffer->getMemoryManager(); +} + +uint8_t* ColumnChunkData::getData() const { + return buffer->getBuffer().data(); +} +uint64_t ColumnChunkData::getBufferSize() const { + return buffer->getBuffer().size_bytes(); +} + +void ColumnChunkData::loadFromDisk() { + buffer->getMemoryManager()->getBufferManager()->getSpillerOrSkip( + [&](auto& spiller) { spiller.loadFromDisk(*this); }); +} + +SpillResult ColumnChunkData::spillToDisk() { + SpillResult spilled{}; + buffer->getMemoryManager()->getBufferManager()->getSpillerOrSkip( + [&](auto& spiller) { spilled = spiller.spillToDisk(*this); }); + return spilled; +} + +void ColumnChunkData::reclaimStorage(PageAllocator& pageAllocator) { + if (nullData) { + nullData->reclaimStorage(pageAllocator); + } + if (residencyState == ResidencyState::ON_DISK) { + if (metadata.getStartPageIdx() != INVALID_PAGE_IDX) { + pageAllocator.freePageRange(metadata.pageRange); + } + } +} + +uint64_t ColumnChunkData::getSizeOnDisk() const { + // Probably could just return the actual size from the metadata if it's on-disk, but it's not + // currently needed for on-disk segments + KU_ASSERT(ResidencyState::IN_MEMORY == residencyState); + auto metadata = getMetadataToFlush(); + uint64_t nullSize = 0; + if (nullData) { + nullSize = nullData->getSizeOnDisk(); + } + return metadata.getNumDataPages(dataType.getPhysicalType()) * common::LBUG_PAGE_SIZE + nullSize; +} + +uint64_t ColumnChunkData::getSizeOnDiskInMemoryStats() const { + // Probably could just return the actual size from the metadata if it's on-disk, but it's not + // currently needed for on-disk segments + KU_ASSERT(ResidencyState::IN_MEMORY == residencyState); + uint64_t nullSize = 0; + if (nullData) { + nullSize = nullData->getSizeOnDiskInMemoryStats(); + } + auto metadata = getMetadataFunction(buffer->getBuffer(), numValues, + inMemoryStats.min.value_or(StorageValue{}), inMemoryStats.max.value_or(StorageValue{})); + return metadata.getNumDataPages(dataType.getPhysicalType()) * common::LBUG_PAGE_SIZE + nullSize; +} + +std::vector> ColumnChunkData::split(bool targetMaxSize) const { + // FIXME(bmwinger): we either need to split recursively, or detect individual values which bring + // the size above MAX_SEGMENT_SIZE, since this will still sometimes produce segments larger than + // MAX_SEGMENT_SIZE + auto maxSegmentSize = std::max(getMinimumSizeOnDisk(), common::StorageConfig::MAX_SEGMENT_SIZE); + auto targetSize = + targetMaxSize ? maxSegmentSize : std::min(getSizeOnDisk() / 2, maxSegmentSize); + std::vector> newSegments; + uint64_t pos = 0; + const uint64_t chunkSize = 64; + uint64_t initialCapacity = std::min(chunkSize, numValues); + while (pos < numValues) { + std::unique_ptr newSegment = + ColumnChunkFactory::createColumnChunkData(getMemoryManager(), getDataType().copy(), + isCompressionEnabled(), initialCapacity, ResidencyState::IN_MEMORY, hasNullData()); + + while (pos < numValues && newSegment->getSizeOnDiskInMemoryStats() <= targetSize) { + if (newSegment->getNumValues() == newSegment->getCapacity()) { + newSegment->resize(newSegment->getCapacity() * 2); + } + auto numValuesToAppendInChunk = std::min(numValues - pos, chunkSize); + newSegment->append(this, pos, numValuesToAppendInChunk); + pos += numValuesToAppendInChunk; + } + if (pos < numValues && newSegment->getNumValues() > chunkSize) { + // Size exceeded target size, so we should drop the last batch added (unless they are + // the only values) + pos -= chunkSize; + newSegment->truncate(newSegment->getNumValues() - chunkSize); + } + newSegments.push_back(std::move(newSegment)); + } + return newSegments; +} + +ColumnChunkData::~ColumnChunkData() = default; + +uint64_t ColumnChunkData::getMinimumSizeOnDisk() const { + if (hasNullData() && nullData->getSizeOnDisk() > 0) { + return 2 * LBUG_PAGE_SIZE; + } + return LBUG_PAGE_SIZE; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_metadata.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_metadata.cpp new file mode 100644 index 0000000000..042ef819b4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_metadata.cpp @@ -0,0 +1,241 @@ +#include "storage/table/column_chunk_metadata.h" + +#include "alp/decode.hpp" +#include "alp/encode.hpp" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/system_config.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "common/utils.h" +#include "storage/compression/compression.h" +#include "storage/compression/float_compression.h" + +namespace lbug::storage { +using namespace common; + +ColumnChunkMetadata GetCompressionMetadata::operator()(std::span buffer, + uint64_t numValues, StorageValue min, StorageValue max) const { + if (min == max) { + return ColumnChunkMetadata(INVALID_PAGE_IDX, 0, numValues, + CompressionMetadata(min, max, CompressionType::CONSTANT)); + } + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::BOOL: { + return booleanGetMetadata(numValues, min, max); + } + case PhysicalTypeID::STRING: + case PhysicalTypeID::INT64: + case PhysicalTypeID::INT32: + case PhysicalTypeID::INT16: + case PhysicalTypeID::INT8: + case PhysicalTypeID::INTERNAL_ID: + case PhysicalTypeID::ARRAY: + case PhysicalTypeID::LIST: + case PhysicalTypeID::UINT64: + case PhysicalTypeID::UINT32: + case PhysicalTypeID::UINT16: + case PhysicalTypeID::UINT8: + case PhysicalTypeID::INT128: { + return GetBitpackingMetadata(alg, dataType)(buffer, numValues, min, max); + } + case PhysicalTypeID::DOUBLE: { + return GetFloatCompressionMetadata(alg, dataType)(buffer, numValues, min, max); + } + case PhysicalTypeID::FLOAT: { + return GetFloatCompressionMetadata(alg, dataType)(buffer, numValues, min, max); + } + default: { + return uncompressedGetMetadata(dataType.getPhysicalType(), numValues, min, max); + } + } +} + +ColumnChunkMetadata uncompressedGetMetadata(PhysicalTypeID dataType, uint64_t numValues, + StorageValue min, StorageValue max) { + auto numPages = 0; + if (getDataTypeSizeInChunk(dataType) > 0) { + const auto numValuesPerPage = Uncompressed::numValues(LBUG_PAGE_SIZE, dataType); + numPages = ceilDiv(numValues, numValuesPerPage); + } + return ColumnChunkMetadata(INVALID_PAGE_IDX, numPages, numValues, + CompressionMetadata(min, max, CompressionType::UNCOMPRESSED)); +} + +ColumnChunkMetadata booleanGetMetadata(uint64_t numValues, StorageValue min, StorageValue max) { + return ColumnChunkMetadata(INVALID_PAGE_IDX, + ceilDiv(ceilDiv(numValues, uint64_t{8}), LBUG_PAGE_SIZE), numValues, + CompressionMetadata(min, max, CompressionType::BOOLEAN_BITPACKING)); +} + +void ColumnChunkMetadata::serialize(common::Serializer& serializer) const { + serializer.write(pageRange.startPageIdx); + serializer.write(pageRange.numPages); + serializer.write(numValues); + compMeta.serialize(serializer); +} + +ColumnChunkMetadata ColumnChunkMetadata::deserialize(common::Deserializer& deserializer) { + ColumnChunkMetadata ret; + deserializer.deserializeValue(ret.pageRange.startPageIdx); + deserializer.deserializeValue(ret.pageRange.numPages); + deserializer.deserializeValue(ret.numValues); + ret.compMeta = decltype(ret.compMeta)::deserialize(deserializer); + + return ret; +} + +page_idx_t ColumnChunkMetadata::getNumDataPages(PhysicalTypeID dataType) const { + switch (compMeta.compression) { + case CompressionType::ALP: { + return TypeUtils::visit( + dataType, + [this](T) -> page_idx_t { + return FloatCompression::getNumDataPages(getNumPages(), compMeta); + }, + [](auto) -> page_idx_t { KU_UNREACHABLE; }); + } + default: + return getNumPages(); + } +} + +ColumnChunkMetadata GetBitpackingMetadata::operator()(std::span /*buffer*/, + uint64_t numValues, StorageValue min, StorageValue max) { + // For supported types, min and max may be null if all values are null + // Compression is supported in this case + // Unsupported types always return a dummy value (where min != max) + // so that we don't constant compress them + auto compMeta = CompressionMetadata(min, max, alg->getCompressionType()); + if (alg->getCompressionType() == CompressionType::INTEGER_BITPACKING) { + TypeUtils::visit( + dataType.getPhysicalType(), + [&](T) { + // If integer bitpacking bitwidth is the maximum, bitpacking cannot be used + // and has poor performance compared to uncompressed + if (IntegerBitpacking::getPackingInfo(compMeta).bitWidth >= sizeof(T) * 8) { + compMeta = CompressionMetadata(min, max, CompressionType::UNCOMPRESSED); + } + }, + [&](auto) {}); + } + const auto numValuesPerPage = compMeta.numValues(LBUG_PAGE_SIZE, dataType); + const auto numPages = + numValuesPerPage == UINT64_MAX ? + 0 : + numValues / numValuesPerPage + (numValues % numValuesPerPage == 0 ? 0 : 1); + return ColumnChunkMetadata(INVALID_PAGE_IDX, numPages, numValues, compMeta); +} + +namespace { +ColumnChunkMetadata getConstantFloatMetadata(PhysicalTypeID physicalType, uint64_t numValues, + StorageValue min, StorageValue max) { + return {INVALID_PAGE_IDX, 0, numValues, + CompressionMetadata(min, max, CompressionType::CONSTANT, alp::state{}, StorageValue{0}, + StorageValue{0}, physicalType)}; +} + +template +alp::state getAlpMetadata(const T* buffer, uint64_t numValues) { + alp::state alpMetadata; + std::vector sampleBuffer(alp::config::SAMPLES_PER_ROWGROUP); + alp::AlpEncode::init(buffer, 0, numValues, sampleBuffer.data(), alpMetadata); + + if (alpMetadata.scheme == alp::SCHEME::ALP) { + if (alpMetadata.k_combinations > 1) { + alp::AlpEncode::find_best_exponent_factor_from_combinations( + alpMetadata.best_k_combinations, alpMetadata.k_combinations, buffer, numValues, + alpMetadata.fac, alpMetadata.exp); + } else { + KU_ASSERT(alpMetadata.best_k_combinations.size() == 1); + alpMetadata.exp = alpMetadata.best_k_combinations[0].first; + alpMetadata.fac = alpMetadata.best_k_combinations[0].second; + } + } + + return alpMetadata; +} + +template +CompressionMetadata createFloatMetadata(CompressionType compressionType, + PhysicalTypeID physicalType, std::span src, alp::state& alpMetadata, StorageValue min, + StorageValue max) { + using EncodedType = typename FloatCompression::EncodedType; + + offset_vec_t unsuccessfulEncodeIdxes; + std::vector floatEncodedValues(src.size()); + std::optional firstSuccessfulEncode; + size_t exceptionCount = 0; + for (offset_t i = 0; i < src.size(); ++i) { + const T& val = src[i]; + const auto encoded_value = + alp::AlpEncode::encode_value(val, alpMetadata.fac, alpMetadata.exp); + const auto decoded_value = + alp::AlpDecode::decode_value(encoded_value, alpMetadata.fac, alpMetadata.exp); + + if (val == decoded_value) { + floatEncodedValues[i] = encoded_value; + if (!firstSuccessfulEncode.has_value()) { + firstSuccessfulEncode = encoded_value; + } + } else { + unsuccessfulEncodeIdxes.push_back(i); + ++exceptionCount; + } + } + alpMetadata.exceptions_count = exceptionCount; + + if (firstSuccessfulEncode.has_value()) { + for (auto unsuccessfulEncodeIdx : unsuccessfulEncodeIdxes) { + floatEncodedValues[unsuccessfulEncodeIdx] = firstSuccessfulEncode.value(); + } + } + + const auto& [minEncoded, maxEncoded] = + std::minmax_element(floatEncodedValues.begin(), floatEncodedValues.end()); + + return CompressionMetadata(min, max, compressionType, alpMetadata, StorageValue{*minEncoded}, + StorageValue{*maxEncoded}, physicalType); +} +} // namespace + +template +ColumnChunkMetadata GetFloatCompressionMetadata::operator()(std::span buffer, + uint64_t numValues, StorageValue min, StorageValue max) { + const PhysicalTypeID physicalType = + std::same_as ? PhysicalTypeID::DOUBLE : PhysicalTypeID::FLOAT; + + if (min == max) { + return getConstantFloatMetadata(physicalType, numValues, min, max); + } + + if (numValues == 0) { + return uncompressedGetMetadata(physicalType, numValues, min, max); + } + + std::span castedBuffer{reinterpret_cast(buffer.data()), (size_t)numValues}; + alp::state alpMetadata = getAlpMetadata(castedBuffer.data(), numValues); + if (alpMetadata.scheme != alp::SCHEME::ALP) { + return uncompressedGetMetadata(physicalType, numValues, min, max); + } + + const auto compMeta = createFloatMetadata(alg->getCompressionType(), physicalType, castedBuffer, + alpMetadata, min, max); + const auto* floatMetadata = compMeta.floatMetadata(); + const auto exceptionCount = floatMetadata->exceptionCount; + + if (exceptionCount * FloatCompression::MAX_EXCEPTION_FACTOR >= numValues) { + return uncompressedGetMetadata(physicalType, numValues, min, max); + } + + const auto numValuesPerPage = compMeta.numValues(LBUG_PAGE_SIZE, dataType); + const auto numPagesForEncoded = ceilDiv(numValues, numValuesPerPage); + const auto numPagesForExceptions = + EncodeException::numPagesFromExceptions(floatMetadata->exceptionCapacity); + return ColumnChunkMetadata(INVALID_PAGE_IDX, numPagesForEncoded + numPagesForExceptions, + numValues, compMeta); +} + +template class GetFloatCompressionMetadata; +template class GetFloatCompressionMetadata; +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_stats.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_stats.cpp new file mode 100644 index 0000000000..4f7188e9cb --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_chunk_stats.cpp @@ -0,0 +1,62 @@ +#include "storage/table/column_chunk_stats.h" + +#include "common/type_utils.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/table/column_chunk_data.h" + +namespace lbug { +namespace storage { + +void ColumnChunkStats::update(const ColumnChunkData& data, uint64_t offset, uint64_t numValues, + common::PhysicalTypeID physicalType) { + const bool isStorageValueType = + common::TypeUtils::visit(physicalType, [](T) { return StorageValueType; }); + if (isStorageValueType || physicalType == common::PhysicalTypeID::INTERNAL_ID) { + auto [minVal, maxVal] = getMinMaxStorageValue(data, offset, numValues, physicalType); + update(minVal, maxVal, physicalType); + } +} + +void ColumnChunkStats::update(const common::ValueVector& data, uint64_t offset, uint64_t numValues, + common::PhysicalTypeID physicalType) { + const bool isStorageValueType = + common::TypeUtils::visit(physicalType, [](T) { return StorageValueType; }); + if (isStorageValueType || physicalType == common::PhysicalTypeID::INTERNAL_ID) { + auto [minVal, maxVal] = getMinMaxStorageValue(data, offset, numValues, physicalType); + update(minVal, maxVal, physicalType); + } +} + +void ColumnChunkStats::update(std::optional newMin, + std::optional newMax, common::PhysicalTypeID dataType) { + if (!min.has_value() || (newMin.has_value() && min->gt(*newMin, dataType))) { + min = newMin; + } + if (!max.has_value() || (newMax.has_value() && newMax->gt(*max, dataType))) { + max = newMax; + } +} + +void ColumnChunkStats::update(StorageValue val, common::PhysicalTypeID dataType) { + if (!min.has_value() || min->gt(val, dataType)) { + min = val; + } + if (!max.has_value() || val.gt(*max, dataType)) { + max = val; + } +} + +void ColumnChunkStats::reset() { + *this = {}; +} + +void MergedColumnChunkStats::merge(const MergedColumnChunkStats& o, + common::PhysicalTypeID dataType) { + stats.update(o.stats.min, o.stats.max, dataType); + guaranteedNoNulls = guaranteedNoNulls && o.guaranteedNoNulls; + guaranteedAllNulls = guaranteedAllNulls && o.guaranteedAllNulls; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_reader_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_reader_writer.cpp new file mode 100644 index 0000000000..b2c0b015c0 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/column_reader_writer.cpp @@ -0,0 +1,457 @@ +#include "storage/table/column_reader_writer.h" + +#include "alp/encode.hpp" +#include "common/utils.h" +#include "common/vector/value_vector.h" +#include "storage/compression/float_compression.h" +#include "storage/file_handle.h" +#include "storage/shadow_utils.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/column_chunk_metadata.h" +#include + +namespace lbug::storage { + +using namespace common; +using namespace transaction; + +namespace { +[[maybe_unused]] bool isPageIdxValid(page_idx_t pageIdx, const ColumnChunkMetadata& metadata) { + return (metadata.getStartPageIdx() <= pageIdx && + pageIdx < metadata.getStartPageIdx() + metadata.getNumPages()) || + (pageIdx == INVALID_PAGE_IDX && metadata.compMeta.isConstant()); +} + +template +concept WriteToPageHelper = requires(T obj, InputType input, offset_t offset, ElementType element) { + { obj.getValue(offset) } -> std::same_as; + { obj.setValue(offset, element) }; + { obj.getData() } -> std::same_as; +} && std::is_constructible_v; + +template +struct WriteToBufferHelper { + WriteToBufferHelper(const uint8_t* inputBuffer, size_t numValues) + : inputBuffer(inputBuffer), + outputBuffer(std::make_unique(numValues * sizeof(T))) {} + T getValue(offset_t offset) const { return reinterpret_cast(inputBuffer)[offset]; } + void setValue(offset_t offset, T element) { + reinterpret_cast(outputBuffer.get())[offset] = element; + } + const uint8_t* getData() const { return outputBuffer.get(); } + + const uint8_t* inputBuffer; + std::unique_ptr outputBuffer; +}; +static_assert(WriteToPageHelper, const uint8_t*, double>); +static_assert(WriteToPageHelper, const uint8_t*, float>); + +template +struct WriteToVectorHelper { + explicit WriteToVectorHelper(ValueVector* inputVec, size_t /*numValues*/) + : inputVec(inputVec), + outputVec(std::is_same_v ? LogicalTypeID::DOUBLE : LogicalTypeID::FLOAT) {} + T getValue(offset_t offset) const { return inputVec->getValue(offset); } + void setValue(offset_t offset, T element) { outputVec.setValue(offset, element); } + ValueVector* getData() { return &outputVec; } + + ValueVector* inputVec; + ValueVector outputVec; +}; +static_assert(WriteToPageHelper, ValueVector*, double>); +static_assert(WriteToPageHelper, ValueVector*, float>); + +template +decltype(auto) getWriteToPageBufferHelper(InputType input, size_t numValues) { + if constexpr (std::is_same_v) { + return WriteToBufferHelper(input, numValues); + } else { + return WriteToVectorHelper(input, numValues); + } +} + +template +class FloatColumnReadWriter; + +class DefaultColumnReadWriter final : public ColumnReadWriter { +public: + DefaultColumnReadWriter(FileHandle* dataFH, ShadowFile* shadowFile) + : ColumnReadWriter(dataFH, shadowFile) {} + + void readCompressedValueToPage(const SegmentState& state, common::offset_t offsetInSegment, + uint8_t* result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) override { + auto cursor = getPageCursorForOffsetInGroup(offsetInSegment, + state.metadata.getStartPageIdx(), state.numValuesPerPage); + readCompressedValue(state.metadata, cursor, offsetInSegment, result, + offsetInResult, readFunc); + } + + void readCompressedValueToVector(const SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) override { + auto cursor = getPageCursorForOffsetInGroup(offsetInSegment, + state.metadata.getStartPageIdx(), state.numValuesPerPage); + readCompressedValue(state.metadata, cursor, offsetInSegment, result, + offsetInResult, readFunc); + } + + uint64_t readCompressedValuesToPage(const SegmentState& state, uint8_t* result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc) override { + return readCompressedValues(state, result, startOffsetInResult, startOffsetInSegment, + length, readFunc, filterFunc); + } + + uint64_t readCompressedValuesToVector(const SegmentState& state, common::ValueVector* result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc) override { + return readCompressedValues(state, result, startOffsetInResult, startOffsetInSegment, + length, readFunc, filterFunc); + } + + void writeValueToPageFromVector(SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* vectorToWriteFrom, uint32_t posInVectorToWriteFrom, + const write_values_from_vector_func_t& writeFromVectorFunc) override { + writeValuesToPage(state, offsetInSegment, vectorToWriteFrom, posInVectorToWriteFrom, 1, + writeFromVectorFunc, &vectorToWriteFrom->getNullMask()); + } + + void writeValuesToPageFromBuffer(SegmentState& state, offset_t dstOffset, const uint8_t* data, + const NullMask* nullChunkData, offset_t srcOffset, offset_t numValues, + const write_values_func_t& writeFunc) override { + writeValuesToPage(state, dstOffset, data, srcOffset, numValues, writeFunc, nullChunkData); + } + + template + void writeValuesToPage(SegmentState& state, offset_t dstOffset, InputType data, + offset_t srcOffset, offset_t numValues, + const write_values_to_page_func_t& writeFunc, + const NullMask* nullMask) { + auto numValuesWritten = 0u; + auto cursor = getPageCursorForOffsetInGroup(dstOffset, state.metadata.getStartPageIdx(), + state.numValuesPerPage); + while (numValuesWritten < numValues) { + KU_ASSERT( + cursor.pageIdx == INVALID_PAGE_IDX /*constant compression*/ || + cursor.pageIdx < state.metadata.getStartPageIdx() + state.metadata.getNumPages()); + auto numValuesToWriteInPage = std::min(numValues - numValuesWritten, + state.numValuesPerPage - cursor.elemPosInPage); + updatePageWithCursor(cursor, [&](auto frame, auto offsetInPage) { + if constexpr (std::is_same_v) { + writeFunc(frame, offsetInPage, data, srcOffset + numValuesWritten, + numValuesToWriteInPage, state.metadata.compMeta); + } else { + writeFunc(frame, offsetInPage, data, srcOffset + numValuesWritten, + numValuesToWriteInPage, state.metadata.compMeta, nullMask); + } + }); + numValuesWritten += numValuesToWriteInPage; + cursor.nextPage(); + } + } + + template + void readCompressedValue(const ColumnChunkMetadata& metadata, PageCursor cursor, + common::offset_t /*offsetInSegment*/, OutputType result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) { + + readFromPage(cursor.pageIdx, [&](uint8_t* frame) -> void { + readFunc(frame, cursor, result, offsetInResult, 1 /* numValuesToRead */, + metadata.compMeta); + }); + } + + template + uint64_t readCompressedValues(const SegmentState& state, OutputType result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc) { + const ColumnChunkMetadata& chunkMeta = state.metadata; + if (length == 0) { + return 0; + } + + auto pageCursor = getPageCursorForOffsetInGroup(startOffsetInSegment, + chunkMeta.getStartPageIdx(), state.numValuesPerPage); + KU_ASSERT(isPageIdxValid(pageCursor.pageIdx, chunkMeta)); + + uint64_t numValuesScanned = 0; + while (numValuesScanned < length) { + uint64_t numValuesToScanInPage = std::min( + state.numValuesPerPage - pageCursor.elemPosInPage, length - numValuesScanned); + KU_ASSERT(isPageIdxValid(pageCursor.pageIdx, chunkMeta)); + if (!filterFunc.has_value() || + filterFunc.value()(numValuesScanned, numValuesScanned + numValuesToScanInPage)) { + + const auto readFromPageFunc = [&](uint8_t* frame) -> void { + readFunc(frame, pageCursor, result, numValuesScanned + startOffsetInResult, + numValuesToScanInPage, chunkMeta.compMeta); + }; + readFromPage(pageCursor.pageIdx, std::cref(readFromPageFunc)); + } + numValuesScanned += numValuesToScanInPage; + pageCursor.nextPage(); + } + + return numValuesScanned; + } +}; + +template +class FloatColumnReadWriter final : public ColumnReadWriter { +public: + FloatColumnReadWriter(FileHandle* dataFH, ShadowFile* shadowFile) + : ColumnReadWriter(dataFH, shadowFile), + defaultReader(std::make_unique(dataFH, shadowFile)) {} + + void readCompressedValueToPage(const SegmentState& state, common::offset_t offsetInSegment, + uint8_t* result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) override { + readCompressedValue(state, offsetInSegment, result, offsetInResult, readFunc); + } + + void readCompressedValueToVector(const SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) override { + readCompressedValue(state, offsetInSegment, result, offsetInResult, readFunc); + } + + uint64_t readCompressedValuesToPage(const SegmentState& state, uint8_t* result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc) override { + return readCompressedValues(state, result, startOffsetInResult, startOffsetInSegment, + length, readFunc, filterFunc); + } + + uint64_t readCompressedValuesToVector(const SegmentState& state, common::ValueVector* result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc) override { + return readCompressedValues(state, result, startOffsetInResult, startOffsetInSegment, + length, readFunc, filterFunc); + } + + void writeValueToPageFromVector(SegmentState& state, common::offset_t offsetInSegment, + common::ValueVector* vectorToWriteFrom, uint32_t posInVectorToWriteFrom, + const write_values_from_vector_func_t& writeFromVectorFunc) override { + if (state.metadata.compMeta.compression != CompressionType::ALP) { + return defaultReader->writeValueToPageFromVector(state, offsetInSegment, + vectorToWriteFrom, posInVectorToWriteFrom, writeFromVectorFunc); + } + + writeValuesToPage(state, offsetInSegment, vectorToWriteFrom, posInVectorToWriteFrom, 1, + writeFromVectorFunc, &vectorToWriteFrom->getNullMask()); + } + + void writeValuesToPageFromBuffer(SegmentState& state, offset_t dstOffset, const uint8_t* data, + const NullMask* nullChunkData, offset_t srcOffset, offset_t numValues, + const write_values_func_t& writeFunc) override { + if (state.metadata.compMeta.compression != CompressionType::ALP) { + defaultReader->writeValuesToPageFromBuffer(state, dstOffset, data, nullChunkData, + srcOffset, numValues, writeFunc); + return; + } + + writeValuesToPage(state, dstOffset, data, srcOffset, numValues, writeFunc, nullChunkData); + } + +private: + template + void patchFloatExceptions(const SegmentState& state, offset_t startOffsetInChunk, + size_t numValuesToScan, OutputType result, offset_t startOffsetInResult, + const std::optional& filterFunc) { + auto* exceptionChunk = state.getExceptionChunkConst(); + offset_t curExceptionIdx = + exceptionChunk->findFirstExceptionAtOrPastOffset(startOffsetInChunk); + for (; curExceptionIdx < exceptionChunk->getExceptionCount(); ++curExceptionIdx) { + const auto curException = exceptionChunk->getExceptionAt(curExceptionIdx); + KU_ASSERT(curExceptionIdx == 0 || + curException.posInChunk > + exceptionChunk->getExceptionAt(curExceptionIdx - 1).posInChunk); + KU_ASSERT(curException.posInChunk >= curExceptionIdx); + if (curException.posInChunk >= startOffsetInChunk + numValuesToScan) { + break; + } + const offset_t offsetInResult = + startOffsetInResult + curException.posInChunk - startOffsetInChunk; + if (!filterFunc.has_value() || filterFunc.value()(offsetInResult, offsetInResult + 1)) { + if constexpr (std::is_same_v) { + reinterpret_cast(result)[offsetInResult] = curException.value; + } else { + static_assert(std::is_same_v>); + reinterpret_cast(result->getData())[offsetInResult] = curException.value; + } + } + } + } + + template + void readCompressedValue(const SegmentState& state, common::offset_t offsetInSegment, + OutputType result, uint32_t offsetInResult, + const read_value_from_page_func_t& readFunc) { + RUNTIME_CHECK(const ColumnChunkMetadata& metadata = state.metadata); + KU_ASSERT(metadata.compMeta.compression == CompressionType::ALP || + metadata.compMeta.compression == CompressionType::CONSTANT || + metadata.compMeta.compression == CompressionType::UNCOMPRESSED); + std::optional filterFunc{}; + readCompressedValues(state, result, offsetInResult, offsetInSegment, 1, readFunc, + filterFunc); + } + + template + uint64_t readCompressedValues(const SegmentState& state, OutputType result, + uint32_t startOffsetInResult, uint64_t startOffsetInSegment, uint64_t length, + const read_values_from_page_func_t& readFunc, + const std::optional& filterFunc) { + const ColumnChunkMetadata& metadata = state.metadata; + KU_ASSERT(metadata.compMeta.compression == CompressionType::ALP || + metadata.compMeta.compression == CompressionType::CONSTANT || + metadata.compMeta.compression == CompressionType::UNCOMPRESSED); + + const uint64_t numValuesScanned = + defaultReader->readCompressedValues(state, result, startOffsetInResult, + startOffsetInSegment, length, readFunc, std::optional{filterFunc}); + + if (metadata.compMeta.compression == CompressionType::ALP && numValuesScanned > 0) { + // we pass in copies of the filter func as it can hold state which may need resetting + // between scanning passes + patchFloatExceptions(state, startOffsetInSegment, length, result, startOffsetInResult, + std::optional{filterFunc}); + } + + return numValuesScanned; + } + + template + void writeValuesToPage(SegmentState& state, offset_t offsetInSegment, InputType data, + uint32_t srcOffset, size_t numValues, + const write_values_to_page_func_t& writeFunc, + const NullMask* nullMask) { + const ColumnChunkMetadata& metadata = state.metadata; + + auto* exceptionChunk = state.getExceptionChunk(); + + auto writeToPageBufferHelper = getWriteToPageBufferHelper(data, numValues); + + const auto bitpackHeader = FloatCompression::getBitpackInfo(state.metadata.compMeta); + offset_t curExceptionIdx = + exceptionChunk->findFirstExceptionAtOrPastOffset(offsetInSegment); + + const auto maxWrittenPosInChunk = offsetInSegment + numValues; + uint32_t curExceptionPosInChunk = + (curExceptionIdx < exceptionChunk->getExceptionCount()) ? + exceptionChunk->getExceptionAt(curExceptionIdx).posInChunk : + maxWrittenPosInChunk; + + for (size_t i = 0; i < numValues; ++i) { + const size_t writeOffset = offsetInSegment + i; + const size_t readOffset = srcOffset + i; + + if (nullMask && nullMask->isNull(readOffset)) { + continue; + } + + while (curExceptionPosInChunk < writeOffset) { + ++curExceptionIdx; + if (curExceptionIdx < exceptionChunk->getExceptionCount()) { + curExceptionPosInChunk = + exceptionChunk->getExceptionAt(curExceptionIdx).posInChunk; + } else { + curExceptionPosInChunk = maxWrittenPosInChunk; + } + } + + const T newValue = writeToPageBufferHelper.getValue(readOffset); + const auto* floatMetadata = metadata.compMeta.floatMetadata(); + const auto encodedValue = + alp::AlpEncode::encode_value(newValue, floatMetadata->fac, floatMetadata->exp); + const T decodedValue = alp::AlpDecode::decode_value(encodedValue, floatMetadata->fac, + floatMetadata->exp); + + bool newValueIsException = newValue != decodedValue; + writeToPageBufferHelper.setValue(i, + newValueIsException ? bitpackHeader.offset : newValue); + + // if the previous value was an exception + // either overwrite it (if the new value is also an exception) or remove it + if (curExceptionPosInChunk == writeOffset) { + if (newValueIsException) { + exceptionChunk->writeException( + EncodeException{newValue, safeIntegerConversion(writeOffset)}, + curExceptionIdx); + } else { + exceptionChunk->removeExceptionAt(curExceptionIdx); + } + } else if (newValueIsException) { + exceptionChunk->addException( + EncodeException{newValue, safeIntegerConversion(writeOffset)}); + } + } + + defaultReader->writeValuesToPage(state, offsetInSegment, writeToPageBufferHelper.getData(), + 0, numValues, writeFunc, nullMask); + } + + std::unique_ptr defaultReader; +}; + +} // namespace + +std::unique_ptr ColumnReadWriterFactory::createColumnReadWriter( + PhysicalTypeID dataType, FileHandle* dataFH, ShadowFile* shadowFile) { + switch (dataType) { + case PhysicalTypeID::FLOAT: + return std::make_unique>(dataFH, shadowFile); + case PhysicalTypeID::DOUBLE: + return std::make_unique>(dataFH, shadowFile); + default: + return std::make_unique(dataFH, shadowFile); + } +} + +ColumnReadWriter::ColumnReadWriter(FileHandle* dataFH, ShadowFile* shadowFile) + : dataFH(dataFH), shadowFile(shadowFile) {} + +void ColumnReadWriter::readFromPage(page_idx_t pageIdx, + const std::function& readFunc) const { + // For constant compression, call read on a nullptr since there is no data on disk and + // decompression only requires metadata + if (pageIdx == INVALID_PAGE_IDX) { + return readFunc(nullptr); + } + dataFH->optimisticReadPage(pageIdx, readFunc); +} + +void ColumnReadWriter::updatePageWithCursor(PageCursor cursor, + const std::function& writeOp) const { + if (cursor.pageIdx == INVALID_PAGE_IDX) { + writeOp(nullptr, cursor.elemPosInPage); + return; + } + KU_ASSERT(cursor.pageIdx < dataFH->getNumPages()); + + ShadowUtils::updatePage(*dataFH, cursor.pageIdx, false /*insertingNewPage*/, *shadowFile, + [&](auto frame) { writeOp(frame, cursor.elemPosInPage); }); +} + +// This function returns the page pageIdx of the page where element will be found and the pos of +// the element in the page as the offset. +static PageCursor getPageCursorForPos(uint64_t elementPos, uint32_t numElementsPerPage) { + KU_ASSERT((elementPos / numElementsPerPage) < UINT32_MAX); + return PageCursor{static_cast(elementPos / numElementsPerPage), + static_cast(elementPos % numElementsPerPage)}; +} + +PageCursor ColumnReadWriter::getPageCursorForOffsetInGroup(offset_t offsetInSegment, + page_idx_t groupPageIdx, uint64_t numValuesPerPage) { + auto pageCursor = getPageCursorForPos(offsetInSegment, numValuesPerPage); + pageCursor.pageIdx += groupPageIdx; + return pageCursor; +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/compression_flush_buffer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/compression_flush_buffer.cpp new file mode 100644 index 0000000000..7408eddb4a --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/compression_flush_buffer.cpp @@ -0,0 +1,173 @@ +#include "storage/table/compression_flush_buffer.h" + +#include + +#include "common/types/types.h" +#include "storage/file_handle.h" +#include "storage/page_manager.h" +#include "storage/table/column_chunk_data.h" +#include + +namespace lbug::storage { +using namespace common; +using namespace transaction; + +ColumnChunkMetadata uncompressedFlushBuffer(std::span buffer, FileHandle* dataFH, + const PageRange& entry, const ColumnChunkMetadata& metadata) { + KU_ASSERT(dataFH->getNumPages() >= entry.startPageIdx + entry.numPages); + KU_ASSERT(buffer.size_bytes() <= entry.numPages * LBUG_PAGE_SIZE); + dataFH->writePagesToFile(buffer.data(), buffer.size(), entry.startPageIdx); + return ColumnChunkMetadata(entry.startPageIdx, entry.numPages, metadata.numValues, + metadata.compMeta); +} + +ColumnChunkMetadata CompressedFlushBuffer::operator()(std::span buffer, + FileHandle* dataFH, const PageRange& entry, const ColumnChunkMetadata& metadata) const { + auto valuesRemaining = metadata.numValues; + const uint8_t* bufferStart = buffer.data(); + const auto compressedBuffer = std::make_unique(LBUG_PAGE_SIZE); + auto numPages = 0u; + const auto numValuesPerPage = metadata.compMeta.numValues(LBUG_PAGE_SIZE, dataType); + KU_ASSERT(numValuesPerPage * entry.numPages >= metadata.numValues); + while (valuesRemaining > 0) { + const auto compressedSize = alg->compressNextPage(bufferStart, valuesRemaining, + compressedBuffer.get(), LBUG_PAGE_SIZE, metadata.compMeta); + // Avoid underflows (when data is compressed to nothing, numValuesPerPage may be + // UINT64_MAX) + if (numValuesPerPage > valuesRemaining) { + valuesRemaining = 0; + } else { + valuesRemaining -= numValuesPerPage; + } + if (compressedSize < LBUG_PAGE_SIZE) { + memset(compressedBuffer.get() + compressedSize, 0, LBUG_PAGE_SIZE - compressedSize); + } + KU_ASSERT(numPages < entry.numPages); + KU_ASSERT(dataFH->getNumPages() >= entry.startPageIdx + numPages); + dataFH->writePageToFile(compressedBuffer.get(), entry.startPageIdx + numPages); + numPages++; + } + // Make sure that the on-disk file is the right length + if (!dataFH->isInMemoryMode() && numPages < entry.numPages) { + memset(compressedBuffer.get(), 0, LBUG_PAGE_SIZE); + while (numPages < entry.numPages) { + dataFH->writePageToFile(compressedBuffer.get(), entry.startPageIdx + numPages); + ++numPages; + } + } + return ColumnChunkMetadata(entry.startPageIdx, entry.numPages, metadata.numValues, + metadata.compMeta); +} + +namespace { +template +std::pair, uint64_t> flushCompressedFloats(const CompressionAlg& alg, + PhysicalTypeID dataType, std::span buffer, FileHandle* dataFH, + const PageRange& entry, const ColumnChunkMetadata& metadata) { + const auto& castedAlg = ku_dynamic_cast&>(alg); + + const auto* floatMetadata = metadata.compMeta.floatMetadata(); + KU_ASSERT(floatMetadata->exceptionCapacity >= floatMetadata->exceptionCount); + + auto valuesRemaining = metadata.numValues; + KU_ASSERT(valuesRemaining <= buffer.size_bytes() / sizeof(T)); + + const size_t exceptionBufferSize = + EncodeException::numPagesFromExceptions(floatMetadata->exceptionCapacity) * + LBUG_PAGE_SIZE; + auto exceptionBuffer = std::make_unique(exceptionBufferSize); + std::byte* exceptionBufferCursor = reinterpret_cast(exceptionBuffer.get()); + + const auto numValuesPerPage = metadata.compMeta.numValues(LBUG_PAGE_SIZE, dataType); + KU_ASSERT(numValuesPerPage * metadata.getNumDataPages(dataType) >= metadata.numValues); + + const auto compressedBuffer = std::make_unique(LBUG_PAGE_SIZE); + const uint8_t* bufferCursor = buffer.data(); + auto numPages = 0u; + size_t remainingExceptionBufferSize = exceptionBufferSize; + RUNTIME_CHECK(size_t totalExceptionCount = 0); + + while (valuesRemaining > 0) { + uint64_t pageExceptionCount = 0; + (void)castedAlg.compressNextPageWithExceptions(bufferCursor, + metadata.numValues - valuesRemaining, valuesRemaining, compressedBuffer.get(), + LBUG_PAGE_SIZE, EncodeExceptionView{exceptionBufferCursor}, + remainingExceptionBufferSize, pageExceptionCount, metadata.compMeta); + + exceptionBufferCursor += pageExceptionCount * EncodeException::sizeInBytes(); + remainingExceptionBufferSize -= pageExceptionCount * EncodeException::sizeInBytes(); + RUNTIME_CHECK(totalExceptionCount += pageExceptionCount); + + // Avoid underflows (when data is compressed to nothing, numValuesPerPage may be + // UINT64_MAX) + if (numValuesPerPage > valuesRemaining) { + valuesRemaining = 0; + } else { + valuesRemaining -= numValuesPerPage; + } + KU_ASSERT(numPages < entry.numPages); + KU_ASSERT(dataFH->getNumPages() >= entry.startPageIdx + numPages); + dataFH->writePageToFile(compressedBuffer.get(), entry.startPageIdx + numPages); + numPages++; + } + + KU_ASSERT(totalExceptionCount == floatMetadata->exceptionCount); + + return {std::move(exceptionBuffer), exceptionBufferSize}; +} + +template +void flushALPExceptions(std::span exceptionBuffer, FileHandle* dataFH, + const PageRange& entry, const ColumnChunkMetadata& metadata) { + const auto encodedType = std::is_same_v ? PhysicalTypeID::ALP_EXCEPTION_FLOAT : + PhysicalTypeID::ALP_EXCEPTION_DOUBLE; + // we don't care about the min/max values for exceptions + const auto preExceptionMetadata = uncompressedGetMetadata(encodedType, + metadata.compMeta.floatMetadata()->exceptionCapacity, StorageValue{0}, StorageValue{0}); + + const auto exceptionStartPageIdx = + entry.startPageIdx + entry.numPages - preExceptionMetadata.getNumPages(); + KU_ASSERT(exceptionStartPageIdx + preExceptionMetadata.getNumPages() <= dataFH->getNumPages()); + PageRange exceptionBlock{exceptionStartPageIdx, preExceptionMetadata.getNumPages()}; + + CompressedFlushBuffer exceptionFlushBuffer{ + std::make_shared(EncodeException::sizeInBytes()), encodedType}; + (void)exceptionFlushBuffer.operator()(exceptionBuffer, dataFH, exceptionBlock, + preExceptionMetadata); +} +} // namespace + +template +CompressedFloatFlushBuffer::CompressedFloatFlushBuffer(std::shared_ptr alg, + PhysicalTypeID dataType) + : alg{std::move(alg)}, dataType{dataType} {} + +template +CompressedFloatFlushBuffer::CompressedFloatFlushBuffer(std::shared_ptr alg, + const LogicalType& dataType) + : CompressedFloatFlushBuffer(alg, dataType.getPhysicalType()) {} + +template +ColumnChunkMetadata CompressedFloatFlushBuffer::operator()(std::span buffer, + FileHandle* dataFH, const PageRange& entry, const ColumnChunkMetadata& metadata) const { + if (metadata.compMeta.compression == CompressionType::UNCOMPRESSED) { + return CompressedFlushBuffer{std::make_shared(dataType), dataType}.operator()( + buffer, dataFH, entry, metadata); + } + // FlushBuffer should not be called with constant compression + KU_ASSERT(metadata.compMeta.compression == CompressionType::ALP); + + auto [exceptionBuffer, exceptionBufferSize] = + flushCompressedFloats(*alg, dataType, buffer, dataFH, entry, metadata); + + flushALPExceptions(std::span(exceptionBuffer.get(), exceptionBufferSize), + dataFH, entry, metadata); + + return ColumnChunkMetadata(entry.startPageIdx, entry.numPages, metadata.numValues, + metadata.compMeta); +} + +template class CompressedFloatFlushBuffer; +template class CompressedFloatFlushBuffer; + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/csr_chunked_node_group.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/csr_chunked_node_group.cpp new file mode 100644 index 0000000000..cd36b63101 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/csr_chunked_node_group.cpp @@ -0,0 +1,470 @@ +#include "storage/table/csr_chunked_node_group.h" + +#include "common/serializer/deserializer.h" +#include "common/types/types.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/enums/residency_state.h" +#include "storage/page_allocator.h" +#include "storage/storage_utils.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/csr_node_group.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +CSRRegion::CSRRegion(idx_t regionIdx, idx_t level) : regionIdx{regionIdx}, level{level} { + const auto leftLeafRegion = regionIdx << level; + leftNodeOffset = leftLeafRegion << StorageConfig::CSR_LEAF_REGION_SIZE_LOG2; + rightNodeOffset = leftNodeOffset + (StorageConfig::CSR_LEAF_REGION_SIZE << level) - 1; + if (rightNodeOffset >= StorageConfig::NODE_GROUP_SIZE) { + // The max right node offset should be NODE_GROUP_SIZE - 1. + rightNodeOffset = StorageConfig::NODE_GROUP_SIZE - 1; + } +} + +bool CSRRegion::isWithin(const CSRRegion& other) const { + if (other.level <= level) { + return false; + } + const auto leftRegionIdx = getLeftLeafRegionIdx(); + const auto rightRegionIdx = getRightLeafRegionIdx(); + const auto otherLeftRegionIdx = other.getLeftLeafRegionIdx(); + const auto otherRightRegionIdx = other.getRightLeafRegionIdx(); + return leftRegionIdx >= otherLeftRegionIdx && rightRegionIdx <= otherRightRegionIdx; +} + +CSRRegion CSRRegion::upgradeLevel(const std::vector& leafRegions, + const CSRRegion& region) { + const auto regionIdx = region.regionIdx >> 1; + CSRRegion newRegion{regionIdx, region.level + 1}; + newRegion.hasUpdates.resize(region.hasUpdates.size(), false); + const idx_t leftLeafRegionIdx = newRegion.getLeftLeafRegionIdx(); + const idx_t rightLeafRegionIdx = newRegion.getRightLeafRegionIdx(); + for (auto leafRegionIdx = leftLeafRegionIdx; leafRegionIdx <= rightLeafRegionIdx; + leafRegionIdx++) { + KU_ASSERT(leafRegionIdx < leafRegions.size()); + newRegion.sizeChange += leafRegions[leafRegionIdx].sizeChange; + newRegion.hasPersistentDeletions |= leafRegions[leafRegionIdx].hasPersistentDeletions; + newRegion.hasInsertions |= leafRegions[leafRegionIdx].hasInsertions; + for (auto columnID = 0u; columnID < leafRegions[leafRegionIdx].hasUpdates.size(); + columnID++) { + newRegion.hasUpdates[columnID] = + static_cast(newRegion.hasUpdates[columnID]) || + static_cast(leafRegions[leafRegionIdx].hasUpdates[columnID]); + } + } + return newRegion; +} + +ChunkedCSRHeader::ChunkedCSRHeader(MemoryManager& memoryManager, bool enableCompression, + uint64_t capacity, ResidencyState residencyState) { + offset = std::make_unique(memoryManager, LogicalType::UINT64(), capacity, + enableCompression, residencyState, false); + length = std::make_unique(memoryManager, LogicalType::UINT64(), capacity, + enableCompression, residencyState, false); +} + +offset_t ChunkedCSRHeader::getStartCSROffset(offset_t nodeOffset) const { + // TODO(Guodong): I think we can simplify the check here by getting rid of some of the + // conditions. + const auto numValues = offset->getNumValues(); + if (nodeOffset == 0 || numValues == 0) { + return 0; + } + if (randomLookup) { + return offset->getValue(0); + } + return offset->getValue(nodeOffset >= numValues ? (numValues - 1) : nodeOffset - 1); +} + +offset_t ChunkedCSRHeader::getEndCSROffset(offset_t nodeOffset) const { + // TODO(Guodong): I think we can simplify the check here by getting rid of some of the + // conditions. + const auto numValues = offset->getNumValues(); + if (numValues == 0) { + return 0; + } + if (randomLookup) { + return offset->getValue(nodeOffset == 0 ? 0 : 1); + } + return offset->getValue(nodeOffset >= numValues ? (numValues - 1) : nodeOffset); +} + +length_t ChunkedCSRHeader::getCSRLength(offset_t nodeOffset) const { + const auto offset = randomLookup ? 0 : nodeOffset; + return offset >= length->getNumValues() ? 0 : length->getValue(offset); +} + +length_t ChunkedCSRHeader::getGapSize(offset_t nodeOffset) const { + return getEndCSROffset(nodeOffset) - getStartCSROffset(nodeOffset) - getCSRLength(nodeOffset); +} + +bool ChunkedCSRHeader::sanityCheck() const { + if (offset->getNumValues() != length->getNumValues()) { + return false; + } + if (offset->getNumValues() == 0) { + return true; + } + if (offset->getValue(0) < length->getValue(0)) { + return false; + } + for (auto i = 1u; i < offset->getNumValues(); i++) { + if (offset->getValue(i - 1) + length->getValue(i) > + offset->getValue(i)) { + return false; + } + } + return true; +} + +offset_vec_t ChunkedCSRHeader::populateStartCSROffsetsFromLength(bool leaveGaps) const { + const auto numNodes = length->getNumValues(); + const auto numLeafRegions = getNumRegions(); + offset_t leftCSROffset = 0; + offset_vec_t rightCSROffsetOfRegions; + rightCSROffsetOfRegions.reserve(numLeafRegions); + for (auto regionIdx = 0u; regionIdx < numLeafRegions; regionIdx++) { + CSRRegion region{regionIdx, 0 /* level*/}; + length_t numRelsInRegion = 0; + const auto rightNodeOffset = std::min(region.rightNodeOffset, numNodes - 1); + // Populate start csr offset for each node in the region. + offset->mapValues( + [&](auto& value, auto nodeOffset) { + value = leftCSROffset + numRelsInRegion; + numRelsInRegion += getCSRLength(nodeOffset); + }, + region.leftNodeOffset, rightNodeOffset); + // Update lastLeftCSROffset for next region. + leftCSROffset += numRelsInRegion; + if (leaveGaps) { + leftCSROffset += computeGapFromLength(numRelsInRegion); + } + rightCSROffsetOfRegions.push_back(leftCSROffset); + } + return rightCSROffsetOfRegions; +} + +void ChunkedCSRHeader::populateEndCSROffsetFromStartAndLength() const { + [[maybe_unused]] const auto numNodes = length->getNumValues(); + KU_ASSERT(offset->getNumValues() == numNodes); + // TODO(bmwinger): maybe there's a way of also vectorizing this for the length chunk, E.g. a + // forEach over two values + offset->mapValues( + [&](offset_t& offset, auto i) { offset += length->getValue(i); }); +} + +void ChunkedCSRHeader::finalizeCSRRegionEndOffsets( + const offset_vec_t& rightCSROffsetOfRegions) const { + const auto numNodes = length->getNumValues(); + const auto numLeafRegions = getNumRegions(); + KU_ASSERT(numLeafRegions == rightCSROffsetOfRegions.size()); + for (auto regionIdx = 0u; regionIdx < numLeafRegions; regionIdx++) { + CSRRegion region{regionIdx, 0 /* level*/}; + const auto rightNodeOffset = std::min(region.rightNodeOffset, numNodes - 1); + offset->setValue(rightCSROffsetOfRegions[regionIdx], rightNodeOffset); + } +} + +idx_t ChunkedCSRHeader::getNumRegions() const { + const auto numNodes = length->getNumValues(); + KU_ASSERT(offset->getNumValues() == numNodes); + return (numNodes + StorageConfig::CSR_LEAF_REGION_SIZE - 1) / + StorageConfig::CSR_LEAF_REGION_SIZE; +} + +void ChunkedCSRHeader::populateRegionCSROffsets(const CSRRegion& region, + const ChunkedCSRHeader& oldHeader) const { + KU_ASSERT(region.level <= CSRNodeGroup::DEFAULT_PACKED_CSR_INFO.calibratorTreeHeight); + const auto leftNodeOffset = region.leftNodeOffset; + const auto rightNodeOffset = region.rightNodeOffset; + const auto leftCSROffset = oldHeader.getStartCSROffset(leftNodeOffset); + const auto oldRightCSROffset = oldHeader.getEndCSROffset(rightNodeOffset); + length_t numRelsInRegion = 0u; + // TODO(bmwinger): should be able to vectorize this somewhat + for (auto i = leftNodeOffset; i <= rightNodeOffset; i++) { + numRelsInRegion += length->getValue(i); + offset->setValue(leftCSROffset + numRelsInRegion, i); + } + // We should keep the region stable and the old right CSR offset is the end of the region. + KU_ASSERT(offset->getValue(rightNodeOffset) <= oldRightCSROffset); + offset->setValue(oldRightCSROffset, rightNodeOffset); +} + +void ChunkedCSRHeader::populateEndCSROffsets(const offset_vec_t& gaps) const { + KU_ASSERT(offset->getNumValues() == length->getNumValues()); + KU_ASSERT(offset->getNumValues() == gaps.size()); + offset->mapValues([&](offset_t& offset, auto i) { offset = gaps[i]; }); +} + +length_t ChunkedCSRHeader::computeGapFromLength(length_t length) { + return StorageUtils::divideAndRoundUpTo(length, StorageConstants::PACKED_CSR_DENSITY) - length; +} + +std::unique_ptr InMemChunkedCSRNodeGroup::flush( + transaction::Transaction* transaction, PageAllocator& pageAllocator) { + auto csrOffset = flushInternal(*csrHeader.offset, pageAllocator); + auto csrLength = flushInternal(*csrHeader.length, pageAllocator); + std::vector> flushedChunks(getNumColumns()); + for (auto i = 0u; i < getNumColumns(); i++) { + flushedChunks[i] = flushInternal(getColumnChunk(i), pageAllocator); + } + ChunkedCSRHeader newCSRHeader{std::move(csrOffset), std::move(csrLength)}; + auto flushedChunkedGroup = std::make_unique(std::move(newCSRHeader), + std::move(flushedChunks), 0 /*startRowIdx*/); + flushedChunkedGroup->versionInfo = std::make_unique(); + KU_ASSERT(numRows == flushedChunkedGroup->getNumRows()); + flushedChunkedGroup->versionInfo->append(transaction->getID(), 0, numRows); + return flushedChunkedGroup; +} + +void ChunkedCSRNodeGroup ::reclaimStorage(PageAllocator& pageAllocator) const { + ChunkedNodeGroup::reclaimStorage(pageAllocator); + if (csrHeader.offset) { + csrHeader.offset->reclaimStorage(pageAllocator); + } + if (csrHeader.length) { + csrHeader.length->reclaimStorage(pageAllocator); + } +} + +void ChunkedCSRNodeGroup::scanCSRHeader(MemoryManager& memoryManager, + CSRNodeGroupCheckpointState& csrState) const { + if (!csrState.oldHeader) { + csrState.oldHeader = std::make_unique(memoryManager, + false /*enableCompression*/, StorageConfig::NODE_GROUP_SIZE); + } + ChunkState headerChunkState; + KU_ASSERT(csrHeader.offset->getResidencyState() == ResidencyState::ON_DISK); + KU_ASSERT(csrHeader.length->getResidencyState() == ResidencyState::ON_DISK); + csrHeader.offset->initializeScanState(headerChunkState, csrState.csrOffsetColumn); + KU_ASSERT(csrState.csrOffsetColumn && csrState.csrLengthColumn); + csrState.csrOffsetColumn->scan(headerChunkState, csrState.oldHeader->offset.get()); + csrHeader.length->initializeScanState(headerChunkState, csrState.csrLengthColumn); + csrState.csrLengthColumn->scan(headerChunkState, csrState.oldHeader->length.get()); +} + +void ChunkedCSRNodeGroup::serialize(Serializer& serializer) const { + KU_ASSERT(csrHeader.offset && csrHeader.length); + serializer.writeDebuggingInfo("csr_header_offset"); + csrHeader.offset->serialize(serializer); + serializer.writeDebuggingInfo("csr_header_length"); + csrHeader.length->serialize(serializer); + ChunkedNodeGroup::serialize(serializer); +} + +std::unique_ptr ChunkedCSRNodeGroup::deserialize(MemoryManager& memoryManager, + Deserializer& deSer) { + std::string key; + deSer.validateDebuggingInfo(key, "csr_header_offset"); + auto offset = ColumnChunk::deserialize(memoryManager, deSer); + deSer.validateDebuggingInfo(key, "csr_header_length"); + auto length = ColumnChunk::deserialize(memoryManager, deSer); + // TODO(Guodong): Rework to reuse ChunkedNodeGroup::deserialize(). + std::vector> chunks; + deSer.validateDebuggingInfo(key, "chunks"); + deSer.deserializeVectorOfPtrs(chunks, + [&](Deserializer& deser) { return ColumnChunk::deserialize(memoryManager, deser); }); + deSer.validateDebuggingInfo(key, "startRowIdx"); + row_idx_t startRowIdx = 0; + deSer.deserializeValue(startRowIdx); + auto chunkedGroup = std::make_unique( + ChunkedCSRHeader{std::move(offset), std::move(length)}, std::move(chunks), startRowIdx); + bool hasVersions = false; + deSer.validateDebuggingInfo(key, "has_version_info"); + deSer.deserializeValue(hasVersions); + if (hasVersions) { + deSer.validateDebuggingInfo(key, "version_info"); + chunkedGroup->versionInfo = VersionInfo::deserialize(deSer); + } + return chunkedGroup; +} + +ChunkedCSRNodeGroup::ChunkedCSRNodeGroup(InMemChunkedCSRNodeGroup& base, + const std::vector& selectedColumns) + : ChunkedNodeGroup{base, selectedColumns}, + csrHeader{std::make_unique(true /*enableCompression*/, + std::move(base.csrHeader.offset)), + std::make_unique(true /*enableCompression*/, + std::move(base.csrHeader.length))} {} + +void InMemChunkedCSRHeader::fillDefaultValues(const offset_t newNumValues) const { + const auto lastCSROffset = getEndCSROffset(length->getNumValues() - 1); + for (auto i = length->getNumValues(); i < newNumValues; i++) { + offset->setValue(lastCSROffset, i); + length->setValue(0, i); + } + KU_ASSERT( + offset->getNumValues() >= newNumValues && length->getNumValues() == offset->getNumValues()); +} + +InMemChunkedCSRHeader::InMemChunkedCSRHeader(MemoryManager& memoryManager, bool enableCompression, + uint64_t capacity) { + offset = ColumnChunkFactory::createColumnChunkData(memoryManager, LogicalType::UINT64(), + enableCompression, capacity, ResidencyState::IN_MEMORY, false); + length = ColumnChunkFactory::createColumnChunkData(memoryManager, LogicalType::UINT64(), + enableCompression, capacity, ResidencyState::IN_MEMORY, false); +} + +offset_t InMemChunkedCSRHeader::getStartCSROffset(offset_t nodeOffset) const { + // TODO(Guodong): I think we can simplify the check here by getting rid of some of the + // conditions. + const auto numValues = offset->getNumValues(); + if (nodeOffset == 0 || numValues == 0) { + return 0; + } + if (randomLookup) { + return offset->getValue(0); + } + return offset->getValue(nodeOffset >= numValues ? (numValues - 1) : nodeOffset - 1); +} + +offset_t InMemChunkedCSRHeader::getEndCSROffset(offset_t nodeOffset) const { + // TODO(Guodong): I think we can simplify the check here by getting rid of some of the + // conditions. + const auto numValues = offset->getNumValues(); + if (numValues == 0) { + return 0; + } + if (randomLookup) { + return offset->getValue(nodeOffset == 0 ? 0 : 1); + } + return offset->getValue(nodeOffset >= numValues ? (numValues - 1) : nodeOffset); +} + +length_t InMemChunkedCSRHeader::getCSRLength(offset_t nodeOffset) const { + const auto offset = randomLookup ? 0 : nodeOffset; + return offset >= length->getNumValues() ? 0 : length->getValue(offset); +} + +length_t InMemChunkedCSRHeader::getGapSize(offset_t nodeOffset) const { + return getEndCSROffset(nodeOffset) - getStartCSROffset(nodeOffset) - getCSRLength(nodeOffset); +} + +bool InMemChunkedCSRHeader::sanityCheck() const { + if (offset->getNumValues() != length->getNumValues()) { + return false; + } + if (offset->getNumValues() == 0) { + return true; + } + if (offset->getValue(0) < length->getValue(0)) { + return false; + } + for (auto i = 1u; i < offset->getNumValues(); i++) { + if (offset->getValue(i - 1) + length->getValue(i) > + offset->getValue(i)) { + return false; + } + } + return true; +} + +void InMemChunkedCSRHeader::copyFrom(const InMemChunkedCSRHeader& other) const { + KU_ASSERT(offset->getNumValues() == length->getNumValues()); + KU_ASSERT(other.offset->getNumValues() == other.length->getNumValues()); + KU_ASSERT(other.offset->getCapacity() == offset->getCapacity()); + const auto numOtherValues = other.offset->getNumValues(); + memcpy(offset->getData(), other.offset->getData(), numOtherValues * sizeof(offset_t)); + memcpy(length->getData(), other.length->getData(), numOtherValues * sizeof(length_t)); + const auto lastOffsetInOtherHeader = other.getEndCSROffset(numOtherValues); + const auto numValues = offset->getNumValues(); + for (auto i = numOtherValues; i < numValues; i++) { + offset->setValue(lastOffsetInOtherHeader, i); + length->setValue(0, i); + } +} + +offset_vec_t InMemChunkedCSRHeader::populateStartCSROffsetsFromLength(bool leaveGaps) const { + const auto numNodes = length->getNumValues(); + const auto numLeafRegions = getNumRegions(); + offset_t leftCSROffset = 0; + offset_vec_t rightCSROffsetOfRegions; + rightCSROffsetOfRegions.reserve(numLeafRegions); + for (auto regionIdx = 0u; regionIdx < numLeafRegions; regionIdx++) { + CSRRegion region{regionIdx, 0 /* level*/}; + length_t numRelsInRegion = 0; + const auto rightNodeOffset = std::min(region.rightNodeOffset, numNodes - 1); + // Populate start csr offset for each node in the region. + for (auto nodeOffset = region.leftNodeOffset; nodeOffset <= rightNodeOffset; nodeOffset++) { + offset->setValue(leftCSROffset + numRelsInRegion, nodeOffset); + numRelsInRegion += getCSRLength(nodeOffset); + } + // Update lastLeftCSROffset for next region. + leftCSROffset += numRelsInRegion; + if (leaveGaps) { + leftCSROffset += computeGapFromLength(numRelsInRegion); + } + rightCSROffsetOfRegions.push_back(leftCSROffset); + } + return rightCSROffsetOfRegions; +} + +void InMemChunkedCSRHeader::populateEndCSROffsetFromStartAndLength() const { + const auto numNodes = length->getNumValues(); + KU_ASSERT(offset->getNumValues() == numNodes); + const auto csrOffsets = reinterpret_cast(offset->getData()); + const auto csrLengths = reinterpret_cast(length->getData()); + for (auto i = 0u; i < numNodes; i++) { + csrOffsets[i] = csrOffsets[i] + csrLengths[i]; + } +} + +void InMemChunkedCSRHeader::finalizeCSRRegionEndOffsets( + const offset_vec_t& rightCSROffsetOfRegions) const { + const auto numNodes = length->getNumValues(); + const auto numLeafRegions = getNumRegions(); + KU_ASSERT(numLeafRegions == rightCSROffsetOfRegions.size()); + for (auto regionIdx = 0u; regionIdx < numLeafRegions; regionIdx++) { + CSRRegion region{regionIdx, 0 /* level*/}; + const auto rightNodeOffset = std::min(region.rightNodeOffset, numNodes - 1); + offset->setValue(rightCSROffsetOfRegions[regionIdx], rightNodeOffset); + } +} + +idx_t InMemChunkedCSRHeader::getNumRegions() const { + const auto numNodes = length->getNumValues(); + KU_ASSERT(offset->getNumValues() == numNodes); + return (numNodes + StorageConfig::CSR_LEAF_REGION_SIZE - 1) / + StorageConfig::CSR_LEAF_REGION_SIZE; +} + +void InMemChunkedCSRHeader::populateRegionCSROffsets(const CSRRegion& region, + const InMemChunkedCSRHeader& oldHeader) const { + KU_ASSERT(region.level <= CSRNodeGroup::DEFAULT_PACKED_CSR_INFO.calibratorTreeHeight); + const auto leftNodeOffset = region.leftNodeOffset; + const auto rightNodeOffset = region.rightNodeOffset; + const auto leftCSROffset = oldHeader.getStartCSROffset(leftNodeOffset); + const auto oldRightCSROffset = oldHeader.getEndCSROffset(rightNodeOffset); + const auto csrOffsets = reinterpret_cast(offset->getData()); + const auto csrLengths = reinterpret_cast(length->getData()); + length_t numRelsInRegion = 0u; + for (auto i = leftNodeOffset; i <= rightNodeOffset; i++) { + numRelsInRegion += csrLengths[i]; + csrOffsets[i] = leftCSROffset + numRelsInRegion; + } + // We should keep the region stable and the old right CSR offset is the end of the region. + KU_ASSERT(csrOffsets[rightNodeOffset] <= oldRightCSROffset); + csrOffsets[rightNodeOffset] = oldRightCSROffset; +} + +void InMemChunkedCSRHeader::populateEndCSROffsets(const offset_vec_t& gaps) const { + const auto csrOffsets = reinterpret_cast(offset->getData()); + KU_ASSERT(offset->getNumValues() == length->getNumValues()); + KU_ASSERT(offset->getNumValues() == gaps.size()); + for (auto i = 0u; i < offset->getNumValues(); i++) { + csrOffsets[i] += gaps[i]; + } +} + +length_t InMemChunkedCSRHeader::computeGapFromLength(length_t length) { + return StorageUtils::divideAndRoundUpTo(length, StorageConstants::PACKED_CSR_DENSITY) - length; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/csr_node_group.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/csr_node_group.cpp new file mode 100644 index 0000000000..e75fb98188 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/csr_node_group.cpp @@ -0,0 +1,1210 @@ +#include "storage/table/csr_node_group.h" + +#include "common/constants.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/csr_chunked_node_group.h" +#include "storage/table/lazy_segment_scanner.h" +#include "storage/table/rel_table.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +bool CSRNodeGroupScanState::tryScanCachedTuples(RelTableScanState& tableScanState) { + if (numCachedRows == 0 || + tableScanState.currBoundNodeIdx >= tableScanState.cachedBoundNodeSelVector.getSelSize()) { + return false; + } + const auto boundNodeOffset = tableScanState.nodeIDVector->readNodeOffset( + tableScanState.cachedBoundNodeSelVector[tableScanState.currBoundNodeIdx]); + const auto boundNodeOffsetInGroup = boundNodeOffset % StorageConfig::NODE_GROUP_SIZE; + const auto startCSROffset = header->getStartCSROffset(boundNodeOffsetInGroup); + const auto csrLength = header->getCSRLength(boundNodeOffsetInGroup); + nextCachedRowToScan = std::max(nextCachedRowToScan, startCSROffset); + if (nextCachedRowToScan >= nextRowToScan || + nextCachedRowToScan < nextRowToScan - numCachedRows) { + // Out of the bound of cached rows. + return false; + } + KU_ASSERT(nextCachedRowToScan >= nextRowToScan - numCachedRows); + const auto numRowsToScan = + std::min(nextRowToScan, startCSROffset + csrLength) - nextCachedRowToScan; + const auto startCachedRow = nextCachedRowToScan - (nextRowToScan - numCachedRows); + if (cachedScannedVectorsSelBitset.has_value()) { + auto cachedScannedVectorsSelBitset = *this->cachedScannedVectorsSelBitset; + auto numSelected = 0u; + tableScanState.outState->getSelVectorUnsafe().setToFiltered(); + for (auto i = 0u; i < numRowsToScan; i++) { + const auto rowIdx = startCachedRow + i; + tableScanState.outState->getSelVectorUnsafe()[numSelected] = rowIdx; + numSelected += cachedScannedVectorsSelBitset[rowIdx]; + } + tableScanState.outState->getSelVectorUnsafe().setSelSize(numSelected); + } else { + tableScanState.outState->getSelVectorUnsafe().setRange(startCachedRow, numRowsToScan); + } + tableScanState.setNodeIDVectorToFlat( + tableScanState.cachedBoundNodeSelVector[tableScanState.currBoundNodeIdx]); + nextCachedRowToScan += numRowsToScan; + if ((startCSROffset + csrLength) == nextCachedRowToScan) { + tableScanState.currBoundNodeIdx++; + nextCachedRowToScan = 0; + } + return true; +} + +void CSRNodeGroup::initializeScanState(const Transaction* transaction, + TableScanState& state) const { + auto& relScanState = state.cast(); + KU_ASSERT(relScanState.nodeGroupScanState); + auto& nodeGroupScanState = relScanState.nodeGroupScanState->cast(); + if (relScanState.nodeGroupIdx != nodeGroupIdx || relScanState.randomLookup) { + relScanState.nodeGroupIdx = nodeGroupIdx; + if (persistentChunkGroup) { + initScanForCommittedPersistent(transaction, relScanState, nodeGroupScanState); + } + } + // Switch to a new Vector of bound nodes (i.e., new csr lists) in the node group. + if (persistentChunkGroup) { + nodeGroupScanState.nextRowToScan = 0; + nodeGroupScanState.numCachedRows = 0; + nodeGroupScanState.nextCachedRowToScan = 0; + nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_PERSISTENT; + } else if (csrIndex) { + initScanForCommittedInMem(relScanState, nodeGroupScanState); + } else { + nodeGroupScanState.source = CSRNodeGroupScanSource::NONE; + nodeGroupScanState.nextRowToScan = 0; + } +} + +void CSRNodeGroup::initScanForCommittedPersistent(const Transaction* transaction, + RelTableScanState& relScanState, CSRNodeGroupScanState& nodeGroupScanState) const { + // Scan the csr header chunks from disk. + ChunkState offsetState, lengthState; + auto& csrChunkGroup = persistentChunkGroup->cast(); + const auto& csrHeader = csrChunkGroup.getCSRHeader(); + // We are switching to a new node group. + // Initialize the scan states of a new node group for the csr header. + csrHeader.offset->initializeScanState(offsetState, relScanState.csrOffsetColumn); + csrHeader.length->initializeScanState(lengthState, relScanState.csrLengthColumn); + nodeGroupScanState.header->offset->setNumValues(0); + nodeGroupScanState.header->length->setNumValues(0); + // Initialize the scan states of a new node group for data columns. + for (auto i = 0u; i < relScanState.columnIDs.size(); i++) { + if (relScanState.columnIDs[i] == INVALID_COLUMN_ID || + relScanState.columnIDs[i] == ROW_IDX_COLUMN_ID) { + continue; + } + auto& chunk = persistentChunkGroup->getColumnChunk(relScanState.columnIDs[i]); + chunk.initializeScanState(nodeGroupScanState.chunkStates[i], relScanState.columns[i]); + } + KU_ASSERT(csrHeader.offset->getNumValues() == csrHeader.length->getNumValues()); + if (relScanState.randomLookup) { + auto pos = relScanState.nodeIDVector->state->getSelVector()[0]; + auto nodeOffset = relScanState.nodeIDVector->readNodeOffset(pos); + auto offsetInGroup = nodeOffset % StorageConfig::NODE_GROUP_SIZE; + auto offsetToScanFrom = offsetInGroup == 0 ? 0 : offsetInGroup - 1; + csrHeader.offset->scanCommitted(transaction, offsetState, + *nodeGroupScanState.header->offset, offsetToScanFrom, 1); + csrHeader.length->scanCommitted(transaction, lengthState, + *nodeGroupScanState.header->length, offsetInGroup, 1); + } else { + auto numBoundNodes = csrHeader.offset->getNumValues(); + csrHeader.offset->scanCommitted(transaction, offsetState, + *nodeGroupScanState.header->offset); + csrHeader.length->scanCommitted(transaction, lengthState, + *nodeGroupScanState.header->length); + nodeGroupScanState.numTotalRows = + nodeGroupScanState.header->getStartCSROffset(numBoundNodes); + } + nodeGroupScanState.header->randomLookup = relScanState.randomLookup; +} + +void CSRNodeGroup::initScanForCommittedInMem(RelTableScanState& relScanState, + CSRNodeGroupScanState& nodeGroupScanState) { + relScanState.currBoundNodeIdx = 0; + nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_IN_MEMORY; + nodeGroupScanState.nextRowToScan = 0; + nodeGroupScanState.numCachedRows = 0; + nodeGroupScanState.inMemCSRList.clear(); +} + +NodeGroupScanResult CSRNodeGroup::scan(const Transaction* transaction, + TableScanState& state) const { + auto& relScanState = state.cast(); + auto& nodeGroupScanState = relScanState.nodeGroupScanState->cast(); + while (true) { + switch (nodeGroupScanState.source) { + case CSRNodeGroupScanSource::COMMITTED_PERSISTENT: { + auto result = scanCommittedPersistent(transaction, relScanState, nodeGroupScanState); + if (result == NODE_GROUP_SCAN_EMPTY_RESULT && csrIndex) { + initScanForCommittedInMem(relScanState, nodeGroupScanState); + continue; + } + return result; + } + case CSRNodeGroupScanSource::COMMITTED_IN_MEMORY: { + relScanState.resetOutVectors(); + const auto result = scanCommittedInMem(transaction, relScanState, nodeGroupScanState); + if (result == NODE_GROUP_SCAN_EMPTY_RESULT) { + relScanState.outState->getSelVectorUnsafe().setSelSize(0); + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + return result; + } + case CSRNodeGroupScanSource::NONE: { + relScanState.outState->getSelVectorUnsafe().setSelSize(0); + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + default: { + KU_UNREACHABLE; + } + } + } +} + +NodeGroupScanResult CSRNodeGroup::scanCommittedPersistent(const Transaction* transaction, + RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { + if (tableState.cachedBoundNodeSelVector.getSelSize() == 1) { + // Note that we don't apply cache when there is only one bound node. + return scanCommittedPersistentWithoutCache(transaction, tableState, nodeGroupScanState); + } + return scanCommittedPersistentWithCache(transaction, tableState, nodeGroupScanState); +} + +NodeGroupScanResult CSRNodeGroup::scanCommittedPersistentWithCache(const Transaction* transaction, + RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { + while (true) { + while (nodeGroupScanState.tryScanCachedTuples(tableState)) { + if (tableState.outState->getSelVector().getSelSize() > 0) { + // Note: This is a dummy return value. + return NodeGroupScanResult{nodeGroupScanState.nextRowToScan, + tableState.outState->getSelVector().getSelSize()}; + } + } + if (nodeGroupScanState.nextRowToScan == nodeGroupScanState.numTotalRows || + tableState.currBoundNodeIdx >= tableState.cachedBoundNodeSelVector.getSelSize()) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + const auto currNodeOffset = tableState.nodeIDVector->readNodeOffset( + tableState.cachedBoundNodeSelVector[tableState.currBoundNodeIdx]); + const auto offsetInGroup = currNodeOffset % StorageConfig::NODE_GROUP_SIZE; + const auto startCSROffset = nodeGroupScanState.header->getStartCSROffset(offsetInGroup); + if (startCSROffset > nodeGroupScanState.nextRowToScan) { + nodeGroupScanState.nextRowToScan = startCSROffset; + } + KU_ASSERT(nodeGroupScanState.nextRowToScan <= nodeGroupScanState.numTotalRows); + const auto numToScan = + std::min(nodeGroupScanState.numTotalRows - nodeGroupScanState.nextRowToScan, + DEFAULT_VECTOR_CAPACITY); + persistentChunkGroup->scan(transaction, tableState, nodeGroupScanState, + nodeGroupScanState.nextRowToScan, numToScan); + nodeGroupScanState.numCachedRows = numToScan; + nodeGroupScanState.nextRowToScan += numToScan; + if (tableState.outState->getSelVector().isUnfiltered()) { + nodeGroupScanState.cachedScannedVectorsSelBitset.reset(); + } else { + nodeGroupScanState.cachedScannedVectorsSelBitset = + std::bitset(); + for (auto i = 0u; i < tableState.outState->getSelVector().getSelSize(); i++) { + nodeGroupScanState.cachedScannedVectorsSelBitset->set( + tableState.outState->getSelVector()[i], true); + } + } + } +} + +NodeGroupScanResult CSRNodeGroup::scanCommittedPersistentWithoutCache( + const Transaction* transaction, RelTableScanState& tableState, + CSRNodeGroupScanState& nodeGroupScanState) const { + const auto currNodeOffset = tableState.nodeIDVector->readNodeOffset( + tableState.cachedBoundNodeSelVector[tableState.currBoundNodeIdx]); + const auto offsetInGroup = currNodeOffset % StorageConfig::NODE_GROUP_SIZE; + const auto csrListLength = nodeGroupScanState.header->getCSRLength(offsetInGroup); + if (nodeGroupScanState.nextRowToScan == csrListLength) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + const auto startRow = nodeGroupScanState.header->getStartCSROffset(offsetInGroup) + + nodeGroupScanState.nextRowToScan; + const auto numToScan = + std::min(csrListLength - nodeGroupScanState.nextRowToScan, DEFAULT_VECTOR_CAPACITY); + persistentChunkGroup->scan(transaction, tableState, nodeGroupScanState, startRow, numToScan); + nodeGroupScanState.nextRowToScan += numToScan; + tableState.setNodeIDVectorToFlat( + tableState.cachedBoundNodeSelVector[tableState.currBoundNodeIdx]); + return NodeGroupScanResult{startRow, numToScan}; +} + +NodeGroupScanResult CSRNodeGroup::scanCommittedInMem(const Transaction* transaction, + RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { + while (true) { + if (tableState.currBoundNodeIdx >= tableState.cachedBoundNodeSelVector.getSelSize()) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + if (nodeGroupScanState.inMemCSRList.rowIndices.empty()) { + const auto boundNodePos = + tableState.cachedBoundNodeSelVector[tableState.currBoundNodeIdx]; + const auto boundNodeOffset = tableState.nodeIDVector->readNodeOffset(boundNodePos); + const auto offsetInGroup = boundNodeOffset % StorageConfig::NODE_GROUP_SIZE; + nodeGroupScanState.inMemCSRList = csrIndex->indices[offsetInGroup]; + } + if (!nodeGroupScanState.inMemCSRList.isSequential) { + KU_ASSERT(std::is_sorted(nodeGroupScanState.inMemCSRList.rowIndices.begin(), + nodeGroupScanState.inMemCSRList.rowIndices.end())); + } + auto scanResult = + nodeGroupScanState.inMemCSRList.isSequential ? + scanCommittedInMemSequential(transaction, tableState, nodeGroupScanState) : + scanCommittedInMemRandom(transaction, tableState, nodeGroupScanState); + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + tableState.currBoundNodeIdx++; + nodeGroupScanState.nextRowToScan = 0; + nodeGroupScanState.inMemCSRList.clear(); + } else { + tableState.setNodeIDVectorToFlat( + tableState.cachedBoundNodeSelVector[tableState.currBoundNodeIdx]); + return scanResult; + } + } +} + +NodeGroupScanResult CSRNodeGroup::scanCommittedInMemSequential(const Transaction* transaction, + const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { + const auto startRow = + nodeGroupScanState.inMemCSRList.rowIndices[0] + nodeGroupScanState.nextRowToScan; + auto numRows = + std::min(nodeGroupScanState.inMemCSRList.rowIndices[1] - nodeGroupScanState.nextRowToScan, + DEFAULT_VECTOR_CAPACITY); + auto [chunkIdx, startRowInChunk] = + StorageUtils::getQuotientRemainder(startRow, StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + numRows = std::min(numRows, StorageConfig::CHUNKED_NODE_GROUP_CAPACITY - startRowInChunk); + if (numRows == 0) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + const ChunkedNodeGroup* chunkedGroup = nullptr; + { + const auto lock = chunkedGroups.lock(); + chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + } + chunkedGroup->scan(transaction, tableState, nodeGroupScanState, startRowInChunk, numRows); + nodeGroupScanState.nextRowToScan += numRows; + return NodeGroupScanResult{startRow, numRows}; +} + +NodeGroupScanResult CSRNodeGroup::scanCommittedInMemRandom(const Transaction* transaction, + const RelTableScanState& tableState, CSRNodeGroupScanState& nodeGroupScanState) const { + const auto numRows = std::min(nodeGroupScanState.inMemCSRList.rowIndices.size() - + nodeGroupScanState.nextRowToScan, + DEFAULT_VECTOR_CAPACITY); + if (numRows == 0) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + row_idx_t nextRow = 0; + ChunkedNodeGroup* chunkedGroup = nullptr; + node_group_idx_t currentChunkIdx = INVALID_NODE_GROUP_IDX; + sel_t numSelected = 0; + while (nextRow < numRows) { + const auto rowIdx = + nodeGroupScanState.inMemCSRList.rowIndices[nextRow + nodeGroupScanState.nextRowToScan]; + auto [chunkIdx, rowInChunk] = + StorageUtils::getQuotientRemainder(rowIdx, StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + if (chunkIdx != currentChunkIdx) { + currentChunkIdx = chunkIdx; + const auto lock = chunkedGroups.lock(); + chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + } + KU_ASSERT(chunkedGroup); + numSelected += chunkedGroup->lookup(transaction, tableState, nodeGroupScanState, rowInChunk, + numSelected); + nextRow++; + } + nodeGroupScanState.nextRowToScan += numRows; + tableState.outState->getSelVectorUnsafe().setSelSize(numSelected); + return NodeGroupScanResult{0, numRows}; +} + +void CSRNodeGroup::appendChunkedCSRGroup(const Transaction* transaction, + const std::vector& columnIDs, InMemChunkedCSRNodeGroup& chunkedGroup) { + const auto& csrHeader = chunkedGroup.getCSRHeader(); + std::vector chunkedGroupForProperties(chunkedGroup.getNumColumns()); + for (auto i = 0u; i < chunkedGroup.getNumColumns(); i++) { + chunkedGroupForProperties[i] = &chunkedGroup.getColumnChunk(i); + } + auto startRow = NodeGroup::append(transaction, columnIDs, chunkedGroupForProperties, 0, + chunkedGroup.getNumRows()); + if (!csrIndex) { + csrIndex = std::make_unique(); + } + for (auto i = 0u; i < csrHeader.offset->getNumValues(); i++) { + const auto length = csrHeader.length->getValue(i); + updateCSRIndex(i, startRow, length); + startRow += length; + } +} + +void CSRNodeGroup::append(const Transaction* transaction, const std::vector& columnIDs, + offset_t boundOffsetInGroup, std::span chunks, row_idx_t startRowInChunks, + row_idx_t numRows) { + const auto startRow = + NodeGroup::append(transaction, columnIDs, chunks, startRowInChunks, numRows); + if (!csrIndex) { + csrIndex = std::make_unique(); + } + updateCSRIndex(boundOffsetInGroup, startRow, 1 /*length*/); +} + +void CSRNodeGroup::updateCSRIndex(offset_t boundNodeOffsetInGroup, row_idx_t startRow, + length_t length) const { + auto& nodeCSRIndex = csrIndex->indices[boundNodeOffsetInGroup]; + const auto isEmptyCSR = nodeCSRIndex.rowIndices.empty(); + const auto appendToEndOfCSR = + !isEmptyCSR && nodeCSRIndex.isSequential && + (nodeCSRIndex.rowIndices[0] + nodeCSRIndex.rowIndices[1] == startRow); + const bool sequential = isEmptyCSR || appendToEndOfCSR; + if (nodeCSRIndex.isSequential && !sequential) { + // Expand rowIndices for the node. + const auto csrListStartRow = nodeCSRIndex.rowIndices[0]; + const auto csrListLength = nodeCSRIndex.rowIndices[1]; + nodeCSRIndex.rowIndices.clear(); + nodeCSRIndex.rowIndices.reserve(csrListLength + length); + for (auto j = 0u; j < csrListLength; j++) { + nodeCSRIndex.rowIndices.push_back(csrListStartRow + j); + } + } + if (sequential) { + nodeCSRIndex.isSequential = true; + if (!nodeCSRIndex.rowIndices.empty()) { + KU_ASSERT(appendToEndOfCSR); + nodeCSRIndex.rowIndices[1] += length; + } else { + nodeCSRIndex.rowIndices.resize(2); + nodeCSRIndex.rowIndices[0] = startRow; + nodeCSRIndex.rowIndices[1] = length; + } + } else { + nodeCSRIndex.isSequential = false; + for (auto j = 0u; j < length; j++) { + nodeCSRIndex.rowIndices.push_back(startRow + j); + } + std::sort(nodeCSRIndex.rowIndices.begin(), nodeCSRIndex.rowIndices.end()); + } +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void CSRNodeGroup::update(const Transaction* transaction, CSRNodeGroupScanSource source, + row_idx_t rowIdxInGroup, column_id_t columnID, const ValueVector& propertyVector) { + switch (source) { + case CSRNodeGroupScanSource::COMMITTED_PERSISTENT: { + KU_ASSERT(persistentChunkGroup); + return persistentChunkGroup->update(transaction, rowIdxInGroup, columnID, propertyVector); + } + case CSRNodeGroupScanSource::COMMITTED_IN_MEMORY: { + KU_ASSERT(csrIndex); + auto [chunkIdx, rowInChunk] = StorageUtils::getQuotientRemainder(rowIdxInGroup, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + const auto lock = chunkedGroups.lock(); + const auto chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + return chunkedGroup->update(transaction, rowInChunk, columnID, propertyVector); + } + default: { + KU_UNREACHABLE; + } + } +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +bool CSRNodeGroup::delete_(const Transaction* transaction, CSRNodeGroupScanSource source, + row_idx_t rowIdxInGroup) { + switch (source) { + case CSRNodeGroupScanSource::COMMITTED_PERSISTENT: { + KU_ASSERT(persistentChunkGroup); + return persistentChunkGroup->delete_(transaction, rowIdxInGroup); + } + case CSRNodeGroupScanSource::COMMITTED_IN_MEMORY: { + KU_ASSERT(csrIndex); + auto [chunkIdx, rowInChunk] = StorageUtils::getQuotientRemainder(rowIdxInGroup, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + const auto lock = chunkedGroups.lock(); + const auto chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + return chunkedGroup->delete_(transaction, rowInChunk); + } + default: { + return false; + } + } +} + +void CSRNodeGroup::addColumn(TableAddColumnState& addColumnState, PageAllocator* pageAllocator, + ColumnStats* newColumnStats) { + if (persistentChunkGroup) { + persistentChunkGroup->addColumn(mm, addColumnState, enableCompression, pageAllocator, + newColumnStats); + } + NodeGroup::addColumn(addColumnState, pageAllocator, newColumnStats); +} + +void CSRNodeGroup::serialize(Serializer& serializer) { + serializer.writeDebuggingInfo("node_group_idx"); + serializer.write(nodeGroupIdx); + serializer.writeDebuggingInfo("enable_compression"); + serializer.write(enableCompression); + serializer.writeDebuggingInfo("format"); + serializer.write(format); + serializer.writeDebuggingInfo("has_checkpointed_data"); + serializer.write(persistentChunkGroup != nullptr); + if (persistentChunkGroup) { + serializer.writeDebuggingInfo("checkpointed_data"); + persistentChunkGroup->serialize(serializer); + } +} + +void CSRNodeGroup::checkpoint(MemoryManager&, NodeGroupCheckpointState& state) { + const auto lock = chunkedGroups.lock(); + if (!persistentChunkGroup) { + checkpointInMemOnly(lock, state); + } else { + checkpointInMemAndOnDisk(lock, state); + } + checkpointDataTypesNoLock(state); +} + +void CSRNodeGroup::reclaimStorage(PageAllocator& pageAllocator, const UniqLock& lock) const { + NodeGroup::reclaimStorage(pageAllocator, lock); + if (persistentChunkGroup) { + persistentChunkGroup->reclaimStorage(pageAllocator); + } +} + +static std::unique_ptr createNewPersistentChunkGroup( + ChunkedCSRNodeGroup& oldPersistentChunkGroup, CSRNodeGroupCheckpointState& csrState) { + auto newGroup = + std::make_unique(oldPersistentChunkGroup, csrState.columnIDs); + // checkpointed columns have been moved to the new group, reclaim storage for dropped column + oldPersistentChunkGroup.reclaimStorage(csrState.pageAllocator); + return newGroup; +} + +void CSRNodeGroup::checkpointInMemAndOnDisk(const UniqLock& lock, NodeGroupCheckpointState& state) { + // TODO(Guodong): Should skip early here if no changes in the node group, so we avoid scanning + // the csr header. Case: No insertions/deletions in persistent chunk and no in-mem chunks. + auto& csrState = state.cast(); + // Scan old csr header from disk and construct new csr header. + persistentChunkGroup->cast().scanCSRHeader(*state.mm, csrState); + csrState.newHeader = + std::make_unique(*state.mm, false, StorageConfig::NODE_GROUP_SIZE); + // TODO(Guodong): Find max node offset in the node group. + csrState.newHeader->setNumValues(StorageConfig::NODE_GROUP_SIZE); + csrState.newHeader->copyFrom(*csrState.oldHeader); + auto leafRegions = collectLeafRegionsAndCSRLength(lock, csrState); + KU_ASSERT(std::is_sorted(leafRegions.begin(), leafRegions.end(), + [](const auto& a, const auto& b) { return a.regionIdx < b.regionIdx; })); + const auto regionsToCheckpoint = mergeRegionsToCheckpoint(csrState, leafRegions); + if (regionsToCheckpoint.empty()) { + // No csr regions need to be checkpointed, meaning nothing is updated or deleted. + // We should reset the version and update info of the persistent chunked group. + persistentChunkGroup->resetVersionAndUpdateInfo(); + if (csrState.columnIDs.size() != persistentChunkGroup->getNumColumns()) { + // The column set of the node group has changed. We need to re-create the persistent + // chunked group. + persistentChunkGroup = createNewPersistentChunkGroup( + persistentChunkGroup->cast(), csrState); + } + return; + } + if (regionsToCheckpoint.size() == 1 && + regionsToCheckpoint[0].level > DEFAULT_PACKED_CSR_INFO.calibratorTreeHeight) { + // Need to re-distribute all CSR regions in the node group. + redistributeCSRRegions(csrState, leafRegions); + } else { + for (auto& region : regionsToCheckpoint) { + csrState.newHeader->populateRegionCSROffsets(region, *csrState.oldHeader); + // The left node offset of a region should always maintain stable across length and + // offset changes. + KU_ASSERT(csrState.oldHeader->getStartCSROffset(region.leftNodeOffset) == + csrState.newHeader->getStartCSROffset(region.leftNodeOffset)); + } + } + + uint64_t numTuplesAfterCheckpoint = 0; + for (const auto& region : regionsToCheckpoint) { + for (auto i = region.leftNodeOffset; i <= region.rightNodeOffset; ++i) { + numTuplesAfterCheckpoint += csrState.newHeader->getCSRLength(i); + } + } + if (numTuplesAfterCheckpoint == 0) { + reclaimStorage(csrState.pageAllocator, lock); + persistentChunkGroup = nullptr; + } else { + KU_ASSERT(csrState.newHeader->sanityCheck()); + for (const auto columnID : csrState.columnIDs) { + checkpointColumn(lock, columnID, csrState, regionsToCheckpoint); + } + checkpointCSRHeaderColumns(csrState); + persistentChunkGroup = createNewPersistentChunkGroup( + persistentChunkGroup->cast(), csrState); + } + finalizeCheckpoint(lock); +} + +std::vector CSRNodeGroup::collectLeafRegionsAndCSRLength(const UniqLock& lock, + const CSRNodeGroupCheckpointState& csrState) const { + std::vector leafRegions; + constexpr auto numLeafRegions = + StorageConfig::NODE_GROUP_SIZE / StorageConfig::CSR_LEAF_REGION_SIZE; + leafRegions.reserve(numLeafRegions); + for (auto leafRegionIdx = 0u; leafRegionIdx < numLeafRegions; leafRegionIdx++) { + CSRRegion region(leafRegionIdx, 0 /*level*/); + collectRegionChangesAndUpdateHeaderLength(lock, region, csrState); + leafRegions.push_back(std::move(region)); + } + return leafRegions; +} + +void CSRNodeGroup::redistributeCSRRegions(const CSRNodeGroupCheckpointState& csrState, + const std::vector& leafRegions) { + KU_ASSERT(std::is_sorted(leafRegions.begin(), leafRegions.end(), + [](const auto& a, const auto& b) { return a.regionIdx < b.regionIdx; })); + KU_ASSERT(std::all_of(leafRegions.begin(), leafRegions.end(), + [](const CSRRegion& region) { return region.level == 0; })); + KU_UNUSED(leafRegions); + const auto rightCSROffsetOfRegions = + csrState.newHeader->populateStartCSROffsetsFromLength(true /* leaveGaps */); + csrState.newHeader->populateEndCSROffsetFromStartAndLength(); + csrState.newHeader->finalizeCSRRegionEndOffsets(rightCSROffsetOfRegions); +} + +void CSRNodeGroup::checkpointColumn(const UniqLock& lock, column_id_t columnID, + const CSRNodeGroupCheckpointState& csrState, const std::vector& regions) const { + std::vector chunkCheckpointStates; + chunkCheckpointStates.reserve(regions.size()); + for (auto& region : regions) { + if (!region.needCheckpointColumn(columnID)) { + // Skip checkpoint for the column if it has no changes in the region. + continue; + } + auto regionCheckpointStates = checkpointColumnInRegion(lock, columnID, csrState, region); + // If there are no rows to write for the region, we don't aggressively reclaim the space in + // the region, but keep deleted rows as gaps. This can happen when all rows are deleted + // within the region. + for (auto& regionCheckpointState : regionCheckpointStates) { + chunkCheckpointStates.push_back(std::move(regionCheckpointState)); + } + } + persistentChunkGroup->getColumnChunk(columnID).checkpoint(*csrState.columns[columnID], + std::move(chunkCheckpointStates), csrState.pageAllocator); +} + +struct SegmentCursor { + SegmentCursor(LazySegmentScanner& scanner, offset_t leftCSROffset) + : scanner(scanner), it(scanner.begin()), curCSROffset(leftCSROffset) {} + + void advance(offset_t n) { + curCSROffset += n; + it.advance(n); + } + void operator++() { advance(1); } + + LazySegmentScanner& scanner; + LazySegmentScanner::Iterator it; + offset_t curCSROffset; +}; + +struct CheckpointReadCursor { + CheckpointReadCursor(LazySegmentScanner& scanner, offset_t leftCSROffset) + : cursor(scanner, leftCSROffset) {} + + void advance(offset_t n) { cursor.advance(n); } + void operator++() { cursor.operator++(); } + offset_t getCSROffset() const { return cursor.curCSROffset; } + + std::pair getDataToRead() { + if (cursor.it->segmentData == nullptr) { + cursor.scanner.scanSegmentIfNeeded(cursor.it.segmentIdx); + } + return {cursor.it->segmentData.get(), cursor.it.offsetInSegment}; + } + + bool canSkipRead() const { return cursor.it->segmentData == nullptr; } + + template + void rangeSegments(common::length_t length, Func func) const { + cursor.scanner.rangeSegments(cursor.it, length, std::move(func)); + } + + SegmentCursor cursor; +}; + +/** + * Writes output into multiple "segments" + * Note that the segments in the output won't necessarily match the segments being read from the + * column chunk + * Lazy writes are supported: + * - when the cursor is advanced and we haven't written to the current position a "gap" is left + * - if there is currently a gap and we perform a write we start a new segment to be written to + */ +class CheckpointWriteCursor { +public: + CheckpointWriteCursor(offset_t leftCSROffset, MemoryManager& memoryManager, + LogicalType& columnType, std::vector& outputSegments) + : segmentStartOffset(leftCSROffset), curCSROffset(leftCSROffset), + memoryManager(memoryManager), columnType(columnType), outputSegments(outputSegments) { + resetOutputChunk(); + } + + void advance(offset_t n) { curCSROffset += n; } + void operator++() { advance(1); } + offset_t getCSROffset() const { return curCSROffset; } + + void finalize() { + if (currentOutputSegment->getNumValues() > 0) { + appendCurrentSegmentToOutput(); + } + } + + ColumnChunkData& getCurrentSegmentForWrite(offset_t numValuesToWrite) { + if (segmentStartOffset + currentOutputSegment->getNumValues() < curCSROffset) { + startNewSegment(); + } + if (currentOutputSegment->getNumValues() + numValuesToWrite > + currentOutputSegment->getCapacity()) { + currentOutputSegment->resize( + std::bit_ceil(currentOutputSegment->getNumValues() + numValuesToWrite)); + } + return *currentOutputSegment; + } + + void appendToCurrentSegment(ColumnChunkData* data, offset_t srcOffset, + offset_t numValuesToAppend) { + getCurrentSegmentForWrite(numValuesToAppend).append(data, srcOffset, numValuesToAppend); + } + +private: + offset_t getInitChunkCapacity() const { return DEFAULT_VECTOR_CAPACITY; } + + void resetOutputChunk() { + currentOutputSegment = ColumnChunkFactory::createColumnChunkData(memoryManager, + columnType.copy(), false, getInitChunkCapacity(), ResidencyState::IN_MEMORY); + } + + void appendCurrentSegmentToOutput() { + outputSegments.emplace_back(std::move(currentOutputSegment), segmentStartOffset, + currentOutputSegment->getNumValues()); + } + + void startNewSegment() { + if (currentOutputSegment->getNumValues() > 0) { + appendCurrentSegmentToOutput(); + resetOutputChunk(); + } + segmentStartOffset = curCSROffset; + } + + offset_t segmentStartOffset; + std::unique_ptr currentOutputSegment; + offset_t curCSROffset; + + MemoryManager& memoryManager; + LogicalType& columnType; + std::vector& outputSegments; +}; + +static bool canSkipWrite(CheckpointReadCursor& readCursor, CheckpointWriteCursor& writeCursor) { + return readCursor.getCSROffset() == writeCursor.getCSROffset() && readCursor.canSkipRead(); +} + +static ChunkState scanCommittedUpdates(ColumnChunk& persistentChunk, Column* column, + LazySegmentScanner& scanner, offset_t startCSROffset, offset_t numRowsToScan) { + ChunkState chunkState; + persistentChunk.initializeScanState(chunkState, column); + persistentChunk.scanCommitted(&DUMMY_CHECKPOINT_TRANSACTION, + chunkState, scanner, startCSROffset, numRowsToScan); + return chunkState; +} + +static void writeCSRListNoPersistentDeletions(CheckpointReadCursor& readCursor, + CheckpointWriteCursor& writeCursor, offset_t oldCSRLength) { + readCursor.rangeSegments(oldCSRLength, + [&](auto& segmentData, auto offsetInSegment, auto lengthInSegment, auto) { + if (!canSkipWrite(readCursor, writeCursor)) { + [[maybe_unused]] auto [readSegmentData, readOffsetInSegment] = + readCursor.getDataToRead(); + KU_ASSERT(readSegmentData == segmentData.segmentData.get() && + readOffsetInSegment == offsetInSegment); + writeCursor.appendToCurrentSegment(segmentData.segmentData.get(), offsetInSegment, + lengthInSegment); + } + readCursor.advance(lengthInSegment); + writeCursor.advance(lengthInSegment); + }); +} + +static void writeCSRListWithPersistentDeletions(CheckpointReadCursor& readCursor, + CheckpointWriteCursor& writeCursor, offset_t oldCSRLength, + const ChunkedNodeGroup& persistentChunkGroup) { + // TODO(Guodong): Optimize the for loop away by appending in batch + for (auto i = 0u; i < oldCSRLength; i++) { + if (!persistentChunkGroup.isDeleted(&DUMMY_CHECKPOINT_TRANSACTION, + readCursor.getCSROffset())) { + if (!canSkipWrite(readCursor, writeCursor)) { + auto [segmentData, offsetInSegment] = readCursor.getDataToRead(); + writeCursor.appendToCurrentSegment(segmentData, offsetInSegment, 1); + } + ++writeCursor; + } + ++readCursor; + } +} + +static void writeInMemoryCSRInsertion(CheckpointWriteCursor& writeCursor, + const ChunkedNodeGroup& chunkedGroup, row_idx_t rowInChunk, column_id_t columnID, + ChunkState& chunkState) { + KU_ASSERT(!chunkedGroup.isDeleted(&DUMMY_CHECKPOINT_TRANSACTION, rowInChunk)); + chunkedGroup.getColumnChunk(columnID).scanCommitted( + &DUMMY_CHECKPOINT_TRANSACTION, chunkState, writeCursor.getCurrentSegmentForWrite(1), + rowInChunk, 1); + ++writeCursor; +} + +static void fillCSRGaps(CheckpointReadCursor& readCursor, CheckpointWriteCursor& writeCursor, + ColumnChunkData* dummyChunkForNulls, length_t numOldGaps, length_t numGaps) { + auto numOldGapsRemaining = numOldGaps; + auto numGapsRemaining = numGaps; + if (readCursor.getCSROffset() < writeCursor.getCSROffset()) { + // Try to advance read cursor to write cursor (if num old gaps is large enough) + const auto numGapsToAdvance = + std::min(numOldGapsRemaining, writeCursor.getCSROffset() - readCursor.getCSROffset()); + readCursor.advance(numGapsToAdvance); + numOldGapsRemaining -= numGapsToAdvance; + } + + // We can skip writes for any new gaps whose CSR offset also corresponds to an old gap + if (readCursor.getCSROffset() == writeCursor.getCSROffset()) { + auto numSkippableGaps = std::min(numGapsRemaining, numOldGapsRemaining); + numGapsRemaining -= numSkippableGaps; + writeCursor.advance(numSkippableGaps); + } + + while (numGapsRemaining > 0) { + const auto numGapsToFill = + std::min(numGapsRemaining, static_cast(DEFAULT_VECTOR_CAPACITY)); + dummyChunkForNulls->setNumValues(numGapsToFill); + writeCursor.appendToCurrentSegment(dummyChunkForNulls, 0, numGapsToFill); + writeCursor.advance(numGapsToFill); + numGapsRemaining -= numGapsToFill; + } + + readCursor.advance(numOldGapsRemaining); +} + +std::vector CSRNodeGroup::checkpointColumnInRegion(const UniqLock& lock, + column_id_t columnID, const CSRNodeGroupCheckpointState& csrState, + const CSRRegion& region) const { + const auto leftCSROffset = csrState.oldHeader->getStartCSROffset(region.leftNodeOffset); + KU_ASSERT(leftCSROffset == csrState.newHeader->getStartCSROffset(region.leftNodeOffset)); + const auto rightCSROffset = csrState.oldHeader->getEndCSROffset(region.rightNodeOffset); + const auto numOldRowsInRegion = rightCSROffset - leftCSROffset; + + Column* column = csrState.columns[columnID]; + LazySegmentScanner oldChunkScanner{*csrState.mm, column->getDataType().copy(), + enableCompression}; + auto chunkState = scanCommittedUpdates(persistentChunkGroup->getColumnChunk(columnID), column, + oldChunkScanner, leftCSROffset, numOldRowsInRegion); + + const auto dummyChunkForNulls = ColumnChunkFactory::createColumnChunkData(*csrState.mm, + dataTypes[columnID].copy(), false, DEFAULT_VECTOR_CAPACITY, ResidencyState::IN_MEMORY); + dummyChunkForNulls->resetToAllNull(); + + std::vector ret; + + CheckpointReadCursor readCursor{oldChunkScanner, leftCSROffset}; + CheckpointWriteCursor writeCursor{leftCSROffset, *csrState.mm, column->getDataType(), ret}; + + // Copy per csr list from old chunk and merge with new insertions into the newChunkData. + for (auto nodeOffset = region.leftNodeOffset; nodeOffset <= region.rightNodeOffset; + nodeOffset++) { + const auto oldCSRLength = csrState.oldHeader->getCSRLength(nodeOffset); + + KU_ASSERT(csrState.newHeader->getStartCSROffset(nodeOffset) == writeCursor.getCSROffset()); + KU_ASSERT(csrState.oldHeader->getStartCSROffset(nodeOffset) == readCursor.getCSROffset()); + + // Copy old csr list with updates into the new chunk. + if (!region.hasPersistentDeletions) { + writeCSRListNoPersistentDeletions(readCursor, writeCursor, oldCSRLength); + } else { + writeCSRListWithPersistentDeletions(readCursor, writeCursor, oldCSRLength, + *persistentChunkGroup); + } + // Merge in-memory insertions into the new chunk. + if (csrIndex) { + auto rows = csrIndex->indices[nodeOffset].getRows(); + // TODO(Guodong): Optimize here. if no deletions and has sequential rows, scan in + // range. + for (const auto row : rows) { + if (row == INVALID_ROW_IDX) { + continue; + } + auto [chunkIdx, rowInChunk] = StorageUtils::getQuotientRemainder(row, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + const auto chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + writeInMemoryCSRInsertion(writeCursor, *chunkedGroup, rowInChunk, columnID, + chunkState); + } + } + + const length_t numGaps = csrState.newHeader->getGapSize(nodeOffset); + const length_t numOldGaps = csrState.oldHeader->getGapSize(nodeOffset); + // Gaps should only happen at the end of the CSR region. + KU_ASSERT(numGaps == 0 || (nodeOffset == region.rightNodeOffset - 1) || + (nodeOffset + 1) % StorageConfig::CSR_LEAF_REGION_SIZE == 0); + fillCSRGaps(readCursor, writeCursor, dummyChunkForNulls.get(), numOldGaps, numGaps); + } + writeCursor.finalize(); + KU_ASSERT(readCursor.getCSROffset() - leftCSROffset == numOldRowsInRegion); + KU_ASSERT( + writeCursor.getCSROffset() == csrState.newHeader->getEndCSROffset(region.rightNodeOffset)); + // We can't skip writing appends as they need to be flushed to disk + KU_ASSERT(readCursor.getCSROffset() == writeCursor.getCSROffset() || ret.empty() || + ret.back().startRow + ret.back().numRows == writeCursor.getCSROffset()); + return ret; +} + +void CSRNodeGroup::checkpointCSRHeaderColumns(const CSRNodeGroupCheckpointState& csrState) const { + std::vector csrOffsetChunkCheckpointStates; + const auto numNodes = csrState.newHeader->offset->getNumValues(); + KU_ASSERT(numNodes == csrState.newHeader->length->getNumValues()); + csrOffsetChunkCheckpointStates.push_back( + ChunkCheckpointState{std::move(csrState.newHeader->offset), 0, numNodes}); + persistentChunkGroup->cast().getCSRHeader().offset->checkpoint( + *csrState.csrOffsetColumn, std::move(csrOffsetChunkCheckpointStates), + csrState.pageAllocator); + std::vector csrLengthChunkCheckpointStates; + csrLengthChunkCheckpointStates.push_back( + ChunkCheckpointState{std::move(csrState.newHeader->length), 0, numNodes}); + persistentChunkGroup->cast().getCSRHeader().length->checkpoint( + *csrState.csrLengthColumn, std::move(csrLengthChunkCheckpointStates), + csrState.pageAllocator); +} + +void CSRNodeGroup::collectRegionChangesAndUpdateHeaderLength(const UniqLock& lock, + CSRRegion& region, const CSRNodeGroupCheckpointState& csrState) const { + collectInMemRegionChangesAndUpdateHeaderLength(lock, region, csrState); + collectOnDiskRegionChangesAndUpdateHeaderLength(lock, region, csrState); +} + +void CSRNodeGroup::collectInMemRegionChangesAndUpdateHeaderLength(const UniqLock& lock, + CSRRegion& region, const CSRNodeGroupCheckpointState& csrState) const { + row_idx_t numInsertionsInRegion = 0u; + if (csrIndex) { + for (auto nodeOffset = region.leftNodeOffset; nodeOffset <= region.rightNodeOffset; + nodeOffset++) { + auto rows = csrIndex->indices[nodeOffset].getRows(); + row_idx_t numInsertedRows = rows.size(); + row_idx_t numInMemDeletionsInCSR = 0; + for (auto i = 0u; i < rows.size(); i++) { + const auto row = rows[i]; + auto [chunkIdx, rowInChunk] = StorageUtils::getQuotientRemainder(row, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + const auto chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + if (chunkedGroup->isDeleted(&DUMMY_CHECKPOINT_TRANSACTION, rowInChunk)) { + csrIndex->indices[nodeOffset].turnToNonSequential(); + csrIndex->indices[nodeOffset].setInvalid(i); + numInMemDeletionsInCSR++; + } + } + KU_ASSERT(numInMemDeletionsInCSR <= numInsertedRows); + numInsertedRows -= numInMemDeletionsInCSR; + const auto oldLength = csrState.oldHeader->getCSRLength(nodeOffset); + const auto newLength = oldLength + numInsertedRows; + csrState.newHeader->length->setValue(newLength, nodeOffset); + numInsertionsInRegion += numInsertedRows; + } + } + region.hasInsertions = numInsertionsInRegion > 0; + region.sizeChange += static_cast(numInsertionsInRegion); +} + +void CSRNodeGroup::collectOnDiskRegionChangesAndUpdateHeaderLength(const UniqLock&, + CSRRegion& region, const CSRNodeGroupCheckpointState& csrState) const { + collectPersistentUpdatesInRegion(region, csrState); + int64_t numDeletionsInRegion = 0u; + if (persistentChunkGroup) { + for (auto nodeOffset = region.leftNodeOffset; nodeOffset <= region.rightNodeOffset; + nodeOffset++) { + const auto numDeletedRows = + getNumDeletionsForNodeInPersistentData(nodeOffset, csrState); + if (numDeletedRows == 0) { + continue; + } + numDeletionsInRegion += numDeletedRows; + const auto currentLength = csrState.newHeader->getCSRLength(nodeOffset); + KU_ASSERT(currentLength >= numDeletedRows); + csrState.newHeader->length->setValue(currentLength - numDeletedRows, + nodeOffset); + } + } + region.hasPersistentDeletions = numDeletionsInRegion > 0; + region.sizeChange -= numDeletionsInRegion; +} + +void CSRNodeGroup::collectPersistentUpdatesInRegion(CSRRegion& region, + const CSRNodeGroupCheckpointState& csrState) const { + const auto leftCSROffset = csrState.oldHeader->getStartCSROffset(region.leftNodeOffset); + const auto rightCSROffset = csrState.oldHeader->getEndCSROffset(region.rightNodeOffset); + region.hasUpdates.resize(csrState.columnIDs.size(), false); + for (auto i = 0u; i < csrState.columnIDs.size(); i++) { + auto columnID = csrState.columnIDs[i]; + if (persistentChunkGroup->hasAnyUpdates(&DUMMY_CHECKPOINT_TRANSACTION, columnID, + leftCSROffset, rightCSROffset - leftCSROffset + 1)) { + region.hasUpdates[i] = true; + } + } +} + +row_idx_t CSRNodeGroup::getNumDeletionsForNodeInPersistentData(offset_t nodeOffset, + const CSRNodeGroupCheckpointState& csrState) const { + const auto length = csrState.oldHeader->getCSRLength(nodeOffset); + const auto startRow = csrState.oldHeader->getStartCSROffset(nodeOffset); + return persistentChunkGroup->getNumDeletions(&DUMMY_CHECKPOINT_TRANSACTION, startRow, length); +} + +static DataChunk initScanDataChunk(const CSRNodeGroupCheckpointState& csrState, + const std::vector& dataTypes) { + const auto scanChunkState = std::make_shared(); + DataChunk dataChunk(csrState.columnIDs.size(), scanChunkState); + for (auto i = 0u; i < csrState.columnIDs.size(); i++) { + const auto columnID = csrState.columnIDs[i]; + KU_ASSERT(columnID < dataTypes.size()); + const auto valueVector = + std::make_shared(dataTypes[columnID].copy(), csrState.mm); + dataChunk.insert(i, valueVector); + } + return dataChunk; +} + +void CSRNodeGroup::checkpointInMemOnly(const UniqLock& lock, NodeGroupCheckpointState& state) { + auto numRels = 0u; + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + numRels += chunkedGroup->getNumRows(); + } + if (numRels == 0) { + return; + } + // Construct in-mem csr header chunks. + auto& csrState = state.cast(); + csrState.newHeader = std::make_unique(*state.mm, + false /*enableCompression*/, StorageConfig::NODE_GROUP_SIZE); + const auto numNodes = csrIndex->getMaxOffsetWithRels() + 1; + csrState.newHeader->setNumValues(numNodes); + populateCSRLengthInMemOnly(lock, numNodes, csrState); + const auto rightCSROffsetsOfRegions = + csrState.newHeader->populateStartCSROffsetsFromLength(true /* leaveGap */); + csrState.newHeader->populateEndCSROffsetFromStartAndLength(); + csrState.newHeader->finalizeCSRRegionEndOffsets(rightCSROffsetsOfRegions); + + // Init scan chunk and scan state. + const auto numColumnsToCheckpoint = csrState.columnIDs.size(); + auto scanChunk = initScanDataChunk(csrState, dataTypes); + std::vector columns(numColumnsToCheckpoint); + for (auto i = 0u; i < numColumnsToCheckpoint; i++) { + columns[i] = csrState.columns[i]; + } + std::vector outVectors; + for (auto i = 0u; i < numColumnsToCheckpoint; i++) { + outVectors.push_back(scanChunk.valueVectors[i].get()); + } + auto scanState = std::make_unique(nullptr, outVectors, scanChunk.state); + scanState->columnIDs = csrState.columnIDs; + scanState->columns = columns; + scanState->nodeGroupScanState = + std::make_unique(csrState.columnIDs.size()); + + auto dummyChunk = initScanDataChunk(csrState, dataTypes); + for (auto i = 0u; i < dummyChunk.getNumValueVectors(); i++) { + dummyChunk.getValueVectorMutable(i).setAllNull(); + } + + // Init data chunks to be appended and flushed. + auto chunkCapacity = rightCSROffsetsOfRegions.back() + 1; + std::vector> dataChunksToFlush(numColumnsToCheckpoint); + for (auto i = 0u; i < numColumnsToCheckpoint; i++) { + const auto columnID = csrState.columnIDs[i]; + KU_ASSERT(columnID < dataTypes.size()); + dataChunksToFlush[i] = std::make_unique(*state.mm, dataTypes[columnID].copy(), + chunkCapacity, enableCompression, ResidencyState::IN_MEMORY); + } + + // Scan tuples from in mem node groups and append to data chunks to flush. + for (auto offset = 0u; offset < numNodes; offset++) { + const auto numRows = csrIndex->getNumRows(offset); + auto rows = csrIndex->indices[offset].getRows(); + auto numRowsTryAppended = 0u; + while (numRowsTryAppended < numRows) { + const auto maxNumRowsToAppend = + std::min(numRows - numRowsTryAppended, DEFAULT_VECTOR_CAPACITY); + auto numRowsToAppend = 0u; + for (auto i = 0u; i < maxNumRowsToAppend; i++) { + const auto row = rows[numRowsTryAppended + i]; + // TODO(Guodong): Should skip deleted rows here. + if (row == INVALID_ROW_IDX) { + continue; + } + scanState->rowIdxVector->setValue(numRowsToAppend++, row); + } + scanChunk.state->getSelVectorUnsafe().setSelSize(numRowsToAppend); + if (numRowsToAppend > 0) { + [[maybe_unused]] auto res = + lookupMultiple(lock, &DUMMY_CHECKPOINT_TRANSACTION, *scanState); + for (auto idx = 0u; idx < numColumnsToCheckpoint; idx++) { + dataChunksToFlush[idx]->append(scanChunk.valueVectors[idx].get(), + scanChunk.state->getSelVector()); + } + } + numRowsTryAppended += maxNumRowsToAppend; + } + auto gapSize = csrState.newHeader->getGapSize(offset); + while (gapSize > 0) { + // Gaps should only happen at the end of the CSR region. + KU_ASSERT((offset == numNodes - 1) || + (offset + 1) % StorageConfig::CSR_LEAF_REGION_SIZE == 0); + const auto numGapsToAppend = std::min(gapSize, DEFAULT_VECTOR_CAPACITY); + KU_ASSERT(dummyChunk.state->getSelVector().isUnfiltered()); + dummyChunk.state->getSelVectorUnsafe().setSelSize(numGapsToAppend); + for (auto columnID = 0u; columnID < numColumnsToCheckpoint; columnID++) { + dataChunksToFlush[columnID]->append(dummyChunk.valueVectors[columnID].get(), + dummyChunk.state->getSelVector()); + } + gapSize -= numGapsToAppend; + } + } + + // FIXME(bmwinger): this needs segmentation. Maybe this should use (or share code with) + // checkpointOutOfPlace Flush data chunks to disk. + for (const auto& chunk : dataChunksToFlush) { + chunk->flush(csrState.pageAllocator); + } + csrState.newHeader->offset->flush(csrState.pageAllocator); + csrState.newHeader->length->flush(csrState.pageAllocator); + persistentChunkGroup = std::make_unique( + ChunkedCSRHeader(false /*enableCompression*/, std::move(*csrState.newHeader)), + std::move(dataChunksToFlush), 0); + // TODO(Guodong): Use `finalizeCheckpoint`. + chunkedGroups.clear(lock); + // Set `numRows` back to 0 is to reflect that the in mem part of the node group is empty. + numRows = 0; + csrIndex.reset(); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void CSRNodeGroup::populateCSRLengthInMemOnly(const UniqLock& lock, offset_t numNodes, + const CSRNodeGroupCheckpointState& csrState) { + for (auto offset = 0u; offset < numNodes; offset++) { + auto rows = csrIndex->indices[offset].getRows(); + const length_t length = rows.size(); + auto lengthAfterDelete = length; + for (auto i = 0u; i < rows.size(); i++) { + const auto row = rows[i]; + auto [chunkIdx, rowInChunk] = + StorageUtils::getQuotientRemainder(row, StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + const auto chunkedGroup = chunkedGroups.getGroup(lock, chunkIdx); + const auto isDeleted = + chunkedGroup->isDeleted(&DUMMY_CHECKPOINT_TRANSACTION, rowInChunk); + if (isDeleted) { + csrIndex->indices[offset].turnToNonSequential(); + csrIndex->indices[offset].setInvalid(i); + lengthAfterDelete--; + } + } + KU_ASSERT(lengthAfterDelete <= length); + csrState.newHeader->length->setValue(lengthAfterDelete, offset); + } +} + +std::vector CSRNodeGroup::mergeRegionsToCheckpoint( + const CSRNodeGroupCheckpointState& csrState, const std::vector& leafRegions) { + KU_ASSERT(std::all_of(leafRegions.begin(), leafRegions.end(), + [](const CSRRegion& region) { return region.level == 0; })); + KU_ASSERT(std::is_sorted(leafRegions.begin(), leafRegions.end(), + [](const CSRRegion& a, const CSRRegion& b) { return a.regionIdx < b.regionIdx; })); + constexpr auto numLeafRegions = + StorageConfig::NODE_GROUP_SIZE / StorageConfig::CSR_LEAF_REGION_SIZE; + KU_ASSERT(leafRegions.size() == numLeafRegions); + std::vector mergedRegions; + idx_t leafRegionIdx = 0u; + while (leafRegionIdx < numLeafRegions) { + auto region = leafRegions[leafRegionIdx]; + if (!region.needCheckpoint()) { + leafRegionIdx++; + continue; + } + while (!isWithinDensityBound(*csrState.oldHeader, leafRegions, region)) { + region = CSRRegion::upgradeLevel(leafRegions, region); + if (region.level > DEFAULT_PACKED_CSR_INFO.calibratorTreeHeight) { + // Hit the top level already. Need to re-distribute. + return {region}; + } + } + // Skip to the next right leaf region of the found region. + leafRegionIdx = region.getRightLeafRegionIdx() + 1; + // Loop through found regions and eliminate the ones that are under the realm of the + // currently found region. + std::erase_if(mergedRegions, [&](const CSRRegion& r) { return r.isWithin(region); }); + mergedRegions.push_back(region); + } + std::sort(mergedRegions.begin(), mergedRegions.end(), + [](const CSRRegion& a, const CSRRegion& b) { + return a.getLeftLeafRegionIdx() < b.getLeftLeafRegionIdx(); + }); + return mergedRegions; +} + +static double getHighDensity(uint64_t level) { + KU_ASSERT(level <= CSRNodeGroup::DEFAULT_PACKED_CSR_INFO.calibratorTreeHeight); + if (level == 0) { + return StorageConstants::LEAF_HIGH_CSR_DENSITY; + } + return StorageConstants::PACKED_CSR_DENSITY + + CSRNodeGroup::DEFAULT_PACKED_CSR_INFO.highDensityStep * + static_cast( + CSRNodeGroup::DEFAULT_PACKED_CSR_INFO.calibratorTreeHeight - level); +} + +bool CSRNodeGroup::isWithinDensityBound(const InMemChunkedCSRHeader& header, + const std::vector& leafRegions, const CSRRegion& region) { + int64_t oldSize = 0; + for (auto offset = region.leftNodeOffset; offset <= region.rightNodeOffset; offset++) { + oldSize += header.getCSRLength(offset); + } + int64_t sizeChange = 0; + const idx_t leftRegionIdx = region.getLeftLeafRegionIdx(); + const idx_t rightRegionIdx = region.getRightLeafRegionIdx(); + for (auto regionIdx = leftRegionIdx; regionIdx <= rightRegionIdx; regionIdx++) { + sizeChange += leafRegions[regionIdx].sizeChange; + } + KU_ASSERT(sizeChange >= 0 || sizeChange < oldSize); + const auto newSize = oldSize + sizeChange; + const auto capacity = header.getEndCSROffset(region.rightNodeOffset) - + header.getStartCSROffset(region.leftNodeOffset); + const double ratio = static_cast(newSize) / static_cast(capacity); + return ratio <= getHighDensity(region.level); +} + +void CSRNodeGroup::finalizeCheckpoint(const UniqLock& lock) { + // Clean up versions and in mem chunked groups. + if (persistentChunkGroup) { + persistentChunkGroup->resetNumRowsFromChunks(); + persistentChunkGroup->resetVersionAndUpdateInfo(); + } + chunkedGroups.clear(lock); + // Set `numRows` back to 0 is to reflect that the in mem part of the node group is empty. + numRows = 0; + csrIndex.reset(); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/dictionary_chunk.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/dictionary_chunk.cpp new file mode 100644 index 0000000000..526819326d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/dictionary_chunk.cpp @@ -0,0 +1,127 @@ +#include "storage/table/dictionary_chunk.h" + +#include "common/constants.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "storage/enums/residency_state.h" +#include + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +// The offset chunk is able to grow beyond the node group size. +// We rely on appending to the dictionary when updating, however if the chunk is full, +// there will be no space for in-place updates. +// The data chunk doubles in size on use, but out of place updates will never need the offset +// chunk to be greater than the node group size since they remove unused entries. +// So the chunk is initialized with a size equal to 3 so that the capacity is never resized to +// exactly the node group size (which is always a power of 2), making sure there is always extra +// space for updates. +static constexpr uint64_t INITIAL_OFFSET_CHUNK_CAPACITY = 3; + +DictionaryChunk::DictionaryChunk(MemoryManager& mm, uint64_t capacity, bool enableCompression, + ResidencyState residencyState) + : enableCompression{enableCompression}, + indexTable(0, StringOps(this) /*hash*/, StringOps(this) /*equals*/) { + // Bitpacking might save 1 bit per value with regular ascii compared to UTF-8 + stringDataChunk = ColumnChunkFactory::createColumnChunkData(mm, LogicalType::UINT8(), + false /*enableCompression*/, 0, residencyState, false /*hasNullData*/); + offsetChunk = ColumnChunkFactory::createColumnChunkData(mm, LogicalType::UINT64(), + enableCompression, std::min(capacity, INITIAL_OFFSET_CHUNK_CAPACITY), residencyState, + false /*hasNullData*/); +} + +void DictionaryChunk::resetToEmpty() { + stringDataChunk->resetToEmpty(); + offsetChunk->resetToEmpty(); + indexTable.clear(); +} + +uint64_t DictionaryChunk::getStringLength(string_index_t index) const { + if (stringDataChunk->getNumValues() == 0) { + return 0; + } + if (index + 1 < offsetChunk->getNumValues()) { + KU_ASSERT(offsetChunk->getValue(index + 1) >= + offsetChunk->getValue(index)); + return offsetChunk->getValue(index + 1) - + offsetChunk->getValue(index); + } + return stringDataChunk->getNumValues() - offsetChunk->getValue(index); +} + +DictionaryChunk::string_index_t DictionaryChunk::appendString(std::string_view val) { + const auto found = indexTable.find(val); + // If the string already exists in the dictionary, skip it and refer to the existing string + if (enableCompression && found != indexTable.end()) { + return found->index; + } + const auto leftSpace = stringDataChunk->getCapacity() - stringDataChunk->getNumValues(); + if (leftSpace < val.size()) { + stringDataChunk->resize(std::bit_ceil(stringDataChunk->getCapacity() + val.size())); + } + const auto startOffset = stringDataChunk->getNumValues(); + memcpy(stringDataChunk->getData() + startOffset, val.data(), val.size()); + stringDataChunk->setNumValues(startOffset + val.size()); + const auto index = offsetChunk->getNumValues(); + if (index >= offsetChunk->getCapacity()) { + offsetChunk->resize(offsetChunk->getCapacity() == 0 ? + 2 : + (offsetChunk->getCapacity() * CHUNK_RESIZE_RATIO)); + } + offsetChunk->setValue(startOffset, index); + offsetChunk->setNumValues(index + 1); + if (enableCompression) { + indexTable.insert({static_cast(index)}); + } + return index; +} + +std::string_view DictionaryChunk::getString(string_index_t index) const { + KU_ASSERT(index < offsetChunk->getNumValues()); + const auto startOffset = offsetChunk->getValue(index); + const auto length = getStringLength(index); + return std::string_view(reinterpret_cast(stringDataChunk->getData()) + startOffset, + length); +} + +bool DictionaryChunk::sanityCheck() const { + return offsetChunk->getNumValues() <= offsetChunk->getNumValues(); +} + +void DictionaryChunk::resetNumValuesFromMetadata() { + stringDataChunk->resetNumValuesFromMetadata(); + offsetChunk->resetNumValuesFromMetadata(); +} + +uint64_t DictionaryChunk::getEstimatedMemoryUsage() const { + return stringDataChunk->getEstimatedMemoryUsage() + offsetChunk->getEstimatedMemoryUsage(); +} + +void DictionaryChunk::flush(PageAllocator& pageAllocator) { + stringDataChunk->flush(pageAllocator); + offsetChunk->flush(pageAllocator); +} + +void DictionaryChunk::serialize(Serializer& serializer) const { + serializer.writeDebuggingInfo("offset_chunk"); + offsetChunk->serialize(serializer); + serializer.writeDebuggingInfo("string_data_chunk"); + stringDataChunk->serialize(serializer); +} + +std::unique_ptr DictionaryChunk::deserialize(MemoryManager& memoryManager, + Deserializer& deSer) { + auto chunk = std::make_unique(memoryManager, 0, true, ResidencyState::ON_DISK); + std::string key; + deSer.validateDebuggingInfo(key, "offset_chunk"); + chunk->offsetChunk = ColumnChunkData::deserialize(memoryManager, deSer); + deSer.validateDebuggingInfo(key, "string_data_chunk"); + chunk->stringDataChunk = ColumnChunkData::deserialize(memoryManager, deSer); + return chunk; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/dictionary_column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/dictionary_column.cpp new file mode 100644 index 0000000000..11a734ba45 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/dictionary_column.cpp @@ -0,0 +1,266 @@ +#include "storage/table/dictionary_column.h" + +#include +#include + +#include "common/types/ku_string.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/dictionary_chunk.h" +#include "storage/table/string_chunk_data.h" +#include "storage/table/string_column.h" +#include +#include + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +using string_index_t = DictionaryChunk::string_index_t; +using string_offset_t = DictionaryChunk::string_offset_t; + +DictionaryColumn::DictionaryColumn(const std::string& name, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression) { + auto dataColName = StorageUtils::getColumnName(name, StorageUtils::ColumnType::DATA, ""); + dataColumn = std::make_unique(dataColName, LogicalType::UINT8(), dataFH, mm, shadowFile, + false /*enableCompression*/, false /*requireNullColumn*/); + auto offsetColName = StorageUtils::getColumnName(name, StorageUtils::ColumnType::OFFSET, ""); + offsetColumn = std::make_unique(offsetColName, LogicalType::UINT64(), dataFH, mm, + shadowFile, enableCompression, false /*requireNullColumn*/); +} + +void DictionaryColumn::scan(const SegmentState& state, DictionaryChunk& dictChunk) const { + auto offsetChunk = dictChunk.getOffsetChunk(); + auto stringDataChunk = dictChunk.getStringDataChunk(); + auto initialDictSize = offsetChunk->getNumValues(); + auto initialDictDataSize = stringDataChunk->getNumValues(); + + auto& dataMetadata = + StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA).metadata; + // Make sure that the chunk is large enough + if (stringDataChunk->getNumValues() + dataMetadata.numValues > stringDataChunk->getCapacity()) { + stringDataChunk->resize( + std::bit_ceil(stringDataChunk->getNumValues() + dataMetadata.numValues)); + } + dataColumn->scanSegment(StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA), + stringDataChunk, 0, + StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA).metadata.numValues); + + auto& offsetMetadata = + StringColumn::getChildState(state, StringColumn::ChildStateIndex::OFFSET).metadata; + // Make sure that the chunk is large enough + if (offsetChunk->getNumValues() + offsetMetadata.numValues > offsetChunk->getCapacity()) { + offsetChunk->resize(std::bit_ceil(offsetChunk->getNumValues() + offsetMetadata.numValues)); + } + offsetColumn->scanSegment( + StringColumn::getChildState(state, StringColumn::ChildStateIndex::OFFSET), offsetChunk, 0, + StringColumn::getChildState(state, StringColumn::ChildStateIndex::OFFSET) + .metadata.numValues); + // Each offset needs to be incremented by the initial size of the dictionary data chunk + for (row_idx_t i = initialDictSize; i < offsetChunk->getNumValues(); i++) { + offsetChunk->setValue( + offsetChunk->getValue(i) + initialDictDataSize, i); + } +} + +template +void DictionaryColumn::scan(const SegmentState& offsetState, const SegmentState& dataState, + std::vector>& offsetsToScan, Result* result, + const ColumnChunkMetadata& indexMeta) const { + string_index_t firstOffsetToScan = 0, lastOffsetToScan = 0; + auto comp = [](auto pair1, auto pair2) { return pair1.first < pair2.first; }; + auto duplicationFactor = (double)offsetState.metadata.numValues / indexMeta.numValues; + if (duplicationFactor <= 0.5) { + // If at least 50% of strings are duplicated, sort the offsets so we can re-use scanned + // strings + std::sort(offsetsToScan.begin(), offsetsToScan.end(), comp); + firstOffsetToScan = offsetsToScan.front().first; + lastOffsetToScan = offsetsToScan.back().first; + } else { + const auto& [min, max] = + std::minmax_element(offsetsToScan.begin(), offsetsToScan.end(), comp); + firstOffsetToScan = min->first; + lastOffsetToScan = max->first; + } + // TODO(bmwinger): scan batches of adjacent values. + // Ideally we scan values together until we reach empty pages + // This would also let us use the same optimization for the data column, + // where the worst case for the current method is much worse + + // Note that the list will contain duplicates when indices are duplicated. + // Each distinct value is scanned once, and re-used when writing to each output value + auto numOffsetsToScan = lastOffsetToScan - firstOffsetToScan + 1; + // One extra offset to scan for the end offset of the last string + std::vector offsets(numOffsetsToScan + 1); + scanOffsets(offsetState, offsets.data(), firstOffsetToScan, numOffsetsToScan, + dataState.metadata.numValues); + + if constexpr (std::same_as) { + auto& offsetChunk = *result->getDictionaryChunk()->getOffsetChunk(); + if (offsetChunk.getNumValues() + offsetsToScan.size() > offsetChunk.getCapacity()) { + offsetChunk.resize(std::bit_ceil(offsetChunk.getNumValues() + offsetsToScan.size())); + } + } + + for (auto pos = 0u; pos < offsetsToScan.size(); pos++) { + auto startOffset = offsets[offsetsToScan[pos].first - firstOffsetToScan]; + auto endOffset = offsets[offsetsToScan[pos].first - firstOffsetToScan + 1]; + auto lengthToScan = endOffset - startOffset; + KU_ASSERT(endOffset >= startOffset); + scanValue(dataState, startOffset, lengthToScan, result, offsetsToScan[pos].second); + // For each string which has the same index in the dictionary as the one we scanned, + // copy the scanned string to its position in the result vector + if constexpr (std::same_as) { + auto& scannedString = result->template getValue(offsetsToScan[pos].second); + while (pos + 1 < offsetsToScan.size() && + offsetsToScan[pos + 1].first == offsetsToScan[pos].first) { + pos++; + result->template setValue(offsetsToScan[pos].second, scannedString); + } + } else { + // When scanning to chunks de-duplication should be done prior to this function such + // that you can have multiple positions in the string index chunk pointing to one string + // in this dictionary chunk. + // The offset chunk cannot have multiple offsets pointing to the same data, even if + // consecutive, since that would break the mechanism for calculating the size of a + // string. + KU_ASSERT(pos == offsetsToScan.size() - 1 || + offsetsToScan[pos].first != offsetsToScan[pos + 1].first); + } + } +} + +template void DictionaryColumn::scan(const SegmentState& offsetState, + const SegmentState& dataState, + std::vector>& offsetsToScan, + common::ValueVector* result, const ColumnChunkMetadata& indexMeta) const; + +template void DictionaryColumn::scan(const SegmentState& offsetState, + const SegmentState& dataState, + std::vector>& offsetsToScan, + StringChunkData* result, const ColumnChunkMetadata& indexMeta) const; + +string_index_t DictionaryColumn::append(const DictionaryChunk& dictChunk, SegmentState& state, + std::string_view val) const { + const auto startOffset = dataColumn->appendValues(*dictChunk.getStringDataChunk(), + StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA), + reinterpret_cast(val.data()), nullptr /*nullChunkData*/, val.size()); + return offsetColumn->appendValues(*dictChunk.getOffsetChunk(), + StringColumn::getChildState(state, StringColumn::ChildStateIndex::OFFSET), + reinterpret_cast(&startOffset), nullptr /*nullChunkData*/, 1 /*numValues*/); +} + +void DictionaryColumn::scanOffsets(const SegmentState& state, + DictionaryChunk::string_offset_t* offsets, uint64_t index, uint64_t numValues, + uint64_t dataSize) const { + // We either need to read the next value, or store the maximum string offset at the end. + // Otherwise we won't know what the length of the last string is. + if (index + numValues < state.metadata.numValues) { + offsetColumn->scanSegment(state, index, numValues + 1, (uint8_t*)offsets); + } else { + offsetColumn->scanSegment(state, index, numValues, (uint8_t*)offsets); + offsets[numValues] = dataSize; + } +} + +void DictionaryColumn::scanValue(const SegmentState& dataState, uint64_t startOffset, + uint64_t length, ValueVector* resultVector, uint64_t offsetInVector) const { + // Add string to vector first and read directly into the vector + auto& kuString = StringVector::reserveString(resultVector, offsetInVector, length); + dataColumn->scanSegment(dataState, startOffset, length, (uint8_t*)kuString.getData()); + // Update prefix to match the scanned string data + if (!ku_string_t::isShortString(kuString.len)) { + memcpy(kuString.prefix, kuString.getData(), ku_string_t::PREFIX_LENGTH); + } +} + +void DictionaryColumn::scanValue(const SegmentState& dataState, uint64_t startOffset, + uint64_t length, StringChunkData* result, uint64_t offsetInResult) const { + auto& stringDataChunk = *result->getDictionaryChunk().getStringDataChunk(); + auto& offsetChunk = *result->getDictionaryChunk().getOffsetChunk(); + auto& indexChunk = *result->getIndexColumnChunk(); + if (stringDataChunk.getCapacity() < stringDataChunk.getNumValues() + length) { + stringDataChunk.resize(std::bit_ceil(stringDataChunk.getNumValues() + length)); + } + if (offsetChunk.getNumValues() == offsetChunk.getCapacity()) { + offsetChunk.resize(std::bit_ceil(offsetChunk.getNumValues() + 1)); + } + if (offsetInResult >= indexChunk.getCapacity()) { + indexChunk.resize(std::bit_ceil(offsetInResult + 1)); + } + dataColumn->scanSegment(dataState, startOffset, length, + stringDataChunk.getData() + stringDataChunk.getNumValues()); + indexChunk.setValue(offsetChunk.getNumValues(), offsetInResult); + offsetChunk.setValue(stringDataChunk.getNumValues(), + offsetChunk.getNumValues()); + stringDataChunk.setNumValues(stringDataChunk.getNumValues() + length); +} + +bool DictionaryColumn::canCommitInPlace(const SegmentState& state, uint64_t numNewStrings, + uint64_t totalStringLengthToAdd) const { + if (!canDataCommitInPlace( + StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA), + totalStringLengthToAdd)) { + return false; + } + if (!canOffsetCommitInPlace( + StringColumn::getChildState(state, StringColumn::ChildStateIndex::OFFSET), + StringColumn::getChildState(state, StringColumn::ChildStateIndex::DATA), numNewStrings, + totalStringLengthToAdd)) { + return false; + } + return true; +} + +bool DictionaryColumn::canDataCommitInPlace(const SegmentState& dataState, + uint64_t totalStringLengthToAdd) { + // Make sure there is sufficient space in the data chunk (not currently compressed) + auto totalStringDataAfterUpdate = dataState.metadata.numValues + totalStringLengthToAdd; + if (totalStringDataAfterUpdate > dataState.metadata.getNumPages() * LBUG_PAGE_SIZE) { + // Data cannot be updated in place + return false; + } + return true; +} + +bool DictionaryColumn::canOffsetCommitInPlace(const SegmentState& offsetState, + const SegmentState& dataState, uint64_t numNewStrings, uint64_t totalStringLengthToAdd) const { + auto totalStringOffsetsAfterUpdate = dataState.metadata.numValues + totalStringLengthToAdd; + auto offsetCapacity = + offsetState.metadata.compMeta.numValues(LBUG_PAGE_SIZE, offsetColumn->getDataType()) * + offsetState.metadata.getNumPages(); + auto numStringsAfterUpdate = offsetState.metadata.numValues + numNewStrings; + if (numStringsAfterUpdate > offsetCapacity) { + // Offsets cannot be updated in place + return false; + } + // Indices are limited to 32 bits but in theory could be larger than that since the offset + // column can grow beyond the node group size. + // + // E.g. one big string is written first, followed by NODE_GROUP_SIZE-1 small strings, + // which are all updated in-place many times (which may fit if the first string is large + // enough that 2^n minus the first string's size is large enough to fit the other strings, + // for some n. + // 32 bits should give plenty of space for updates. + if (numStringsAfterUpdate > std::numeric_limits::max()) [[unlikely]] { + return false; + } + if (offsetState.metadata.compMeta.canAlwaysUpdateInPlace()) { + return true; + } + InPlaceUpdateLocalState localUpdateState{}; + if (!offsetState.metadata.compMeta.canUpdateInPlace( + (const uint8_t*)&totalStringOffsetsAfterUpdate, 0 /*offset*/, 1 /*numValues*/, + offsetColumn->getDataType().getPhysicalType(), localUpdateState)) { + return false; + } + return true; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/in_mem_chunked_node_group_collection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/in_mem_chunked_node_group_collection.cpp new file mode 100644 index 0000000000..32ff78d440 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/in_mem_chunked_node_group_collection.cpp @@ -0,0 +1,51 @@ +#include "storage/table/in_mem_chunked_node_group_collection.h" + +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +void InMemChunkedNodeGroupCollection::append(MemoryManager& memoryManager, + const std::vector& vectors, row_idx_t startRowInVectors, + row_idx_t numRowsToAppend) { + if (chunkedGroups.empty()) { + chunkedGroups.push_back(std::make_unique(memoryManager, types, + false /*enableCompression*/, common::StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, + 0 /*startOffset*/)); + } + row_idx_t numRowsAppended = 0; + while (numRowsAppended < numRowsToAppend) { + auto& lastChunkedGroup = chunkedGroups.back(); + auto numRowsToAppendInGroup = std::min(numRowsToAppend - numRowsAppended, + common::StorageConfig::CHUNKED_NODE_GROUP_CAPACITY - lastChunkedGroup->getNumRows()); + lastChunkedGroup->append(vectors, startRowInVectors, numRowsToAppendInGroup); + if (lastChunkedGroup->getNumRows() == common::StorageConfig::CHUNKED_NODE_GROUP_CAPACITY) { + lastChunkedGroup->setUnused(memoryManager); + chunkedGroups.push_back(std::make_unique(memoryManager, types, + false /*enableCompression*/, common::StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, + 0 /* startRowIdx */)); + } + numRowsAppended += numRowsToAppendInGroup; + } +} + +void InMemChunkedNodeGroupCollection::merge(std::unique_ptr chunkedGroup) { + KU_ASSERT(chunkedGroup->getNumColumns() == types.size()); + for (auto i = 0u; i < chunkedGroup->getNumColumns(); i++) { + KU_ASSERT(chunkedGroup->getColumnChunk(i).getDataType() == types[i]); + } + chunkedGroups.push_back(std::move(chunkedGroup)); +} + +void InMemChunkedNodeGroupCollection::merge(InMemChunkedNodeGroupCollection& other) { + chunkedGroups.reserve(chunkedGroups.size() + other.chunkedGroups.size()); + for (auto& chunkedGroup : other.chunkedGroups) { + merge(std::move(chunkedGroup)); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/in_memory_exception_chunk.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/in_memory_exception_chunk.cpp new file mode 100644 index 0000000000..96990f89f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/in_memory_exception_chunk.cpp @@ -0,0 +1,160 @@ +#include "storage/table/in_memory_exception_chunk.h" + +#include + +#include "common/utils.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/compression/float_compression.h" +#include "storage/storage_utils.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk_data.h" +#include + +namespace lbug::storage { + +using namespace common; +using namespace transaction; + +template +using ExceptionInBuffer = std::array::sizeInBytes()>; + +template +InMemoryExceptionChunk::InMemoryExceptionChunk(const SegmentState& state, FileHandle* dataFH, + MemoryManager* memoryManager, ShadowFile* shadowFile) + : exceptionCount(state.metadata.compMeta.floatMetadata()->exceptionCount), + finalizedExceptionCount(exceptionCount), + exceptionCapacity(state.metadata.compMeta.floatMetadata()->exceptionCapacity), + emptyMask(exceptionCapacity), + column(std::make_unique("ALPExceptionChunk", physicalType, dataFH, memoryManager, + shadowFile, false, false /*has nulls*/)) { + const auto exceptionBaseCursor = + getExceptionPageCursor(state.metadata, PageCursor{state.metadata.getStartPageIdx(), 0}, + state.metadata.compMeta.floatMetadata()->exceptionCapacity); + // for ALP exceptions we don't care about the statistics + const auto compMeta = + CompressionMetadata(StorageValue{0}, StorageValue{1}, CompressionType::UNCOMPRESSED); + const auto exceptionChunkMeta = ColumnChunkMetadata(exceptionBaseCursor.pageIdx, + safeIntegerConversion( + EncodeException::numPagesFromExceptions(exceptionCapacity)), + exceptionCapacity, compMeta); + chunkState = std::make_unique(exceptionChunkMeta, + EncodeException::exceptionBytesPerPage() / EncodeException::sizeInBytes()); + + chunkData = + std::make_unique(*memoryManager, physicalType, false, exceptionChunkMeta, + false /*all written data is non-null and nulls are kept in a separate mask in-memory*/); + chunkData->setToInMemory(); + column->scanSegment(*chunkState, chunkData.get(), 0, chunkState->metadata.numValues); +} + +template +InMemoryExceptionChunk::~InMemoryExceptionChunk() = default; + +template +void InMemoryExceptionChunk::finalizeAndFlushToDisk(SegmentState& state) { + finalize(state); + + column->writeSegment(*chunkData, *chunkState, 0, *chunkData, 0, exceptionCapacity); +} + +template +void InMemoryExceptionChunk::finalize(SegmentState& state) { + // removes holes + sorts exception chunk + finalizedExceptionCount = 0; + for (size_t i = 0; i < exceptionCount; ++i) { + if (!emptyMask.isNull(i)) { + ++finalizedExceptionCount; + if (finalizedExceptionCount - 1 == i) { + continue; + } + writeException(getExceptionAt(i), finalizedExceptionCount - 1); + } + } + + KU_ASSERT( + finalizedExceptionCount <= state.metadata.compMeta.floatMetadata()->exceptionCapacity); + state.metadata.compMeta.floatMetadata()->exceptionCount = finalizedExceptionCount; + + ExceptionInBuffer* exceptionWordBuffer = + reinterpret_cast*>(chunkData->getData()); + std::sort(exceptionWordBuffer, exceptionWordBuffer + finalizedExceptionCount, + [](ExceptionInBuffer& a, ExceptionInBuffer& b) { + return EncodeExceptionView{reinterpret_cast(&a)}.getValue() < + EncodeExceptionView{reinterpret_cast(&b)}.getValue(); + }); + std::memset(chunkData->getData() + finalizedExceptionCount * EncodeException::sizeInBytes(), + 0, (exceptionCount - finalizedExceptionCount) * EncodeException::sizeInBytes()); + emptyMask.setNullFromRange(0, finalizedExceptionCount, false); + emptyMask.setNullFromRange(finalizedExceptionCount, (exceptionCount - finalizedExceptionCount), + true); + exceptionCount = finalizedExceptionCount; + chunkData->setNumValues(finalizedExceptionCount); +} + +template +void InMemoryExceptionChunk::addException(EncodeException exception) { + KU_ASSERT(exceptionCount < exceptionCapacity); + ++exceptionCount; + writeException(exception, exceptionCount - 1); + emptyMask.setNull(exceptionCount - 1, false); +} + +template +void InMemoryExceptionChunk::removeExceptionAt(size_t exceptionIdx) { + // removing an exception does not free up space in the exception buffer + emptyMask.setNull(exceptionIdx, true); +} + +template +EncodeException InMemoryExceptionChunk::getExceptionAt(size_t exceptionIdx) const { + KU_ASSERT(exceptionIdx < exceptionCount); + auto bytesInBuffer = chunkData->getValue>(exceptionIdx); + return EncodeExceptionView{reinterpret_cast(&bytesInBuffer)}.getValue(); +} + +template +void InMemoryExceptionChunk::writeException(EncodeException exception, size_t exceptionIdx) { + KU_ASSERT(exceptionIdx < exceptionCount); + EncodeExceptionView{reinterpret_cast(chunkData->getData())}.setValue(exception, + exceptionIdx); +} + +template +offset_t InMemoryExceptionChunk::findFirstExceptionAtOrPastOffset(offset_t offsetInChunk) const { + // binary search for chunkOffset in exceptions + // we only search among non-finalized exceptions + + offset_t lo = 0; + offset_t hi = finalizedExceptionCount; + while (lo < hi) { + const size_t curExceptionIdx = (lo + hi) / 2; + EncodeException lastException = getExceptionAt(curExceptionIdx); + + if (lastException.posInChunk < offsetInChunk) { + lo = curExceptionIdx + 1; + } else { + hi = curExceptionIdx; + } + } + + return lo; +} + +template +PageCursor InMemoryExceptionChunk::getExceptionPageCursor(const ColumnChunkMetadata& metadata, + PageCursor pageBaseCursor, size_t exceptionCapacity) { + const size_t numExceptionPages = EncodeException::numPagesFromExceptions(exceptionCapacity); + const size_t exceptionPageOffset = metadata.getNumPages() - numExceptionPages; + KU_ASSERT(exceptionPageOffset == (page_idx_t)exceptionPageOffset); + return {pageBaseCursor.pageIdx + (page_idx_t)exceptionPageOffset, 0}; +} + +template +size_t InMemoryExceptionChunk::getExceptionCount() const { + return finalizedExceptionCount; +} + +template class InMemoryExceptionChunk; +template class InMemoryExceptionChunk; + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/lazy_segment_scanner.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/lazy_segment_scanner.cpp new file mode 100644 index 0000000000..23cdf1510b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/lazy_segment_scanner.cpp @@ -0,0 +1,47 @@ +#include "storage/table/lazy_segment_scanner.h" + +namespace lbug::storage { +void LazySegmentScanner::Iterator::advance(common::offset_t n) { + segmentScanner.rangeSegments(*this, n, + [this](auto& segmentData, auto, auto lengthInSegment, auto) { + KU_ASSERT(segmentData.length > offsetInSegment); + if (segmentData.length - offsetInSegment == lengthInSegment) { + ++segmentIdx; + offsetInSegment = 0; + } else { + offsetInSegment += lengthInSegment; + } + }); +} + +void LazySegmentScanner::scanSegment(common::offset_t offsetInSegment, + common::offset_t segmentLength, scan_func_t newScanFunc) { + segments.emplace_back(nullptr, offsetInSegment, segmentLength, std::move(newScanFunc)); + numValues += segmentLength; +} + +void LazySegmentScanner::applyCommittedUpdates(const UpdateInfo& updateInfo, + const transaction::Transaction* transaction, common::offset_t startRow, + common::offset_t numRows) { + KU_ASSERT(numRows == numValues); + rangeSegments(begin(), numRows, + [&](auto& segment, common::offset_t, common::offset_t lengthInSegment, + common::offset_t offsetInChunk) { + updateInfo.iterateScan(transaction, startRow + offsetInChunk, lengthInSegment, 0, + [&](const VectorUpdateInfo& vecUpdateInfo, uint64_t i, + uint64_t posInOutput) -> void { + scanSegmentIfNeeded(segment); + segment.segmentData->write(vecUpdateInfo.data.get(), i, posInOutput, 1); + }); + }); +} + +void LazySegmentScanner::scanSegmentIfNeeded(LazySegmentData& segment) { + if (segment.segmentData == nullptr) { + segment.segmentData = ColumnChunkFactory::createColumnChunkData(mm, columnType.copy(), + enableCompression, segment.length, ResidencyState::IN_MEMORY); + + segment.scanFunc(*segment.segmentData, segment.startOffsetInSegment, segment.length); + } +} +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/list_chunk_data.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/list_chunk_data.cpp new file mode 100644 index 0000000000..c57c9efb80 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/list_chunk_data.cpp @@ -0,0 +1,462 @@ +#include "storage/table/list_chunk_data.h" + +#include + +#include "common/data_chunk/sel_vector.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/list_column.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +ListChunkData::ListChunkData(MemoryManager& memoryManager, LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState) + : ColumnChunkData{memoryManager, std::move(dataType), capacity, enableCompression, + residencyState, true /*hasNullData*/} { + offsetColumnChunk = ColumnChunkFactory::createColumnChunkData(memoryManager, + LogicalType::UINT64(), enableCompression, capacity, residencyState, false /*hasNull*/); + sizeColumnChunk = ColumnChunkFactory::createColumnChunkData(memoryManager, + LogicalType::UINT32(), enableCompression, capacity, residencyState, false /*hasNull*/); + if (ListColumn::disableCompressionOnData(this->dataType)) { + enableCompression = false; + } + dataColumnChunk = ColumnChunkFactory::createColumnChunkData(memoryManager, + ListType::getChildType(this->dataType).copy(), enableCompression, 0 /* capacity */, + residencyState); + checkOffsetSortedAsc = false; + KU_ASSERT(this->dataType.getPhysicalType() == PhysicalTypeID::LIST || + this->dataType.getPhysicalType() == PhysicalTypeID::ARRAY); +} + +ListChunkData::ListChunkData(MemoryManager& memoryManager, LogicalType dataType, + bool enableCompression, const ColumnChunkMetadata& metadata) + : ColumnChunkData{memoryManager, std::move(dataType), enableCompression, metadata, + true /*hasNullData*/}, + checkOffsetSortedAsc{false} { + offsetColumnChunk = ColumnChunkFactory::createColumnChunkData(memoryManager, + LogicalType::UINT64(), enableCompression, 0, ResidencyState::ON_DISK); + sizeColumnChunk = ColumnChunkFactory::createColumnChunkData(memoryManager, + LogicalType::UINT32(), enableCompression, 0, ResidencyState::ON_DISK); + if (ListColumn::disableCompressionOnData(this->dataType)) { + enableCompression = false; + } + dataColumnChunk = ColumnChunkFactory::createColumnChunkData(memoryManager, + ListType::getChildType(this->dataType).copy(), enableCompression, 0 /* capacity */, + ResidencyState::ON_DISK); +} + +bool ListChunkData::isOffsetsConsecutiveAndSortedAscending(uint64_t startPos, + uint64_t endPos) const { + offset_t prevEndOffset = getListStartOffset(startPos); + for (auto i = startPos; i < endPos; i++) { + offset_t currentEndOffset = getListEndOffset(i); + auto size = getListSize(i); + prevEndOffset += size; + if (currentEndOffset != prevEndOffset) { + return false; + } + } + return true; +} + +offset_t ListChunkData::getListStartOffset(offset_t offset) const { + if (numValues == 0 || (offset != numValues && nullData->isNull(offset))) { + return 0; + } + KU_ASSERT(offset == numValues || getListEndOffset(offset) >= getListSize(offset)); + return offset == numValues ? getListEndOffset(offset - 1) : + getListEndOffset(offset) - getListSize(offset); +} + +offset_t ListChunkData::getListEndOffset(offset_t offset) const { + if (numValues == 0 || nullData->isNull(offset)) { + return 0; + } + KU_ASSERT(offset < numValues); + return offsetColumnChunk->getValue(offset); +} + +list_size_t ListChunkData::getListSize(offset_t offset) const { + if (numValues == 0 || nullData->isNull(offset)) { + return 0; + } + KU_ASSERT(offset < sizeColumnChunk->getNumValues()); + return sizeColumnChunk->getValue(offset); +} + +void ListChunkData::setOffsetChunkValue(offset_t val, offset_t pos) { + offsetColumnChunk->setValue(val, pos); + + // we will keep numValues in the main column synchronized + numValues = offsetColumnChunk->getNumValues(); +} + +void ListChunkData::append(const ColumnChunkData* other, offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + checkOffsetSortedAsc = true; + auto& otherListChunk = other->cast(); + nullData->append(other->getNullData(), startPosInOtherChunk, numValuesToAppend); + offset_t offsetInDataChunkToAppend = dataColumnChunk->getNumValues(); + for (auto i = 0u; i < numValuesToAppend; i++) { + auto appendSize = otherListChunk.getListSize(startPosInOtherChunk + i); + sizeColumnChunk->setValue(appendSize, numValues); + offsetInDataChunkToAppend += appendSize; + setOffsetChunkValue(offsetInDataChunkToAppend, numValues); + } + dataColumnChunk->resize(offsetInDataChunkToAppend); + for (auto i = 0u; i < numValuesToAppend; i++) { + auto startOffset = otherListChunk.getListStartOffset(startPosInOtherChunk + i); + auto appendSize = otherListChunk.getListSize(startPosInOtherChunk + i); + dataColumnChunk->append(otherListChunk.dataColumnChunk.get(), startOffset, appendSize); + } + KU_ASSERT(sanityCheck()); +} + +void ListChunkData::resetToEmpty() { + ColumnChunkData::resetToEmpty(); + sizeColumnChunk->resetToEmpty(); + offsetColumnChunk->resetToEmpty(); + dataColumnChunk->resetToEmpty(); +} + +void ListChunkData::resetNumValuesFromMetadata() { + ColumnChunkData::resetNumValuesFromMetadata(); + sizeColumnChunk->resetNumValuesFromMetadata(); + offsetColumnChunk->resetNumValuesFromMetadata(); + dataColumnChunk->resetNumValuesFromMetadata(); +} + +void ListChunkData::append(ValueVector* vector, const SelectionView& selView) { + auto numToAppend = selView.getSelSize(); + auto newCapacity = capacity; + while (numValues + numToAppend >= newCapacity) { + newCapacity = std::ceil(newCapacity * 1.5); + } + if (capacity < newCapacity) { + resize(newCapacity); + } + offset_t nextListOffsetInChunk = dataColumnChunk->getNumValues(); + const offset_t appendBaseOffset = numValues; + for (auto i = 0u; i < selView.getSelSize(); i++) { + auto pos = selView[i]; + auto listLen = vector->isNull(pos) ? 0 : vector->getValue(pos).size; + sizeColumnChunk->setValue(listLen, appendBaseOffset + i); + + nullData->setNull(appendBaseOffset + i, vector->isNull(pos)); + + nextListOffsetInChunk += listLen; + setOffsetChunkValue(nextListOffsetInChunk, appendBaseOffset + i); + } + dataColumnChunk->resize(nextListOffsetInChunk); + auto dataVector = ListVector::getDataVector(vector); + for (auto i = 0u; i < selView.getSelSize(); i++) { + auto pos = selView[i]; + if (vector->isNull(pos)) { + continue; + } + copyListValues(vector->getValue(pos), dataVector); + } + KU_ASSERT(sanityCheck()); +} + +void ListChunkData::appendNullList() { + offset_t nextListOffsetInChunk = dataColumnChunk->getNumValues(); + const offset_t appendPosition = numValues; + sizeColumnChunk->setValue(0, appendPosition); + setOffsetChunkValue(nextListOffsetInChunk, appendPosition); + nullData->setNull(appendPosition, true); +} + +void ListChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + KU_ASSERT(offset + length <= numValues); + if (nullData) { + nullData->scan(output, offset, length, posInOutputVector); + } + auto currentListDataSize = ListVector::getDataVectorSize(&output); + auto dataSize = 0ul; + for (auto i = 0u; i < length; i++) { + auto listSize = getListSize(offset + i); + output.setValue(posInOutputVector + i, + list_entry_t{currentListDataSize + dataSize, listSize}); + dataSize += listSize; + } + ListVector::resizeDataVector(&output, currentListDataSize + dataSize); + auto dataVector = ListVector::getDataVector(&output); + if (isOffsetsConsecutiveAndSortedAscending(offset, offset + length)) { + dataColumnChunk->scan(*dataVector, getListStartOffset(offset), dataSize, + currentListDataSize); + } else { + for (auto i = 0u; i < length; i++) { + auto startOffset = getListStartOffset(offset + i); + auto listSize = getListSize(offset + i); + dataColumnChunk->scan(*dataVector, startOffset, listSize, currentListDataSize); + currentListDataSize += listSize; + } + } +} + +void ListChunkData::lookup(offset_t offsetInChunk, ValueVector& output, + sel_t posInOutputVector) const { + KU_ASSERT(offsetInChunk < numValues); + output.setNull(posInOutputVector, nullData->isNull(offsetInChunk)); + if (output.isNull(posInOutputVector)) { + return; + } + auto startOffset = getListStartOffset(offsetInChunk); + auto listSize = getListSize(offsetInChunk); + auto dataVector = ListVector::getDataVector(&output); + auto currentListDataSize = ListVector::getDataVectorSize(&output); + ListVector::resizeDataVector(&output, currentListDataSize + listSize); + dataColumnChunk->scan(*dataVector, startOffset, listSize, currentListDataSize); + // reset offset + output.setValue(posInOutputVector, list_entry_t{currentListDataSize, listSize}); +} + +void ListChunkData::initializeScanState(SegmentState& state, const Column* column) const { + ColumnChunkData::initializeScanState(state, column); + + auto* listColumn = ku_dynamic_cast(column); + state.childrenStates.resize(CHILD_COLUMN_COUNT); + sizeColumnChunk->initializeScanState(state.childrenStates[SIZE_COLUMN_CHILD_READ_STATE_IDX], + listColumn->getSizeColumn()); + dataColumnChunk->initializeScanState(state.childrenStates[DATA_COLUMN_CHILD_READ_STATE_IDX], + listColumn->getDataColumn()); + offsetColumnChunk->initializeScanState(state.childrenStates[OFFSET_COLUMN_CHILD_READ_STATE_IDX], + listColumn->getOffsetColumn()); +} + +void ListChunkData::write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, RelMultiplicity) { + KU_ASSERT(chunk->getDataType().getPhysicalType() == dataType.getPhysicalType() && + dstOffsets->getDataType().getPhysicalType() == PhysicalTypeID::INTERNAL_ID && + chunk->getNumValues() == dstOffsets->getNumValues()); + checkOffsetSortedAsc = true; + offset_t currentIndex = dataColumnChunk->getNumValues(); + auto& otherListChunk = chunk->cast(); + dataColumnChunk->resize( + dataColumnChunk->getNumValues() + otherListChunk.dataColumnChunk->getNumValues()); + dataColumnChunk->append(otherListChunk.dataColumnChunk.get(), 0, + otherListChunk.dataColumnChunk->getNumValues()); + offset_t maxDstOffset = 0; + for (auto i = 0u; i < dstOffsets->getNumValues(); i++) { + auto posInChunk = dstOffsets->getValue(i); + if (posInChunk > maxDstOffset) { + maxDstOffset = posInChunk; + } + } + while (maxDstOffset >= numValues) { + appendNullList(); + } + for (auto i = 0u; i < dstOffsets->getNumValues(); i++) { + auto posInChunk = dstOffsets->getValue(i); + auto appendSize = otherListChunk.getListSize(i); + currentIndex += appendSize; + nullData->setNull(posInChunk, otherListChunk.nullData->isNull(i)); + setOffsetChunkValue(currentIndex, posInChunk); + sizeColumnChunk->setValue(appendSize, posInChunk); + } + KU_ASSERT(sanityCheck()); +} + +void ListChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + checkOffsetSortedAsc = true; + auto appendSize = + vector->isNull(offsetInVector) ? 0 : vector->getValue(offsetInVector).size; + dataColumnChunk->resize(dataColumnChunk->getNumValues() + appendSize); + while (offsetInChunk >= numValues) { + appendNullList(); + } + auto isNull = vector->isNull(offsetInVector); + nullData->setNull(offsetInChunk, isNull); + if (!isNull) { + auto dataVector = ListVector::getDataVector(vector); + copyListValues(vector->getValue(offsetInVector), dataVector); + + sizeColumnChunk->setValue(appendSize, offsetInChunk); + setOffsetChunkValue(dataColumnChunk->getNumValues(), offsetInChunk); + } + KU_ASSERT(sanityCheck()); +} + +void ListChunkData::write(const ColumnChunkData* srcChunk, offset_t srcOffsetInChunk, + offset_t dstOffsetInChunk, offset_t numValuesToCopy) { + KU_ASSERT(srcChunk->getDataType().getPhysicalType() == PhysicalTypeID::LIST || + srcChunk->getDataType().getPhysicalType() == PhysicalTypeID::ARRAY); + checkOffsetSortedAsc = true; + auto& srcListChunk = srcChunk->cast(); + auto offsetInDataChunkToAppend = dataColumnChunk->getNumValues(); + for (auto i = 0u; i < numValuesToCopy; i++) { + auto appendSize = srcListChunk.getListSize(srcOffsetInChunk + i); + offsetInDataChunkToAppend += appendSize; + sizeColumnChunk->setValue(appendSize, dstOffsetInChunk + i); + setOffsetChunkValue(offsetInDataChunkToAppend, dstOffsetInChunk + i); + nullData->setNull(dstOffsetInChunk + i, + srcListChunk.nullData->isNull(srcOffsetInChunk + i)); + } + dataColumnChunk->resize(offsetInDataChunkToAppend); + for (auto i = 0u; i < numValuesToCopy; i++) { + auto startOffsetInSrcChunk = srcListChunk.getListStartOffset(srcOffsetInChunk + i); + auto appendSize = srcListChunk.getListSize(srcOffsetInChunk + i); + dataColumnChunk->append(srcListChunk.dataColumnChunk.get(), startOffsetInSrcChunk, + appendSize); + } + KU_ASSERT(sanityCheck()); +} + +void ListChunkData::copyListValues(const list_entry_t& entry, ValueVector* dataVector) { + auto numListValuesToCopy = entry.size; + auto numListValuesCopied = 0u; + + SelectionVector selVector; + selVector.setToFiltered(); + while (numListValuesCopied < numListValuesToCopy) { + auto numListValuesToCopyInBatch = + std::min(numListValuesToCopy - numListValuesCopied, DEFAULT_VECTOR_CAPACITY); + selVector.setSelSize(numListValuesToCopyInBatch); + for (auto j = 0u; j < numListValuesToCopyInBatch; j++) { + selVector[j] = entry.offset + numListValuesCopied + j; + } + dataColumnChunk->append(dataVector, selVector); + numListValuesCopied += numListValuesToCopyInBatch; + } +} + +void ListChunkData::resetOffset() { + offset_t nextListOffsetReset = 0; + for (auto i = 0u; i < numValues; i++) { + auto listSize = getListSize(i); + nextListOffsetReset += uint64_t(listSize); + setOffsetChunkValue(nextListOffsetReset, i); + sizeColumnChunk->setValue(listSize, i); + } +} + +void ListChunkData::finalize() { + // rewrite the column chunk for better scanning performance + auto newColumnChunk = ColumnChunkFactory::createColumnChunkData(getMemoryManager(), + dataType.copy(), enableCompression, capacity, ResidencyState::IN_MEMORY); + uint64_t totalListLen = dataColumnChunk->getNumValues(); + uint64_t resizeThreshold = dataColumnChunk->getCapacity() / 2; + // if the list is not very long, we do not need to rewrite + if (totalListLen < resizeThreshold) { + return; + } + // if we do not trigger random write, we do not need to rewrite + if (!checkOffsetSortedAsc) { + return; + } + // if the list is in ascending order, we do not need to rewrite + if (isOffsetsConsecutiveAndSortedAscending(0, numValues)) { + return; + } + auto& newListChunk = newColumnChunk->cast(); + newListChunk.resize(numValues); + newListChunk.getDataColumnChunk()->resize(totalListLen); + auto newDataColumnChunk = newListChunk.getDataColumnChunk(); + newDataColumnChunk->resize(totalListLen); + offset_t offsetInChunk = 0; + offset_t currentIndex = 0; + for (auto i = 0u; i < numValues; i++) { + if (nullData->isNull(i)) { + newListChunk.appendNullList(); + } else { + auto startOffset = getListStartOffset(i); + auto listSize = getListSize(i); + newDataColumnChunk->append(dataColumnChunk.get(), startOffset, listSize); + offsetInChunk += listSize; + newListChunk.nullData->setNull(currentIndex, false); + newListChunk.sizeColumnChunk->setValue(listSize, currentIndex); + newListChunk.setOffsetChunkValue(offsetInChunk, currentIndex); + } + currentIndex++; + } + KU_ASSERT(newListChunk.sanityCheck()); + // Move offsets, null, data from newListChunk to this column chunk. And release indices. + resetFromOtherChunk(&newListChunk); +} + +void ListChunkData::resetFromOtherChunk(ListChunkData* other) { + nullData = std::move(other->nullData); + sizeColumnChunk = std::move(other->sizeColumnChunk); + dataColumnChunk = std::move(other->dataColumnChunk); + offsetColumnChunk = std::move(other->offsetColumnChunk); + numValues = other->numValues; + checkOffsetSortedAsc = false; +} + +bool ListChunkData::sanityCheck() const { + KU_ASSERT(ColumnChunkData::sanityCheck()); + KU_ASSERT(sizeColumnChunk->sanityCheck()); + KU_ASSERT(offsetColumnChunk->sanityCheck()); + KU_ASSERT(getDataColumnChunk()->sanityCheck()); + return sizeColumnChunk->getNumValues() == numValues; +} + +uint64_t ListChunkData::getEstimatedMemoryUsage() const { + return ColumnChunkData::getEstimatedMemoryUsage() + sizeColumnChunk->getEstimatedMemoryUsage() + + dataColumnChunk->getEstimatedMemoryUsage() + + offsetColumnChunk->getEstimatedMemoryUsage(); +} + +void ListChunkData::serialize(Serializer& serializer) const { + ColumnChunkData::serialize(serializer); + serializer.writeDebuggingInfo("size_column_chunk"); + sizeColumnChunk->serialize(serializer); + serializer.writeDebuggingInfo("data_column_chunk"); + dataColumnChunk->serialize(serializer); + serializer.writeDebuggingInfo("offset_column_chunk"); + offsetColumnChunk->serialize(serializer); +} + +void ListChunkData::deserialize(Deserializer& deSer, ColumnChunkData& chunkData) { + std::string key; + deSer.validateDebuggingInfo(key, "size_column_chunk"); + chunkData.cast().sizeColumnChunk = + ColumnChunkData::deserialize(chunkData.getMemoryManager(), deSer); + deSer.validateDebuggingInfo(key, "data_column_chunk"); + chunkData.cast().dataColumnChunk = + ColumnChunkData::deserialize(chunkData.getMemoryManager(), deSer); + deSer.validateDebuggingInfo(key, "offset_column_chunk"); + chunkData.cast().offsetColumnChunk = + ColumnChunkData::deserialize(chunkData.getMemoryManager(), deSer); +} + +void ListChunkData::flush(PageAllocator& pageAllocator) { + ColumnChunkData::flush(pageAllocator); + sizeColumnChunk->flush(pageAllocator); + dataColumnChunk->flush(pageAllocator); + offsetColumnChunk->flush(pageAllocator); +} + +void ListChunkData::reclaimStorage(PageAllocator& pageAllocator) { + ColumnChunkData::reclaimStorage(pageAllocator); + sizeColumnChunk->reclaimStorage(pageAllocator); + dataColumnChunk->reclaimStorage(pageAllocator); + offsetColumnChunk->reclaimStorage(pageAllocator); +} +uint64_t ListChunkData::getSizeOnDisk() const { + return ColumnChunkData::getSizeOnDisk() + sizeColumnChunk->getSizeOnDisk() + + dataColumnChunk->getSizeOnDisk() + offsetColumnChunk->getSizeOnDisk(); +} +uint64_t ListChunkData::getMinimumSizeOnDisk() const { + return ColumnChunkData::getMinimumSizeOnDisk() + sizeColumnChunk->getMinimumSizeOnDisk() + + dataColumnChunk->getMinimumSizeOnDisk() + offsetColumnChunk->getMinimumSizeOnDisk(); +} + +uint64_t ListChunkData::getSizeOnDiskInMemoryStats() const { + return ColumnChunkData::getSizeOnDiskInMemoryStats() + + sizeColumnChunk->getSizeOnDiskInMemoryStats() + + dataColumnChunk->getSizeOnDiskInMemoryStats() + + offsetColumnChunk->getSizeOnDiskInMemoryStats(); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/list_column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/list_column.cpp new file mode 100644 index 0000000000..995cae8509 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/list_column.cpp @@ -0,0 +1,435 @@ +#include "storage/table/list_column.h" + +#include + +#include "common/assert.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/list_chunk_data.h" +#include "storage/table/null_column.h" +#include + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +offset_t ListOffsetSizeInfo::getListStartOffset(uint64_t pos) const { + if (numTotal == 0) { + return 0; + } + return pos == numTotal ? getListEndOffset(pos - 1) : getListEndOffset(pos) - getListSize(pos); +} + +offset_t ListOffsetSizeInfo::getListEndOffset(uint64_t pos) const { + if (numTotal == 0) { + return 0; + } + KU_ASSERT(pos < offsetColumnChunk->getNumValues()); + return offsetColumnChunk->getValue(pos); +} + +list_size_t ListOffsetSizeInfo::getListSize(uint64_t pos) const { + if (numTotal == 0) { + return 0; + } + KU_ASSERT(pos < sizeColumnChunk->getNumValues()); + return sizeColumnChunk->getValue(pos); +} + +bool ListOffsetSizeInfo::isOffsetSortedAscending(uint64_t startPos, uint64_t endPos) const { + offset_t prevEndOffset = getListStartOffset(startPos); + for (auto i = startPos; i < endPos; i++) { + offset_t currentEndOffset = getListEndOffset(i); + auto size = getListSize(i); + prevEndOffset += size; + if (currentEndOffset != prevEndOffset) { + return false; + } + } + return true; +} + +ListColumn::ListColumn(std::string name, LogicalType dataType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression) + : Column{std::move(name), std::move(dataType), dataFH, mm, shadowFile, enableCompression, + true /* requireNullColumn */} { + auto offsetColName = + StorageUtils::getColumnName(this->name, StorageUtils::ColumnType::OFFSET, "offset_"); + auto sizeColName = + StorageUtils::getColumnName(this->name, StorageUtils::ColumnType::OFFSET, ""); + auto dataColName = StorageUtils::getColumnName(this->name, StorageUtils::ColumnType::DATA, ""); + sizeColumn = std::make_unique(sizeColName, LogicalType::UINT32(), dataFH, mm, + shadowFile, enableCompression, false /*requireNullColumn*/); + offsetColumn = std::make_unique(offsetColName, LogicalType::UINT64(), dataFH, mm, + shadowFile, enableCompression, false /*requireNullColumn*/); + if (disableCompressionOnData(this->dataType)) { + enableCompression = false; + } + dataColumn = ColumnFactory::createColumn(dataColName, + ListType::getChildType(this->dataType).copy(), dataFH, mm, shadowFile, enableCompression); +} + +bool ListColumn::disableCompressionOnData(const LogicalType& dataType) { + if (dataType.getLogicalTypeID() == LogicalTypeID::ARRAY && + (ListType::getChildType(dataType).getPhysicalType() == PhysicalTypeID::FLOAT || + ListType::getChildType(dataType).getPhysicalType() == PhysicalTypeID::DOUBLE)) { + // Force disable compression for floating point types. + return true; + } + return false; +} + +std::unique_ptr ListColumn::flushChunkData(const ColumnChunkData& chunk, + PageAllocator& pageAllocator) { + auto flushedChunk = flushNonNestedChunkData(chunk, pageAllocator); + auto& listChunk = chunk.cast(); + auto& flushedListChunk = flushedChunk->cast(); + flushedListChunk.setOffsetColumnChunk( + Column::flushChunkData(*listChunk.getOffsetColumnChunk(), pageAllocator)); + flushedListChunk.setSizeColumnChunk( + Column::flushChunkData(*listChunk.getSizeColumnChunk(), pageAllocator)); + flushedListChunk.setDataColumnChunk( + Column::flushChunkData(*listChunk.getDataColumnChunk(), pageAllocator)); + return flushedChunk; +} + +void ListColumn::scanSegment(const SegmentState& state, offset_t startOffsetInChunk, + row_idx_t numValuesToScan, ValueVector* resultVector, offset_t offsetInResult) const { + if (nullColumn) { + KU_ASSERT(state.nullState); + nullColumn->scanSegment(*state.nullState, startOffsetInChunk, numValuesToScan, resultVector, + offsetInResult); + } + auto listOffsetSizeInfo = getListOffsetSizeInfo(state, startOffsetInChunk, numValuesToScan); + if (!resultVector->state || resultVector->state->getSelVector().isUnfiltered()) { + scanUnfiltered(state, resultVector, numValuesToScan, listOffsetSizeInfo, offsetInResult); + } else { + scanFiltered(state, startOffsetInChunk, resultVector, listOffsetSizeInfo, offsetInResult); + } +} + +void ListColumn::scanSegment(const SegmentState& state, ColumnChunkData* resultChunk, + common::offset_t startOffsetInSegment, common::row_idx_t numValuesToScan) const { + auto startOffsetInResult = resultChunk->getNumValues(); + Column::scanSegment(state, resultChunk, startOffsetInSegment, numValuesToScan); + if (numValuesToScan == 0) { + return; + } + // Column::scanSegment above modifies the size of the offset/size chunks before we scan + // them + // Revert this so that they scan to the correct position + // FIXME(bmwinger): there should be a better solution to this, but it will probably be removed + // later anyway + auto& listColumnChunk = resultChunk->cast(); + listColumnChunk.getOffsetColumnChunk()->setNumValues(startOffsetInResult); + listColumnChunk.getSizeColumnChunk()->setNumValues(startOffsetInResult); + + offsetColumn->scanSegment(state.childrenStates[OFFSET_COLUMN_CHILD_READ_STATE_IDX], + listColumnChunk.getOffsetColumnChunk(), startOffsetInSegment, numValuesToScan); + sizeColumn->scanSegment(state.childrenStates[SIZE_COLUMN_CHILD_READ_STATE_IDX], + listColumnChunk.getSizeColumnChunk(), startOffsetInSegment, numValuesToScan); + auto resizeNumValues = listColumnChunk.getDataColumnChunk()->getNumValues(); + bool isOffsetSortedAscending = true; + KU_ASSERT(listColumnChunk.getSizeColumnChunk()->getNumValues() == + startOffsetInResult + numValuesToScan); + offset_t prevOffset = listColumnChunk.getListStartOffset(startOffsetInResult); + for (auto i = startOffsetInResult; i < startOffsetInResult + numValuesToScan; i++) { + auto currentEndOffset = listColumnChunk.getListEndOffset(i); + auto appendSize = listColumnChunk.getListSize(i); + prevOffset += appendSize; + if (currentEndOffset != prevOffset) { + isOffsetSortedAscending = false; + } + resizeNumValues += appendSize; + } + if (isOffsetSortedAscending) { + listColumnChunk.resizeDataColumnChunk(std::bit_ceil(resizeNumValues)); + offset_t startListOffset = listColumnChunk.getListStartOffset(startOffsetInResult); + offset_t endListOffset = + listColumnChunk.getListStartOffset(startOffsetInResult + numValuesToScan); + KU_ASSERT(endListOffset >= startListOffset); + dataColumn->scanSegment(state.childrenStates[DATA_COLUMN_CHILD_READ_STATE_IDX], + listColumnChunk.getDataColumnChunk(), startListOffset, endListOffset - startListOffset); + } else { + listColumnChunk.resizeDataColumnChunk(std::bit_ceil(resizeNumValues)); + for (auto i = startOffsetInResult; i < startOffsetInResult + numValuesToScan; i++) { + offset_t startListOffset = listColumnChunk.getListStartOffset(i); + offset_t endListOffset = listColumnChunk.getListEndOffset(i); + dataColumn->scanSegment(state.childrenStates[DATA_COLUMN_CHILD_READ_STATE_IDX], + listColumnChunk.getDataColumnChunk(), startListOffset, + endListOffset - startListOffset); + } + } + listColumnChunk.resetOffset(); + + KU_ASSERT(listColumnChunk.sanityCheck()); +} + +void ListColumn::lookupInternal(const SegmentState& state, offset_t nodeOffset, + ValueVector* resultVector, uint32_t posInVector) const { + auto [nodeGroupIdx, offsetInChunk] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(nodeOffset); + const auto listEndOffset = readOffset(state, offsetInChunk); + const auto size = readSize(state, offsetInChunk); + const auto listStartOffset = listEndOffset - size; + auto dataVector = ListVector::getDataVector(resultVector); + auto currentListDataSize = ListVector::getDataVectorSize(resultVector); + ListVector::resizeDataVector(resultVector, currentListDataSize + size); + dataColumn->scanSegment(state.childrenStates[ListChunkData::DATA_COLUMN_CHILD_READ_STATE_IDX], + listStartOffset, listEndOffset - listStartOffset, dataVector, currentListDataSize); + resultVector->setValue(posInVector, list_entry_t{currentListDataSize, size}); +} + +void ListColumn::scanUnfiltered(const SegmentState& state, ValueVector* resultVector, + uint64_t numValuesToScan, const ListOffsetSizeInfo& listOffsetInfoInStorage, + offset_t offsetInResult) const { + auto dataVector = ListVector::getDataVector(resultVector); + // Scans append to the end of the vector, so we need to start at the end of the last list + auto startOffsetInDataVector = ListVector::getDataVectorSize(resultVector); + auto offsetInDataVector = startOffsetInDataVector; + + numValuesToScan = std::min(numValuesToScan, listOffsetInfoInStorage.numTotal); + for (auto i = 0u; i < numValuesToScan; i++) { + auto listLen = listOffsetInfoInStorage.getListSize(i); + resultVector->setValue(offsetInResult + i, list_entry_t{offsetInDataVector, listLen}); + offsetInDataVector += listLen; + } + ListVector::resizeDataVector(resultVector, offsetInDataVector); + const bool checkOffsetOrder = + listOffsetInfoInStorage.isOffsetSortedAscending(0, numValuesToScan); + if (checkOffsetOrder) { + auto startListOffsetInStorage = listOffsetInfoInStorage.getListStartOffset(0); + numValuesToScan = numValuesToScan == 0 ? 0 : numValuesToScan - 1; + auto endListOffsetInStorage = listOffsetInfoInStorage.getListEndOffset(numValuesToScan); + dataColumn->scanSegment( + state.childrenStates[ListChunkData::DATA_COLUMN_CHILD_READ_STATE_IDX], + startListOffsetInStorage, endListOffsetInStorage - startListOffsetInStorage, dataVector, + static_cast(startOffsetInDataVector /* offsetInVector */)); + } else { + offsetInDataVector = startOffsetInDataVector; + for (auto i = 0u; i < numValuesToScan; i++) { + // Nulls are scanned to the resultVector first + if (!resultVector->isNull(i)) { + auto startListOffsetInStorage = listOffsetInfoInStorage.getListStartOffset(i); + auto appendSize = listOffsetInfoInStorage.getListSize(i); + dataColumn->scanSegment(state.childrenStates[DATA_COLUMN_CHILD_READ_STATE_IDX], + startListOffsetInStorage, appendSize, dataVector, offsetInDataVector); + offsetInDataVector += appendSize; + } + } + } +} + +void ListColumn::scanFiltered(const SegmentState& state, offset_t startOffsetInSegment, + ValueVector* resultVector, const ListOffsetSizeInfo& listOffsetSizeInfo, + offset_t offsetInResult) const { + auto dataVector = ListVector::getDataVector(resultVector); + auto startOffsetInDataVector = ListVector::getDataVectorSize(resultVector); + auto offsetInDataVector = startOffsetInDataVector; + + for (sel_t i = 0; i < resultVector->state->getSelVector().getSelSize(); i++) { + auto pos = resultVector->state->getSelVector()[i]; + if (startOffsetInSegment + pos - offsetInResult < state.metadata.numValues) { + // The listOffsetSizeInfo starts with the first value being scanned, so the + // startOffsetInSegment parameter is not needed here except for the bounds check + auto listSize = listOffsetSizeInfo.getListSize(pos - offsetInResult); + resultVector->setValue(pos, list_entry_t{(offset_t)offsetInDataVector, listSize}); + offsetInDataVector += listSize; + } + } + ListVector::resizeDataVector(resultVector, offsetInDataVector); + offsetInDataVector = startOffsetInDataVector; + for (auto i = 0u; i < resultVector->state->getSelVector().getSelSize(); i++) { + auto pos = resultVector->state->getSelVector()[i]; + // Nulls are scanned to the resultVector first + if (pos >= offsetInResult && + startOffsetInSegment + pos - offsetInResult < state.metadata.numValues && + !resultVector->isNull(pos)) { + auto startOffsetInStorageToScan = + listOffsetSizeInfo.getListStartOffset(pos - offsetInResult); + auto appendSize = listOffsetSizeInfo.getListSize(pos - offsetInResult); + // If there is a selection vector for the dataVector, its selected positions are not + // being updated at all for this specific segment + KU_ASSERT(!dataVector->state || dataVector->state->getSelVector().isUnfiltered()); + dataColumn->scanSegment(state.childrenStates[DATA_COLUMN_CHILD_READ_STATE_IDX], + startOffsetInStorageToScan, appendSize, dataVector, offsetInDataVector); + offsetInDataVector += resultVector->getValue(pos).size; + } + } +} + +offset_t ListColumn::readOffset(const SegmentState& state, offset_t offsetInNodeGroup) const { + offset_t ret = INVALID_OFFSET; + const auto& offsetState = state.childrenStates[OFFSET_COLUMN_CHILD_READ_STATE_IDX]; + offsetColumn->columnReadWriter->readCompressedValueToPage(offsetState, offsetInNodeGroup, + reinterpret_cast(&ret), 0, offsetColumn->readToPageFunc); + return ret; +} + +list_size_t ListColumn::readSize(const SegmentState& readState, offset_t offsetInNodeGroup) const { + const auto& sizeState = readState.childrenStates[SIZE_COLUMN_CHILD_READ_STATE_IDX]; + offset_t value = INVALID_OFFSET; + sizeColumn->columnReadWriter->readCompressedValueToPage(sizeState, offsetInNodeGroup, + reinterpret_cast(&value), 0, sizeColumn->readToPageFunc); + return value; +} + +ListOffsetSizeInfo ListColumn::getListOffsetSizeInfo(const SegmentState& state, + offset_t startOffsetInSegment, offset_t numOffsetsToRead) const { + auto offsetColumnChunk = ColumnChunkFactory::createColumnChunkData(*mm, LogicalType::INT64(), + enableCompression, numOffsetsToRead, ResidencyState::IN_MEMORY); + auto sizeColumnChunk = ColumnChunkFactory::createColumnChunkData(*mm, LogicalType::UINT32(), + enableCompression, numOffsetsToRead, ResidencyState::IN_MEMORY); + offsetColumn->scanSegment(state.childrenStates[OFFSET_COLUMN_CHILD_READ_STATE_IDX], + offsetColumnChunk.get(), startOffsetInSegment, numOffsetsToRead); + sizeColumn->scanSegment(state.childrenStates[SIZE_COLUMN_CHILD_READ_STATE_IDX], + sizeColumnChunk.get(), startOffsetInSegment, numOffsetsToRead); + auto numValuesScan = offsetColumnChunk->getNumValues(); + return {numValuesScan, std::move(offsetColumnChunk), std::move(sizeColumnChunk)}; +} + +static void appendDataCheckpointState( + std::vector& listDataChunkCheckpointStates, ColumnChunkData& dataChunk, + offset_t inputOffset, offset_t& outputOffset, offset_t numRows) { + if (numRows > 0) { + listDataChunkCheckpointStates.push_back( + SegmentCheckpointState{dataChunk, inputOffset, outputOffset, numRows}); + outputOffset += numRows; + } +} + +static std::vector createListDataChunkCheckpointStates( + ListChunkData& persistentListChunk, std::span segmentCheckpointStates) { + const auto persistentDataChunk = persistentListChunk.getDataColumnChunk(); + row_idx_t newListDataSize = persistentDataChunk->getNumValues(); + + std::vector listDataChunkCheckpointStates; + for (const auto& segmentCheckpointState : segmentCheckpointStates) { + // We append the data for each list entry as separate segment checkpoint states + // List entries with adjacent data are commbined into a single segment checkpoint state + const auto& listChunk = segmentCheckpointState.chunkData.cast(); + offset_t currentSegmentStartOffset = INVALID_OFFSET; + offset_t currentSegmentNumRows = 0; + for (offset_t i = 0; i < segmentCheckpointState.numRows; i++) { + if (listChunk.isNull(segmentCheckpointState.startRowInData + i)) { + // Nulls will have 0 length and start at pos 0, which will work with the logic + // below, but may create more checkpoint states than necessary + continue; + } + const auto currentListStartOffset = + listChunk.getListStartOffset(segmentCheckpointState.startRowInData + i); + const auto currentListLength = + listChunk.getListSize(segmentCheckpointState.startRowInData + i); + if (currentSegmentStartOffset + currentSegmentNumRows == currentListStartOffset) { + currentSegmentNumRows += currentListLength; + } else { + appendDataCheckpointState(listDataChunkCheckpointStates, + *listChunk.getDataColumnChunk(), currentSegmentStartOffset, newListDataSize, + currentSegmentNumRows); + currentSegmentStartOffset = currentListStartOffset; + currentSegmentNumRows = currentListLength; + } + } + appendDataCheckpointState(listDataChunkCheckpointStates, *listChunk.getDataColumnChunk(), + currentSegmentStartOffset, newListDataSize, currentSegmentNumRows); + } + + return listDataChunkCheckpointStates; +} + +std::vector> ListColumn::checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment) const { + if (checkpointState.segmentCheckpointStates.empty()) { + return {}; + } + auto& persistentListChunk = checkpointState.persistentData.cast(); + const auto persistentDataChunk = persistentListChunk.getDataColumnChunk(); + + auto listDataChunkCheckpointStates = createListDataChunkCheckpointStates(persistentListChunk, + checkpointState.segmentCheckpointStates); + + // First, check if we can checkpoint list data chunk in place. + SegmentState chunkState; + checkpointState.persistentData.initializeScanState(chunkState, this); + ColumnCheckpointState listDataCheckpointState(*persistentDataChunk, + std::move(listDataChunkCheckpointStates)); + const auto listDataCanCheckpointInPlace = dataColumn->canCheckpointInPlace( + chunkState.childrenStates[ListChunkData::DATA_COLUMN_CHILD_READ_STATE_IDX], + listDataCheckpointState); + if (!listDataCanCheckpointInPlace) { + // If we cannot checkpoint list data chunk in place, we need to checkpoint the whole chunk + // out of place. + return checkpointColumnChunkOutOfPlace(chunkState, checkpointState, pageAllocator, + canSplitSegment); + } + + const auto persistentListDataSize = persistentDataChunk->getNumValues(); + + // In place checkpoint for list data. + dataColumn->checkpointColumnChunkInPlace( + chunkState.childrenStates[ListChunkData::DATA_COLUMN_CHILD_READ_STATE_IDX], + listDataCheckpointState, pageAllocator); + + // Checkpoint offset data. + std::vector offsetChunkCheckpointStates; + + KU_ASSERT(std::is_sorted(checkpointState.segmentCheckpointStates.begin(), + checkpointState.segmentCheckpointStates.end(), + [](const auto& a, const auto& b) { return a.startRowInData < b.startRowInData; })); + std::vector> offsetsToWrite; + uint64_t totalAppendedListSize = 0; + for (const auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + offsetsToWrite.push_back( + ColumnChunkFactory::createColumnChunkData(*mm, LogicalType::UINT64(), false, + segmentCheckpointState.numRows, ResidencyState::IN_MEMORY)); + const auto& listChunk = segmentCheckpointState.chunkData.cast(); + for (auto i = 0u; i < segmentCheckpointState.numRows; i++) { + // When checkpointing the data chunks we append each list in the checkpoint state to the + // end of the data This loop processes the lists in the same order, so the offsets match + // the ones used by the data chunk checkpoint + totalAppendedListSize += + listChunk.getListSize(segmentCheckpointState.startRowInData + i); + offsetsToWrite.back()->setValue( + persistentListDataSize + totalAppendedListSize, i); + } + offsetChunkCheckpointStates.push_back(SegmentCheckpointState{*offsetsToWrite.back(), 0, + segmentCheckpointState.offsetInSegment, segmentCheckpointState.numRows}); + } + + // We do not allow nested splitting of offset/size segments + offsetColumn->checkpointSegment( + ColumnCheckpointState(*persistentListChunk.getOffsetColumnChunk(), + std::move(offsetChunkCheckpointStates)), + pageAllocator, false); + + // Checkpoint size data. + std::vector sizeChunkCheckpointStates; + for (const auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + sizeChunkCheckpointStates.push_back(SegmentCheckpointState{ + *segmentCheckpointState.chunkData.cast().getSizeColumnChunk(), + segmentCheckpointState.startRowInData, segmentCheckpointState.offsetInSegment, + segmentCheckpointState.numRows}); + } + sizeColumn->checkpointSegment(ColumnCheckpointState(*persistentListChunk.getSizeColumnChunk(), + std::move(sizeChunkCheckpointStates)), + pageAllocator, false); + // Checkpoint null data. + Column::checkpointNullData(checkpointState, pageAllocator); + + KU_ASSERT(persistentListChunk.getNullData()->getNumValues() == + persistentListChunk.getOffsetColumnChunk()->getNumValues() && + persistentListChunk.getNullData()->getNumValues() == + persistentListChunk.getSizeColumnChunk()->getNumValues()); + + persistentListChunk.syncNumValues(); + return {}; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_group.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_group.cpp new file mode 100644 index 0000000000..6ac7c473dd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_group.cpp @@ -0,0 +1,785 @@ +#include "storage/table/node_group.h" + +#include "common/assert.h" +#include "common/types/types.h" +#include "common/uniq_lock.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/enums/residency_state.h" +#include "storage/storage_utils.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_scanner.h" +#include "storage/table/csr_chunked_node_group.h" +#include "storage/table/csr_node_group.h" +#include "storage/table/lazy_segment_scanner.h" +#include "storage/table/node_table.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +row_idx_t NodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, ChunkedNodeGroup& chunkedGroup, + row_idx_t startRowIdx, row_idx_t numRowsToAppend) { + KU_ASSERT(numRowsToAppend <= chunkedGroup.getNumRows()); + std::vector chunksToAppend(chunkedGroup.getNumColumns()); + for (auto i = 0u; i < chunkedGroup.getNumColumns(); i++) { + chunksToAppend[i] = &chunkedGroup.getColumnChunk(i); + } + return append(transaction, columnIDs, chunksToAppend, startRowIdx, numRowsToAppend); +} + +row_idx_t NodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, InMemChunkedNodeGroup& chunkedGroup, + row_idx_t startRowIdx, row_idx_t numRowsToAppend) { + KU_ASSERT(numRowsToAppend <= chunkedGroup.getNumRows()); + std::vector chunksToAppend(chunkedGroup.getNumColumns()); + for (auto i = 0u; i < chunkedGroup.getNumColumns(); i++) { + chunksToAppend[i] = &chunkedGroup.getColumnChunk(i); + } + return append(transaction, columnIDs, chunksToAppend, startRowIdx, numRowsToAppend); +} + +row_idx_t NodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, std::span chunkedGroup, + row_idx_t startRowIdx, row_idx_t numRowsToAppend) { + const auto lock = chunkedGroups.lock(); + const auto numRowsBeforeAppend = getNumRows(); + if (chunkedGroups.isEmpty(lock)) { + chunkedGroups.appendGroup(lock, + std::make_unique(mm, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, 0, ResidencyState::IN_MEMORY)); + } + row_idx_t numRowsAppended = 0u; + while (numRowsAppended < numRowsToAppend) { + auto lastChunkedGroup = chunkedGroups.getLastGroup(lock); + if (!lastChunkedGroup || lastChunkedGroup->isFullOrOnDisk()) { + chunkedGroups.appendGroup(lock, + std::make_unique(mm, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, + numRowsBeforeAppend + numRowsAppended, ResidencyState::IN_MEMORY)); + } + lastChunkedGroup = chunkedGroups.getLastGroup(lock); + KU_ASSERT(StorageConfig::CHUNKED_NODE_GROUP_CAPACITY >= lastChunkedGroup->getNumRows()); + auto numToCopyIntoChunk = + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY - lastChunkedGroup->getNumRows(); + const auto numToAppendInChunk = + std::min(numRowsToAppend - numRowsAppended, numToCopyIntoChunk); + lastChunkedGroup->append(transaction, columnIDs, chunkedGroup, + numRowsAppended + startRowIdx, numToAppendInChunk); + numRowsAppended += numToAppendInChunk; + } + numRows += numRowsAppended; + return numRowsBeforeAppend; +} + +row_idx_t NodeGroup::append(const Transaction* transaction, + const std::vector& columnIDs, std::span chunkedGroup, + row_idx_t startRowIdx, row_idx_t numRowsToAppend) { + const auto lock = chunkedGroups.lock(); + const auto numRowsBeforeAppend = getNumRows(); + if (chunkedGroups.isEmpty(lock)) { + chunkedGroups.appendGroup(lock, + std::make_unique(mm, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, 0, ResidencyState::IN_MEMORY)); + } + row_idx_t numRowsAppended = 0u; + while (numRowsAppended < numRowsToAppend) { + auto lastChunkedGroup = chunkedGroups.getLastGroup(lock); + if (!lastChunkedGroup || lastChunkedGroup->isFullOrOnDisk()) { + chunkedGroups.appendGroup(lock, + std::make_unique(mm, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, + numRowsBeforeAppend + numRowsAppended, ResidencyState::IN_MEMORY)); + } + lastChunkedGroup = chunkedGroups.getLastGroup(lock); + KU_ASSERT(StorageConfig::CHUNKED_NODE_GROUP_CAPACITY >= lastChunkedGroup->getNumRows()); + auto numToCopyIntoChunk = + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY - lastChunkedGroup->getNumRows(); + const auto numToAppendInChunk = + std::min(numRowsToAppend - numRowsAppended, numToCopyIntoChunk); + lastChunkedGroup->append(transaction, columnIDs, chunkedGroup, + numRowsAppended + startRowIdx, numToAppendInChunk); + numRowsAppended += numToAppendInChunk; + } + numRows += numRowsAppended; + return numRowsBeforeAppend; +} + +void NodeGroup::append(const Transaction* transaction, const std::vector& vectors, + const row_idx_t startRowIdx, const row_idx_t numRowsToAppend) { + const auto lock = chunkedGroups.lock(); + const auto numRowsBeforeAppend = getNumRows(); + if (chunkedGroups.isEmpty(lock)) { + chunkedGroups.appendGroup(lock, + std::make_unique(mm, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, 0 /*startOffset*/, + ResidencyState::IN_MEMORY)); + } + row_idx_t numRowsAppended = 0; + while (numRowsAppended < numRowsToAppend) { + auto lastChunkedGroup = chunkedGroups.getLastGroup(lock); + if (!lastChunkedGroup || lastChunkedGroup->isFullOrOnDisk()) { + chunkedGroups.appendGroup(lock, + std::make_unique(mm, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, + numRowsBeforeAppend + numRowsAppended, ResidencyState::IN_MEMORY)); + } + lastChunkedGroup = chunkedGroups.getLastGroup(lock); + const auto numRowsToAppendInGroup = std::min(numRowsToAppend - numRowsAppended, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY - lastChunkedGroup->getNumRows()); + lastChunkedGroup->append(transaction, vectors, startRowIdx + numRowsAppended, + numRowsToAppendInGroup); + numRowsAppended += numRowsToAppendInGroup; + } + numRows += numRowsAppended; +} + +void NodeGroup::merge(Transaction*, std::unique_ptr chunkedGroup) { + KU_ASSERT(chunkedGroup->getNumColumns() == dataTypes.size()); + for (auto i = 0u; i < chunkedGroup->getNumColumns(); i++) { + KU_ASSERT(chunkedGroup->getColumnChunk(i).getDataType().getPhysicalType() == + dataTypes[i].getPhysicalType()); + } + const auto lock = chunkedGroups.lock(); + numRows += chunkedGroup->getNumRows(); + chunkedGroups.appendGroup(lock, std::move(chunkedGroup)); +} + +void NodeGroup::initializeScanState(const Transaction* transaction, TableScanState& state) const { + const auto lock = chunkedGroups.lock(); + initializeScanState(transaction, lock, state); +} + +static void initializeScanStateForChunkedGroup(const TableScanState& state, + const ChunkedNodeGroup* chunkedGroup) { + KU_ASSERT(chunkedGroup); + if (chunkedGroup->getResidencyState() != ResidencyState::ON_DISK) { + return; + } + auto& nodeGroupScanState = *state.nodeGroupScanState; + for (auto i = 0u; i < state.columnIDs.size(); i++) { + KU_ASSERT(i < state.columnIDs.size()); + KU_ASSERT(i < nodeGroupScanState.chunkStates.size()); + const auto columnID = state.columnIDs[i]; + if (columnID == INVALID_COLUMN_ID || columnID == ROW_IDX_COLUMN_ID) { + continue; + } + auto& chunk = chunkedGroup->getColumnChunk(columnID); + auto& chunkState = nodeGroupScanState.chunkStates[i]; + chunk.initializeScanState(chunkState, state.columns[i]); + } +} + +void NodeGroup::initializeScanState(const Transaction*, const UniqLock& lock, + TableScanState& state) const { + auto& nodeGroupScanState = *state.nodeGroupScanState; + nodeGroupScanState.chunkedGroupIdx = 0; + ChunkedNodeGroup* firstChunkedGroup = chunkedGroups.getFirstGroup(lock); + nodeGroupScanState.nextRowToScan = firstChunkedGroup->getStartRowIdx(); + initializeScanStateForChunkedGroup(state, firstChunkedGroup); +} + +void applySemiMaskFilter(const TableScanState& state, row_idx_t numRowsToScan, + SelectionVector& selVector) { + auto& nodeGroupScanState = *state.nodeGroupScanState; + const auto startNodeOffset = nodeGroupScanState.nextRowToScan + + StorageUtils::getStartOffsetOfNodeGroup(state.nodeGroupIdx); + const auto endNodeOffset = startNodeOffset + numRowsToScan; + const auto& arr = state.semiMask->range(startNodeOffset, endNodeOffset); + if (arr.empty()) { + selVector.setSelSize(0); + } else { + auto stat = selVector.getMutableBuffer(); + uint64_t numSelectedValues = 0; + size_t i = 0, j = 0; + while (i < numRowsToScan && j < arr.size()) { + auto temp = arr[j] - startNodeOffset; + if (selVector[i] < temp) { + ++i; + } else if (selVector[i] > temp) { + ++j; + } else { + stat[numSelectedValues++] = temp; + ++i; + ++j; + } + } + selVector.setToFiltered(numSelectedValues); + } +} + +NodeGroupScanResult NodeGroup::scan(const Transaction* transaction, TableScanState& state) const { + // TODO(Guodong): Move the locked part of figuring out the chunked group to initScan. + const auto lock = chunkedGroups.lock(); + auto& nodeGroupScanState = *state.nodeGroupScanState; + KU_ASSERT(nodeGroupScanState.chunkedGroupIdx < chunkedGroups.getNumGroups(lock)); + const auto chunkedGroup = chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); + if (nodeGroupScanState.nextRowToScan >= + chunkedGroup->getNumRows() + chunkedGroup->getStartRowIdx()) { + nodeGroupScanState.chunkedGroupIdx++; + if (nodeGroupScanState.chunkedGroupIdx >= chunkedGroups.getNumGroups(lock)) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + ChunkedNodeGroup* currentChunkedGroup = + chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); + initializeScanStateForChunkedGroup(state, currentChunkedGroup); + } + const auto& chunkedGroupToScan = + *chunkedGroups.getGroup(lock, nodeGroupScanState.chunkedGroupIdx); + KU_ASSERT(nodeGroupScanState.nextRowToScan >= chunkedGroupToScan.getStartRowIdx()); + const auto rowIdxInChunkToScan = + nodeGroupScanState.nextRowToScan - chunkedGroupToScan.getStartRowIdx(); + const auto numRowsToScan = + std::min(chunkedGroupToScan.getNumRows() - rowIdxInChunkToScan, DEFAULT_VECTOR_CAPACITY); + bool enableSemiMask = + state.source == TableScanSource::COMMITTED && state.semiMask && state.semiMask->isEnabled(); + if (enableSemiMask) { + applySemiMaskFilter(state, numRowsToScan, state.outState->getSelVectorUnsafe()); + if (state.outState->getSelVector().getSelSize() == 0) { + state.nodeGroupScanState->nextRowToScan += numRowsToScan; + return NodeGroupScanResult{nodeGroupScanState.nextRowToScan, 0}; + } + } + chunkedGroupToScan.scan(transaction, state, nodeGroupScanState, rowIdxInChunkToScan, + numRowsToScan); + const auto startRow = nodeGroupScanState.nextRowToScan; + nodeGroupScanState.nextRowToScan += numRowsToScan; + return NodeGroupScanResult{startRow, numRowsToScan}; +} + +NodeGroupScanResult NodeGroup::scan(Transaction* transaction, TableScanState& state, + offset_t startOffsetInGroup, offset_t numRowsToScan) const { + bool enableSemiMask = + state.source == TableScanSource::COMMITTED && state.semiMask && state.semiMask->isEnabled(); + if (enableSemiMask) { + applySemiMaskFilter(state, numRowsToScan, state.outState->getSelVectorUnsafe()); + if (state.outState->getSelVector().getSelSize() == 0) { + state.nodeGroupScanState->nextRowToScan += numRowsToScan; + return NodeGroupScanResult{state.nodeGroupScanState->nextRowToScan, 0}; + } + } + if (state.outputVectors.size() == 0) { + KU_ASSERT(scanInternal(chunkedGroups.lock(), transaction, state, startOffsetInGroup, + numRowsToScan) == NodeGroupScanResult(startOffsetInGroup, numRowsToScan)); + return NodeGroupScanResult{startOffsetInGroup, numRowsToScan}; + } + return scanInternal(chunkedGroups.lock(), transaction, state, startOffsetInGroup, + numRowsToScan); +} + +NodeGroupScanResult NodeGroup::scanInternal(const UniqLock& lock, Transaction* transaction, + TableScanState& state, offset_t startOffsetInGroup, offset_t numRowsToScan) const { + // Only meant for scanning once + KU_ASSERT(numRowsToScan <= DEFAULT_VECTOR_CAPACITY); + + auto startRowIdxInGroup = getStartRowIdxInGroupNoLock(); + if (startOffsetInGroup < startRowIdxInGroup) { + numRowsToScan = std::min(numRowsToScan, startRowIdxInGroup - startOffsetInGroup); + // If the scan starts before the first row in the group, skip the deleted part and return. + return NodeGroupScanResult{startOffsetInGroup, numRowsToScan}; + } + + auto& nodeGroupScanState = *state.nodeGroupScanState; + nodeGroupScanState.nextRowToScan = startOffsetInGroup; + + auto [newChunkedGroupIdx, _] = findChunkedGroupIdxFromRowIdxNoLock(startOffsetInGroup); + KU_ASSERT(newChunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX); + + const auto* chunkedGroupToScan = chunkedGroups.getGroup(lock, newChunkedGroupIdx); + if (newChunkedGroupIdx != nodeGroupScanState.chunkedGroupIdx) { + // If the chunked group matches the scan state, don't re-initialize it. + // E.g., we may scan a group multiple times in parts + initializeScanStateForChunkedGroup(state, chunkedGroupToScan); + nodeGroupScanState.chunkedGroupIdx = newChunkedGroupIdx; + } + + uint64_t numRowsScanned = 0; + const auto rowIdxInChunkToScan = + (startOffsetInGroup + numRowsScanned) - chunkedGroupToScan->getStartRowIdx(); + uint64_t numRowsToScanInChunk = std::min(numRowsToScan - numRowsScanned, + chunkedGroupToScan->getNumRows() - rowIdxInChunkToScan); + KU_ASSERT(startOffsetInGroup + numRowsToScanInChunk <= numRows); + chunkedGroupToScan->scan(transaction, state, nodeGroupScanState, rowIdxInChunkToScan, + numRowsToScanInChunk); + numRowsScanned += numRowsToScanInChunk; + nodeGroupScanState.nextRowToScan += numRowsToScanInChunk; + + return NodeGroupScanResult{startOffsetInGroup, numRowsScanned}; +} + +bool NodeGroup::lookupNoLock(const Transaction* transaction, const TableScanState& state, + sel_t posInSel) const { + auto& nodeGroupScanState = *state.nodeGroupScanState; + const auto pos = state.rowIdxVector->state->getSelVector().getSelectedPositions()[posInSel]; + KU_ASSERT(!state.rowIdxVector->isNull(pos)); + const auto rowIdx = state.rowIdxVector->getValue(pos); + const ChunkedNodeGroup* chunkedGroupToScan = findChunkedGroupFromRowIdxNoLock(rowIdx); + KU_ASSERT(chunkedGroupToScan); + const auto rowIdxInChunkedGroup = rowIdx - chunkedGroupToScan->getStartRowIdx(); + return chunkedGroupToScan->lookup(transaction, state, nodeGroupScanState, rowIdxInChunkedGroup, + posInSel); +} + +bool NodeGroup::lookupMultiple(const UniqLock& lock, const Transaction* transaction, + const TableScanState& state) const { + idx_t numTuplesFound = 0; + for (auto i = 0u; i < state.rowIdxVector->state->getSelVector().getSelSize(); i++) { + auto& nodeGroupScanState = *state.nodeGroupScanState; + const auto pos = state.rowIdxVector->state->getSelVector().getSelectedPositions()[i]; + KU_ASSERT(!state.rowIdxVector->isNull(pos)); + const auto rowIdx = state.rowIdxVector->getValue(pos); + const ChunkedNodeGroup* chunkedGroupToScan = findChunkedGroupFromRowIdx(lock, rowIdx); + KU_ASSERT(chunkedGroupToScan); + const auto rowIdxInChunkedGroup = rowIdx - chunkedGroupToScan->getStartRowIdx(); + numTuplesFound += chunkedGroupToScan->lookup(transaction, state, nodeGroupScanState, + rowIdxInChunkedGroup, i); + } + return numTuplesFound == state.rowIdxVector->state->getSelVector().getSelSize(); +} + +bool NodeGroup::lookup(const Transaction* transaction, const TableScanState& state, + sel_t posInSel) const { + const auto lock = chunkedGroups.lock(); + return lookupNoLock(transaction, state, posInSel); +} + +bool NodeGroup::lookupMultiple(const Transaction* transaction, const TableScanState& state) const { + const auto lock = chunkedGroups.lock(); + return lookupMultiple(lock, transaction, state); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void NodeGroup::update(const Transaction* transaction, row_idx_t rowIdxInGroup, + column_id_t columnID, const ValueVector& propertyVector) { + KU_ASSERT(propertyVector.state->getSelVector().getSelSize() == 1); + ChunkedNodeGroup* chunkedGroupToUpdate = nullptr; + { + const auto lock = chunkedGroups.lock(); + chunkedGroupToUpdate = findChunkedGroupFromRowIdx(lock, rowIdxInGroup); + } + KU_ASSERT(chunkedGroupToUpdate); + const auto rowIdxInChunkedGroup = rowIdxInGroup - chunkedGroupToUpdate->getStartRowIdx(); + chunkedGroupToUpdate->update(transaction, rowIdxInChunkedGroup, columnID, propertyVector); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +bool NodeGroup::delete_(const Transaction* transaction, row_idx_t rowIdxInGroup) { + ChunkedNodeGroup* groupToDelete = nullptr; + { + const auto lock = chunkedGroups.lock(); + groupToDelete = findChunkedGroupFromRowIdx(lock, rowIdxInGroup); + } + KU_ASSERT(groupToDelete); + const auto rowIdxInChunkedGroup = rowIdxInGroup - groupToDelete->getStartRowIdx(); + return groupToDelete->delete_(transaction, rowIdxInChunkedGroup); +} + +bool NodeGroup::hasDeletions(const Transaction* transaction) const { + const auto lock = chunkedGroups.lock(); + for (auto i = 0u; i < chunkedGroups.getNumGroups(lock); i++) { + const auto chunkedGroup = chunkedGroups.getGroup(lock, i); + if (chunkedGroup->hasDeletions(transaction)) { + return true; + } + } + return false; +} + +void NodeGroup::addColumn(TableAddColumnState& addColumnState, PageAllocator* pageAllocator, + ColumnStats* newColumnStats) { + dataTypes.push_back(addColumnState.propertyDefinition.getType().copy()); + const auto lock = chunkedGroups.lock(); + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + chunkedGroup->addColumn(mm, addColumnState, enableCompression, pageAllocator, + newColumnStats); + } +} + +void NodeGroup::rollbackInsert(row_idx_t startRow) { + const auto lock = chunkedGroups.lock(); + const auto numEmptyTrailingGroups = chunkedGroups.getNumEmptyTrailingGroups(lock); + chunkedGroups.removeTrailingGroups(lock, numEmptyTrailingGroups); + numRows = startRow; +} + +void NodeGroup::reclaimStorage(PageAllocator& pageAllocator) const { + reclaimStorage(pageAllocator, chunkedGroups.lock()); +} + +void NodeGroup::reclaimStorage(PageAllocator& pageAllocator, const UniqLock& lock) const { + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + chunkedGroup->reclaimStorage(pageAllocator); + } +} + +void NodeGroup::checkpoint(MemoryManager& memoryManager, NodeGroupCheckpointState& state) { + const auto lock = chunkedGroups.lock(); + KU_ASSERT(chunkedGroups.getNumGroups(lock) >= 1); + const auto firstGroup = chunkedGroups.getFirstGroup(lock); + const auto hasPersistentData = firstGroup->getResidencyState() == ResidencyState::ON_DISK; + // Re-populate version info here first. + auto checkpointedVersionInfo = checkpointVersionInfo(lock, &DUMMY_CHECKPOINT_TRANSACTION); + std::unique_ptr checkpointedChunkedGroup; + if (checkpointedVersionInfo->getNumDeletions(&DUMMY_CHECKPOINT_TRANSACTION, 0, numRows) == + numRows - firstGroup->getStartRowIdx()) { + reclaimStorage(state.pageAllocator, lock); + checkpointedChunkedGroup = + ChunkedNodeGroup::flushEmpty(memoryManager, dataTypes, enableCompression, + StorageConfig::CHUNKED_NODE_GROUP_CAPACITY, numRows, state.pageAllocator); + } else { + if (hasPersistentData) { + checkpointedChunkedGroup = checkpointInMemAndOnDisk(memoryManager, lock, state); + } else { + checkpointedChunkedGroup = checkpointInMemOnly(memoryManager, lock, state); + } + checkpointedChunkedGroup->setVersionInfo(std::move(checkpointedVersionInfo)); + } + chunkedGroups.clear(lock); + chunkedGroups.appendGroup(lock, std::move(checkpointedChunkedGroup)); + checkpointDataTypesNoLock(state); +} + +void NodeGroup::checkpointDataTypesNoLock(const NodeGroupCheckpointState& state) { + std::vector checkpointedTypes; + for (auto i = 0u; i < state.columnIDs.size(); i++) { + auto columnID = state.columnIDs[i]; + KU_ASSERT(columnID < dataTypes.size()); + checkpointedTypes.push_back(dataTypes[columnID].copy()); + } + dataTypes = std::move(checkpointedTypes); +} + +void NodeGroup::scanCommittedUpdatesForColumn( + std::vector& chunkCheckpointStates, MemoryManager& memoryManager, + const UniqLock& lock, column_id_t columnID, const Column* column) const { + auto updateSegmentScanner = + LazySegmentScanner(memoryManager, column->getDataType().copy(), enableCompression); + ChunkState chunkState; + auto& firstColumnChunk = chunkedGroups.getFirstGroup(lock)->getColumnChunk(columnID); + const auto numPersistentRows = firstColumnChunk.getNumValues(); + firstColumnChunk.initializeScanState(chunkState, column); + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + chunkedGroup->getColumnChunk(columnID).scanCommitted( + &DUMMY_CHECKPOINT_TRANSACTION, chunkState, updateSegmentScanner); + } + KU_ASSERT(updateSegmentScanner.getNumValues() == numPersistentRows); + updateSegmentScanner.rangeSegments(updateSegmentScanner.begin(), numPersistentRows, + [&chunkCheckpointStates](auto& segment, auto, auto segmentLength, auto offsetInChunk) { + if (segment.segmentData) { + chunkCheckpointStates.emplace_back(std::move(segment.segmentData), offsetInChunk, + segmentLength); + } + }); +} + +std::unique_ptr NodeGroup::checkpointInMemAndOnDisk(MemoryManager& memoryManager, + const UniqLock& lock, NodeGroupCheckpointState& state) const { + const auto firstGroup = chunkedGroups.getFirstGroup(lock); + const auto numPersistentRows = firstGroup->getNumRows(); + std::vector columnPtrs; + columnPtrs.reserve(state.columns.size()); + for (auto* column : state.columns) { + columnPtrs.push_back(column); + } + const auto insertChunkedGroup = scanAllInsertedAndVersions( + memoryManager, lock, state.columnIDs, columnPtrs); + const auto numInsertedRows = insertChunkedGroup->getNumRows(); + for (auto i = 0u; i < state.columnIDs.size(); i++) { + const auto columnID = state.columnIDs[i]; + // if has persistent data, scan updates from persistent chunked group; + KU_ASSERT(firstGroup && firstGroup->getResidencyState() == ResidencyState::ON_DISK); + const auto columnHasUpdates = firstGroup->hasAnyUpdates(&DUMMY_CHECKPOINT_TRANSACTION, + columnID, 0, firstGroup->getNumRows()); + if (numInsertedRows == 0 && !columnHasUpdates) { + continue; + } + std::vector chunkCheckpointStates; + if (columnHasUpdates) { + scanCommittedUpdatesForColumn(chunkCheckpointStates, memoryManager, lock, columnID, + state.columns[columnID]); + } + if (numInsertedRows > 0) { + chunkCheckpointStates.emplace_back(insertChunkedGroup->moveColumnChunk(columnID), + numPersistentRows, numInsertedRows); + } + firstGroup->getColumnChunk(columnID).checkpoint(*state.columns[i], + std::move(chunkCheckpointStates), state.pageAllocator); + } + auto checkpointedChunkedGroup = + std::make_unique(*chunkedGroups.getGroup(lock, 0), state.columnIDs); + KU_ASSERT(checkpointedChunkedGroup->getResidencyState() == ResidencyState::ON_DISK); + checkpointedChunkedGroup->resetNumRowsFromChunks(); + checkpointedChunkedGroup->resetVersionAndUpdateInfo(); + // The first chunked group is the only persistent one + // The checkpointed columns have been moved to the checkpointedChunkedGroup, the + // remaining must have been dropped + firstGroup->reclaimStorage(state.pageAllocator); + return checkpointedChunkedGroup; +} + +std::unique_ptr NodeGroup::checkpointInMemOnly(MemoryManager& memoryManager, + const UniqLock& lock, const NodeGroupCheckpointState& state) const { + // Flush insertChunkedGroup to persistent one. + std::vector columnPtrs; + columnPtrs.reserve(state.columns.size()); + for (auto& column : state.columns) { + columnPtrs.push_back(column); + } + auto insertChunkedGroup = scanAllInsertedAndVersions(memoryManager, + lock, state.columnIDs, columnPtrs); + return insertChunkedGroup->flush(&DUMMY_CHECKPOINT_TRANSACTION, state.pageAllocator); +} + +std::unique_ptr NodeGroup::checkpointVersionInfo(const UniqLock& lock, + const Transaction* transaction) const { + auto checkpointVersionInfo = std::make_unique(); + row_idx_t currRow = 0; + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + if (chunkedGroup->hasVersionInfo()) { + // TODO(Guodong): Optimize the for loop here to directly acess the version info. + for (auto i = 0u; i < chunkedGroup->getNumRows(); i++) { + if (chunkedGroup->isDeleted(transaction, i)) { + checkpointVersionInfo->delete_(transaction->getID(), currRow + i); + } + } + } + currRow += chunkedGroup->getNumRows(); + } + return checkpointVersionInfo; +} + +uint64_t NodeGroup::getEstimatedMemoryUsage() const { + uint64_t memUsage = 0; + const auto lock = chunkedGroups.lock(); + for (const auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + memUsage += chunkedGroup->getEstimatedMemoryUsage(); + } + return memUsage; +} + +void NodeGroup::serialize(Serializer& serializer) { + // Serialize checkpointed chunks. + serializer.writeDebuggingInfo("node_group_idx"); + serializer.write(nodeGroupIdx); + serializer.writeDebuggingInfo("enable_compression"); + serializer.write(enableCompression); + serializer.writeDebuggingInfo("format"); + serializer.write(format); + const auto lock = chunkedGroups.lock(); + KU_ASSERT(chunkedGroups.getNumGroups(lock) == 1); + const auto chunkedGroup = chunkedGroups.getFirstGroup(lock); + serializer.writeDebuggingInfo("has_checkpointed_data"); + serializer.write(chunkedGroup->getResidencyState() == ResidencyState::ON_DISK); + if (chunkedGroup->getResidencyState() == ResidencyState::ON_DISK) { + serializer.writeDebuggingInfo("checkpointed_data"); + chunkedGroup->serialize(serializer); + } +} + +std::unique_ptr NodeGroup::deserialize(MemoryManager& mm, Deserializer& deSer, + const std::vector& columnTypes) { + std::string key; + node_group_idx_t nodeGroupIdx = INVALID_NODE_GROUP_IDX; + bool enableCompression = false; + auto format = NodeGroupDataFormat::REGULAR; + bool hasCheckpointedData = false; + deSer.validateDebuggingInfo(key, "node_group_idx"); + deSer.deserializeValue(nodeGroupIdx); + deSer.validateDebuggingInfo(key, "enable_compression"); + deSer.deserializeValue(enableCompression); + deSer.validateDebuggingInfo(key, "format"); + deSer.deserializeValue(format); + deSer.validateDebuggingInfo(key, "has_checkpointed_data"); + deSer.deserializeValue(hasCheckpointedData); + if (hasCheckpointedData) { + deSer.validateDebuggingInfo(key, "checkpointed_data"); + } + std::unique_ptr chunkedNodeGroup; + switch (format) { + case NodeGroupDataFormat::REGULAR: { + if (hasCheckpointedData) { + chunkedNodeGroup = ChunkedNodeGroup::deserialize(mm, deSer); + } else { + chunkedNodeGroup = std::make_unique(mm, columnTypes, + enableCompression, 0, 0, ResidencyState::IN_MEMORY); + } + return std::make_unique(mm, nodeGroupIdx, enableCompression, + std::move(chunkedNodeGroup)); + } + case NodeGroupDataFormat::CSR: { + if (hasCheckpointedData) { + chunkedNodeGroup = ChunkedCSRNodeGroup::deserialize(mm, deSer); + return std::make_unique(mm, nodeGroupIdx, enableCompression, + std::move(chunkedNodeGroup)); + } else { + return std::make_unique(mm, nodeGroupIdx, enableCompression, + copyVector(columnTypes)); + } + } + default: { + KU_UNREACHABLE; + } + } +} + +std::pair NodeGroup::findChunkedGroupIdxFromRowIdxNoLock(row_idx_t rowIdx) const { + if (chunkedGroups.getNumGroupsNoLock() == 0 || rowIdx < getStartRowIdxInGroupNoLock()) { + return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; + } + rowIdx -= getStartRowIdxInGroupNoLock(); + const auto numRowsInFirstGroup = chunkedGroups.getFirstGroupNoLock()->getNumRows(); + if (rowIdx < numRowsInFirstGroup) { + return {0, rowIdx}; + } + rowIdx -= numRowsInFirstGroup; + const auto chunkedGroupIdx = rowIdx / StorageConfig::CHUNKED_NODE_GROUP_CAPACITY + 1; + const auto rowIdxInChunk = rowIdx % StorageConfig::CHUNKED_NODE_GROUP_CAPACITY; + if (chunkedGroupIdx >= chunkedGroups.getNumGroupsNoLock()) { + return {INVALID_CHUNKED_GROUP_IDX, INVALID_START_ROW_IDX}; + } + return {chunkedGroupIdx, rowIdxInChunk}; +} + +ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdx(const UniqLock& lock, + row_idx_t rowIdx) const { + const auto [chunkedGroupIdx, rowIdxInChunkedGroup] = + findChunkedGroupIdxFromRowIdxNoLock(rowIdx); + if (chunkedGroupIdx == INVALID_CHUNKED_GROUP_IDX) { + return nullptr; + } + return chunkedGroups.getGroup(lock, chunkedGroupIdx); +} + +ChunkedNodeGroup* NodeGroup::findChunkedGroupFromRowIdxNoLock(row_idx_t rowIdx) const { + const auto [chunkedGroupIdx, rowIdxInChunkedGroup] = + findChunkedGroupIdxFromRowIdxNoLock(rowIdx); + if (chunkedGroupIdx == INVALID_CHUNKED_GROUP_IDX) { + return nullptr; + } + return chunkedGroups.getGroupNoLock(chunkedGroupIdx); +} + +template +row_idx_t NodeGroup::getNumResidentRows(const UniqLock& lock) const { + row_idx_t numResidentRows = 0u; + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + if (chunkedGroup->getResidencyState() == RESIDENCY_STATE) { + numResidentRows += chunkedGroup->getNumRows(); + } + } + return numResidentRows; +} + +template +std::unique_ptr NodeGroup::scanAllInsertedAndVersions( + MemoryManager& memoryManager, const UniqLock& lock, const std::vector& columnIDs, + const std::vector& columns) const { + auto numResidentRows = getNumResidentRows(lock); + std::vector columnTypes; + for (const auto* column : columns) { + columnTypes.push_back(column->getDataType().copy()); + } + auto mergedInMemGroup = std::make_unique(memoryManager, columnTypes, + enableCompression, numResidentRows, chunkedGroups.getFirstGroup(lock)->getStartRowIdx()); + auto scanState = std::make_unique(columnIDs, columns); + scanState->nodeGroupScanState = std::make_unique(columnIDs.size()); + initializeScanState(&DUMMY_CHECKPOINT_TRANSACTION, lock, *scanState); + for (auto& chunkedGroup : chunkedGroups.getAllGroups(lock)) { + chunkedGroup->scanCommitted(&DUMMY_CHECKPOINT_TRANSACTION, *scanState, + *mergedInMemGroup); + } + for (auto i = 0u; i < columnIDs.size(); i++) { + if (columnIDs[i] != 0) { + KU_ASSERT(numResidentRows == mergedInMemGroup->getColumnChunk(i).getNumValues()); + } + } + mergedInMemGroup->setNumRows(numResidentRows); + return mergedInMemGroup; +} + +template std::unique_ptr +NodeGroup::scanAllInsertedAndVersions(MemoryManager& memoryManager, + const UniqLock& lock, const std::vector& columnIDs, + const std::vector& columns) const; +template std::unique_ptr +NodeGroup::scanAllInsertedAndVersions(MemoryManager& memoryManager, + const UniqLock& lock, const std::vector& columnIDs, + const std::vector& columns) const; + +bool NodeGroup::isVisible(const Transaction* transaction, row_idx_t rowIdxInGroup) const { + ChunkedNodeGroup* chunkedGroup = nullptr; + { + const auto lock = chunkedGroups.lock(); + chunkedGroup = findChunkedGroupFromRowIdx(lock, rowIdxInGroup); + } + if (!chunkedGroup) { + return false; + } + const auto rowIdxInChunkedGroup = rowIdxInGroup - chunkedGroup->getStartRowIdx(); + return !chunkedGroup->isDeleted(transaction, rowIdxInChunkedGroup) && + chunkedGroup->isInserted(transaction, rowIdxInChunkedGroup); +} + +bool NodeGroup::isVisibleNoLock(const Transaction* transaction, row_idx_t rowIdxInGroup) const { + const auto* chunkedGroup = findChunkedGroupFromRowIdxNoLock(rowIdxInGroup); + if (!chunkedGroup) { + return false; + } + const auto rowIdxInChunkedGroup = rowIdxInGroup - chunkedGroup->getStartRowIdx(); + return !chunkedGroup->isDeleted(transaction, rowIdxInChunkedGroup) && + chunkedGroup->isInserted(transaction, rowIdxInChunkedGroup); +} + +bool NodeGroup::isDeleted(const Transaction* transaction, offset_t offsetInGroup) const { + const auto lock = chunkedGroups.lock(); + const auto* chunkedGroup = findChunkedGroupFromRowIdx(lock, offsetInGroup); + KU_ASSERT(chunkedGroup); + return chunkedGroup->isDeleted(transaction, offsetInGroup - chunkedGroup->getStartRowIdx()); +} + +bool NodeGroup::isInserted(const Transaction* transaction, offset_t offsetInGroup) const { + const auto lock = chunkedGroups.lock(); + const auto* chunkedGroup = findChunkedGroupFromRowIdx(lock, offsetInGroup); + KU_ASSERT(chunkedGroup); + return chunkedGroup->isInserted(transaction, offsetInGroup - chunkedGroup->getStartRowIdx()); +} + +void NodeGroup::applyFuncToChunkedGroups(version_record_handler_op_t func, row_idx_t startRow, + row_idx_t numRows, transaction_t commitTS) const { + KU_ASSERT(startRow <= getNumRows()); + + auto lock = chunkedGroups.lock(); + const auto [chunkedGroupIdx, startRowInChunkedGroup] = + findChunkedGroupIdxFromRowIdxNoLock(startRow); + if (chunkedGroupIdx != INVALID_CHUNKED_GROUP_IDX) { + auto curChunkedGroupIdx = chunkedGroupIdx; + auto curStartRowIdxInChunk = startRowInChunkedGroup; + + auto numRowsLeft = numRows; + while (numRowsLeft > 0 && curChunkedGroupIdx < chunkedGroups.getNumGroups(lock)) { + auto* chunkedGroup = chunkedGroups.getGroup(lock, curChunkedGroupIdx); + const auto numRowsForGroup = + std::min(numRowsLeft, chunkedGroup->getNumRows() - curStartRowIdxInChunk); + std::invoke(func, *chunkedGroup, curStartRowIdxInChunk, numRowsForGroup, commitTS); + + ++curChunkedGroupIdx; + numRowsLeft -= numRowsForGroup; + curStartRowIdxInChunk = 0; + } + } +} + +row_idx_t NodeGroup::getStartRowIdxInGroupNoLock() const { + return chunkedGroups.getFirstGroupNoLock()->getStartRowIdx(); +} + +row_idx_t NodeGroup::getStartRowIdxInGroup(const common::UniqLock& lock) const { + return chunkedGroups.getFirstGroup(lock)->getStartRowIdx(); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_group_collection.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_group_collection.cpp new file mode 100644 index 0000000000..76d03aa342 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_group_collection.cpp @@ -0,0 +1,276 @@ +#include "storage/table/node_group_collection.h" + +#include "common/vector/value_vector.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/csr_node_group.h" +#include "storage/table/table.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +NodeGroupCollection::NodeGroupCollection(MemoryManager& mm, const std::vector& types, + const bool enableCompression, ResidencyState residency, + const VersionRecordHandler* versionRecordHandler) + : mm{mm}, enableCompression{enableCompression}, numTotalRows{0}, + types{LogicalType::copy(types)}, residency{residency}, stats{std::span{types}}, + versionRecordHandler(versionRecordHandler) { + const auto lock = nodeGroups.lock(); + for (auto& nodeGroup : nodeGroups.getAllGroups(lock)) { + numTotalRows += nodeGroup->getNumRows(); + } +} + +void NodeGroupCollection::append(const Transaction* transaction, + const std::vector& vectors) { + const auto numRowsToAppend = vectors[0]->state->getSelVector().getSelSize(); + KU_ASSERT(numRowsToAppend == vectors[0]->state->getSelVector().getSelSize()); + for (auto i = 1u; i < vectors.size(); i++) { + KU_ASSERT(vectors[i]->state->getSelVector().getSelSize() == numRowsToAppend); + } + const auto lock = nodeGroups.lock(); + if (nodeGroups.isEmpty(lock)) { + auto newGroup = + std::make_unique(mm, 0, enableCompression, LogicalType::copy(types)); + nodeGroups.appendGroup(lock, std::move(newGroup)); + } + row_idx_t numRowsAppended = 0u; + while (numRowsAppended < numRowsToAppend) { + auto lastNodeGroup = nodeGroups.getLastGroup(lock); + if (!lastNodeGroup || lastNodeGroup->isFull()) { + auto newGroup = std::make_unique(mm, nodeGroups.getNumGroups(lock), + enableCompression, LogicalType::copy(types)); + nodeGroups.appendGroup(lock, std::move(newGroup)); + } + lastNodeGroup = nodeGroups.getLastGroup(lock); + const auto numToAppendInNodeGroup = + std::min(numRowsToAppend - numRowsAppended, lastNodeGroup->getNumRowsLeftToAppend()); + lastNodeGroup->moveNextRowToAppend(numToAppendInNodeGroup); + pushInsertInfo(transaction, lastNodeGroup, numToAppendInNodeGroup); + numTotalRows += numToAppendInNodeGroup; + lastNodeGroup->append(transaction, vectors, numRowsAppended, numToAppendInNodeGroup); + numRowsAppended += numToAppendInNodeGroup; + } + stats.update(vectors); +} + +void NodeGroupCollection::append(const Transaction* transaction, + const std::vector& columnIDs, const NodeGroupCollection& other) { + const auto otherLock = other.nodeGroups.lock(); + for (auto& nodeGroup : other.nodeGroups.getAllGroups(otherLock)) { + append(transaction, columnIDs, *nodeGroup); + } + mergeStats(columnIDs, other.getStats(otherLock)); +} + +void NodeGroupCollection::append(const Transaction* transaction, + const std::vector& columnIDs, const NodeGroup& nodeGroup) { + KU_ASSERT(nodeGroup.getDataTypes().size() == columnIDs.size()); + const auto lock = nodeGroups.lock(); + if (nodeGroups.isEmpty(lock)) { + auto newGroup = + std::make_unique(mm, 0, enableCompression, LogicalType::copy(types)); + nodeGroups.appendGroup(lock, std::move(newGroup)); + } + const auto numChunkedGroupsToAppend = nodeGroup.getNumChunkedGroups(); + node_group_idx_t numChunkedGroupsAppended = 0; + while (numChunkedGroupsAppended < numChunkedGroupsToAppend) { + const auto chunkedGroupToAppend = nodeGroup.getChunkedNodeGroup(numChunkedGroupsAppended); + const auto numRowsToAppendInChunkedGroup = chunkedGroupToAppend->getNumRows(); + row_idx_t numRowsAppendedInChunkedGroup = 0; + while (numRowsAppendedInChunkedGroup < numRowsToAppendInChunkedGroup) { + auto lastNodeGroup = nodeGroups.getLastGroup(lock); + if (!lastNodeGroup || lastNodeGroup->isFull()) { + auto newGroup = std::make_unique(mm, nodeGroups.getNumGroups(lock), + enableCompression, LogicalType::copy(types)); + nodeGroups.appendGroup(lock, std::move(newGroup)); + } + lastNodeGroup = nodeGroups.getLastGroup(lock); + const auto numToAppendInBatch = + std::min(numRowsToAppendInChunkedGroup - numRowsAppendedInChunkedGroup, + lastNodeGroup->getNumRowsLeftToAppend()); + lastNodeGroup->moveNextRowToAppend(numToAppendInBatch); + pushInsertInfo(transaction, lastNodeGroup, numToAppendInBatch); + numTotalRows += numToAppendInBatch; + lastNodeGroup->append(transaction, columnIDs, *chunkedGroupToAppend, + numRowsAppendedInChunkedGroup, numToAppendInBatch); + numRowsAppendedInChunkedGroup += numToAppendInBatch; + } + numChunkedGroupsAppended++; + } +} + +std::pair NodeGroupCollection::appendToLastNodeGroupAndFlushWhenFull( + Transaction* transaction, const std::vector& columnIDs, + InMemChunkedNodeGroup& chunkedGroup, PageAllocator& pageAllocator) { + NodeGroup* lastNodeGroup = nullptr; + offset_t startOffset = 0; + offset_t numToAppend = 0; + bool directFlushWhenAppend = false; + { + const auto lock = nodeGroups.lock(); + startOffset = numTotalRows; + if (nodeGroups.isEmpty(lock)) { + nodeGroups.appendGroup(lock, + std::make_unique(mm, nodeGroups.getNumGroups(lock), enableCompression, + LogicalType::copy(types))); + } + lastNodeGroup = nodeGroups.getLastGroup(lock); + auto numRowsLeftInLastNodeGroup = lastNodeGroup->getNumRowsLeftToAppend(); + if (numRowsLeftInLastNodeGroup == 0) { + nodeGroups.appendGroup(lock, + std::make_unique(mm, nodeGroups.getNumGroups(lock), enableCompression, + LogicalType::copy(types))); + lastNodeGroup = nodeGroups.getLastGroup(lock); + numRowsLeftInLastNodeGroup = lastNodeGroup->getNumRowsLeftToAppend(); + } + numToAppend = std::min(chunkedGroup.getNumRows(), numRowsLeftInLastNodeGroup); + lastNodeGroup->moveNextRowToAppend(numToAppend); + // If the node group is empty now and the chunked group is full, we can directly flush it. + directFlushWhenAppend = + numToAppend == numRowsLeftInLastNodeGroup && lastNodeGroup->getNumRows() == 0; + pushInsertInfo(transaction, lastNodeGroup, numToAppend); + numTotalRows += numToAppend; + if (!directFlushWhenAppend) { + // TODO(Guodong): Further optimize on this. Should directly figure out startRowIdx to + // start appending into the node group and pass in as param. + lastNodeGroup->append(transaction, columnIDs, chunkedGroup, 0, numToAppend); + } + } + if (directFlushWhenAppend) { + auto flushedGroup = chunkedGroup.flush(transaction, pageAllocator); + + // If there are deleted columns that haven't been vacuumed yet, + // we need to add extra columns to the chunked group + // to ensure that the number of columns is consistent with the rest of the node group + auto groupToMerge = std::make_unique(mm, *flushedGroup, + lastNodeGroup->getDataTypes(), columnIDs); + + KU_ASSERT(lastNodeGroup->getNumChunkedGroups() == 0); + lastNodeGroup->merge(transaction, std::move(groupToMerge)); + } + return {startOffset, numToAppend}; +} + +row_idx_t NodeGroupCollection::getNumTotalRows() const { + const auto lock = nodeGroups.lock(); + return numTotalRows; +} + +NodeGroup* NodeGroupCollection::getOrCreateNodeGroup(const Transaction* transaction, + node_group_idx_t groupIdx, NodeGroupDataFormat format) { + const auto lock = nodeGroups.lock(); + while (groupIdx >= nodeGroups.getNumGroups(lock)) { + const auto currentGroupIdx = nodeGroups.getNumGroups(lock); + nodeGroups.appendGroup(lock, format == NodeGroupDataFormat::REGULAR ? + std::make_unique(mm, currentGroupIdx, + enableCompression, LogicalType::copy(types)) : + std::make_unique(mm, currentGroupIdx, + enableCompression, LogicalType::copy(types))); + // push an insert of size 0 so that we can roll back the creation of this node group if + // needed + pushInsertInfo(transaction, nodeGroups.getLastGroup(lock), 0); + } + KU_ASSERT(groupIdx < nodeGroups.getNumGroups(lock)); + return nodeGroups.getGroup(lock, groupIdx); +} + +void NodeGroupCollection::addColumn(TableAddColumnState& addColumnState, + PageAllocator* pageAllocator) { + KU_ASSERT((pageAllocator == nullptr) == (residency == ResidencyState::IN_MEMORY)); + const auto lock = nodeGroups.lock(); + auto& newColumnStats = stats.addNewColumn(addColumnState.propertyDefinition.getType()); + for (const auto& nodeGroup : nodeGroups.getAllGroups(lock)) { + nodeGroup->addColumn(addColumnState, pageAllocator, &newColumnStats); + } + types.push_back(addColumnState.propertyDefinition.getType().copy()); +} + +uint64_t NodeGroupCollection::getEstimatedMemoryUsage() const { + auto estimatedMemUsage = 0u; + const auto lock = nodeGroups.lock(); + for (const auto& nodeGroup : nodeGroups.getAllGroups(lock)) { + estimatedMemUsage += nodeGroup->getEstimatedMemoryUsage(); + } + return estimatedMemUsage; +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void NodeGroupCollection::checkpoint(MemoryManager& memoryManager, + NodeGroupCheckpointState& state) { + KU_ASSERT(residency == ResidencyState::ON_DISK); + const auto lock = nodeGroups.lock(); + for (const auto& nodeGroup : nodeGroups.getAllGroups(lock)) { + nodeGroup->checkpoint(memoryManager, state); + } + std::vector typesAfterCheckpoint; + for (auto i = 0u; i < state.columnIDs.size(); i++) { + typesAfterCheckpoint.push_back(types[state.columnIDs[i]].copy()); + } + types = std::move(typesAfterCheckpoint); +} + +void NodeGroupCollection::reclaimStorage(PageAllocator& pageAllocator) const { + const auto lock = nodeGroups.lock(); + for (auto& nodeGroup : nodeGroups.getAllGroups(lock)) { + nodeGroup->reclaimStorage(pageAllocator); + } +} + +void NodeGroupCollection::rollbackInsert(row_idx_t numRows_, bool updateNumRows) { + const auto lock = nodeGroups.lock(); + + // remove any empty trailing node groups after the rollback + const auto numGroupsToRemove = nodeGroups.getNumEmptyTrailingGroups(lock); + nodeGroups.removeTrailingGroups(lock, numGroupsToRemove); + + if (updateNumRows) { + KU_ASSERT(numRows_ <= numTotalRows); + numTotalRows -= numRows_; + } +} + +void NodeGroupCollection::pushInsertInfo(const Transaction* transaction, const NodeGroup* nodeGroup, + row_idx_t numRows) { + pushInsertInfo(transaction, nodeGroup->getNodeGroupIdx(), nodeGroup->getNumRows(), numRows, + versionRecordHandler, false); +}; + +void NodeGroupCollection::pushInsertInfo(const Transaction* transaction, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows, + const VersionRecordHandler* versionRecordHandler, bool incrementNumRows) { + // we only append to the undo buffer if the node group collection is persistent + if (residency == ResidencyState::ON_DISK && transaction->shouldAppendToUndoBuffer()) { + transaction->pushInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandler); + } + if (incrementNumRows) { + numTotalRows += numRows; + } +} + +void NodeGroupCollection::serialize(Serializer& ser) { + ser.writeDebuggingInfo("node_groups"); + nodeGroups.serializeGroups(ser); + ser.writeDebuggingInfo("stats"); + stats.serialize(ser); +} + +void NodeGroupCollection::deserialize(Deserializer& deSer, MemoryManager& memoryManager) { + std::string key; + deSer.validateDebuggingInfo(key, "node_groups"); + KU_ASSERT(residency == ResidencyState::ON_DISK); + nodeGroups.deserializeGroups(memoryManager, deSer, types); + deSer.validateDebuggingInfo(key, "stats"); + stats.deserialize(deSer); + numTotalRows = 0; + const auto lock = nodeGroups.lock(); + for (auto& nodeGroup : nodeGroups.getAllGroups(lock)) { + numTotalRows += nodeGroup->getNumRows(); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_table.cpp new file mode 100644 index 0000000000..3edbc9a573 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/node_table.cpp @@ -0,0 +1,857 @@ +#include "storage/table/node_table.h" + +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/cast.h" +#include "common/exception/message.h" +#include "common/exception/runtime.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "storage/local_storage/local_node_table.h" +#include "storage/local_storage/local_storage.h" +#include "storage/local_storage/local_table.h" +#include "storage/storage_manager.h" +#include "storage/wal/local_wal.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::transaction; +using namespace lbug::evaluator; + +namespace lbug { +namespace storage { + +NodeTableVersionRecordHandler::NodeTableVersionRecordHandler(NodeTable* table) : table(table) {} + +void NodeTableVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows, + transaction_t commitTS) const { + table->getNodeGroupNoLock(nodeGroupIdx) + ->applyFuncToChunkedGroups(func, startRow, numRows, commitTS); +} + +void NodeTableVersionRecordHandler::rollbackInsert(main::ClientContext* context, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows) const { + table->rollbackPKIndexInsert(context, startRow, numRows, nodeGroupIdx); + + // the only case where a node group would be empty (and potentially removed before) is if an + // exception occurred while adding its first chunk + KU_ASSERT(nodeGroupIdx < table->getNumNodeGroups() || startRow == 0); + if (nodeGroupIdx < table->getNumNodeGroups()) { + VersionRecordHandler::rollbackInsert(context, nodeGroupIdx, startRow, numRows); + auto* nodeGroup = table->getNodeGroupNoLock(nodeGroupIdx); + const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); + nodeGroup->rollbackInsert(startRow); + table->rollbackGroupCollectionInsert(numRowsToRollback); + } +} + +NodeGroupScanResult NodeTableScanState::scanNext(Transaction* transaction, offset_t startOffset, + offset_t numNodes) { + KU_ASSERT(columns.size() == outputVectors.size()); + if (source == TableScanSource::NONE) { + return NODE_GROUP_SCAN_EMPTY_RESULT; + } + auto nodeGroupStartOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + const auto tableID = table->getTableID(); + if (source == TableScanSource::UNCOMMITTED) { + nodeGroupStartOffset = transaction->getUncommittedOffset(tableID, nodeGroupStartOffset); + } + auto startOffsetInGroup = startOffset - nodeGroupStartOffset; + const NodeGroupScanResult scanResult = + nodeGroup->scan(transaction, *this, startOffsetInGroup, numNodes); + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + return scanResult; + } + for (auto i = 0u; i < scanResult.numRows; i++) { + nodeIDVector->setValue(i, + nodeID_t{nodeGroupStartOffset + scanResult.startRow + i, tableID}); + } + return scanResult; +} + +std::unique_ptr IndexScanHelper::initScanState(const Transaction* transaction, + DataChunk& dataChunk) { + std::vector outVectors; + for (auto& vector : dataChunk.valueVectors) { + outVectors.push_back(vector.get()); + } + auto scanState = std::make_unique(nullptr, outVectors, dataChunk.state); + scanState->setToTable(transaction, table, index->getIndexInfo().columnIDs, {}); + return scanState; +} + +namespace { + +struct UncommittedIndexInserter final : IndexScanHelper { + UncommittedIndexInserter(row_idx_t startNodeOffset, NodeTable* table, Index* index, + visible_func isVisible) + : IndexScanHelper(table, index), startNodeOffset(startNodeOffset), + nodeIDVector(LogicalType::INTERNAL_ID()), isVisible(std::move(isVisible)) {} + + std::unique_ptr initScanState(const Transaction* transaction, + DataChunk& dataChunk) override; + + bool processScanOutput(main::ClientContext* context, NodeGroupScanResult scanResult, + const std::vector& scannedVectors) override; + + row_idx_t startNodeOffset; + ValueVector nodeIDVector; + visible_func isVisible; + std::unique_ptr insertState; +}; + +struct RollbackPKDeleter final : IndexScanHelper { + RollbackPKDeleter(row_idx_t startNodeOffset, row_idx_t numRows, NodeTable* table, + PrimaryKeyIndex* pkIndex) + : IndexScanHelper(table, pkIndex), + semiMask(SemiMaskUtil::createMask(startNodeOffset + numRows)) { + semiMask->maskRange(startNodeOffset, startNodeOffset + numRows); + semiMask->enable(); + } + + std::unique_ptr initScanState(const Transaction* transaction, + DataChunk& dataChunk) override; + + bool processScanOutput(main::ClientContext* context, NodeGroupScanResult scanResult, + const std::vector& scannedVectors) override; + + std::unique_ptr semiMask; +}; + +std::unique_ptr UncommittedIndexInserter::initScanState( + const Transaction* transaction, DataChunk& dataChunk) { + auto scanState = IndexScanHelper::initScanState(transaction, dataChunk); + nodeIDVector.setState(dataChunk.state); + scanState->source = TableScanSource::UNCOMMITTED; + return scanState; +} + +bool UncommittedIndexInserter::processScanOutput(main::ClientContext* context, + NodeGroupScanResult scanResult, const std::vector& scannedVectors) { + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + return false; + } + for (auto i = 0u; i < scanResult.numRows; i++) { + nodeIDVector.setValue(i, nodeID_t{startNodeOffset + i, table->getTableID()}); + } + if (!insertState) { + insertState = index->initInsertState(context, isVisible); + } + index->commitInsert(transaction::Transaction::Get(*context), nodeIDVector, {scannedVectors}, + *insertState); + startNodeOffset += scanResult.numRows; + return true; +} + +std::unique_ptr RollbackPKDeleter::initScanState(const Transaction* transaction, + DataChunk& dataChunk) { + auto scanState = IndexScanHelper::initScanState(transaction, dataChunk); + scanState->source = TableScanSource::COMMITTED; + scanState->semiMask = semiMask.get(); + return scanState; +} + +template +concept notIndexHashable = !IndexHashable; + +bool RollbackPKDeleter::processScanOutput(main::ClientContext* context, + NodeGroupScanResult scanResult, const std::vector& scannedVectors) { + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + return false; + } + KU_ASSERT(scannedVectors.size() == 1); + auto& scannedVector = *scannedVectors[0]; + auto& pkIndex = index->cast(); + const auto rollbackFunc = [&](T) { + for (idx_t i = 0; i < scannedVector.state->getSelSize(); ++i) { + const auto pos = scannedVector.state->getSelVector()[i]; + T key = scannedVector.getValue(pos); + static constexpr auto isVisible = [](offset_t) { return true; }; + if (offset_t lookupOffset = 0; pkIndex.lookup(transaction::Transaction::Get(*context), + key, lookupOffset, isVisible)) { + // If we delete the key then it will not be visible to future transactions within + // this process + pkIndex.discardLocal(key); + } + } + }; + TypeUtils::visit(scannedVector.dataType.getPhysicalType(), std::cref(rollbackFunc), + [](T) { KU_UNREACHABLE; }); + return true; +} +} // namespace + +void NodeTableScanState::setToTable(const Transaction* transaction, Table* table_, + std::vector columnIDs_, std::vector columnPredicateSets_, + RelDataDirection) { + TableScanState::setToTable(transaction, table_, columnIDs_, std::move(columnPredicateSets_)); + columns.resize(columnIDs.size()); + for (auto i = 0u; i < columnIDs.size(); i++) { + if (const auto columnID = columnIDs[i]; + columnID == INVALID_COLUMN_ID || columnID == ROW_IDX_COLUMN_ID) { + columns[i] = nullptr; + } else { + columns[i] = &table->cast().getColumn(columnID); + } + } +} + +bool NodeTableScanState::scanNext(Transaction* transaction) { + if (source == TableScanSource::NONE) { + return false; + } + KU_ASSERT(columns.size() == outputVectors.size()); + const NodeGroupScanResult scanResult = nodeGroup->scan(transaction, *this); + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + return false; + } + auto nodeGroupStartOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + const auto tableID = table->getTableID(); + if (source == TableScanSource::UNCOMMITTED) { + nodeGroupStartOffset = transaction->getUncommittedOffset(tableID, nodeGroupStartOffset); + } + for (auto i = 0u; i < scanResult.numRows; i++) { + auto& nodeID = nodeIDVector->getValue(i); + nodeID.tableID = tableID; + nodeID.offset = nodeGroupStartOffset + scanResult.startRow + i; + } + return true; +} + +NodeTable::NodeTable(const StorageManager* storageManager, + const NodeTableCatalogEntry* nodeTableEntry, MemoryManager* mm) + : Table{nodeTableEntry, storageManager, mm}, + pkColumnID{nodeTableEntry->getColumnID(nodeTableEntry->getPrimaryKeyName())}, + versionRecordHandler(this) { + auto* dataFH = storageManager->getDataFH(); + auto& pageAllocator = *dataFH->getPageManager(); + const auto maxColumnID = nodeTableEntry->getMaxColumnID(); + columns.resize(maxColumnID + 1); + for (auto& property : nodeTableEntry->getProperties()) { + const auto columnID = nodeTableEntry->getColumnID(property.getName()); + const auto columnName = + StorageUtils::getColumnName(property.getName(), StorageUtils::ColumnType::DEFAULT, ""); + columns[columnID] = ColumnFactory::createColumn(columnName, property.getType().copy(), + dataFH, mm, shadowFile, enableCompression); + } + auto& pkDefinition = nodeTableEntry->getPrimaryKeyDefinition(); + KU_ASSERT(pkColumnID != INVALID_COLUMN_ID); + auto hashIndexType = PrimaryKeyIndex::getIndexType(); + IndexInfo indexInfo{PrimaryKeyIndex::DEFAULT_NAME, hashIndexType.typeName, tableID, + {pkColumnID}, {pkDefinition.getType().getPhysicalType()}, + hashIndexType.constraintType == IndexConstraintType::PRIMARY, + hashIndexType.definitionType == IndexDefinitionType::BUILTIN}; + indexes.push_back(IndexHolder{PrimaryKeyIndex::createNewIndex(indexInfo, + storageManager->isInMemory(), *mm, pageAllocator, shadowFile)}); + nodeGroups = std::make_unique(*mm, + LocalNodeTable::getNodeTableColumnTypes(*nodeTableEntry), enableCompression, + storageManager->getDataFH() ? ResidencyState::ON_DISK : ResidencyState::IN_MEMORY, + &versionRecordHandler); +} + +row_idx_t NodeTable::getNumTotalRows(const Transaction* transaction) { + auto numLocalRows = 0u; + if (transaction && transaction->getLocalStorage()) { + if (const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID)) { + numLocalRows = localTable->getNumTotalRows(); + } + } + return numLocalRows + nodeGroups->getNumTotalRows(); +} + +void NodeTable::initScanState(Transaction* transaction, TableScanState& scanState, bool) const { + auto& nodeScanState = scanState.cast(); + NodeGroup* nodeGroup = nullptr; + switch (nodeScanState.source) { + case TableScanSource::COMMITTED: { + nodeGroup = nodeGroups->getNodeGroup(nodeScanState.nodeGroupIdx); + } break; + case TableScanSource::UNCOMMITTED: { + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + KU_ASSERT(localTable); + const auto& localNodeTable = localTable->cast(); + nodeGroup = localNodeTable.getNodeGroup(nodeScanState.nodeGroupIdx); + KU_ASSERT(nodeGroup); + } break; + case TableScanSource::NONE: { + // DO NOTHING. + } break; + default: { + KU_UNREACHABLE; + } + } + nodeScanState.initState(transaction, nodeGroup); +} + +void NodeTable::initScanState(Transaction* transaction, TableScanState& scanState, + table_id_t tableID, offset_t startOffset) const { + if (transaction->isUnCommitted(tableID, startOffset)) { + scanState.source = TableScanSource::UNCOMMITTED; + scanState.nodeGroupIdx = + StorageUtils::getNodeGroupIdx(transaction->getLocalRowIdx(tableID, startOffset)); + } else { + scanState.source = TableScanSource::COMMITTED; + scanState.nodeGroupIdx = StorageUtils::getNodeGroupIdx(startOffset); + } + initScanState(transaction, scanState); +} + +bool NodeTable::scanInternal(Transaction* transaction, TableScanState& scanState) { + scanState.resetOutVectors(); + return scanState.scanNext(transaction); +} + +template +bool NodeTable::lookup(const Transaction* transaction, const TableScanState& scanState) const { + KU_ASSERT(scanState.nodeIDVector->state->getSelVector().getSelSize() == 1); + const auto nodeIDPos = scanState.nodeIDVector->state->getSelVector()[0]; + if (scanState.nodeIDVector->isNull(nodeIDPos)) { + return false; + } + const auto nodeOffset = scanState.nodeIDVector->readNodeOffset(nodeIDPos); + const offset_t rowIdxInGroup = + transaction->isUnCommitted(tableID, nodeOffset) ? + transaction->getLocalRowIdx(tableID, nodeOffset) - + StorageUtils::getStartOffsetOfNodeGroup(scanState.nodeGroupIdx) : + nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(scanState.nodeGroupIdx); + scanState.rowIdxVector->setValue(nodeIDPos, rowIdxInGroup); + if constexpr (lock) { + return scanState.nodeGroup->lookup(transaction, scanState); + } else { + return scanState.nodeGroup->lookupNoLock(transaction, scanState); + } +} + +template bool NodeTable::lookup(const Transaction* transaction, + const TableScanState& scanState) const; +template bool NodeTable::lookup(const Transaction* transaction, + const TableScanState& scanState) const; + +template +bool NodeTable::lookupMultiple(Transaction* transaction, TableScanState& scanState) const { + const auto numRowsToRead = scanState.nodeIDVector->state->getSelSize(); + sel_t numRowsRead = 0; + for (auto i = 0u; i < numRowsToRead; i++) { + const auto nodeIDPos = scanState.nodeIDVector->state->getSelVector()[i]; + if (scanState.nodeIDVector->isNull(nodeIDPos)) { + continue; + } + const auto nodeOffset = scanState.nodeIDVector->readNodeOffset(nodeIDPos); + const auto isUnCommitted = transaction->isUnCommitted(tableID, nodeOffset); + const auto source = + isUnCommitted ? TableScanSource::UNCOMMITTED : TableScanSource::COMMITTED; + const auto nodeGroupIdx = + isUnCommitted ? + StorageUtils::getNodeGroupIdx(transaction->getLocalRowIdx(tableID, nodeOffset)) : + StorageUtils::getNodeGroupIdx(nodeOffset); + const offset_t rowIdxInGroup = + isUnCommitted ? transaction->getLocalRowIdx(tableID, nodeOffset) - + StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx) : + nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + if (scanState.source == source && scanState.nodeGroupIdx == nodeGroupIdx) { + // If the scan state is already initialized for the same source and node group, we can + // skip re-initialization. + } else { + scanState.source = source; + scanState.nodeGroupIdx = nodeGroupIdx; + initScanState(transaction, scanState); + } + scanState.rowIdxVector->setValue(nodeIDPos, rowIdxInGroup); + if constexpr (lock) { + numRowsRead += scanState.nodeGroup->lookup(transaction, scanState, i); + } else { + numRowsRead += scanState.nodeGroup->lookupNoLock(transaction, scanState, i); + } + } + return numRowsRead == numRowsToRead; +} + +template bool NodeTable::lookupMultiple(Transaction* transaction, + TableScanState& scanState) const; +template bool NodeTable::lookupMultiple(Transaction* transaction, + TableScanState& scanState) const; + +offset_t NodeTable::validateUniquenessConstraint(const Transaction* transaction, + const std::vector& propertyVectors) const { + const auto pkVector = propertyVectors[pkColumnID]; + KU_ASSERT(pkVector->state->getSelVector().getSelSize() == 1); + const auto pkVectorPos = pkVector->state->getSelVector()[0]; + if (offset_t offset = INVALID_OFFSET; + getPKIndex()->lookup(transaction, propertyVectors[pkColumnID], pkVectorPos, offset, + [&](offset_t offset_) { return isVisible(transaction, offset_); })) { + return offset; + } + if (const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID)) { + return localTable->cast().validateUniquenessConstraint(transaction, + *pkVector); + } + return INVALID_OFFSET; +} + +void NodeTable::validatePkNotExists(const Transaction* transaction, ValueVector* pkVector) const { + offset_t dummyOffset = INVALID_OFFSET; + auto& selVector = pkVector->state->getSelVector(); + KU_ASSERT(selVector.getSelSize() == 1); + if (pkVector->isNull(selVector[0])) { + throw RuntimeException(ExceptionMessage::nullPKException()); + } + if (getPKIndex()->lookup(transaction, pkVector, selVector[0], dummyOffset, + [&](offset_t offset) { return isVisible(transaction, offset); })) { + throw RuntimeException( + ExceptionMessage::duplicatePKException(pkVector->getAsValue(selVector[0])->toString())); + } +} + +void NodeTable::initInsertState(main::ClientContext* context, TableInsertState& insertState) { + auto& nodeInsertState = insertState.cast(); + nodeInsertState.indexInsertStates.resize(indexes.size()); + for (auto i = 0u; i < indexes.size(); i++) { + auto& indexHolder = indexes[i]; + const auto index = indexHolder.getIndex(); + nodeInsertState.indexInsertStates[i] = + index->initInsertState(context, [&](offset_t offset) { + return isVisible(transaction::Transaction::Get(*context), offset); + }); + } +} + +void NodeTable::insert(Transaction* transaction, TableInsertState& insertState) { + const auto& nodeInsertState = insertState.cast(); + auto& nodeIDSelVector = nodeInsertState.nodeIDVector.state->getSelVector(); + KU_ASSERT(nodeInsertState.propertyVectors[0]->state->getSelVector().getSelSize() == 1); + KU_ASSERT(nodeIDSelVector.getSelSize() == 1); + if (nodeInsertState.nodeIDVector.isNull(nodeIDSelVector[0])) { + return; + } + const auto localTable = transaction->getLocalStorage()->getOrCreateLocalTable(*this); + validatePkNotExists(transaction, const_cast(&nodeInsertState.pkVector)); + localTable->insert(transaction, insertState); + for (auto i = 0u; i < indexes.size(); i++) { + auto index = indexes[i].getIndex(); + std::vector indexedPropertyVectors; + for (const auto columnID : index->getIndexInfo().columnIDs) { + indexedPropertyVectors.push_back(insertState.propertyVectors[columnID]); + } + index->insert(transaction, nodeInsertState.nodeIDVector, indexedPropertyVectors, + *nodeInsertState.indexInsertStates[i]); + } + if (insertState.logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + auto& wal = transaction->getLocalWAL(); + wal.logTableInsertion(tableID, TableType::NODE, + nodeInsertState.nodeIDVector.state->getSelVector().getSelSize(), + insertState.propertyVectors); + } + hasChanges = true; +} + +void NodeTable::initUpdateState(main::ClientContext* context, TableUpdateState& updateState) const { + auto& nodeUpdateState = updateState.cast(); + nodeUpdateState.indexUpdateState.resize(indexes.size()); + for (auto i = 0u; i < indexes.size(); i++) { + auto& indexHolder = indexes[i]; + auto index = indexHolder.getIndex(); + if (index->isPrimary() || !index->isBuiltOnColumn(nodeUpdateState.columnID)) { + nodeUpdateState.indexUpdateState[i] = nullptr; + continue; + } + nodeUpdateState.indexUpdateState[i] = + index->initUpdateState(context, nodeUpdateState.columnID, [&](offset_t offset) { + return isVisible(transaction::Transaction::Get(*context), offset); + }); + } +} + +void NodeTable::update(Transaction* transaction, TableUpdateState& updateState) { + // NOTE: We assume all inputs are flattened now. This is to simplify the implementation. + // We should optimize this to take unflattened input later. + auto& nodeUpdateState = updateState.constCast(); + KU_ASSERT(nodeUpdateState.nodeIDVector.state->getSelVector().getSelSize() == 1 && + nodeUpdateState.propertyVector.state->getSelVector().getSelSize() == 1); + const auto pos = nodeUpdateState.nodeIDVector.state->getSelVector()[0]; + if (nodeUpdateState.nodeIDVector.isNull(pos)) { + return; + } + const auto pkIndex = getPKIndex(); + if (nodeUpdateState.columnID == pkColumnID && pkIndex) { + throw RuntimeException("Cannot update pk."); + } + const auto nodeOffset = nodeUpdateState.nodeIDVector.readNodeOffset(pos); + for (auto i = 0u; i < indexes.size(); i++) { + auto index = indexes[i].getIndex(); + if (!nodeUpdateState.needToUpdateIndex(i)) { + continue; + } + index->update(transaction, nodeUpdateState.nodeIDVector, nodeUpdateState.propertyVector, + *nodeUpdateState.indexUpdateState[i]); + } + if (transaction->isUnCommitted(tableID, nodeOffset)) { + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + KU_ASSERT(localTable); + localTable->update(&DUMMY_TRANSACTION, updateState); + } else { + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); + const auto rowIdxInGroup = + nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + nodeGroups->getNodeGroup(nodeGroupIdx) + ->update(transaction, rowIdxInGroup, nodeUpdateState.columnID, + nodeUpdateState.propertyVector); + } + if (updateState.logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + auto& wal = transaction->getLocalWAL(); + wal.logNodeUpdate(tableID, nodeUpdateState.columnID, nodeOffset, + &nodeUpdateState.propertyVector); + } + hasChanges = true; +} + +bool NodeTable::delete_(Transaction* transaction, TableDeleteState& deleteState) { + const auto& nodeDeleteState = ku_dynamic_cast(deleteState); + KU_ASSERT(nodeDeleteState.nodeIDVector.state->getSelVector().getSelSize() == 1); + const auto pos = nodeDeleteState.nodeIDVector.state->getSelVector()[0]; + if (nodeDeleteState.nodeIDVector.isNull(pos)) { + return false; + } + bool isDeleted = false; + const auto nodeOffset = nodeDeleteState.nodeIDVector.readNodeOffset(pos); + for (auto& index : indexes) { + auto indexDeleteState = index.getIndex()->initDeleteState(transaction, memoryManager, + getVisibleFunc(transaction)); + index.getIndex()->delete_(transaction, nodeDeleteState.nodeIDVector, *indexDeleteState); + } + + if (transaction->isUnCommitted(tableID, nodeOffset)) { + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + isDeleted = localTable->delete_(&DUMMY_TRANSACTION, deleteState); + } else { + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); + const auto rowIdxInGroup = + nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + isDeleted = nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, &versionRecordHandler); + } + } + if (isDeleted) { + hasChanges = true; + if (deleteState.logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + auto& wal = transaction->getLocalWAL(); + wal.logNodeDeletion(tableID, nodeOffset, &nodeDeleteState.pkVector); + } + } + return isDeleted; +} + +void NodeTable::addColumn(Transaction* transaction, TableAddColumnState& addColumnState, + PageAllocator& pageAllocator) { + auto& definition = addColumnState.propertyDefinition; + columns.push_back(ColumnFactory::createColumn(definition.getName(), definition.getType().copy(), + pageAllocator.getDataFH(), memoryManager, shadowFile, enableCompression)); + LocalTable* localTable = nullptr; + if (transaction->getLocalStorage()) { + localTable = transaction->getLocalStorage()->getLocalTable(tableID); + } + if (localTable) { + localTable->addColumn(addColumnState); + } + nodeGroups->addColumn(addColumnState, &pageAllocator); + hasChanges = true; +} + +std::pair NodeTable::appendToLastNodeGroup(Transaction* transaction, + const std::vector& columnIDs, InMemChunkedNodeGroup& chunkedGroup, + PageAllocator& pageAllocator) { + hasChanges = true; + return nodeGroups->appendToLastNodeGroupAndFlushWhenFull(transaction, columnIDs, chunkedGroup, + pageAllocator); +} + +DataChunk NodeTable::constructDataChunkForColumns(const std::vector& columnIDs) const { + std::vector types; + for (const auto& columnID : columnIDs) { + KU_ASSERT(columnID < columns.size()); + types.push_back(columns[columnID]->getDataType().copy()); + } + return constructDataChunk(memoryManager, std::move(types)); +} + +void NodeTable::commit(main::ClientContext* context, TableCatalogEntry* tableEntry, + LocalTable* localTable) { + const auto startNodeOffset = nodeGroups->getNumTotalRows(); + auto& localNodeTable = localTable->cast(); + + std::vector columnIDsToCommit; + for (auto& property : tableEntry->getProperties()) { + auto columnID = tableEntry->getColumnID(property.getName()); + columnIDsToCommit.push_back(columnID); + } + + auto transaction = transaction::Transaction::Get(*context); + // 1. Append all tuples from local storage to nodeGroups regardless of deleted or not. + // Note: We cannot simply remove all deleted tuples in local node table, as they may have + // connected local rels. Directly removing them will cause shift of committed node offset, + // leading to an inconsistent result with connected rels. + nodeGroups->append(transaction, columnIDsToCommit, localNodeTable.getNodeGroups()); + // 2. Set deleted flag for tuples that are deleted in local storage. + row_idx_t numLocalRows = 0u; + for (auto localNodeGroupIdx = 0u; localNodeGroupIdx < localNodeTable.getNumNodeGroups(); + localNodeGroupIdx++) { + const auto localNodeGroup = localNodeTable.getNodeGroup(localNodeGroupIdx); + if (localNodeGroup->hasDeletions(transaction)) { + // TODO(Guodong): Assume local storage is small here. Should optimize the loop away by + // grabbing a set of deleted rows. + for (auto row = 0u; row < localNodeGroup->getNumRows(); row++) { + if (localNodeGroup->isDeleted(transaction, row)) { + const auto nodeOffset = startNodeOffset + numLocalRows + row; + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); + const auto rowIdxInGroup = + nodeOffset - StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx); + [[maybe_unused]] const bool isDeleted = + nodeGroups->getNodeGroup(nodeGroupIdx)->delete_(transaction, rowIdxInGroup); + KU_ASSERT(isDeleted); + if (transaction->shouldAppendToUndoBuffer()) { + transaction->pushDeleteInfo(nodeGroupIdx, rowIdxInGroup, 1, + &versionRecordHandler); + } + } + } + } + numLocalRows += localNodeGroup->getNumRows(); + } + + // 3. Scan index columns for newly inserted tuples. + for (auto& index : indexes) { + if (!index.needCommitInsert()) { + continue; + } + if (!index.isLoaded()) { + throw RuntimeException( + "Cannot commit index insertions for index " + index.getName() + + ", because it is not loaded. Please load the extension for the index first."); + } + UncommittedIndexInserter indexInserter{startNodeOffset, this, index.getIndex(), + getVisibleFunc(transaction)}; + // We need to scan from local storage here because some tuples in local node groups might + // have been deleted. + scanIndexColumns(context, indexInserter, localNodeTable.getNodeGroups()); + } + + // 4. Clear local table. + localTable->clear(*MemoryManager::Get(*context)); +} + +visible_func NodeTable::getVisibleFunc(const Transaction* transaction) const { + return + [this, transaction](offset_t offset_) -> bool { return isVisible(transaction, offset_); }; +} + +bool NodeTable::checkpoint(main::ClientContext* context, TableCatalogEntry* tableEntry, + PageAllocator& pageAllocator) { + const bool ret = hasChanges; + if (hasChanges) { + // Deleted columns are vacuumed and not checkpointed. + std::vector> checkpointColumns; + std::vector columnIDs; + for (auto& property : tableEntry->getProperties()) { + auto columnID = tableEntry->getColumnID(property.getName()); + checkpointColumns.push_back(std::move(columns[columnID])); + columnIDs.push_back(columnID); + } + columns = std::move(checkpointColumns); + + std::vector checkpointColumnPtrs; + for (const auto& column : columns) { + checkpointColumnPtrs.push_back(column.get()); + } + + NodeGroupCheckpointState state{columnIDs, std::move(checkpointColumnPtrs), pageAllocator, + memoryManager}; + nodeGroups->checkpoint(*memoryManager, state); + for (auto& index : indexes) { + index.checkpoint(context, pageAllocator); + } + tableEntry->vacuumColumnIDs(0 /*nextColumnID*/); + hasChanges = false; + } + return ret; +} + +void NodeTable::rollbackPKIndexInsert(main::ClientContext* context, row_idx_t startRow, + row_idx_t numRows_, node_group_idx_t nodeGroupIdx_) { + const row_idx_t startNodeOffset = + startRow + StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx_); + + RollbackPKDeleter pkDeleter{startNodeOffset, numRows_, this, getPKIndex()}; + scanIndexColumns(context, pkDeleter, *nodeGroups); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void NodeTable::rollbackGroupCollectionInsert(row_idx_t numRows_) { + nodeGroups->rollbackInsert(numRows_); +} + +void NodeTable::rollbackCheckpoint() { + for (auto& index : indexes) { + index.rollbackCheckpoint(); + } +} + +void NodeTable::reclaimStorage(PageAllocator& pageAllocator) const { + nodeGroups->reclaimStorage(pageAllocator); + getPKIndex()->reclaimStorage(pageAllocator); +} + +TableStats NodeTable::getStats(const Transaction* transaction) const { + auto stats = nodeGroups->getStats(); + if (const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID)) { + const auto localStats = localTable->cast().getStats(); + stats.merge(localStats); + } + return stats; +} + +bool NodeTable::isVisible(const Transaction* transaction, offset_t offset) const { + auto [nodeGroupIdx, offsetInGroup] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(offset); + const auto* nodeGroup = getNodeGroup(nodeGroupIdx); + return nodeGroup->isVisible(transaction, offsetInGroup); +} + +bool NodeTable::isVisibleNoLock(const Transaction* transaction, offset_t offset) const { + auto [nodeGroupIdx, offsetInGroup] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(offset); + if (nodeGroupIdx >= nodeGroups->getNumNodeGroupsNoLock()) { + return false; + } + const auto* nodeGroup = getNodeGroupNoLock(nodeGroupIdx); + return nodeGroup->isVisibleNoLock(transaction, offsetInGroup); +} + +bool NodeTable::lookupPK(const Transaction* transaction, ValueVector* keyVector, uint64_t vectorPos, + offset_t& result) const { + if (transaction->getLocalStorage()) { + if (const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + localTable && localTable->cast().lookupPK(transaction, keyVector, + vectorPos, result)) { + return true; + } + } + return getPKIndex()->lookup(transaction, keyVector, vectorPos, result, + [&](offset_t offset) { return isVisibleNoLock(transaction, offset); }); +} + +void NodeTable::scanIndexColumns(main::ClientContext* context, IndexScanHelper& scanHelper, + const NodeGroupCollection& nodeGroups_) const { + auto dataChunk = constructDataChunkForColumns(scanHelper.index->getIndexInfo().columnIDs); + const auto scanState = + scanHelper.initScanState(transaction::Transaction::Get(*context), dataChunk); + + const auto numNodeGroups = nodeGroups_.getNumNodeGroups(); + for (node_group_idx_t nodeGroupToScan = 0u; nodeGroupToScan < numNodeGroups; + ++nodeGroupToScan) { + scanState->nodeGroup = nodeGroups_.getNodeGroupNoLock(nodeGroupToScan); + + // It is possible for the node group to have no chunked groups if we are rolling back due to + // an exception that is thrown before any chunked groups could be appended to the node group + if (scanState->nodeGroup->getNumChunkedGroups() > 0) { + scanState->nodeGroupIdx = nodeGroupToScan; + KU_ASSERT(scanState->nodeGroup); + scanState->nodeGroup->initializeScanState(transaction::Transaction::Get(*context), + *scanState); + while (true) { + if (const auto scanResult = scanState->nodeGroup->scan( + transaction::Transaction::Get(*context), *scanState); + !scanHelper.processScanOutput(context, scanResult, scanState->outputVectors)) { + break; + } + } + } + } +} + +void NodeTable::addIndex(std::unique_ptr index) { + if (getIndex(index->getName()).has_value()) { + throw RuntimeException("Index with name " + index->getName() + " already exists."); + } + indexes.push_back(IndexHolder{std::move(index)}); + hasChanges = true; +} + +void NodeTable::dropIndex(const std::string& name) { + KU_ASSERT(getIndex(name) != nullptr); + for (auto it = indexes.begin(); it != indexes.end(); ++it) { + if (StringUtils::caseInsensitiveEquals(it->getName(), name)) { + KU_ASSERT(it->isLoaded()); + indexes.erase(it); + hasChanges = true; + return; + } + } +} + +std::optional> NodeTable::getIndexHolder( + const std::string& name) { + for (auto& index : indexes) { + if (StringUtils::caseInsensitiveEquals(index.getName(), name)) { + return index; + } + } + return std::nullopt; +} + +std::optional NodeTable::getIndex(const std::string& name) const { + for (auto& index : indexes) { + if (StringUtils::caseInsensitiveEquals(index.getName(), name)) { + if (index.isLoaded()) { + return index.getIndex(); + } + throw RuntimeException(stringFormat( + "Index {} is not loaded yet. Please load the index before accessing it.", name)); + } + } + return std::nullopt; +} + +void NodeTable::serialize(Serializer& serializer) const { + nodeGroups->serialize(serializer); + serializer.write(indexes.size()); + for (auto i = 0u; i < indexes.size(); ++i) { + indexes[i].serialize(serializer); + } +} + +void NodeTable::deserialize(main::ClientContext* context, StorageManager* storageManager, + Deserializer& deSer) { + nodeGroups->deserialize(deSer, *memoryManager); + std::vector indexInfos; + std::vector storageInfoBufferSizes; + std::vector> storageInfoBuffers; + uint64_t numIndexes = 0u; + deSer.deserializeValue(numIndexes); + indexInfos.reserve(numIndexes); + storageInfoBufferSizes.reserve(numIndexes); + storageInfoBuffers.reserve(numIndexes); + for (uint64_t i = 0; i < numIndexes; ++i) { + IndexInfo indexInfo = IndexInfo::deserialize(deSer); + indexInfos.push_back(indexInfo); + uint64_t storageInfoSize = 0u; + deSer.deserializeValue(storageInfoSize); + storageInfoBufferSizes.push_back(storageInfoSize); + auto storageInfoBuffer = std::make_unique(storageInfoSize); + deSer.read(storageInfoBuffer.get(), storageInfoSize); + storageInfoBuffers.push_back(std::move(storageInfoBuffer)); + } + indexes.clear(); + indexes.reserve(indexInfos.size()); + for (auto i = 0u; i < indexInfos.size(); ++i) { + indexes.push_back(IndexHolder(indexInfos[i], std::move(storageInfoBuffers[i]), + storageInfoBufferSizes[i])); + if (indexInfos[i].isBuiltin) { + indexes[i].load(context, storageManager); + } + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/null_column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/null_column.cpp new file mode 100644 index 0000000000..44daf91335 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/null_column.cpp @@ -0,0 +1,41 @@ +#include "storage/table/null_column.h" + +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/compression/compression.h" +#include "storage/storage_utils.h" + +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +struct NullColumnFunc { + static void readValuesFromPageToVector(const uint8_t* frame, PageCursor& pageCursor, + ValueVector* resultVector, uint32_t posInVector, uint32_t numValuesToRead, + const CompressionMetadata& metadata) { + // Read bit-packed null flags from the frame into the result vector + // Casting to uint64_t should be safe as long as the page size is a multiple of 8 bytes. + // Otherwise, it could read off the end of the page. + if (metadata.isConstant()) { + bool value = false; + ConstantCompression::decompressValues(reinterpret_cast(&value), 0 /*offset*/, + 1 /*numValues*/, PhysicalTypeID::BOOL, 1 /*numBytesPerValue*/, metadata); + resultVector->setNullRange(posInVector, numValuesToRead, value); + } else { + resultVector->setNullFromBits(reinterpret_cast(frame), + pageCursor.elemPosInPage, posInVector, numValuesToRead); + } + } +}; + +NullColumn::NullColumn(const std::string& name, FileHandle* dataFH, MemoryManager* mm, + ShadowFile* shadowFile, bool enableCompression) + : Column{name, LogicalType::BOOL(), dataFH, mm, shadowFile, enableCompression, + false /*requireNullColumn*/} { + readToVectorFunc = NullColumnFunc::readValuesFromPageToVector; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/rel_table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/rel_table.cpp new file mode 100644 index 0000000000..49820f9efc --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/rel_table.cpp @@ -0,0 +1,548 @@ +#include "storage/table/rel_table.h" + +#include + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/exception/message.h" +#include "common/exception/runtime.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "storage/local_storage/local_rel_table.h" +#include "storage/local_storage/local_storage.h" +#include "storage/local_storage/local_table.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/rel_table_data.h" +#include "storage/wal/local_wal.h" +#include "transaction/transaction.h" +#include + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::transaction; +using namespace lbug::evaluator; + +namespace lbug { +namespace storage { + +void RelTableScanState::setToTable(const Transaction* transaction, Table* table_, + std::vector columnIDs_, std::vector columnPredicateSets_, + RelDataDirection direction_) { + TableScanState::setToTable(transaction, table_, std::move(columnIDs_), + std::move(columnPredicateSets_)); + columns.resize(columnIDs.size()); + direction = direction_; + for (size_t i = 0; i < columnIDs.size(); ++i) { + auto columnID = columnIDs[i]; + if (columnID == INVALID_COLUMN_ID || columnID == ROW_IDX_COLUMN_ID) { + columns[i] = nullptr; + } else { + columns[i] = table->cast().getColumn(columnID, direction); + } + } + csrOffsetColumn = table->cast().getCSROffsetColumn(direction); + csrLengthColumn = table->cast().getCSRLengthColumn(direction); + nodeGroupIdx = INVALID_NODE_GROUP_IDX; + if (const auto localRelTable = + transaction->getLocalStorage()->getLocalTable(table->getTableID())) { + auto localTableColumnIDs = LocalRelTable::rewriteLocalColumnIDs(direction, columnIDs); + localTableScanState = std::make_unique(*this, + localRelTable->ptrCast(), localTableColumnIDs); + } +} + +void RelTableScanState::initState(Transaction* transaction, NodeGroup* nodeGroup, + bool resetCachedBoundNodeIDs) { + this->nodeGroup = nodeGroup; + if (resetCachedBoundNodeIDs) { + initCachedBoundNodeIDSelVector(); + } + if (this->nodeGroup) { + initStateForCommitted(transaction); + } else if (hasUnCommittedData()) { + initStateForUncommitted(); + } else { + source = TableScanSource::NONE; + } +} + +void RelTableScanState::initCachedBoundNodeIDSelVector() { + if (nodeIDVector->state->getSelVector().isUnfiltered()) { + cachedBoundNodeSelVector.setToUnfiltered(); + } else { + cachedBoundNodeSelVector.setToFiltered(); + memcpy(cachedBoundNodeSelVector.getMutableBuffer().data(), + nodeIDVector->state->getSelVectorUnsafe().getMutableBuffer().data(), + nodeIDVector->state->getSelVector().getSelSize() * sizeof(sel_t)); + } + cachedBoundNodeSelVector.setSelSize(nodeIDVector->state->getSelVector().getSelSize()); +} + +bool RelTableScanState::hasUnCommittedData() const { + return localTableScanState && localTableScanState->localRelTable; +} + +void RelTableScanState::initStateForCommitted(const Transaction* transaction) { + source = TableScanSource::COMMITTED; + currBoundNodeIdx = 0; + nodeGroup->initializeScanState(transaction, *this); +} + +void RelTableScanState::initStateForUncommitted() { + KU_ASSERT(localTableScanState); + source = TableScanSource::UNCOMMITTED; + currBoundNodeIdx = 0; + localTableScanState->localRelTable->initializeScan(*this); +} + +bool RelTableScanState::scanNext(Transaction* transaction) { + while (true) { + switch (source) { + case TableScanSource::COMMITTED: { + const auto scanResult = nodeGroup->scan(transaction, *this); + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + if (hasUnCommittedData()) { + initStateForUncommitted(); + } else { + source = TableScanSource::NONE; + } + continue; + } + return true; + } + case TableScanSource::UNCOMMITTED: { + KU_ASSERT(localTableScanState && localTableScanState->localRelTable); + return localTableScanState->localRelTable->scan(transaction, *this); + } + case TableScanSource::NONE: { + return false; + } + default: { + KU_UNREACHABLE; + } + } + } +} + +void RelTableScanState::setNodeIDVectorToFlat(sel_t selPos) const { + nodeIDVector->state->setToFlat(); + nodeIDVector->state->getSelVectorUnsafe().setToFiltered(1); + nodeIDVector->state->getSelVectorUnsafe()[0] = selPos; +} + +RelTable::RelTable(RelGroupCatalogEntry* relGroupEntry, table_id_t fromTableID, + table_id_t toTableID, const StorageManager* storageManager, MemoryManager* memoryManager) + : Table{relGroupEntry, storageManager, memoryManager}, fromNodeTableID{fromTableID}, + toNodeTableID{toTableID}, nextRelOffset{0} { + auto relEntryInfo = relGroupEntry->getRelEntryInfo(fromNodeTableID, toNodeTableID); + tableID = relEntryInfo->oid; + relGroupID = relGroupEntry->getTableID(); + for (auto direction : relGroupEntry->getRelDataDirections()) { + auto nbrTableID = RelDirectionUtils::getNbrTableID(direction, fromTableID, toTableID); + directedRelData.emplace_back( + std::make_unique(storageManager->getDataFH(), memoryManager, shadowFile, + *relGroupEntry, *this, direction, nbrTableID, enableCompression)); + } +} + +void RelTable::initScanState(Transaction* transaction, TableScanState& scanState, + bool resetCachedBoundNodeSelVec) const { + auto& relScanState = scanState.cast(); + // Note there we directly read node at pos 0 here regardless the selVector is filtered or not. + // This is because we're assuming the nodeIDVector is always a sequence here. + const auto boundNodeID = relScanState.nodeIDVector->getValue( + relScanState.nodeIDVector->state->getSelVector()[0]); + NodeGroup* nodeGroup = nullptr; + // Check if the node group idx is same as previous scan. + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeID.offset); + if (relScanState.nodeGroupIdx != nodeGroupIdx) { + // We need to re-initialize the node group scan state. + nodeGroup = getDirectedTableData(relScanState.direction)->getNodeGroup(nodeGroupIdx); + } else { + nodeGroup = relScanState.nodeGroup; + } + scanState.initState(transaction, nodeGroup, resetCachedBoundNodeSelVec); +} + +bool RelTable::scanInternal(Transaction* transaction, TableScanState& scanState) { + return scanState.scanNext(transaction); +} + +static void throwRelMultiplicityConstraintError(const std::string& tableName, offset_t nodeOffset, + RelDataDirection direction) { + throw RuntimeException(ExceptionMessage::violateRelMultiplicityConstraint(tableName, + std::to_string(nodeOffset), RelDirectionUtils::relDirectionToString(direction))); +} + +void RelTable::checkRelMultiplicityConstraint(Transaction* transaction, + const TableInsertState& state) const { + const auto& insertState = state.constCast(); + KU_ASSERT(insertState.srcNodeIDVector.state->getSelVector().getSelSize() == 1 && + insertState.dstNodeIDVector.state->getSelVector().getSelSize() == 1); + + for (auto& relData : directedRelData) { + if (relData->getMultiplicity() == RelMultiplicity::ONE) { + throwIfNodeHasRels(transaction, relData->getDirection(), + &insertState.getBoundNodeIDVector(relData->getDirection()), + throwRelMultiplicityConstraintError); + } + } +} + +void RelTable::insert(Transaction* transaction, TableInsertState& insertState) { + checkRelMultiplicityConstraint(transaction, insertState); + + KU_ASSERT(transaction->getLocalStorage()); + const auto localTable = transaction->getLocalStorage()->getOrCreateLocalTable(*this); + localTable->insert(transaction, insertState); + if (insertState.logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + const auto& relInsertState = insertState.cast(); + std::vector vectorsToLog; + vectorsToLog.push_back(&relInsertState.srcNodeIDVector); + vectorsToLog.push_back(&relInsertState.dstNodeIDVector); + vectorsToLog.insert(vectorsToLog.end(), relInsertState.propertyVectors.begin(), + relInsertState.propertyVectors.end()); + KU_ASSERT(relInsertState.srcNodeIDVector.state->getSelVector().getSelSize() == 1); + auto& wal = transaction->getLocalWAL(); + wal.logTableInsertion(tableID, TableType::REL, + relInsertState.srcNodeIDVector.state->getSelVector().getSelSize(), vectorsToLog); + } + hasChanges = true; +} + +void RelTable::update(Transaction* transaction, TableUpdateState& updateState) { + const auto& relUpdateState = updateState.cast(); + KU_ASSERT(relUpdateState.relIDVector.state->getSelVector().getSelSize() == 1); + const auto relIDPos = relUpdateState.relIDVector.state->getSelVector()[0]; + if (const auto relOffset = relUpdateState.relIDVector.readNodeOffset(relIDPos); + relOffset >= StorageConstants::MAX_NUM_ROWS_IN_TABLE) { + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + KU_ASSERT(localTable); + localTable->update(&DUMMY_TRANSACTION, updateState); + } else { + for (auto& relData : directedRelData) { + relData->update(transaction, + relUpdateState.getBoundNodeIDVector(relData->getDirection()), + relUpdateState.relIDVector, relUpdateState.columnID, relUpdateState.propertyVector); + } + } + if (updateState.logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + auto& wal = transaction->getLocalWAL(); + wal.logRelUpdate(tableID, relUpdateState.columnID, &relUpdateState.srcNodeIDVector, + &relUpdateState.dstNodeIDVector, &relUpdateState.relIDVector, + &relUpdateState.propertyVector); + } + hasChanges = true; +} + +bool RelTable::delete_(Transaction* transaction, TableDeleteState& deleteState) { + const auto& relDeleteState = deleteState.cast(); + KU_ASSERT(relDeleteState.relIDVector.state->getSelVector().getSelSize() == 1); + const auto relIDPos = relDeleteState.relIDVector.state->getSelVector()[0]; + bool isDeleted = false; + if (const auto relOffset = relDeleteState.relIDVector.readNodeOffset(relIDPos); + relOffset >= StorageConstants::MAX_NUM_ROWS_IN_TABLE) { + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + KU_ASSERT(localTable); + isDeleted = localTable->delete_(transaction, deleteState); + } else { + for (auto& relData : directedRelData) { + isDeleted = relData->delete_(transaction, + relDeleteState.getBoundNodeIDVector(relData->getDirection()), + relDeleteState.relIDVector); + if (!isDeleted) { + break; + } + } + } + if (isDeleted) { + hasChanges = true; + if (deleteState.logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + auto& wal = transaction->getLocalWAL(); + wal.logRelDelete(tableID, &relDeleteState.srcNodeIDVector, + &relDeleteState.dstNodeIDVector, &relDeleteState.relIDVector); + } + } + return isDeleted; +} + +void RelTable::detachDelete(Transaction* transaction, RelTableDeleteState* deleteState) { + auto direction = deleteState->detachDeleteDirection; + if (std::ranges::count(getStorageDirections(), direction) == 0) { + throw RuntimeException( + stringFormat("Cannot delete edges of direction {} from table {} as they do not exist.", + RelDirectionUtils::relDirectionToString(direction), tableName)); + } + KU_ASSERT(deleteState->srcNodeIDVector.state->getSelVector().getSelSize() == 1); + const auto tableData = getDirectedTableData(direction); + const auto reverseTableData = + directedRelData.size() == NUM_REL_DIRECTIONS ? + getDirectedTableData(RelDirectionUtils::getOppositeDirection(direction)) : + nullptr; + auto relReadState = + std::make_unique(*memoryManager, &deleteState->srcNodeIDVector, + std::vector{&deleteState->dstNodeIDVector, &deleteState->relIDVector}, + deleteState->dstNodeIDVector.state, true /*randomLookup*/); + relReadState->setToTable(transaction, this, {NBR_ID_COLUMN_ID, REL_ID_COLUMN_ID}, {}, + direction); + initScanState(transaction, *relReadState); + detachDeleteForCSRRels(transaction, tableData, reverseTableData, relReadState.get(), + deleteState); + if (deleteState->logToWAL && transaction->shouldLogToWAL()) { + KU_ASSERT(transaction->isWriteTransaction()); + auto& wal = transaction->getLocalWAL(); + wal.logRelDetachDelete(tableID, direction, &deleteState->srcNodeIDVector); + } + hasChanges = true; +} + +std::vector RelTable::getStorageDirections() const { + std::vector ret; + for (const auto& relData : directedRelData) { + ret.push_back(relData->getDirection()); + } + return ret; +} + +bool RelTable::checkIfNodeHasRels(Transaction* transaction, RelDataDirection direction, + ValueVector* srcNodeIDVector) const { + bool hasRels = false; + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + if (localTable) { + hasRels = localTable->cast().checkIfNodeHasRels(srcNodeIDVector, direction); + } + hasRels = hasRels || + getDirectedTableData(direction)->checkIfNodeHasRels(transaction, srcNodeIDVector); + return hasRels; +} + +void RelTable::throwIfNodeHasRels(Transaction* transaction, RelDataDirection direction, + ValueVector* srcNodeIDVector, const rel_multiplicity_constraint_throw_func_t& throwFunc) const { + const auto nodeIDPos = srcNodeIDVector->state->getSelVector()[0]; + const auto nodeOffset = srcNodeIDVector->getValue(nodeIDPos).offset; + if (checkIfNodeHasRels(transaction, direction, srcNodeIDVector)) { + throwFunc(tableName, nodeOffset, direction); + } +} + +void RelTable::detachDeleteForCSRRels(Transaction* transaction, RelTableData* tableData, + RelTableData* reverseTableData, RelTableScanState* relDataReadState, + RelTableDeleteState* deleteState) { + const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID); + const auto tempState = deleteState->dstNodeIDVector.state.get(); + while (scan(transaction, *relDataReadState)) { + const auto numRelsScanned = tempState->getSelVector().getSelSize(); + + // rel table data delete_() expects the input to be flat + // so we manually flatten the scanned rels here + // also if the scanned state is unfiltered we need to copy over the unfiltered values to the + // filtered buffer + // TODO(Royi/Guodong) remove this once delete_() supports unflat vectors + if (tempState->getSelVector().isUnfiltered()) { + tempState->getSelVectorUnsafe().setRange(0, numRelsScanned); + } + tempState->getSelVectorUnsafe().setToFiltered(1); + + for (auto i = 0u; i < numRelsScanned; i++) { + tempState->getSelVectorUnsafe()[0] = deleteState->relIDVector.state->getSelVector()[i]; + + const auto relIDPos = deleteState->relIDVector.state->getSelVector()[0]; + const auto relOffset = deleteState->relIDVector.readNodeOffset(relIDPos); + if (relOffset >= StorageConstants::MAX_NUM_ROWS_IN_TABLE) { + KU_ASSERT(localTable); + localTable->delete_(transaction, *deleteState); + continue; + } + [[maybe_unused]] const auto deleted = tableData->delete_(transaction, + deleteState->srcNodeIDVector, deleteState->relIDVector); + if (reverseTableData) { + [[maybe_unused]] const auto reverseDeleted = reverseTableData->delete_(transaction, + deleteState->dstNodeIDVector, deleteState->relIDVector); + KU_ASSERT(deleted == reverseDeleted); + } + } + tempState->getSelVectorUnsafe().setToUnfiltered(); + } +} + +void RelTable::addColumn(Transaction* transaction, TableAddColumnState& addColumnState, + PageAllocator& pageAllocator) { + LocalTable* localTable = nullptr; + if (transaction->getLocalStorage()) { + localTable = transaction->getLocalStorage()->getLocalTable(tableID); + } + if (localTable) { + localTable->addColumn(addColumnState); + } + for (auto& directedRelData : directedRelData) { + directedRelData->addColumn(addColumnState, pageAllocator); + } + hasChanges = true; +} + +RelTableData* RelTable::getDirectedTableData(RelDataDirection direction) const { + const auto directionIdx = RelDirectionUtils::relDirectionToKeyIdx(direction); + if (directionIdx >= directedRelData.size()) { + throw RuntimeException(stringFormat( + "Failed to get {} data for rel table \"{}\", please set the storage direction to BOTH", + RelDirectionUtils::relDirectionToString(direction), tableName)); + } + KU_ASSERT(directedRelData[directionIdx]->getDirection() == direction); + return directedRelData[directionIdx].get(); +} + +NodeGroup* RelTable::getOrCreateNodeGroup(const Transaction* transaction, + node_group_idx_t nodeGroupIdx, RelDataDirection direction) const { + return getDirectedTableData(direction)->getOrCreateNodeGroup(transaction, nodeGroupIdx); +} + +void RelTable::pushInsertInfo(const Transaction* transaction, RelDataDirection direction, + const CSRNodeGroup& nodeGroup, row_idx_t numRows_, CSRNodeGroupScanSource source) const { + getDirectedTableData(direction)->pushInsertInfo(transaction, nodeGroup, numRows_, source); +} + +void RelTable::commit(main::ClientContext* context, TableCatalogEntry* tableEntry, + LocalTable* localTable) { + auto& localRelTable = localTable->cast(); + if (localRelTable.isEmpty()) { + localTable->clear(*MemoryManager::Get(*context)); + return; + } + // Update relID in local storage. + updateRelOffsets(localRelTable); + // For both forward and backward directions, re-org local storage into compact CSR node groups. + auto& localNodeGroup = localRelTable.getLocalNodeGroup(); + // Scan from local node group and write to WAL. + std::vector columnIDsToScan; + for (auto i = 0u; i < localRelTable.getNumColumns(); i++) { + columnIDsToScan.push_back(i); + } + + std::vector columnIDsToCommit; + columnIDsToCommit.push_back(0); // NBR column. + for (auto& property : tableEntry->getProperties()) { + auto columnID = tableEntry->getColumnID(property.getName()); + columnIDsToCommit.push_back(columnID); + } + // commit rel table data + auto transaction = transaction::Transaction::Get(*context); + for (auto& relData : directedRelData) { + const auto direction = relData->getDirection(); + const auto columnToSkip = (direction == RelDataDirection::FWD) ? + LOCAL_BOUND_NODE_ID_COLUMN_ID : + LOCAL_NBR_NODE_ID_COLUMN_ID; + for (auto& [boundNodeOffset, rowIndices] : localRelTable.getCSRIndex(direction)) { + auto [nodeGroupIdx, boundOffsetInGroup] = + StorageUtils::getQuotientRemainder(boundNodeOffset, StorageConfig::NODE_GROUP_SIZE); + auto& nodeGroup = + relData->getOrCreateNodeGroup(transaction, nodeGroupIdx)->cast(); + pushInsertInfo(transaction, direction, nodeGroup, rowIndices.size(), + CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); + prepareCommitForNodeGroup(transaction, columnIDsToCommit, localNodeGroup, nodeGroup, + boundOffsetInGroup, rowIndices, columnToSkip); + } + } + + localRelTable.clear(*MemoryManager::Get(*context)); +} + +void RelTable::reclaimStorage(PageAllocator& pageAllocator) const { + for (auto& relData : directedRelData) { + relData->reclaimStorage(pageAllocator); + } +} + +void RelTable::updateRelOffsets(const LocalRelTable& localRelTable) { + auto& localNodeGroup = localRelTable.getLocalNodeGroup(); + const offset_t maxCommittedOffset = reserveRelOffsets(localNodeGroup.getNumRows()); + RUNTIME_CHECK(uint64_t totalNumRows = 0); + for (auto i = 0u; i < localNodeGroup.getNumChunkedGroups(); i++) { + const auto chunkedGroup = localNodeGroup.getChunkedNodeGroup(i); + KU_ASSERT(chunkedGroup); + auto& internalIDChunk = chunkedGroup->getColumnChunk(LOCAL_REL_ID_COLUMN_ID); + RUNTIME_CHECK(totalNumRows += internalIDChunk.getNumValues()); + for (auto rowIdx = 0u; rowIdx < internalIDChunk.getNumValues(); rowIdx++) { + const auto localRelOffset = internalIDChunk.getValue(rowIdx); + const auto committedRelOffset = getCommittedOffset(localRelOffset, maxCommittedOffset); + internalIDChunk.setValue(committedRelOffset, rowIdx); + } + + internalIDChunk.setTableID(tableID); + } + KU_ASSERT(totalNumRows == localNodeGroup.getNumRows()); +} + +offset_t RelTable::getCommittedOffset(offset_t uncommittedOffset, offset_t maxCommittedOffset) { + return uncommittedOffset - StorageConstants::MAX_NUM_ROWS_IN_TABLE + maxCommittedOffset; +} + +void RelTable::prepareCommitForNodeGroup(const Transaction* transaction, + const std::vector& columnIDs, const NodeGroup& localNodeGroup, + CSRNodeGroup& csrNodeGroup, offset_t boundOffsetInGroup, const row_idx_vec_t& rowIndices, + column_id_t skippedColumn) { + for (const auto row : rowIndices) { + auto [chunkedGroupIdx, rowInChunkedGroup] = + StorageUtils::getQuotientRemainder(row, StorageConfig::CHUNKED_NODE_GROUP_CAPACITY); + std::vector chunks; + const auto chunkedGroup = localNodeGroup.getChunkedNodeGroup(chunkedGroupIdx); + for (auto i = 0u; i < chunkedGroup->getNumColumns(); i++) { + if (i == skippedColumn) { + continue; + } + chunks.push_back(&chunkedGroup->getColumnChunk(i)); + } + csrNodeGroup.append(transaction, columnIDs, boundOffsetInGroup, chunks, rowInChunkedGroup, + 1 /*numRows*/); + } +} + +bool RelTable::checkpoint(main::ClientContext*, TableCatalogEntry* tableEntry, + PageAllocator& pageAllocator) { + bool ret = hasChanges; + if (hasChanges) { + // Deleted columns are vacuumed and not checkpointed or serialized. + std::vector columnIDs; + columnIDs.push_back(0); + for (auto& property : tableEntry->getProperties()) { + columnIDs.push_back(tableEntry->getColumnID(property.getName())); + } + for (auto& directedRelData : directedRelData) { + directedRelData->checkpoint(columnIDs, pageAllocator); + } + hasChanges = false; + } + return ret; +} + +row_idx_t RelTable::getNumTotalRows(const Transaction* transaction) { + auto numLocalRows = 0u; + if (auto localTable = transaction->getLocalStorage()->getLocalTable(tableID)) { + numLocalRows = localTable->getNumTotalRows(); + } + return numLocalRows + nextRelOffset; +} + +void RelTable::serialize(Serializer& ser) const { + ser.writeDebuggingInfo("next_rel_offset"); + ser.write(nextRelOffset); + for (auto& directedRelData : directedRelData) { + directedRelData->serialize(ser); + } +} + +void RelTable::deserialize(main::ClientContext*, StorageManager*, Deserializer& deSer) { + std::string key; + deSer.validateDebuggingInfo(key, "next_rel_offset"); + deSer.deserializeValue(nextRelOffset); + for (auto i = 0u; i < directedRelData.size(); i++) { + directedRelData[i]->deserialize(deSer, *memoryManager); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/rel_table_data.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/rel_table_data.cpp new file mode 100644 index 0000000000..7640f6c202 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/rel_table_data.cpp @@ -0,0 +1,292 @@ +#include "storage/table/rel_table_data.h" + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/enums/rel_direction.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/node_group.h" +#include "storage/table/rel_table.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +PersistentVersionRecordHandler::PersistentVersionRecordHandler(RelTableData* relTableData) + : relTableData(relTableData) {} + +void PersistentVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows, + transaction_t commitTS) const { + if (nodeGroupIdx < relTableData->getNumNodeGroups()) { + auto& nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx)->cast(); + if (auto* persistentChunkedGroup = nodeGroup.getPersistentChunkedGroup()) { + std::invoke(func, *persistentChunkedGroup, startRow, numRows, commitTS); + } + } +} + +void PersistentVersionRecordHandler::rollbackInsert(main::ClientContext* context, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows) const { + VersionRecordHandler::rollbackInsert(context, nodeGroupIdx, startRow, numRows); + relTableData->rollbackGroupCollectionInsert(numRows, true); +} + +InMemoryVersionRecordHandler::InMemoryVersionRecordHandler(RelTableData* relTableData) + : relTableData(relTableData) {} + +void InMemoryVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows, + transaction_t commitTS) const { + auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx); + nodeGroup->applyFuncToChunkedGroups(func, startRow, numRows, commitTS); +} + +void InMemoryVersionRecordHandler::rollbackInsert(main::ClientContext* context, + node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows) const { + VersionRecordHandler::rollbackInsert(context, nodeGroupIdx, startRow, numRows); + auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx); + const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow); + nodeGroup->rollbackInsert(startRow); + relTableData->rollbackGroupCollectionInsert(numRowsToRollback, false); +} + +RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile, + const RelGroupCatalogEntry& relGroupEntry, Table& table, RelDataDirection direction, + table_id_t nbrTableID, bool enableCompression) + : table{table}, mm{mm}, shadowFile{shadowFile}, enableCompression{enableCompression}, + direction{direction}, multiplicity{relGroupEntry.getMultiplicity(direction)}, + persistentVersionRecordHandler(this), inMemoryVersionRecordHandler(this) { + initCSRHeaderColumns(dataFH); + initPropertyColumns(relGroupEntry, nbrTableID, dataFH); + // default to using the persistent version record handler + // if we want to use the in-memory handler, we will explicitly pass it into + // nodeGroups.pushInsertInfo() + nodeGroups = std::make_unique(*mm, getColumnTypes(), enableCompression, + ResidencyState::ON_DISK, &persistentVersionRecordHandler); +} + +void RelTableData::initCSRHeaderColumns(FileHandle* dataFH) { + // No NULL values is allowed for the csr length and offset column. + auto csrOffsetColumnName = StorageUtils::getColumnName("", StorageUtils::ColumnType::CSR_OFFSET, + RelDirectionUtils::relDirectionToString(direction)); + csrHeaderColumns.offset = std::make_unique(csrOffsetColumnName, LogicalType::UINT64(), + dataFH, mm, shadowFile, enableCompression, false /* requireNullColumn */); + auto csrLengthColumnName = StorageUtils::getColumnName("", StorageUtils::ColumnType::CSR_LENGTH, + RelDirectionUtils::relDirectionToString(direction)); + csrHeaderColumns.length = std::make_unique(csrLengthColumnName, LogicalType::UINT64(), + dataFH, mm, shadowFile, enableCompression, false /* requireNullColumn */); +} + +void RelTableData::initPropertyColumns(const RelGroupCatalogEntry& relGroupEntry, + table_id_t nbrTableID, FileHandle* dataFH) { + const auto maxColumnID = relGroupEntry.getMaxColumnID(); + columns.resize(maxColumnID + 1); + auto nbrIDColName = StorageUtils::getColumnName("NBR_ID", StorageUtils::ColumnType::DEFAULT, + RelDirectionUtils::relDirectionToString(direction)); + auto nbrIDColumn = + std::make_unique(nbrIDColName, dataFH, mm, shadowFile, enableCompression); + columns[NBR_ID_COLUMN_ID] = std::move(nbrIDColumn); + for (auto& property : relGroupEntry.getProperties()) { + const auto columnID = relGroupEntry.getColumnID(property.getName()); + const auto colName = StorageUtils::getColumnName(property.getName(), + StorageUtils::ColumnType::DEFAULT, RelDirectionUtils::relDirectionToString(direction)); + columns[columnID] = ColumnFactory::createColumn(colName, property.getType().copy(), dataFH, + mm, shadowFile, enableCompression); + } + // Set common tableID for nbrIDColumn and relIDColumn. + columns[NBR_ID_COLUMN_ID]->cast().setCommonTableID(nbrTableID); + columns[REL_ID_COLUMN_ID]->cast().setCommonTableID(table.getTableID()); +} + +bool RelTableData::update(Transaction* transaction, ValueVector& boundNodeIDVector, + const ValueVector& relIDVector, column_id_t columnID, const ValueVector& dataVector) const { + KU_ASSERT(boundNodeIDVector.state->getSelVector().getSelSize() == 1); + KU_ASSERT(relIDVector.state->getSelVector().getSelSize() == 1); + const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0]; + const auto relIDPos = relIDVector.state->getSelVector()[0]; + if (boundNodeIDVector.isNull(boundNodePos) || relIDVector.isNull(relIDPos)) { + return false; + } + const auto [source, rowIdx] = findMatchingRow(transaction, boundNodeIDVector, relIDVector); + KU_ASSERT(rowIdx != INVALID_ROW_IDX); + const auto boundNodeOffset = boundNodeIDVector.getValue(boundNodePos).offset; + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset); + auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); + csrNodeGroup.update(transaction, source, rowIdx, columnID, dataVector); + return true; +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVector, + const ValueVector& relIDVector) { + const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0]; + const auto relIDPos = relIDVector.state->getSelVector()[0]; + if (boundNodeIDVector.isNull(boundNodePos) || relIDVector.isNull(relIDPos)) { + return false; + } + const auto [source, rowIdx] = findMatchingRow(transaction, boundNodeIDVector, relIDVector); + if (rowIdx == INVALID_ROW_IDX) { + return false; + } + const auto boundNodeOffset = boundNodeIDVector.getValue(boundNodePos).offset; + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset); + auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast(); + bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx); + if (isDeleted && transaction->shouldAppendToUndoBuffer()) { + transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, getVersionRecordHandler(source)); + } + return isDeleted; +} + +void RelTableData::addColumn(TableAddColumnState& addColumnState, PageAllocator& pageAllocator) { + auto& definition = addColumnState.propertyDefinition; + columns.push_back(ColumnFactory::createColumn(definition.getName(), definition.getType().copy(), + pageAllocator.getDataFH(), mm, shadowFile, enableCompression)); + nodeGroups->addColumn(addColumnState, &pageAllocator); +} + +std::pair RelTableData::findMatchingRow(Transaction* transaction, + ValueVector& boundNodeIDVector, const ValueVector& relIDVector) const { + KU_ASSERT(boundNodeIDVector.state->getSelVector().getSelSize() == 1); + KU_ASSERT(relIDVector.state->getSelVector().getSelSize() == 1); + const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0]; + const auto relIDPos = relIDVector.state->getSelVector()[0]; + const auto boundNodeOffset = boundNodeIDVector.getValue(boundNodePos).offset; + const auto relOffset = relIDVector.getValue(relIDPos).offset; + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset); + + DataChunk scanChunk(1); + // RelID output vector. + scanChunk.insert(0, std::make_shared(LogicalType::INTERNAL_ID())); + std::vector columnIDs = {REL_ID_COLUMN_ID, ROW_IDX_COLUMN_ID}; + std::vector columns{getColumn(REL_ID_COLUMN_ID), nullptr}; + auto scanState = std::make_unique(*mm, &boundNodeIDVector, + std::vector{&scanChunk.getValueVectorMutable(0)}, scanChunk.state, true /*randomLookup*/); + scanState->setToTable(transaction, &table, columnIDs, {}, direction); + scanState->initState(transaction, getNodeGroup(nodeGroupIdx)); + row_idx_t matchingRowIdx = INVALID_ROW_IDX; + auto source = CSRNodeGroupScanSource::NONE; + const auto scannedIDVector = scanState->outputVectors[0]; + while (true) { + const auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + break; + } + for (auto i = 0u; i < scanState->outState->getSelVector().getSelSize(); i++) { + const auto pos = scanState->outState->getSelVector()[i]; + if (scannedIDVector->getValue(pos).offset == relOffset) { + const auto rowIdxPos = scanState->rowIdxVector->state->getSelVector()[i]; + matchingRowIdx = scanState->rowIdxVector->getValue(rowIdxPos); + source = scanState->nodeGroupScanState->cast().source; + break; + } + } + if (matchingRowIdx != INVALID_ROW_IDX) { + break; + } + } + return {source, matchingRowIdx}; +} + +bool RelTableData::checkIfNodeHasRels(Transaction* transaction, + ValueVector* srcNodeIDVector) const { + KU_ASSERT(srcNodeIDVector->state->isFlat()); + const auto nodeIDPos = srcNodeIDVector->state->getSelVector()[0]; + const auto nodeOffset = srcNodeIDVector->getValue(nodeIDPos).offset; + const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); + if (nodeGroupIdx >= getNumNodeGroups()) { + return false; + } + DataChunk scanChunk(1); + // RelID output vector. + scanChunk.insert(0, std::make_shared(LogicalType::INTERNAL_ID())); + std::vector columnIDs = {REL_ID_COLUMN_ID}; + std::vector columns{getColumn(REL_ID_COLUMN_ID)}; + auto scanState = std::make_unique(*mm, srcNodeIDVector, + std::vector{&scanChunk.getValueVectorMutable(0)}, scanChunk.state, true /*randomLookup*/); + scanState->setToTable(transaction, &table, columnIDs, {}, direction); + scanState->initState(transaction, getNodeGroup(nodeGroupIdx)); + while (true) { + const auto scanResult = scanState->nodeGroup->scan(transaction, *scanState); + if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) { + break; + } + if (scanState->outState->getSelVector().getSelSize() > 0) { + return true; + } + } + return false; +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void RelTableData::pushInsertInfo(const Transaction* transaction, const CSRNodeGroup& nodeGroup, + row_idx_t numRows_, CSRNodeGroupScanSource source) { + // we shouldn't be appending directly to the to the persistent data + // unless we are performing batch insert and the persistent chunked group is empty + KU_ASSERT(source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT || + !nodeGroup.getPersistentChunkedGroup() || + nodeGroup.getPersistentChunkedGroup()->getNumRows() == 0); + + const auto [startRow, shouldIncrementNumRows] = + (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ? + std::make_pair(static_cast(0), false) : + std::make_pair(nodeGroup.getNumRows(), true); + + nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_, + getVersionRecordHandler(source), shouldIncrementNumRows); +} + +void RelTableData::checkpoint(const std::vector& columnIDs, + PageAllocator& pageAllocator) { + std::vector> checkpointColumns; + for (auto i = 0u; i < columnIDs.size(); i++) { + const auto columnID = columnIDs[i]; + checkpointColumns.push_back(std::move(columns[columnID])); + } + columns = std::move(checkpointColumns); + + std::vector checkpointColumnPtrs; + for (const auto& column : columns) { + checkpointColumnPtrs.push_back(column.get()); + } + + CSRNodeGroupCheckpointState state{columnIDs, std::move(checkpointColumnPtrs), pageAllocator, mm, + csrHeaderColumns.offset.get(), csrHeaderColumns.length.get()}; + nodeGroups->checkpoint(*mm, state); +} + +void RelTableData::serialize(Serializer& serializer) const { + nodeGroups->serialize(serializer); +} + +void RelTableData::deserialize(Deserializer& deSerializer, MemoryManager& memoryManager) { + nodeGroups->deserialize(deSerializer, memoryManager); +} + +const VersionRecordHandler* RelTableData::getVersionRecordHandler( + CSRNodeGroupScanSource source) const { + if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) { + return &persistentVersionRecordHandler; + } else { + KU_ASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY); + return &inMemoryVersionRecordHandler; + } +} + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void RelTableData::rollbackGroupCollectionInsert(row_idx_t numRows_, bool isPersistent) { + nodeGroups->rollbackInsert(numRows_, !isPersistent); +} + +void RelTableData::reclaimStorage(PageAllocator& pageAllocator) const { + nodeGroups->reclaimStorage(pageAllocator); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/string_chunk_data.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/string_chunk_data.cpp new file mode 100644 index 0000000000..05eeb6863c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/string_chunk_data.cpp @@ -0,0 +1,325 @@ +#include "storage/table/string_chunk_data.h" + +#include "common/data_chunk/sel_vector.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/dictionary_chunk.h" +#include "storage/table/string_column.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +StringChunkData::StringChunkData(MemoryManager& mm, LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState) + : ColumnChunkData{mm, std::move(dataType), capacity, enableCompression, residencyState, + true /*hasNullData*/}, + indexColumnChunk{ColumnChunkFactory::createColumnChunkData(mm, LogicalType::UINT32(), + enableCompression, capacity, residencyState, false /*hasNullData*/)}, + dictionaryChunk{ + std::make_unique(mm, capacity, enableCompression, residencyState)}, + needFinalize{false} {} + +StringChunkData::StringChunkData(MemoryManager& mm, bool enableCompression, + const ColumnChunkMetadata& metadata) + : ColumnChunkData{mm, LogicalType::STRING(), enableCompression, metadata, true /*hasNullData*/}, + dictionaryChunk{ + std::make_unique(mm, 0, enableCompression, ResidencyState::IN_MEMORY)}, + needFinalize{false} { + // create index chunk + indexColumnChunk = ColumnChunkFactory::createColumnChunkData(mm, LogicalType::UINT32(), + enableCompression, capacity, ResidencyState::ON_DISK); +} + +ColumnChunkData* StringChunkData::getIndexColumnChunk() { + return indexColumnChunk.get(); +} + +const ColumnChunkData* StringChunkData::getIndexColumnChunk() const { + return indexColumnChunk.get(); +} + +void StringChunkData::updateNumValues(size_t newValue) { + numValues = newValue; + indexColumnChunk->setNumValues(newValue); +} + +void StringChunkData::setToInMemory() { + ColumnChunkData::setToInMemory(); + indexColumnChunk->setToInMemory(); + dictionaryChunk->setToInMemory(); +} + +void StringChunkData::resize(uint64_t newCapacity) { + ColumnChunkData::resize(newCapacity); + indexColumnChunk->resize(newCapacity); +} + +void StringChunkData::resizeWithoutPreserve(uint64_t newCapacity) { + ColumnChunkData::resizeWithoutPreserve(newCapacity); + indexColumnChunk->resizeWithoutPreserve(newCapacity); +} + +void StringChunkData::resetToEmpty() { + ColumnChunkData::resetToEmpty(); + indexColumnChunk->resetToEmpty(); + dictionaryChunk->resetToEmpty(); +} + +void StringChunkData::append(ValueVector* vector, const SelectionView& selView) { + selView.forEach([&](auto pos) { + // index is stored in main chunk, data is stored in the data chunk + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + // index is stored in main chunk, data is stored in the data chunk + nullData->setNull(numValues, vector->isNull(pos)); + auto dstPos = numValues; + updateNumValues(numValues + 1); + if (!vector->isNull(pos)) { + auto kuString = vector->getValue(pos); + setValueFromString(kuString.getAsStringView(), dstPos); + } + }); +} + +void StringChunkData::append(const ColumnChunkData* other, offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + const auto& otherChunk = other->cast(); + nullData->append(otherChunk.getNullData(), startPosInOtherChunk, numValuesToAppend); + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BLOB: + case LogicalTypeID::STRING: { + appendStringColumnChunk(&otherChunk, startPosInOtherChunk, numValuesToAppend); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +void StringChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + KU_ASSERT(offset + length <= numValues && nullData); + nullData->scan(output, offset, length, posInOutputVector); + if (!nullData->noNullsGuaranteedInMem()) { + for (auto i = 0u; i < length; i++) { + if (!nullData->isNull(offset + i)) { + output.setValue(posInOutputVector + i, + getValue(offset + i)); + } + } + } else { + for (auto i = 0u; i < length; i++) { + output.setValue(posInOutputVector + i, + getValue(offset + i)); + } + } +} + +void StringChunkData::lookup(offset_t offsetInChunk, ValueVector& output, + sel_t posInOutputVector) const { + KU_ASSERT(offsetInChunk < numValues); + output.setNull(posInOutputVector, nullData->isNull(offsetInChunk)); + if (nullData->isNull(offsetInChunk)) { + return; + } + auto str = getValue(offsetInChunk); + output.setValue(posInOutputVector, str); +} + +void StringChunkData::initializeScanState(SegmentState& state, const Column* column) const { + ColumnChunkData::initializeScanState(state, column); + + auto* stringColumn = ku_dynamic_cast(column); + state.childrenStates.resize(CHILD_COLUMN_COUNT); + indexColumnChunk->initializeScanState(state.childrenStates[INDEX_COLUMN_CHILD_READ_STATE_IDX], + stringColumn->getIndexColumn()); + dictionaryChunk->getOffsetChunk()->initializeScanState( + state.childrenStates[OFFSET_COLUMN_CHILD_READ_STATE_IDX], + stringColumn->getDictionary().getOffsetColumn()); + dictionaryChunk->getStringDataChunk()->initializeScanState( + state.childrenStates[DATA_COLUMN_CHILD_READ_STATE_IDX], + stringColumn->getDictionary().getDataColumn()); +} + +void StringChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + if (!needFinalize && offsetInChunk < numValues) [[unlikely]] { + needFinalize = true; + } + nullData->setNull(offsetInChunk, vector->isNull(offsetInVector)); + if (offsetInChunk >= numValues) { + updateNumValues(offsetInChunk + 1); + } + if (!vector->isNull(offsetInVector)) { + auto kuStr = vector->getValue(offsetInVector); + setValueFromString(kuStr.getAsStringView(), offsetInChunk); + } +} + +void StringChunkData::write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, RelMultiplicity) { + KU_ASSERT(chunk->getDataType().getPhysicalType() == PhysicalTypeID::STRING && + dstOffsets->getDataType().getPhysicalType() == PhysicalTypeID::INTERNAL_ID && + chunk->getNumValues() == dstOffsets->getNumValues()); + auto& stringChunk = chunk->cast(); + for (auto i = 0u; i < chunk->getNumValues(); i++) { + auto offsetInChunk = dstOffsets->getValue(i); + if (!needFinalize && offsetInChunk < numValues) [[unlikely]] { + needFinalize = true; + } + bool isNull = chunk->getNullData()->isNull(i); + nullData->setNull(offsetInChunk, isNull); + if (offsetInChunk >= numValues) { + updateNumValues(offsetInChunk + 1); + } + if (!isNull) { + setValueFromString(stringChunk.getValue(i), offsetInChunk); + } + } +} + +void StringChunkData::write(const ColumnChunkData* srcChunk, offset_t srcOffsetInChunk, + offset_t dstOffsetInChunk, offset_t numValuesToCopy) { + KU_ASSERT(srcChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRING); + if ((dstOffsetInChunk + numValuesToCopy) >= numValues) { + updateNumValues(dstOffsetInChunk + numValuesToCopy); + } + auto& srcStringChunk = srcChunk->cast(); + for (auto i = 0u; i < numValuesToCopy; i++) { + auto srcPos = srcOffsetInChunk + i; + auto dstPos = dstOffsetInChunk + i; + bool isNull = srcChunk->getNullData()->isNull(srcPos); + nullData->setNull(dstPos, isNull); + if (isNull) { + continue; + } + setValueFromString(srcStringChunk.getValue(srcPos), dstPos); + } +} + +void StringChunkData::appendStringColumnChunk(const StringChunkData* other, + offset_t startPosInOtherChunk, uint32_t numValuesToAppend) { + for (auto i = 0u; i < numValuesToAppend; i++) { + auto posInChunk = numValues; + auto posInOtherChunk = i + startPosInOtherChunk; + updateNumValues(numValues + 1); + if (nullData->isNull(posInChunk)) { + indexColumnChunk->setValue(0, posInChunk); + continue; + } + setValueFromString(other->getValue(posInOtherChunk), posInChunk); + } +} + +void StringChunkData::setValueFromString(std::string_view value, uint64_t pos) { + auto index = dictionaryChunk->appendString(value); + indexColumnChunk->setValue(index, pos); +} + +void StringChunkData::resetNumValuesFromMetadata() { + ColumnChunkData::resetNumValuesFromMetadata(); + indexColumnChunk->resetNumValuesFromMetadata(); + dictionaryChunk->resetNumValuesFromMetadata(); +} + +void StringChunkData::finalize() { + if (!needFinalize) { + return; + } + // Prune unused entries in the dictionary before we flush + // We already de-duplicate as we go, but when out of place updates occur new values will be + // appended to the end and the original values may be able to be pruned before flushing them to + // disk + auto newDictionaryChunk = std::make_unique(getMemoryManager(), numValues, + enableCompression, residencyState); + // Each index is replaced by a new one for the de-duplicated data in the new dictionary. + for (auto i = 0u; i < numValues; i++) { + if (nullData->isNull(i)) { + continue; + } + auto stringData = getValue(i); + auto index = newDictionaryChunk->appendString(stringData); + indexColumnChunk->setValue(index, i); + } + dictionaryChunk = std::move(newDictionaryChunk); +} + +void StringChunkData::flush(PageAllocator& pageAllocator) { + ColumnChunkData::flush(pageAllocator); + indexColumnChunk->flush(pageAllocator); + dictionaryChunk->flush(pageAllocator); +} + +void StringChunkData::reclaimStorage(PageAllocator& pageAllocator) { + ColumnChunkData::reclaimStorage(pageAllocator); + indexColumnChunk->reclaimStorage(pageAllocator); + dictionaryChunk->getOffsetChunk()->reclaimStorage(pageAllocator); + dictionaryChunk->getStringDataChunk()->reclaimStorage(pageAllocator); +} + +uint64_t StringChunkData::getSizeOnDisk() const { + return ColumnChunkData::getSizeOnDisk() + indexColumnChunk->getSizeOnDisk() + + dictionaryChunk->getOffsetChunk()->getSizeOnDisk() + + dictionaryChunk->getStringDataChunk()->getSizeOnDisk(); +} +uint64_t StringChunkData::getMinimumSizeOnDisk() const { + return ColumnChunkData::getMinimumSizeOnDisk() + indexColumnChunk->getMinimumSizeOnDisk() + + dictionaryChunk->getOffsetChunk()->getMinimumSizeOnDisk() + + dictionaryChunk->getStringDataChunk()->getMinimumSizeOnDisk(); +} + +uint64_t StringChunkData::getSizeOnDiskInMemoryStats() const { + return ColumnChunkData::getSizeOnDiskInMemoryStats() + + indexColumnChunk->getSizeOnDiskInMemoryStats() + + dictionaryChunk->getOffsetChunk()->getSizeOnDiskInMemoryStats() + + dictionaryChunk->getStringDataChunk()->getSizeOnDiskInMemoryStats(); +} + +uint64_t StringChunkData::getEstimatedMemoryUsage() const { + return ColumnChunkData::getEstimatedMemoryUsage() + dictionaryChunk->getEstimatedMemoryUsage(); +} + +void StringChunkData::serialize(Serializer& serializer) const { + ColumnChunkData::serialize(serializer); + serializer.writeDebuggingInfo("index_column_chunk"); + indexColumnChunk->serialize(serializer); + serializer.writeDebuggingInfo("dictionary_chunk"); + dictionaryChunk->serialize(serializer); +} + +void StringChunkData::deserialize(Deserializer& deSer, ColumnChunkData& chunkData) { + std::string key; + deSer.validateDebuggingInfo(key, "index_column_chunk"); + chunkData.cast().indexColumnChunk = + ColumnChunkData::deserialize(chunkData.getMemoryManager(), deSer); + deSer.validateDebuggingInfo(key, "dictionary_chunk"); + chunkData.cast().dictionaryChunk = + DictionaryChunk::deserialize(chunkData.getMemoryManager(), deSer); +} + +template<> +ku_string_t StringChunkData::getValue(offset_t) const { + KU_UNREACHABLE; +} + +// STRING +template<> +std::string_view StringChunkData::getValue(offset_t pos) const { + KU_ASSERT(pos < numValues); + KU_ASSERT(!nullData->isNull(pos)); + auto index = indexColumnChunk->getValue(pos); + return dictionaryChunk->getString(index); +} + +template<> +std::string StringChunkData::getValue(offset_t pos) const { + return std::string(getValue(pos)); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/string_column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/string_column.cpp new file mode 100644 index 0000000000..2bf870ccd4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/string_column.cpp @@ -0,0 +1,302 @@ +#include "storage/table/string_column.h" + +#include +#include + +#include "common/assert.h" +#include "common/cast.h" +#include "common/null_mask.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/compression/compression.h" +#include "storage/page_allocator.h" +#include "storage/storage_utils.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk.h" +#include "storage/table/null_column.h" +#include "storage/table/string_chunk_data.h" + +using namespace lbug::catalog; +using namespace lbug::common; + +namespace lbug { +namespace storage { + +using string_index_t = DictionaryChunk::string_index_t; +using string_offset_t = DictionaryChunk::string_offset_t; + +StringColumn::StringColumn(std::string name, common::LogicalType dataType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression) + : Column{std::move(name), std::move(dataType), dataFH, mm, shadowFile, enableCompression, + true /* requireNullColumn */}, + dictionary{this->name, dataFH, mm, shadowFile, enableCompression} { + auto indexColumnName = + StorageUtils::getColumnName(this->name, StorageUtils::ColumnType::INDEX, "index"); + indexColumn = std::make_unique(indexColumnName, LogicalType::UINT32(), dataFH, mm, + shadowFile, enableCompression, false /*requireNullColumn*/); +} + +SegmentState& StringColumn::getChildState(SegmentState& state, ChildStateIndex child) { + const auto childIdx = static_cast(child); + return state.getChildState(childIdx); +} + +const SegmentState& StringColumn::getChildState(const SegmentState& state, ChildStateIndex child) { + const auto childIdx = static_cast(child); + return state.getChildState(childIdx); +} + +std::unique_ptr StringColumn::flushChunkData(const ColumnChunkData& chunkData, + PageAllocator& pageAllocator) { + auto flushedChunkData = flushNonNestedChunkData(chunkData, pageAllocator); + auto& flushedStringData = flushedChunkData->cast(); + + auto& stringChunk = chunkData.cast(); + flushedStringData.setIndexChunk( + Column::flushChunkData(*stringChunk.getIndexColumnChunk(), pageAllocator)); + auto& dictChunk = stringChunk.getDictionaryChunk(); + flushedStringData.getDictionaryChunk().setOffsetChunk( + Column::flushChunkData(*dictChunk.getOffsetChunk(), pageAllocator)); + flushedStringData.getDictionaryChunk().setStringDataChunk( + Column::flushChunkData(*dictChunk.getStringDataChunk(), pageAllocator)); + return flushedChunkData; +} + +void StringColumn::lookupInternal(const SegmentState& state, offset_t nodeOffset, + ValueVector* resultVector, uint32_t posInVector) const { + auto [nodeGroupIdx, offsetInChunk] = StorageUtils::getNodeGroupIdxAndOffsetInChunk(nodeOffset); + string_index_t index = 0; + indexColumn->scanSegment(getChildState(state, ChildStateIndex::INDEX), offsetInChunk, 1, + reinterpret_cast(&index)); + std::vector> offsetsToScan; + offsetsToScan.emplace_back(index, posInVector); + dictionary.scan(getChildState(state, ChildStateIndex::OFFSET), + getChildState(state, ChildStateIndex::DATA), offsetsToScan, resultVector, + getChildState(state, ChildStateIndex::INDEX).metadata); +} + +void StringColumn::writeSegment(ColumnChunkData& persistentChunk, SegmentState& state, + offset_t dstOffsetInSegment, const ColumnChunkData& data, offset_t srcOffset, + length_t numValues) const { + auto& stringPersistentChunk = persistentChunk.cast(); + numValues = std::min(numValues, data.getNumValues() - srcOffset); + auto& strChunkToWriteFrom = data.cast(); + std::vector indices; + indices.resize(numValues); + for (auto i = 0u; i < numValues; i++) { + if (strChunkToWriteFrom.getNullData()->isNull(i + srcOffset)) { + indices[i] = 0; + continue; + } + const auto strVal = strChunkToWriteFrom.getValue(i + srcOffset); + indices[i] = dictionary.append(persistentChunk.cast().getDictionaryChunk(), + state, strVal); + } + NullMask nullMask(numValues); + nullMask.copyFromNullBits(data.getNullData()->getNullMask().getData(), srcOffset, + 0 /*dstOffset*/, numValues); + // Write index to main column + indexColumn->writeValuesInternal(getChildState(state, ChildStateIndex::INDEX), + dstOffsetInSegment, reinterpret_cast(&indices[0]), &nullMask, + 0 /*srcOffset*/, numValues); + auto [min, max] = std::minmax_element(indices.begin(), indices.end()); + auto minWritten = StorageValue(*min); + auto maxWritten = StorageValue(*max); + updateStatistics(persistentChunk.getMetadata(), dstOffsetInSegment + numValues - 1, minWritten, + maxWritten); + indexColumn->updateStatistics(stringPersistentChunk.getIndexColumnChunk()->getMetadata(), + dstOffsetInSegment + numValues - 1, minWritten, maxWritten); +} + +std::vector> StringColumn::checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment) const { + auto& persistentData = checkpointState.persistentData; + auto result = + Column::checkpointSegment(std::move(checkpointState), pageAllocator, canSplitSegment); + persistentData.syncNumValues(); + return result; +} + +void StringColumn::scanSegment(const SegmentState& state, offset_t startOffsetInChunk, + row_idx_t numValuesToScan, ValueVector* resultVector, offset_t offsetInResult) const { + if (nullColumn) { + KU_ASSERT(state.nullState); + nullColumn->scanSegment(*state.nullState, startOffsetInChunk, numValuesToScan, resultVector, + offsetInResult); + } + + KU_ASSERT(resultVector->dataType.getPhysicalType() == PhysicalTypeID::STRING); + if (!resultVector->state || resultVector->state->getSelVector().isUnfiltered()) { + scanUnfiltered(state, startOffsetInChunk, numValuesToScan, resultVector, offsetInResult); + } else { + scanFiltered(state, startOffsetInChunk, resultVector, offsetInResult); + } +} + +void StringColumn::scanSegment(const SegmentState& state, ColumnChunkData* resultChunk, + common::offset_t startOffsetInSegment, common::row_idx_t numValuesToScan) const { + auto startOffsetInResult = resultChunk->getNumValues(); + Column::scanSegment(state, resultChunk, startOffsetInSegment, numValuesToScan); + KU_ASSERT(resultChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRING); + + auto* stringResultChunk = ku_dynamic_cast(resultChunk); + // Revert change to numValues from Column::scanSegment (see note in list_column.cpp) + // This shouldn't be necessary in future + stringResultChunk->getIndexColumnChunk()->setNumValues(startOffsetInResult); + + auto* indexChunk = stringResultChunk->getIndexColumnChunk(); + indexColumn->scanSegment(getChildState(state, ChildStateIndex::INDEX), indexChunk, + startOffsetInSegment, numValuesToScan); + + const auto initialDictSize = + stringResultChunk->getDictionaryChunk().getOffsetChunk()->getNumValues(); + if (numValuesToScan == state.metadata.numValues) { + // Append the entire dictionary into the chunk + // Since the resultChunk may be non-empty, each index needs to be incremented by the initial + // size of the dictionary so that the indices line up with the values that will be scanned + // into the dictionary chunk + for (row_idx_t i = 0; i < numValuesToScan; i++) { + indexChunk->setValue( + indexChunk->getValue(startOffsetInResult + i) + initialDictSize, + startOffsetInResult + i); + } + dictionary.scan(state, stringResultChunk->getDictionaryChunk()); + } else { + // Any strings which are duplicated only need to be scanned once, so we track duplicate + // indices + std::unordered_map indexMap; + std::vector> offsetsToScan; + for (auto i = 0u; i < numValuesToScan; i++) { + if (!resultChunk->isNull(startOffsetInResult + i)) { + auto index = indexChunk->getValue(startOffsetInResult + i); + auto element = indexMap.find(index); + if (element == indexMap.end()) { + indexMap.insert(std::make_pair(index, initialDictSize + offsetsToScan.size())); + indexChunk->setValue(initialDictSize + offsetsToScan.size(), + startOffsetInResult + i); + offsetsToScan.emplace_back(index, initialDictSize + offsetsToScan.size()); + } else { + indexChunk->setValue(element->second, startOffsetInResult + i); + } + } + } + + if (offsetsToScan.size() == 0) { + // All scanned values are null + return; + } + dictionary.scan(getChildState(state, ChildStateIndex::OFFSET), + getChildState(state, ChildStateIndex::DATA), offsetsToScan, stringResultChunk, + getChildState(state, ChildStateIndex::INDEX).metadata); + } + KU_ASSERT(resultChunk->getNumValues() == startOffsetInResult + numValuesToScan && + stringResultChunk->getIndexColumnChunk()->getNumValues() == + startOffsetInResult + numValuesToScan); + RUNTIME_CHECK({ + auto dictionarySize = + stringResultChunk->getDictionaryChunk().getOffsetChunk()->getNumValues(); + auto indexSize = stringResultChunk->getIndexColumnChunk()->getNumValues(); + for (offset_t i = 0; i < indexSize; i++) { + if (!stringResultChunk->isNull(i)) { + auto stringIndex = + stringResultChunk->getIndexColumnChunk()->getValue(i); + KU_ASSERT(stringIndex < dictionarySize); + } + } + }); +} + +void StringColumn::scanUnfiltered(const SegmentState& state, offset_t startOffsetInChunk, + offset_t numValuesToRead, ValueVector* resultVector, sel_t startPosInVector) const { + // TODO: Replace indices with ValueVector to avoid maintaining `scan` interface from + // uint8_t*. + auto indices = std::make_unique(numValuesToRead); + indexColumn->scanSegment(getChildState(state, ChildStateIndex::INDEX), startOffsetInChunk, + numValuesToRead, reinterpret_cast(indices.get())); + + std::vector> offsetsToScan; + for (auto i = 0u; i < numValuesToRead; i++) { + if (!resultVector->isNull(startPosInVector + i)) { + offsetsToScan.emplace_back(indices[i], startPosInVector + i); + } + } + + if (offsetsToScan.size() == 0) { + // All scanned values are null + return; + } + dictionary.scan(getChildState(state, ChildStateIndex::OFFSET), + getChildState(state, ChildStateIndex::DATA), offsetsToScan, resultVector, + getChildState(state, ChildStateIndex::INDEX).metadata); +} + +void StringColumn::scanFiltered(const SegmentState& state, offset_t startOffsetInChunk, + ValueVector* resultVector, offset_t offsetInResult) const { + std::vector> offsetsToScan; + for (sel_t i = 0; i < resultVector->state->getSelVector().getSelSize(); i++) { + const auto pos = resultVector->state->getSelVector()[i]; + if (pos >= offsetInResult && startOffsetInChunk + pos < state.metadata.numValues && + !resultVector->isNull(pos)) { + // TODO(bmwinger): optimize index scans by grouping them when adjacent + const auto offsetInGroup = startOffsetInChunk + pos - offsetInResult; + string_index_t index = 0; + indexColumn->scanSegment(getChildState(state, ChildStateIndex::INDEX), offsetInGroup, 1, + reinterpret_cast(&index)); + offsetsToScan.emplace_back(index, pos); + } + } + + if (offsetsToScan.size() == 0) { + // All scanned values are null + return; + } + dictionary.scan(getChildState(state, ChildStateIndex::OFFSET), + getChildState(state, ChildStateIndex::DATA), offsetsToScan, resultVector, + getChildState(state, ChildStateIndex::INDEX).metadata); +} + +bool StringColumn::canCheckpointInPlace(const SegmentState& state, + const ColumnCheckpointState& checkpointState) const { + row_idx_t strLenToAdd = 0u; + idx_t numStrings = 0u; + for (auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + auto& strChunk = segmentCheckpointState.chunkData.cast(); + numStrings += segmentCheckpointState.numRows; + for (auto i = 0u; i < segmentCheckpointState.numRows; i++) { + if (strChunk.getNullData()->isNull(segmentCheckpointState.startRowInData + i)) { + continue; + } + strLenToAdd += strChunk.getStringLength(segmentCheckpointState.startRowInData + i); + } + } + if (!dictionary.canCommitInPlace(state, numStrings, strLenToAdd)) { + return false; + } + return canIndexCommitInPlace(state, numStrings, checkpointState.endRowIdxToWrite); +} + +bool StringColumn::canIndexCommitInPlace(const SegmentState& state, uint64_t numStrings, + offset_t maxOffset) const { + const SegmentState& indexState = getChildState(state, ChildStateIndex::INDEX); + if (indexColumn->isEndOffsetOutOfPagesCapacity(indexState.metadata, maxOffset)) { + return false; + } + if (indexState.metadata.compMeta.canAlwaysUpdateInPlace()) { + return true; + } + const auto totalStringsAfterUpdate = + getChildState(state, ChildStateIndex::OFFSET).metadata.numValues + numStrings; + InPlaceUpdateLocalState localUpdateState{}; + // Check if the index column can store the largest new index in-place + if (!indexState.metadata.compMeta.canUpdateInPlace( + reinterpret_cast(&totalStringsAfterUpdate), 0 /*pos*/, 1 /*numValues*/, + PhysicalTypeID::UINT32, localUpdateState)) { + return false; + } + return true; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/struct_chunk_data.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/struct_chunk_data.cpp new file mode 100644 index 0000000000..118977695c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/struct_chunk_data.cpp @@ -0,0 +1,266 @@ +#include "storage/table/struct_chunk_data.h" + +#include "common/data_chunk/sel_vector.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/struct_column.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +StructChunkData::StructChunkData(MemoryManager& mm, LogicalType dataType, uint64_t capacity, + bool enableCompression, ResidencyState residencyState) + : ColumnChunkData{mm, std::move(dataType), capacity, enableCompression, residencyState, + true /*hasNullData*/} { + const auto fieldTypes = StructType::getFieldTypes(this->dataType); + childChunks.resize(fieldTypes.size()); + for (auto i = 0u; i < fieldTypes.size(); i++) { + childChunks[i] = ColumnChunkFactory::createColumnChunkData(mm, fieldTypes[i]->copy(), + enableCompression, capacity, residencyState); + } +} + +StructChunkData::StructChunkData(MemoryManager& mm, LogicalType dataType, bool enableCompression, + const ColumnChunkMetadata& metadata) + : ColumnChunkData{mm, std::move(dataType), enableCompression, metadata, true /*hasNullData*/} { + const auto fieldTypes = StructType::getFieldTypes(this->dataType); + childChunks.resize(fieldTypes.size()); + for (auto i = 0u; i < fieldTypes.size(); i++) { + childChunks[i] = ColumnChunkFactory::createColumnChunkData(mm, fieldTypes[i]->copy(), + enableCompression, 0, ResidencyState::IN_MEMORY); + } +} + +void StructChunkData::finalize() { + for (const auto& childChunk : childChunks) { + childChunk->finalize(); + } +} + +uint64_t StructChunkData::getEstimatedMemoryUsage() const { + auto estimatedMemoryUsage = ColumnChunkData::getEstimatedMemoryUsage(); + for (auto& childChunk : childChunks) { + estimatedMemoryUsage += childChunk->getEstimatedMemoryUsage(); + } + return estimatedMemoryUsage; +} + +void StructChunkData::resetNumValuesFromMetadata() { + ColumnChunkData::resetNumValuesFromMetadata(); + for (const auto& childChunk : childChunks) { + childChunk->resetNumValuesFromMetadata(); + } +} + +void StructChunkData::resetToAllNull() { + ColumnChunkData::resetToAllNull(); + for (const auto& childChunk : childChunks) { + childChunk->resetToAllNull(); + } +} + +void StructChunkData::serialize(Serializer& serializer) const { + ColumnChunkData::serialize(serializer); + serializer.writeDebuggingInfo("struct_children"); + serializer.serializeVectorOfPtrs(childChunks); +} + +void StructChunkData::deserialize(Deserializer& deSer, ColumnChunkData& chunkData) { + std::string key; + deSer.validateDebuggingInfo(key, "struct_children"); + deSer.deserializeVectorOfPtrs(chunkData.cast().childChunks, + [&](Deserializer& deser) { + return ColumnChunkData::deserialize(chunkData.getMemoryManager(), deser); + }); +} + +void StructChunkData::flush(PageAllocator& pageAllocator) { + ColumnChunkData::flush(pageAllocator); + for (const auto& childChunk : childChunks) { + childChunk->flush(pageAllocator); + } +} + +void StructChunkData::reclaimStorage(PageAllocator& pageAllocator) { + ColumnChunkData::reclaimStorage(pageAllocator); + for (const auto& childChunk : childChunks) { + childChunk->reclaimStorage(pageAllocator); + } +} + +uint64_t StructChunkData::getSizeOnDisk() const { + uint64_t size = ColumnChunkData::getSizeOnDisk(); + for (const auto& childChunk : childChunks) { + size += childChunk->getSizeOnDisk(); + } + return size; +} + +uint64_t StructChunkData::getMinimumSizeOnDisk() const { + uint64_t size = ColumnChunkData::getMinimumSizeOnDisk(); + for (const auto& childChunk : childChunks) { + size += childChunk->getMinimumSizeOnDisk(); + } + return size; +} + +uint64_t StructChunkData::getSizeOnDiskInMemoryStats() const { + uint64_t size = ColumnChunkData::getSizeOnDiskInMemoryStats(); + for (const auto& childChunk : childChunks) { + size += childChunk->getSizeOnDiskInMemoryStats(); + } + return size; +} + +void StructChunkData::append(const ColumnChunkData* other, offset_t startPosInOtherChunk, + uint32_t numValuesToAppend) { + KU_ASSERT(other->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT); + const auto& otherStructChunk = other->cast(); + KU_ASSERT(childChunks.size() == otherStructChunk.childChunks.size()); + nullData->append(other->getNullData(), startPosInOtherChunk, numValuesToAppend); + for (auto i = 0u; i < childChunks.size(); i++) { + childChunks[i]->append(otherStructChunk.childChunks[i].get(), startPosInOtherChunk, + numValuesToAppend); + } + numValues += numValuesToAppend; +} + +void StructChunkData::append(ValueVector* vector, const SelectionView& selView) { + const auto numFields = StructType::getNumFields(dataType); + for (auto i = 0u; i < numFields; i++) { + childChunks[i]->append(StructVector::getFieldVector(vector, i).get(), selView); + } + for (auto i = 0u; i < selView.getSelSize(); i++) { + nullData->setNull(numValues + i, vector->isNull(selView[i])); + } + numValues += selView.getSelSize(); +} + +void StructChunkData::scan(ValueVector& output, offset_t offset, length_t length, + sel_t posInOutputVector) const { + KU_ASSERT(offset + length <= numValues); + if (nullData) { + nullData->scan(output, offset, length, posInOutputVector); + } + const auto numFields = StructType::getNumFields(dataType); + for (auto i = 0u; i < numFields; i++) { + childChunks[i]->scan(*StructVector::getFieldVector(&output, i), offset, length, + posInOutputVector); + } +} + +void StructChunkData::lookup(offset_t offsetInChunk, ValueVector& output, + sel_t posInOutputVector) const { + KU_ASSERT(offsetInChunk < numValues); + const auto numFields = StructType::getNumFields(dataType); + output.setNull(posInOutputVector, nullData->isNull(offsetInChunk)); + for (auto i = 0u; i < numFields; i++) { + childChunks[i]->lookup(offsetInChunk, *StructVector::getFieldVector(&output, i).get(), + posInOutputVector); + } +} + +void StructChunkData::initializeScanState(SegmentState& state, const Column* column) const { + ColumnChunkData::initializeScanState(state, column); + + auto* structColumn = ku_dynamic_cast(column); + state.childrenStates.resize(childChunks.size()); + for (auto i = 0u; i < childChunks.size(); i++) { + childChunks[i]->initializeScanState(state.childrenStates[i], structColumn->getChild(i)); + } +} + +void StructChunkData::setToInMemory() { + ColumnChunkData::setToInMemory(); + for (const auto& child : childChunks) { + child->setToInMemory(); + } +} + +void StructChunkData::resize(uint64_t newCapacity) { + ColumnChunkData::resize(newCapacity); + capacity = newCapacity; + for (const auto& child : childChunks) { + child->resize(newCapacity); + } +} + +void StructChunkData::resizeWithoutPreserve(uint64_t newCapacity) { + ColumnChunkData::resizeWithoutPreserve(newCapacity); + capacity = newCapacity; + for (const auto& child : childChunks) { + child->resizeWithoutPreserve(newCapacity); + } +} + +void StructChunkData::resetToEmpty() { + ColumnChunkData::resetToEmpty(); + for (const auto& child : childChunks) { + child->resetToEmpty(); + } +} + +void StructChunkData::write(const ValueVector* vector, offset_t offsetInVector, + offset_t offsetInChunk) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::STRUCT); + nullData->setNull(offsetInChunk, vector->isNull(offsetInVector)); + const auto fields = StructVector::getFieldVectors(vector); + for (auto i = 0u; i < fields.size(); i++) { + childChunks[i]->write(fields[i].get(), offsetInVector, offsetInChunk); + } + if (offsetInChunk >= numValues) { + numValues = offsetInChunk + 1; + } +} + +void StructChunkData::write(ColumnChunkData* chunk, ColumnChunkData* dstOffsets, + RelMultiplicity multiplicity) { + KU_ASSERT(chunk->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT && + dstOffsets->getDataType().getPhysicalType() == PhysicalTypeID::INTERNAL_ID); + for (auto i = 0u; i < dstOffsets->getNumValues(); i++) { + const auto offsetInChunk = dstOffsets->getValue(i); + KU_ASSERT(offsetInChunk < capacity); + nullData->setNull(offsetInChunk, chunk->getNullData()->isNull(i)); + numValues = offsetInChunk >= numValues ? offsetInChunk + 1 : numValues; + } + auto& structChunk = chunk->cast(); + for (auto i = 0u; i < childChunks.size(); i++) { + childChunks[i]->write(structChunk.getChild(i), dstOffsets, multiplicity); + } +} + +void StructChunkData::write(const ColumnChunkData* srcChunk, offset_t srcOffsetInChunk, + offset_t dstOffsetInChunk, offset_t numValuesToCopy) { + KU_ASSERT(srcChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT); + const auto& srcStructChunk = srcChunk->cast(); + KU_ASSERT(childChunks.size() == srcStructChunk.childChunks.size()); + nullData->write(srcChunk->getNullData(), srcOffsetInChunk, dstOffsetInChunk, numValuesToCopy); + if ((dstOffsetInChunk + numValuesToCopy) >= numValues) { + numValues = dstOffsetInChunk + numValuesToCopy; + } + for (auto i = 0u; i < childChunks.size(); i++) { + childChunks[i]->write(srcStructChunk.childChunks[i].get(), srcOffsetInChunk, + dstOffsetInChunk, numValuesToCopy); + } +} + +bool StructChunkData::numValuesSanityCheck() const { + for (auto& child : childChunks) { + if (child->getNumValues() != numValues) { + return false; + } + if (!child->numValuesSanityCheck()) { + return false; + } + } + return nullData->getNumValues() == numValues; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/struct_column.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/struct_column.cpp new file mode 100644 index 0000000000..ab52006474 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/struct_column.cpp @@ -0,0 +1,148 @@ +#include "storage/table/struct_column.h" + +#include "common/types/types.h" +#include "common/vector/value_vector.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk.h" +#include "storage/table/null_column.h" +#include "storage/table/struct_chunk_data.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +StructColumn::StructColumn(std::string name, LogicalType dataType, FileHandle* dataFH, + MemoryManager* mm, ShadowFile* shadowFile, bool enableCompression) + : Column{std::move(name), std::move(dataType), dataFH, mm, shadowFile, enableCompression, + true /* requireNullColumn */} { + const auto fieldTypes = StructType::getFieldTypes(this->dataType); + childColumns.resize(fieldTypes.size()); + for (auto i = 0u; i < fieldTypes.size(); i++) { + const auto childColName = StorageUtils::getColumnName(this->name, + StorageUtils::ColumnType::STRUCT_CHILD, std::to_string(i)); + childColumns[i] = ColumnFactory::createColumn(childColName, fieldTypes[i]->copy(), dataFH, + mm, shadowFile, enableCompression); + } +} + +std::unique_ptr StructColumn::flushChunkData(const ColumnChunkData& chunk, + PageAllocator& pageAllocator) { + auto flushedChunk = flushNonNestedChunkData(chunk, pageAllocator); + auto& structChunk = chunk.cast(); + auto& flushedStructChunk = flushedChunk->cast(); + for (auto i = 0u; i < structChunk.getNumChildren(); i++) { + auto flushedChildChunk = Column::flushChunkData(structChunk.getChild(i), pageAllocator); + flushedStructChunk.setChild(i, std::move(flushedChildChunk)); + } + return flushedChunk; +} + +void StructColumn::scanSegment(const SegmentState& state, ColumnChunkData* resultChunk, + common::offset_t startOffsetInSegment, common::row_idx_t numValuesToScan) const { + KU_ASSERT(resultChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT); + // Fix size since Column::scanSegment will adjust the size of the child chunks to be equal to + // the size of the main one (see note in list_column.cpp) + // TODO(bmwinger): eventually this shouldn't be necessary + auto sizeBeforeScan = resultChunk->getNumValues(); + Column::scanSegment(state, resultChunk, startOffsetInSegment, numValuesToScan); + auto& structColumnChunk = resultChunk->cast(); + for (auto i = 0u; i < childColumns.size(); i++) { + structColumnChunk.getChild(i)->setNumValues(sizeBeforeScan); + childColumns[i]->scanSegment(state.childrenStates[i], structColumnChunk.getChild(i), + startOffsetInSegment, numValuesToScan); + } +} + +void StructColumn::scanSegment(const SegmentState& state, offset_t startOffsetInSegment, + row_idx_t numValuesToScan, ValueVector* resultVector, offset_t offsetInResult) const { + Column::scanSegment(state, startOffsetInSegment, numValuesToScan, resultVector, offsetInResult); + for (auto i = 0u; i < childColumns.size(); i++) { + const auto fieldVector = StructVector::getFieldVector(resultVector, i).get(); + childColumns[i]->scanSegment(state.childrenStates[i], startOffsetInSegment, numValuesToScan, + fieldVector, offsetInResult); + } +} + +void StructColumn::lookupInternal(const SegmentState& state, offset_t offsetInSegment, + ValueVector* resultVector, uint32_t posInVector) const { + for (auto i = 0u; i < childColumns.size(); i++) { + const auto fieldVector = StructVector::getFieldVector(resultVector, i).get(); + childColumns[i]->lookupInternal(state.childrenStates[i], offsetInSegment, fieldVector, + posInVector); + } +} + +void StructColumn::writeSegment(ColumnChunkData& persistentChunk, SegmentState& state, + offset_t offsetInSegment, const ColumnChunkData& data, offset_t dataOffset, + length_t numValues) const { + KU_ASSERT(data.getDataType().getPhysicalType() == PhysicalTypeID::STRUCT); + nullColumn->writeSegment(*persistentChunk.getNullData(), *state.nullState, offsetInSegment, + *data.getNullData(), dataOffset, numValues); + auto& structData = data.cast(); + auto& persistentStructChunk = persistentChunk.cast(); + for (auto i = 0u; i < childColumns.size(); i++) { + const auto& childData = structData.getChild(i); + childColumns[i]->writeSegment(*persistentStructChunk.getChild(i), state.childrenStates[i], + offsetInSegment, childData, dataOffset, numValues); + } +} + +std::vector> StructColumn::checkpointSegment( + ColumnCheckpointState&& checkpointState, PageAllocator& pageAllocator, + bool canSplitSegment) const { + auto& persistentStructChunk = checkpointState.persistentData.cast(); + // TODO(bmwinger): child columns are now handled as a group so they get split together + // Re-introduce the code below when struct columns checkpoing each field individually again + /* + for (auto i = 0u; i < childColumns.size(); i++) { + std::vector childSegmentCheckpointStates; + for (const auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + childSegmentCheckpointStates.emplace_back( + segmentCheckpointState.chunkData.cast().getChild(i), + segmentCheckpointState.startRowInData, segmentCheckpointState.offsetInSegment, + segmentCheckpointState.numRows); + } + childColumns[i]->checkpointSegment(ColumnCheckpointState(*persistentStructChunk.getChild(i), + std::move(childSegmentCheckpointStates)), pageAllocator); + } + Column::checkpointNullData(checkpointState, pageAllocator); + */ + auto result = + Column::checkpointSegment(std::move(checkpointState), pageAllocator, canSplitSegment); + persistentStructChunk.syncNumValues(); + return result; +} + +bool StructColumn::canCheckpointInPlace(const SegmentState& state, + const ColumnCheckpointState& checkpointState) const { + if (!Column::canCheckpointInPlace(state, checkpointState)) { + return false; + } + for (size_t i = 0; i < childColumns.size(); ++i) { + auto& structChunkData = checkpointState.persistentData.cast(); + KU_ASSERT(childColumns.size() == structChunkData.getNumChildren()); + auto* childChunkData = structChunkData.getChild(i); + + std::vector childSegmentCheckpointStates; + for (auto& segmentCheckpointState : checkpointState.segmentCheckpointStates) { + auto& structSegmentData = segmentCheckpointState.chunkData.cast(); + auto& childSegmentData = structSegmentData.getChild(i); + childSegmentCheckpointStates.emplace_back(childSegmentData, + segmentCheckpointState.offsetInSegment, segmentCheckpointState.startRowInData, + segmentCheckpointState.numRows); + } + + if (!childColumns[i]->canCheckpointInPlace(state.getChildState(i), + ColumnCheckpointState(*childChunkData, std::move(childSegmentCheckpointStates)))) { + return false; + } + } + return true; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/table.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/table.cpp new file mode 100644 index 0000000000..26400bd206 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/table.cpp @@ -0,0 +1,65 @@ +#include "storage/table/table.h" + +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +TableScanState::~TableScanState() = default; + +// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const. +void TableScanState::resetOutVectors() { + for (const auto& outputVector : outputVectors) { + KU_ASSERT(outputVector->state.get() == outState.get()); + KU_UNUSED(outputVector); + outputVector->resetAuxiliaryBuffer(); + } + outState->getSelVectorUnsafe().setToUnfiltered(); +} + +void TableScanState::setToTable(const transaction::Transaction*, Table* table_, + std::vector columnIDs_, std::vector columnPredicateSets_, + RelDataDirection) { + table = table_; + columnIDs = std::move(columnIDs_); + columnPredicateSets = std::move(columnPredicateSets_); + nodeGroupScanState->chunkStates.resize(columnIDs.size()); +} + +TableInsertState::TableInsertState(std::vector propertyVectors) + : propertyVectors{std::move(propertyVectors)}, logToWAL{true} {} +TableInsertState::~TableInsertState() = default; +TableUpdateState::TableUpdateState(column_id_t columnID, ValueVector& propertyVector) + : columnID{columnID}, propertyVector{propertyVector}, logToWAL{true} {} +TableUpdateState::~TableUpdateState() = default; +TableDeleteState::TableDeleteState() : logToWAL{true} {} +TableDeleteState::~TableDeleteState() = default; + +Table::Table(const catalog::TableCatalogEntry* tableEntry, const StorageManager* storageManager, + MemoryManager* memoryManager) + : tableType{tableEntry->getTableType()}, tableID{tableEntry->getTableID()}, + tableName{tableEntry->getName()}, enableCompression{storageManager->compressionEnabled()}, + memoryManager{memoryManager}, shadowFile{&storageManager->getShadowFile()}, + hasChanges{false} {} + +Table::~Table() = default; + +bool Table::scan(transaction::Transaction* transaction, TableScanState& scanState) { + return scanInternal(transaction, scanState); +} + +DataChunk Table::constructDataChunk(MemoryManager* mm, std::vector types) { + DataChunk dataChunk(types.size()); + for (auto i = 0u; i < types.size(); i++) { + auto valueVector = std::make_unique(std::move(types[i]), mm); + dataChunk.insert(i, std::move(valueVector)); + } + return dataChunk; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/update_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/update_info.cpp new file mode 100644 index 0000000000..032f3530a1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/update_info.cpp @@ -0,0 +1,299 @@ +#include "storage/table/update_info.h" + +#include + +#include "common/exception/runtime.h" +#include "common/vector/value_vector.h" +#include "storage/storage_utils.h" +#include "storage/table/column_chunk_data.h" +#include "transaction/transaction.h" + +using namespace lbug::transaction; +using namespace lbug::common; + +namespace lbug { +namespace storage { + +VectorUpdateInfo& UpdateInfo::update(MemoryManager& memoryManager, const Transaction* transaction, + const idx_t vectorIdx, const sel_t rowIdxInVector, const ValueVector& values) { + UpdateNode& header = getOrCreateUpdateNode(vectorIdx); + // We always lock the head of the chain of vectorUpdateInfo to ensure that we can safely + // read/write to any part of the chain. + std::unique_lock chainLock{header.mtx}; + // Traverse the chain of vectorUpdateInfo to find the one that matches the transaction. Also + // detect if there is any write-write conflicts. + auto current = header.info.get(); + VectorUpdateInfo* vecUpdateInfo = nullptr; + while (current) { + if (current->version == transaction->getID()) { + // Same transaction, we can update the existing vector info. + KU_ASSERT(current->version >= Transaction::START_TRANSACTION_ID); + vecUpdateInfo = current; + } else if (current->version > transaction->getStartTS()) { + // Potentially there can be conflicts. `current` can be uncommitted transaction (version + // is transaction ID) or committed transaction started after this transaction. + for (auto i = 0u; i < current->numRowsUpdated; i++) { + if (current->rowsInVector[i] == rowIdxInVector) { + throw RuntimeException("Write-write conflict of updating the same row."); + } + } + } + current = current->prev.get(); + } + if (!vecUpdateInfo) { + // Create a new version here if not found in the chain. + auto newInfo = std::make_unique(memoryManager, transaction->getID(), + values.dataType.copy()); + vecUpdateInfo = newInfo.get(); + auto currentInfo = std::move(header.info); + if (currentInfo) { + currentInfo->next = newInfo.get(); + } + newInfo->prev = std::move(currentInfo); + header.info = std::move(newInfo); + } + KU_ASSERT(vecUpdateInfo); + // Check if the row is already updated in this transaction. + idx_t idxInUpdateData = INVALID_IDX; + for (auto i = 0u; i < vecUpdateInfo->numRowsUpdated; i++) { + if (vecUpdateInfo->rowsInVector[i] == rowIdxInVector) { + idxInUpdateData = i; + break; + } + } + if (idxInUpdateData != INVALID_IDX) { + // Overwrite existing update value. + vecUpdateInfo->data->write(&values, values.state->getSelVector()[0], idxInUpdateData); + } else { + // Append new value and update `rowsInVector`. + vecUpdateInfo->rowsInVector[vecUpdateInfo->numRowsUpdated] = rowIdxInVector; + vecUpdateInfo->data->write(&values, values.state->getSelVector()[0], + vecUpdateInfo->numRowsUpdated++); + } + return *vecUpdateInfo; +} + +void UpdateInfo::scan(const Transaction* transaction, ValueVector& output, offset_t offsetInChunk, + length_t length) const { + iterateScan(transaction, offsetInChunk, length, 0 /* startPosInOutput */, + [&](const VectorUpdateInfo& vecUpdateInfo, uint64_t i, uint64_t posInOutput) -> void { + vecUpdateInfo.data->lookup(i, output, posInOutput); + }); +} + +void UpdateInfo::lookup(const Transaction* transaction, offset_t rowInChunk, ValueVector& output, + sel_t posInOutputVector) const { + if (!isSet()) { + return; + } + auto [vectorIdx, rowInVector] = + StorageUtils::getQuotientRemainder(rowInChunk, DEFAULT_VECTOR_CAPACITY); + bool updated = false; + iterateVectorInfo(transaction, vectorIdx, [&](const VectorUpdateInfo& vectorInfo) { + if (updated) { + return; + } + for (auto i = 0u; i < vectorInfo.numRowsUpdated; i++) { + if (vectorInfo.rowsInVector[i] == rowInVector) { + vectorInfo.data->lookup(i, output, posInOutputVector); + updated = true; + return; + } + } + }); +} + +void UpdateInfo::scanCommitted(const Transaction* transaction, ColumnChunkData& output, + offset_t startOffsetInOutput, row_idx_t startRowScanned, row_idx_t numRows) const { + iterateScan(transaction, startRowScanned, numRows, startOffsetInOutput, + [&](const VectorUpdateInfo& vecUpdateInfo, uint64_t i, uint64_t posInOutput) -> void { + output.write(vecUpdateInfo.data.get(), i, posInOutput, 1); + }); +} + +void UpdateInfo::iterateVectorInfo(const Transaction* transaction, idx_t idx, + const std::function& func) const { + const UpdateNode* head = nullptr; + { + std::shared_lock lock{mtx}; + if (idx >= updates.size() || !updates[idx]->isEmpty()) { + return; + } + head = updates[idx].get(); + } + // We lock the head of the chain to ensure that we can safely read from any part of the + // chain. + KU_ASSERT(head); + std::shared_lock chainLock{head->mtx}; + auto current = head->info.get(); + KU_ASSERT(current); + while (current) { + if (current->version == transaction->getID() || + current->version <= transaction->getStartTS()) { + KU_ASSERT((current->version == transaction->getID() && + current->version >= Transaction::START_TRANSACTION_ID) || + (current->version <= transaction->getStartTS() && + current->version < Transaction::START_TRANSACTION_ID)); + func(*current); + } + current = current->getPrev(); + } +} + +#if defined(LBUG_RUNTIME_CHECKS) || !defined(NDEBUG) +// Assert that info is in the updatedNode version chain. +static bool validateUpdateChain(const UpdateNode& updatedNode, const VectorUpdateInfo* info) { + auto current = updatedNode.info.get(); + while (current) { + if (current == info) { + return true; + } + current = current->getPrev(); + } + return false; +} +#endif + +void UpdateInfo::commit(idx_t vectorIdx, VectorUpdateInfo* info, transaction_t commitTS) { + auto& updateNode = getUpdateNode(vectorIdx); + std::unique_lock chainLock{updateNode.mtx}; + KU_ASSERT(validateUpdateChain(updateNode, info)); + info->version = commitTS; +} + +void UpdateInfo::rollback(idx_t vectorIdx, transaction_t version) { + UpdateNode* header = nullptr; + // Note that we lock the entire UpdateInfo structure here because we might modify the + // head of the version chain. This is just a simplification and should be optimized later. + { + std::unique_lock lock{mtx}; + KU_ASSERT(updates.size() > vectorIdx); + header = updates[vectorIdx].get(); + } + KU_ASSERT(header); + std::unique_lock chainLock{header->mtx}; + // First check if this version is still in the chain. It might have been removed by + // a previous rollback entry of the same transaction. + // TODO(Guodong): This will be optimized by moving VectorUpdateInfo into UndoBuffer. + auto current = header->info.get(); + while (current) { + if (current->version == version) { + auto prevVersion = current->movePrev(); + if (current->next) { + // Has newer version. Remove this from the version chain. + const auto newerVersion = current->next; + if (prevVersion) { + prevVersion->next = newerVersion; + } + newerVersion->setPrev(std::move(prevVersion)); + } else { + KU_ASSERT(header->info.get() == current); + // This is the beginning of the version chain. + if (prevVersion) { + prevVersion->next = nullptr; + } + header->info = std::move(prevVersion); + } + break; + } + current = current->getPrev(); + } +} + +row_idx_t UpdateInfo::getNumUpdatedRows(const Transaction* transaction) const { + std::unordered_set updatedRows; + for (auto vectorIdx = 0u; vectorIdx < updates.size(); vectorIdx++) { + iterateVectorInfo(transaction, vectorIdx, [&](const VectorUpdateInfo& info) { + for (auto i = 0u; i < info.numRowsUpdated; i++) { + updatedRows.insert(info.rowsInVector[i]); + } + }); + } + return updatedRows.size(); +} + +bool UpdateInfo::hasUpdates(const Transaction* transaction, row_idx_t startRow, + length_t numRows) const { + bool hasUpdates = false; + iterateScan(transaction, startRow, numRows, 0 /* startPosInOutput */, + [&](const VectorUpdateInfo&, uint64_t, uint64_t) -> void { hasUpdates = true; }); + return hasUpdates; +} + +UpdateNode& UpdateInfo::getUpdateNode(idx_t vectorIdx) { + std::shared_lock lock{mtx}; + if (vectorIdx >= updates.size()) { + throw InternalException( + "UpdateInfo does not have update node for vector index: " + std::to_string(vectorIdx)); + } + return *updates[vectorIdx]; +} + +UpdateNode& UpdateInfo::getOrCreateUpdateNode(idx_t vectorIdx) { + std::unique_lock lock{mtx}; + if (vectorIdx >= updates.size()) { + updates.resize(vectorIdx + 1); + for (auto i = 0u; i < updates.size(); i++) { + if (!updates[i]) { + updates[i] = std::make_unique(); + } + } + } + return *updates[vectorIdx]; +} + +void UpdateInfo::iterateScan(const Transaction* transaction, uint64_t startOffsetToScan, + uint64_t numRowsToScan, uint64_t startPosInOutput, + const iterate_read_from_row_func_t& readFromRowFunc) const { + if (!isSet()) { + return; + } + auto [startVectorIdx, startOffsetInVector] = + StorageUtils::getQuotientRemainder(startOffsetToScan, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endOffsetInVector] = StorageUtils::getQuotientRemainder( + startOffsetToScan + numRowsToScan, DEFAULT_VECTOR_CAPACITY); + idx_t idx = startVectorIdx; + sel_t posInVector = startPosInOutput; + while (idx <= endVectorIdx) { + const auto startOffsetInclusively = idx == startVectorIdx ? startOffsetInVector : 0; + const auto endOffsetExclusively = + idx == endVectorIdx ? endOffsetInVector : DEFAULT_VECTOR_CAPACITY; + const auto numRowsInVector = endOffsetExclusively - startOffsetInclusively; + // We keep track of the rows that have been applied with updates from updateInfo. The update + // version chain is maintained with the newest version at the head and the oldest version at + // the tail. For each tuple, we iterate through the chain to merge the updates from latest + // visible version. If a row has been updated in the current vectorInfo, we should skip it + // in older versions. + std::bitset rowsUpdated; + iterateVectorInfo(transaction, idx, [&](const VectorUpdateInfo& vecUpdateInfo) -> void { + if (vecUpdateInfo.numRowsUpdated == 0) { + return; + } + if (rowsUpdated.count() == numRowsInVector) { + // All rows in this vector have been updated with a newer visible version already. + return; + } + // TODO(Guodong): Ideally we should make sure vecUpdateInfo.rowsInVector is sorted to + // simplify the checks here. + for (auto i = 0u; i < vecUpdateInfo.numRowsUpdated; i++) { + if (vecUpdateInfo.rowsInVector[i] < startOffsetInclusively || + vecUpdateInfo.rowsInVector[i] >= endOffsetExclusively) { + // Continue if the row is out of the current scan range. + continue; + } + auto updatedRowIdx = vecUpdateInfo.rowsInVector[i] - startOffsetInclusively; + if (rowsUpdated[updatedRowIdx]) { + // Skip the rows that have been updated with a newer visible version already. + continue; + } + readFromRowFunc(vecUpdateInfo, i, posInVector + updatedRowIdx); + rowsUpdated[updatedRowIdx] = true; + } + }); + posInVector += numRowsInVector; + idx++; + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/version_info.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/version_info.cpp new file mode 100644 index 0000000000..7bd524706e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/version_info.cpp @@ -0,0 +1,700 @@ +#include "storage/table/version_info.h" + +#include "common/exception/runtime.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "storage/storage_utils.h" +#include "transaction/transaction.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +struct VectorVersionInfo { + enum class InsertionStatus : uint8_t { NO_INSERTED, CHECK_VERSION, ALWAYS_INSERTED }; + // TODO(Guodong): ALWAYS_INSERTED is not added for now, but it may be useful as an optimization + // to mark the vector data after checkpoint is all deleted. + enum class DeletionStatus : uint8_t { NO_DELETED, CHECK_VERSION }; + + // TODO: Keep an additional same insertion/deletion field as an optimization to avoid the need + // of `array` if all are inserted/deleted in the same transaction. + // Also, avoid allocate `array` when status are NO_INSERTED and NO_DELETED. + // We can even consider separating the insertion and deletion into two separate Vectors. + std::unique_ptr> insertedVersions; + std::unique_ptr> deletedVersions; + // If all values in the Vector are inserted/deleted in the same transaction, we can use this to + // avoid the allocation of `array`. + transaction_t sameInsertionVersion; + transaction_t sameDeletionVersion; + InsertionStatus insertionStatus; + DeletionStatus deletionStatus; + + VectorVersionInfo() + : sameInsertionVersion{INVALID_TRANSACTION}, sameDeletionVersion{INVALID_TRANSACTION}, + insertionStatus{InsertionStatus::NO_INSERTED}, + deletionStatus{DeletionStatus::NO_DELETED} {} + DELETE_COPY_DEFAULT_MOVE(VectorVersionInfo); + + void append(transaction_t transactionID, row_idx_t startRow, row_idx_t numRows); + bool delete_(transaction_t transactionID, row_idx_t rowIdx); + void setInsertCommitTS(transaction_t commitTS, row_idx_t startRow, row_idx_t numRows); + void setDeleteCommitTS(transaction_t commitTS, row_idx_t startRow, row_idx_t numRows); + + bool isSelected(transaction_t startTS, transaction_t transactionID, row_idx_t rowIdx) const; + void getSelVectorForScan(transaction_t startTS, transaction_t transactionID, + SelectionVector& selVector, row_idx_t startRow, row_idx_t numRows, + sel_t startOutputPos) const; + + void rollbackInsertions(row_idx_t startRowInVector, row_idx_t numRows); + void rollbackDeletions(row_idx_t startRowInVector, row_idx_t numRows); + + bool hasDeletions(const transaction::Transaction* transaction) const; + + // Given startTS and transactionID, if the row is deleted to the transaction, return true. + bool isDeleted(transaction_t startTS, transaction_t transactionID, row_idx_t rowIdx) const; + // Given startTS and transactionID, if the row is readable to the transaction, return true. + bool isInserted(transaction_t startTS, transaction_t transactionID, row_idx_t rowIdx) const; + + row_idx_t getNumDeletions(transaction_t startTS, transaction_t transactionID, + row_idx_t startRow, length_t numRows) const; + + void serialize(Serializer& serializer) const; + static std::unique_ptr deSerialize(Deserializer& deSer); + +private: + void initInsertionVersionArray(); + void initDeletionVersionArray(); + + bool isSameInsertionVersion() const; + bool isSameDeletionVersion() const; +}; + +void VectorVersionInfo::append(const transaction_t transactionID, const row_idx_t startRow, + const row_idx_t numRows) { + insertionStatus = InsertionStatus::CHECK_VERSION; + if (transactionID == sameInsertionVersion) { + return; + } + if (!isSameInsertionVersion() && !insertedVersions) { + // No insertions before, and no need to allocate array. + sameInsertionVersion = transactionID; + return; + } + if (!insertedVersions) { + initInsertionVersionArray(); + for (auto i = 0u; i < startRow; i++) { + insertedVersions->operator[](i) = sameInsertionVersion; + } + sameInsertionVersion = INVALID_TRANSACTION; + } + for (auto i = 0u; i < numRows; i++) { + KU_ASSERT(insertedVersions->operator[](startRow + i) == INVALID_TRANSACTION); + insertedVersions->operator[](startRow + i) = transactionID; + } +} + +bool VectorVersionInfo::delete_(const transaction_t transactionID, const row_idx_t rowIdx) { + deletionStatus = DeletionStatus::CHECK_VERSION; + if (transactionID == sameDeletionVersion) { + // All are deleted in the same transaction. + return false; + } + if (isSameDeletionVersion()) { + // All are deleted in a different transaction. + throw RuntimeException( + "Write-write conflict: deleting a row that is already deleted by another transaction."); + } + if (!deletedVersions) { + // No deletions before. + initDeletionVersionArray(); + } + if (deletedVersions->operator[](rowIdx) == transactionID) { + return false; + } + if (deletedVersions->operator[](rowIdx) != INVALID_TRANSACTION) { + throw RuntimeException( + "Write-write conflict: deleting a row that is already deleted by another transaction."); + } + deletedVersions->operator[](rowIdx) = transactionID; + return true; +} + +void VectorVersionInfo::setInsertCommitTS(transaction_t commitTS, row_idx_t startRow, + row_idx_t numRows) { + if (isSameInsertionVersion()) { + sameInsertionVersion = commitTS; + return; + } + KU_ASSERT(insertedVersions); + for (auto rowIdx = startRow; rowIdx < startRow + numRows; rowIdx++) { + insertedVersions->operator[](rowIdx) = commitTS; + } +} + +void VectorVersionInfo::setDeleteCommitTS(transaction_t commitTS, row_idx_t startRow, + row_idx_t numRows) { + if (isSameDeletionVersion()) { + sameDeletionVersion = commitTS; + return; + } + KU_ASSERT(deletedVersions); + for (auto rowIdx = startRow; rowIdx < startRow + numRows; rowIdx++) { + deletedVersions->operator[](rowIdx) = commitTS; + } +} + +bool VectorVersionInfo::isSelected(const transaction_t startTS, const transaction_t transactionID, + const row_idx_t rowIdx) const { + if (deletionStatus == DeletionStatus::NO_DELETED && + insertionStatus == InsertionStatus::ALWAYS_INSERTED) { + return true; + } + if (insertionStatus == InsertionStatus::NO_INSERTED) { + return false; + } + if (isInserted(startTS, transactionID, rowIdx)) { + return !isDeleted(startTS, transactionID, rowIdx); + } + return false; +} + +void VectorVersionInfo::getSelVectorForScan(const transaction_t startTS, + const transaction_t transactionID, SelectionVector& selVector, const row_idx_t startRow, + const row_idx_t numRows, sel_t startOutputPos) const { + auto numSelected = selVector.getSelSize(); + if (deletionStatus == DeletionStatus::NO_DELETED && + insertionStatus == InsertionStatus::ALWAYS_INSERTED) { + if (selVector.isUnfiltered()) { + selVector.setSelSize(selVector.getSelSize() + numRows); + } else { + for (auto i = 0u; i < numRows; i++) { + selVector.getMutableBuffer()[numSelected++] = startOutputPos + i; + } + selVector.setToFiltered(numSelected); + } + } else if (insertionStatus != InsertionStatus::NO_INSERTED) { + // If there were no deleted values up to this point the selVector may be unfiltered but have + // non-zero size, and the mutable buffer may have arbitrary contents + if (selVector.isUnfiltered()) { + selVector.makeDynamic(); + } + for (auto i = 0u; i < numRows; i++) { + if (const auto rowIdx = startRow + i; isInserted(startTS, transactionID, rowIdx) && + !isDeleted(startTS, transactionID, rowIdx)) { + selVector.getMutableBuffer()[numSelected++] = startOutputPos + i; + } + } + selVector.setToFiltered(numSelected); + } +} + +bool VectorVersionInfo::isDeleted(const transaction_t startTS, const transaction_t transactionID, + const row_idx_t rowIdx) const { + switch (deletionStatus) { + case DeletionStatus::NO_DELETED: { + return false; + } + case DeletionStatus::CHECK_VERSION: { + transaction_t deletion = INVALID_TRANSACTION; + if (isSameDeletionVersion()) { + deletion = sameDeletionVersion; + } else { + KU_ASSERT(deletedVersions); + deletion = deletedVersions->operator[](rowIdx); + } + const auto isDeletedWithinSameTransaction = deletion == transactionID; + const auto isDeletedByPrevCommittedTransaction = deletion <= startTS; + return isDeletedWithinSameTransaction || isDeletedByPrevCommittedTransaction; + } + default: { + KU_UNREACHABLE; + } + } +} + +bool VectorVersionInfo::isInserted(const transaction_t startTS, const transaction_t transactionID, + const row_idx_t rowIdx) const { + switch (insertionStatus) { + case InsertionStatus::ALWAYS_INSERTED: { + return true; + } + case InsertionStatus::NO_INSERTED: { + return false; + } + case InsertionStatus::CHECK_VERSION: { + transaction_t insertion = INVALID_TRANSACTION; + if (isSameInsertionVersion()) { + insertion = sameInsertionVersion; + } else { + KU_ASSERT(insertedVersions); + insertion = insertedVersions->operator[](rowIdx); + } + const auto isInsertedWithinSameTransaction = insertion == transactionID; + const auto isInsertedByPrevCommittedTransaction = insertion <= startTS; + return isInsertedWithinSameTransaction || isInsertedByPrevCommittedTransaction; + } + default: { + KU_UNREACHABLE; + } + } +} + +row_idx_t VectorVersionInfo::getNumDeletions(transaction_t startTS, transaction_t transactionID, + row_idx_t startRow, length_t numRows) const { + if (deletionStatus == DeletionStatus::NO_DELETED) { + return 0; + } + row_idx_t numDeletions = 0u; + for (auto i = 0u; i < numRows; i++) { + numDeletions += isDeleted(startTS, transactionID, startRow + i); + } + return numDeletions; +} + +void VectorVersionInfo::rollbackInsertions(row_idx_t startRowInVector, row_idx_t numRows) { + if (isSameInsertionVersion()) { + // This implicitly assumes that all rows are inserted in the same transaction, so regardless + // which rows to be rolled back, we just reset the sameInsertionVersion. + sameInsertionVersion = INVALID_TRANSACTION; + } else { + if (insertedVersions) { + for (auto row = startRowInVector; row < startRowInVector + numRows; row++) { + insertedVersions->operator[](row) = INVALID_TRANSACTION; + } + bool hasAnyInsertions = false; + for (const auto& version : *insertedVersions) { + if (version != INVALID_TRANSACTION) { + hasAnyInsertions = true; + break; + } + } + if (!hasAnyInsertions) { + insertedVersions.reset(); + } + } + } + if (!insertedVersions) { + insertionStatus = InsertionStatus::NO_INSERTED; + deletionStatus = DeletionStatus::NO_DELETED; + } +} + +void VectorVersionInfo::rollbackDeletions(row_idx_t startRowInVector, row_idx_t numRows) { + if (isSameDeletionVersion()) { + // This implicitly assumes that all rows are deleted in the same transaction, so regardless + // which rows to be rollbacked, we just reset the sameInsertionVersion. + sameDeletionVersion = INVALID_TRANSACTION; + } else { + if (deletedVersions) { + for (auto row = startRowInVector; row < startRowInVector + numRows; row++) { + deletedVersions->operator[](row) = INVALID_TRANSACTION; + } + bool hasAnyDeletions = false; + for (const auto& version : *deletedVersions) { + if (version != INVALID_TRANSACTION) { + hasAnyDeletions = true; + break; + } + } + if (!hasAnyDeletions) { + deletedVersions.reset(); + } + } + } + if (!deletedVersions) { + deletionStatus = DeletionStatus::NO_DELETED; + } +} + +void VectorVersionInfo::initInsertionVersionArray() { + insertedVersions = std::make_unique>(); + insertedVersions->fill(INVALID_TRANSACTION); +} + +void VectorVersionInfo::initDeletionVersionArray() { + deletedVersions = std::make_unique>(); + deletedVersions->fill(INVALID_TRANSACTION); +} + +bool VectorVersionInfo::isSameInsertionVersion() const { + return sameInsertionVersion != INVALID_TRANSACTION; +} + +bool VectorVersionInfo::isSameDeletionVersion() const { + return sameDeletionVersion != INVALID_TRANSACTION; +} + +void VectorVersionInfo::serialize(Serializer& serializer) const { + if (deletedVersions) { + for (const auto deleted : *deletedVersions) { + // Versions should be either INVALID_TRANSACTION or committed timestamps. + KU_ASSERT(deleted == INVALID_TRANSACTION || + deleted < transaction::Transaction::START_TRANSACTION_ID); + KU_UNUSED(deleted); + } + } + KU_ASSERT(insertionStatus == InsertionStatus::NO_INSERTED || + insertionStatus == InsertionStatus::ALWAYS_INSERTED); + serializer.writeDebuggingInfo("insertion_status"); + serializer.serializeValue(insertionStatus); + serializer.writeDebuggingInfo("deletion_status"); + serializer.serializeValue(deletionStatus); + switch (deletionStatus) { + case DeletionStatus::NO_DELETED: { + // Nothing to serialize. + } break; + case DeletionStatus::CHECK_VERSION: { + serializer.writeDebuggingInfo("same_deletion_version"); + serializer.serializeValue(sameDeletionVersion); + if (sameDeletionVersion == INVALID_TRANSACTION) { + KU_ASSERT(deletedVersions); + serializer.writeDebuggingInfo("deleted_versions"); + serializer.serializeArray(*deletedVersions); + } + } break; + default: { + KU_UNREACHABLE; + } + } +} + +std::unique_ptr VectorVersionInfo::deSerialize(Deserializer& deSer) { + std::string key; + auto vectorVersionInfo = std::make_unique(); + deSer.validateDebuggingInfo(key, "insertion_status"); + deSer.deserializeValue(vectorVersionInfo->insertionStatus); + KU_ASSERT(vectorVersionInfo->insertionStatus == InsertionStatus::NO_INSERTED || + vectorVersionInfo->insertionStatus == InsertionStatus::ALWAYS_INSERTED); + deSer.validateDebuggingInfo(key, "deletion_status"); + deSer.deserializeValue(vectorVersionInfo->deletionStatus); + switch (vectorVersionInfo->deletionStatus) { + case DeletionStatus::NO_DELETED: { + // Nothing to deserialize. + } break; + case DeletionStatus::CHECK_VERSION: { + deSer.validateDebuggingInfo(key, "same_deletion_version"); + deSer.deserializeValue(vectorVersionInfo->sameDeletionVersion); + if (vectorVersionInfo->sameDeletionVersion == INVALID_TRANSACTION) { + deSer.validateDebuggingInfo(key, "deleted_versions"); + vectorVersionInfo->initDeletionVersionArray(); + deSer.deserializeArray( + *vectorVersionInfo->deletedVersions); + } + } break; + default: { + KU_UNREACHABLE; + } + } + if (vectorVersionInfo->deletedVersions) { + for (const auto deleted : *vectorVersionInfo->deletedVersions) { + // Versions should be either INVALID_TRANSACTION or committed timestamps. + KU_ASSERT(deleted == INVALID_TRANSACTION || + deleted < transaction::Transaction::START_TRANSACTION_ID); + KU_UNUSED(deleted); + } + } + return vectorVersionInfo; +} + +bool VectorVersionInfo::hasDeletions(const transaction::Transaction* transaction) const { + if (isSameDeletionVersion()) { + return sameDeletionVersion <= transaction->getStartTS() || + sameDeletionVersion == transaction->getID(); + } + row_idx_t numDeletions = 0; + for (auto i = 0u; i < deletedVersions->size(); i++) { + numDeletions += isDeleted(transaction->getStartTS(), transaction->getID(), i); + } + return numDeletions > 0; +} + +VectorVersionInfo& VersionInfo::getOrCreateVersionInfo(idx_t vectorIdx) { + if (vectorsInfo.size() <= vectorIdx) { + vectorsInfo.resize(vectorIdx + 1); + } + if (!vectorsInfo[vectorIdx]) { + vectorsInfo[vectorIdx] = std::make_unique(); + } + return *vectorsInfo[vectorIdx]; +} + +VectorVersionInfo* VersionInfo::getVectorVersionInfo(idx_t vectorIdx) const { + if (vectorIdx >= vectorsInfo.size()) { + return nullptr; + } + return vectorsInfo[vectorIdx].get(); +} + +VersionInfo::VersionInfo() = default; +VersionInfo::~VersionInfo() = default; + +void VersionInfo::append(transaction_t transactionID, const row_idx_t startRow, + const row_idx_t numRows) { + if (numRows == 0) { + return; + } + auto [startVectorIdx, startRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + for (auto vectorIdx = startVectorIdx; vectorIdx <= endVectorIdx; vectorIdx++) { + auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); + const auto startRowIdx = vectorIdx == startVectorIdx ? startRowIdxInVector : 0; + const auto endRowIdx = + vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; + const auto numRowsInVector = endRowIdx - startRowIdx + 1; + vectorVersionInfo.append(transactionID, startRowIdx, numRowsInVector); + } +} + +bool VersionInfo::delete_(transaction_t transactionID, const row_idx_t rowIdx) { + auto [vectorIdx, rowIdxInVector] = + StorageUtils::getQuotientRemainder(rowIdx, DEFAULT_VECTOR_CAPACITY); + auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); + if (vectorVersionInfo.insertionStatus == VectorVersionInfo::InsertionStatus::NO_INSERTED) { + // Note: The version info is newly created due to `delete_`. There is no newly inserted rows + // in this vector, thus all are rows checkpointed. We set the insertion status to + // ALWAYS_INSERTED to avoid checking the version in the future. + vectorVersionInfo.insertionStatus = VectorVersionInfo::InsertionStatus::ALWAYS_INSERTED; + } + return vectorVersionInfo.delete_(transactionID, rowIdxInVector); +} + +bool VersionInfo::isSelected(transaction_t startTS, transaction_t transactionID, + row_idx_t rowIdx) const { + auto [vectorIdx, rowIdxInVector] = + StorageUtils::getQuotientRemainder(rowIdx, DEFAULT_VECTOR_CAPACITY); + if (const auto vectorVersion = getVectorVersionInfo(vectorIdx)) { + return vectorVersion->isSelected(startTS, transactionID, rowIdxInVector); + } + return true; +} + +void VersionInfo::getSelVectorToScan(const transaction_t startTS, const transaction_t transactionID, + SelectionVector& selVector, const row_idx_t startRow, const row_idx_t numRows) const { + if (numRows == 0) { + return; + } + auto [startVectorIdx, startRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + auto vectorIdx = startVectorIdx; + selVector.setToUnfiltered(0); + sel_t outputPos = 0u; + while (vectorIdx <= endVectorIdx) { + const auto startRowIdx = vectorIdx == startVectorIdx ? startRowIdxInVector : 0; + const auto endRowIdx = + vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; + const auto numRowsInVector = endRowIdx - startRowIdx + 1; + const auto vectorVersion = getVectorVersionInfo(vectorIdx); + if (!vectorVersion) { + auto numSelected = selVector.getSelSize(); + if (selVector.isUnfiltered()) { + selVector.setSelSize(numSelected + numRowsInVector); + } else { + for (auto i = 0u; i < numRowsInVector; i++) { + selVector.getMutableBuffer()[numSelected++] = outputPos + i; + } + selVector.setToFiltered(numSelected); + } + } else { + vectorVersion->getSelVectorForScan(startTS, transactionID, selVector, startRowIdx, + numRowsInVector, outputPos); + } + outputPos += numRowsInVector; + vectorIdx++; + } + KU_ASSERT(outputPos <= DEFAULT_VECTOR_CAPACITY); +} + +void VersionInfo::clearVectorInfo(const idx_t vectorIdx) { + KU_ASSERT(vectorIdx < vectorsInfo.size()); + vectorsInfo[vectorIdx] = nullptr; +} + +bool VersionInfo::hasDeletions() const { + for (auto& vectorInfo : vectorsInfo) { + if (vectorInfo && + vectorInfo->deletionStatus == VectorVersionInfo::DeletionStatus::CHECK_VERSION) { + return true; + } + } + return false; +} + +row_idx_t VersionInfo::getNumDeletions(const transaction::Transaction* transaction, + row_idx_t startRow, length_t numRows) const { + if (numRows == 0) { + return 0; + } + auto [startVector, startRowInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + idx_t vectorIdx = startVector; + row_idx_t numDeletions = 0u; + while (vectorIdx <= endVectorIdx) { + const auto rowInVector = vectorIdx == startVector ? startRowInVector : 0; + const auto numRowsInVector = vectorIdx == endVectorIdx ? + endRowInVector - rowInVector + 1 : + DEFAULT_VECTOR_CAPACITY - rowInVector; + const auto vectorVersion = getVectorVersionInfo(vectorIdx); + if (vectorVersion) { + numDeletions += vectorVersion->getNumDeletions(transaction->getStartTS(), + transaction->getID(), rowInVector, numRowsInVector); + } + vectorIdx++; + } + return numDeletions; +} + +bool VersionInfo::hasInsertions() const { + for (auto& vectorInfo : vectorsInfo) { + if (vectorInfo && + vectorInfo->insertionStatus == VectorVersionInfo::InsertionStatus::CHECK_VERSION) { + return true; + } + } + return false; +} + +bool VersionInfo::isDeleted(const transaction::Transaction* transaction, + row_idx_t rowInChunk) const { + auto [vectorIdx, rowInVector] = + StorageUtils::getQuotientRemainder(rowInChunk, DEFAULT_VECTOR_CAPACITY); + const auto vectorVersion = getVectorVersionInfo(vectorIdx); + if (vectorVersion) { + return vectorVersion->isDeleted(transaction->getStartTS(), transaction->getID(), + rowInVector); + } + return false; +} + +bool VersionInfo::isInserted(const transaction::Transaction* transaction, + row_idx_t rowInChunk) const { + auto [vectorIdx, rowInVector] = + StorageUtils::getQuotientRemainder(rowInChunk, DEFAULT_VECTOR_CAPACITY); + const auto vectorVersion = getVectorVersionInfo(vectorIdx); + if (vectorVersion) { + return vectorVersion->isInserted(transaction->getStartTS(), transaction->getID(), + rowInVector); + } + return true; +} + +bool VersionInfo::hasDeletions(const transaction::Transaction* transaction) const { + for (auto& vectorInfo : vectorsInfo) { + if (vectorInfo && vectorInfo->hasDeletions(transaction) > 0) { + return true; + } + } + return false; +} + +void VersionInfo::commitInsert(row_idx_t startRow, row_idx_t numRows, transaction_t commitTS) { + if (numRows == 0) { + return; + } + auto [startVectorIdx, startRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + for (auto vectorIdx = startVectorIdx; vectorIdx <= endVectorIdx; vectorIdx++) { + const auto startRowIdx = vectorIdx == startVectorIdx ? startRowIdxInVector : 0; + const auto endRowIdx = + vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; + auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); + vectorVersionInfo.setInsertCommitTS(commitTS, startRowIdx, endRowIdx - startRowIdx + 1); + } +} + +void VersionInfo::rollbackInsert(row_idx_t startRow, row_idx_t numRows) { + if (numRows == 0) { + return; + } + auto [startVectorIdx, startRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + for (auto vectorIdx = startVectorIdx; vectorIdx <= endVectorIdx; vectorIdx++) { + const auto startRowIdx = vectorIdx == startVectorIdx ? startRowIdxInVector : 0; + const auto endRowIdx = + vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; + auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); + vectorVersionInfo.rollbackInsertions(startRowIdx, endRowIdx - startRowIdx + 1); + } +} + +void VersionInfo::commitDelete(row_idx_t startRow, row_idx_t numRows, transaction_t commitTS) { + if (numRows == 0) { + return; + } + auto [startVectorIdx, startRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + for (auto vectorIdx = startVectorIdx; vectorIdx <= endVectorIdx; vectorIdx++) { + const auto startRowIdx = vectorIdx == startVectorIdx ? startRowIdxInVector : 0; + const auto endRowIdx = + vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; + auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); + vectorVersionInfo.setDeleteCommitTS(commitTS, startRowIdx, endRowIdx - startRowIdx + 1); + } +} + +void VersionInfo::rollbackDelete(row_idx_t startRow, row_idx_t numRows) { + if (numRows == 0) { + return; + } + auto [startVectorIdx, startRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow, DEFAULT_VECTOR_CAPACITY); + auto [endVectorIdx, endRowIdxInVector] = + StorageUtils::getQuotientRemainder(startRow + numRows - 1, DEFAULT_VECTOR_CAPACITY); + for (auto vectorIdx = startVectorIdx; vectorIdx <= endVectorIdx; vectorIdx++) { + auto& vectorVersionInfo = getOrCreateVersionInfo(vectorIdx); + const auto startRowIdx = vectorIdx == startVectorIdx ? startRowIdxInVector : 0; + const auto endRowIdx = + vectorIdx == endVectorIdx ? endRowIdxInVector : DEFAULT_VECTOR_CAPACITY - 1; + vectorVersionInfo.rollbackDeletions(startRowIdx, endRowIdx - startRowIdx + 1); + } +} + +void VersionInfo::serialize(Serializer& serializer) const { + serializer.writeDebuggingInfo("vectors_info_size"); + serializer.write(vectorsInfo.size()); + for (auto i = 0u; i < vectorsInfo.size(); i++) { + auto hasVectorVersion = vectorsInfo[i] != nullptr; + serializer.writeDebuggingInfo("has_vector_info"); + serializer.write(hasVectorVersion); + if (hasVectorVersion) { + serializer.writeDebuggingInfo("vector_info"); + vectorsInfo[i]->serialize(serializer); + } + } +} + +std::unique_ptr VersionInfo::deserialize(Deserializer& deSer) { + std::string key; + uint64_t vectorSize = 0; + deSer.validateDebuggingInfo(key, "vectors_info_size"); + deSer.deserializeValue(vectorSize); + auto versionInfo = std::make_unique(); + for (auto i = 0u; i < vectorSize; i++) { + bool hasVectorVersion = false; + deSer.validateDebuggingInfo(key, "has_vector_info"); + deSer.deserializeValue(hasVectorVersion); + if (hasVectorVersion) { + deSer.validateDebuggingInfo(key, "vector_info"); + auto vectorVersionInfo = VectorVersionInfo::deSerialize(deSer); + versionInfo->vectorsInfo.push_back(std::move(vectorVersionInfo)); + } else { + versionInfo->vectorsInfo.push_back(nullptr); + } + } + return versionInfo; +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/version_record_handler.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/version_record_handler.cpp new file mode 100644 index 0000000000..7efa34727b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/table/version_record_handler.cpp @@ -0,0 +1,15 @@ +#include "storage/table/version_record_handler.h" + +#include "main/client_context.h" +#include "storage/table/chunked_node_group.h" + +namespace lbug::storage { + +void VersionRecordHandler::rollbackInsert(main::ClientContext* context, + common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows) const { + applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackInsert, nodeGroupIdx, startRow, numRows, + transaction::Transaction::Get(*context)->getCommitTS()); +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/undo_buffer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/undo_buffer.cpp new file mode 100644 index 0000000000..c6e9faf6e1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/undo_buffer.cpp @@ -0,0 +1,309 @@ +#include "storage/undo_buffer.h" + +#include "catalog/catalog_entry/catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "catalog/catalog_set.h" +#include "storage/table/chunked_node_group.h" +#include "storage/table/update_info.h" +#include "storage/table/version_record_handler.h" +#include "transaction/transaction.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::main; + +namespace lbug { +namespace storage { + +struct UndoRecordHeader { + UndoBuffer::UndoRecordType recordType; + uint32_t recordSize; + + UndoRecordHeader(const UndoBuffer::UndoRecordType recordType, const uint32_t recordSize) + : recordType{recordType}, recordSize{recordSize} {} +}; + +struct CatalogEntryRecord { + CatalogSet* catalogSet; + CatalogEntry* catalogEntry; +}; + +struct SequenceEntryRecord { + SequenceCatalogEntry* sequenceEntry; + SequenceRollbackData sequenceRollbackData; +}; + +struct NodeBatchInsertRecord { + table_id_t tableID; +}; + +struct VersionRecord { + row_idx_t startRow; + row_idx_t numRows; + node_group_idx_t nodeGroupIdx; + const VersionRecordHandler* versionRecordHandler; +}; + +struct VectorUpdateRecord { + UpdateInfo* updateInfo; + idx_t vectorIdx; + VectorUpdateInfo* vectorUpdateInfo; + transaction_t version; // This is used during roll back. +}; + +template +void UndoBufferIterator::iterate(F&& callback) { + idx_t bufferIdx = 0; + while (bufferIdx < undoBuffer.memoryBuffers.size()) { + auto& currentBuffer = undoBuffer.memoryBuffers[bufferIdx]; + auto current = currentBuffer.getData(); + const auto end = current + currentBuffer.getCurrentPosition(); + while (current < end) { + UndoRecordHeader recordHeader = *reinterpret_cast(current); + current += sizeof(UndoRecordHeader); + callback(recordHeader.recordType, current); + current += recordHeader.recordSize; // Skip the current entry. + } + bufferIdx++; + } +} + +template +void UndoBufferIterator::reverseIterate(F&& callback) { + idx_t numBuffersLeft = undoBuffer.memoryBuffers.size(); + while (numBuffersLeft > 0) { + const auto bufferIdx = numBuffersLeft - 1; + auto& currentBuffer = undoBuffer.memoryBuffers[bufferIdx]; + auto current = currentBuffer.getData(); + const auto end = current + currentBuffer.getCurrentPosition(); + std::vector> entries; + while (current < end) { + UndoRecordHeader recordHeader = *reinterpret_cast(current); + current += sizeof(UndoRecordHeader); + entries.push_back({recordHeader.recordType, current}); + current += recordHeader.recordSize; // Skip the current entry. + } + for (auto i = entries.size(); i >= 1; i--) { + callback(entries[i - 1].first, entries[i - 1].second); + } + numBuffersLeft--; + } +} + +void UndoBuffer::createCatalogEntry(CatalogSet& catalogSet, CatalogEntry& catalogEntry) { + auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(CatalogEntryRecord)); + const UndoRecordHeader recordHeader{UndoRecordType::CATALOG_ENTRY, sizeof(CatalogEntryRecord)}; + *reinterpret_cast(buffer) = recordHeader; + buffer += sizeof(UndoRecordHeader); + const CatalogEntryRecord catalogEntryRecord{&catalogSet, &catalogEntry}; + *reinterpret_cast(buffer) = catalogEntryRecord; +} + +void UndoBuffer::createSequenceChange(SequenceCatalogEntry& sequenceEntry, + const SequenceRollbackData& data) { + auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(SequenceEntryRecord)); + const UndoRecordHeader recordHeader{UndoRecordType::SEQUENCE_ENTRY, + sizeof(SequenceEntryRecord)}; + *reinterpret_cast(buffer) = recordHeader; + buffer += sizeof(UndoRecordHeader); + const SequenceEntryRecord sequenceEntryRecord{&sequenceEntry, data}; + *reinterpret_cast(buffer) = sequenceEntryRecord; +} + +void UndoBuffer::createInsertInfo(node_group_idx_t nodeGroupIdx, row_idx_t startRow, + row_idx_t numRows, const VersionRecordHandler* versionRecordHandler) { + createVersionInfo(UndoRecordType::INSERT_INFO, startRow, numRows, versionRecordHandler, + nodeGroupIdx); +} + +void UndoBuffer::createDeleteInfo(node_group_idx_t nodeGroupIdx, row_idx_t startRow, + row_idx_t numRows, const VersionRecordHandler* versionRecordHandler) { + createVersionInfo(UndoRecordType::DELETE_INFO, startRow, numRows, versionRecordHandler, + nodeGroupIdx); +} + +void UndoBuffer::createVersionInfo(const UndoRecordType recordType, row_idx_t startRow, + row_idx_t numRows, const VersionRecordHandler* versionRecordHandler, + node_group_idx_t nodeGroupIdx) { + KU_ASSERT(versionRecordHandler); + auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VersionRecord)); + const UndoRecordHeader recordHeader{recordType, sizeof(VersionRecord)}; + *reinterpret_cast(buffer) = recordHeader; + buffer += sizeof(UndoRecordHeader); + *reinterpret_cast(buffer) = + VersionRecord{startRow, numRows, nodeGroupIdx, versionRecordHandler}; +} + +void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, + VectorUpdateInfo* vectorUpdateInfo, transaction_t version) { + auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VectorUpdateRecord)); + const UndoRecordHeader recordHeader{UndoRecordType::UPDATE_INFO, sizeof(VectorUpdateRecord)}; + *reinterpret_cast(buffer) = recordHeader; + buffer += sizeof(UndoRecordHeader); + const VectorUpdateRecord vectorUpdateRecord{updateInfo, vectorIdx, vectorUpdateInfo, version}; + *reinterpret_cast(buffer) = vectorUpdateRecord; +} + +uint8_t* UndoBuffer::createUndoRecord(const uint64_t size) { + std::unique_lock xLck{mtx}; + if (memoryBuffers.empty() || !memoryBuffers.back().canFit(size)) { + auto capacity = UndoMemoryBuffer::UNDO_MEMORY_BUFFER_INIT_CAPACITY; + while (size > capacity) { + capacity *= 2; + } + // We need to allocate a new memory buffer. + memoryBuffers.emplace_back(mm->allocateBuffer(false, capacity), capacity); + } + const auto res = + memoryBuffers.back().getDataUnsafe() + memoryBuffers.back().getCurrentPosition(); + memoryBuffers.back().moveCurrentPosition(size); + return res; +} + +void UndoBuffer::commit(transaction_t commitTS) const { + UndoBufferIterator iterator{*this}; + iterator.iterate([&](UndoRecordType entryType, uint8_t const* entry) { + commitRecord(entryType, entry, commitTS); + }); +} + +void UndoBuffer::rollback(ClientContext* context) const { + UndoBufferIterator iterator{*this}; + iterator.reverseIterate([&](UndoRecordType entryType, uint8_t const* entry) { + rollbackRecord(context, entryType, entry); + }); +} + +void UndoBuffer::commitRecord(UndoRecordType recordType, const uint8_t* record, + transaction_t commitTS) { + switch (recordType) { + case UndoRecordType::CATALOG_ENTRY: { + commitCatalogEntryRecord(record, commitTS); + } break; + case UndoRecordType::SEQUENCE_ENTRY: { + commitSequenceEntry(record, commitTS); + } break; + case UndoRecordType::INSERT_INFO: + case UndoRecordType::DELETE_INFO: { + commitVersionInfo(recordType, record, commitTS); + } break; + case UndoRecordType::UPDATE_INFO: { + commitVectorUpdateInfo(record, commitTS); + } break; + default: + KU_UNREACHABLE; + } +} + +void UndoBuffer::commitCatalogEntryRecord(const uint8_t* record, const transaction_t commitTS) { + const auto& [_, catalogEntry] = *reinterpret_cast(record); + const auto newCatalogEntry = catalogEntry->getNext(); + KU_ASSERT(newCatalogEntry); + newCatalogEntry->setTimestamp(commitTS); +} + +void UndoBuffer::commitVersionInfo(UndoRecordType recordType, const uint8_t* record, + transaction_t commitTS) { + const auto& undoRecord = *reinterpret_cast(record); + switch (recordType) { + case UndoRecordType::INSERT_INFO: { + undoRecord.versionRecordHandler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitInsert, + undoRecord.nodeGroupIdx, undoRecord.startRow, undoRecord.numRows, commitTS); + } break; + case UndoRecordType::DELETE_INFO: { + undoRecord.versionRecordHandler->applyFuncToChunkedGroups(&ChunkedNodeGroup::commitDelete, + undoRecord.nodeGroupIdx, undoRecord.startRow, undoRecord.numRows, commitTS); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +void UndoBuffer::commitVectorUpdateInfo(const uint8_t* record, transaction_t commitTS) { + auto& undoRecord = *reinterpret_cast(record); + KU_ASSERT(undoRecord.updateInfo); + KU_ASSERT(undoRecord.vectorUpdateInfo); + undoRecord.updateInfo->commit(undoRecord.vectorIdx, undoRecord.vectorUpdateInfo, commitTS); +} + +void UndoBuffer::rollbackRecord(ClientContext* context, const UndoRecordType recordType, + const uint8_t* record) { + switch (recordType) { + case UndoRecordType::CATALOG_ENTRY: { + rollbackCatalogEntryRecord(record); + } break; + case UndoRecordType::SEQUENCE_ENTRY: { + rollbackSequenceEntry(record); + } break; + case UndoRecordType::INSERT_INFO: + case UndoRecordType::DELETE_INFO: { + rollbackVersionInfo(context, recordType, record); + } break; + case UndoRecordType::UPDATE_INFO: { + rollbackVectorUpdateInfo(record); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +void UndoBuffer::rollbackCatalogEntryRecord(const uint8_t* record) { + const auto& [catalogSet, catalogEntry] = *reinterpret_cast(record); + const auto entryToRollback = catalogEntry->getNext(); + KU_ASSERT(entryToRollback); + if (entryToRollback->getNext()) { + // If entryToRollback has a newer entry (next) in the version chain. Simple remove + // entryToRollback from the chain. + const auto newerEntry = entryToRollback->getNext(); + newerEntry->setPrev(entryToRollback->movePrev()); + } else { + // This is the beginning of the version chain. + auto olderEntry = entryToRollback->movePrev(); + catalogSet->eraseNoLock(catalogEntry->getName()); + if (olderEntry) { + catalogSet->emplaceNoLock(std::move(olderEntry)); + } + } +} + +void UndoBuffer::commitSequenceEntry(const uint8_t*, transaction_t) { + // DO NOTHING. +} + +void UndoBuffer::rollbackSequenceEntry(const uint8_t* entry) { + const auto& sequenceRecord = *reinterpret_cast(entry); + const auto sequenceEntry = sequenceRecord.sequenceEntry; + const auto& data = sequenceRecord.sequenceRollbackData; + sequenceEntry->rollbackVal(data.usageCount, data.currVal); +} + +void UndoBuffer::rollbackVersionInfo(ClientContext* context, UndoRecordType recordType, + const uint8_t* record) { + auto& undoRecord = *reinterpret_cast(record); + switch (recordType) { + case UndoRecordType::INSERT_INFO: { + undoRecord.versionRecordHandler->rollbackInsert(context, undoRecord.nodeGroupIdx, + undoRecord.startRow, undoRecord.numRows); + } break; + case UndoRecordType::DELETE_INFO: { + undoRecord.versionRecordHandler->applyFuncToChunkedGroups(&ChunkedNodeGroup::rollbackDelete, + undoRecord.nodeGroupIdx, undoRecord.startRow, undoRecord.numRows, + transaction::Transaction::Get(*context)->getCommitTS()); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +void UndoBuffer::rollbackVectorUpdateInfo(const uint8_t* record) { + auto& undoRecord = *reinterpret_cast(record); + KU_ASSERT(undoRecord.updateInfo); + undoRecord.updateInfo->rollback(undoRecord.vectorIdx, undoRecord.version); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/CMakeLists.txt new file mode 100644 index 0000000000..eae0dd8343 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(lbug_storage_wal + OBJECT + checksum_reader.cpp + checksum_writer.cpp + local_wal.cpp + wal.cpp + wal_record.cpp + wal_replayer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/checksum_reader.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/checksum_reader.cpp new file mode 100644 index 0000000000..1282c7507b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/checksum_reader.cpp @@ -0,0 +1,62 @@ +#include "storage/wal/checksum_reader.h" + +#include + +#include "common/checksum.h" +#include "common/exception/storage.h" +#include "common/serializer/buffered_file.h" +#include "common/serializer/deserializer.h" +#include + +namespace lbug::storage { +static constexpr uint64_t INITIAL_BUFFER_SIZE = common::LBUG_PAGE_SIZE; + +ChecksumReader::ChecksumReader(common::FileInfo& fileInfo, MemoryManager& memoryManager, + std::string_view checksumMismatchMessage) + : deserializer(std::make_unique(fileInfo)), + entryBuffer(memoryManager.allocateBuffer(false, INITIAL_BUFFER_SIZE)), + checksumMismatchMessage(checksumMismatchMessage) {} + +static void resizeBufferIfNeeded(std::unique_ptr& entryBuffer, + uint64_t requestedSize) { + const auto currentBufferSize = entryBuffer->getBuffer().size_bytes(); + if (requestedSize > currentBufferSize) { + auto* memoryManager = entryBuffer->getMemoryManager(); + entryBuffer = memoryManager->allocateBuffer(false, std::bit_ceil(requestedSize)); + } +} + +void ChecksumReader::read(uint8_t* data, uint64_t size) { + deserializer.read(data, size); + if (currentEntrySize.has_value()) { + resizeBufferIfNeeded(entryBuffer, *currentEntrySize + size); + std::memcpy(entryBuffer->getData() + *currentEntrySize, data, size); + *currentEntrySize += size; + } +} + +bool ChecksumReader::finished() { + return deserializer.finished(); +} + +void ChecksumReader::onObjectBegin() { + currentEntrySize.emplace(0); +} + +void ChecksumReader::onObjectEnd() { + KU_ASSERT(currentEntrySize.has_value()); + const uint64_t computedChecksum = common::checksum(entryBuffer->getData(), *currentEntrySize); + uint64_t storedChecksum{}; + deserializer.deserializeValue(storedChecksum); + if (storedChecksum != computedChecksum) { + throw common::StorageException(std::string{checksumMismatchMessage}); + } + + currentEntrySize.reset(); +} + +uint64_t ChecksumReader::getReadOffset() const { + return deserializer.getReader()->cast()->getReadOffset(); +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/checksum_writer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/checksum_writer.cpp new file mode 100644 index 0000000000..99fd00a51d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/checksum_writer.cpp @@ -0,0 +1,66 @@ +#include "storage/wal/checksum_writer.h" + +#include + +#include "common/checksum.h" +#include "common/serializer/serializer.h" +#include + +namespace lbug::storage { +static constexpr uint64_t INITIAL_BUFFER_SIZE = common::LBUG_PAGE_SIZE; + +ChecksumWriter::ChecksumWriter(std::shared_ptr outputWriter, + MemoryManager& memoryManager) + : outputSerializer(std::move(outputWriter)), + entryBuffer(memoryManager.allocateBuffer(false, INITIAL_BUFFER_SIZE)) {} + +static void resizeBufferIfNeeded(std::unique_ptr& entryBuffer, + uint64_t requestedSize) { + const auto currentBufferSize = entryBuffer->getBuffer().size_bytes(); + if (requestedSize > currentBufferSize) { + auto* memoryManager = entryBuffer->getMemoryManager(); + entryBuffer = memoryManager->allocateBuffer(false, std::bit_ceil(requestedSize)); + } +} + +void ChecksumWriter::write(const uint8_t* data, uint64_t size) { + if (currentEntrySize.has_value()) { + resizeBufferIfNeeded(entryBuffer, *currentEntrySize + size); + std::memcpy(entryBuffer->getData() + *currentEntrySize, data, size); + *currentEntrySize += size; + } else { + // The data we are writing does not need to be checksummed + outputSerializer.write(data, size); + } +} + +void ChecksumWriter::clear() { + currentEntrySize.reset(); + outputSerializer.getWriter()->clear(); +} + +void ChecksumWriter::flush() { + outputSerializer.getWriter()->flush(); +} + +void ChecksumWriter::onObjectBegin() { + currentEntrySize.emplace(0); +} + +void ChecksumWriter::onObjectEnd() { + KU_ASSERT(currentEntrySize.has_value()); + const auto checksum = common::checksum(entryBuffer->getData(), *currentEntrySize); + outputSerializer.write(entryBuffer->getData(), *currentEntrySize); + outputSerializer.serializeValue(checksum); + currentEntrySize.reset(); +} + +uint64_t ChecksumWriter::getSize() const { + return currentEntrySize.value_or(0) + outputSerializer.getWriter()->getSize(); +} + +void ChecksumWriter::sync() { + outputSerializer.getWriter()->sync(); +} + +} // namespace lbug::storage diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/local_wal.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/local_wal.cpp new file mode 100644 index 0000000000..94e9f88c69 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/local_wal.cpp @@ -0,0 +1,113 @@ +#include "storage/wal/local_wal.h" + +#include "binder/ddl/bound_alter_info.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "common/serializer/in_mem_file_writer.h" +#include "common/vector/value_vector.h" +#include "storage/wal/checksum_writer.h" + +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace storage { + +LocalWAL::LocalWAL(MemoryManager& mm, bool enableChecksums) + : inMemWriter(std::make_shared(mm)), + serializer(enableChecksums ? std::make_shared(inMemWriter, mm) : + std::static_pointer_cast(inMemWriter)) {} + +void LocalWAL::logBeginTransaction() { + BeginTransactionRecord walRecord; + addNewWALRecord(walRecord); +} + +void LocalWAL::logCommit() { + CommitRecord walRecord; + addNewWALRecord(walRecord); +} + +void LocalWAL::logCreateCatalogEntryRecord(CatalogEntry* catalogEntry, bool isInternal) { + CreateCatalogEntryRecord walRecord(catalogEntry, isInternal); + addNewWALRecord(walRecord); +} + +void LocalWAL::logDropCatalogEntryRecord(table_id_t tableID, CatalogEntryType type) { + DropCatalogEntryRecord walRecord(tableID, type); + addNewWALRecord(walRecord); +} + +void LocalWAL::logAlterCatalogEntryRecord(const BoundAlterInfo* alterInfo) { + AlterTableEntryRecord walRecord(alterInfo); + addNewWALRecord(walRecord); +} + +void LocalWAL::logTableInsertion(table_id_t tableID, TableType tableType, row_idx_t numRows, + const std::vector& vectors) { + TableInsertionRecord walRecord(tableID, tableType, numRows, vectors); + addNewWALRecord(walRecord); +} + +void LocalWAL::logNodeDeletion(table_id_t tableID, offset_t nodeOffset, ValueVector* pkVector) { + NodeDeletionRecord walRecord(tableID, nodeOffset, pkVector); + addNewWALRecord(walRecord); +} + +void LocalWAL::logNodeUpdate(table_id_t tableID, column_id_t columnID, offset_t nodeOffset, + ValueVector* propertyVector) { + NodeUpdateRecord walRecord(tableID, columnID, nodeOffset, propertyVector); + addNewWALRecord(walRecord); +} + +void LocalWAL::logRelDelete(table_id_t tableID, ValueVector* srcNodeVector, + ValueVector* dstNodeVector, ValueVector* relIDVector) { + RelDeletionRecord walRecord(tableID, srcNodeVector, dstNodeVector, relIDVector); + addNewWALRecord(walRecord); +} + +void LocalWAL::logRelDetachDelete(table_id_t tableID, RelDataDirection direction, + ValueVector* srcNodeVector) { + RelDetachDeleteRecord walRecord(tableID, direction, srcNodeVector); + addNewWALRecord(walRecord); +} + +void LocalWAL::logRelUpdate(table_id_t tableID, column_id_t columnID, ValueVector* srcNodeVector, + ValueVector* dstNodeVector, ValueVector* relIDVector, ValueVector* propertyVector) { + RelUpdateRecord walRecord(tableID, columnID, srcNodeVector, dstNodeVector, relIDVector, + propertyVector); + addNewWALRecord(walRecord); +} + +void LocalWAL::logUpdateSequenceRecord(sequence_id_t sequenceID, uint64_t kCount) { + UpdateSequenceRecord walRecord(sequenceID, kCount); + addNewWALRecord(walRecord); +} + +void LocalWAL::logLoadExtension(std::string path) { + LoadExtensionRecord walRecord(std::move(path)); + addNewWALRecord(walRecord); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): semantically non-const function. +void LocalWAL::clear() { + std::unique_lock lck{mtx}; + serializer.getWriter()->clear(); +} + +uint64_t LocalWAL::getSize() { + std::unique_lock lck{mtx}; + return serializer.getWriter()->getSize(); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): semantically non-const function. +void LocalWAL::addNewWALRecord(const WALRecord& walRecord) { + std::unique_lock lck{mtx}; + KU_ASSERT(walRecord.type != WALRecordType::INVALID_RECORD); + serializer.getWriter()->onObjectBegin(); + walRecord.serialize(serializer); + serializer.getWriter()->onObjectEnd(); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal.cpp new file mode 100644 index 0000000000..046378ef40 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal.cpp @@ -0,0 +1,122 @@ +#include "storage/wal/wal.h" + +#include "common/file_system/file_info.h" +#include "common/file_system/virtual_file_system.h" +#include "common/serializer/buffered_file.h" +#include "common/serializer/in_mem_file_writer.h" +#include "main/client_context.h" +#include "main/database.h" +#include "main/db_config.h" +#include "storage/file_db_id_utils.h" +#include "storage/storage_manager.h" +#include "storage/storage_utils.h" +#include "storage/wal/checksum_writer.h" +#include "storage/wal/local_wal.h" + +using namespace lbug::common; + +namespace lbug { +namespace storage { + +WAL::WAL(const std::string& dbPath, bool readOnly, bool enableChecksums, VirtualFileSystem* vfs) + : walPath{StorageUtils::getWALFilePath(dbPath)}, + inMemory{main::DBConfig::isDBPathInMemory(dbPath)}, readOnly{readOnly}, vfs{vfs}, + enableChecksums(enableChecksums) {} + +WAL::~WAL() {} + +void WAL::logCommittedWAL(LocalWAL& localWAL, main::ClientContext* context) { + KU_ASSERT(!readOnly); + if (inMemory || localWAL.getSize() == 0) { + return; // No need to log empty WAL. + } + std::unique_lock lck{mtx}; + initWriter(context); + localWAL.inMemWriter->flush(*serializer->getWriter()); + flushAndSyncNoLock(); +} + +void WAL::logAndFlushCheckpoint(main::ClientContext* context) { + std::unique_lock lck{mtx}; + initWriter(context); + CheckpointRecord walRecord; + addNewWALRecordNoLock(walRecord); + flushAndSyncNoLock(); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): semantically non-const function. +void WAL::clear() { + std::unique_lock lck{mtx}; + serializer->getWriter()->clear(); +} + +void WAL::reset() { + std::unique_lock lck{mtx}; + fileInfo.reset(); + serializer.reset(); + vfs->removeFileIfExists(walPath); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): semantically non-const function. +void WAL::flushAndSyncNoLock() { + serializer->getWriter()->flush(); + serializer->getWriter()->sync(); +} + +uint64_t WAL::getFileSize() { + std::unique_lock lck{mtx}; + return serializer->getWriter()->getSize(); +} + +void WAL::writeHeader(main::ClientContext& context) { + serializer->getWriter()->onObjectBegin(); + FileDBIDUtils::writeDatabaseID(*serializer, + StorageManager::Get(context)->getOrInitDatabaseID(context)); + serializer->write(enableChecksums); + serializer->getWriter()->onObjectEnd(); +} + +void WAL::initWriter(main::ClientContext* context) { + if (serializer) { + return; + } + fileInfo = vfs->openFile(walPath, + FileOpenFlags(FileFlags::CREATE_IF_NOT_EXISTS | FileFlags::READ_ONLY | FileFlags::WRITE), + context); + + std::shared_ptr writer = std::make_shared(*fileInfo); + auto& bufferedWriter = writer->cast(); + if (enableChecksums) { + writer = std::make_shared(std::move(writer), *MemoryManager::Get(*context)); + } + serializer = std::make_unique(std::move(writer)); + + // Write the databaseID at the start of the WAL if needed + // This is used to ensure that when replaying the WAL matches the database + if (fileInfo->getFileSize() == 0) { + writeHeader(*context); + } + + // WAL should always be APPEND only. We don't want to overwrite the file as it may still + // contain records not replayed. This can happen if checkpoint is not triggered before the + // Database is closed last time. + bufferedWriter.setFileOffset(fileInfo->getFileSize()); +} + +// NOLINTNEXTLINE(readability-make-member-function-const): semantically non-const function. +void WAL::addNewWALRecordNoLock(const WALRecord& walRecord) { + KU_ASSERT(walRecord.type != WALRecordType::INVALID_RECORD); + KU_ASSERT(!inMemory); + KU_ASSERT(serializer != nullptr); + serializer->getWriter()->onObjectBegin(); + walRecord.serialize(*serializer); + serializer->getWriter()->onObjectEnd(); +} + +WAL* WAL::Get(const main::ClientContext& context) { + KU_ASSERT(context.getDatabase() && context.getDatabase()->getStorageManager()); + return &context.getDatabase()->getStorageManager()->getWAL(); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal_record.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal_record.cpp new file mode 100644 index 0000000000..ccd12cb94c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal_record.cpp @@ -0,0 +1,486 @@ +#include "storage/wal/wal_record.h" + +#include "catalog/catalog_entry/catalog_entry.h" +#include "common/exception/runtime.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" +#include "main/client_context.h" +#include "storage/buffer_manager/memory_manager.h" + +using namespace lbug::common; +using namespace lbug::binder; + +namespace lbug { +namespace storage { + +void WALRecord::serialize(Serializer& serializer) const { + serializer.writeDebuggingInfo("type"); + serializer.write(type); +} + +std::unique_ptr WALRecord::deserialize(Deserializer& deserializer, + const main::ClientContext& clientContext) { + std::string key; + auto type = WALRecordType::INVALID_RECORD; + deserializer.getReader()->onObjectBegin(); + deserializer.validateDebuggingInfo(key, "type"); + deserializer.deserializeValue(type); + std::unique_ptr walRecord; + switch (type) { + case WALRecordType::BEGIN_TRANSACTION_RECORD: { + walRecord = BeginTransactionRecord::deserialize(deserializer); + } break; + case WALRecordType::COMMIT_RECORD: { + walRecord = CommitRecord::deserialize(deserializer); + } break; + case WALRecordType::CREATE_CATALOG_ENTRY_RECORD: { + walRecord = CreateCatalogEntryRecord::deserialize(deserializer); + } break; + case WALRecordType::DROP_CATALOG_ENTRY_RECORD: { + walRecord = DropCatalogEntryRecord::deserialize(deserializer); + } break; + case WALRecordType::ALTER_TABLE_ENTRY_RECORD: { + walRecord = AlterTableEntryRecord::deserialize(deserializer); + } break; + case WALRecordType::TABLE_INSERTION_RECORD: { + walRecord = TableInsertionRecord::deserialize(deserializer, clientContext); + } break; + case WALRecordType::NODE_DELETION_RECORD: { + walRecord = NodeDeletionRecord::deserialize(deserializer, clientContext); + } break; + case WALRecordType::NODE_UPDATE_RECORD: { + walRecord = NodeUpdateRecord::deserialize(deserializer, clientContext); + } break; + case WALRecordType::REL_DELETION_RECORD: { + walRecord = RelDeletionRecord::deserialize(deserializer, clientContext); + } break; + case WALRecordType::REL_DETACH_DELETE_RECORD: { + walRecord = RelDetachDeleteRecord::deserialize(deserializer, clientContext); + } break; + case WALRecordType::REL_UPDATE_RECORD: { + walRecord = RelUpdateRecord::deserialize(deserializer, clientContext); + } break; + case WALRecordType::COPY_TABLE_RECORD: { + walRecord = CopyTableRecord::deserialize(deserializer); + } break; + case WALRecordType::CHECKPOINT_RECORD: { + walRecord = CheckpointRecord::deserialize(deserializer); + } break; + case WALRecordType::UPDATE_SEQUENCE_RECORD: { + walRecord = UpdateSequenceRecord::deserialize(deserializer); + } break; + case WALRecordType::LOAD_EXTENSION_RECORD: { + walRecord = LoadExtensionRecord::deserialize(deserializer); + } break; + case WALRecordType::INVALID_RECORD: { + throw RuntimeException("Corrupted wal file. Read out invalid WAL record type."); + } + default: { + KU_UNREACHABLE; + } + } + walRecord->type = type; + deserializer.getReader()->onObjectEnd(); + return walRecord; +} + +void BeginTransactionRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); +} + +std::unique_ptr BeginTransactionRecord::deserialize(Deserializer&) { + return std::make_unique(); +} + +void CommitRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); +} + +std::unique_ptr CommitRecord::deserialize(Deserializer&) { + return std::make_unique(); +} + +void CheckpointRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); +} + +std::unique_ptr CheckpointRecord::deserialize(Deserializer&) { + return std::make_unique(); +} + +void CreateCatalogEntryRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + catalogEntry->serialize(serializer); + serializer.serializeValue(isInternal); +} + +std::unique_ptr CreateCatalogEntryRecord::deserialize( + Deserializer& deserializer) { + auto retVal = std::make_unique(); + retVal->ownedCatalogEntry = catalog::CatalogEntry::deserialize(deserializer); + bool isInternal = false; + deserializer.deserializeValue(isInternal); + retVal->isInternal = isInternal; + return retVal; +} + +void CopyTableRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.write(tableID); +} + +std::unique_ptr CopyTableRecord::deserialize(Deserializer& deserializer) { + auto retVal = std::make_unique(); + deserializer.deserializeValue(retVal->tableID); + return retVal; +} + +void DropCatalogEntryRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.write(entryID); + serializer.write(entryType); +} + +std::unique_ptr DropCatalogEntryRecord::deserialize( + Deserializer& deserializer) { + auto retVal = std::make_unique(); + deserializer.deserializeValue(retVal->entryID); + deserializer.deserializeValue(retVal->entryType); + return retVal; +} + +static void serializeAlterExtraInfo(Serializer& serializer, const BoundAlterInfo* alterInfo) { + const auto* extraInfo = alterInfo->extraInfo.get(); + serializer.write(alterInfo->alterType); + serializer.write(alterInfo->tableName); + switch (alterInfo->alterType) { + case AlterType::ADD_PROPERTY: { + auto addInfo = extraInfo->constPtrCast(); + addInfo->propertyDefinition.serialize(serializer); + } break; + case AlterType::DROP_PROPERTY: { + auto dropInfo = extraInfo->constPtrCast(); + serializer.write(dropInfo->propertyName); + } break; + case AlterType::RENAME_PROPERTY: { + auto renamePropertyInfo = extraInfo->constPtrCast(); + serializer.write(renamePropertyInfo->newName); + serializer.write(renamePropertyInfo->oldName); + } break; + case AlterType::COMMENT: { + auto commentInfo = extraInfo->constPtrCast(); + serializer.write(commentInfo->comment); + } break; + case AlterType::RENAME: { + auto renameTableInfo = extraInfo->constPtrCast(); + serializer.write(renameTableInfo->newName); + } break; + case AlterType::ADD_FROM_TO_CONNECTION: + case AlterType::DROP_FROM_TO_CONNECTION: { + auto connectionInfo = extraInfo->constPtrCast(); + serializer.write(connectionInfo->fromTableID); + serializer.write(connectionInfo->toTableID); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +static decltype(auto) deserializeAlterRecord(Deserializer& deserializer) { + auto alterType = AlterType::INVALID; + std::string tableName; + deserializer.deserializeValue(alterType); + deserializer.deserializeValue(tableName); + std::unique_ptr extraInfo; + switch (alterType) { + case AlterType::ADD_PROPERTY: { + auto definition = PropertyDefinition::deserialize(deserializer); + extraInfo = std::make_unique(std::move(definition), nullptr); + } break; + case AlterType::DROP_PROPERTY: { + std::string propertyName; + deserializer.deserializeValue(propertyName); + extraInfo = std::make_unique(std::move(propertyName)); + } break; + case AlterType::RENAME_PROPERTY: { + std::string newName; + std::string oldName; + deserializer.deserializeValue(newName); + deserializer.deserializeValue(oldName); + extraInfo = + std::make_unique(std::move(newName), std::move(oldName)); + } break; + case AlterType::COMMENT: { + std::string comment; + deserializer.deserializeValue(comment); + extraInfo = std::make_unique(std::move(comment)); + } break; + case AlterType::RENAME: { + std::string newName; + deserializer.deserializeValue(newName); + extraInfo = std::make_unique(std::move(newName)); + } break; + case AlterType::ADD_FROM_TO_CONNECTION: + case AlterType::DROP_FROM_TO_CONNECTION: { + table_id_t fromTableID = INVALID_TABLE_ID; + table_id_t toTableID = INVALID_TABLE_ID; + deserializer.deserializeValue(fromTableID); + deserializer.deserializeValue(toTableID); + extraInfo = std::make_unique(fromTableID, toTableID); + } break; + default: { + KU_UNREACHABLE; + } + } + return std::make_tuple(alterType, tableName, std::move(extraInfo)); +} + +void AlterTableEntryRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializeAlterExtraInfo(serializer, alterInfo); +} + +std::unique_ptr AlterTableEntryRecord::deserialize( + Deserializer& deserializer) { + auto [alterType, tableName, extraInfo] = deserializeAlterRecord(deserializer); + auto retval = std::make_unique(); + retval->ownedAlterInfo = + std::make_unique(alterType, tableName, std::move(extraInfo)); + return retval; +} + +void UpdateSequenceRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.write(sequenceID); + serializer.write(kCount); +} + +std::unique_ptr UpdateSequenceRecord::deserialize( + Deserializer& deserializer) { + auto retVal = std::make_unique(); + deserializer.deserializeValue(retVal->sequenceID); + deserializer.deserializeValue(retVal->kCount); + return retVal; +} + +void TableInsertionRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("table_id"); + serializer.write(tableID); + serializer.writeDebuggingInfo("table_type"); + serializer.write(tableType); + serializer.writeDebuggingInfo("num_rows"); + serializer.write(numRows); + serializer.writeDebuggingInfo("num_vectors"); + serializer.write(vectors.size()); + for (auto& vector : vectors) { + vector->serialize(serializer); + } +} + +std::unique_ptr TableInsertionRecord::deserialize(Deserializer& deserializer, + const main::ClientContext& clientContext) { + std::string key; + table_id_t tableID = INVALID_TABLE_ID; + auto tableType = TableType::UNKNOWN; + row_idx_t numRows = INVALID_ROW_IDX; + idx_t numVectors = 0; + std::vector> valueVectors; + deserializer.validateDebuggingInfo(key, "table_id"); + deserializer.deserializeValue(tableID); + deserializer.validateDebuggingInfo(key, "table_type"); + deserializer.deserializeValue(tableType); + deserializer.validateDebuggingInfo(key, "num_rows"); + deserializer.deserializeValue(numRows); + deserializer.validateDebuggingInfo(key, "num_vectors"); + deserializer.deserializeValue(numVectors); + auto resultChunkState = DataChunkState::getSingleValueDataChunkState(); + valueVectors.reserve(numVectors); + for (auto i = 0u; i < numVectors; i++) { + valueVectors.push_back(ValueVector::deSerialize(deserializer, + MemoryManager::Get(clientContext), resultChunkState)); + } + return std::make_unique(tableID, tableType, numRows, + std::move(valueVectors)); +} + +void NodeDeletionRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("table_id"); + serializer.write(tableID); + serializer.writeDebuggingInfo("node_offset"); + serializer.write(nodeOffset); + serializer.writeDebuggingInfo("pk_vector"); + pkVector->serialize(serializer); +} + +std::unique_ptr NodeDeletionRecord::deserialize(Deserializer& deserializer, + const main::ClientContext& clientContext) { + std::string key; + table_id_t tableID = INVALID_TABLE_ID; + offset_t nodeOffset = INVALID_OFFSET; + + deserializer.validateDebuggingInfo(key, "table_id"); + deserializer.deserializeValue(tableID); + deserializer.validateDebuggingInfo(key, "node_offset"); + deserializer.deserializeValue(nodeOffset); + deserializer.validateDebuggingInfo(key, "pk_vector"); + auto resultChunkState = std::make_shared(); + auto ownedVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + return std::make_unique(tableID, nodeOffset, std::move(ownedVector)); +} + +void NodeUpdateRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("table_id"); + serializer.write(tableID); + serializer.writeDebuggingInfo("column_id"); + serializer.write(columnID); + serializer.writeDebuggingInfo("node_offset"); + serializer.write(nodeOffset); + serializer.writeDebuggingInfo("property_vector"); + propertyVector->serialize(serializer); +} + +std::unique_ptr NodeUpdateRecord::deserialize(Deserializer& deserializer, + const main::ClientContext& clientContext) { + std::string key; + table_id_t tableID = INVALID_TABLE_ID; + column_id_t columnID = INVALID_COLUMN_ID; + offset_t nodeOffset = INVALID_OFFSET; + + deserializer.validateDebuggingInfo(key, "table_id"); + deserializer.deserializeValue(tableID); + deserializer.validateDebuggingInfo(key, "column_id"); + deserializer.deserializeValue(columnID); + deserializer.validateDebuggingInfo(key, "node_offset"); + deserializer.deserializeValue(nodeOffset); + deserializer.validateDebuggingInfo(key, "property_vector"); + auto resultChunkState = std::make_shared(); + auto ownedVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + return std::make_unique(tableID, columnID, nodeOffset, + std::move(ownedVector)); +} + +void RelDeletionRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("table_id"); + serializer.write(tableID); + serializer.writeDebuggingInfo("src_node_vector"); + srcNodeIDVector->serialize(serializer); + serializer.writeDebuggingInfo("dst_node_vector"); + dstNodeIDVector->serialize(serializer); + serializer.writeDebuggingInfo("rel_id_vector"); + relIDVector->serialize(serializer); +} + +std::unique_ptr RelDeletionRecord::deserialize(Deserializer& deserializer, + const main::ClientContext& clientContext) { + std::string key; + table_id_t tableID = INVALID_TABLE_ID; + + deserializer.validateDebuggingInfo(key, "table_id"); + deserializer.deserializeValue(tableID); + deserializer.validateDebuggingInfo(key, "src_node_vector"); + auto resultChunkState = std::make_shared(); + auto srcNodeIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + deserializer.validateDebuggingInfo(key, "dst_node_vector"); + auto dstNodeIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + deserializer.validateDebuggingInfo(key, "rel_id_vector"); + auto relIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + return std::make_unique(tableID, std::move(srcNodeIDVector), + std::move(dstNodeIDVector), std::move(relIDVector)); +} + +void RelDetachDeleteRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("table_id"); + serializer.write(tableID); + serializer.writeDebuggingInfo("direction"); + serializer.write(direction); + serializer.writeDebuggingInfo("src_node_vector"); + srcNodeIDVector->serialize(serializer); +} + +std::unique_ptr RelDetachDeleteRecord::deserialize( + Deserializer& deserializer, const main::ClientContext& clientContext) { + std::string key; + table_id_t tableID = INVALID_TABLE_ID; + auto direction = RelDataDirection::INVALID; + + deserializer.validateDebuggingInfo(key, "table_id"); + deserializer.deserializeValue(tableID); + deserializer.validateDebuggingInfo(key, "direction"); + deserializer.deserializeValue(direction); + deserializer.validateDebuggingInfo(key, "src_node_vector"); + auto resultChunkState = std::make_shared(); + auto srcNodeIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + return std::make_unique(tableID, direction, std::move(srcNodeIDVector)); +} + +void RelUpdateRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("table_id"); + serializer.write(tableID); + serializer.writeDebuggingInfo("column_id"); + serializer.write(columnID); + serializer.writeDebuggingInfo("src_node_vector"); + srcNodeIDVector->serialize(serializer); + serializer.writeDebuggingInfo("dst_node_vector"); + dstNodeIDVector->serialize(serializer); + serializer.writeDebuggingInfo("rel_id_vector"); + relIDVector->serialize(serializer); + serializer.writeDebuggingInfo("property_vector"); + propertyVector->serialize(serializer); +} + +std::unique_ptr RelUpdateRecord::deserialize(Deserializer& deserializer, + const main::ClientContext& clientContext) { + std::string key; + table_id_t tableID = INVALID_TABLE_ID; + column_id_t columnID = INVALID_COLUMN_ID; + + deserializer.validateDebuggingInfo(key, "table_id"); + deserializer.deserializeValue(tableID); + deserializer.validateDebuggingInfo(key, "column_id"); + deserializer.deserializeValue(columnID); + deserializer.validateDebuggingInfo(key, "src_node_vector"); + auto resultChunkState = std::make_shared(); + auto srcNodeIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + deserializer.validateDebuggingInfo(key, "dst_node_vector"); + auto dstNodeIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + deserializer.validateDebuggingInfo(key, "rel_id_vector"); + auto relIDVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + deserializer.validateDebuggingInfo(key, "property_vector"); + auto propertyVector = + ValueVector::deSerialize(deserializer, MemoryManager::Get(clientContext), resultChunkState); + return std::make_unique(tableID, columnID, std::move(srcNodeIDVector), + std::move(dstNodeIDVector), std::move(relIDVector), std::move(propertyVector)); +} + +void LoadExtensionRecord::serialize(Serializer& serializer) const { + WALRecord::serialize(serializer); + serializer.writeDebuggingInfo("path"); + serializer.write(path); +} + +std::unique_ptr LoadExtensionRecord::deserialize(Deserializer& deserializer) { + std::string key; + deserializer.validateDebuggingInfo(key, "path"); + std::string path; + deserializer.deserializeValue(path); + return std::make_unique(std::move(path)); +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal_replayer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal_replayer.cpp new file mode 100644 index 0000000000..6f1996dd44 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/storage/wal/wal_replayer.cpp @@ -0,0 +1,592 @@ +#include "storage/wal/wal_replayer.h" + +#include "binder/binder.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "catalog/catalog_entry/type_catalog_entry.h" +#include "common/file_system/file_info.h" +#include "common/file_system/file_system.h" +#include "common/file_system/virtual_file_system.h" +#include "common/serializer/buffered_file.h" +#include "extension/extension_manager.h" +#include "main/client_context.h" +#include "processor/expression_mapper.h" +#include "storage/file_db_id_utils.h" +#include "storage/local_storage/local_rel_table.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" +#include "storage/wal/checksum_reader.h" +#include "storage/wal/wal_record.h" +#include "transaction/transaction_context.h" + +using namespace lbug::binder; +using namespace lbug::catalog; +using namespace lbug::common; +using namespace lbug::processor; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace storage { + +static constexpr std::string_view checksumMismatchMessage = + "Checksum verification failed, the WAL file is corrupted."; + +WALReplayer::WALReplayer(main::ClientContext& clientContext) : clientContext{clientContext} { + walPath = StorageUtils::getWALFilePath(clientContext.getDatabasePath()); + shadowFilePath = StorageUtils::getShadowFilePath(clientContext.getDatabasePath()); +} + +static WALHeader readWALHeader(Deserializer& deserializer) { + WALHeader header{}; + deserializer.deserializeValue(header.databaseID); + + // It is possible to read a value other than 0/1 when deserializing the flag + // This causes some weird behaviours with some toolchains so we manually do the conversion here + uint8_t enableChecksumsBytes = 0; + deserializer.deserializeValue(enableChecksumsBytes); + header.enableChecksums = enableChecksumsBytes != 0; + + return header; +} + +static Deserializer initDeserializer(FileInfo& fileInfo, main::ClientContext& clientContext, + bool enableChecksums) { + if (enableChecksums) { + return Deserializer{std::make_unique(fileInfo, + *MemoryManager::Get(clientContext), checksumMismatchMessage)}; + } else { + return Deserializer{std::make_unique(fileInfo)}; + } +} + +static void checkWALHeader(const WALHeader& header, bool enableChecksums) { + if (enableChecksums != header.enableChecksums) { + throw RuntimeException(stringFormat( + "The database you are trying to open was serialized with enableChecksums={} but you " + "are trying to open it with enableChecksums={}. Please open your database using the " + "correct enableChecksums config. If you wish to change this for your database, please " + "use the export/import functionality.", + TypeUtils::toString(header.enableChecksums), TypeUtils::toString(enableChecksums))); + } +} + +static uint64_t getReadOffset(Deserializer& deSer, bool enableChecksums) { + if (enableChecksums) { + return deSer.getReader()->cast()->getReadOffset(); + } else { + return deSer.getReader()->cast()->getReadOffset(); + } +} + +void WALReplayer::replay(bool throwOnWalReplayFailure, bool enableChecksums) const { + auto vfs = VirtualFileSystem::GetUnsafe(clientContext); + Checkpointer checkpointer(clientContext); + // First, check if the WAL file exists. If it does not, we can safely remove the shadow file. + if (!vfs->fileOrPathExists(walPath, &clientContext)) { + removeFileIfExists(shadowFilePath); + // Read the checkpointed data from the disk. + checkpointer.readCheckpoint(); + return; + } + // If the WAL file exists, we need to replay it. + auto fileInfo = openWALFile(); + // Check if the wal file is empty. If so, we do not need to replay anything. + if (fileInfo->getFileSize() == 0) { + removeWALAndShadowFiles(); + // Read the checkpointed data from the disk. + checkpointer.readCheckpoint(); + return; + } + // A previous unclean exit may have left non-durable contents in the WAL, so before we start + // replaying the WAL records, make a best-effort attempt at ensuring the WAL is fully durable. + syncWALFile(*fileInfo); + + // Start replaying the WAL records. + try { + // First, we dry run the replay to find out the offset of the last record that was + // CHECKPOINT or COMMIT. + auto [offsetDeserialized, isLastRecordCheckpoint] = + dryReplay(*fileInfo, throwOnWalReplayFailure, enableChecksums); + if (isLastRecordCheckpoint) { + // If the last record is a checkpoint, we resume by replaying the shadow file. + ShadowFile::replayShadowPageRecords(clientContext); + removeWALAndShadowFiles(); + // Re-read checkpointed data from disk again as now the shadow file is applied. + checkpointer.readCheckpoint(); + } else { + // There is no checkpoint record, so we should remove the shadow file if it exists. + removeFileIfExists(shadowFilePath); + // Read the checkpointed data from the disk. + checkpointer.readCheckpoint(); + // Resume by replaying the WAL file from the beginning until the last COMMIT record. + Deserializer deserializer = initDeserializer(*fileInfo, clientContext, enableChecksums); + + if (offsetDeserialized > 0) { + // Make sure the WAL file is for the current database + deserializer.getReader()->onObjectBegin(); + const auto walHeader = readWALHeader(deserializer); + FileDBIDUtils::verifyDatabaseID(*fileInfo, + StorageManager::Get(clientContext)->getOrInitDatabaseID(clientContext), + walHeader.databaseID); + deserializer.getReader()->onObjectEnd(); + } + + while (getReadOffset(deserializer, enableChecksums) < offsetDeserialized) { + KU_ASSERT(!deserializer.finished()); + auto walRecord = WALRecord::deserialize(deserializer, clientContext); + replayWALRecord(*walRecord); + } + // After replaying all the records, we should truncate the WAL file to the last + // COMMIT/CHECKPOINT record. + truncateWALFile(*fileInfo, offsetDeserialized); + } + } catch (const std::exception&) { + auto transactionContext = TransactionContext::Get(clientContext); + if (transactionContext->hasActiveTransaction()) { + // Handle the case that some transaction went during replaying. We should roll back + // under this case. Usually this shouldn't happen, but it is possible if we have a bug + // with the replay logic. This is to handle cases like that so we don't corrupt + // transactions that have been replayed. + transactionContext->rollback(); + } + throw; + } +} + +WALReplayer::WALReplayInfo WALReplayer::dryReplay(FileInfo& fileInfo, bool throwOnWalReplayFailure, + bool enableChecksums) const { + uint64_t offsetDeserialized = 0; + bool isLastRecordCheckpoint = false; + try { + Deserializer deserializer = initDeserializer(fileInfo, clientContext, enableChecksums); + + // Skip the databaseID here, we'll verify it when we actually replay + deserializer.getReader()->onObjectBegin(); + const auto walHeader = readWALHeader(deserializer); + checkWALHeader(walHeader, enableChecksums); + deserializer.getReader()->onObjectEnd(); + + bool finishedDeserializing = deserializer.finished(); + while (!finishedDeserializing) { + auto walRecord = WALRecord::deserialize(deserializer, clientContext); + finishedDeserializing = deserializer.finished(); + switch (walRecord->type) { + case WALRecordType::CHECKPOINT_RECORD: { + KU_ASSERT(finishedDeserializing); + // If we reach a checkpoint record, we can stop replaying. + isLastRecordCheckpoint = true; + finishedDeserializing = true; + offsetDeserialized = getReadOffset(deserializer, enableChecksums); + } break; + case WALRecordType::COMMIT_RECORD: { + // Update the offset to the end of the last commit record. + offsetDeserialized = getReadOffset(deserializer, enableChecksums); + } break; + default: { + // DO NOTHING. + } + } + } + } catch (...) { + // If we hit an exception while deserializing, we assume that the WAL file is (partially) + // corrupted. This should only happen for records of the last transaction recorded. + if (throwOnWalReplayFailure) { + throw; + } + } + return {offsetDeserialized, isLastRecordCheckpoint}; +} + +void WALReplayer::replayWALRecord(WALRecord& walRecord) const { + switch (walRecord.type) { + case WALRecordType::BEGIN_TRANSACTION_RECORD: { + TransactionContext::Get(clientContext)->beginRecoveryTransaction(); + } break; + case WALRecordType::COMMIT_RECORD: { + TransactionContext::Get(clientContext)->commit(); + } break; + case WALRecordType::CREATE_CATALOG_ENTRY_RECORD: { + replayCreateCatalogEntryRecord(walRecord); + } break; + case WALRecordType::DROP_CATALOG_ENTRY_RECORD: { + replayDropCatalogEntryRecord(walRecord); + } break; + case WALRecordType::ALTER_TABLE_ENTRY_RECORD: { + replayAlterTableEntryRecord(walRecord); + } break; + case WALRecordType::TABLE_INSERTION_RECORD: { + replayTableInsertionRecord(walRecord); + } break; + case WALRecordType::NODE_DELETION_RECORD: { + replayNodeDeletionRecord(walRecord); + } break; + case WALRecordType::NODE_UPDATE_RECORD: { + replayNodeUpdateRecord(walRecord); + } break; + case WALRecordType::REL_DELETION_RECORD: { + replayRelDeletionRecord(walRecord); + } break; + case WALRecordType::REL_DETACH_DELETE_RECORD: { + replayRelDetachDeletionRecord(walRecord); + } break; + case WALRecordType::REL_UPDATE_RECORD: { + replayRelUpdateRecord(walRecord); + } break; + case WALRecordType::COPY_TABLE_RECORD: { + replayCopyTableRecord(walRecord); + } break; + case WALRecordType::UPDATE_SEQUENCE_RECORD: { + replayUpdateSequenceRecord(walRecord); + } break; + case WALRecordType::LOAD_EXTENSION_RECORD: { + replayLoadExtensionRecord(walRecord); + } break; + case WALRecordType::CHECKPOINT_RECORD: { + // This record should not be replayed. It is only used to indicate that the previous records + // had been replayed and shadow files are created. + KU_UNREACHABLE; + } + default: + KU_UNREACHABLE; + } +} + +void WALReplayer::replayCreateCatalogEntryRecord(WALRecord& walRecord) const { + auto catalog = Catalog::Get(clientContext); + auto transaction = transaction::Transaction::Get(clientContext); + auto storageManager = StorageManager::Get(clientContext); + auto& record = walRecord.cast(); + switch (record.ownedCatalogEntry->getType()) { + case CatalogEntryType::NODE_TABLE_ENTRY: + case CatalogEntryType::REL_GROUP_ENTRY: { + auto& entry = record.ownedCatalogEntry->constCast(); + auto newEntry = catalog->createTableEntry(transaction, + entry.getBoundCreateTableInfo(transaction, record.isInternal)); + storageManager->createTable(newEntry->ptrCast()); + } break; + case CatalogEntryType::SCALAR_MACRO_ENTRY: { + auto& macroEntry = record.ownedCatalogEntry->constCast(); + catalog->addScalarMacroFunction(transaction, macroEntry.getName(), + macroEntry.getMacroFunction()->copy()); + } break; + case CatalogEntryType::SEQUENCE_ENTRY: { + auto& sequenceEntry = record.ownedCatalogEntry->constCast(); + catalog->createSequence(transaction, + sequenceEntry.getBoundCreateSequenceInfo(record.isInternal)); + } break; + case CatalogEntryType::TYPE_ENTRY: { + auto& typeEntry = record.ownedCatalogEntry->constCast(); + catalog->createType(transaction, typeEntry.getName(), typeEntry.getLogicalType().copy()); + } break; + case CatalogEntryType::INDEX_ENTRY: { + catalog->createIndex(transaction, std::move(record.ownedCatalogEntry)); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +void WALReplayer::replayDropCatalogEntryRecord(const WALRecord& walRecord) const { + auto& dropEntryRecord = walRecord.constCast(); + auto catalog = Catalog::Get(clientContext); + auto transaction = transaction::Transaction::Get(clientContext); + const auto entryID = dropEntryRecord.entryID; + switch (dropEntryRecord.entryType) { + case CatalogEntryType::NODE_TABLE_ENTRY: + case CatalogEntryType::REL_GROUP_ENTRY: { + KU_ASSERT(Catalog::Get(clientContext)); + catalog->dropTableEntry(transaction, entryID); + } break; + case CatalogEntryType::SEQUENCE_ENTRY: { + catalog->dropSequence(transaction, entryID); + } break; + case CatalogEntryType::INDEX_ENTRY: { + catalog->dropIndex(transaction, entryID); + } break; + case CatalogEntryType::SCALAR_MACRO_ENTRY: { + catalog->dropMacroEntry(transaction, entryID); + } break; + default: { + KU_UNREACHABLE; + } + } +} + +void WALReplayer::replayAlterTableEntryRecord(const WALRecord& walRecord) const { + auto binder = Binder(&clientContext); + auto& alterEntryRecord = walRecord.constCast(); + auto catalog = Catalog::Get(clientContext); + auto transaction = transaction::Transaction::Get(clientContext); + auto storageManager = StorageManager::Get(clientContext); + auto ownedAlterInfo = alterEntryRecord.ownedAlterInfo.get(); + catalog->alterTableEntry(transaction, *ownedAlterInfo); + auto& pageAllocator = *PageManager::Get(clientContext); + switch (ownedAlterInfo->alterType) { + case AlterType::ADD_PROPERTY: { + const auto exprBinder = binder.getExpressionBinder(); + const auto addInfo = ownedAlterInfo->extraInfo->constPtrCast(); + // We don't implicit cast here since it must already be done the first time + const auto boundDefault = + exprBinder->bindExpression(*addInfo->propertyDefinition.defaultExpr); + auto exprMapper = ExpressionMapper(); + const auto defaultValueEvaluator = exprMapper.getEvaluator(boundDefault); + defaultValueEvaluator->init(ResultSet(0) /* dummy ResultSet */, &clientContext); + const auto entry = catalog->getTableCatalogEntry(transaction, ownedAlterInfo->tableName); + const auto& addedProp = entry->getProperty(addInfo->propertyDefinition.getName()); + TableAddColumnState state{addedProp, *defaultValueEvaluator}; + KU_ASSERT(StorageManager::Get(clientContext)); + switch (entry->getTableType()) { + case TableType::REL: { + for (auto& relEntryInfo : entry->cast().getRelEntryInfos()) { + storageManager->getTable(relEntryInfo.oid) + ->addColumn(transaction, state, pageAllocator); + } + } break; + case TableType::NODE: { + storageManager->getTable(entry->getTableID()) + ->addColumn(transaction, state, pageAllocator); + } break; + default: { + KU_UNREACHABLE; + } + } + } break; + case AlterType::ADD_FROM_TO_CONNECTION: { + auto extraInfo = ownedAlterInfo->extraInfo->constPtrCast(); + auto relGroupEntry = catalog->getTableCatalogEntry(transaction, ownedAlterInfo->tableName) + ->ptrCast(); + auto relEntryInfo = + relGroupEntry->getRelEntryInfo(extraInfo->fromTableID, extraInfo->toTableID); + storageManager->addRelTable(relGroupEntry, *relEntryInfo); + } break; + default: + break; + } +} + +void WALReplayer::replayTableInsertionRecord(const WALRecord& walRecord) const { + const auto& insertionRecord = walRecord.constCast(); + switch (insertionRecord.tableType) { + case TableType::NODE: { + replayNodeTableInsertRecord(walRecord); + } break; + case TableType::REL: { + replayRelTableInsertRecord(walRecord); + } break; + default: { + throw RuntimeException("Invalid table type for insertion replay in WAL record."); + } + } +} + +void WALReplayer::replayNodeTableInsertRecord(const WALRecord& walRecord) const { + const auto& insertionRecord = walRecord.constCast(); + const auto tableID = insertionRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + KU_ASSERT(!insertionRecord.ownedVectors.empty()); + const auto anchorState = insertionRecord.ownedVectors[0]->state; + const auto numNodes = anchorState->getSelVector().getSelSize(); + for (auto i = 0u; i < insertionRecord.ownedVectors.size(); i++) { + insertionRecord.ownedVectors[i]->setState(anchorState); + } + std::vector propertyVectors(insertionRecord.ownedVectors.size()); + for (auto i = 0u; i < insertionRecord.ownedVectors.size(); i++) { + propertyVectors[i] = insertionRecord.ownedVectors[i].get(); + } + KU_ASSERT(table.getPKColumnID() < insertionRecord.ownedVectors.size()); + auto& pkVector = *insertionRecord.ownedVectors[table.getPKColumnID()]; + const auto nodeIDVector = std::make_unique(LogicalType::INTERNAL_ID()); + nodeIDVector->setState(anchorState); + const auto insertState = + std::make_unique(*nodeIDVector, pkVector, propertyVectors); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + table.initInsertState(&clientContext, *insertState); + anchorState->getSelVectorUnsafe().setToFiltered(1); + for (auto i = 0u; i < numNodes; i++) { + anchorState->getSelVectorUnsafe()[0] = i; + table.insert(transaction::Transaction::Get(clientContext), *insertState); + } +} + +void WALReplayer::replayRelTableInsertRecord(const WALRecord& walRecord) const { + const auto& insertionRecord = walRecord.constCast(); + const auto tableID = insertionRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + KU_ASSERT(!insertionRecord.ownedVectors.empty()); + const auto anchorState = insertionRecord.ownedVectors[0]->state; + const auto numRels = anchorState->getSelVector().getSelSize(); + anchorState->getSelVectorUnsafe().setToFiltered(1); + for (auto i = 0u; i < insertionRecord.ownedVectors.size(); i++) { + insertionRecord.ownedVectors[i]->setState(anchorState); + } + std::vector propertyVectors; + for (auto i = 0u; i < insertionRecord.ownedVectors.size(); i++) { + if (i < LOCAL_REL_ID_COLUMN_ID) { + // Skip the first two vectors which are the src nodeID and the dst nodeID. + continue; + } + propertyVectors.push_back(insertionRecord.ownedVectors[i].get()); + } + const auto insertState = std::make_unique( + *insertionRecord.ownedVectors[LOCAL_BOUND_NODE_ID_COLUMN_ID], + *insertionRecord.ownedVectors[LOCAL_NBR_NODE_ID_COLUMN_ID], propertyVectors); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + for (auto i = 0u; i < numRels; i++) { + anchorState->getSelVectorUnsafe()[0] = i; + table.initInsertState(&clientContext, *insertState); + table.insert(transaction::Transaction::Get(clientContext), *insertState); + } +} + +void WALReplayer::replayNodeDeletionRecord(const WALRecord& walRecord) const { + const auto& deletionRecord = walRecord.constCast(); + const auto tableID = deletionRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + const auto anchorState = deletionRecord.ownedPKVector->state; + KU_ASSERT(anchorState->getSelVector().getSelSize() == 1); + const auto nodeIDVector = std::make_unique(LogicalType::INTERNAL_ID()); + nodeIDVector->setState(anchorState); + nodeIDVector->setValue(0, + internalID_t{deletionRecord.nodeOffset, deletionRecord.tableID}); + const auto deleteState = + std::make_unique(*nodeIDVector, *deletionRecord.ownedPKVector); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + table.delete_(transaction::Transaction::Get(clientContext), *deleteState); +} + +void WALReplayer::replayNodeUpdateRecord(const WALRecord& walRecord) const { + const auto& updateRecord = walRecord.constCast(); + const auto tableID = updateRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + const auto anchorState = updateRecord.ownedPropertyVector->state; + KU_ASSERT(anchorState->getSelVector().getSelSize() == 1); + const auto nodeIDVector = std::make_unique(LogicalType::INTERNAL_ID()); + nodeIDVector->setState(anchorState); + nodeIDVector->setValue(0, + internalID_t{updateRecord.nodeOffset, updateRecord.tableID}); + const auto updateState = std::make_unique(updateRecord.columnID, + *nodeIDVector, *updateRecord.ownedPropertyVector); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + table.update(transaction::Transaction::Get(clientContext), *updateState); +} + +void WALReplayer::replayRelDeletionRecord(const WALRecord& walRecord) const { + const auto& deletionRecord = walRecord.constCast(); + const auto tableID = deletionRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + const auto anchorState = deletionRecord.ownedRelIDVector->state; + KU_ASSERT(anchorState->getSelVector().getSelSize() == 1); + const auto deleteState = + std::make_unique(*deletionRecord.ownedSrcNodeIDVector, + *deletionRecord.ownedDstNodeIDVector, *deletionRecord.ownedRelIDVector); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + table.delete_(transaction::Transaction::Get(clientContext), *deleteState); +} + +void WALReplayer::replayRelDetachDeletionRecord(const WALRecord& walRecord) const { + const auto& deletionRecord = walRecord.constCast(); + const auto tableID = deletionRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + const auto anchorState = deletionRecord.ownedSrcNodeIDVector->state; + KU_ASSERT(anchorState->getSelVector().getSelSize() == 1); + const auto dstNodeIDVector = + std::make_unique(LogicalType{LogicalTypeID::INTERNAL_ID}); + const auto relIDVector = std::make_unique(LogicalType{LogicalTypeID::INTERNAL_ID}); + dstNodeIDVector->setState(anchorState); + relIDVector->setState(anchorState); + const auto deleteState = std::make_unique( + *deletionRecord.ownedSrcNodeIDVector, *dstNodeIDVector, *relIDVector); + deleteState->detachDeleteDirection = deletionRecord.direction; + table.detachDelete(transaction::Transaction::Get(clientContext), deleteState.get()); +} + +void WALReplayer::replayRelUpdateRecord(const WALRecord& walRecord) const { + const auto& updateRecord = walRecord.constCast(); + const auto tableID = updateRecord.tableID; + auto& table = StorageManager::Get(clientContext)->getTable(tableID)->cast(); + const auto anchorState = updateRecord.ownedRelIDVector->state; + KU_ASSERT(anchorState == updateRecord.ownedSrcNodeIDVector->state && + anchorState == updateRecord.ownedSrcNodeIDVector->state && + anchorState == updateRecord.ownedPropertyVector->state); + KU_ASSERT(anchorState->getSelVector().getSelSize() == 1); + const auto updateState = std::make_unique(updateRecord.columnID, + *updateRecord.ownedSrcNodeIDVector, *updateRecord.ownedDstNodeIDVector, + *updateRecord.ownedRelIDVector, *updateRecord.ownedPropertyVector); + KU_ASSERT(transaction::Transaction::Get(clientContext) && + transaction::Transaction::Get(clientContext)->isRecovery()); + table.update(transaction::Transaction::Get(clientContext), *updateState); +} + +void WALReplayer::replayCopyTableRecord(const WALRecord&) const { + // DO NOTHING. +} + +void WALReplayer::replayUpdateSequenceRecord(const WALRecord& walRecord) const { + auto& sequenceEntryRecord = walRecord.constCast(); + const auto sequenceID = sequenceEntryRecord.sequenceID; + const auto entry = + Catalog::Get(clientContext) + ->getSequenceEntry(transaction::Transaction::Get(clientContext), sequenceID); + entry->nextKVal(transaction::Transaction::Get(clientContext), sequenceEntryRecord.kCount); +} + +void WALReplayer::replayLoadExtensionRecord(const WALRecord& walRecord) const { + const auto& loadExtensionRecord = walRecord.constCast(); + extension::ExtensionManager::Get(clientContext) + ->loadExtension(loadExtensionRecord.path, &clientContext); +} + +void WALReplayer::removeWALAndShadowFiles() const { + removeFileIfExists(shadowFilePath); + removeFileIfExists(walPath); +} + +void WALReplayer::removeFileIfExists(const std::string& path) const { + if (StorageManager::Get(clientContext)->isReadOnly()) { + return; + } + auto vfs = VirtualFileSystem::GetUnsafe(clientContext); + if (vfs->fileOrPathExists(path, &clientContext)) { + vfs->removeFileIfExists(path); + } +} + +std::unique_ptr WALReplayer::openWALFile() const { + auto flag = FileFlags::READ_ONLY; + if (!StorageManager::Get(clientContext)->isReadOnly()) { + flag |= FileFlags::WRITE; // The write flag here is to ensure the file is opened with O_RDWR + // so that we can sync it. + } + return VirtualFileSystem::GetUnsafe(clientContext)->openFile(walPath, FileOpenFlags(flag)); +} + +void WALReplayer::syncWALFile(const FileInfo& fileInfo) const { + if (StorageManager::Get(clientContext)->isReadOnly()) { + return; + } + fileInfo.syncFile(); +} + +void WALReplayer::truncateWALFile(FileInfo& fileInfo, uint64_t size) const { + if (StorageManager::Get(clientContext)->isReadOnly()) { + return; + } + if (fileInfo.getFileSize() > size) { + fileInfo.truncate(size); + fileInfo.syncFile(); + } +} + +} // namespace storage +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/CMakeLists.txt new file mode 100644 index 0000000000..c833479015 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(lbug_transaction + OBJECT + transaction.cpp + transaction_context.cpp + transaction_manager.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction.cpp new file mode 100644 index 0000000000..c8d027e74e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction.cpp @@ -0,0 +1,222 @@ +#include "transaction/transaction.h" + +#include "common/exception/runtime.h" +#include "main/client_context.h" +#include "main/db_config.h" +#include "storage/local_storage/local_node_table.h" +#include "storage/local_storage/local_storage.h" +#include "storage/storage_manager.h" +#include "storage/undo_buffer.h" +#include "storage/wal/local_wal.h" +#include "transaction/transaction_context.h" + +using namespace lbug::catalog; + +namespace lbug { +namespace transaction { + +bool LocalCacheManager::put(std::unique_ptr object) { + std::unique_lock lck{mtx}; + const auto key = object->getKey(); + if (cachedObjects.contains(key)) { + return false; + } + cachedObjects[object->getKey()] = std::move(object); + return true; +} + +Transaction::Transaction(main::ClientContext& clientContext, TransactionType transactionType, + common::transaction_t transactionID, common::transaction_t startTS) + : type{transactionType}, ID{transactionID}, startTS{startTS}, + commitTS{common::INVALID_TRANSACTION}, forceCheckpoint{false}, hasCatalogChanges{false} { + this->clientContext = &clientContext; + localStorage = std::make_unique(clientContext); + undoBuffer = std::make_unique(storage::MemoryManager::Get(clientContext)); + currentTS = common::Timestamp::getCurrentTimestamp().value; + localWAL = std::make_unique(*storage::MemoryManager::Get(clientContext), + clientContext.getDBConfig()->enableChecksums); +} + +Transaction::Transaction(TransactionType transactionType) noexcept + : type{transactionType}, ID{DUMMY_TRANSACTION_ID}, startTS{DUMMY_START_TIMESTAMP}, + commitTS{common::INVALID_TRANSACTION}, clientContext{nullptr}, undoBuffer{nullptr}, + forceCheckpoint{false}, hasCatalogChanges{false} { + currentTS = common::Timestamp::getCurrentTimestamp().value; +} + +Transaction::Transaction(TransactionType transactionType, common::transaction_t ID, + common::transaction_t startTS) noexcept + : type{transactionType}, ID{ID}, startTS{startTS}, commitTS{common::INVALID_TRANSACTION}, + clientContext{nullptr}, undoBuffer{nullptr}, forceCheckpoint{false}, + hasCatalogChanges{false} { + currentTS = common::Timestamp::getCurrentTimestamp().value; +} + +bool Transaction::shouldLogToWAL() const { + return isWriteTransaction() && !clientContext->isInMemory(); +} + +bool Transaction::shouldForceCheckpoint() const { + return !clientContext->isInMemory() && forceCheckpoint; +} + +void Transaction::commit(storage::WAL* wal) { + localStorage->commit(); + undoBuffer->commit(commitTS); + if (shouldLogToWAL()) { + KU_ASSERT(localWAL && wal); + localWAL->logCommit(); + wal->logCommittedWAL(*localWAL, clientContext); + localWAL->clear(); + } + if (hasCatalogChanges) { + Catalog::Get(*clientContext)->incrementVersion(); + hasCatalogChanges = false; + } +} + +void Transaction::rollback(storage::WAL*) { + // Rolling back the local storage will free + evict all optimistically-allocated pages + // Since the undo buffer may do some scanning (e.g. to delete inserted keys from the hash index) + // this must be rolled back first + undoBuffer->rollback(clientContext); + localStorage->rollback(); + hasCatalogChanges = false; +} + +bool Transaction::isUnCommitted(common::table_id_t tableID, common::offset_t nodeOffset) const { + return localStorage && localStorage->getLocalTable(tableID) && + nodeOffset >= getMinUncommittedNodeOffset(tableID); +} + +void Transaction::pushCreateDropCatalogEntry(CatalogSet& catalogSet, CatalogEntry& catalogEntry, + bool isInternal, bool skipLoggingToWAL) { + undoBuffer->createCatalogEntry(catalogSet, catalogEntry); + hasCatalogChanges = true; + if (!shouldLogToWAL() || skipLoggingToWAL) { + return; + } + KU_ASSERT(localWAL); + const auto newCatalogEntry = catalogEntry.getNext(); + switch (newCatalogEntry->getType()) { + case CatalogEntryType::INDEX_ENTRY: + case CatalogEntryType::NODE_TABLE_ENTRY: + case CatalogEntryType::REL_GROUP_ENTRY: { + if (catalogEntry.getType() == CatalogEntryType::DUMMY_ENTRY) { + KU_ASSERT(catalogEntry.isDeleted()); + localWAL->logCreateCatalogEntryRecord(newCatalogEntry, isInternal); + } else { + throw common::RuntimeException("This shouldn't happen. Alter table is not supported."); + } + } break; + case CatalogEntryType::SEQUENCE_ENTRY: { + KU_ASSERT( + catalogEntry.getType() == CatalogEntryType::DUMMY_ENTRY && catalogEntry.isDeleted()); + if (newCatalogEntry->hasParent()) { + // We don't log SERIAL catalog entry creation as it is implicit + return; + } + localWAL->logCreateCatalogEntryRecord(newCatalogEntry, isInternal); + } break; + case CatalogEntryType::SCALAR_MACRO_ENTRY: + case CatalogEntryType::TYPE_ENTRY: { + KU_ASSERT( + catalogEntry.getType() == CatalogEntryType::DUMMY_ENTRY && catalogEntry.isDeleted()); + localWAL->logCreateCatalogEntryRecord(newCatalogEntry, isInternal); + } break; + case CatalogEntryType::DUMMY_ENTRY: { + KU_ASSERT(newCatalogEntry->isDeleted()); + if (catalogEntry.hasParent()) { + return; + } + switch (catalogEntry.getType()) { + case CatalogEntryType::INDEX_ENTRY: + case CatalogEntryType::SCALAR_MACRO_ENTRY: + case CatalogEntryType::NODE_TABLE_ENTRY: + case CatalogEntryType::REL_GROUP_ENTRY: + case CatalogEntryType::SEQUENCE_ENTRY: { + localWAL->logDropCatalogEntryRecord(catalogEntry.getOID(), catalogEntry.getType()); + } break; + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + case CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY: { + // DO NOTHING. We don't persist function entries. + } break; + case CatalogEntryType::TYPE_ENTRY: + default: { + throw common::RuntimeException( + common::stringFormat("Not supported catalog entry type {} yet.", + CatalogEntryTypeUtils::toString(catalogEntry.getType()))); + } + } + } break; + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + case CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY: { + // DO NOTHING. We don't persist function entries. + } break; + default: { + throw common::RuntimeException( + common::stringFormat("Not supported catalog entry type {} yet.", + CatalogEntryTypeUtils::toString(catalogEntry.getType()))); + } + } +} + +void Transaction::pushAlterCatalogEntry(CatalogSet& catalogSet, CatalogEntry& catalogEntry, + const binder::BoundAlterInfo& alterInfo) { + undoBuffer->createCatalogEntry(catalogSet, catalogEntry); + hasCatalogChanges = true; + if (shouldLogToWAL()) { + KU_ASSERT(localWAL); + localWAL->logAlterCatalogEntryRecord(&alterInfo); + } +} + +void Transaction::pushSequenceChange(SequenceCatalogEntry* sequenceEntry, int64_t kCount, + const SequenceRollbackData& data) { + undoBuffer->createSequenceChange(*sequenceEntry, data); + hasCatalogChanges = true; + if (shouldLogToWAL()) { + KU_ASSERT(localWAL); + localWAL->logUpdateSequenceRecord(sequenceEntry->getOID(), kCount); + } +} + +void Transaction::pushInsertInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const { + undoBuffer->createInsertInfo(nodeGroupIdx, startRow, numRows, versionRecordHandler); +} + +void Transaction::pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, + common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const { + undoBuffer->createDeleteInfo(nodeGroupIdx, startRow, numRows, versionRecordHandler); +} + +void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, + const common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo, + common::transaction_t version) const { + undoBuffer->createVectorUpdateInfo(&updateInfo, vectorIdx, &vectorUpdateInfo, version); +} + +Transaction::~Transaction() = default; + +common::offset_t Transaction::getMinUncommittedNodeOffset(common::table_id_t tableID) const { + if (localStorage && localStorage->getLocalTable(tableID)) { + return localStorage->getLocalTable(tableID) + ->cast() + .getStartOffset(); + } + return 0; +} + +Transaction* Transaction::Get(const main::ClientContext& context) { + return TransactionContext::Get(context)->getActiveTransaction(); +} + +Transaction DUMMY_TRANSACTION = Transaction(TransactionType::DUMMY); +Transaction DUMMY_CHECKPOINT_TRANSACTION = Transaction(TransactionType::CHECKPOINT, + Transaction::DUMMY_TRANSACTION_ID, Transaction::START_TRANSACTION_ID - 1); + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction_context.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction_context.cpp new file mode 100644 index 0000000000..066b0dd02d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction_context.cpp @@ -0,0 +1,88 @@ +#include "transaction/transaction_context.h" + +#include "common/exception/transaction_manager.h" +#include "main/client_context.h" +#include "main/database.h" +#include "transaction/transaction_manager.h" + +using namespace lbug::common; + +namespace lbug { +namespace transaction { + +TransactionContext::TransactionContext(main::ClientContext& clientContext) + : clientContext{clientContext}, mode{TransactionMode::AUTO}, activeTransaction{nullptr} {} + +TransactionContext::~TransactionContext() = default; + +void TransactionContext::beginReadTransaction() { + std::unique_lock lck{mtx}; + mode = TransactionMode::MANUAL; + beginTransactionInternal(TransactionType::READ_ONLY); +} + +void TransactionContext::beginWriteTransaction() { + std::unique_lock lck{mtx}; + mode = TransactionMode::MANUAL; + beginTransactionInternal(TransactionType::WRITE); +} + +void TransactionContext::beginAutoTransaction(bool readOnlyStatement) { + // LCOV_EXCL_START + if (hasActiveTransaction()) { + throw TransactionManagerException( + "Cannot start a new transaction while there is an active transaction."); + } + // LCOV_EXCL_STOP + beginTransactionInternal( + readOnlyStatement ? TransactionType::READ_ONLY : TransactionType::WRITE); +} + +void TransactionContext::beginRecoveryTransaction() { + std::unique_lock lck{mtx}; + mode = TransactionMode::MANUAL; + beginTransactionInternal(TransactionType::RECOVERY); +} + +void TransactionContext::validateManualTransaction(bool readOnlyStatement) const { + KU_ASSERT(hasActiveTransaction()); + if (activeTransaction->isReadOnly() && !readOnlyStatement) { + throw TransactionManagerException( + "Can not execute a write query inside a read-only transaction."); + } +} + +void TransactionContext::commit() { + if (!hasActiveTransaction()) { + return; + } + clientContext.getDatabase()->getTransactionManager()->commit(clientContext, activeTransaction); + clearTransaction(); +} + +void TransactionContext::rollback() { + if (!hasActiveTransaction()) { + return; + } + clientContext.getDatabase()->getTransactionManager()->rollback(clientContext, + activeTransaction); + clearTransaction(); +} + +void TransactionContext::clearTransaction() { + activeTransaction = nullptr; + mode = TransactionMode::AUTO; +} + +TransactionContext* TransactionContext::Get(const main::ClientContext& context) { + return context.transactionContext.get(); +} + +void TransactionContext::beginTransactionInternal(TransactionType transactionType) { + KU_ASSERT(!activeTransaction); + activeTransaction = clientContext.getDatabase()->getTransactionManager()->beginTransaction( + clientContext, transactionType); +} + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction_manager.cpp b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction_manager.cpp new file mode 100644 index 0000000000..157b778132 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/src/transaction/transaction_manager.cpp @@ -0,0 +1,187 @@ +#include "transaction/transaction_manager.h" + +#include + +#include "common/exception/checkpoint.h" +#include "common/exception/transaction_manager.h" +#include "main/attached_database.h" +#include "main/client_context.h" +#include "main/database.h" +#include "main/db_config.h" +#include "storage/checkpointer.h" +#include "storage/wal/local_wal.h" + +using namespace lbug::common; +using namespace lbug::storage; + +namespace lbug { +namespace transaction { + +Transaction* TransactionManager::beginTransaction(main::ClientContext& clientContext, + TransactionType type) { + // We acquire the lock for starting new transactions. In case this cannot be acquired, this + // ensures calls to other public functions are not restricted. + std::unique_lock publicFunctionLck{mtxForSerializingPublicFunctionCalls}; + std::unique_lock newTransactionLck{mtxForStartingNewTransactions}; + switch (type) { + case TransactionType::READ_ONLY: { + auto transaction = + std::make_unique(clientContext, type, ++lastTransactionID, lastTimestamp); + activeTransactions.push_back(std::move(transaction)); + return activeTransactions.back().get(); + } + case TransactionType::RECOVERY: + case TransactionType::WRITE: { + if (!clientContext.getDBConfig()->enableMultiWrites && hasActiveWriteTransactionNoLock()) { + throw TransactionManagerException( + "Cannot start a new write transaction in the system. " + "Only one write transaction at a time is allowed in the system."); + } + auto transaction = + std::make_unique(clientContext, type, ++lastTransactionID, lastTimestamp); + if (transaction->shouldLogToWAL()) { + transaction->getLocalWAL().logBeginTransaction(); + } + activeTransactions.push_back(std::move(transaction)); + return activeTransactions.back().get(); + } + // LCOV_EXCL_START + default: { + throw TransactionManagerException("Invalid transaction type to begin transaction."); + } + // LCOV_EXCL_STOP + } +} + +void TransactionManager::commit(main::ClientContext& clientContext, Transaction* transaction) { + std::unique_lock lck{mtxForSerializingPublicFunctionCalls}; + clientContext.cleanUp(); + switch (transaction->getType()) { + case TransactionType::READ_ONLY: { + clearTransactionNoLock(transaction->getID()); + } break; + case TransactionType::RECOVERY: + case TransactionType::WRITE: { + lastTimestamp++; + transaction->commitTS = lastTimestamp; + transaction->commit(&wal); + auto shouldCheckpoint = transaction->shouldForceCheckpoint() || + Checkpointer::canAutoCheckpoint(clientContext, *transaction); + clearTransactionNoLock(transaction->getID()); + if (shouldCheckpoint) { + checkpointNoLock(clientContext); + } + } break; + // LCOV_EXCL_START + default: { + throw TransactionManagerException("Invalid transaction type to commit."); + } + // LCOV_EXCL_STOP + } +} + +// Note: We take in additional `transaction` here is due to that `transactionContext` might be +// destructed when a transaction throws an exception, while we need to roll back the active +// transaction still. +void TransactionManager::rollback(main::ClientContext& clientContext, Transaction* transaction) { + std::unique_lock lck{mtxForSerializingPublicFunctionCalls}; + clientContext.cleanUp(); + switch (transaction->getType()) { + case TransactionType::READ_ONLY: { + clearTransactionNoLock(transaction->getID()); + } break; + case TransactionType::RECOVERY: + case TransactionType::WRITE: { + transaction->rollback(&wal); + clearTransactionNoLock(transaction->getID()); + } break; + default: { + throw TransactionManagerException("Invalid transaction type to rollback."); + } + } +} + +void TransactionManager::checkpoint(main::ClientContext& clientContext) { + UniqLock lck{mtxForSerializingPublicFunctionCalls}; + if (clientContext.isInMemory()) { + return; + } + checkpointNoLock(clientContext); +} + +TransactionManager* TransactionManager::Get(const main::ClientContext& context) { + if (context.getAttachedDatabase() != nullptr) { + context.getAttachedDatabase()->getTransactionManager(); + } + return context.getDatabase()->getTransactionManager(); +} + +UniqLock TransactionManager::stopNewTransactionsAndWaitUntilAllTransactionsLeave() { + UniqLock startTransactionLock{mtxForStartingNewTransactions}; + uint64_t numTimesWaited = 0; + while (true) { + if (hasNoActiveTransactions()) { + break; + } + numTimesWaited++; + if (numTimesWaited * THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS > + checkpointWaitTimeoutInMicros) { + throw TransactionManagerException( + "Timeout waiting for active transactions to leave the system before " + "checkpointing. If you have an open transaction, please close it and try " + "again."); + } + std::this_thread::sleep_for( + std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS)); + } + return startTransactionLock; +} + +bool TransactionManager::hasNoActiveTransactions() const { + return activeTransactions.empty(); +} + +bool TransactionManager::hasActiveWriteTransactionNoLock() const { + return std::ranges::any_of(activeTransactions, + [](const auto& transaction) { return transaction->isWriteTransaction(); }); +} + +void TransactionManager::clearTransactionNoLock(transaction_t transactionID) { + KU_ASSERT(std::ranges::any_of(activeTransactions.begin(), activeTransactions.end(), + [transactionID](const auto& activeTransaction) { + return activeTransaction->getID() == transactionID; + })); + std::erase_if(activeTransactions, [transactionID](const auto& activeTransaction) { + return activeTransaction->getID() == transactionID; + }); +} + +std::unique_ptr TransactionManager::initCheckpointer( + main::ClientContext& clientContext) { + return std::make_unique(clientContext); +} + +void TransactionManager::checkpointNoLock(main::ClientContext& clientContext) { + // Note: It is enough to stop and wait for transactions to leave the system instead of, for + // example, checking on the query processor's task scheduler. This is because the + // first and last steps that a connection performs when executing a query are to + // start and commit/rollback transaction. The query processor also ensures that it + // will only return results or error after all threads working on the tasks of a + // query stop working on the tasks of the query and these tasks are removed from the + // query. + try { + auto lockForStartingTransaction = stopNewTransactionsAndWaitUntilAllTransactionsLeave(); + } catch (std::exception& e) { + throw CheckpointException{e}; + } + auto checkpointer = initCheckpointerFunc(clientContext); + try { + checkpointer->writeCheckpoint(); + } catch (std::exception& e) { + checkpointer->rollback(); + throw CheckpointException{e}; + } +} + +} // namespace transaction +} // namespace lbug diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/third_party/CMakeLists.txt new file mode 100644 index 0000000000..5554d7759c --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/CMakeLists.txt @@ -0,0 +1,28 @@ +if(NOT MSVC) + add_compile_options(-Wno-extra) +endif() + +add_subdirectory(alp) +add_subdirectory(antlr4_cypher) +add_subdirectory(antlr4_runtime) +set(BROTLI_BUNDLED_MODE ON) +add_subdirectory(brotli) +add_subdirectory(fast_float) +add_subdirectory(fastpfor) +add_subdirectory(glob) +add_subdirectory(lz4) +add_subdirectory(mbedtls) +add_subdirectory(miniz) +add_subdirectory(parquet) +if(${BUILD_PYTHON}) + add_subdirectory(pybind11) +endif() +add_subdirectory(re2) +add_subdirectory(roaring_bitmap) +add_subdirectory(simsimd) +add_subdirectory(snappy) +add_subdirectory(thrift) +add_subdirectory(utf8proc) +add_subdirectory(yyjson) +add_subdirectory(zstd) +add_subdirectory(cppjieba) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/.clang-format b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/.clang-format new file mode 100644 index 0000000000..faae7b654d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/.clang-format @@ -0,0 +1,36 @@ +BasedOnStyle: LLVM +TabWidth: 4 +IndentWidth: 4 +UseTab: ForIndentation +DerivePointerAlignment: false +PointerAlignment: Left +AlignConsecutiveMacros: AcrossEmptyLinesAndComments +AlignAfterOpenBracket: Align +AlignTrailingComments: true +AlignConsecutiveDeclarations: Consecutive +AlignConsecutiveAssignments: Consecutive +AllowAllArgumentsOnNextLine: true +AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpacesInAngles: false +SpacesInCStyleCastParentheses: false +SpacesInConditionalStatement: false +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakTemplateDeclarations: Yes +ColumnLimit: 120 +IncludeBlocks: Merge +SortIncludes: CaseSensitive +Language: Cpp +AccessModifierOffset: -4 +BreakConstructorInitializers: BeforeComma +AllowShortBlocksOnASingleLine: Always +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +CompactNamespaces: true +BinPackArguments: false +BinPackParameters: false + diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/CMakeLists.txt new file mode 100644 index 0000000000..f38d0c0717 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(libalp INTERFACE) + +target_include_directories(libalp INTERFACE include) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/LICENSE b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/LICENSE new file mode 100644 index 0000000000..46301aaa73 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 CWI, Azim Afroozeh, Leonardo Xavier Kuffo Rivero + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp.hpp new file mode 100644 index 0000000000..7baa7e26e1 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp.hpp @@ -0,0 +1,17 @@ +#ifndef ALP_ALP_HPP +#define ALP_ALP_HPP + +#include "alp/compressor.hpp" +#include "alp/config.hpp" +#include "alp/constants.hpp" +#include "alp/decode.hpp" +#include "alp/decompressor.hpp" +#include "alp/encode.hpp" +#include "alp/rd.hpp" +#include "alp/sampler.hpp" +#include "alp/storer.hpp" +#include "alp/utils.hpp" +#include "fastlanes/ffor.hpp" +#include "fastlanes/unffor.hpp" + +#endif // ALP_ALP_HPP diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/common.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/common.hpp new file mode 100644 index 0000000000..5af36b5b43 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/common.hpp @@ -0,0 +1,19 @@ +#ifndef ALP_COMMON_HPP +#define ALP_COMMON_HPP + +#include + +namespace alp { +//! bitwidth type +using bw_t = uint8_t; +//! exception counter type +using exp_c_t = uint32_t; +//! exception position type +using exp_p_t = uint32_t; +//! factor idx type +using factor_idx_t = uint8_t; +//! exponent idx type +using exponent_idx_t = uint8_t; +} // namespace alp + +#endif // ALP_COMMON_HPP diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/config.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/config.hpp new file mode 100644 index 0000000000..da102f46d8 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/config.hpp @@ -0,0 +1,29 @@ +#ifndef ALP_CONFIG_HPP +#define ALP_CONFIG_HPP + +#include + +/* + * ALP Configs + */ +namespace alp::config { +/// ALP Vector size (We recommend against changing this; it should be constant) +inline constexpr size_t VECTOR_SIZE = 128UL * 1024; +/// Rowgroup size +inline constexpr size_t ROWGROUP_SIZE = VECTOR_SIZE; +/// Vectors from the rowgroup from which to take samples; this will be used to then calculate the jumps +inline constexpr size_t ROWGROUP_VECTOR_SAMPLES = 1; +/// We calculate how many equidistant vector we must jump within a rowgroup +inline constexpr size_t ROWGROUP_SAMPLES_JUMP = (ROWGROUP_SIZE / ROWGROUP_VECTOR_SAMPLES) / VECTOR_SIZE; +/// Values to sample per vector +inline constexpr size_t SAMPLES_PER_VECTOR = 4 * 1024; +inline constexpr size_t SAMPLES_PER_ROWGROUP = SAMPLES_PER_VECTOR; +/// Maximum number of combinations obtained from row group sampling +inline constexpr size_t MAX_K_COMBINATIONS = 5; +inline constexpr size_t CUTTING_LIMIT = 16; +inline constexpr size_t MAX_RD_DICT_BIT_WIDTH = 3; +inline constexpr size_t MAX_RD_DICTIONARY_SIZE = (1 << MAX_RD_DICT_BIT_WIDTH); + +} // namespace alp::config + +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/constants.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/constants.hpp new file mode 100644 index 0000000000..0d7d088211 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/constants.hpp @@ -0,0 +1,245 @@ +#ifndef ALP_CONSTANTS_HPP +#define ALP_CONSTANTS_HPP + +#include "alp/config.hpp" +#include +#include + +namespace alp { + +enum class SCHEME : uint8_t { + ALP_RD, + ALP, +}; + +template +struct FloatingToExact {}; + +template <> +struct FloatingToExact { + typedef uint64_t type; +}; + +template <> +struct FloatingToExact { + typedef uint32_t type; +}; + +template +using FloatingToEncodedType = std::conditional_t, int64_t, int32_t>; + +inline constexpr uint8_t SAMPLING_EARLY_EXIT_THRESHOLD = 2; +inline constexpr double ENCODING_UPPER_LIMIT = 9223372036854774784; +inline constexpr double ENCODING_LOWER_LIMIT = -9223372036854774784; +inline constexpr uint8_t DICTIONARY_ELEMENT_SIZE_BYTES = 2; +inline constexpr uint8_t RD_EXCEPTION_POSITION_SIZE = 16; +inline constexpr uint8_t RD_EXCEPTION_POSITION_SIZE_BYTES = RD_EXCEPTION_POSITION_SIZE / 8; +inline constexpr uint8_t EXCEPTION_POSITION_SIZE = 32; +inline constexpr uint8_t EXCEPTION_POSITION_SIZE_BYTES = EXCEPTION_POSITION_SIZE / 8; +inline constexpr uint8_t RD_EXCEPTION_SIZE = 16; +inline constexpr uint8_t RD_EXCEPTION_SIZE_BYTES = RD_EXCEPTION_SIZE / 8; + +template +struct Constants {}; + +template <> +struct Constants { + /// 22 bits per value * 32 values in the sampled vector + static inline constexpr size_t RD_SIZE_THRESHOLD_LIMIT = 22 * alp::config::SAMPLES_PER_VECTOR; + static inline constexpr float MAGIC_NUMBER = 12582912.0; + static inline constexpr uint8_t EXCEPTION_SIZE = 32; + static inline constexpr uint8_t EXCEPTION_SIZE_BYTES = EXCEPTION_SIZE / 8; + static inline constexpr uint8_t MAX_EXPONENT = 10; + + // -Inf: 11111111100000000000000000000000 + // +Inf: 01111111100000000000000000000000 + // -0.0: 10000000000000000000000000000000 + static constexpr uint32_t NEGATIVE_ZERO = 0b10000000000000000000000000000000; + static constexpr uint32_t POSITIVE_INF = 0b11111111100000000000000000000000; + static constexpr uint32_t NEGATIVE_INF = 0b11111111100000000000000000000000; + + static inline constexpr float FRAC_ARR[] = { + 1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001, 0.0000001, 0.00000001, 0.000000001, 0.0000000001}; + + static inline constexpr float EXP_ARR[] = { + 1.0, 10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0, 10000000.0, 100000000.0, 1000000000.0, 10000000000.0}; +}; + +template <> +struct Constants { + /// 48 bits per value * 32 values in the sampled vector + static inline constexpr size_t RD_SIZE_THRESHOLD_LIMIT = 48 * alp::config::SAMPLES_PER_VECTOR; + static inline constexpr double MAGIC_NUMBER {0x0018000000000000}; + static inline constexpr uint8_t EXCEPTION_SIZE = 64; + static inline constexpr uint8_t EXCEPTION_SIZE_BYTES = EXCEPTION_SIZE / 8; + static inline constexpr uint8_t MAX_EXPONENT = 18; + + // -Inf: 1111111111110000000000000000000000000000000000000000000000000000 + // +Inf: 0111111111110000000000000000000000000000000000000000000000000000 + // -0.0: 1000000000000000000000000000000000000000000000000000000000000000 + static constexpr uint64_t NEGATIVE_ZERO = 0b1000000000000000000000000000000000000000000000000000000000000000; + static constexpr uint64_t POSITIVE_INF = 0b0111111111110000000000000000000000000000000000000000000000000000; + static constexpr uint64_t NEGATIVE_INF = 0b1111111111110000000000000000000000000000000000000000000000000000; + + static inline constexpr double FRAC_ARR[] = { + 1.0, + 0.1, + 0.01, + 0.001, + 0.0001, + 0.00001, + 0.000001, + 0.0000001, + 0.00000001, + 0.000000001, + 0.0000000001, + 0.00000000001, + 0.000000000001, + 0.0000000000001, + 0.00000000000001, + 0.000000000000001, + 0.0000000000000001, + 0.00000000000000001, + 0.000000000000000001, + 0.0000000000000000001, + 0.00000000000000000001, + }; + + static inline constexpr double EXP_ARR[] = { + 1.0, + 10.0, + 100.0, + 1000.0, + 10000.0, + 100000.0, + 1000000.0, + 10000000.0, + 100000000.0, + 1000000000.0, + 10000000000.0, + 100000000000.0, + 1000000000000.0, + 10000000000000.0, + 100000000000000.0, + 1000000000000000.0, + 10000000000000000.0, + 100000000000000000.0, + 1000000000000000000.0, + 10000000000000000000.0, + 100000000000000000000.0, + 1000000000000000000000.0, + 10000000000000000000000.0, + 100000000000000000000000.0, + }; +}; + +inline constexpr int64_t FACT_ARR[] = {1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000}; + +inline constexpr int64_t U_FACT_ARR[] = {1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000}; + +alignas(64) inline constexpr uint64_t INDEX_ARR[1024] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, + 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, + 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, + 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, + 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, + 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, + 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, + 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, + 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, + 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, + 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, + 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, + 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, + 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, + 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, + 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, + 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, + 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, + 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, + 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, + 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, + 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, + 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, + 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, + 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, + 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, + 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, + 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, + 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, + 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, + 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, + 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, + 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, + 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, + 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, + 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, + 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, + 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, + 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, + 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, + 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, + 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, + 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, + 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, + 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, + 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, + 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, + 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, +}; + +alignas(64) inline constexpr uint8_t LOOKUP_TABLE[256] { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, + 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, + 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, + 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, + 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, + 5, 5, 6, 5, 6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, + 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +} // namespace alp + +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/decode.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/decode.hpp new file mode 100644 index 0000000000..d253b7afea --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/decode.hpp @@ -0,0 +1,131 @@ +#ifndef ALP_DECODE_HPP +#define ALP_DECODE_HPP + +#include "common.hpp" +#include "constants.hpp" +#include + +namespace alp { + +#ifdef AVX2 +#include "immintrin.h" + +// from: https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx +// Only works for inputs in the range: [-2^51, 2^51] +__m128i double_to_int64(__m128d x) { + x = _mm_add_pd(x, _mm_set1_pd(0x0018000000000000)); + return _mm_sub_epi64(_mm_castpd_si128(x), _mm_castpd_si128(_mm_set1_pd(0x0018000000000000))); +} + +// Only works for inputs in the range: [-2^51, 2^51] +__m128d int64_to_double(__m128i x) { + x = _mm_add_epi64(x, _mm_castpd_si128(_mm_set1_pd(0x0018000000000000))); + return _mm_sub_pd(_mm_castsi128_pd(x), _mm_set1_pd(0x0018000000000000)); +} + +/* + * scalar version of int64_to_double + */ +double int64_to_double(int64_t x) { + double magic_number = static_cast(0x0018000000000000); + x = x + static_cast(magic_number); + return static_cast(x) - static_cast(magic_number); +} + +// SSE version of int64_to_double +// Only works for inputs in the range: [-2^51, 2^51] +__m128d sse_int64_to_double(__m128i x) { + x = _mm_add_epi64(x, _mm_castpd_si128(_mm_set1_pd(0x0018000000000000))); + return _mm_sub_pd(_mm_castsi128_pd(x), _mm_set1_pd(0x0018000000000000)); +} + +__m256d int64_to_double_fast_precise(const __m256i v) +/* Optimized full range int64_t to double conversion */ +/* Emulate _mm256_cvtepi64_pd() */ +{ + __m256i magic_i_lo = _mm256_set1_epi64x(0x4330000000000000); /* 2^52 encoded as floating-point */ + __m256i magic_i_hi32 = _mm256_set1_epi64x(0x4530000080000000); /* 2^84 + 2^63 encoded as floating-point */ + __m256i magic_i_all = _mm256_set1_epi64x(0x4530000080100000); /* 2^84 + 2^63 + 2^52 encoded as floating-point */ + __m256d magic_d_all = _mm256_castsi256_pd(magic_i_all); + + __m256i v_lo = + _mm256_blend_epi32(magic_i_lo, v, 0b01010101); /* Blend the 32 lowest significant bits of v with magic_int_lo */ + __m256i v_hi = _mm256_srli_epi64(v, 32); /* Extract the 32 most significant bits of v */ + v_hi = _mm256_xor_si256(v_hi, magic_i_hi32); /* Flip the msb of v_hi and blend with 0x45300000 */ + __m256d v_hi_dbl = _mm256_sub_pd(_mm256_castsi256_pd(v_hi), magic_d_all); /* Compute in double precision: */ + __m256d result = _mm256_add_pd( + v_hi_dbl, + _mm256_castsi256_pd( + v_lo)); /* (v_hi - magic_d_all) + v_lo Do not assume associativity of floating point addition !! */ + return result; /* With gcc use -O3, then -fno-associative-math is default. Do not use -Ofast, which enables + -fassociative-math! */ +} + +void sse_decode(const int64_t* digits, uint8_t fac_idx, uint8_t exp_idx, double* out_p) { + uint64_t factor = alp::U_FACT_ARR[fac_idx]; + double frac10 = alp::Constants::FRAC_ARR[exp_idx]; + __m128i factor_sse = _mm_set1_epi64x(factor); + __m128d frac10_sse = _mm_set1_pd(frac10); + + auto digits_p = reinterpret_cast(digits); + + for (size_t i {0}; i < 512; ++i) { + __m128i digit = _mm_loadu_si128(digits_p + i); + __m128i tmp_int = digit * factor_sse; + __m128d tmp_dbl = sse_int64_to_double(tmp_int); + __m128d tmp_dbl_mlt = tmp_dbl * frac10_sse; + _mm_storeu_pd(out_p + (i * 2), tmp_dbl_mlt); + } +} + +void avx2_decode(const int64_t* digits, uint8_t fac_idx, uint8_t exp_idx, double* out_p) { + uint64_t factor = alp::U_FACT_ARR[fac_idx]; + double frac10 = alp::Constants::FRAC_ARR[exp_idx]; + __m256i factor_sse = _mm256_set1_epi64x(factor); + __m256d frac10_sse = _mm256_set1_pd(frac10); + + auto digits_p = reinterpret_cast(digits); + + for (size_t i {0}; i < 256; ++i) { + __m256i digit = _mm256_loadu_si256(digits_p + i); + __m256i tmp_int = digit * factor_sse; + __m256d tmp_dbl = int64_to_double_fast_precise(tmp_int); + __m256d tmp_dbl_mlt = tmp_dbl * frac10_sse; + _mm256_storeu_pd(out_p + (i * 4), tmp_dbl_mlt); + } +} + +#endif + +template +struct AlpDecode { + + //! Scalar decoding a single value with ALP + static inline T decode_value(const int64_t encoded_value, const uint8_t factor, const uint8_t exponent) { + const T decoded_value = encoded_value * FACT_ARR[factor] * alp::Constants::FRAC_ARR[exponent]; + return decoded_value; + } + + //! Scalar decoding of an ALP vector + static inline void + decode(const int64_t* encoded_integers, const uint8_t fac_idx, const uint8_t exp_idx, T* output) { + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + output[i] = decode_value(encoded_integers[i], fac_idx, exp_idx); + } + } + + //! Patch Exceptions + static inline void patch_exceptions(T* out, + const T* exceptions, + const exp_p_t* exceptions_positions, + const exp_c_t* exceptions_count) { + const auto exp_c = exceptions_count[0]; + for (exp_c_t i {0}; i < exp_c; i++) { + out[exceptions_positions[i]] = exceptions[i]; + } + } +}; + +} // namespace alp + +#endif // ALP_DECODE_HPP diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/encode.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/encode.hpp new file mode 100644 index 0000000000..611842c7b3 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/encode.hpp @@ -0,0 +1,450 @@ +#ifndef ALP_ENCODE_HPP +#define ALP_ENCODE_HPP + +#include "alp/config.hpp" +#include "alp/constants.hpp" +#include "alp/decode.hpp" +#include "alp/sampler.hpp" +#include "alp/state.hpp" +#include "common.hpp" +#include +#include +#include +#include +#include +#include +#include + +#ifdef __AVX2__ + +#include + +#endif + +/* + * ALP Encoding + */ +namespace alp { + +template +struct AlpEncode { + + using EXACT_TYPE = typename FloatingToExact::type; + using ENCODED_TYPE = FloatingToEncodedType; + static constexpr uint8_t EXACT_TYPE_BITSIZE = sizeof(EXACT_TYPE) * 8; + + /* + * Check for special values which are impossible for ALP to encode + * because they cannot be cast to int64 without an undefined behaviour + */ + static inline bool is_impossible_to_encode(const T n) { + return !std::isfinite(n) || std::isnan(n) || n > ENCODING_UPPER_LIMIT || n < ENCODING_LOWER_LIMIT || + (n == 0.0 && std::signbit(n)); //! Verification for -0.0 + } + + //! Scalar encoding a single value with ALP + template + static ENCODED_TYPE encode_value(const T value, const factor_idx_t factor_idx, const exponent_idx_t exponent_idx) { + T tmp_encoded_value = value * Constants::EXP_ARR[exponent_idx] * Constants::FRAC_ARR[factor_idx]; + if constexpr (SAFE) { + if (is_impossible_to_encode(tmp_encoded_value)) { return static_cast(ENCODING_UPPER_LIMIT); } + } + tmp_encoded_value = tmp_encoded_value + Constants::MAGIC_NUMBER - Constants::MAGIC_NUMBER; + return static_cast(tmp_encoded_value); + } + + //! Analyze FFOR to obtain bitwidth and frame-of-reference value + static inline void analyze_ffor(const int64_t* input_vector, bw_t& bit_width, int64_t* base_for) { + auto min = std::numeric_limits::max(); + auto max = std::numeric_limits::min(); + + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + if (input_vector[i] < min) { min = input_vector[i]; } + if (input_vector[i] > max) { max = input_vector[i]; } + } + + const auto delta = (static_cast(max) - static_cast(min)); + const auto estimated_bits_per_value = static_cast(ceil(log2(delta + 1))); + bit_width = estimated_bits_per_value; + base_for[0] = min; + } + + /* + * Function to sort the best combinations from each vector sampled from the rowgroup + * First criteria is number of times it appears + * Second criteria is bigger exponent + * Third criteria is bigger factor + */ + static inline bool compare_best_combinations(const std::pair, int>& t1, + const std::pair, int>& t2) { + return (t1.second > t2.second) || (t1.second == t2.second && (t2.first.first < t1.first.first)) || + ((t1.second == t2.second && t2.first.first == t1.first.first) && (t2.first.second < t1.first.second)); + } + + /* + * Find the best combinations of factor-exponent from each vector sampled from a rowgroup + * This function is called once per rowgroup + * This operates over ALP first level samples + */ + static inline void find_top_k_combinations(const T* smp_arr, state& stt) { + const auto n_vectors_to_sample = + static_cast(std::ceil(static_cast(stt.sampled_values_n) / config::SAMPLES_PER_VECTOR)); + const uint64_t samples_size = std::min(stt.sampled_values_n, config::SAMPLES_PER_VECTOR); + std::map, int> global_combinations; + uint64_t smp_offset {0}; + + // For each vector in the rg sample + uint64_t best_estimated_compression_size { + (samples_size * (Constants::EXCEPTION_SIZE + EXCEPTION_POSITION_SIZE)) + + (samples_size * (Constants::EXCEPTION_SIZE))}; + for (size_t smp_n = 0; smp_n < n_vectors_to_sample; smp_n++) { + uint8_t found_factor {0}; + uint8_t found_exponent {0}; + // We start our optimization with the worst possible total bits obtained from compression + uint64_t sample_estimated_compression_size { + (samples_size * (Constants::EXCEPTION_SIZE + EXCEPTION_POSITION_SIZE)) + + (samples_size * (Constants::EXCEPTION_SIZE))}; // worst scenario + + // We try all combinations in search for the one which minimize the compression size + for (int8_t exp_ref = Constants::MAX_EXPONENT; exp_ref >= 0; exp_ref--) { + for (int8_t factor_idx = exp_ref; factor_idx >= 0; factor_idx--) { + uint32_t exceptions_count = {0}; + uint32_t non_exceptions_count = {0}; + uint32_t estimated_bits_per_value = {0}; + uint64_t estimated_compression_size = {0}; + + ENCODED_TYPE max_encoded_value = {std::numeric_limits::min()}; + ENCODED_TYPE min_encoded_value = {std::numeric_limits::max()}; + + for (size_t i = 0; i < samples_size; i++) { + const T actual_value = smp_arr[smp_offset + i]; + const ENCODED_TYPE encoded_value = encode_value(actual_value, factor_idx, exp_ref); + const T decoded_value = AlpDecode::decode_value(encoded_value, factor_idx, exp_ref); + if (decoded_value == actual_value) { + non_exceptions_count++; + if (encoded_value > max_encoded_value) { max_encoded_value = encoded_value; } + if (encoded_value < min_encoded_value) { min_encoded_value = encoded_value; } + } else { + exceptions_count++; + } + } + + // We do not take into account combinations which yield to almsot all exceptions + if (non_exceptions_count < 2) { continue; } + + // Evaluate factor/exponent compression size (we optimize for FOR) + const uint64_t delta = + (static_cast(max_encoded_value) - static_cast(min_encoded_value)); + estimated_bits_per_value = std::ceil(std::log2(delta + 1)); + estimated_compression_size += samples_size * estimated_bits_per_value; + estimated_compression_size += + exceptions_count * (Constants::EXCEPTION_SIZE + EXCEPTION_POSITION_SIZE); + + if ((estimated_compression_size < sample_estimated_compression_size) || + (estimated_compression_size == sample_estimated_compression_size && + (found_exponent < exp_ref)) || + // We prefer bigger exponents + ((estimated_compression_size == sample_estimated_compression_size && + found_exponent == exp_ref) && + (found_factor < factor_idx)) // We prefer bigger factors + ) { + sample_estimated_compression_size = estimated_compression_size; + found_exponent = exp_ref; + found_factor = factor_idx; + if (sample_estimated_compression_size < best_estimated_compression_size) { + best_estimated_compression_size = sample_estimated_compression_size; + } + } + } + } + std::pair cmb = std::make_pair(found_exponent, found_factor); + global_combinations[cmb]++; + smp_offset += samples_size; + } + + // We adapt scheme if we were not able to achieve compression in the current rg + if (best_estimated_compression_size >= Constants::RD_SIZE_THRESHOLD_LIMIT) { + stt.scheme = SCHEME::ALP_RD; + return; + } + + // Convert our hash to a Combination vector to be able to sort + // Note that this vector is always small (< 10 combinations) + std::vector, int>> best_k_combinations; + best_k_combinations.reserve(global_combinations.size()); + for (auto const& itr : global_combinations) { + best_k_combinations.emplace_back(itr.first, // Pair exp, fac + itr.second // N of times it appeared + ); + } + // We sort combinations based on times they appeared + std::sort(best_k_combinations.begin(), best_k_combinations.end(), compare_best_combinations); + if (best_k_combinations.size() < stt.k_combinations) { stt.k_combinations = best_k_combinations.size(); } + + // Save k' best exp, fac combination pairs + for (size_t i {0}; i < stt.k_combinations; i++) { + stt.best_k_combinations.push_back(best_k_combinations[i].first); + } + } + + /* + * Find the best combination of factor-exponent for a vector from within the best k combinations + * This is ALP second level sampling + */ + static inline void + find_best_exponent_factor_from_combinations(const std::vector>& top_combinations, + const uint8_t top_k, + const T* input_vector, + const size_t input_vector_size, + uint8_t& factor, + uint8_t& exponent) { + uint8_t found_exponent {0}; + uint8_t found_factor {0}; + uint64_t best_estimated_compression_size {0}; + uint8_t worse_threshold_count {0}; + + const int32_t sample_increments = + std::max(1, static_cast(std::ceil(input_vector_size / config::SAMPLES_PER_ROWGROUP))); + + // We try each K combination in search for the one which minimize the compression size in the vector + for (size_t k {0}; k < top_k; k++) { + const int exp_idx = top_combinations[k].first; + const int factor_idx = top_combinations[k].second; + uint32_t exception_count {0}; + uint32_t estimated_bits_per_value {0}; + uint64_t estimated_compression_size {0}; + ENCODED_TYPE max_encoded_value {std::numeric_limits::min()}; + ENCODED_TYPE min_encoded_value {std::numeric_limits::max()}; + + for (size_t sample_idx = 0; sample_idx < input_vector_size; sample_idx += sample_increments) { + const T actual_value = input_vector[sample_idx]; + const ENCODED_TYPE encoded_value = encode_value(actual_value, factor_idx, exp_idx); + const T decoded_value = AlpDecode::decode_value(encoded_value, factor_idx, exp_idx); + if (decoded_value == actual_value) { + if (encoded_value > max_encoded_value) { max_encoded_value = encoded_value; } + if (encoded_value < min_encoded_value) { min_encoded_value = encoded_value; } + } else { + exception_count++; + } + } + + // Evaluate factor/exponent performance (we optimize for FOR) + const uint64_t delta = max_encoded_value - min_encoded_value; + estimated_bits_per_value = ceil(log2(delta + 1)); + estimated_compression_size += config::SAMPLES_PER_ROWGROUP * estimated_bits_per_value; + estimated_compression_size += exception_count * (Constants::EXCEPTION_SIZE + EXCEPTION_POSITION_SIZE); + + if (k == 0) { // First try with first combination + best_estimated_compression_size = estimated_compression_size; + found_factor = factor_idx; + found_exponent = exp_idx; + continue; // Go to second + } + if (estimated_compression_size >= + best_estimated_compression_size) { // If current is worse or equal than previous + worse_threshold_count += 1; + if (worse_threshold_count == SAMPLING_EARLY_EXIT_THRESHOLD) { + break; // We stop only if two are worse + } + continue; + } + // Otherwise we replace best and continue with next + best_estimated_compression_size = estimated_compression_size; + found_factor = factor_idx; + found_exponent = exp_idx; + worse_threshold_count = 0; + } + exponent = found_exponent; + factor = found_factor; + } + + // DOUBLE + static inline void encode_simdized(const double* input_vector, + double* exceptions, + exp_p_t* exceptions_positions, + exp_c_t* exceptions_count, + int64_t* encoded_integers, + const factor_idx_t factor_idx, + const exponent_idx_t exponent_idx) { + alignas(64) static double encoded_dbl_arr[1024]; + alignas(64) static double dbl_arr_without_specials[1024]; + alignas(64) static uint64_t INDEX_ARR[1024]; + + exp_p_t current_exceptions_count {0}; + uint64_t exceptions_idx {0}; + + // make copy of input with all special values replaced by ENCODING_UPPER_LIMIT + const auto* tmp_input = reinterpret_cast(input_vector); + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + const auto is_special = + ((tmp_input[i] & 0x7FFFFFFFFFFFFFFF) >= + 0x7FF0000000000000) // any NaN, +inf and -inf (https://stackoverflow.com/questions/29730530/) + || tmp_input[i] == Constants::NEGATIVE_ZERO; + + if (is_special) { + dbl_arr_without_specials[i] = ENCODING_UPPER_LIMIT; + } else { + dbl_arr_without_specials[i] = input_vector[i]; + } + } + +#pragma clang loop vectorize_width(64) + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + auto const actual_value = dbl_arr_without_specials[i]; + + // Attempt conversion + const int64_t encoded_value = encode_value(actual_value, factor_idx, exponent_idx); + encoded_integers[i] = encoded_value; + const double decoded_value = AlpDecode::decode_value(encoded_value, factor_idx, exponent_idx); + encoded_dbl_arr[i] = decoded_value; + } + +#ifdef __AVX512F__ + for (size_t i {0}; i < config::VECTOR_SIZE; i = i + 8) { + __m512d l = _mm512_loadu_pd(tmp_dbl_arr + i); + __m512d r = _mm512_loadu_pd(input_vector + i); + __m512i index = _mm512_loadu_pd(INDEX_ARR + i); + auto is_exception = _mm512_cmpneq_pd_mask(l, r); + _mm512_mask_compressstoreu_pd(tmp_index + exceptions_idx, is_exception, index); + exceptions_idx += LOOKUP_TABLE[is_exception]; + } +#else + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + auto l = encoded_dbl_arr[i]; + auto r = dbl_arr_without_specials[i]; + auto is_exception = (l != r); + INDEX_ARR[exceptions_idx] = i; + exceptions_idx += is_exception; + } +#endif + + int64_t a_non_exception_value = 0; + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + if (i != INDEX_ARR[i]) { + a_non_exception_value = encoded_integers[i]; + break; + } + } + + for (size_t j {0}; j < exceptions_idx; j++) { + size_t i = INDEX_ARR[j]; + const auto actual_value = input_vector[i]; + encoded_integers[i] = a_non_exception_value; + exceptions[current_exceptions_count] = actual_value; + exceptions_positions[current_exceptions_count] = i; + current_exceptions_count = current_exceptions_count + 1; + } + + *exceptions_count = current_exceptions_count; + } + + // FLOAT + static inline void encode_simdized(const float* input_vector, + float* exceptions, + exp_p_t* exceptions_positions, + exp_c_t* exceptions_count, + int64_t* encoded_integers, + const factor_idx_t factor_idx, + const exponent_idx_t exponent_idx) { + alignas(64) static float encoded_dbl_arr[1024]; + alignas(64) static float dbl_arr_without_specials[1024]; + alignas(64) static uint64_t INDEX_ARR[1024]; + + exp_p_t current_exceptions_count {0}; + uint64_t exceptions_idx {0}; + + // make copy of input with all special values replaced by ENCODING_UPPER_LIMIT + const auto* tmp_input = reinterpret_cast(input_vector); + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + const auto is_special = + ((tmp_input[i] & 0x7FFFFFFF) >= + 0x7F800000) // any NaN, +inf and -inf (https://stackoverflow.com/questions/29730530/) + || tmp_input[i] == Constants::NEGATIVE_ZERO; + + if (is_special) { + dbl_arr_without_specials[i] = ENCODING_UPPER_LIMIT; + } else { + dbl_arr_without_specials[i] = input_vector[i]; + } + } + +#pragma clang loop vectorize_width(64) + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + auto const actual_value = dbl_arr_without_specials[i]; + + // Attempt conversion + const int64_t encoded_value = encode_value(actual_value, factor_idx, exponent_idx); + encoded_integers[i] = encoded_value; + const float decoded_value = AlpDecode::decode_value(encoded_value, factor_idx, exponent_idx); + encoded_dbl_arr[i] = decoded_value; + } + +#ifdef __AVX512F__ + for (size_t i {0}; i < config::VECTOR_SIZE; i = i + 16) { + __m512 l = _mm512_loadu_ps(tmp_dbl_arr + i); + __m512 r = _mm512_loadu_ps(input_vector + i); + __m512i index = _mm512_loadu_ps(INDEX_ARR + i); + auto is_exception = _mm512_cmpneq_ps_mask(l, r); + _mm512_mask_compressstoreu_ps(tmp_index + exceptions_idx, is_exception, index); + exceptions_idx += LOOKUP_TABLE[is_exception]; + } +#else + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + auto l = encoded_dbl_arr[i]; + auto r = dbl_arr_without_specials[i]; + auto is_exception = (l != r); + INDEX_ARR[exceptions_idx] = i; + exceptions_idx += is_exception; + } +#endif + + int64_t a_non_exception_value = 0; + for (size_t i {0}; i < config::VECTOR_SIZE; i++) { + if (i != INDEX_ARR[i]) { + a_non_exception_value = encoded_integers[i]; + break; + } + } + + for (size_t j {0}; j < exceptions_idx; j++) { + size_t i = INDEX_ARR[j]; + const auto actual_value = input_vector[i]; + encoded_integers[i] = a_non_exception_value; + exceptions[current_exceptions_count] = actual_value; + exceptions_positions[current_exceptions_count] = i; + current_exceptions_count = current_exceptions_count + 1; + } + + *exceptions_count = current_exceptions_count; + } + + static inline void encode(const T* input_vector, + T* exceptions, + uint32_t* exceptions_positions, + uint32_t* exceptions_count, + int64_t* encoded_integers, + state& stt) { + + if (stt.k_combinations > 1) { // Only if more than 1 found top combinations we sample and search + find_best_exponent_factor_from_combinations( + stt.best_k_combinations, stt.k_combinations, input_vector, stt.vector_size, stt.fac, stt.exp); + } else { + stt.exp = stt.best_k_combinations[0].first; + stt.fac = stt.best_k_combinations[0].second; + } + encode_simdized( + input_vector, exceptions, exceptions_positions, exceptions_count, encoded_integers, stt.fac, stt.exp); + } + + static inline void + init(const T* data_column, const size_t column_offset, const size_t tuples_count, T* sample_arr, state& stt) { + stt.scheme = SCHEME::ALP; + stt.sampled_values_n = sampler::first_level_sample(data_column, column_offset, tuples_count, sample_arr); + stt.k_combinations = config::MAX_K_COMBINATIONS; + stt.best_k_combinations.clear(); + find_top_k_combinations(sample_arr, stt); + } +}; + +} // namespace alp +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/sampler.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/sampler.hpp new file mode 100644 index 0000000000..88d2d6001f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/sampler.hpp @@ -0,0 +1,52 @@ +#ifndef ALP_SAMPLER_HPP +#define ALP_SAMPLER_HPP + +#include "alp/config.hpp" +#include +#include + +namespace alp::sampler { + +template +inline size_t first_level_sample(const T* data, const size_t data_offset, const size_t data_size, T* data_sample) { + const size_t left_in_data = data_size - data_offset; + const size_t portion_to_sample = std::min(config::ROWGROUP_SIZE, left_in_data); + const size_t available_alp_vectors = std::ceil(static_cast(portion_to_sample) / config::VECTOR_SIZE); + size_t sample_idx = 0; + size_t data_idx = data_offset; + + for (size_t vector_idx = 0; vector_idx < available_alp_vectors; vector_idx++) { + const size_t current_vector_n_values = std::min(data_size - data_idx, config::VECTOR_SIZE); + + //! We sample equidistant vectors; to do this we skip a fixed values of vectors + //! If we are not in the correct jump, we do not take sample from this vector + if (const bool must_select_rowgroup_sample = (vector_idx % config::ROWGROUP_SAMPLES_JUMP) == 0; + !must_select_rowgroup_sample) { + data_idx += current_vector_n_values; + continue; + } + + const size_t n_sampled_increments = std::max( + 1, + static_cast(std::ceil(static_cast(current_vector_n_values) / config::SAMPLES_PER_VECTOR))); + + //! We do not take samples of non-complete duckdb vectors (usually the last one) + //! Except in the case of too little data + if (current_vector_n_values < config::SAMPLES_PER_VECTOR && sample_idx != 0) { + data_idx += current_vector_n_values; + continue; + } + + // Storing the sample of that vector + for (size_t i = 0; i < current_vector_n_values; i += n_sampled_increments) { + data_sample[sample_idx] = data[data_idx + i]; + sample_idx++; + } + data_idx += current_vector_n_values; + } + return sample_idx; +} + +} // namespace alp::sampler + +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/state.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/state.hpp new file mode 100644 index 0000000000..d6c9c68677 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/state.hpp @@ -0,0 +1,27 @@ +#ifndef ALP_STATE_HPP +#define ALP_STATE_HPP + +#include "alp/common.hpp" +#include "alp/config.hpp" +#include "alp/constants.hpp" +#include +#include + +namespace alp { +struct state { + SCHEME scheme {SCHEME::ALP}; + uint32_t vector_size {config::VECTOR_SIZE}; + uint32_t exceptions_count {0}; + size_t sampled_values_n {0}; + + // ALP + uint16_t k_combinations {5}; + std::vector> best_k_combinations; + uint8_t exp; + uint8_t fac; + bw_t bit_width; + int64_t for_base; +}; +} // namespace alp + +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/storer.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/storer.hpp new file mode 100644 index 0000000000..bc761fbcbd --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/storer.hpp @@ -0,0 +1,56 @@ +#ifndef ALP_API_MEM_STORER_HPP +#define ALP_API_MEM_STORER_HPP + +#include +#include + +namespace alp { namespace storer { + +template +struct MemStorer { + + uint8_t* out_buffer; + size_t buffer_offset; + + MemStorer() {} + MemStorer(uint8_t* out_buffer) + : out_buffer(out_buffer) + , buffer_offset(0) {} + + void set_buffer(uint8_t* out) { out_buffer = out; } + + void reset() { buffer_offset = 0; } + + size_t get_size() { return buffer_offset; } + + void store(void* in, size_t bytes_to_store) { + if (!DRY) memcpy((void*)(out_buffer + buffer_offset), in, bytes_to_store); + buffer_offset += bytes_to_store; + } +}; + +struct MemReader { + + uint8_t* in_buffer; + size_t buffer_offset; + + MemReader() {} + MemReader(uint8_t* in_buffer) + : in_buffer(in_buffer) + , buffer_offset(0) {} + + void set_buffer(uint8_t* in) { in_buffer = in; } + + void reset() { buffer_offset = 0; } + + size_t get_size() { return buffer_offset; } + + void read(void* out, size_t bytes_to_read) { + memcpy(out, (void*)(in_buffer + buffer_offset), bytes_to_read); + buffer_offset += bytes_to_read; + } +}; + +}} // namespace alp::storer + +#endif \ No newline at end of file diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/utils.hpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/utils.hpp new file mode 100644 index 0000000000..bf570125a9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/alp/include/alp/utils.hpp @@ -0,0 +1,74 @@ +#ifndef ALP_UTILS_HPP +#define ALP_UTILS_HPP + +#include "alp/config.hpp" +#include "alp/encode.hpp" +#include +#include + +namespace alp { + +template +struct AlpApiUtils { + + static size_t get_rowgroup_count(size_t values_count) { + return std::ceil((double)values_count / config::ROWGROUP_SIZE); + }; + + static size_t get_complete_vector_count(size_t n_values) { + return std::floor(static_cast(n_values) / config::VECTOR_SIZE); + } + + /* + * Function to get the size of a vector after bit packing + * Note that we always store VECTOR_SIZE size vectors + */ + static size_t get_size_after_bitpacking(uint8_t bit_width) { + return align_value(config::VECTOR_SIZE * bit_width) / 8; + } + + template + static M align_value(M n) { + return ((n + (val - 1)) / val) * val; + } + + static void fill_incomplete_alp_vector(T* input_vector, + T* exceptions, + uint32_t* exceptions_positions, + uint32_t* exceptions_count, + int64_t* encoded_integers, + state& stt) { + + static auto* tmp_index = new (std::align_val_t {64}) uint64_t[1024]; + + // We fill a vector with 0s since these values will never be exceptions + for (size_t i = stt.vector_size; i < config::VECTOR_SIZE; i++) { + input_vector[i] = 0.0; + } + // We encode the vector filled with the dummy values + AlpEncode::encode(input_vector, exceptions, exceptions_positions, exceptions_count, encoded_integers, stt); + T a_non_exception_value = 0.0; + // We lookup the first non exception value from the true vector values + for (size_t i {0}; i < stt.vector_size; i++) { + if (i != tmp_index[i]) { + a_non_exception_value = input_vector[i]; + break; + } + } + // We fill the vector with this dummy value + for (size_t i = stt.vector_size; i < config::VECTOR_SIZE; i++) { + input_vector[i] = a_non_exception_value; + } + } + + static void fill_incomplete_alprd_vector(T* input_vector, const state& stt) { + // We just fill the vector with the first value + const T first_vector_value = input_vector[0]; + for (size_t i = stt.vector_size; i < config::VECTOR_SIZE; i++) { + input_vector[i] = first_vector_value; + } + } +}; +} // namespace alp + +#endif diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/CMakeLists.txt new file mode 100644 index 0000000000..2c8b639231 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/CMakeLists.txt @@ -0,0 +1,23 @@ +if(${AUTO_UPDATE_GRAMMAR}) + add_custom_command( + OUTPUT + ${CMAKE_CURRENT_SOURCE_DIR}/cypher_lexer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cypher_parser.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/cypher_lexer.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/cypher_parser.h + COMMAND cmake -D ROOT_DIR=${PROJECT_SOURCE_DIR} -P generate_grammar.cmake + DEPENDS + ${PROJECT_SOURCE_DIR}/src/antlr4/Cypher.g4 + ${PROJECT_SOURCE_DIR}/scripts/antlr4/generate_grammar.cmake + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/scripts/antlr4) +endif() + +add_library(antlr4_cypher + STATIC + cypher_lexer.cpp + cypher_parser.cpp) + +target_include_directories(antlr4_cypher + PRIVATE ../antlr4_runtime/src) + +target_link_libraries(antlr4_cypher PRIVATE antlr4_runtime) diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/cypher_lexer.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/cypher_lexer.cpp new file mode 100644 index 0000000000..6cffcad8ce --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/cypher_lexer.cpp @@ -0,0 +1,965 @@ + +// Generated from Cypher.g4 by ANTLR 4.13.1 + + +#include "cypher_lexer.h" + + +using namespace antlr4; + + + +using namespace antlr4; + +namespace { + +struct CypherLexerStaticData final { + CypherLexerStaticData(std::vector ruleNames, + std::vector channelNames, + std::vector modeNames, + std::vector literalNames, + std::vector symbolicNames) + : ruleNames(std::move(ruleNames)), channelNames(std::move(channelNames)), + modeNames(std::move(modeNames)), literalNames(std::move(literalNames)), + symbolicNames(std::move(symbolicNames)), + vocabulary(this->literalNames, this->symbolicNames) {} + + CypherLexerStaticData(const CypherLexerStaticData&) = delete; + CypherLexerStaticData(CypherLexerStaticData&&) = delete; + CypherLexerStaticData& operator=(const CypherLexerStaticData&) = delete; + CypherLexerStaticData& operator=(CypherLexerStaticData&&) = delete; + + std::vector decisionToDFA; + antlr4::atn::PredictionContextCache sharedContextCache; + const std::vector ruleNames; + const std::vector channelNames; + const std::vector modeNames; + const std::vector literalNames; + const std::vector symbolicNames; + const antlr4::dfa::Vocabulary vocabulary; + antlr4::atn::SerializedATNView serializedATN; + std::unique_ptr atn; +}; + +::antlr4::internal::OnceFlag cypherlexerLexerOnceFlag; +#if ANTLR4_USE_THREAD_LOCAL_CACHE +static thread_local +#endif +CypherLexerStaticData *cypherlexerLexerStaticData = nullptr; + +void cypherlexerLexerInitialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + if (cypherlexerLexerStaticData != nullptr) { + return; + } +#else + assert(cypherlexerLexerStaticData == nullptr); +#endif + auto staticData = std::make_unique( + std::vector{ + "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", "T__7", "T__8", + "T__9", "T__10", "T__11", "T__12", "T__13", "T__14", "T__15", "T__16", + "T__17", "T__18", "T__19", "T__20", "T__21", "T__22", "T__23", "T__24", + "T__25", "T__26", "T__27", "T__28", "T__29", "T__30", "T__31", "T__32", + "T__33", "T__34", "T__35", "T__36", "T__37", "T__38", "T__39", "T__40", + "T__41", "T__42", "T__43", "ACYCLIC", "ANY", "ADD", "ALL", "ALTER", + "AND", "AS", "ASC", "ASCENDING", "ATTACH", "BEGIN", "BY", "CALL", + "CASE", "CAST", "CHECKPOINT", "COLUMN", "COMMENT", "COMMIT", "COMMIT_SKIP_CHECKPOINT", + "CONTAINS", "COPY", "COUNT", "CREATE", "CYCLE", "DATABASE", "DBTYPE", + "DEFAULT", "DELETE", "DESC", "DESCENDING", "DETACH", "DISTINCT", "DROP", + "ELSE", "END", "ENDS", "EXISTS", "EXPLAIN", "EXPORT", "EXTENSION", + "FALSE", "FROM", "FORCE", "GLOB", "GRAPH", "GROUP", "HEADERS", "HINT", + "IMPORT", "IF", "IN", "INCREMENT", "INSTALL", "IS", "JOIN", "KEY", + "LIMIT", "LOAD", "LOGICAL", "MACRO", "MATCH", "MAXVALUE", "MERGE", + "MINVALUE", "MULTI_JOIN", "NO", "NODE", "NOT", "NONE", "NULL", "ON", + "ONLY", "OPTIONAL", "OR", "ORDER", "PRIMARY", "PROFILE", "PROJECT", + "READ", "REL", "RENAME", "RETURN", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", + "SEQUENCE", "SET", "SHORTEST", "START", "STARTS", "STRUCT", "TABLE", + "THEN", "TO", "TRAIL", "TRANSACTION", "TRUE", "TYPE", "UNION", "UNWIND", + "UNINSTALL", "UPDATE", "USE", "WHEN", "WHERE", "WITH", "WRITE", "WSHORTEST", + "XOR", "SINGLE", "YIELD", "USER", "PASSWORD", "ROLE", "MAP", "DECIMAL", + "STAR", "L_SKIP", "INVALID_NOT_EQUAL", "COLON", "DOTDOT", "MINUS", + "FACTORIAL", "StringLiteral", "EscapedChar", "DecimalInteger", "HexLetter", + "HexDigit", "Digit", "NonZeroDigit", "NonZeroOctDigit", "ZeroDigit", + "ExponentDecimalReal", "RegularDecimalReal", "UnescapedSymbolicName", + "IdentifierStart", "IdentifierPart", "EscapedSymbolicName", "SP", + "WHITESPACE", "CypherComment", "FF", "EscapedSymbolicName_0", "RS", + "ID_Continue", "Comment_1", "StringLiteral_1", "Comment_3", "Comment_2", + "GS", "FS", "CR", "Sc", "SPACE", "Pc", "TAB", "StringLiteral_0", "LF", + "VT", "US", "ID_Start", "Unknown" + }, + std::vector{ + "DEFAULT_TOKEN_CHANNEL", "HIDDEN" + }, + std::vector{ + "DEFAULT_MODE" + }, + std::vector{ + "", "';'", "'('", "')'", "','", "'.'", "'='", "'['", "']'", "'{'", + "'}'", "'|'", "'<>'", "'<'", "'<='", "'>'", "'>='", "'&'", "'>>'", + "'<<'", "'+'", "'/'", "'%'", "'^'", "'=~'", "'$'", "'\\u27E8'", "'\\u3008'", + "'\\uFE64'", "'\\uFF1C'", "'\\u27E9'", "'\\u3009'", "'\\uFE65'", "'\\uFF1E'", + "'\\u00AD'", "'\\u2010'", "'\\u2011'", "'\\u2012'", "'\\u2013'", "'\\u2014'", + "'\\u2015'", "'\\u2212'", "'\\uFE58'", "'\\uFE63'", "'\\uFF0D'", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "'*'", "", "'!='", + "':'", "'..'", "'-'", "'!'", "", "", "", "", "", "", "", "", "'0'" + }, + std::vector{ + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "ACYCLIC", "ANY", "ADD", + "ALL", "ALTER", "AND", "AS", "ASC", "ASCENDING", "ATTACH", "BEGIN", + "BY", "CALL", "CASE", "CAST", "CHECKPOINT", "COLUMN", "COMMENT", "COMMIT", + "COMMIT_SKIP_CHECKPOINT", "CONTAINS", "COPY", "COUNT", "CREATE", "CYCLE", + "DATABASE", "DBTYPE", "DEFAULT", "DELETE", "DESC", "DESCENDING", "DETACH", + "DISTINCT", "DROP", "ELSE", "END", "ENDS", "EXISTS", "EXPLAIN", "EXPORT", + "EXTENSION", "FALSE", "FROM", "FORCE", "GLOB", "GRAPH", "GROUP", "HEADERS", + "HINT", "IMPORT", "IF", "IN", "INCREMENT", "INSTALL", "IS", "JOIN", + "KEY", "LIMIT", "LOAD", "LOGICAL", "MACRO", "MATCH", "MAXVALUE", "MERGE", + "MINVALUE", "MULTI_JOIN", "NO", "NODE", "NOT", "NONE", "NULL", "ON", + "ONLY", "OPTIONAL", "OR", "ORDER", "PRIMARY", "PROFILE", "PROJECT", + "READ", "REL", "RENAME", "RETURN", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", + "SEQUENCE", "SET", "SHORTEST", "START", "STARTS", "STRUCT", "TABLE", + "THEN", "TO", "TRAIL", "TRANSACTION", "TRUE", "TYPE", "UNION", "UNWIND", + "UNINSTALL", "UPDATE", "USE", "WHEN", "WHERE", "WITH", "WRITE", "WSHORTEST", + "XOR", "SINGLE", "YIELD", "USER", "PASSWORD", "ROLE", "MAP", "DECIMAL", + "STAR", "L_SKIP", "INVALID_NOT_EQUAL", "COLON", "DOTDOT", "MINUS", + "FACTORIAL", "StringLiteral", "EscapedChar", "DecimalInteger", "HexLetter", + "HexDigit", "Digit", "NonZeroDigit", "NonZeroOctDigit", "ZeroDigit", + "ExponentDecimalReal", "RegularDecimalReal", "UnescapedSymbolicName", + "IdentifierStart", "IdentifierPart", "EscapedSymbolicName", "SP", + "WHITESPACE", "CypherComment", "Unknown" + } + ); + static const int32_t serializedATNSegment[] = { + 4,0,186,1527,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6, + 7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13,2, + 14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20,7,20,2, + 21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26,2,27,7,27,2, + 28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33,7,33,2,34,7,34,2, + 35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2,39,7,39,2,40,7,40,2,41,7,41,2, + 42,7,42,2,43,7,43,2,44,7,44,2,45,7,45,2,46,7,46,2,47,7,47,2,48,7,48,2, + 49,7,49,2,50,7,50,2,51,7,51,2,52,7,52,2,53,7,53,2,54,7,54,2,55,7,55,2, + 56,7,56,2,57,7,57,2,58,7,58,2,59,7,59,2,60,7,60,2,61,7,61,2,62,7,62,2, + 63,7,63,2,64,7,64,2,65,7,65,2,66,7,66,2,67,7,67,2,68,7,68,2,69,7,69,2, + 70,7,70,2,71,7,71,2,72,7,72,2,73,7,73,2,74,7,74,2,75,7,75,2,76,7,76,2, + 77,7,77,2,78,7,78,2,79,7,79,2,80,7,80,2,81,7,81,2,82,7,82,2,83,7,83,2, + 84,7,84,2,85,7,85,2,86,7,86,2,87,7,87,2,88,7,88,2,89,7,89,2,90,7,90,2, + 91,7,91,2,92,7,92,2,93,7,93,2,94,7,94,2,95,7,95,2,96,7,96,2,97,7,97,2, + 98,7,98,2,99,7,99,2,100,7,100,2,101,7,101,2,102,7,102,2,103,7,103,2,104, + 7,104,2,105,7,105,2,106,7,106,2,107,7,107,2,108,7,108,2,109,7,109,2,110, + 7,110,2,111,7,111,2,112,7,112,2,113,7,113,2,114,7,114,2,115,7,115,2,116, + 7,116,2,117,7,117,2,118,7,118,2,119,7,119,2,120,7,120,2,121,7,121,2,122, + 7,122,2,123,7,123,2,124,7,124,2,125,7,125,2,126,7,126,2,127,7,127,2,128, + 7,128,2,129,7,129,2,130,7,130,2,131,7,131,2,132,7,132,2,133,7,133,2,134, + 7,134,2,135,7,135,2,136,7,136,2,137,7,137,2,138,7,138,2,139,7,139,2,140, + 7,140,2,141,7,141,2,142,7,142,2,143,7,143,2,144,7,144,2,145,7,145,2,146, + 7,146,2,147,7,147,2,148,7,148,2,149,7,149,2,150,7,150,2,151,7,151,2,152, + 7,152,2,153,7,153,2,154,7,154,2,155,7,155,2,156,7,156,2,157,7,157,2,158, + 7,158,2,159,7,159,2,160,7,160,2,161,7,161,2,162,7,162,2,163,7,163,2,164, + 7,164,2,165,7,165,2,166,7,166,2,167,7,167,2,168,7,168,2,169,7,169,2,170, + 7,170,2,171,7,171,2,172,7,172,2,173,7,173,2,174,7,174,2,175,7,175,2,176, + 7,176,2,177,7,177,2,178,7,178,2,179,7,179,2,180,7,180,2,181,7,181,2,182, + 7,182,2,183,7,183,2,184,7,184,2,185,7,185,2,186,7,186,2,187,7,187,2,188, + 7,188,2,189,7,189,2,190,7,190,2,191,7,191,2,192,7,192,2,193,7,193,2,194, + 7,194,2,195,7,195,2,196,7,196,2,197,7,197,2,198,7,198,2,199,7,199,2,200, + 7,200,2,201,7,201,2,202,7,202,2,203,7,203,2,204,7,204,2,205,7,205,1,0, + 1,0,1,1,1,1,1,2,1,2,1,3,1,3,1,4,1,4,1,5,1,5,1,6,1,6,1,7,1,7,1,8,1,8,1, + 9,1,9,1,10,1,10,1,11,1,11,1,11,1,12,1,12,1,13,1,13,1,13,1,14,1,14,1,15, + 1,15,1,15,1,16,1,16,1,17,1,17,1,17,1,18,1,18,1,18,1,19,1,19,1,20,1,20, + 1,21,1,21,1,22,1,22,1,23,1,23,1,23,1,24,1,24,1,25,1,25,1,26,1,26,1,27, + 1,27,1,28,1,28,1,29,1,29,1,30,1,30,1,31,1,31,1,32,1,32,1,33,1,33,1,34, + 1,34,1,35,1,35,1,36,1,36,1,37,1,37,1,38,1,38,1,39,1,39,1,40,1,40,1,41, + 1,41,1,42,1,42,1,43,1,43,1,44,1,44,1,44,1,44,1,44,1,44,1,44,1,44,1,45, + 1,45,1,45,1,45,1,46,1,46,1,46,1,46,1,47,1,47,1,47,1,47,1,48,1,48,1,48, + 1,48,1,48,1,48,1,49,1,49,1,49,1,49,1,50,1,50,1,50,1,51,1,51,1,51,1,51, + 1,52,1,52,1,52,1,52,1,52,1,52,1,52,1,52,1,52,1,52,1,53,1,53,1,53,1,53, + 1,53,1,53,1,53,1,54,1,54,1,54,1,54,1,54,1,54,1,55,1,55,1,55,1,56,1,56, + 1,56,1,56,1,56,1,57,1,57,1,57,1,57,1,57,1,58,1,58,1,58,1,58,1,58,1,59, + 1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,60,1,60,1,60,1,60, + 1,60,1,60,1,60,1,61,1,61,1,61,1,61,1,61,1,61,1,61,1,61,1,62,1,62,1,62, + 1,62,1,62,1,62,1,62,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63, + 1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,64, + 1,64,1,64,1,64,1,64,1,64,1,64,1,64,1,64,1,65,1,65,1,65,1,65,1,65,1,66, + 1,66,1,66,1,66,1,66,1,66,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,68,1,68, + 1,68,1,68,1,68,1,68,1,69,1,69,1,69,1,69,1,69,1,69,1,69,1,69,1,69,1,70, + 1,70,1,70,1,70,1,70,1,70,1,70,1,71,1,71,1,71,1,71,1,71,1,71,1,71,1,71, + 1,72,1,72,1,72,1,72,1,72,1,72,1,72,1,73,1,73,1,73,1,73,1,73,1,74,1,74, + 1,74,1,74,1,74,1,74,1,74,1,74,1,74,1,74,1,74,1,75,1,75,1,75,1,75,1,75, + 1,75,1,75,1,76,1,76,1,76,1,76,1,76,1,76,1,76,1,76,1,76,1,77,1,77,1,77, + 1,77,1,77,1,78,1,78,1,78,1,78,1,78,1,79,1,79,1,79,1,79,1,80,1,80,1,80, + 1,80,1,80,1,81,1,81,1,81,1,81,1,81,1,81,1,81,1,82,1,82,1,82,1,82,1,82, + 1,82,1,82,1,82,1,83,1,83,1,83,1,83,1,83,1,83,1,83,1,84,1,84,1,84,1,84, + 1,84,1,84,1,84,1,84,1,84,1,84,1,85,1,85,1,85,1,85,1,85,1,85,1,86,1,86, + 1,86,1,86,1,86,1,87,1,87,1,87,1,87,1,87,1,87,1,88,1,88,1,88,1,88,1,88, + 1,89,1,89,1,89,1,89,1,89,1,89,1,90,1,90,1,90,1,90,1,90,1,90,1,91,1,91, + 1,91,1,91,1,91,1,91,1,91,1,91,1,92,1,92,1,92,1,92,1,92,1,93,1,93,1,93, + 1,93,1,93,1,93,1,93,1,94,1,94,1,94,1,95,1,95,1,95,1,96,1,96,1,96,1,96, + 1,96,1,96,1,96,1,96,1,96,1,96,1,97,1,97,1,97,1,97,1,97,1,97,1,97,1,97, + 1,98,1,98,1,98,1,99,1,99,1,99,1,99,1,99,1,100,1,100,1,100,1,100,1,101, + 1,101,1,101,1,101,1,101,1,101,1,102,1,102,1,102,1,102,1,102,1,103,1,103, + 1,103,1,103,1,103,1,103,1,103,1,103,1,104,1,104,1,104,1,104,1,104,1,104, + 1,105,1,105,1,105,1,105,1,105,1,105,1,106,1,106,1,106,1,106,1,106,1,106, + 1,106,1,106,1,106,1,107,1,107,1,107,1,107,1,107,1,107,1,108,1,108,1,108, + 1,108,1,108,1,108,1,108,1,108,1,108,1,109,1,109,1,109,1,109,1,109,1,109, + 1,109,1,109,1,109,1,109,1,109,1,110,1,110,1,110,1,111,1,111,1,111,1,111, + 1,111,1,112,1,112,1,112,1,112,1,113,1,113,1,113,1,113,1,113,1,114,1,114, + 1,114,1,114,1,114,1,115,1,115,1,115,1,116,1,116,1,116,1,116,1,116,1,117, + 1,117,1,117,1,117,1,117,1,117,1,117,1,117,1,117,1,118,1,118,1,118,1,119, + 1,119,1,119,1,119,1,119,1,119,1,120,1,120,1,120,1,120,1,120,1,120,1,120, + 1,120,1,121,1,121,1,121,1,121,1,121,1,121,1,121,1,121,1,122,1,122,1,122, + 1,122,1,122,1,122,1,122,1,122,1,123,1,123,1,123,1,123,1,123,1,124,1,124, + 1,124,1,124,1,125,1,125,1,125,1,125,1,125,1,125,1,125,1,126,1,126,1,126, + 1,126,1,126,1,126,1,126,1,127,1,127,1,127,1,127,1,127,1,127,1,127,1,127, + 1,127,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128, + 1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128,1,128, + 1,128,1,128,1,129,1,129,1,129,1,129,1,129,1,129,1,129,1,129,1,129,1,130, + 1,130,1,130,1,130,1,131,1,131,1,131,1,131,1,131,1,131,1,131,1,131,1,131, + 1,132,1,132,1,132,1,132,1,132,1,132,1,133,1,133,1,133,1,133,1,133,1,133, + 1,133,1,134,1,134,1,134,1,134,1,134,1,134,1,134,1,135,1,135,1,135,1,135, + 1,135,1,135,1,136,1,136,1,136,1,136,1,136,1,137,1,137,1,137,1,138,1,138, + 1,138,1,138,1,138,1,138,1,139,1,139,1,139,1,139,1,139,1,139,1,139,1,139, + 1,139,1,139,1,139,1,139,1,140,1,140,1,140,1,140,1,140,1,141,1,141,1,141, + 1,141,1,141,1,142,1,142,1,142,1,142,1,142,1,142,1,143,1,143,1,143,1,143, + 1,143,1,143,1,143,1,144,1,144,1,144,1,144,1,144,1,144,1,144,1,144,1,144, + 1,144,1,145,1,145,1,145,1,145,1,145,1,145,1,145,1,146,1,146,1,146,1,146, + 1,147,1,147,1,147,1,147,1,147,1,148,1,148,1,148,1,148,1,148,1,148,1,149, + 1,149,1,149,1,149,1,149,1,150,1,150,1,150,1,150,1,150,1,150,1,151,1,151, + 1,151,1,151,1,151,1,151,1,151,1,151,1,151,1,151,1,152,1,152,1,152,1,152, + 1,153,1,153,1,153,1,153,1,153,1,153,1,153,1,154,1,154,1,154,1,154,1,154, + 1,154,1,155,1,155,1,155,1,155,1,155,1,156,1,156,1,156,1,156,1,156,1,156, + 1,156,1,156,1,156,1,157,1,157,1,157,1,157,1,157,1,158,1,158,1,158,1,158, + 1,159,1,159,1,159,1,159,1,159,1,159,1,159,1,159,1,160,1,160,1,161,1,161, + 1,161,1,161,1,161,1,162,1,162,1,162,1,163,1,163,1,164,1,164,1,164,1,165, + 1,165,1,166,1,166,1,167,1,167,1,167,5,167,1294,8,167,10,167,12,167,1297, + 9,167,1,167,1,167,1,167,1,167,5,167,1303,8,167,10,167,12,167,1306,9,167, + 1,167,3,167,1309,8,167,1,168,1,168,1,168,1,168,1,168,1,168,1,168,1,168, + 1,168,1,168,1,168,1,168,1,168,1,168,1,168,1,168,1,168,1,168,1,168,1,168, + 1,168,1,168,3,168,1333,8,168,1,169,1,169,1,169,5,169,1338,8,169,10,169, + 12,169,1341,9,169,3,169,1343,8,169,1,170,3,170,1346,8,170,1,171,1,171, + 3,171,1350,8,171,1,172,1,172,3,172,1354,8,172,1,173,1,173,3,173,1358, + 8,173,1,174,1,174,1,175,1,175,1,176,4,176,1365,8,176,11,176,12,176,1366, + 1,176,4,176,1370,8,176,11,176,12,176,1371,1,176,1,176,4,176,1376,8,176, + 11,176,12,176,1377,1,176,1,176,4,176,1382,8,176,11,176,12,176,1383,3, + 176,1386,8,176,1,176,1,176,3,176,1390,8,176,1,176,4,176,1393,8,176,11, + 176,12,176,1394,1,177,5,177,1398,8,177,10,177,12,177,1401,9,177,1,177, + 1,177,4,177,1405,8,177,11,177,12,177,1406,1,178,1,178,5,178,1411,8,178, + 10,178,12,178,1414,9,178,1,179,1,179,3,179,1418,8,179,1,180,1,180,3,180, + 1422,8,180,1,181,1,181,5,181,1426,8,181,10,181,12,181,1429,9,181,1,181, + 4,181,1432,8,181,11,181,12,181,1433,1,182,4,182,1437,8,182,11,182,12, + 182,1438,1,183,1,183,1,183,1,183,1,183,1,183,1,183,1,183,1,183,1,183, + 1,183,1,183,3,183,1453,8,183,1,184,1,184,1,184,1,184,1,184,1,184,5,184, + 1461,8,184,10,184,12,184,1464,9,184,1,184,1,184,1,184,1,184,1,184,1,184, + 5,184,1472,8,184,10,184,12,184,1475,9,184,1,184,3,184,1478,8,184,1,184, + 1,184,3,184,1482,8,184,3,184,1484,8,184,1,185,1,185,1,186,1,186,1,187, + 1,187,1,188,1,188,1,189,1,189,1,190,1,190,1,191,1,191,1,192,1,192,1,193, + 1,193,1,194,1,194,1,195,1,195,1,196,1,196,1,197,1,197,1,198,1,198,1,199, + 1,199,1,200,1,200,1,201,1,201,1,202,1,202,1,203,1,203,1,204,1,204,1,205, + 1,205,0,0,206,1,1,3,2,5,3,7,4,9,5,11,6,13,7,15,8,17,9,19,10,21,11,23, + 12,25,13,27,14,29,15,31,16,33,17,35,18,37,19,39,20,41,21,43,22,45,23, + 47,24,49,25,51,26,53,27,55,28,57,29,59,30,61,31,63,32,65,33,67,34,69, + 35,71,36,73,37,75,38,77,39,79,40,81,41,83,42,85,43,87,44,89,45,91,46, + 93,47,95,48,97,49,99,50,101,51,103,52,105,53,107,54,109,55,111,56,113, + 57,115,58,117,59,119,60,121,61,123,62,125,63,127,64,129,65,131,66,133, + 67,135,68,137,69,139,70,141,71,143,72,145,73,147,74,149,75,151,76,153, + 77,155,78,157,79,159,80,161,81,163,82,165,83,167,84,169,85,171,86,173, + 87,175,88,177,89,179,90,181,91,183,92,185,93,187,94,189,95,191,96,193, + 97,195,98,197,99,199,100,201,101,203,102,205,103,207,104,209,105,211, + 106,213,107,215,108,217,109,219,110,221,111,223,112,225,113,227,114,229, + 115,231,116,233,117,235,118,237,119,239,120,241,121,243,122,245,123,247, + 124,249,125,251,126,253,127,255,128,257,129,259,130,261,131,263,132,265, + 133,267,134,269,135,271,136,273,137,275,138,277,139,279,140,281,141,283, + 142,285,143,287,144,289,145,291,146,293,147,295,148,297,149,299,150,301, + 151,303,152,305,153,307,154,309,155,311,156,313,157,315,158,317,159,319, + 160,321,161,323,162,325,163,327,164,329,165,331,166,333,167,335,168,337, + 169,339,170,341,171,343,172,345,173,347,174,349,175,351,176,353,177,355, + 178,357,179,359,180,361,181,363,182,365,183,367,184,369,185,371,0,373, + 0,375,0,377,0,379,0,381,0,383,0,385,0,387,0,389,0,391,0,393,0,395,0,397, + 0,399,0,401,0,403,0,405,0,407,0,409,0,411,186,1,0,48,2,0,65,65,97,97, + 2,0,67,67,99,99,2,0,89,89,121,121,2,0,76,76,108,108,2,0,73,73,105,105, + 2,0,78,78,110,110,2,0,68,68,100,100,2,0,84,84,116,116,2,0,69,69,101,101, + 2,0,82,82,114,114,2,0,83,83,115,115,2,0,71,71,103,103,2,0,72,72,104,104, + 2,0,66,66,98,98,2,0,75,75,107,107,2,0,80,80,112,112,2,0,79,79,111,111, + 2,0,85,85,117,117,2,0,77,77,109,109,2,0,70,70,102,102,2,0,88,88,120,120, + 2,0,74,74,106,106,2,0,86,86,118,118,2,0,81,81,113,113,2,0,87,87,119,119, + 13,0,34,34,39,39,66,66,70,70,78,78,82,82,84,84,92,92,98,98,102,102,110, + 110,114,114,116,116,2,0,65,70,97,102,8,0,160,160,5760,5760,6158,6158, + 8192,8202,8232,8233,8239,8239,8287,8287,12288,12288,1,0,12,12,1,0,96, + 96,1,0,30,30,768,0,48,57,65,90,95,95,97,122,170,170,181,181,183,183,186, + 186,192,214,216,246,248,705,710,721,736,740,748,748,750,750,768,884,886, + 887,890,893,895,895,902,906,908,908,910,929,931,1013,1015,1153,1155,1159, + 1162,1327,1329,1366,1369,1369,1376,1416,1425,1469,1471,1471,1473,1474, + 1476,1477,1479,1479,1488,1514,1519,1522,1552,1562,1568,1641,1646,1747, + 1749,1756,1759,1768,1770,1788,1791,1791,1808,1866,1869,1969,1984,2037, + 2042,2042,2045,2045,2048,2093,2112,2139,2144,2154,2160,2183,2185,2190, + 2200,2273,2275,2403,2406,2415,2417,2435,2437,2444,2447,2448,2451,2472, + 2474,2480,2482,2482,2486,2489,2492,2500,2503,2504,2507,2510,2519,2519, + 2524,2525,2527,2531,2534,2545,2556,2556,2558,2558,2561,2563,2565,2570, + 2575,2576,2579,2600,2602,2608,2610,2611,2613,2614,2616,2617,2620,2620, + 2622,2626,2631,2632,2635,2637,2641,2641,2649,2652,2654,2654,2662,2677, + 2689,2691,2693,2701,2703,2705,2707,2728,2730,2736,2738,2739,2741,2745, + 2748,2757,2759,2761,2763,2765,2768,2768,2784,2787,2790,2799,2809,2815, + 2817,2819,2821,2828,2831,2832,2835,2856,2858,2864,2866,2867,2869,2873, + 2876,2884,2887,2888,2891,2893,2901,2903,2908,2909,2911,2915,2918,2927, + 2929,2929,2946,2947,2949,2954,2958,2960,2962,2965,2969,2970,2972,2972, + 2974,2975,2979,2980,2984,2986,2990,3001,3006,3010,3014,3016,3018,3021, + 3024,3024,3031,3031,3046,3055,3072,3084,3086,3088,3090,3112,3114,3129, + 3132,3140,3142,3144,3146,3149,3157,3158,3160,3162,3165,3165,3168,3171, + 3174,3183,3200,3203,3205,3212,3214,3216,3218,3240,3242,3251,3253,3257, + 3260,3268,3270,3272,3274,3277,3285,3286,3293,3294,3296,3299,3302,3311, + 3313,3315,3328,3340,3342,3344,3346,3396,3398,3400,3402,3406,3412,3415, + 3423,3427,3430,3439,3450,3455,3457,3459,3461,3478,3482,3505,3507,3515, + 3517,3517,3520,3526,3530,3530,3535,3540,3542,3542,3544,3551,3558,3567, + 3570,3571,3585,3642,3648,3662,3664,3673,3713,3714,3716,3716,3718,3722, + 3724,3747,3749,3749,3751,3773,3776,3780,3782,3782,3784,3790,3792,3801, + 3804,3807,3840,3840,3864,3865,3872,3881,3893,3893,3895,3895,3897,3897, + 3902,3911,3913,3948,3953,3972,3974,3991,3993,4028,4038,4038,4096,4169, + 4176,4253,4256,4293,4295,4295,4301,4301,4304,4346,4348,4680,4682,4685, + 4688,4694,4696,4696,4698,4701,4704,4744,4746,4749,4752,4784,4786,4789, + 4792,4798,4800,4800,4802,4805,4808,4822,4824,4880,4882,4885,4888,4954, + 4957,4959,4969,4977,4992,5007,5024,5109,5112,5117,5121,5740,5743,5759, + 5761,5786,5792,5866,5870,5880,5888,5909,5919,5940,5952,5971,5984,5996, + 5998,6000,6002,6003,6016,6099,6103,6103,6108,6109,6112,6121,6155,6157, + 6159,6169,6176,6264,6272,6314,6320,6389,6400,6430,6432,6443,6448,6459, + 6470,6509,6512,6516,6528,6571,6576,6601,6608,6618,6656,6683,6688,6750, + 6752,6780,6783,6793,6800,6809,6823,6823,6832,6845,6847,6862,6912,6988, + 6992,7001,7019,7027,7040,7155,7168,7223,7232,7241,7245,7293,7296,7304, + 7312,7354,7357,7359,7376,7378,7380,7418,7424,7957,7960,7965,7968,8005, + 8008,8013,8016,8023,8025,8025,8027,8027,8029,8029,8031,8061,8064,8116, + 8118,8124,8126,8126,8130,8132,8134,8140,8144,8147,8150,8155,8160,8172, + 8178,8180,8182,8188,8255,8256,8276,8276,8305,8305,8319,8319,8336,8348, + 8400,8412,8417,8417,8421,8432,8450,8450,8455,8455,8458,8467,8469,8469, + 8472,8477,8484,8484,8486,8486,8488,8488,8490,8505,8508,8511,8517,8521, + 8526,8526,8544,8584,11264,11492,11499,11507,11520,11557,11559,11559,11565, + 11565,11568,11623,11631,11631,11647,11670,11680,11686,11688,11694,11696, + 11702,11704,11710,11712,11718,11720,11726,11728,11734,11736,11742,11744, + 11775,12293,12295,12321,12335,12337,12341,12344,12348,12353,12438,12441, + 12447,12449,12538,12540,12543,12549,12591,12593,12686,12704,12735,12784, + 12799,13312,19903,19968,42124,42192,42237,42240,42508,42512,42539,42560, + 42607,42612,42621,42623,42737,42775,42783,42786,42888,42891,42954,42960, + 42961,42963,42963,42965,42969,42994,43047,43052,43052,43072,43123,43136, + 43205,43216,43225,43232,43255,43259,43259,43261,43309,43312,43347,43360, + 43388,43392,43456,43471,43481,43488,43518,43520,43574,43584,43597,43600, + 43609,43616,43638,43642,43714,43739,43741,43744,43759,43762,43766,43777, + 43782,43785,43790,43793,43798,43808,43814,43816,43822,43824,43866,43868, + 43881,43888,44010,44012,44013,44016,44025,44032,55203,55216,55238,55243, + 55291,63744,64109,64112,64217,64256,64262,64275,64279,64285,64296,64298, + 64310,64312,64316,64318,64318,64320,64321,64323,64324,64326,64433,64467, + 64829,64848,64911,64914,64967,65008,65019,65024,65039,65056,65071,65075, + 65076,65101,65103,65136,65140,65142,65276,65296,65305,65313,65338,65343, + 65343,65345,65370,65382,65470,65474,65479,65482,65487,65490,65495,65498, + 65500,65536,65547,65549,65574,65576,65594,65596,65597,65599,65613,65616, + 65629,65664,65786,65856,65908,66045,66045,66176,66204,66208,66256,66272, + 66272,66304,66335,66349,66378,66384,66426,66432,66461,66464,66499,66504, + 66511,66513,66517,66560,66717,66720,66729,66736,66771,66776,66811,66816, + 66855,66864,66915,66928,66938,66940,66954,66956,66962,66964,66965,66967, + 66977,66979,66993,66995,67001,67003,67004,67072,67382,67392,67413,67424, + 67431,67456,67461,67463,67504,67506,67514,67584,67589,67592,67592,67594, + 67637,67639,67640,67644,67644,67647,67669,67680,67702,67712,67742,67808, + 67826,67828,67829,67840,67861,67872,67897,67968,68023,68030,68031,68096, + 68099,68101,68102,68108,68115,68117,68119,68121,68149,68152,68154,68159, + 68159,68192,68220,68224,68252,68288,68295,68297,68326,68352,68405,68416, + 68437,68448,68466,68480,68497,68608,68680,68736,68786,68800,68850,68864, + 68903,68912,68921,69248,69289,69291,69292,69296,69297,69373,69404,69415, + 69415,69424,69456,69488,69509,69552,69572,69600,69622,69632,69702,69734, + 69749,69759,69818,69826,69826,69840,69864,69872,69881,69888,69940,69942, + 69951,69956,69959,69968,70003,70006,70006,70016,70084,70089,70092,70094, + 70106,70108,70108,70144,70161,70163,70199,70206,70209,70272,70278,70280, + 70280,70282,70285,70287,70301,70303,70312,70320,70378,70384,70393,70400, + 70403,70405,70412,70415,70416,70419,70440,70442,70448,70450,70451,70453, + 70457,70459,70468,70471,70472,70475,70477,70480,70480,70487,70487,70493, + 70499,70502,70508,70512,70516,70656,70730,70736,70745,70750,70753,70784, + 70853,70855,70855,70864,70873,71040,71093,71096,71104,71128,71133,71168, + 71232,71236,71236,71248,71257,71296,71352,71360,71369,71424,71450,71453, + 71467,71472,71481,71488,71494,71680,71738,71840,71913,71935,71942,71945, + 71945,71948,71955,71957,71958,71960,71989,71991,71992,71995,72003,72016, + 72025,72096,72103,72106,72151,72154,72161,72163,72164,72192,72254,72263, + 72263,72272,72345,72349,72349,72368,72440,72704,72712,72714,72758,72760, + 72768,72784,72793,72818,72847,72850,72871,72873,72886,72960,72966,72968, + 72969,72971,73014,73018,73018,73020,73021,73023,73031,73040,73049,73056, + 73061,73063,73064,73066,73102,73104,73105,73107,73112,73120,73129,73440, + 73462,73472,73488,73490,73530,73534,73538,73552,73561,73648,73648,73728, + 74649,74752,74862,74880,75075,77712,77808,77824,78895,78912,78933,82944, + 83526,92160,92728,92736,92766,92768,92777,92784,92862,92864,92873,92880, + 92909,92912,92916,92928,92982,92992,92995,93008,93017,93027,93047,93053, + 93071,93760,93823,93952,94026,94031,94087,94095,94111,94176,94177,94179, + 94180,94192,94193,94208,100343,100352,101589,101632,101640,110576,110579, + 110581,110587,110589,110590,110592,110882,110898,110898,110928,110930, + 110933,110933,110948,110951,110960,111355,113664,113770,113776,113788, + 113792,113800,113808,113817,113821,113822,118528,118573,118576,118598, + 119141,119145,119149,119154,119163,119170,119173,119179,119210,119213, + 119362,119364,119808,119892,119894,119964,119966,119967,119970,119970, + 119973,119974,119977,119980,119982,119993,119995,119995,119997,120003, + 120005,120069,120071,120074,120077,120084,120086,120092,120094,120121, + 120123,120126,120128,120132,120134,120134,120138,120144,120146,120485, + 120488,120512,120514,120538,120540,120570,120572,120596,120598,120628, + 120630,120654,120656,120686,120688,120712,120714,120744,120746,120770, + 120772,120779,120782,120831,121344,121398,121403,121452,121461,121461, + 121476,121476,121499,121503,121505,121519,122624,122654,122661,122666, + 122880,122886,122888,122904,122907,122913,122915,122916,122918,122922, + 122928,122989,123023,123023,123136,123180,123184,123197,123200,123209, + 123214,123214,123536,123566,123584,123641,124112,124153,124896,124902, + 124904,124907,124909,124910,124912,124926,124928,125124,125136,125142, + 125184,125259,125264,125273,126464,126467,126469,126495,126497,126498, + 126500,126500,126503,126503,126505,126514,126516,126519,126521,126521, + 126523,126523,126530,126530,126535,126535,126537,126537,126539,126539, + 126541,126543,126545,126546,126548,126548,126551,126551,126553,126553, + 126555,126555,126557,126557,126559,126559,126561,126562,126564,126564, + 126567,126570,126572,126578,126580,126583,126585,126588,126590,126590, + 126592,126601,126603,126619,126625,126627,126629,126633,126635,126651, + 130032,130041,131072,173791,173824,177977,177984,178205,178208,183969, + 183984,191456,194560,195101,196608,201546,201552,205743,917760,917999, + 1,0,42,42,2,0,39,39,92,92,2,0,10,10,13,13,1,0,47,47,1,0,29,29,1,0,28, + 28,1,0,13,13,21,0,36,36,162,165,1423,1423,1547,1547,2046,2047,2546,2547, + 2555,2555,2801,2801,3065,3065,3647,3647,6107,6107,8352,8384,43064,43064, + 65020,65020,65129,65129,65284,65284,65504,65505,65509,65510,73693,73696, + 123647,123647,126128,126128,1,0,32,32,6,0,95,95,8255,8256,8276,8276,65075, + 65076,65101,65103,65343,65343,1,0,9,9,2,0,34,34,92,92,1,0,10,10,1,0,11, + 11,1,0,31,31,659,0,65,90,97,122,170,170,181,181,186,186,192,214,216,246, + 248,705,710,721,736,740,748,748,750,750,880,884,886,887,890,893,895,895, + 902,902,904,906,908,908,910,929,931,1013,1015,1153,1162,1327,1329,1366, + 1369,1369,1376,1416,1488,1514,1519,1522,1568,1610,1646,1647,1649,1747, + 1749,1749,1765,1766,1774,1775,1786,1788,1791,1791,1808,1808,1810,1839, + 1869,1957,1969,1969,1994,2026,2036,2037,2042,2042,2048,2069,2074,2074, + 2084,2084,2088,2088,2112,2136,2144,2154,2160,2183,2185,2190,2208,2249, + 2308,2361,2365,2365,2384,2384,2392,2401,2417,2432,2437,2444,2447,2448, + 2451,2472,2474,2480,2482,2482,2486,2489,2493,2493,2510,2510,2524,2525, + 2527,2529,2544,2545,2556,2556,2565,2570,2575,2576,2579,2600,2602,2608, + 2610,2611,2613,2614,2616,2617,2649,2652,2654,2654,2674,2676,2693,2701, + 2703,2705,2707,2728,2730,2736,2738,2739,2741,2745,2749,2749,2768,2768, + 2784,2785,2809,2809,2821,2828,2831,2832,2835,2856,2858,2864,2866,2867, + 2869,2873,2877,2877,2908,2909,2911,2913,2929,2929,2947,2947,2949,2954, + 2958,2960,2962,2965,2969,2970,2972,2972,2974,2975,2979,2980,2984,2986, + 2990,3001,3024,3024,3077,3084,3086,3088,3090,3112,3114,3129,3133,3133, + 3160,3162,3165,3165,3168,3169,3200,3200,3205,3212,3214,3216,3218,3240, + 3242,3251,3253,3257,3261,3261,3293,3294,3296,3297,3313,3314,3332,3340, + 3342,3344,3346,3386,3389,3389,3406,3406,3412,3414,3423,3425,3450,3455, + 3461,3478,3482,3505,3507,3515,3517,3517,3520,3526,3585,3632,3634,3635, + 3648,3654,3713,3714,3716,3716,3718,3722,3724,3747,3749,3749,3751,3760, + 3762,3763,3773,3773,3776,3780,3782,3782,3804,3807,3840,3840,3904,3911, + 3913,3948,3976,3980,4096,4138,4159,4159,4176,4181,4186,4189,4193,4193, + 4197,4198,4206,4208,4213,4225,4238,4238,4256,4293,4295,4295,4301,4301, + 4304,4346,4348,4680,4682,4685,4688,4694,4696,4696,4698,4701,4704,4744, + 4746,4749,4752,4784,4786,4789,4792,4798,4800,4800,4802,4805,4808,4822, + 4824,4880,4882,4885,4888,4954,4992,5007,5024,5109,5112,5117,5121,5740, + 5743,5759,5761,5786,5792,5866,5870,5880,5888,5905,5919,5937,5952,5969, + 5984,5996,5998,6000,6016,6067,6103,6103,6108,6108,6176,6264,6272,6312, + 6314,6314,6320,6389,6400,6430,6480,6509,6512,6516,6528,6571,6576,6601, + 6656,6678,6688,6740,6823,6823,6917,6963,6981,6988,7043,7072,7086,7087, + 7098,7141,7168,7203,7245,7247,7258,7293,7296,7304,7312,7354,7357,7359, + 7401,7404,7406,7411,7413,7414,7418,7418,7424,7615,7680,7957,7960,7965, + 7968,8005,8008,8013,8016,8023,8025,8025,8027,8027,8029,8029,8031,8061, + 8064,8116,8118,8124,8126,8126,8130,8132,8134,8140,8144,8147,8150,8155, + 8160,8172,8178,8180,8182,8188,8305,8305,8319,8319,8336,8348,8450,8450, + 8455,8455,8458,8467,8469,8469,8472,8477,8484,8484,8486,8486,8488,8488, + 8490,8505,8508,8511,8517,8521,8526,8526,8544,8584,11264,11492,11499,11502, + 11506,11507,11520,11557,11559,11559,11565,11565,11568,11623,11631,11631, + 11648,11670,11680,11686,11688,11694,11696,11702,11704,11710,11712,11718, + 11720,11726,11728,11734,11736,11742,12293,12295,12321,12329,12337,12341, + 12344,12348,12353,12438,12443,12447,12449,12538,12540,12543,12549,12591, + 12593,12686,12704,12735,12784,12799,13312,19903,19968,42124,42192,42237, + 42240,42508,42512,42527,42538,42539,42560,42606,42623,42653,42656,42735, + 42775,42783,42786,42888,42891,42954,42960,42961,42963,42963,42965,42969, + 42994,43009,43011,43013,43015,43018,43020,43042,43072,43123,43138,43187, + 43250,43255,43259,43259,43261,43262,43274,43301,43312,43334,43360,43388, + 43396,43442,43471,43471,43488,43492,43494,43503,43514,43518,43520,43560, + 43584,43586,43588,43595,43616,43638,43642,43642,43646,43695,43697,43697, + 43701,43702,43705,43709,43712,43712,43714,43714,43739,43741,43744,43754, + 43762,43764,43777,43782,43785,43790,43793,43798,43808,43814,43816,43822, + 43824,43866,43868,43881,43888,44002,44032,55203,55216,55238,55243,55291, + 63744,64109,64112,64217,64256,64262,64275,64279,64285,64285,64287,64296, + 64298,64310,64312,64316,64318,64318,64320,64321,64323,64324,64326,64433, + 64467,64829,64848,64911,64914,64967,65008,65019,65136,65140,65142,65276, + 65313,65338,65345,65370,65382,65470,65474,65479,65482,65487,65490,65495, + 65498,65500,65536,65547,65549,65574,65576,65594,65596,65597,65599,65613, + 65616,65629,65664,65786,65856,65908,66176,66204,66208,66256,66304,66335, + 66349,66378,66384,66421,66432,66461,66464,66499,66504,66511,66513,66517, + 66560,66717,66736,66771,66776,66811,66816,66855,66864,66915,66928,66938, + 66940,66954,66956,66962,66964,66965,66967,66977,66979,66993,66995,67001, + 67003,67004,67072,67382,67392,67413,67424,67431,67456,67461,67463,67504, + 67506,67514,67584,67589,67592,67592,67594,67637,67639,67640,67644,67644, + 67647,67669,67680,67702,67712,67742,67808,67826,67828,67829,67840,67861, + 67872,67897,67968,68023,68030,68031,68096,68096,68112,68115,68117,68119, + 68121,68149,68192,68220,68224,68252,68288,68295,68297,68324,68352,68405, + 68416,68437,68448,68466,68480,68497,68608,68680,68736,68786,68800,68850, + 68864,68899,69248,69289,69296,69297,69376,69404,69415,69415,69424,69445, + 69488,69505,69552,69572,69600,69622,69635,69687,69745,69746,69749,69749, + 69763,69807,69840,69864,69891,69926,69956,69956,69959,69959,69968,70002, + 70006,70006,70019,70066,70081,70084,70106,70106,70108,70108,70144,70161, + 70163,70187,70207,70208,70272,70278,70280,70280,70282,70285,70287,70301, + 70303,70312,70320,70366,70405,70412,70415,70416,70419,70440,70442,70448, + 70450,70451,70453,70457,70461,70461,70480,70480,70493,70497,70656,70708, + 70727,70730,70751,70753,70784,70831,70852,70853,70855,70855,71040,71086, + 71128,71131,71168,71215,71236,71236,71296,71338,71352,71352,71424,71450, + 71488,71494,71680,71723,71840,71903,71935,71942,71945,71945,71948,71955, + 71957,71958,71960,71983,71999,71999,72001,72001,72096,72103,72106,72144, + 72161,72161,72163,72163,72192,72192,72203,72242,72250,72250,72272,72272, + 72284,72329,72349,72349,72368,72440,72704,72712,72714,72750,72768,72768, + 72818,72847,72960,72966,72968,72969,72971,73008,73030,73030,73056,73061, + 73063,73064,73066,73097,73112,73112,73440,73458,73474,73474,73476,73488, + 73490,73523,73648,73648,73728,74649,74752,74862,74880,75075,77712,77808, + 77824,78895,78913,78918,82944,83526,92160,92728,92736,92766,92784,92862, + 92880,92909,92928,92975,92992,92995,93027,93047,93053,93071,93760,93823, + 93952,94026,94032,94032,94099,94111,94176,94177,94179,94179,94208,100343, + 100352,101589,101632,101640,110576,110579,110581,110587,110589,110590, + 110592,110882,110898,110898,110928,110930,110933,110933,110948,110951, + 110960,111355,113664,113770,113776,113788,113792,113800,113808,113817, + 119808,119892,119894,119964,119966,119967,119970,119970,119973,119974, + 119977,119980,119982,119993,119995,119995,119997,120003,120005,120069, + 120071,120074,120077,120084,120086,120092,120094,120121,120123,120126, + 120128,120132,120134,120134,120138,120144,120146,120485,120488,120512, + 120514,120538,120540,120570,120572,120596,120598,120628,120630,120654, + 120656,120686,120688,120712,120714,120744,120746,120770,120772,120779, + 122624,122654,122661,122666,122928,122989,123136,123180,123191,123197, + 123214,123214,123536,123565,123584,123627,124112,124139,124896,124902, + 124904,124907,124909,124910,124912,124926,124928,125124,125184,125251, + 125259,125259,126464,126467,126469,126495,126497,126498,126500,126500, + 126503,126503,126505,126514,126516,126519,126521,126521,126523,126523, + 126530,126530,126535,126535,126537,126537,126539,126539,126541,126543, + 126545,126546,126548,126548,126551,126551,126553,126553,126555,126555, + 126557,126557,126559,126559,126561,126562,126564,126564,126567,126570, + 126572,126578,126580,126583,126585,126588,126590,126590,126592,126601, + 126603,126619,126625,126627,126629,126633,126635,126651,131072,173791, + 173824,177977,177984,178205,178208,183969,183984,191456,194560,195101, + 196608,201546,201552,205743,1552,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0, + 0,7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17,1, + 0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0,27,1,0,0, + 0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0,37,1,0,0,0,0, + 39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0,47,1,0,0,0,0,49,1, + 0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0,57,1,0,0,0,0,59,1,0,0, + 0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0,67,1,0,0,0,0,69,1,0,0,0,0, + 71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0,77,1,0,0,0,0,79,1,0,0,0,0,81,1, + 0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0,87,1,0,0,0,0,89,1,0,0,0,0,91,1,0,0, + 0,0,93,1,0,0,0,0,95,1,0,0,0,0,97,1,0,0,0,0,99,1,0,0,0,0,101,1,0,0,0,0, + 103,1,0,0,0,0,105,1,0,0,0,0,107,1,0,0,0,0,109,1,0,0,0,0,111,1,0,0,0,0, + 113,1,0,0,0,0,115,1,0,0,0,0,117,1,0,0,0,0,119,1,0,0,0,0,121,1,0,0,0,0, + 123,1,0,0,0,0,125,1,0,0,0,0,127,1,0,0,0,0,129,1,0,0,0,0,131,1,0,0,0,0, + 133,1,0,0,0,0,135,1,0,0,0,0,137,1,0,0,0,0,139,1,0,0,0,0,141,1,0,0,0,0, + 143,1,0,0,0,0,145,1,0,0,0,0,147,1,0,0,0,0,149,1,0,0,0,0,151,1,0,0,0,0, + 153,1,0,0,0,0,155,1,0,0,0,0,157,1,0,0,0,0,159,1,0,0,0,0,161,1,0,0,0,0, + 163,1,0,0,0,0,165,1,0,0,0,0,167,1,0,0,0,0,169,1,0,0,0,0,171,1,0,0,0,0, + 173,1,0,0,0,0,175,1,0,0,0,0,177,1,0,0,0,0,179,1,0,0,0,0,181,1,0,0,0,0, + 183,1,0,0,0,0,185,1,0,0,0,0,187,1,0,0,0,0,189,1,0,0,0,0,191,1,0,0,0,0, + 193,1,0,0,0,0,195,1,0,0,0,0,197,1,0,0,0,0,199,1,0,0,0,0,201,1,0,0,0,0, + 203,1,0,0,0,0,205,1,0,0,0,0,207,1,0,0,0,0,209,1,0,0,0,0,211,1,0,0,0,0, + 213,1,0,0,0,0,215,1,0,0,0,0,217,1,0,0,0,0,219,1,0,0,0,0,221,1,0,0,0,0, + 223,1,0,0,0,0,225,1,0,0,0,0,227,1,0,0,0,0,229,1,0,0,0,0,231,1,0,0,0,0, + 233,1,0,0,0,0,235,1,0,0,0,0,237,1,0,0,0,0,239,1,0,0,0,0,241,1,0,0,0,0, + 243,1,0,0,0,0,245,1,0,0,0,0,247,1,0,0,0,0,249,1,0,0,0,0,251,1,0,0,0,0, + 253,1,0,0,0,0,255,1,0,0,0,0,257,1,0,0,0,0,259,1,0,0,0,0,261,1,0,0,0,0, + 263,1,0,0,0,0,265,1,0,0,0,0,267,1,0,0,0,0,269,1,0,0,0,0,271,1,0,0,0,0, + 273,1,0,0,0,0,275,1,0,0,0,0,277,1,0,0,0,0,279,1,0,0,0,0,281,1,0,0,0,0, + 283,1,0,0,0,0,285,1,0,0,0,0,287,1,0,0,0,0,289,1,0,0,0,0,291,1,0,0,0,0, + 293,1,0,0,0,0,295,1,0,0,0,0,297,1,0,0,0,0,299,1,0,0,0,0,301,1,0,0,0,0, + 303,1,0,0,0,0,305,1,0,0,0,0,307,1,0,0,0,0,309,1,0,0,0,0,311,1,0,0,0,0, + 313,1,0,0,0,0,315,1,0,0,0,0,317,1,0,0,0,0,319,1,0,0,0,0,321,1,0,0,0,0, + 323,1,0,0,0,0,325,1,0,0,0,0,327,1,0,0,0,0,329,1,0,0,0,0,331,1,0,0,0,0, + 333,1,0,0,0,0,335,1,0,0,0,0,337,1,0,0,0,0,339,1,0,0,0,0,341,1,0,0,0,0, + 343,1,0,0,0,0,345,1,0,0,0,0,347,1,0,0,0,0,349,1,0,0,0,0,351,1,0,0,0,0, + 353,1,0,0,0,0,355,1,0,0,0,0,357,1,0,0,0,0,359,1,0,0,0,0,361,1,0,0,0,0, + 363,1,0,0,0,0,365,1,0,0,0,0,367,1,0,0,0,0,369,1,0,0,0,0,411,1,0,0,0,1, + 413,1,0,0,0,3,415,1,0,0,0,5,417,1,0,0,0,7,419,1,0,0,0,9,421,1,0,0,0,11, + 423,1,0,0,0,13,425,1,0,0,0,15,427,1,0,0,0,17,429,1,0,0,0,19,431,1,0,0, + 0,21,433,1,0,0,0,23,435,1,0,0,0,25,438,1,0,0,0,27,440,1,0,0,0,29,443, + 1,0,0,0,31,445,1,0,0,0,33,448,1,0,0,0,35,450,1,0,0,0,37,453,1,0,0,0,39, + 456,1,0,0,0,41,458,1,0,0,0,43,460,1,0,0,0,45,462,1,0,0,0,47,464,1,0,0, + 0,49,467,1,0,0,0,51,469,1,0,0,0,53,471,1,0,0,0,55,473,1,0,0,0,57,475, + 1,0,0,0,59,477,1,0,0,0,61,479,1,0,0,0,63,481,1,0,0,0,65,483,1,0,0,0,67, + 485,1,0,0,0,69,487,1,0,0,0,71,489,1,0,0,0,73,491,1,0,0,0,75,493,1,0,0, + 0,77,495,1,0,0,0,79,497,1,0,0,0,81,499,1,0,0,0,83,501,1,0,0,0,85,503, + 1,0,0,0,87,505,1,0,0,0,89,507,1,0,0,0,91,515,1,0,0,0,93,519,1,0,0,0,95, + 523,1,0,0,0,97,527,1,0,0,0,99,533,1,0,0,0,101,537,1,0,0,0,103,540,1,0, + 0,0,105,544,1,0,0,0,107,554,1,0,0,0,109,561,1,0,0,0,111,567,1,0,0,0,113, + 570,1,0,0,0,115,575,1,0,0,0,117,580,1,0,0,0,119,585,1,0,0,0,121,596,1, + 0,0,0,123,603,1,0,0,0,125,611,1,0,0,0,127,618,1,0,0,0,129,641,1,0,0,0, + 131,650,1,0,0,0,133,655,1,0,0,0,135,661,1,0,0,0,137,668,1,0,0,0,139,674, + 1,0,0,0,141,683,1,0,0,0,143,690,1,0,0,0,145,698,1,0,0,0,147,705,1,0,0, + 0,149,710,1,0,0,0,151,721,1,0,0,0,153,728,1,0,0,0,155,737,1,0,0,0,157, + 742,1,0,0,0,159,747,1,0,0,0,161,751,1,0,0,0,163,756,1,0,0,0,165,763,1, + 0,0,0,167,771,1,0,0,0,169,778,1,0,0,0,171,788,1,0,0,0,173,794,1,0,0,0, + 175,799,1,0,0,0,177,805,1,0,0,0,179,810,1,0,0,0,181,816,1,0,0,0,183,822, + 1,0,0,0,185,830,1,0,0,0,187,835,1,0,0,0,189,842,1,0,0,0,191,845,1,0,0, + 0,193,848,1,0,0,0,195,858,1,0,0,0,197,866,1,0,0,0,199,869,1,0,0,0,201, + 874,1,0,0,0,203,878,1,0,0,0,205,884,1,0,0,0,207,889,1,0,0,0,209,897,1, + 0,0,0,211,903,1,0,0,0,213,909,1,0,0,0,215,918,1,0,0,0,217,924,1,0,0,0, + 219,933,1,0,0,0,221,944,1,0,0,0,223,947,1,0,0,0,225,952,1,0,0,0,227,956, + 1,0,0,0,229,961,1,0,0,0,231,966,1,0,0,0,233,969,1,0,0,0,235,974,1,0,0, + 0,237,983,1,0,0,0,239,986,1,0,0,0,241,992,1,0,0,0,243,1000,1,0,0,0,245, + 1008,1,0,0,0,247,1016,1,0,0,0,249,1021,1,0,0,0,251,1025,1,0,0,0,253,1032, + 1,0,0,0,255,1039,1,0,0,0,257,1048,1,0,0,0,259,1073,1,0,0,0,261,1082,1, + 0,0,0,263,1086,1,0,0,0,265,1095,1,0,0,0,267,1101,1,0,0,0,269,1108,1,0, + 0,0,271,1115,1,0,0,0,273,1121,1,0,0,0,275,1126,1,0,0,0,277,1129,1,0,0, + 0,279,1135,1,0,0,0,281,1147,1,0,0,0,283,1152,1,0,0,0,285,1157,1,0,0,0, + 287,1163,1,0,0,0,289,1170,1,0,0,0,291,1180,1,0,0,0,293,1187,1,0,0,0,295, + 1191,1,0,0,0,297,1196,1,0,0,0,299,1202,1,0,0,0,301,1207,1,0,0,0,303,1213, + 1,0,0,0,305,1223,1,0,0,0,307,1227,1,0,0,0,309,1234,1,0,0,0,311,1240,1, + 0,0,0,313,1245,1,0,0,0,315,1254,1,0,0,0,317,1259,1,0,0,0,319,1263,1,0, + 0,0,321,1271,1,0,0,0,323,1273,1,0,0,0,325,1278,1,0,0,0,327,1281,1,0,0, + 0,329,1283,1,0,0,0,331,1286,1,0,0,0,333,1288,1,0,0,0,335,1308,1,0,0,0, + 337,1310,1,0,0,0,339,1342,1,0,0,0,341,1345,1,0,0,0,343,1349,1,0,0,0,345, + 1353,1,0,0,0,347,1357,1,0,0,0,349,1359,1,0,0,0,351,1361,1,0,0,0,353,1385, + 1,0,0,0,355,1399,1,0,0,0,357,1408,1,0,0,0,359,1417,1,0,0,0,361,1421,1, + 0,0,0,363,1431,1,0,0,0,365,1436,1,0,0,0,367,1452,1,0,0,0,369,1483,1,0, + 0,0,371,1485,1,0,0,0,373,1487,1,0,0,0,375,1489,1,0,0,0,377,1491,1,0,0, + 0,379,1493,1,0,0,0,381,1495,1,0,0,0,383,1497,1,0,0,0,385,1499,1,0,0,0, + 387,1501,1,0,0,0,389,1503,1,0,0,0,391,1505,1,0,0,0,393,1507,1,0,0,0,395, + 1509,1,0,0,0,397,1511,1,0,0,0,399,1513,1,0,0,0,401,1515,1,0,0,0,403,1517, + 1,0,0,0,405,1519,1,0,0,0,407,1521,1,0,0,0,409,1523,1,0,0,0,411,1525,1, + 0,0,0,413,414,5,59,0,0,414,2,1,0,0,0,415,416,5,40,0,0,416,4,1,0,0,0,417, + 418,5,41,0,0,418,6,1,0,0,0,419,420,5,44,0,0,420,8,1,0,0,0,421,422,5,46, + 0,0,422,10,1,0,0,0,423,424,5,61,0,0,424,12,1,0,0,0,425,426,5,91,0,0,426, + 14,1,0,0,0,427,428,5,93,0,0,428,16,1,0,0,0,429,430,5,123,0,0,430,18,1, + 0,0,0,431,432,5,125,0,0,432,20,1,0,0,0,433,434,5,124,0,0,434,22,1,0,0, + 0,435,436,5,60,0,0,436,437,5,62,0,0,437,24,1,0,0,0,438,439,5,60,0,0,439, + 26,1,0,0,0,440,441,5,60,0,0,441,442,5,61,0,0,442,28,1,0,0,0,443,444,5, + 62,0,0,444,30,1,0,0,0,445,446,5,62,0,0,446,447,5,61,0,0,447,32,1,0,0, + 0,448,449,5,38,0,0,449,34,1,0,0,0,450,451,5,62,0,0,451,452,5,62,0,0,452, + 36,1,0,0,0,453,454,5,60,0,0,454,455,5,60,0,0,455,38,1,0,0,0,456,457,5, + 43,0,0,457,40,1,0,0,0,458,459,5,47,0,0,459,42,1,0,0,0,460,461,5,37,0, + 0,461,44,1,0,0,0,462,463,5,94,0,0,463,46,1,0,0,0,464,465,5,61,0,0,465, + 466,5,126,0,0,466,48,1,0,0,0,467,468,5,36,0,0,468,50,1,0,0,0,469,470, + 5,10216,0,0,470,52,1,0,0,0,471,472,5,12296,0,0,472,54,1,0,0,0,473,474, + 5,65124,0,0,474,56,1,0,0,0,475,476,5,65308,0,0,476,58,1,0,0,0,477,478, + 5,10217,0,0,478,60,1,0,0,0,479,480,5,12297,0,0,480,62,1,0,0,0,481,482, + 5,65125,0,0,482,64,1,0,0,0,483,484,5,65310,0,0,484,66,1,0,0,0,485,486, + 5,173,0,0,486,68,1,0,0,0,487,488,5,8208,0,0,488,70,1,0,0,0,489,490,5, + 8209,0,0,490,72,1,0,0,0,491,492,5,8210,0,0,492,74,1,0,0,0,493,494,5,8211, + 0,0,494,76,1,0,0,0,495,496,5,8212,0,0,496,78,1,0,0,0,497,498,5,8213,0, + 0,498,80,1,0,0,0,499,500,5,8722,0,0,500,82,1,0,0,0,501,502,5,65112,0, + 0,502,84,1,0,0,0,503,504,5,65123,0,0,504,86,1,0,0,0,505,506,5,65293,0, + 0,506,88,1,0,0,0,507,508,7,0,0,0,508,509,7,1,0,0,509,510,7,2,0,0,510, + 511,7,1,0,0,511,512,7,3,0,0,512,513,7,4,0,0,513,514,7,1,0,0,514,90,1, + 0,0,0,515,516,7,0,0,0,516,517,7,5,0,0,517,518,7,2,0,0,518,92,1,0,0,0, + 519,520,7,0,0,0,520,521,7,6,0,0,521,522,7,6,0,0,522,94,1,0,0,0,523,524, + 7,0,0,0,524,525,7,3,0,0,525,526,7,3,0,0,526,96,1,0,0,0,527,528,7,0,0, + 0,528,529,7,3,0,0,529,530,7,7,0,0,530,531,7,8,0,0,531,532,7,9,0,0,532, + 98,1,0,0,0,533,534,7,0,0,0,534,535,7,5,0,0,535,536,7,6,0,0,536,100,1, + 0,0,0,537,538,7,0,0,0,538,539,7,10,0,0,539,102,1,0,0,0,540,541,7,0,0, + 0,541,542,7,10,0,0,542,543,7,1,0,0,543,104,1,0,0,0,544,545,7,0,0,0,545, + 546,7,10,0,0,546,547,7,1,0,0,547,548,7,8,0,0,548,549,7,5,0,0,549,550, + 7,6,0,0,550,551,7,4,0,0,551,552,7,5,0,0,552,553,7,11,0,0,553,106,1,0, + 0,0,554,555,7,0,0,0,555,556,7,7,0,0,556,557,7,7,0,0,557,558,7,0,0,0,558, + 559,7,1,0,0,559,560,7,12,0,0,560,108,1,0,0,0,561,562,7,13,0,0,562,563, + 7,8,0,0,563,564,7,11,0,0,564,565,7,4,0,0,565,566,7,5,0,0,566,110,1,0, + 0,0,567,568,7,13,0,0,568,569,7,2,0,0,569,112,1,0,0,0,570,571,7,1,0,0, + 571,572,7,0,0,0,572,573,7,3,0,0,573,574,7,3,0,0,574,114,1,0,0,0,575,576, + 7,1,0,0,576,577,7,0,0,0,577,578,7,10,0,0,578,579,7,8,0,0,579,116,1,0, + 0,0,580,581,7,1,0,0,581,582,7,0,0,0,582,583,7,10,0,0,583,584,7,7,0,0, + 584,118,1,0,0,0,585,586,7,1,0,0,586,587,7,12,0,0,587,588,7,8,0,0,588, + 589,7,1,0,0,589,590,7,14,0,0,590,591,7,15,0,0,591,592,7,16,0,0,592,593, + 7,4,0,0,593,594,7,5,0,0,594,595,7,7,0,0,595,120,1,0,0,0,596,597,7,1,0, + 0,597,598,7,16,0,0,598,599,7,3,0,0,599,600,7,17,0,0,600,601,7,18,0,0, + 601,602,7,5,0,0,602,122,1,0,0,0,603,604,7,1,0,0,604,605,7,16,0,0,605, + 606,7,18,0,0,606,607,7,18,0,0,607,608,7,8,0,0,608,609,7,5,0,0,609,610, + 7,7,0,0,610,124,1,0,0,0,611,612,7,1,0,0,612,613,7,16,0,0,613,614,7,18, + 0,0,614,615,7,18,0,0,615,616,7,4,0,0,616,617,7,7,0,0,617,126,1,0,0,0, + 618,619,7,1,0,0,619,620,7,16,0,0,620,621,7,18,0,0,621,622,7,18,0,0,622, + 623,7,4,0,0,623,624,7,7,0,0,624,625,5,95,0,0,625,626,7,10,0,0,626,627, + 7,14,0,0,627,628,7,4,0,0,628,629,7,15,0,0,629,630,5,95,0,0,630,631,7, + 1,0,0,631,632,7,12,0,0,632,633,7,8,0,0,633,634,7,1,0,0,634,635,7,14,0, + 0,635,636,7,15,0,0,636,637,7,16,0,0,637,638,7,4,0,0,638,639,7,5,0,0,639, + 640,7,7,0,0,640,128,1,0,0,0,641,642,7,1,0,0,642,643,7,16,0,0,643,644, + 7,5,0,0,644,645,7,7,0,0,645,646,7,0,0,0,646,647,7,4,0,0,647,648,7,5,0, + 0,648,649,7,10,0,0,649,130,1,0,0,0,650,651,7,1,0,0,651,652,7,16,0,0,652, + 653,7,15,0,0,653,654,7,2,0,0,654,132,1,0,0,0,655,656,7,1,0,0,656,657, + 7,16,0,0,657,658,7,17,0,0,658,659,7,5,0,0,659,660,7,7,0,0,660,134,1,0, + 0,0,661,662,7,1,0,0,662,663,7,9,0,0,663,664,7,8,0,0,664,665,7,0,0,0,665, + 666,7,7,0,0,666,667,7,8,0,0,667,136,1,0,0,0,668,669,7,1,0,0,669,670,7, + 2,0,0,670,671,7,1,0,0,671,672,7,3,0,0,672,673,7,8,0,0,673,138,1,0,0,0, + 674,675,7,6,0,0,675,676,7,0,0,0,676,677,7,7,0,0,677,678,7,0,0,0,678,679, + 7,13,0,0,679,680,7,0,0,0,680,681,7,10,0,0,681,682,7,8,0,0,682,140,1,0, + 0,0,683,684,7,6,0,0,684,685,7,13,0,0,685,686,7,7,0,0,686,687,7,2,0,0, + 687,688,7,15,0,0,688,689,7,8,0,0,689,142,1,0,0,0,690,691,7,6,0,0,691, + 692,7,8,0,0,692,693,7,19,0,0,693,694,7,0,0,0,694,695,7,17,0,0,695,696, + 7,3,0,0,696,697,7,7,0,0,697,144,1,0,0,0,698,699,7,6,0,0,699,700,7,8,0, + 0,700,701,7,3,0,0,701,702,7,8,0,0,702,703,7,7,0,0,703,704,7,8,0,0,704, + 146,1,0,0,0,705,706,7,6,0,0,706,707,7,8,0,0,707,708,7,10,0,0,708,709, + 7,1,0,0,709,148,1,0,0,0,710,711,7,6,0,0,711,712,7,8,0,0,712,713,7,10, + 0,0,713,714,7,1,0,0,714,715,7,8,0,0,715,716,7,5,0,0,716,717,7,6,0,0,717, + 718,7,4,0,0,718,719,7,5,0,0,719,720,7,11,0,0,720,150,1,0,0,0,721,722, + 7,6,0,0,722,723,7,8,0,0,723,724,7,7,0,0,724,725,7,0,0,0,725,726,7,1,0, + 0,726,727,7,12,0,0,727,152,1,0,0,0,728,729,7,6,0,0,729,730,7,4,0,0,730, + 731,7,10,0,0,731,732,7,7,0,0,732,733,7,4,0,0,733,734,7,5,0,0,734,735, + 7,1,0,0,735,736,7,7,0,0,736,154,1,0,0,0,737,738,7,6,0,0,738,739,7,9,0, + 0,739,740,7,16,0,0,740,741,7,15,0,0,741,156,1,0,0,0,742,743,7,8,0,0,743, + 744,7,3,0,0,744,745,7,10,0,0,745,746,7,8,0,0,746,158,1,0,0,0,747,748, + 7,8,0,0,748,749,7,5,0,0,749,750,7,6,0,0,750,160,1,0,0,0,751,752,7,8,0, + 0,752,753,7,5,0,0,753,754,7,6,0,0,754,755,7,10,0,0,755,162,1,0,0,0,756, + 757,7,8,0,0,757,758,7,20,0,0,758,759,7,4,0,0,759,760,7,10,0,0,760,761, + 7,7,0,0,761,762,7,10,0,0,762,164,1,0,0,0,763,764,7,8,0,0,764,765,7,20, + 0,0,765,766,7,15,0,0,766,767,7,3,0,0,767,768,7,0,0,0,768,769,7,4,0,0, + 769,770,7,5,0,0,770,166,1,0,0,0,771,772,7,8,0,0,772,773,7,20,0,0,773, + 774,7,15,0,0,774,775,7,16,0,0,775,776,7,9,0,0,776,777,7,7,0,0,777,168, + 1,0,0,0,778,779,7,8,0,0,779,780,7,20,0,0,780,781,7,7,0,0,781,782,7,8, + 0,0,782,783,7,5,0,0,783,784,7,10,0,0,784,785,7,4,0,0,785,786,7,16,0,0, + 786,787,7,5,0,0,787,170,1,0,0,0,788,789,7,19,0,0,789,790,7,0,0,0,790, + 791,7,3,0,0,791,792,7,10,0,0,792,793,7,8,0,0,793,172,1,0,0,0,794,795, + 7,19,0,0,795,796,7,9,0,0,796,797,7,16,0,0,797,798,7,18,0,0,798,174,1, + 0,0,0,799,800,7,19,0,0,800,801,7,16,0,0,801,802,7,9,0,0,802,803,7,1,0, + 0,803,804,7,8,0,0,804,176,1,0,0,0,805,806,7,11,0,0,806,807,7,3,0,0,807, + 808,7,16,0,0,808,809,7,13,0,0,809,178,1,0,0,0,810,811,7,11,0,0,811,812, + 7,9,0,0,812,813,7,0,0,0,813,814,7,15,0,0,814,815,7,12,0,0,815,180,1,0, + 0,0,816,817,7,11,0,0,817,818,7,9,0,0,818,819,7,16,0,0,819,820,7,17,0, + 0,820,821,7,15,0,0,821,182,1,0,0,0,822,823,7,12,0,0,823,824,7,8,0,0,824, + 825,7,0,0,0,825,826,7,6,0,0,826,827,7,8,0,0,827,828,7,9,0,0,828,829,7, + 10,0,0,829,184,1,0,0,0,830,831,7,12,0,0,831,832,7,4,0,0,832,833,7,5,0, + 0,833,834,7,7,0,0,834,186,1,0,0,0,835,836,7,4,0,0,836,837,7,18,0,0,837, + 838,7,15,0,0,838,839,7,16,0,0,839,840,7,9,0,0,840,841,7,7,0,0,841,188, + 1,0,0,0,842,843,7,4,0,0,843,844,7,19,0,0,844,190,1,0,0,0,845,846,7,4, + 0,0,846,847,7,5,0,0,847,192,1,0,0,0,848,849,7,4,0,0,849,850,7,5,0,0,850, + 851,7,1,0,0,851,852,7,9,0,0,852,853,7,8,0,0,853,854,7,18,0,0,854,855, + 7,8,0,0,855,856,7,5,0,0,856,857,7,7,0,0,857,194,1,0,0,0,858,859,7,4,0, + 0,859,860,7,5,0,0,860,861,7,10,0,0,861,862,7,7,0,0,862,863,7,0,0,0,863, + 864,7,3,0,0,864,865,7,3,0,0,865,196,1,0,0,0,866,867,7,4,0,0,867,868,7, + 10,0,0,868,198,1,0,0,0,869,870,7,21,0,0,870,871,7,16,0,0,871,872,7,4, + 0,0,872,873,7,5,0,0,873,200,1,0,0,0,874,875,7,14,0,0,875,876,7,8,0,0, + 876,877,7,2,0,0,877,202,1,0,0,0,878,879,7,3,0,0,879,880,7,4,0,0,880,881, + 7,18,0,0,881,882,7,4,0,0,882,883,7,7,0,0,883,204,1,0,0,0,884,885,7,3, + 0,0,885,886,7,16,0,0,886,887,7,0,0,0,887,888,7,6,0,0,888,206,1,0,0,0, + 889,890,7,3,0,0,890,891,7,16,0,0,891,892,7,11,0,0,892,893,7,4,0,0,893, + 894,7,1,0,0,894,895,7,0,0,0,895,896,7,3,0,0,896,208,1,0,0,0,897,898,7, + 18,0,0,898,899,7,0,0,0,899,900,7,1,0,0,900,901,7,9,0,0,901,902,7,16,0, + 0,902,210,1,0,0,0,903,904,7,18,0,0,904,905,7,0,0,0,905,906,7,7,0,0,906, + 907,7,1,0,0,907,908,7,12,0,0,908,212,1,0,0,0,909,910,7,18,0,0,910,911, + 7,0,0,0,911,912,7,20,0,0,912,913,7,22,0,0,913,914,7,0,0,0,914,915,7,3, + 0,0,915,916,7,17,0,0,916,917,7,8,0,0,917,214,1,0,0,0,918,919,7,18,0,0, + 919,920,7,8,0,0,920,921,7,9,0,0,921,922,7,11,0,0,922,923,7,8,0,0,923, + 216,1,0,0,0,924,925,7,18,0,0,925,926,7,4,0,0,926,927,7,5,0,0,927,928, + 7,22,0,0,928,929,7,0,0,0,929,930,7,3,0,0,930,931,7,17,0,0,931,932,7,8, + 0,0,932,218,1,0,0,0,933,934,7,18,0,0,934,935,7,17,0,0,935,936,7,3,0,0, + 936,937,7,7,0,0,937,938,7,4,0,0,938,939,5,95,0,0,939,940,7,21,0,0,940, + 941,7,16,0,0,941,942,7,4,0,0,942,943,7,5,0,0,943,220,1,0,0,0,944,945, + 7,5,0,0,945,946,7,16,0,0,946,222,1,0,0,0,947,948,7,5,0,0,948,949,7,16, + 0,0,949,950,7,6,0,0,950,951,7,8,0,0,951,224,1,0,0,0,952,953,7,5,0,0,953, + 954,7,16,0,0,954,955,7,7,0,0,955,226,1,0,0,0,956,957,7,5,0,0,957,958, + 7,16,0,0,958,959,7,5,0,0,959,960,7,8,0,0,960,228,1,0,0,0,961,962,7,5, + 0,0,962,963,7,17,0,0,963,964,7,3,0,0,964,965,7,3,0,0,965,230,1,0,0,0, + 966,967,7,16,0,0,967,968,7,5,0,0,968,232,1,0,0,0,969,970,7,16,0,0,970, + 971,7,5,0,0,971,972,7,3,0,0,972,973,7,2,0,0,973,234,1,0,0,0,974,975,7, + 16,0,0,975,976,7,15,0,0,976,977,7,7,0,0,977,978,7,4,0,0,978,979,7,16, + 0,0,979,980,7,5,0,0,980,981,7,0,0,0,981,982,7,3,0,0,982,236,1,0,0,0,983, + 984,7,16,0,0,984,985,7,9,0,0,985,238,1,0,0,0,986,987,7,16,0,0,987,988, + 7,9,0,0,988,989,7,6,0,0,989,990,7,8,0,0,990,991,7,9,0,0,991,240,1,0,0, + 0,992,993,7,15,0,0,993,994,7,9,0,0,994,995,7,4,0,0,995,996,7,18,0,0,996, + 997,7,0,0,0,997,998,7,9,0,0,998,999,7,2,0,0,999,242,1,0,0,0,1000,1001, + 7,15,0,0,1001,1002,7,9,0,0,1002,1003,7,16,0,0,1003,1004,7,19,0,0,1004, + 1005,7,4,0,0,1005,1006,7,3,0,0,1006,1007,7,8,0,0,1007,244,1,0,0,0,1008, + 1009,7,15,0,0,1009,1010,7,9,0,0,1010,1011,7,16,0,0,1011,1012,7,21,0,0, + 1012,1013,7,8,0,0,1013,1014,7,1,0,0,1014,1015,7,7,0,0,1015,246,1,0,0, + 0,1016,1017,7,9,0,0,1017,1018,7,8,0,0,1018,1019,7,0,0,0,1019,1020,7,6, + 0,0,1020,248,1,0,0,0,1021,1022,7,9,0,0,1022,1023,7,8,0,0,1023,1024,7, + 3,0,0,1024,250,1,0,0,0,1025,1026,7,9,0,0,1026,1027,7,8,0,0,1027,1028, + 7,5,0,0,1028,1029,7,0,0,0,1029,1030,7,18,0,0,1030,1031,7,8,0,0,1031,252, + 1,0,0,0,1032,1033,7,9,0,0,1033,1034,7,8,0,0,1034,1035,7,7,0,0,1035,1036, + 7,17,0,0,1036,1037,7,9,0,0,1037,1038,7,5,0,0,1038,254,1,0,0,0,1039,1040, + 7,9,0,0,1040,1041,7,16,0,0,1041,1042,7,3,0,0,1042,1043,7,3,0,0,1043,1044, + 7,13,0,0,1044,1045,7,0,0,0,1045,1046,7,1,0,0,1046,1047,7,14,0,0,1047, + 256,1,0,0,0,1048,1049,7,9,0,0,1049,1050,7,16,0,0,1050,1051,7,3,0,0,1051, + 1052,7,3,0,0,1052,1053,7,13,0,0,1053,1054,7,0,0,0,1054,1055,7,1,0,0,1055, + 1056,7,14,0,0,1056,1057,5,95,0,0,1057,1058,7,10,0,0,1058,1059,7,14,0, + 0,1059,1060,7,4,0,0,1060,1061,7,15,0,0,1061,1062,5,95,0,0,1062,1063,7, + 1,0,0,1063,1064,7,12,0,0,1064,1065,7,8,0,0,1065,1066,7,1,0,0,1066,1067, + 7,14,0,0,1067,1068,7,15,0,0,1068,1069,7,16,0,0,1069,1070,7,4,0,0,1070, + 1071,7,5,0,0,1071,1072,7,7,0,0,1072,258,1,0,0,0,1073,1074,7,10,0,0,1074, + 1075,7,8,0,0,1075,1076,7,23,0,0,1076,1077,7,17,0,0,1077,1078,7,8,0,0, + 1078,1079,7,5,0,0,1079,1080,7,1,0,0,1080,1081,7,8,0,0,1081,260,1,0,0, + 0,1082,1083,7,10,0,0,1083,1084,7,8,0,0,1084,1085,7,7,0,0,1085,262,1,0, + 0,0,1086,1087,7,10,0,0,1087,1088,7,12,0,0,1088,1089,7,16,0,0,1089,1090, + 7,9,0,0,1090,1091,7,7,0,0,1091,1092,7,8,0,0,1092,1093,7,10,0,0,1093,1094, + 7,7,0,0,1094,264,1,0,0,0,1095,1096,7,10,0,0,1096,1097,7,7,0,0,1097,1098, + 7,0,0,0,1098,1099,7,9,0,0,1099,1100,7,7,0,0,1100,266,1,0,0,0,1101,1102, + 7,10,0,0,1102,1103,7,7,0,0,1103,1104,7,0,0,0,1104,1105,7,9,0,0,1105,1106, + 7,7,0,0,1106,1107,7,10,0,0,1107,268,1,0,0,0,1108,1109,7,10,0,0,1109,1110, + 7,7,0,0,1110,1111,7,9,0,0,1111,1112,7,17,0,0,1112,1113,7,1,0,0,1113,1114, + 7,7,0,0,1114,270,1,0,0,0,1115,1116,7,7,0,0,1116,1117,7,0,0,0,1117,1118, + 7,13,0,0,1118,1119,7,3,0,0,1119,1120,7,8,0,0,1120,272,1,0,0,0,1121,1122, + 7,7,0,0,1122,1123,7,12,0,0,1123,1124,7,8,0,0,1124,1125,7,5,0,0,1125,274, + 1,0,0,0,1126,1127,7,7,0,0,1127,1128,7,16,0,0,1128,276,1,0,0,0,1129,1130, + 7,7,0,0,1130,1131,7,9,0,0,1131,1132,7,0,0,0,1132,1133,7,4,0,0,1133,1134, + 7,3,0,0,1134,278,1,0,0,0,1135,1136,7,7,0,0,1136,1137,7,9,0,0,1137,1138, + 7,0,0,0,1138,1139,7,5,0,0,1139,1140,7,10,0,0,1140,1141,7,0,0,0,1141,1142, + 7,1,0,0,1142,1143,7,7,0,0,1143,1144,7,4,0,0,1144,1145,7,16,0,0,1145,1146, + 7,5,0,0,1146,280,1,0,0,0,1147,1148,7,7,0,0,1148,1149,7,9,0,0,1149,1150, + 7,17,0,0,1150,1151,7,8,0,0,1151,282,1,0,0,0,1152,1153,7,7,0,0,1153,1154, + 7,2,0,0,1154,1155,7,15,0,0,1155,1156,7,8,0,0,1156,284,1,0,0,0,1157,1158, + 7,17,0,0,1158,1159,7,5,0,0,1159,1160,7,4,0,0,1160,1161,7,16,0,0,1161, + 1162,7,5,0,0,1162,286,1,0,0,0,1163,1164,7,17,0,0,1164,1165,7,5,0,0,1165, + 1166,7,24,0,0,1166,1167,7,4,0,0,1167,1168,7,5,0,0,1168,1169,7,6,0,0,1169, + 288,1,0,0,0,1170,1171,7,17,0,0,1171,1172,7,5,0,0,1172,1173,7,4,0,0,1173, + 1174,7,5,0,0,1174,1175,7,10,0,0,1175,1176,7,7,0,0,1176,1177,7,0,0,0,1177, + 1178,7,3,0,0,1178,1179,7,3,0,0,1179,290,1,0,0,0,1180,1181,7,17,0,0,1181, + 1182,7,15,0,0,1182,1183,7,6,0,0,1183,1184,7,0,0,0,1184,1185,7,7,0,0,1185, + 1186,7,8,0,0,1186,292,1,0,0,0,1187,1188,7,17,0,0,1188,1189,7,10,0,0,1189, + 1190,7,8,0,0,1190,294,1,0,0,0,1191,1192,7,24,0,0,1192,1193,7,12,0,0,1193, + 1194,7,8,0,0,1194,1195,7,5,0,0,1195,296,1,0,0,0,1196,1197,7,24,0,0,1197, + 1198,7,12,0,0,1198,1199,7,8,0,0,1199,1200,7,9,0,0,1200,1201,7,8,0,0,1201, + 298,1,0,0,0,1202,1203,7,24,0,0,1203,1204,7,4,0,0,1204,1205,7,7,0,0,1205, + 1206,7,12,0,0,1206,300,1,0,0,0,1207,1208,7,24,0,0,1208,1209,7,9,0,0,1209, + 1210,7,4,0,0,1210,1211,7,7,0,0,1211,1212,7,8,0,0,1212,302,1,0,0,0,1213, + 1214,7,24,0,0,1214,1215,7,10,0,0,1215,1216,7,12,0,0,1216,1217,7,16,0, + 0,1217,1218,7,9,0,0,1218,1219,7,7,0,0,1219,1220,7,8,0,0,1220,1221,7,10, + 0,0,1221,1222,7,7,0,0,1222,304,1,0,0,0,1223,1224,7,20,0,0,1224,1225,7, + 16,0,0,1225,1226,7,9,0,0,1226,306,1,0,0,0,1227,1228,7,10,0,0,1228,1229, + 7,4,0,0,1229,1230,7,5,0,0,1230,1231,7,11,0,0,1231,1232,7,3,0,0,1232,1233, + 7,8,0,0,1233,308,1,0,0,0,1234,1235,7,2,0,0,1235,1236,7,4,0,0,1236,1237, + 7,8,0,0,1237,1238,7,3,0,0,1238,1239,7,6,0,0,1239,310,1,0,0,0,1240,1241, + 7,17,0,0,1241,1242,7,10,0,0,1242,1243,7,8,0,0,1243,1244,7,9,0,0,1244, + 312,1,0,0,0,1245,1246,7,15,0,0,1246,1247,7,0,0,0,1247,1248,7,10,0,0,1248, + 1249,7,10,0,0,1249,1250,7,24,0,0,1250,1251,7,16,0,0,1251,1252,7,9,0,0, + 1252,1253,7,6,0,0,1253,314,1,0,0,0,1254,1255,7,9,0,0,1255,1256,7,16,0, + 0,1256,1257,7,3,0,0,1257,1258,7,8,0,0,1258,316,1,0,0,0,1259,1260,7,18, + 0,0,1260,1261,7,0,0,0,1261,1262,7,15,0,0,1262,318,1,0,0,0,1263,1264,7, + 6,0,0,1264,1265,7,8,0,0,1265,1266,7,1,0,0,1266,1267,7,4,0,0,1267,1268, + 7,18,0,0,1268,1269,7,0,0,0,1269,1270,7,3,0,0,1270,320,1,0,0,0,1271,1272, + 5,42,0,0,1272,322,1,0,0,0,1273,1274,7,10,0,0,1274,1275,7,14,0,0,1275, + 1276,7,4,0,0,1276,1277,7,15,0,0,1277,324,1,0,0,0,1278,1279,5,33,0,0,1279, + 1280,5,61,0,0,1280,326,1,0,0,0,1281,1282,5,58,0,0,1282,328,1,0,0,0,1283, + 1284,5,46,0,0,1284,1285,5,46,0,0,1285,330,1,0,0,0,1286,1287,5,45,0,0, + 1287,332,1,0,0,0,1288,1289,5,33,0,0,1289,334,1,0,0,0,1290,1295,5,34,0, + 0,1291,1294,3,401,200,0,1292,1294,3,337,168,0,1293,1291,1,0,0,0,1293, + 1292,1,0,0,0,1294,1297,1,0,0,0,1295,1293,1,0,0,0,1295,1296,1,0,0,0,1296, + 1298,1,0,0,0,1297,1295,1,0,0,0,1298,1309,5,34,0,0,1299,1304,5,39,0,0, + 1300,1303,3,381,190,0,1301,1303,3,337,168,0,1302,1300,1,0,0,0,1302,1301, + 1,0,0,0,1303,1306,1,0,0,0,1304,1302,1,0,0,0,1304,1305,1,0,0,0,1305,1307, + 1,0,0,0,1306,1304,1,0,0,0,1307,1309,5,39,0,0,1308,1290,1,0,0,0,1308,1299, + 1,0,0,0,1309,336,1,0,0,0,1310,1332,5,92,0,0,1311,1333,7,25,0,0,1312,1313, + 7,20,0,0,1313,1314,3,343,171,0,1314,1315,3,343,171,0,1315,1333,1,0,0, + 0,1316,1317,7,17,0,0,1317,1318,3,343,171,0,1318,1319,3,343,171,0,1319, + 1320,3,343,171,0,1320,1321,3,343,171,0,1321,1333,1,0,0,0,1322,1323,7, + 17,0,0,1323,1324,3,343,171,0,1324,1325,3,343,171,0,1325,1326,3,343,171, + 0,1326,1327,3,343,171,0,1327,1328,3,343,171,0,1328,1329,3,343,171,0,1329, + 1330,3,343,171,0,1330,1331,3,343,171,0,1331,1333,1,0,0,0,1332,1311,1, + 0,0,0,1332,1312,1,0,0,0,1332,1316,1,0,0,0,1332,1322,1,0,0,0,1333,338, + 1,0,0,0,1334,1343,3,351,175,0,1335,1339,3,347,173,0,1336,1338,3,345,172, + 0,1337,1336,1,0,0,0,1338,1341,1,0,0,0,1339,1337,1,0,0,0,1339,1340,1,0, + 0,0,1340,1343,1,0,0,0,1341,1339,1,0,0,0,1342,1334,1,0,0,0,1342,1335,1, + 0,0,0,1343,340,1,0,0,0,1344,1346,7,26,0,0,1345,1344,1,0,0,0,1346,342, + 1,0,0,0,1347,1350,3,345,172,0,1348,1350,3,341,170,0,1349,1347,1,0,0,0, + 1349,1348,1,0,0,0,1350,344,1,0,0,0,1351,1354,3,351,175,0,1352,1354,3, + 347,173,0,1353,1351,1,0,0,0,1353,1352,1,0,0,0,1354,346,1,0,0,0,1355,1358, + 3,349,174,0,1356,1358,2,56,57,0,1357,1355,1,0,0,0,1357,1356,1,0,0,0,1358, + 348,1,0,0,0,1359,1360,2,49,55,0,1360,350,1,0,0,0,1361,1362,5,48,0,0,1362, + 352,1,0,0,0,1363,1365,3,345,172,0,1364,1363,1,0,0,0,1365,1366,1,0,0,0, + 1366,1364,1,0,0,0,1366,1367,1,0,0,0,1367,1386,1,0,0,0,1368,1370,3,345, + 172,0,1369,1368,1,0,0,0,1370,1371,1,0,0,0,1371,1369,1,0,0,0,1371,1372, + 1,0,0,0,1372,1373,1,0,0,0,1373,1375,5,46,0,0,1374,1376,3,345,172,0,1375, + 1374,1,0,0,0,1376,1377,1,0,0,0,1377,1375,1,0,0,0,1377,1378,1,0,0,0,1378, + 1386,1,0,0,0,1379,1381,5,46,0,0,1380,1382,3,345,172,0,1381,1380,1,0,0, + 0,1382,1383,1,0,0,0,1383,1381,1,0,0,0,1383,1384,1,0,0,0,1384,1386,1,0, + 0,0,1385,1364,1,0,0,0,1385,1369,1,0,0,0,1385,1379,1,0,0,0,1386,1387,1, + 0,0,0,1387,1389,7,8,0,0,1388,1390,5,45,0,0,1389,1388,1,0,0,0,1389,1390, + 1,0,0,0,1390,1392,1,0,0,0,1391,1393,3,345,172,0,1392,1391,1,0,0,0,1393, + 1394,1,0,0,0,1394,1392,1,0,0,0,1394,1395,1,0,0,0,1395,354,1,0,0,0,1396, + 1398,3,345,172,0,1397,1396,1,0,0,0,1398,1401,1,0,0,0,1399,1397,1,0,0, + 0,1399,1400,1,0,0,0,1400,1402,1,0,0,0,1401,1399,1,0,0,0,1402,1404,5,46, + 0,0,1403,1405,3,345,172,0,1404,1403,1,0,0,0,1405,1406,1,0,0,0,1406,1404, + 1,0,0,0,1406,1407,1,0,0,0,1407,356,1,0,0,0,1408,1412,3,359,179,0,1409, + 1411,3,361,180,0,1410,1409,1,0,0,0,1411,1414,1,0,0,0,1412,1410,1,0,0, + 0,1412,1413,1,0,0,0,1413,358,1,0,0,0,1414,1412,1,0,0,0,1415,1418,3,409, + 204,0,1416,1418,3,397,198,0,1417,1415,1,0,0,0,1417,1416,1,0,0,0,1418, + 360,1,0,0,0,1419,1422,3,377,188,0,1420,1422,3,393,196,0,1421,1419,1,0, + 0,0,1421,1420,1,0,0,0,1422,362,1,0,0,0,1423,1427,5,96,0,0,1424,1426,3, + 373,186,0,1425,1424,1,0,0,0,1426,1429,1,0,0,0,1427,1425,1,0,0,0,1427, + 1428,1,0,0,0,1428,1430,1,0,0,0,1429,1427,1,0,0,0,1430,1432,5,96,0,0,1431, + 1423,1,0,0,0,1432,1433,1,0,0,0,1433,1431,1,0,0,0,1433,1434,1,0,0,0,1434, + 364,1,0,0,0,1435,1437,3,367,183,0,1436,1435,1,0,0,0,1437,1438,1,0,0,0, + 1438,1436,1,0,0,0,1438,1439,1,0,0,0,1439,366,1,0,0,0,1440,1453,3,395, + 197,0,1441,1453,3,399,199,0,1442,1453,3,403,201,0,1443,1453,3,405,202, + 0,1444,1453,3,371,185,0,1445,1453,3,391,195,0,1446,1453,3,389,194,0,1447, + 1453,3,387,193,0,1448,1453,3,375,187,0,1449,1453,3,407,203,0,1450,1453, + 7,27,0,0,1451,1453,3,369,184,0,1452,1440,1,0,0,0,1452,1441,1,0,0,0,1452, + 1442,1,0,0,0,1452,1443,1,0,0,0,1452,1444,1,0,0,0,1452,1445,1,0,0,0,1452, + 1446,1,0,0,0,1452,1447,1,0,0,0,1452,1448,1,0,0,0,1452,1449,1,0,0,0,1452, + 1450,1,0,0,0,1452,1451,1,0,0,0,1453,368,1,0,0,0,1454,1455,5,47,0,0,1455, + 1456,5,42,0,0,1456,1462,1,0,0,0,1457,1461,3,379,189,0,1458,1459,5,42, + 0,0,1459,1461,3,385,192,0,1460,1457,1,0,0,0,1460,1458,1,0,0,0,1461,1464, + 1,0,0,0,1462,1460,1,0,0,0,1462,1463,1,0,0,0,1463,1465,1,0,0,0,1464,1462, + 1,0,0,0,1465,1466,5,42,0,0,1466,1484,5,47,0,0,1467,1468,5,47,0,0,1468, + 1469,5,47,0,0,1469,1473,1,0,0,0,1470,1472,3,383,191,0,1471,1470,1,0,0, + 0,1472,1475,1,0,0,0,1473,1471,1,0,0,0,1473,1474,1,0,0,0,1474,1477,1,0, + 0,0,1475,1473,1,0,0,0,1476,1478,3,391,195,0,1477,1476,1,0,0,0,1477,1478, + 1,0,0,0,1478,1481,1,0,0,0,1479,1482,3,403,201,0,1480,1482,5,0,0,1,1481, + 1479,1,0,0,0,1481,1480,1,0,0,0,1482,1484,1,0,0,0,1483,1454,1,0,0,0,1483, + 1467,1,0,0,0,1484,370,1,0,0,0,1485,1486,7,28,0,0,1486,372,1,0,0,0,1487, + 1488,8,29,0,0,1488,374,1,0,0,0,1489,1490,7,30,0,0,1490,376,1,0,0,0,1491, + 1492,7,31,0,0,1492,378,1,0,0,0,1493,1494,8,32,0,0,1494,380,1,0,0,0,1495, + 1496,8,33,0,0,1496,382,1,0,0,0,1497,1498,8,34,0,0,1498,384,1,0,0,0,1499, + 1500,8,35,0,0,1500,386,1,0,0,0,1501,1502,7,36,0,0,1502,388,1,0,0,0,1503, + 1504,7,37,0,0,1504,390,1,0,0,0,1505,1506,7,38,0,0,1506,392,1,0,0,0,1507, + 1508,7,39,0,0,1508,394,1,0,0,0,1509,1510,7,40,0,0,1510,396,1,0,0,0,1511, + 1512,7,41,0,0,1512,398,1,0,0,0,1513,1514,7,42,0,0,1514,400,1,0,0,0,1515, + 1516,8,43,0,0,1516,402,1,0,0,0,1517,1518,7,44,0,0,1518,404,1,0,0,0,1519, + 1520,7,45,0,0,1520,406,1,0,0,0,1521,1522,7,46,0,0,1522,408,1,0,0,0,1523, + 1524,7,47,0,0,1524,410,1,0,0,0,1525,1526,9,0,0,0,1526,412,1,0,0,0,35, + 0,1293,1295,1302,1304,1308,1332,1339,1342,1345,1349,1353,1357,1366,1371, + 1377,1383,1385,1389,1394,1399,1406,1412,1417,1421,1427,1433,1438,1452, + 1460,1462,1473,1477,1481,1483,0 + }; + staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); + + antlr4::atn::ATNDeserializer deserializer; + staticData->atn = deserializer.deserialize(staticData->serializedATN); + + const size_t count = staticData->atn->getNumberOfDecisions(); + staticData->decisionToDFA.reserve(count); + for (size_t i = 0; i < count; i++) { + staticData->decisionToDFA.emplace_back(staticData->atn->getDecisionState(i), i); + } + cypherlexerLexerStaticData = staticData.release(); +} + +} + +CypherLexer::CypherLexer(CharStream *input) : Lexer(input) { + CypherLexer::initialize(); + _interpreter = new atn::LexerATNSimulator(this, *cypherlexerLexerStaticData->atn, cypherlexerLexerStaticData->decisionToDFA, cypherlexerLexerStaticData->sharedContextCache); +} + +CypherLexer::~CypherLexer() { + delete _interpreter; +} + +std::string CypherLexer::getGrammarFileName() const { + return "Cypher.g4"; +} + +const std::vector& CypherLexer::getRuleNames() const { + return cypherlexerLexerStaticData->ruleNames; +} + +const std::vector& CypherLexer::getChannelNames() const { + return cypherlexerLexerStaticData->channelNames; +} + +const std::vector& CypherLexer::getModeNames() const { + return cypherlexerLexerStaticData->modeNames; +} + +const dfa::Vocabulary& CypherLexer::getVocabulary() const { + return cypherlexerLexerStaticData->vocabulary; +} + +antlr4::atn::SerializedATNView CypherLexer::getSerializedATN() const { + return cypherlexerLexerStaticData->serializedATN; +} + +const atn::ATN& CypherLexer::getATN() const { + return *cypherlexerLexerStaticData->atn; +} + + + + +void CypherLexer::initialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + cypherlexerLexerInitialize(); +#else + ::antlr4::internal::call_once(cypherlexerLexerOnceFlag, cypherlexerLexerInitialize); +#endif +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/cypher_parser.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/cypher_parser.cpp new file mode 100644 index 0000000000..1bff0da747 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/cypher_parser.cpp @@ -0,0 +1,19407 @@ + +// Generated from Cypher.g4 by ANTLR 4.13.1 + + + +#include "cypher_parser.h" + + +using namespace antlrcpp; + +using namespace antlr4; + +namespace { + +struct CypherParserStaticData final { + CypherParserStaticData(std::vector ruleNames, + std::vector literalNames, + std::vector symbolicNames) + : ruleNames(std::move(ruleNames)), literalNames(std::move(literalNames)), + symbolicNames(std::move(symbolicNames)), + vocabulary(this->literalNames, this->symbolicNames) {} + + CypherParserStaticData(const CypherParserStaticData&) = delete; + CypherParserStaticData(CypherParserStaticData&&) = delete; + CypherParserStaticData& operator=(const CypherParserStaticData&) = delete; + CypherParserStaticData& operator=(CypherParserStaticData&&) = delete; + + std::vector decisionToDFA; + antlr4::atn::PredictionContextCache sharedContextCache; + const std::vector ruleNames; + const std::vector literalNames; + const std::vector symbolicNames; + const antlr4::dfa::Vocabulary vocabulary; + antlr4::atn::SerializedATNView serializedATN; + std::unique_ptr atn; +}; + +::antlr4::internal::OnceFlag cypherParserOnceFlag; +#if ANTLR4_USE_THREAD_LOCAL_CACHE +static thread_local +#endif +CypherParserStaticData *cypherParserStaticData = nullptr; + +void cypherParserInitialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + if (cypherParserStaticData != nullptr) { + return; + } +#else + assert(cypherParserStaticData == nullptr); +#endif + auto staticData = std::make_unique( + std::vector{ + "ku_Statements", "oC_Cypher", "oC_Statement", "kU_CopyFrom", "kU_ColumnNames", + "kU_ScanSource", "kU_CopyFromByColumn", "kU_CopyTO", "kU_ExportDatabase", + "kU_ImportDatabase", "kU_AttachDatabase", "kU_Option", "kU_Options", + "kU_DetachDatabase", "kU_UseDatabase", "kU_StandaloneCall", "kU_CommentOn", + "kU_CreateMacro", "kU_PositionalArgs", "kU_DefaultArg", "kU_FilePaths", + "kU_IfNotExists", "kU_CreateNodeTable", "kU_CreateRelTable", "kU_FromToConnections", + "kU_FromToConnection", "kU_CreateSequence", "kU_CreateType", "kU_SequenceOptions", + "kU_WithPasswd", "kU_CreateUser", "kU_CreateRole", "kU_IncrementBy", + "kU_MinValue", "kU_MaxValue", "kU_StartWith", "kU_Cycle", "kU_IfExists", + "kU_Drop", "kU_AlterTable", "kU_AlterOptions", "kU_AddProperty", "kU_Default", + "kU_DropProperty", "kU_RenameTable", "kU_RenameProperty", "kU_AddFromToConnection", + "kU_DropFromToConnection", "kU_ColumnDefinitions", "kU_ColumnDefinition", + "kU_PropertyDefinitions", "kU_PropertyDefinition", "kU_CreateNodeConstraint", + "kU_UnionType", "kU_StructType", "kU_MapType", "kU_DecimalType", "kU_DataType", + "kU_ListIdentifiers", "kU_ListIdentifier", "oC_AnyCypherOption", "oC_Explain", + "oC_Profile", "kU_Transaction", "kU_Extension", "kU_LoadExtension", + "kU_InstallExtension", "kU_UninstallExtension", "kU_UpdateExtension", + "oC_Query", "oC_RegularQuery", "oC_Union", "oC_SingleQuery", "oC_SinglePartQuery", + "oC_MultiPartQuery", "kU_QueryPart", "oC_UpdatingClause", "oC_ReadingClause", + "kU_LoadFrom", "oC_YieldItem", "oC_YieldItems", "kU_InQueryCall", + "oC_Match", "kU_Hint", "kU_JoinNode", "oC_Unwind", "oC_Create", "oC_Merge", + "oC_MergeAction", "oC_Set", "oC_SetItem", "oC_Delete", "oC_With", + "oC_Return", "oC_ProjectionBody", "oC_ProjectionItems", "oC_ProjectionItem", + "oC_Order", "oC_Skip", "oC_Limit", "oC_SortItem", "oC_Where", "oC_Pattern", + "oC_PatternPart", "oC_AnonymousPatternPart", "oC_PatternElement", + "oC_NodePattern", "oC_PatternElementChain", "oC_RelationshipPattern", + "oC_RelationshipDetail", "kU_Properties", "oC_RelationshipTypes", + "oC_NodeLabels", "kU_RecursiveDetail", "kU_RecursiveType", "oC_RangeLiteral", + "kU_RecursiveComprehension", "kU_RecursiveProjectionItems", "oC_LowerBound", + "oC_UpperBound", "oC_LabelName", "oC_RelTypeName", "oC_Expression", + "oC_OrExpression", "oC_XorExpression", "oC_AndExpression", "oC_NotExpression", + "oC_ComparisonExpression", "kU_ComparisonOperator", "kU_BitwiseOrOperatorExpression", + "kU_BitwiseAndOperatorExpression", "kU_BitShiftOperatorExpression", + "kU_BitShiftOperator", "oC_AddOrSubtractExpression", "kU_AddOrSubtractOperator", + "oC_MultiplyDivideModuloExpression", "kU_MultiplyDivideModuloOperator", + "oC_PowerOfExpression", "oC_StringListNullOperatorExpression", "oC_ListOperatorExpression", + "oC_StringOperatorExpression", "oC_RegularExpression", "oC_NullOperatorExpression", + "oC_UnaryAddSubtractOrFactorialExpression", "oC_PropertyOrLabelsExpression", + "oC_Atom", "oC_Quantifier", "oC_FilterExpression", "oC_IdInColl", + "oC_Literal", "oC_BooleanLiteral", "oC_ListLiteral", "kU_ListEntry", + "kU_StructLiteral", "kU_StructField", "oC_ParenthesizedExpression", + "oC_FunctionInvocation", "oC_FunctionName", "kU_FunctionParameter", + "kU_LambdaParameter", "kU_LambdaVars", "oC_PathPatterns", "oC_ExistCountSubquery", + "oC_PropertyLookup", "oC_CaseExpression", "oC_CaseAlternative", "oC_Variable", + "oC_NumberLiteral", "oC_Parameter", "oC_PropertyExpression", "oC_PropertyKeyName", + "oC_IntegerLiteral", "oC_DoubleLiteral", "oC_SchemaName", "oC_SymbolicName", + "kU_NonReservedKeywords", "oC_LeftArrowHead", "oC_RightArrowHead", + "oC_Dash" + }, + std::vector{ + "", "';'", "'('", "')'", "','", "'.'", "'='", "'['", "']'", "'{'", + "'}'", "'|'", "'<>'", "'<'", "'<='", "'>'", "'>='", "'&'", "'>>'", + "'<<'", "'+'", "'/'", "'%'", "'^'", "'=~'", "'$'", "'\\u27E8'", "'\\u3008'", + "'\\uFE64'", "'\\uFF1C'", "'\\u27E9'", "'\\u3009'", "'\\uFE65'", "'\\uFF1E'", + "'\\u00AD'", "'\\u2010'", "'\\u2011'", "'\\u2012'", "'\\u2013'", "'\\u2014'", + "'\\u2015'", "'\\u2212'", "'\\uFE58'", "'\\uFE63'", "'\\uFF0D'", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "'*'", "", "'!='", + "':'", "'..'", "'-'", "'!'", "", "", "", "", "", "", "", "", "'0'" + }, + std::vector{ + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "", "", "", "", "", "", "ACYCLIC", "ANY", "ADD", + "ALL", "ALTER", "AND", "AS", "ASC", "ASCENDING", "ATTACH", "BEGIN", + "BY", "CALL", "CASE", "CAST", "CHECKPOINT", "COLUMN", "COMMENT", "COMMIT", + "COMMIT_SKIP_CHECKPOINT", "CONTAINS", "COPY", "COUNT", "CREATE", "CYCLE", + "DATABASE", "DBTYPE", "DEFAULT", "DELETE", "DESC", "DESCENDING", "DETACH", + "DISTINCT", "DROP", "ELSE", "END", "ENDS", "EXISTS", "EXPLAIN", "EXPORT", + "EXTENSION", "FALSE", "FROM", "FORCE", "GLOB", "GRAPH", "GROUP", "HEADERS", + "HINT", "IMPORT", "IF", "IN", "INCREMENT", "INSTALL", "IS", "JOIN", + "KEY", "LIMIT", "LOAD", "LOGICAL", "MACRO", "MATCH", "MAXVALUE", "MERGE", + "MINVALUE", "MULTI_JOIN", "NO", "NODE", "NOT", "NONE", "NULL", "ON", + "ONLY", "OPTIONAL", "OR", "ORDER", "PRIMARY", "PROFILE", "PROJECT", + "READ", "REL", "RENAME", "RETURN", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", + "SEQUENCE", "SET", "SHORTEST", "START", "STARTS", "STRUCT", "TABLE", + "THEN", "TO", "TRAIL", "TRANSACTION", "TRUE", "TYPE", "UNION", "UNWIND", + "UNINSTALL", "UPDATE", "USE", "WHEN", "WHERE", "WITH", "WRITE", "WSHORTEST", + "XOR", "SINGLE", "YIELD", "USER", "PASSWORD", "ROLE", "MAP", "DECIMAL", + "STAR", "L_SKIP", "INVALID_NOT_EQUAL", "COLON", "DOTDOT", "MINUS", + "FACTORIAL", "StringLiteral", "EscapedChar", "DecimalInteger", "HexLetter", + "HexDigit", "Digit", "NonZeroDigit", "NonZeroOctDigit", "ZeroDigit", + "ExponentDecimalReal", "RegularDecimalReal", "UnescapedSymbolicName", + "IdentifierStart", "IdentifierPart", "EscapedSymbolicName", "SP", + "WHITESPACE", "CypherComment", "Unknown" + } + ); + static const int32_t serializedATNSegment[] = { + 4,1,186,2909,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7,6, + 2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13,2,14, + 7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20,7,20,2,21, + 7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26,2,27,7,27,2,28, + 7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33,7,33,2,34,7,34,2,35, + 7,35,2,36,7,36,2,37,7,37,2,38,7,38,2,39,7,39,2,40,7,40,2,41,7,41,2,42, + 7,42,2,43,7,43,2,44,7,44,2,45,7,45,2,46,7,46,2,47,7,47,2,48,7,48,2,49, + 7,49,2,50,7,50,2,51,7,51,2,52,7,52,2,53,7,53,2,54,7,54,2,55,7,55,2,56, + 7,56,2,57,7,57,2,58,7,58,2,59,7,59,2,60,7,60,2,61,7,61,2,62,7,62,2,63, + 7,63,2,64,7,64,2,65,7,65,2,66,7,66,2,67,7,67,2,68,7,68,2,69,7,69,2,70, + 7,70,2,71,7,71,2,72,7,72,2,73,7,73,2,74,7,74,2,75,7,75,2,76,7,76,2,77, + 7,77,2,78,7,78,2,79,7,79,2,80,7,80,2,81,7,81,2,82,7,82,2,83,7,83,2,84, + 7,84,2,85,7,85,2,86,7,86,2,87,7,87,2,88,7,88,2,89,7,89,2,90,7,90,2,91, + 7,91,2,92,7,92,2,93,7,93,2,94,7,94,2,95,7,95,2,96,7,96,2,97,7,97,2,98, + 7,98,2,99,7,99,2,100,7,100,2,101,7,101,2,102,7,102,2,103,7,103,2,104, + 7,104,2,105,7,105,2,106,7,106,2,107,7,107,2,108,7,108,2,109,7,109,2,110, + 7,110,2,111,7,111,2,112,7,112,2,113,7,113,2,114,7,114,2,115,7,115,2,116, + 7,116,2,117,7,117,2,118,7,118,2,119,7,119,2,120,7,120,2,121,7,121,2,122, + 7,122,2,123,7,123,2,124,7,124,2,125,7,125,2,126,7,126,2,127,7,127,2,128, + 7,128,2,129,7,129,2,130,7,130,2,131,7,131,2,132,7,132,2,133,7,133,2,134, + 7,134,2,135,7,135,2,136,7,136,2,137,7,137,2,138,7,138,2,139,7,139,2,140, + 7,140,2,141,7,141,2,142,7,142,2,143,7,143,2,144,7,144,2,145,7,145,2,146, + 7,146,2,147,7,147,2,148,7,148,2,149,7,149,2,150,7,150,2,151,7,151,2,152, + 7,152,2,153,7,153,2,154,7,154,2,155,7,155,2,156,7,156,2,157,7,157,2,158, + 7,158,2,159,7,159,2,160,7,160,2,161,7,161,2,162,7,162,2,163,7,163,2,164, + 7,164,2,165,7,165,2,166,7,166,2,167,7,167,2,168,7,168,2,169,7,169,2,170, + 7,170,2,171,7,171,2,172,7,172,2,173,7,173,2,174,7,174,2,175,7,175,2,176, + 7,176,2,177,7,177,2,178,7,178,1,0,1,0,3,0,361,8,0,1,0,1,0,3,0,365,8,0, + 1,0,5,0,368,8,0,10,0,12,0,371,9,0,1,0,3,0,374,8,0,1,0,1,0,1,1,3,1,379, + 8,1,1,1,3,1,382,8,1,1,1,1,1,3,1,386,8,1,1,1,3,1,389,8,1,1,2,1,2,1,2,1, + 2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2, + 1,2,3,2,413,8,2,1,3,1,3,1,3,1,3,3,3,419,8,3,1,3,1,3,1,3,1,3,1,3,3,3,426, + 8,3,1,3,1,3,3,3,430,8,3,1,3,1,3,3,3,434,8,3,1,3,1,3,3,3,438,8,3,1,4,3, + 4,441,8,4,1,4,1,4,3,4,445,8,4,1,4,1,4,3,4,449,8,4,1,4,1,4,3,4,453,8,4, + 1,4,5,4,456,8,4,10,4,12,4,459,9,4,1,4,3,4,462,8,4,3,4,464,8,4,1,4,1,4, + 1,5,1,5,1,5,3,5,471,8,5,1,5,1,5,3,5,475,8,5,1,5,1,5,1,5,1,5,1,5,1,5,1, + 5,3,5,484,8,5,1,5,1,5,1,5,3,5,489,8,5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6, + 3,6,499,8,6,1,6,1,6,3,6,503,8,6,1,6,1,6,3,6,507,8,6,1,6,5,6,510,8,6,10, + 6,12,6,513,9,6,1,6,1,6,1,6,1,6,1,6,1,6,1,7,1,7,1,7,1,7,3,7,525,8,7,1, + 7,1,7,3,7,529,8,7,1,7,1,7,1,7,1,7,1,7,1,7,3,7,537,8,7,1,7,1,7,3,7,541, + 8,7,1,7,1,7,3,7,545,8,7,1,7,1,7,3,7,549,8,7,1,8,1,8,1,8,1,8,1,8,1,8,3, + 8,557,8,8,1,8,1,8,3,8,561,8,8,1,8,1,8,3,8,565,8,8,1,8,1,8,3,8,569,8,8, + 1,9,1,9,1,9,1,9,1,9,1,9,1,10,1,10,1,10,1,10,1,10,1,10,1,10,3,10,584,8, + 10,1,10,1,10,1,10,3,10,589,8,10,1,10,1,10,1,10,1,10,3,10,595,8,10,1,10, + 1,10,3,10,599,8,10,1,10,3,10,602,8,10,1,10,3,10,605,8,10,1,10,1,10,1, + 11,1,11,3,11,611,8,11,1,11,1,11,3,11,615,8,11,1,11,5,11,618,8,11,10,11, + 12,11,621,9,11,3,11,623,8,11,1,11,1,11,1,11,3,11,628,8,11,1,12,1,12,3, + 12,632,8,12,1,12,1,12,3,12,636,8,12,1,12,5,12,639,8,12,10,12,12,12,642, + 9,12,1,13,1,13,1,13,1,13,1,14,1,14,1,14,1,14,1,15,1,15,1,15,1,15,3,15, + 656,8,15,1,15,1,15,3,15,660,8,15,1,15,1,15,1,15,1,15,1,15,3,15,667,8, + 15,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,17,1, + 17,1,17,1,17,1,17,1,17,3,17,687,8,17,1,17,1,17,3,17,691,8,17,1,17,3,17, + 694,8,17,1,17,3,17,697,8,17,1,17,3,17,700,8,17,1,17,3,17,703,8,17,1,17, + 1,17,3,17,707,8,17,1,17,5,17,710,8,17,10,17,12,17,713,9,17,1,17,3,17, + 716,8,17,1,17,1,17,1,17,1,17,1,17,1,17,1,18,1,18,3,18,726,8,18,1,18,1, + 18,3,18,730,8,18,1,18,5,18,733,8,18,10,18,12,18,736,9,18,1,19,1,19,3, + 19,740,8,19,1,19,1,19,1,19,3,19,745,8,19,1,19,1,19,1,20,1,20,3,20,751, + 8,20,1,20,1,20,3,20,755,8,20,1,20,1,20,3,20,759,8,20,1,20,5,20,762,8, + 20,10,20,12,20,765,9,20,1,20,1,20,1,20,1,20,3,20,771,8,20,1,20,1,20,3, + 20,775,8,20,1,20,1,20,3,20,779,8,20,1,20,3,20,782,8,20,1,21,1,21,1,21, + 1,21,1,21,1,21,1,22,1,22,1,22,1,22,1,22,1,22,1,22,1,22,1,22,3,22,799, + 8,22,1,22,1,22,3,22,803,8,22,1,22,1,22,3,22,807,8,22,1,22,1,22,3,22,811, + 8,22,1,22,1,22,3,22,815,8,22,1,22,3,22,818,8,22,1,22,3,22,821,8,22,1, + 22,1,22,1,22,1,22,1,22,1,22,3,22,829,8,22,1,23,1,23,1,23,1,23,1,23,1, + 23,1,23,3,23,838,8,23,1,23,1,23,3,23,842,8,23,1,23,1,23,1,23,3,23,847, + 8,23,1,23,1,23,3,23,851,8,23,1,23,1,23,3,23,855,8,23,1,23,1,23,3,23,859, + 8,23,1,23,1,23,3,23,863,8,23,3,23,865,8,23,1,23,1,23,3,23,869,8,23,1, + 23,1,23,3,23,873,8,23,3,23,875,8,23,1,23,1,23,1,23,1,23,1,23,1,23,3,23, + 883,8,23,1,23,1,23,1,23,3,23,888,8,23,1,23,1,23,3,23,892,8,23,1,23,1, + 23,3,23,896,8,23,1,23,1,23,3,23,900,8,23,1,24,1,24,3,24,904,8,24,1,24, + 1,24,3,24,908,8,24,1,24,5,24,911,8,24,10,24,12,24,914,9,24,1,25,1,25, + 1,25,1,25,1,25,1,25,1,25,1,25,1,26,1,26,1,26,1,26,1,26,1,26,1,26,3,26, + 931,8,26,1,26,1,26,1,26,5,26,936,8,26,10,26,12,26,939,9,26,1,27,1,27, + 1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,3,27,951,8,27,1,28,1,28,1,28, + 1,28,1,28,3,28,958,8,28,1,29,1,29,1,29,1,29,1,29,1,29,1,29,1,30,1,30, + 1,30,1,30,1,30,1,30,1,30,3,30,974,8,30,1,30,1,30,3,30,978,8,30,1,31,1, + 31,1,31,1,31,1,31,1,31,1,31,3,31,987,8,31,1,31,1,31,1,32,1,32,1,32,1, + 32,3,32,995,8,32,1,32,3,32,998,8,32,1,32,1,32,1,33,1,33,1,33,1,33,1,33, + 1,33,3,33,1008,8,33,1,33,3,33,1011,8,33,1,34,1,34,1,34,1,34,1,34,1,34, + 3,34,1019,8,34,1,34,3,34,1022,8,34,1,35,1,35,1,35,1,35,3,35,1028,8,35, + 1,35,3,35,1031,8,35,1,35,1,35,1,36,1,36,3,36,1037,8,36,1,36,1,36,1,37, + 1,37,1,37,1,37,1,38,1,38,1,38,1,38,1,38,1,38,1,38,3,38,1052,8,38,1,38, + 1,38,1,39,1,39,1,39,1,39,1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40, + 1,40,3,40,1070,8,40,1,41,1,41,1,41,1,41,1,41,3,41,1077,8,41,1,41,1,41, + 1,41,1,41,1,41,3,41,1084,8,41,1,42,1,42,1,42,1,42,1,43,1,43,1,43,1,43, + 1,43,3,43,1095,8,43,1,43,1,43,1,44,1,44,1,44,1,44,1,44,1,44,1,45,1,45, + 1,45,1,45,1,45,1,45,1,45,1,45,1,46,1,46,1,46,1,46,1,46,3,46,1118,8,46, + 1,46,1,46,1,47,1,47,1,47,1,47,1,47,3,47,1127,8,47,1,47,1,47,1,48,1,48, + 3,48,1133,8,48,1,48,1,48,3,48,1137,8,48,1,48,5,48,1140,8,48,10,48,12, + 48,1143,9,48,1,49,1,49,1,49,1,49,1,50,1,50,3,50,1151,8,50,1,50,1,50,3, + 50,1155,8,50,1,50,5,50,1158,8,50,10,50,12,50,1161,9,50,1,51,1,51,1,51, + 3,51,1166,8,51,1,51,1,51,1,51,1,51,3,51,1172,8,51,1,52,1,52,1,52,1,52, + 3,52,1178,8,52,1,52,1,52,3,52,1182,8,52,1,52,1,52,3,52,1186,8,52,1,52, + 1,52,1,53,1,53,3,53,1192,8,53,1,53,1,53,3,53,1196,8,53,1,53,1,53,3,53, + 1200,8,53,1,53,1,53,1,54,1,54,3,54,1206,8,54,1,54,1,54,3,54,1210,8,54, + 1,54,1,54,3,54,1214,8,54,1,54,1,54,1,55,1,55,3,55,1220,8,55,1,55,1,55, + 3,55,1224,8,55,1,55,1,55,3,55,1228,8,55,1,55,1,55,3,55,1232,8,55,1,55, + 1,55,3,55,1236,8,55,1,55,1,55,1,56,1,56,3,56,1242,8,56,1,56,1,56,3,56, + 1246,8,56,1,56,1,56,3,56,1250,8,56,1,56,1,56,3,56,1254,8,56,1,56,1,56, + 3,56,1258,8,56,1,56,1,56,1,57,1,57,1,57,1,57,1,57,1,57,3,57,1268,8,57, + 1,57,1,57,5,57,1272,8,57,10,57,12,57,1275,9,57,1,58,1,58,5,58,1279,8, + 58,10,58,12,58,1282,9,58,1,59,1,59,3,59,1286,8,59,1,59,1,59,1,60,1,60, + 3,60,1292,8,60,1,61,1,61,1,61,3,61,1297,8,61,1,62,1,62,1,63,1,63,1,63, + 1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,3,63,1314,8,63,1,64, + 1,64,1,64,1,64,3,64,1320,8,64,1,65,1,65,1,65,1,65,3,65,1326,8,65,1,65, + 1,65,3,65,1330,8,65,1,66,1,66,3,66,1334,8,66,1,66,1,66,1,66,1,66,1,66, + 1,66,1,66,3,66,1343,8,66,1,67,1,67,1,67,1,67,1,68,1,68,1,68,1,68,1,69, + 1,69,1,70,1,70,3,70,1357,8,70,1,70,5,70,1360,8,70,10,70,12,70,1363,9, + 70,1,70,1,70,3,70,1367,8,70,4,70,1369,8,70,11,70,12,70,1370,1,70,1,70, + 1,70,3,70,1376,8,70,1,71,1,71,1,71,1,71,3,71,1382,8,71,1,71,1,71,1,71, + 3,71,1387,8,71,1,71,3,71,1390,8,71,1,72,1,72,3,72,1394,8,72,1,73,1,73, + 3,73,1398,8,73,5,73,1400,8,73,10,73,12,73,1403,9,73,1,73,1,73,1,73,3, + 73,1408,8,73,5,73,1410,8,73,10,73,12,73,1413,9,73,1,73,1,73,3,73,1417, + 8,73,1,73,5,73,1420,8,73,10,73,12,73,1423,9,73,1,73,3,73,1426,8,73,1, + 73,3,73,1429,8,73,3,73,1431,8,73,1,74,1,74,3,74,1435,8,74,4,74,1437,8, + 74,11,74,12,74,1438,1,74,1,74,1,75,1,75,3,75,1445,8,75,5,75,1447,8,75, + 10,75,12,75,1450,9,75,1,75,1,75,3,75,1454,8,75,5,75,1456,8,75,10,75,12, + 75,1459,9,75,1,75,1,75,1,76,1,76,1,76,1,76,3,76,1467,8,76,1,77,1,77,1, + 77,1,77,3,77,1473,8,77,1,78,1,78,1,78,1,78,1,78,1,78,3,78,1481,8,78,1, + 78,1,78,3,78,1485,8,78,1,78,1,78,3,78,1489,8,78,1,78,1,78,3,78,1493,8, + 78,1,78,1,78,1,78,1,78,1,78,3,78,1500,8,78,1,78,1,78,3,78,1504,8,78,1, + 78,1,78,3,78,1508,8,78,1,78,1,78,3,78,1512,8,78,1,78,3,78,1515,8,78,1, + 78,3,78,1518,8,78,1,79,1,79,1,79,1,79,1,79,3,79,1525,8,79,1,79,1,79,1, + 80,1,80,3,80,1531,8,80,1,80,1,80,3,80,1535,8,80,1,80,5,80,1538,8,80,10, + 80,12,80,1541,9,80,1,81,1,81,1,81,1,81,3,81,1547,8,81,1,81,3,81,1550, + 8,81,1,81,3,81,1553,8,81,1,81,1,81,1,81,3,81,1558,8,81,1,82,1,82,3,82, + 1562,8,82,1,82,1,82,3,82,1566,8,82,1,82,1,82,1,82,3,82,1571,8,82,1,82, + 1,82,3,82,1575,8,82,1,83,1,83,1,83,1,83,1,84,1,84,1,84,3,84,1584,8,84, + 1,84,1,84,3,84,1588,8,84,1,84,1,84,1,84,3,84,1593,8,84,1,84,1,84,1,84, + 1,84,1,84,1,84,1,84,1,84,1,84,1,84,4,84,1605,8,84,11,84,12,84,1606,5, + 84,1609,8,84,10,84,12,84,1612,9,84,1,85,1,85,3,85,1616,8,85,1,85,1,85, + 1,85,1,85,1,85,1,85,1,86,1,86,3,86,1626,8,86,1,86,1,86,1,87,1,87,3,87, + 1632,8,87,1,87,1,87,1,87,5,87,1637,8,87,10,87,12,87,1640,9,87,1,88,1, + 88,1,88,1,88,1,88,1,88,1,88,1,88,1,88,1,88,3,88,1652,8,88,1,89,1,89,3, + 89,1656,8,89,1,89,1,89,3,89,1660,8,89,1,89,1,89,3,89,1664,8,89,1,89,5, + 89,1667,8,89,10,89,12,89,1670,9,89,1,89,1,89,3,89,1674,8,89,1,89,1,89, + 3,89,1678,8,89,1,89,1,89,3,89,1682,8,89,1,89,1,89,3,89,1686,8,89,1,90, + 1,90,3,90,1690,8,90,1,90,1,90,3,90,1694,8,90,1,90,1,90,1,91,1,91,3,91, + 1700,8,91,1,91,1,91,3,91,1704,8,91,1,91,1,91,3,91,1708,8,91,1,91,1,91, + 3,91,1712,8,91,1,91,5,91,1715,8,91,10,91,12,91,1718,9,91,1,92,1,92,1, + 92,3,92,1723,8,92,1,92,3,92,1726,8,92,1,93,1,93,1,93,1,94,3,94,1732,8, + 94,1,94,3,94,1735,8,94,1,94,1,94,1,94,1,94,3,94,1741,8,94,1,94,1,94,3, + 94,1745,8,94,1,94,1,94,3,94,1749,8,94,1,95,1,95,3,95,1753,8,95,1,95,1, + 95,3,95,1757,8,95,1,95,5,95,1760,8,95,10,95,12,95,1763,9,95,1,95,1,95, + 3,95,1767,8,95,1,95,1,95,3,95,1771,8,95,1,95,5,95,1774,8,95,10,95,12, + 95,1777,9,95,3,95,1779,8,95,1,96,1,96,1,96,1,96,1,96,1,96,1,96,3,96,1788, + 8,96,1,97,1,97,1,97,1,97,1,97,1,97,1,97,3,97,1797,8,97,1,97,5,97,1800, + 8,97,10,97,12,97,1803,9,97,1,98,1,98,1,98,1,98,1,99,1,99,1,99,1,99,1, + 100,1,100,3,100,1815,8,100,1,100,3,100,1818,8,100,1,101,1,101,1,101,1, + 101,1,102,1,102,3,102,1826,8,102,1,102,1,102,3,102,1830,8,102,1,102,5, + 102,1833,8,102,10,102,12,102,1836,9,102,1,103,1,103,3,103,1840,8,103, + 1,103,1,103,3,103,1844,8,103,1,103,1,103,1,103,3,103,1849,8,103,1,104, + 1,104,1,105,1,105,3,105,1855,8,105,1,105,5,105,1858,8,105,10,105,12,105, + 1861,9,105,1,105,1,105,1,105,1,105,3,105,1867,8,105,1,106,1,106,3,106, + 1871,8,106,1,106,1,106,3,106,1875,8,106,3,106,1877,8,106,1,106,1,106, + 3,106,1881,8,106,3,106,1883,8,106,1,106,1,106,3,106,1887,8,106,3,106, + 1889,8,106,1,106,1,106,1,107,1,107,3,107,1895,8,107,1,107,1,107,1,108, + 1,108,3,108,1901,8,108,1,108,1,108,3,108,1905,8,108,1,108,3,108,1908, + 8,108,1,108,3,108,1911,8,108,1,108,1,108,1,108,1,108,3,108,1917,8,108, + 1,108,3,108,1920,8,108,1,108,3,108,1923,8,108,1,108,1,108,3,108,1927, + 8,108,1,108,1,108,1,108,1,108,3,108,1933,8,108,1,108,3,108,1936,8,108, + 1,108,3,108,1939,8,108,1,108,1,108,3,108,1943,8,108,1,109,1,109,3,109, + 1947,8,109,1,109,1,109,3,109,1951,8,109,3,109,1953,8,109,1,109,1,109, + 3,109,1957,8,109,3,109,1959,8,109,1,109,1,109,3,109,1963,8,109,3,109, + 1965,8,109,1,109,1,109,3,109,1969,8,109,3,109,1971,8,109,1,109,1,109, + 1,110,1,110,3,110,1977,8,110,1,110,1,110,3,110,1981,8,110,1,110,1,110, + 3,110,1985,8,110,1,110,1,110,3,110,1989,8,110,1,110,1,110,3,110,1993, + 8,110,1,110,1,110,3,110,1997,8,110,1,110,1,110,3,110,2001,8,110,1,110, + 1,110,3,110,2005,8,110,5,110,2007,8,110,10,110,12,110,2010,9,110,3,110, + 2012,8,110,1,110,1,110,1,111,1,111,3,111,2018,8,111,1,111,1,111,3,111, + 2022,8,111,1,111,1,111,3,111,2026,8,111,1,111,3,111,2029,8,111,1,111, + 5,111,2032,8,111,10,111,12,111,2035,9,111,1,112,1,112,3,112,2039,8,112, + 1,112,1,112,3,112,2043,8,112,1,112,1,112,3,112,2047,8,112,1,112,3,112, + 2050,8,112,1,112,3,112,2053,8,112,1,112,5,112,2056,8,112,10,112,12,112, + 2059,9,112,1,113,1,113,3,113,2063,8,113,1,113,3,113,2066,8,113,1,113, + 3,113,2069,8,113,1,113,3,113,2072,8,113,1,113,3,113,2075,8,113,1,113, + 3,113,2078,8,113,1,114,1,114,3,114,2082,8,114,1,114,1,114,3,114,2086, + 8,114,1,114,1,114,3,114,2090,8,114,1,114,1,114,3,114,2094,8,114,1,114, + 1,114,1,114,1,114,1,114,1,114,1,114,1,114,3,114,2104,8,114,1,115,3,115, + 2107,8,115,1,115,3,115,2110,8,115,1,115,1,115,3,115,2114,8,115,1,115, + 3,115,2117,8,115,1,115,3,115,2120,8,115,1,116,1,116,3,116,2124,8,116, + 1,116,1,116,3,116,2128,8,116,1,116,1,116,3,116,2132,8,116,1,116,1,116, + 3,116,2136,8,116,1,116,1,116,3,116,2140,8,116,1,116,1,116,3,116,2144, + 8,116,3,116,2146,8,116,1,116,3,116,2149,8,116,1,116,1,116,3,116,2153, + 8,116,1,116,1,116,3,116,2157,8,116,1,116,1,116,3,116,2161,8,116,1,116, + 1,116,3,116,2165,8,116,3,116,2167,8,116,1,116,1,116,1,117,1,117,3,117, + 2173,8,117,1,117,3,117,2176,8,117,1,117,3,117,2179,8,117,1,117,1,117, + 1,118,1,118,1,119,1,119,1,120,1,120,1,121,1,121,1,122,1,122,1,123,1,123, + 1,123,1,123,1,123,5,123,2198,8,123,10,123,12,123,2201,9,123,1,124,1,124, + 1,124,1,124,1,124,5,124,2208,8,124,10,124,12,124,2211,9,124,1,125,1,125, + 1,125,1,125,1,125,5,125,2218,8,125,10,125,12,125,2221,9,125,1,126,1,126, + 3,126,2225,8,126,5,126,2227,8,126,10,126,12,126,2230,9,126,1,126,1,126, + 1,127,1,127,3,127,2236,8,127,1,127,1,127,3,127,2240,8,127,1,127,1,127, + 3,127,2244,8,127,1,127,1,127,3,127,2248,8,127,1,127,1,127,3,127,2252, + 8,127,1,127,1,127,1,127,1,127,1,127,1,127,3,127,2260,8,127,1,127,1,127, + 3,127,2264,8,127,1,127,1,127,3,127,2268,8,127,1,127,1,127,3,127,2272, + 8,127,1,127,1,127,4,127,2276,8,127,11,127,12,127,2277,1,127,1,127,3,127, + 2282,8,127,1,128,1,128,1,129,1,129,3,129,2288,8,129,1,129,1,129,3,129, + 2292,8,129,1,129,5,129,2295,8,129,10,129,12,129,2298,9,129,1,130,1,130, + 3,130,2302,8,130,1,130,1,130,3,130,2306,8,130,1,130,5,130,2309,8,130, + 10,130,12,130,2312,9,130,1,131,1,131,3,131,2316,8,131,1,131,1,131,3,131, + 2320,8,131,1,131,1,131,5,131,2324,8,131,10,131,12,131,2327,9,131,1,132, + 1,132,1,133,1,133,3,133,2333,8,133,1,133,1,133,3,133,2337,8,133,1,133, + 1,133,5,133,2341,8,133,10,133,12,133,2344,9,133,1,134,1,134,1,135,1,135, + 3,135,2350,8,135,1,135,1,135,3,135,2354,8,135,1,135,1,135,5,135,2358, + 8,135,10,135,12,135,2361,9,135,1,136,1,136,1,137,1,137,3,137,2367,8,137, + 1,137,1,137,3,137,2371,8,137,1,137,5,137,2374,8,137,10,137,12,137,2377, + 9,137,1,138,1,138,1,138,4,138,2382,8,138,11,138,12,138,2383,1,138,3,138, + 2387,8,138,1,139,1,139,1,139,3,139,2392,8,139,1,139,1,139,1,139,1,139, + 1,139,1,139,1,139,3,139,2401,8,139,1,139,1,139,3,139,2405,8,139,1,139, + 3,139,2408,8,139,1,140,1,140,1,140,1,140,1,140,1,140,1,140,1,140,1,140, + 1,140,1,140,3,140,2421,8,140,1,140,3,140,2424,8,140,1,140,1,140,1,141, + 3,141,2429,8,141,1,141,1,141,1,142,1,142,1,142,1,142,1,142,1,142,1,142, + 1,142,1,142,1,142,3,142,2443,8,142,1,143,1,143,3,143,2447,8,143,5,143, + 2449,8,143,10,143,12,143,2452,9,143,1,143,1,143,3,143,2456,8,143,1,143, + 3,143,2459,8,143,1,144,1,144,3,144,2463,8,144,1,144,5,144,2466,8,144, + 10,144,12,144,2469,9,144,1,145,1,145,1,145,1,145,1,145,1,145,1,145,1, + 145,1,145,3,145,2480,8,145,1,146,1,146,3,146,2484,8,146,1,146,1,146,3, + 146,2488,8,146,1,146,1,146,3,146,2492,8,146,1,146,1,146,1,146,1,146,3, + 146,2498,8,146,1,146,1,146,3,146,2502,8,146,1,146,1,146,3,146,2506,8, + 146,1,146,1,146,1,146,1,146,3,146,2512,8,146,1,146,1,146,3,146,2516,8, + 146,1,146,1,146,3,146,2520,8,146,1,146,1,146,1,146,1,146,3,146,2526,8, + 146,1,146,1,146,3,146,2530,8,146,1,146,1,146,3,146,2534,8,146,1,146,1, + 146,3,146,2538,8,146,1,147,1,147,1,147,1,147,1,148,1,148,1,148,1,148, + 1,148,1,148,1,149,1,149,1,149,1,149,1,149,1,149,3,149,2556,8,149,1,150, + 1,150,1,151,1,151,3,151,2562,8,151,1,151,1,151,3,151,2566,8,151,1,151, + 1,151,3,151,2570,8,151,5,151,2572,8,151,10,151,12,151,2575,9,151,3,151, + 2577,8,151,1,151,1,151,1,152,1,152,3,152,2583,8,152,1,152,3,152,2586, + 8,152,1,153,1,153,3,153,2590,8,153,1,153,1,153,3,153,2594,8,153,1,153, + 1,153,3,153,2598,8,153,1,153,1,153,3,153,2602,8,153,5,153,2604,8,153, + 10,153,12,153,2607,9,153,1,153,1,153,1,154,1,154,3,154,2613,8,154,1,154, + 3,154,2616,8,154,1,154,1,154,3,154,2620,8,154,1,154,1,154,1,155,1,155, + 3,155,2626,8,155,1,155,1,155,3,155,2630,8,155,1,155,1,155,1,156,1,156, + 3,156,2636,8,156,1,156,1,156,3,156,2640,8,156,1,156,1,156,3,156,2644, + 8,156,1,156,1,156,1,156,3,156,2649,8,156,1,156,1,156,3,156,2653,8,156, + 1,156,1,156,3,156,2657,8,156,1,156,1,156,3,156,2661,8,156,1,156,1,156, + 1,156,3,156,2666,8,156,1,156,3,156,2669,8,156,1,156,3,156,2672,8,156, + 1,156,1,156,1,156,1,156,3,156,2678,8,156,1,156,1,156,3,156,2682,8,156, + 1,156,1,156,3,156,2686,8,156,3,156,2688,8,156,1,156,1,156,3,156,2692, + 8,156,1,156,1,156,3,156,2696,8,156,1,156,1,156,3,156,2700,8,156,5,156, + 2702,8,156,10,156,12,156,2705,9,156,3,156,2707,8,156,1,156,1,156,3,156, + 2711,8,156,1,157,1,157,1,158,1,158,3,158,2717,8,158,1,158,1,158,1,158, + 3,158,2722,8,158,3,158,2724,8,158,1,158,1,158,3,158,2728,8,158,1,159, + 1,159,3,159,2732,8,159,1,159,1,159,1,159,3,159,2737,8,159,1,159,1,159, + 3,159,2741,8,159,1,160,1,160,1,160,3,160,2746,8,160,1,160,1,160,3,160, + 2750,8,160,1,160,1,160,3,160,2754,8,160,1,160,1,160,3,160,2758,8,160, + 5,160,2760,8,160,10,160,12,160,2763,9,160,1,160,1,160,3,160,2767,8,160, + 1,161,1,161,3,161,2771,8,161,1,161,4,161,2774,8,161,11,161,12,161,2775, + 1,162,1,162,3,162,2780,8,162,1,162,1,162,3,162,2784,8,162,1,162,1,162, + 3,162,2788,8,162,1,162,1,162,3,162,2792,8,162,1,162,3,162,2795,8,162, + 1,162,3,162,2798,8,162,1,162,3,162,2801,8,162,1,162,3,162,2804,8,162, + 1,162,1,162,1,163,1,163,3,163,2810,8,163,1,163,1,163,3,163,2814,8,163, + 1,164,1,164,3,164,2818,8,164,1,164,4,164,2821,8,164,11,164,12,164,2822, + 1,164,1,164,3,164,2827,8,164,1,164,1,164,3,164,2831,8,164,1,164,4,164, + 2834,8,164,11,164,12,164,2835,3,164,2838,8,164,1,164,3,164,2841,8,164, + 1,164,1,164,3,164,2845,8,164,1,164,3,164,2848,8,164,1,164,3,164,2851, + 8,164,1,164,1,164,1,165,1,165,3,165,2857,8,165,1,165,1,165,3,165,2861, + 8,165,1,165,1,165,3,165,2865,8,165,1,165,1,165,1,166,1,166,1,167,1,167, + 3,167,2873,8,167,1,168,1,168,1,168,3,168,2878,8,168,1,169,1,169,3,169, + 2882,8,169,1,169,1,169,1,170,1,170,1,171,1,171,1,172,1,172,1,173,1,173, + 1,174,1,174,1,174,1,174,1,174,3,174,2899,8,174,1,175,1,175,1,176,1,176, + 1,177,1,177,1,178,1,178,1,178,0,2,114,168,179,0,2,4,6,8,10,12,14,16,18, + 20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64, + 66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108, + 110,112,114,116,118,120,122,124,126,128,130,132,134,136,138,140,142,144, + 146,148,150,152,154,156,158,160,162,164,166,168,170,172,174,176,178,180, + 182,184,186,188,190,192,194,196,198,200,202,204,206,208,210,212,214,216, + 218,220,222,224,226,228,230,232,234,236,238,240,242,244,246,248,250,252, + 254,256,258,260,262,264,266,268,270,272,274,276,278,280,282,284,286,288, + 290,292,294,296,298,300,302,304,306,308,310,312,314,316,318,320,322,324, + 326,328,330,332,334,336,338,340,342,344,346,348,350,352,354,356,0,14, + 3,0,105,105,130,130,136,136,2,0,52,53,74,75,2,0,6,6,12,16,1,0,18,19,2, + 0,20,20,166,166,2,0,21,22,161,161,1,0,164,165,2,0,86,86,141,141,2,0,67, + 67,82,82,1,0,177,178,31,0,47,47,49,49,51,51,54,57,60,60,62,63,65,67,69, + 70,73,73,76,76,78,78,83,85,87,88,90,90,94,95,97,97,99,99,101,104,106, + 109,111,112,123,128,130,131,133,133,135,135,138,138,140,140,142,142,145, + 147,151,151,155,160,162,162,2,0,13,13,26,29,2,0,15,15,30,33,2,0,34,44, + 166,166,3294,0,358,1,0,0,0,2,378,1,0,0,0,4,412,1,0,0,0,6,414,1,0,0,0, + 8,440,1,0,0,0,10,488,1,0,0,0,12,490,1,0,0,0,14,520,1,0,0,0,16,550,1,0, + 0,0,18,570,1,0,0,0,20,576,1,0,0,0,22,627,1,0,0,0,24,629,1,0,0,0,26,643, + 1,0,0,0,28,647,1,0,0,0,30,666,1,0,0,0,32,668,1,0,0,0,34,680,1,0,0,0,36, + 723,1,0,0,0,38,737,1,0,0,0,40,781,1,0,0,0,42,783,1,0,0,0,44,789,1,0,0, + 0,46,830,1,0,0,0,48,901,1,0,0,0,50,915,1,0,0,0,52,923,1,0,0,0,54,940, + 1,0,0,0,56,957,1,0,0,0,58,959,1,0,0,0,60,966,1,0,0,0,62,979,1,0,0,0,64, + 990,1,0,0,0,66,1010,1,0,0,0,68,1021,1,0,0,0,70,1023,1,0,0,0,72,1036,1, + 0,0,0,74,1040,1,0,0,0,76,1044,1,0,0,0,78,1055,1,0,0,0,80,1069,1,0,0,0, + 82,1071,1,0,0,0,84,1085,1,0,0,0,86,1089,1,0,0,0,88,1098,1,0,0,0,90,1104, + 1,0,0,0,92,1112,1,0,0,0,94,1121,1,0,0,0,96,1130,1,0,0,0,98,1144,1,0,0, + 0,100,1148,1,0,0,0,102,1162,1,0,0,0,104,1173,1,0,0,0,106,1189,1,0,0,0, + 108,1203,1,0,0,0,110,1217,1,0,0,0,112,1239,1,0,0,0,114,1267,1,0,0,0,116, + 1276,1,0,0,0,118,1283,1,0,0,0,120,1291,1,0,0,0,122,1293,1,0,0,0,124,1298, + 1,0,0,0,126,1313,1,0,0,0,128,1319,1,0,0,0,130,1321,1,0,0,0,132,1333,1, + 0,0,0,134,1344,1,0,0,0,136,1348,1,0,0,0,138,1352,1,0,0,0,140,1375,1,0, + 0,0,142,1389,1,0,0,0,144,1393,1,0,0,0,146,1430,1,0,0,0,148,1436,1,0,0, + 0,150,1448,1,0,0,0,152,1466,1,0,0,0,154,1472,1,0,0,0,156,1474,1,0,0,0, + 158,1524,1,0,0,0,160,1528,1,0,0,0,162,1542,1,0,0,0,164,1561,1,0,0,0,166, + 1576,1,0,0,0,168,1592,1,0,0,0,170,1613,1,0,0,0,172,1623,1,0,0,0,174,1629, + 1,0,0,0,176,1651,1,0,0,0,178,1685,1,0,0,0,180,1687,1,0,0,0,182,1699,1, + 0,0,0,184,1719,1,0,0,0,186,1727,1,0,0,0,188,1734,1,0,0,0,190,1778,1,0, + 0,0,192,1787,1,0,0,0,194,1789,1,0,0,0,196,1804,1,0,0,0,198,1808,1,0,0, + 0,200,1812,1,0,0,0,202,1819,1,0,0,0,204,1823,1,0,0,0,206,1848,1,0,0,0, + 208,1850,1,0,0,0,210,1866,1,0,0,0,212,1868,1,0,0,0,214,1892,1,0,0,0,216, + 1942,1,0,0,0,218,1944,1,0,0,0,220,1974,1,0,0,0,222,2015,1,0,0,0,224,2036, + 1,0,0,0,226,2060,1,0,0,0,228,2103,1,0,0,0,230,2119,1,0,0,0,232,2121,1, + 0,0,0,234,2170,1,0,0,0,236,2182,1,0,0,0,238,2184,1,0,0,0,240,2186,1,0, + 0,0,242,2188,1,0,0,0,244,2190,1,0,0,0,246,2192,1,0,0,0,248,2202,1,0,0, + 0,250,2212,1,0,0,0,252,2228,1,0,0,0,254,2281,1,0,0,0,256,2283,1,0,0,0, + 258,2285,1,0,0,0,260,2299,1,0,0,0,262,2313,1,0,0,0,264,2328,1,0,0,0,266, + 2330,1,0,0,0,268,2345,1,0,0,0,270,2347,1,0,0,0,272,2362,1,0,0,0,274,2364, + 1,0,0,0,276,2378,1,0,0,0,278,2407,1,0,0,0,280,2420,1,0,0,0,282,2428,1, + 0,0,0,284,2442,1,0,0,0,286,2450,1,0,0,0,288,2460,1,0,0,0,290,2479,1,0, + 0,0,292,2537,1,0,0,0,294,2539,1,0,0,0,296,2543,1,0,0,0,298,2555,1,0,0, + 0,300,2557,1,0,0,0,302,2559,1,0,0,0,304,2580,1,0,0,0,306,2587,1,0,0,0, + 308,2612,1,0,0,0,310,2623,1,0,0,0,312,2710,1,0,0,0,314,2712,1,0,0,0,316, + 2727,1,0,0,0,318,2729,1,0,0,0,320,2766,1,0,0,0,322,2768,1,0,0,0,324,2777, + 1,0,0,0,326,2807,1,0,0,0,328,2837,1,0,0,0,330,2854,1,0,0,0,332,2868,1, + 0,0,0,334,2872,1,0,0,0,336,2874,1,0,0,0,338,2879,1,0,0,0,340,2885,1,0, + 0,0,342,2887,1,0,0,0,344,2889,1,0,0,0,346,2891,1,0,0,0,348,2898,1,0,0, + 0,350,2900,1,0,0,0,352,2902,1,0,0,0,354,2904,1,0,0,0,356,2906,1,0,0,0, + 358,369,3,2,1,0,359,361,5,183,0,0,360,359,1,0,0,0,360,361,1,0,0,0,361, + 362,1,0,0,0,362,364,5,1,0,0,363,365,5,183,0,0,364,363,1,0,0,0,364,365, + 1,0,0,0,365,366,1,0,0,0,366,368,3,2,1,0,367,360,1,0,0,0,368,371,1,0,0, + 0,369,367,1,0,0,0,369,370,1,0,0,0,370,373,1,0,0,0,371,369,1,0,0,0,372, + 374,5,183,0,0,373,372,1,0,0,0,373,374,1,0,0,0,374,375,1,0,0,0,375,376, + 5,0,0,1,376,1,1,0,0,0,377,379,3,120,60,0,378,377,1,0,0,0,378,379,1,0, + 0,0,379,381,1,0,0,0,380,382,5,183,0,0,381,380,1,0,0,0,381,382,1,0,0,0, + 382,383,1,0,0,0,383,388,3,4,2,0,384,386,5,183,0,0,385,384,1,0,0,0,385, + 386,1,0,0,0,386,387,1,0,0,0,387,389,5,1,0,0,388,385,1,0,0,0,388,389,1, + 0,0,0,389,3,1,0,0,0,390,413,3,138,69,0,391,413,3,60,30,0,392,413,3,62, + 31,0,393,413,3,44,22,0,394,413,3,46,23,0,395,413,3,52,26,0,396,413,3, + 54,27,0,397,413,3,76,38,0,398,413,3,78,39,0,399,413,3,6,3,0,400,413,3, + 12,6,0,401,413,3,14,7,0,402,413,3,30,15,0,403,413,3,34,17,0,404,413,3, + 32,16,0,405,413,3,126,63,0,406,413,3,128,64,0,407,413,3,16,8,0,408,413, + 3,18,9,0,409,413,3,20,10,0,410,413,3,26,13,0,411,413,3,28,14,0,412,390, + 1,0,0,0,412,391,1,0,0,0,412,392,1,0,0,0,412,393,1,0,0,0,412,394,1,0,0, + 0,412,395,1,0,0,0,412,396,1,0,0,0,412,397,1,0,0,0,412,398,1,0,0,0,412, + 399,1,0,0,0,412,400,1,0,0,0,412,401,1,0,0,0,412,402,1,0,0,0,412,403,1, + 0,0,0,412,404,1,0,0,0,412,405,1,0,0,0,412,406,1,0,0,0,412,407,1,0,0,0, + 412,408,1,0,0,0,412,409,1,0,0,0,412,410,1,0,0,0,412,411,1,0,0,0,413,5, + 1,0,0,0,414,415,5,66,0,0,415,416,5,183,0,0,416,418,3,346,173,0,417,419, + 3,8,4,0,418,417,1,0,0,0,418,419,1,0,0,0,419,420,1,0,0,0,420,421,5,183, + 0,0,421,422,5,87,0,0,422,423,5,183,0,0,423,437,3,10,5,0,424,426,5,183, + 0,0,425,424,1,0,0,0,425,426,1,0,0,0,426,427,1,0,0,0,427,429,5,2,0,0,428, + 430,5,183,0,0,429,428,1,0,0,0,429,430,1,0,0,0,430,431,1,0,0,0,431,433, + 3,24,12,0,432,434,5,183,0,0,433,432,1,0,0,0,433,434,1,0,0,0,434,435,1, + 0,0,0,435,436,5,3,0,0,436,438,1,0,0,0,437,425,1,0,0,0,437,438,1,0,0,0, + 438,7,1,0,0,0,439,441,5,183,0,0,440,439,1,0,0,0,440,441,1,0,0,0,441,442, + 1,0,0,0,442,444,5,2,0,0,443,445,5,183,0,0,444,443,1,0,0,0,444,445,1,0, + 0,0,445,463,1,0,0,0,446,457,3,346,173,0,447,449,5,183,0,0,448,447,1,0, + 0,0,448,449,1,0,0,0,449,450,1,0,0,0,450,452,5,4,0,0,451,453,5,183,0,0, + 452,451,1,0,0,0,452,453,1,0,0,0,453,454,1,0,0,0,454,456,3,346,173,0,455, + 448,1,0,0,0,456,459,1,0,0,0,457,455,1,0,0,0,457,458,1,0,0,0,458,461,1, + 0,0,0,459,457,1,0,0,0,460,462,5,183,0,0,461,460,1,0,0,0,461,462,1,0,0, + 0,462,464,1,0,0,0,463,446,1,0,0,0,463,464,1,0,0,0,464,465,1,0,0,0,465, + 466,5,3,0,0,466,9,1,0,0,0,467,489,3,40,20,0,468,470,5,2,0,0,469,471,5, + 183,0,0,470,469,1,0,0,0,470,471,1,0,0,0,471,472,1,0,0,0,472,474,3,138, + 69,0,473,475,5,183,0,0,474,473,1,0,0,0,474,475,1,0,0,0,475,476,1,0,0, + 0,476,477,5,3,0,0,477,489,1,0,0,0,478,489,3,336,168,0,479,489,3,332,166, + 0,480,481,3,332,166,0,481,483,5,5,0,0,482,484,5,183,0,0,483,482,1,0,0, + 0,483,484,1,0,0,0,484,485,1,0,0,0,485,486,3,346,173,0,486,489,1,0,0,0, + 487,489,3,312,156,0,488,467,1,0,0,0,488,468,1,0,0,0,488,478,1,0,0,0,488, + 479,1,0,0,0,488,480,1,0,0,0,488,487,1,0,0,0,489,11,1,0,0,0,490,491,5, + 66,0,0,491,492,5,183,0,0,492,493,3,346,173,0,493,494,5,183,0,0,494,495, + 5,87,0,0,495,496,5,183,0,0,496,498,5,2,0,0,497,499,5,183,0,0,498,497, + 1,0,0,0,498,499,1,0,0,0,499,500,1,0,0,0,500,511,5,168,0,0,501,503,5,183, + 0,0,502,501,1,0,0,0,502,503,1,0,0,0,503,504,1,0,0,0,504,506,5,4,0,0,505, + 507,5,183,0,0,506,505,1,0,0,0,506,507,1,0,0,0,507,508,1,0,0,0,508,510, + 5,168,0,0,509,502,1,0,0,0,510,513,1,0,0,0,511,509,1,0,0,0,511,512,1,0, + 0,0,512,514,1,0,0,0,513,511,1,0,0,0,514,515,5,3,0,0,515,516,5,183,0,0, + 516,517,5,56,0,0,517,518,5,183,0,0,518,519,5,61,0,0,519,13,1,0,0,0,520, + 521,5,66,0,0,521,522,5,183,0,0,522,524,5,2,0,0,523,525,5,183,0,0,524, + 523,1,0,0,0,524,525,1,0,0,0,525,526,1,0,0,0,526,528,3,138,69,0,527,529, + 5,183,0,0,528,527,1,0,0,0,528,529,1,0,0,0,529,530,1,0,0,0,530,531,5,3, + 0,0,531,532,5,183,0,0,532,533,5,138,0,0,533,534,5,183,0,0,534,548,5,168, + 0,0,535,537,5,183,0,0,536,535,1,0,0,0,536,537,1,0,0,0,537,538,1,0,0,0, + 538,540,5,2,0,0,539,541,5,183,0,0,540,539,1,0,0,0,540,541,1,0,0,0,541, + 542,1,0,0,0,542,544,3,24,12,0,543,545,5,183,0,0,544,543,1,0,0,0,544,545, + 1,0,0,0,545,546,1,0,0,0,546,547,5,3,0,0,547,549,1,0,0,0,548,536,1,0,0, + 0,548,549,1,0,0,0,549,15,1,0,0,0,550,551,5,84,0,0,551,552,5,183,0,0,552, + 553,5,70,0,0,553,554,5,183,0,0,554,568,5,168,0,0,555,557,5,183,0,0,556, + 555,1,0,0,0,556,557,1,0,0,0,557,558,1,0,0,0,558,560,5,2,0,0,559,561,5, + 183,0,0,560,559,1,0,0,0,560,561,1,0,0,0,561,562,1,0,0,0,562,564,3,24, + 12,0,563,565,5,183,0,0,564,563,1,0,0,0,564,565,1,0,0,0,565,566,1,0,0, + 0,566,567,5,3,0,0,567,569,1,0,0,0,568,556,1,0,0,0,568,569,1,0,0,0,569, + 17,1,0,0,0,570,571,5,94,0,0,571,572,5,183,0,0,572,573,5,70,0,0,573,574, + 5,183,0,0,574,575,5,168,0,0,575,19,1,0,0,0,576,577,5,54,0,0,577,578,5, + 183,0,0,578,583,5,168,0,0,579,580,5,183,0,0,580,581,5,51,0,0,581,582, + 5,183,0,0,582,584,3,346,173,0,583,579,1,0,0,0,583,584,1,0,0,0,584,585, + 1,0,0,0,585,586,5,183,0,0,586,588,5,2,0,0,587,589,5,183,0,0,588,587,1, + 0,0,0,588,589,1,0,0,0,589,590,1,0,0,0,590,591,5,71,0,0,591,592,5,183, + 0,0,592,601,3,348,174,0,593,595,5,183,0,0,594,593,1,0,0,0,594,595,1,0, + 0,0,595,596,1,0,0,0,596,598,5,4,0,0,597,599,5,183,0,0,598,597,1,0,0,0, + 598,599,1,0,0,0,599,600,1,0,0,0,600,602,3,24,12,0,601,594,1,0,0,0,601, + 602,1,0,0,0,602,604,1,0,0,0,603,605,5,183,0,0,604,603,1,0,0,0,604,605, + 1,0,0,0,605,606,1,0,0,0,606,607,5,3,0,0,607,21,1,0,0,0,608,622,3,348, + 174,0,609,611,5,183,0,0,610,609,1,0,0,0,610,611,1,0,0,0,611,612,1,0,0, + 0,612,614,5,6,0,0,613,615,5,183,0,0,614,613,1,0,0,0,614,615,1,0,0,0,615, + 623,1,0,0,0,616,618,5,183,0,0,617,616,1,0,0,0,618,621,1,0,0,0,619,617, + 1,0,0,0,619,620,1,0,0,0,620,623,1,0,0,0,621,619,1,0,0,0,622,610,1,0,0, + 0,622,619,1,0,0,0,623,624,1,0,0,0,624,625,3,298,149,0,625,628,1,0,0,0, + 626,628,3,348,174,0,627,608,1,0,0,0,627,626,1,0,0,0,628,23,1,0,0,0,629, + 640,3,22,11,0,630,632,5,183,0,0,631,630,1,0,0,0,631,632,1,0,0,0,632,633, + 1,0,0,0,633,635,5,4,0,0,634,636,5,183,0,0,635,634,1,0,0,0,635,636,1,0, + 0,0,636,637,1,0,0,0,637,639,3,22,11,0,638,631,1,0,0,0,639,642,1,0,0,0, + 640,638,1,0,0,0,640,641,1,0,0,0,641,25,1,0,0,0,642,640,1,0,0,0,643,644, + 5,76,0,0,644,645,5,183,0,0,645,646,3,346,173,0,646,27,1,0,0,0,647,648, + 5,147,0,0,648,649,5,183,0,0,649,650,3,346,173,0,650,29,1,0,0,0,651,652, + 5,57,0,0,652,653,5,183,0,0,653,655,3,348,174,0,654,656,5,183,0,0,655, + 654,1,0,0,0,655,656,1,0,0,0,656,657,1,0,0,0,657,659,5,6,0,0,658,660,5, + 183,0,0,659,658,1,0,0,0,659,660,1,0,0,0,660,661,1,0,0,0,661,662,3,244, + 122,0,662,667,1,0,0,0,663,664,5,57,0,0,664,665,5,183,0,0,665,667,3,312, + 156,0,666,651,1,0,0,0,666,663,1,0,0,0,667,31,1,0,0,0,668,669,5,62,0,0, + 669,670,5,183,0,0,670,671,5,116,0,0,671,672,5,183,0,0,672,673,5,136,0, + 0,673,674,5,183,0,0,674,675,3,346,173,0,675,676,5,183,0,0,676,677,5,99, + 0,0,677,678,5,183,0,0,678,679,5,168,0,0,679,33,1,0,0,0,680,681,5,68,0, + 0,681,682,5,183,0,0,682,683,5,105,0,0,683,684,5,183,0,0,684,686,3,314, + 157,0,685,687,5,183,0,0,686,685,1,0,0,0,686,687,1,0,0,0,687,688,1,0,0, + 0,688,690,5,2,0,0,689,691,5,183,0,0,690,689,1,0,0,0,690,691,1,0,0,0,691, + 693,1,0,0,0,692,694,3,36,18,0,693,692,1,0,0,0,693,694,1,0,0,0,694,696, + 1,0,0,0,695,697,5,183,0,0,696,695,1,0,0,0,696,697,1,0,0,0,697,699,1,0, + 0,0,698,700,3,38,19,0,699,698,1,0,0,0,699,700,1,0,0,0,700,711,1,0,0,0, + 701,703,5,183,0,0,702,701,1,0,0,0,702,703,1,0,0,0,703,704,1,0,0,0,704, + 706,5,4,0,0,705,707,5,183,0,0,706,705,1,0,0,0,706,707,1,0,0,0,707,708, + 1,0,0,0,708,710,3,38,19,0,709,702,1,0,0,0,710,713,1,0,0,0,711,709,1,0, + 0,0,711,712,1,0,0,0,712,715,1,0,0,0,713,711,1,0,0,0,714,716,5,183,0,0, + 715,714,1,0,0,0,715,716,1,0,0,0,716,717,1,0,0,0,717,718,5,3,0,0,718,719, + 5,183,0,0,719,720,5,51,0,0,720,721,5,183,0,0,721,722,3,244,122,0,722, + 35,1,0,0,0,723,734,3,348,174,0,724,726,5,183,0,0,725,724,1,0,0,0,725, + 726,1,0,0,0,726,727,1,0,0,0,727,729,5,4,0,0,728,730,5,183,0,0,729,728, + 1,0,0,0,729,730,1,0,0,0,730,731,1,0,0,0,731,733,3,348,174,0,732,725,1, + 0,0,0,733,736,1,0,0,0,734,732,1,0,0,0,734,735,1,0,0,0,735,37,1,0,0,0, + 736,734,1,0,0,0,737,739,3,348,174,0,738,740,5,183,0,0,739,738,1,0,0,0, + 739,740,1,0,0,0,740,741,1,0,0,0,741,742,5,164,0,0,742,744,5,6,0,0,743, + 745,5,183,0,0,744,743,1,0,0,0,744,745,1,0,0,0,745,746,1,0,0,0,746,747, + 3,298,149,0,747,39,1,0,0,0,748,750,5,7,0,0,749,751,5,183,0,0,750,749, + 1,0,0,0,750,751,1,0,0,0,751,752,1,0,0,0,752,763,5,168,0,0,753,755,5,183, + 0,0,754,753,1,0,0,0,754,755,1,0,0,0,755,756,1,0,0,0,756,758,5,4,0,0,757, + 759,5,183,0,0,758,757,1,0,0,0,758,759,1,0,0,0,759,760,1,0,0,0,760,762, + 5,168,0,0,761,754,1,0,0,0,762,765,1,0,0,0,763,761,1,0,0,0,763,764,1,0, + 0,0,764,766,1,0,0,0,765,763,1,0,0,0,766,782,5,8,0,0,767,782,5,168,0,0, + 768,770,5,89,0,0,769,771,5,183,0,0,770,769,1,0,0,0,770,771,1,0,0,0,771, + 772,1,0,0,0,772,774,5,2,0,0,773,775,5,183,0,0,774,773,1,0,0,0,774,775, + 1,0,0,0,775,776,1,0,0,0,776,778,5,168,0,0,777,779,5,183,0,0,778,777,1, + 0,0,0,778,779,1,0,0,0,779,780,1,0,0,0,780,782,5,3,0,0,781,748,1,0,0,0, + 781,767,1,0,0,0,781,768,1,0,0,0,782,41,1,0,0,0,783,784,5,95,0,0,784,785, + 5,183,0,0,785,786,5,113,0,0,786,787,5,183,0,0,787,788,5,82,0,0,788,43, + 1,0,0,0,789,790,5,68,0,0,790,791,5,183,0,0,791,792,5,112,0,0,792,793, + 5,183,0,0,793,794,5,136,0,0,794,798,5,183,0,0,795,796,3,42,21,0,796,797, + 5,183,0,0,797,799,1,0,0,0,798,795,1,0,0,0,798,799,1,0,0,0,799,800,1,0, + 0,0,800,828,3,346,173,0,801,803,5,183,0,0,802,801,1,0,0,0,802,803,1,0, + 0,0,803,804,1,0,0,0,804,806,5,2,0,0,805,807,5,183,0,0,806,805,1,0,0,0, + 806,807,1,0,0,0,807,808,1,0,0,0,808,810,3,100,50,0,809,811,5,183,0,0, + 810,809,1,0,0,0,810,811,1,0,0,0,811,817,1,0,0,0,812,814,5,4,0,0,813,815, + 5,183,0,0,814,813,1,0,0,0,814,815,1,0,0,0,815,816,1,0,0,0,816,818,3,104, + 52,0,817,812,1,0,0,0,817,818,1,0,0,0,818,820,1,0,0,0,819,821,5,183,0, + 0,820,819,1,0,0,0,820,821,1,0,0,0,821,822,1,0,0,0,822,823,5,3,0,0,823, + 829,1,0,0,0,824,825,5,183,0,0,825,826,5,51,0,0,826,827,5,183,0,0,827, + 829,3,138,69,0,828,802,1,0,0,0,828,824,1,0,0,0,829,45,1,0,0,0,830,831, + 5,68,0,0,831,832,5,183,0,0,832,833,5,125,0,0,833,834,5,183,0,0,834,837, + 5,136,0,0,835,836,5,183,0,0,836,838,5,91,0,0,837,835,1,0,0,0,837,838, + 1,0,0,0,838,841,1,0,0,0,839,840,5,183,0,0,840,842,3,42,21,0,841,839,1, + 0,0,0,841,842,1,0,0,0,842,843,1,0,0,0,843,844,5,183,0,0,844,846,3,346, + 173,0,845,847,5,183,0,0,846,845,1,0,0,0,846,847,1,0,0,0,847,848,1,0,0, + 0,848,850,5,2,0,0,849,851,5,183,0,0,850,849,1,0,0,0,850,851,1,0,0,0,851, + 852,1,0,0,0,852,854,3,48,24,0,853,855,5,183,0,0,854,853,1,0,0,0,854,855, + 1,0,0,0,855,882,1,0,0,0,856,858,5,4,0,0,857,859,5,183,0,0,858,857,1,0, + 0,0,858,859,1,0,0,0,859,860,1,0,0,0,860,862,3,100,50,0,861,863,5,183, + 0,0,862,861,1,0,0,0,862,863,1,0,0,0,863,865,1,0,0,0,864,856,1,0,0,0,864, + 865,1,0,0,0,865,874,1,0,0,0,866,868,5,4,0,0,867,869,5,183,0,0,868,867, + 1,0,0,0,868,869,1,0,0,0,869,870,1,0,0,0,870,872,3,348,174,0,871,873,5, + 183,0,0,872,871,1,0,0,0,872,873,1,0,0,0,873,875,1,0,0,0,874,866,1,0,0, + 0,874,875,1,0,0,0,875,876,1,0,0,0,876,883,5,3,0,0,877,878,5,3,0,0,878, + 879,5,183,0,0,879,880,5,51,0,0,880,881,5,183,0,0,881,883,3,138,69,0,882, + 864,1,0,0,0,882,877,1,0,0,0,883,899,1,0,0,0,884,885,5,183,0,0,885,887, + 5,150,0,0,886,888,5,183,0,0,887,886,1,0,0,0,887,888,1,0,0,0,888,889,1, + 0,0,0,889,891,5,2,0,0,890,892,5,183,0,0,891,890,1,0,0,0,891,892,1,0,0, + 0,892,893,1,0,0,0,893,895,3,24,12,0,894,896,5,183,0,0,895,894,1,0,0,0, + 895,896,1,0,0,0,896,897,1,0,0,0,897,898,5,3,0,0,898,900,1,0,0,0,899,884, + 1,0,0,0,899,900,1,0,0,0,900,47,1,0,0,0,901,912,3,50,25,0,902,904,5,183, + 0,0,903,902,1,0,0,0,903,904,1,0,0,0,904,905,1,0,0,0,905,907,5,4,0,0,906, + 908,5,183,0,0,907,906,1,0,0,0,907,908,1,0,0,0,908,909,1,0,0,0,909,911, + 3,50,25,0,910,903,1,0,0,0,911,914,1,0,0,0,912,910,1,0,0,0,912,913,1,0, + 0,0,913,49,1,0,0,0,914,912,1,0,0,0,915,916,5,87,0,0,916,917,5,183,0,0, + 917,918,3,346,173,0,918,919,5,183,0,0,919,920,5,138,0,0,920,921,5,183, + 0,0,921,922,3,346,173,0,922,51,1,0,0,0,923,924,5,68,0,0,924,925,5,183, + 0,0,925,926,5,130,0,0,926,930,5,183,0,0,927,928,3,42,21,0,928,929,5,183, + 0,0,929,931,1,0,0,0,930,927,1,0,0,0,930,931,1,0,0,0,931,932,1,0,0,0,932, + 937,3,346,173,0,933,934,5,183,0,0,934,936,3,56,28,0,935,933,1,0,0,0,936, + 939,1,0,0,0,937,935,1,0,0,0,937,938,1,0,0,0,938,53,1,0,0,0,939,937,1, + 0,0,0,940,941,5,68,0,0,941,942,5,183,0,0,942,943,5,142,0,0,943,944,5, + 183,0,0,944,945,3,346,173,0,945,946,5,183,0,0,946,947,5,51,0,0,947,948, + 5,183,0,0,948,950,3,114,57,0,949,951,5,183,0,0,950,949,1,0,0,0,950,951, + 1,0,0,0,951,55,1,0,0,0,952,958,3,64,32,0,953,958,3,66,33,0,954,958,3, + 68,34,0,955,958,3,70,35,0,956,958,3,72,36,0,957,952,1,0,0,0,957,953,1, + 0,0,0,957,954,1,0,0,0,957,955,1,0,0,0,957,956,1,0,0,0,958,57,1,0,0,0, + 959,960,5,183,0,0,960,961,5,150,0,0,961,962,5,183,0,0,962,963,5,157,0, + 0,963,964,5,183,0,0,964,965,5,168,0,0,965,59,1,0,0,0,966,967,5,68,0,0, + 967,968,5,183,0,0,968,969,5,156,0,0,969,973,5,183,0,0,970,971,3,42,21, + 0,971,972,5,183,0,0,972,974,1,0,0,0,973,970,1,0,0,0,973,974,1,0,0,0,974, + 975,1,0,0,0,975,977,3,332,166,0,976,978,3,58,29,0,977,976,1,0,0,0,977, + 978,1,0,0,0,978,61,1,0,0,0,979,980,5,68,0,0,980,981,5,183,0,0,981,982, + 5,158,0,0,982,986,5,183,0,0,983,984,3,42,21,0,984,985,5,183,0,0,985,987, + 1,0,0,0,986,983,1,0,0,0,986,987,1,0,0,0,987,988,1,0,0,0,988,989,3,332, + 166,0,989,63,1,0,0,0,990,991,5,97,0,0,991,994,5,183,0,0,992,993,5,56, + 0,0,993,995,5,183,0,0,994,992,1,0,0,0,994,995,1,0,0,0,995,997,1,0,0,0, + 996,998,5,166,0,0,997,996,1,0,0,0,997,998,1,0,0,0,998,999,1,0,0,0,999, + 1000,3,342,171,0,1000,65,1,0,0,0,1001,1002,5,111,0,0,1002,1003,5,183, + 0,0,1003,1011,5,109,0,0,1004,1005,5,109,0,0,1005,1007,5,183,0,0,1006, + 1008,5,166,0,0,1007,1006,1,0,0,0,1007,1008,1,0,0,0,1008,1009,1,0,0,0, + 1009,1011,3,342,171,0,1010,1001,1,0,0,0,1010,1004,1,0,0,0,1011,67,1,0, + 0,0,1012,1013,5,111,0,0,1013,1014,5,183,0,0,1014,1022,5,107,0,0,1015, + 1016,5,107,0,0,1016,1018,5,183,0,0,1017,1019,5,166,0,0,1018,1017,1,0, + 0,0,1018,1019,1,0,0,0,1019,1020,1,0,0,0,1020,1022,3,342,171,0,1021,1012, + 1,0,0,0,1021,1015,1,0,0,0,1022,69,1,0,0,0,1023,1024,5,133,0,0,1024,1027, + 5,183,0,0,1025,1026,5,150,0,0,1026,1028,5,183,0,0,1027,1025,1,0,0,0,1027, + 1028,1,0,0,0,1028,1030,1,0,0,0,1029,1031,5,166,0,0,1030,1029,1,0,0,0, + 1030,1031,1,0,0,0,1031,1032,1,0,0,0,1032,1033,3,342,171,0,1033,71,1,0, + 0,0,1034,1035,5,111,0,0,1035,1037,5,183,0,0,1036,1034,1,0,0,0,1036,1037, + 1,0,0,0,1037,1038,1,0,0,0,1038,1039,5,69,0,0,1039,73,1,0,0,0,1040,1041, + 5,95,0,0,1041,1042,5,183,0,0,1042,1043,5,82,0,0,1043,75,1,0,0,0,1044, + 1045,5,78,0,0,1045,1046,5,183,0,0,1046,1047,7,0,0,0,1047,1051,5,183,0, + 0,1048,1049,3,74,37,0,1049,1050,5,183,0,0,1050,1052,1,0,0,0,1051,1048, + 1,0,0,0,1051,1052,1,0,0,0,1052,1053,1,0,0,0,1053,1054,3,346,173,0,1054, + 77,1,0,0,0,1055,1056,5,49,0,0,1056,1057,5,183,0,0,1057,1058,5,136,0,0, + 1058,1059,5,183,0,0,1059,1060,3,346,173,0,1060,1061,5,183,0,0,1061,1062, + 3,80,40,0,1062,79,1,0,0,0,1063,1070,3,82,41,0,1064,1070,3,86,43,0,1065, + 1070,3,88,44,0,1066,1070,3,90,45,0,1067,1070,3,92,46,0,1068,1070,3,94, + 47,0,1069,1063,1,0,0,0,1069,1064,1,0,0,0,1069,1065,1,0,0,0,1069,1066, + 1,0,0,0,1069,1067,1,0,0,0,1069,1068,1,0,0,0,1070,81,1,0,0,0,1071,1072, + 5,47,0,0,1072,1076,5,183,0,0,1073,1074,3,42,21,0,1074,1075,5,183,0,0, + 1075,1077,1,0,0,0,1076,1073,1,0,0,0,1076,1077,1,0,0,0,1077,1078,1,0,0, + 0,1078,1079,3,340,170,0,1079,1080,5,183,0,0,1080,1083,3,114,57,0,1081, + 1082,5,183,0,0,1082,1084,3,84,42,0,1083,1081,1,0,0,0,1083,1084,1,0,0, + 0,1084,83,1,0,0,0,1085,1086,5,72,0,0,1086,1087,5,183,0,0,1087,1088,3, + 244,122,0,1088,85,1,0,0,0,1089,1090,5,78,0,0,1090,1094,5,183,0,0,1091, + 1092,3,74,37,0,1092,1093,5,183,0,0,1093,1095,1,0,0,0,1094,1091,1,0,0, + 0,1094,1095,1,0,0,0,1095,1096,1,0,0,0,1096,1097,3,340,170,0,1097,87,1, + 0,0,0,1098,1099,5,126,0,0,1099,1100,5,183,0,0,1100,1101,5,138,0,0,1101, + 1102,5,183,0,0,1102,1103,3,346,173,0,1103,89,1,0,0,0,1104,1105,5,126, + 0,0,1105,1106,5,183,0,0,1106,1107,3,340,170,0,1107,1108,5,183,0,0,1108, + 1109,5,138,0,0,1109,1110,5,183,0,0,1110,1111,3,340,170,0,1111,91,1,0, + 0,0,1112,1113,5,47,0,0,1113,1117,5,183,0,0,1114,1115,3,42,21,0,1115,1116, + 5,183,0,0,1116,1118,1,0,0,0,1117,1114,1,0,0,0,1117,1118,1,0,0,0,1118, + 1119,1,0,0,0,1119,1120,3,50,25,0,1120,93,1,0,0,0,1121,1122,5,78,0,0,1122, + 1126,5,183,0,0,1123,1124,3,74,37,0,1124,1125,5,183,0,0,1125,1127,1,0, + 0,0,1126,1123,1,0,0,0,1126,1127,1,0,0,0,1127,1128,1,0,0,0,1128,1129,3, + 50,25,0,1129,95,1,0,0,0,1130,1141,3,98,49,0,1131,1133,5,183,0,0,1132, + 1131,1,0,0,0,1132,1133,1,0,0,0,1133,1134,1,0,0,0,1134,1136,5,4,0,0,1135, + 1137,5,183,0,0,1136,1135,1,0,0,0,1136,1137,1,0,0,0,1137,1138,1,0,0,0, + 1138,1140,3,98,49,0,1139,1132,1,0,0,0,1140,1143,1,0,0,0,1141,1139,1,0, + 0,0,1141,1142,1,0,0,0,1142,97,1,0,0,0,1143,1141,1,0,0,0,1144,1145,3,340, + 170,0,1145,1146,5,183,0,0,1146,1147,3,114,57,0,1147,99,1,0,0,0,1148,1159, + 3,102,51,0,1149,1151,5,183,0,0,1150,1149,1,0,0,0,1150,1151,1,0,0,0,1151, + 1152,1,0,0,0,1152,1154,5,4,0,0,1153,1155,5,183,0,0,1154,1153,1,0,0,0, + 1154,1155,1,0,0,0,1155,1156,1,0,0,0,1156,1158,3,102,51,0,1157,1150,1, + 0,0,0,1158,1161,1,0,0,0,1159,1157,1,0,0,0,1159,1160,1,0,0,0,1160,101, + 1,0,0,0,1161,1159,1,0,0,0,1162,1165,3,98,49,0,1163,1164,5,183,0,0,1164, + 1166,3,84,42,0,1165,1163,1,0,0,0,1165,1166,1,0,0,0,1166,1171,1,0,0,0, + 1167,1168,5,183,0,0,1168,1169,5,121,0,0,1169,1170,5,183,0,0,1170,1172, + 5,101,0,0,1171,1167,1,0,0,0,1171,1172,1,0,0,0,1172,103,1,0,0,0,1173,1174, + 5,121,0,0,1174,1175,5,183,0,0,1175,1177,5,101,0,0,1176,1178,5,183,0,0, + 1177,1176,1,0,0,0,1177,1178,1,0,0,0,1178,1179,1,0,0,0,1179,1181,5,2,0, + 0,1180,1182,5,183,0,0,1181,1180,1,0,0,0,1181,1182,1,0,0,0,1182,1183,1, + 0,0,0,1183,1185,3,340,170,0,1184,1186,5,183,0,0,1185,1184,1,0,0,0,1185, + 1186,1,0,0,0,1186,1187,1,0,0,0,1187,1188,5,3,0,0,1188,105,1,0,0,0,1189, + 1191,5,143,0,0,1190,1192,5,183,0,0,1191,1190,1,0,0,0,1191,1192,1,0,0, + 0,1192,1193,1,0,0,0,1193,1195,5,2,0,0,1194,1196,5,183,0,0,1195,1194,1, + 0,0,0,1195,1196,1,0,0,0,1196,1197,1,0,0,0,1197,1199,3,96,48,0,1198,1200, + 5,183,0,0,1199,1198,1,0,0,0,1199,1200,1,0,0,0,1200,1201,1,0,0,0,1201, + 1202,5,3,0,0,1202,107,1,0,0,0,1203,1205,5,135,0,0,1204,1206,5,183,0,0, + 1205,1204,1,0,0,0,1205,1206,1,0,0,0,1206,1207,1,0,0,0,1207,1209,5,2,0, + 0,1208,1210,5,183,0,0,1209,1208,1,0,0,0,1209,1210,1,0,0,0,1210,1211,1, + 0,0,0,1211,1213,3,96,48,0,1212,1214,5,183,0,0,1213,1212,1,0,0,0,1213, + 1214,1,0,0,0,1214,1215,1,0,0,0,1215,1216,5,3,0,0,1216,109,1,0,0,0,1217, + 1219,5,159,0,0,1218,1220,5,183,0,0,1219,1218,1,0,0,0,1219,1220,1,0,0, + 0,1220,1221,1,0,0,0,1221,1223,5,2,0,0,1222,1224,5,183,0,0,1223,1222,1, + 0,0,0,1223,1224,1,0,0,0,1224,1225,1,0,0,0,1225,1227,3,114,57,0,1226,1228, + 5,183,0,0,1227,1226,1,0,0,0,1227,1228,1,0,0,0,1228,1229,1,0,0,0,1229, + 1231,5,4,0,0,1230,1232,5,183,0,0,1231,1230,1,0,0,0,1231,1232,1,0,0,0, + 1232,1233,1,0,0,0,1233,1235,3,114,57,0,1234,1236,5,183,0,0,1235,1234, + 1,0,0,0,1235,1236,1,0,0,0,1236,1237,1,0,0,0,1237,1238,5,3,0,0,1238,111, + 1,0,0,0,1239,1241,5,160,0,0,1240,1242,5,183,0,0,1241,1240,1,0,0,0,1241, + 1242,1,0,0,0,1242,1243,1,0,0,0,1243,1245,5,2,0,0,1244,1246,5,183,0,0, + 1245,1244,1,0,0,0,1245,1246,1,0,0,0,1246,1247,1,0,0,0,1247,1249,3,342, + 171,0,1248,1250,5,183,0,0,1249,1248,1,0,0,0,1249,1250,1,0,0,0,1250,1251, + 1,0,0,0,1251,1253,5,4,0,0,1252,1254,5,183,0,0,1253,1252,1,0,0,0,1253, + 1254,1,0,0,0,1254,1255,1,0,0,0,1255,1257,3,342,171,0,1256,1258,5,183, + 0,0,1257,1256,1,0,0,0,1257,1258,1,0,0,0,1258,1259,1,0,0,0,1259,1260,5, + 3,0,0,1260,113,1,0,0,0,1261,1262,6,57,-1,0,1262,1268,3,348,174,0,1263, + 1268,3,106,53,0,1264,1268,3,108,54,0,1265,1268,3,110,55,0,1266,1268,3, + 112,56,0,1267,1261,1,0,0,0,1267,1263,1,0,0,0,1267,1264,1,0,0,0,1267,1265, + 1,0,0,0,1267,1266,1,0,0,0,1268,1273,1,0,0,0,1269,1270,10,5,0,0,1270,1272, + 3,116,58,0,1271,1269,1,0,0,0,1272,1275,1,0,0,0,1273,1271,1,0,0,0,1273, + 1274,1,0,0,0,1274,115,1,0,0,0,1275,1273,1,0,0,0,1276,1280,3,118,59,0, + 1277,1279,3,118,59,0,1278,1277,1,0,0,0,1279,1282,1,0,0,0,1280,1278,1, + 0,0,0,1280,1281,1,0,0,0,1281,117,1,0,0,0,1282,1280,1,0,0,0,1283,1285, + 5,7,0,0,1284,1286,3,342,171,0,1285,1284,1,0,0,0,1285,1286,1,0,0,0,1286, + 1287,1,0,0,0,1287,1288,5,8,0,0,1288,119,1,0,0,0,1289,1292,3,122,61,0, + 1290,1292,3,124,62,0,1291,1289,1,0,0,0,1291,1290,1,0,0,0,1292,121,1,0, + 0,0,1293,1296,5,83,0,0,1294,1295,5,183,0,0,1295,1297,5,104,0,0,1296,1294, + 1,0,0,0,1296,1297,1,0,0,0,1297,123,1,0,0,0,1298,1299,5,122,0,0,1299,125, + 1,0,0,0,1300,1301,5,55,0,0,1301,1302,5,183,0,0,1302,1314,5,140,0,0,1303, + 1304,5,55,0,0,1304,1305,5,183,0,0,1305,1306,5,140,0,0,1306,1307,5,183, + 0,0,1307,1308,5,124,0,0,1308,1309,5,183,0,0,1309,1314,5,117,0,0,1310, + 1314,5,63,0,0,1311,1314,5,128,0,0,1312,1314,5,60,0,0,1313,1300,1,0,0, + 0,1313,1303,1,0,0,0,1313,1310,1,0,0,0,1313,1311,1,0,0,0,1313,1312,1,0, + 0,0,1314,127,1,0,0,0,1315,1320,3,130,65,0,1316,1320,3,132,66,0,1317,1320, + 3,134,67,0,1318,1320,3,136,68,0,1319,1315,1,0,0,0,1319,1316,1,0,0,0,1319, + 1317,1,0,0,0,1319,1318,1,0,0,0,1320,129,1,0,0,0,1321,1322,5,103,0,0,1322, + 1325,5,183,0,0,1323,1324,5,85,0,0,1324,1326,5,183,0,0,1325,1323,1,0,0, + 0,1325,1326,1,0,0,0,1326,1329,1,0,0,0,1327,1330,5,168,0,0,1328,1330,3, + 332,166,0,1329,1327,1,0,0,0,1329,1328,1,0,0,0,1330,131,1,0,0,0,1331,1332, + 5,88,0,0,1332,1334,5,183,0,0,1333,1331,1,0,0,0,1333,1334,1,0,0,0,1334, + 1335,1,0,0,0,1335,1336,5,98,0,0,1336,1337,5,183,0,0,1337,1342,3,332,166, + 0,1338,1339,5,183,0,0,1339,1340,5,87,0,0,1340,1341,5,183,0,0,1341,1343, + 5,168,0,0,1342,1338,1,0,0,0,1342,1343,1,0,0,0,1343,133,1,0,0,0,1344,1345, + 5,145,0,0,1345,1346,5,183,0,0,1346,1347,3,332,166,0,1347,135,1,0,0,0, + 1348,1349,5,146,0,0,1349,1350,5,183,0,0,1350,1351,3,332,166,0,1351,137, + 1,0,0,0,1352,1353,3,140,70,0,1353,139,1,0,0,0,1354,1361,3,144,72,0,1355, + 1357,5,183,0,0,1356,1355,1,0,0,0,1356,1357,1,0,0,0,1357,1358,1,0,0,0, + 1358,1360,3,142,71,0,1359,1356,1,0,0,0,1360,1363,1,0,0,0,1361,1359,1, + 0,0,0,1361,1362,1,0,0,0,1362,1376,1,0,0,0,1363,1361,1,0,0,0,1364,1366, + 3,186,93,0,1365,1367,5,183,0,0,1366,1365,1,0,0,0,1366,1367,1,0,0,0,1367, + 1369,1,0,0,0,1368,1364,1,0,0,0,1369,1370,1,0,0,0,1370,1368,1,0,0,0,1370, + 1371,1,0,0,0,1371,1372,1,0,0,0,1372,1373,3,144,72,0,1373,1374,6,70,-1, + 0,1374,1376,1,0,0,0,1375,1354,1,0,0,0,1375,1368,1,0,0,0,1376,141,1,0, + 0,0,1377,1378,5,143,0,0,1378,1379,5,183,0,0,1379,1381,5,48,0,0,1380,1382, + 5,183,0,0,1381,1380,1,0,0,0,1381,1382,1,0,0,0,1382,1383,1,0,0,0,1383, + 1390,3,144,72,0,1384,1386,5,143,0,0,1385,1387,5,183,0,0,1386,1385,1,0, + 0,0,1386,1387,1,0,0,0,1387,1388,1,0,0,0,1388,1390,3,144,72,0,1389,1377, + 1,0,0,0,1389,1384,1,0,0,0,1390,143,1,0,0,0,1391,1394,3,146,73,0,1392, + 1394,3,148,74,0,1393,1391,1,0,0,0,1393,1392,1,0,0,0,1394,145,1,0,0,0, + 1395,1397,3,154,77,0,1396,1398,5,183,0,0,1397,1396,1,0,0,0,1397,1398, + 1,0,0,0,1398,1400,1,0,0,0,1399,1395,1,0,0,0,1400,1403,1,0,0,0,1401,1399, + 1,0,0,0,1401,1402,1,0,0,0,1402,1404,1,0,0,0,1403,1401,1,0,0,0,1404,1431, + 3,186,93,0,1405,1407,3,154,77,0,1406,1408,5,183,0,0,1407,1406,1,0,0,0, + 1407,1408,1,0,0,0,1408,1410,1,0,0,0,1409,1405,1,0,0,0,1410,1413,1,0,0, + 0,1411,1409,1,0,0,0,1411,1412,1,0,0,0,1412,1414,1,0,0,0,1413,1411,1,0, + 0,0,1414,1421,3,152,76,0,1415,1417,5,183,0,0,1416,1415,1,0,0,0,1416,1417, + 1,0,0,0,1417,1418,1,0,0,0,1418,1420,3,152,76,0,1419,1416,1,0,0,0,1420, + 1423,1,0,0,0,1421,1419,1,0,0,0,1421,1422,1,0,0,0,1422,1428,1,0,0,0,1423, + 1421,1,0,0,0,1424,1426,5,183,0,0,1425,1424,1,0,0,0,1425,1426,1,0,0,0, + 1426,1427,1,0,0,0,1427,1429,3,186,93,0,1428,1425,1,0,0,0,1428,1429,1, + 0,0,0,1429,1431,1,0,0,0,1430,1401,1,0,0,0,1430,1411,1,0,0,0,1431,147, + 1,0,0,0,1432,1434,3,150,75,0,1433,1435,5,183,0,0,1434,1433,1,0,0,0,1434, + 1435,1,0,0,0,1435,1437,1,0,0,0,1436,1432,1,0,0,0,1437,1438,1,0,0,0,1438, + 1436,1,0,0,0,1438,1439,1,0,0,0,1439,1440,1,0,0,0,1440,1441,3,146,73,0, + 1441,149,1,0,0,0,1442,1444,3,154,77,0,1443,1445,5,183,0,0,1444,1443,1, + 0,0,0,1444,1445,1,0,0,0,1445,1447,1,0,0,0,1446,1442,1,0,0,0,1447,1450, + 1,0,0,0,1448,1446,1,0,0,0,1448,1449,1,0,0,0,1449,1457,1,0,0,0,1450,1448, + 1,0,0,0,1451,1453,3,152,76,0,1452,1454,5,183,0,0,1453,1452,1,0,0,0,1453, + 1454,1,0,0,0,1454,1456,1,0,0,0,1455,1451,1,0,0,0,1456,1459,1,0,0,0,1457, + 1455,1,0,0,0,1457,1458,1,0,0,0,1458,1460,1,0,0,0,1459,1457,1,0,0,0,1460, + 1461,3,184,92,0,1461,151,1,0,0,0,1462,1467,3,172,86,0,1463,1467,3,174, + 87,0,1464,1467,3,178,89,0,1465,1467,3,182,91,0,1466,1462,1,0,0,0,1466, + 1463,1,0,0,0,1466,1464,1,0,0,0,1466,1465,1,0,0,0,1467,153,1,0,0,0,1468, + 1473,3,164,82,0,1469,1473,3,170,85,0,1470,1473,3,162,81,0,1471,1473,3, + 156,78,0,1472,1468,1,0,0,0,1472,1469,1,0,0,0,1472,1470,1,0,0,0,1472,1471, + 1,0,0,0,1473,155,1,0,0,0,1474,1492,5,103,0,0,1475,1476,5,183,0,0,1476, + 1477,5,150,0,0,1477,1478,5,183,0,0,1478,1480,5,92,0,0,1479,1481,5,183, + 0,0,1480,1479,1,0,0,0,1480,1481,1,0,0,0,1481,1482,1,0,0,0,1482,1484,5, + 2,0,0,1483,1485,5,183,0,0,1484,1483,1,0,0,0,1484,1485,1,0,0,0,1485,1486, + 1,0,0,0,1486,1488,3,96,48,0,1487,1489,5,183,0,0,1488,1487,1,0,0,0,1488, + 1489,1,0,0,0,1489,1490,1,0,0,0,1490,1491,5,3,0,0,1491,1493,1,0,0,0,1492, + 1475,1,0,0,0,1492,1493,1,0,0,0,1493,1494,1,0,0,0,1494,1495,5,183,0,0, + 1495,1496,5,87,0,0,1496,1497,5,183,0,0,1497,1511,3,10,5,0,1498,1500,5, + 183,0,0,1499,1498,1,0,0,0,1499,1500,1,0,0,0,1500,1501,1,0,0,0,1501,1503, + 5,2,0,0,1502,1504,5,183,0,0,1503,1502,1,0,0,0,1503,1504,1,0,0,0,1504, + 1505,1,0,0,0,1505,1507,3,24,12,0,1506,1508,5,183,0,0,1507,1506,1,0,0, + 0,1507,1508,1,0,0,0,1508,1509,1,0,0,0,1509,1510,5,3,0,0,1510,1512,1,0, + 0,0,1511,1499,1,0,0,0,1511,1512,1,0,0,0,1512,1517,1,0,0,0,1513,1515,5, + 183,0,0,1514,1513,1,0,0,0,1514,1515,1,0,0,0,1515,1516,1,0,0,0,1516,1518, + 3,202,101,0,1517,1514,1,0,0,0,1517,1518,1,0,0,0,1518,157,1,0,0,0,1519, + 1520,3,332,166,0,1520,1521,5,183,0,0,1521,1522,5,51,0,0,1522,1523,5,183, + 0,0,1523,1525,1,0,0,0,1524,1519,1,0,0,0,1524,1525,1,0,0,0,1525,1526,1, + 0,0,0,1526,1527,3,332,166,0,1527,159,1,0,0,0,1528,1539,3,158,79,0,1529, + 1531,5,183,0,0,1530,1529,1,0,0,0,1530,1531,1,0,0,0,1531,1532,1,0,0,0, + 1532,1534,5,4,0,0,1533,1535,5,183,0,0,1534,1533,1,0,0,0,1534,1535,1,0, + 0,0,1535,1536,1,0,0,0,1536,1538,3,158,79,0,1537,1530,1,0,0,0,1538,1541, + 1,0,0,0,1539,1537,1,0,0,0,1539,1540,1,0,0,0,1540,161,1,0,0,0,1541,1539, + 1,0,0,0,1542,1543,5,57,0,0,1543,1544,5,183,0,0,1544,1549,3,312,156,0, + 1545,1547,5,183,0,0,1546,1545,1,0,0,0,1546,1547,1,0,0,0,1547,1548,1,0, + 0,0,1548,1550,3,202,101,0,1549,1546,1,0,0,0,1549,1550,1,0,0,0,1550,1557, + 1,0,0,0,1551,1553,5,183,0,0,1552,1551,1,0,0,0,1552,1553,1,0,0,0,1553, + 1554,1,0,0,0,1554,1555,5,155,0,0,1555,1556,5,183,0,0,1556,1558,3,160, + 80,0,1557,1552,1,0,0,0,1557,1558,1,0,0,0,1558,163,1,0,0,0,1559,1560,5, + 118,0,0,1560,1562,5,183,0,0,1561,1559,1,0,0,0,1561,1562,1,0,0,0,1562, + 1563,1,0,0,0,1563,1565,5,106,0,0,1564,1566,5,183,0,0,1565,1564,1,0,0, + 0,1565,1566,1,0,0,0,1566,1567,1,0,0,0,1567,1570,3,204,102,0,1568,1569, + 5,183,0,0,1569,1571,3,202,101,0,1570,1568,1,0,0,0,1570,1571,1,0,0,0,1571, + 1574,1,0,0,0,1572,1573,5,183,0,0,1573,1575,3,166,83,0,1574,1572,1,0,0, + 0,1574,1575,1,0,0,0,1575,165,1,0,0,0,1576,1577,5,93,0,0,1577,1578,5,183, + 0,0,1578,1579,3,168,84,0,1579,167,1,0,0,0,1580,1581,6,84,-1,0,1581,1583, + 5,2,0,0,1582,1584,5,183,0,0,1583,1582,1,0,0,0,1583,1584,1,0,0,0,1584, + 1585,1,0,0,0,1585,1587,3,168,84,0,1586,1588,5,183,0,0,1587,1586,1,0,0, + 0,1587,1588,1,0,0,0,1588,1589,1,0,0,0,1589,1590,5,3,0,0,1590,1593,1,0, + 0,0,1591,1593,3,346,173,0,1592,1580,1,0,0,0,1592,1591,1,0,0,0,1593,1610, + 1,0,0,0,1594,1595,10,4,0,0,1595,1596,5,183,0,0,1596,1597,5,100,0,0,1597, + 1598,5,183,0,0,1598,1609,3,168,84,5,1599,1604,10,3,0,0,1600,1601,5,183, + 0,0,1601,1602,5,110,0,0,1602,1603,5,183,0,0,1603,1605,3,346,173,0,1604, + 1600,1,0,0,0,1605,1606,1,0,0,0,1606,1604,1,0,0,0,1606,1607,1,0,0,0,1607, + 1609,1,0,0,0,1608,1594,1,0,0,0,1608,1599,1,0,0,0,1609,1612,1,0,0,0,1610, + 1608,1,0,0,0,1610,1611,1,0,0,0,1611,169,1,0,0,0,1612,1610,1,0,0,0,1613, + 1615,5,144,0,0,1614,1616,5,183,0,0,1615,1614,1,0,0,0,1615,1616,1,0,0, + 0,1616,1617,1,0,0,0,1617,1618,3,244,122,0,1618,1619,5,183,0,0,1619,1620, + 5,51,0,0,1620,1621,5,183,0,0,1621,1622,3,332,166,0,1622,171,1,0,0,0,1623, + 1625,5,68,0,0,1624,1626,5,183,0,0,1625,1624,1,0,0,0,1625,1626,1,0,0,0, + 1626,1627,1,0,0,0,1627,1628,3,204,102,0,1628,173,1,0,0,0,1629,1631,5, + 108,0,0,1630,1632,5,183,0,0,1631,1630,1,0,0,0,1631,1632,1,0,0,0,1632, + 1633,1,0,0,0,1633,1638,3,204,102,0,1634,1635,5,183,0,0,1635,1637,3,176, + 88,0,1636,1634,1,0,0,0,1637,1640,1,0,0,0,1638,1636,1,0,0,0,1638,1639, + 1,0,0,0,1639,175,1,0,0,0,1640,1638,1,0,0,0,1641,1642,5,116,0,0,1642,1643, + 5,183,0,0,1643,1644,5,106,0,0,1644,1645,5,183,0,0,1645,1652,3,178,89, + 0,1646,1647,5,116,0,0,1647,1648,5,183,0,0,1648,1649,5,68,0,0,1649,1650, + 5,183,0,0,1650,1652,3,178,89,0,1651,1641,1,0,0,0,1651,1646,1,0,0,0,1652, + 177,1,0,0,0,1653,1655,5,131,0,0,1654,1656,5,183,0,0,1655,1654,1,0,0,0, + 1655,1656,1,0,0,0,1656,1657,1,0,0,0,1657,1668,3,180,90,0,1658,1660,5, + 183,0,0,1659,1658,1,0,0,0,1659,1660,1,0,0,0,1660,1661,1,0,0,0,1661,1663, + 5,4,0,0,1662,1664,5,183,0,0,1663,1662,1,0,0,0,1663,1664,1,0,0,0,1664, + 1665,1,0,0,0,1665,1667,3,180,90,0,1666,1659,1,0,0,0,1667,1670,1,0,0,0, + 1668,1666,1,0,0,0,1668,1669,1,0,0,0,1669,1686,1,0,0,0,1670,1668,1,0,0, + 0,1671,1673,5,131,0,0,1672,1674,5,183,0,0,1673,1672,1,0,0,0,1673,1674, + 1,0,0,0,1674,1675,1,0,0,0,1675,1677,3,290,145,0,1676,1678,5,183,0,0,1677, + 1676,1,0,0,0,1677,1678,1,0,0,0,1678,1679,1,0,0,0,1679,1681,5,6,0,0,1680, + 1682,5,183,0,0,1681,1680,1,0,0,0,1681,1682,1,0,0,0,1682,1683,1,0,0,0, + 1683,1684,3,220,110,0,1684,1686,1,0,0,0,1685,1653,1,0,0,0,1685,1671,1, + 0,0,0,1686,179,1,0,0,0,1687,1689,3,338,169,0,1688,1690,5,183,0,0,1689, + 1688,1,0,0,0,1689,1690,1,0,0,0,1690,1691,1,0,0,0,1691,1693,5,6,0,0,1692, + 1694,5,183,0,0,1693,1692,1,0,0,0,1693,1694,1,0,0,0,1694,1695,1,0,0,0, + 1695,1696,3,244,122,0,1696,181,1,0,0,0,1697,1698,5,76,0,0,1698,1700,5, + 183,0,0,1699,1697,1,0,0,0,1699,1700,1,0,0,0,1700,1701,1,0,0,0,1701,1703, + 5,73,0,0,1702,1704,5,183,0,0,1703,1702,1,0,0,0,1703,1704,1,0,0,0,1704, + 1705,1,0,0,0,1705,1716,3,244,122,0,1706,1708,5,183,0,0,1707,1706,1,0, + 0,0,1707,1708,1,0,0,0,1708,1709,1,0,0,0,1709,1711,5,4,0,0,1710,1712,5, + 183,0,0,1711,1710,1,0,0,0,1711,1712,1,0,0,0,1712,1713,1,0,0,0,1713,1715, + 3,244,122,0,1714,1707,1,0,0,0,1715,1718,1,0,0,0,1716,1714,1,0,0,0,1716, + 1717,1,0,0,0,1717,183,1,0,0,0,1718,1716,1,0,0,0,1719,1720,5,150,0,0,1720, + 1725,3,188,94,0,1721,1723,5,183,0,0,1722,1721,1,0,0,0,1722,1723,1,0,0, + 0,1723,1724,1,0,0,0,1724,1726,3,202,101,0,1725,1722,1,0,0,0,1725,1726, + 1,0,0,0,1726,185,1,0,0,0,1727,1728,5,127,0,0,1728,1729,3,188,94,0,1729, + 187,1,0,0,0,1730,1732,5,183,0,0,1731,1730,1,0,0,0,1731,1732,1,0,0,0,1732, + 1733,1,0,0,0,1733,1735,5,77,0,0,1734,1731,1,0,0,0,1734,1735,1,0,0,0,1735, + 1736,1,0,0,0,1736,1737,5,183,0,0,1737,1740,3,190,95,0,1738,1739,5,183, + 0,0,1739,1741,3,194,97,0,1740,1738,1,0,0,0,1740,1741,1,0,0,0,1741,1744, + 1,0,0,0,1742,1743,5,183,0,0,1743,1745,3,196,98,0,1744,1742,1,0,0,0,1744, + 1745,1,0,0,0,1745,1748,1,0,0,0,1746,1747,5,183,0,0,1747,1749,3,198,99, + 0,1748,1746,1,0,0,0,1748,1749,1,0,0,0,1749,189,1,0,0,0,1750,1761,5,161, + 0,0,1751,1753,5,183,0,0,1752,1751,1,0,0,0,1752,1753,1,0,0,0,1753,1754, + 1,0,0,0,1754,1756,5,4,0,0,1755,1757,5,183,0,0,1756,1755,1,0,0,0,1756, + 1757,1,0,0,0,1757,1758,1,0,0,0,1758,1760,3,192,96,0,1759,1752,1,0,0,0, + 1760,1763,1,0,0,0,1761,1759,1,0,0,0,1761,1762,1,0,0,0,1762,1779,1,0,0, + 0,1763,1761,1,0,0,0,1764,1775,3,192,96,0,1765,1767,5,183,0,0,1766,1765, + 1,0,0,0,1766,1767,1,0,0,0,1767,1768,1,0,0,0,1768,1770,5,4,0,0,1769,1771, + 5,183,0,0,1770,1769,1,0,0,0,1770,1771,1,0,0,0,1771,1772,1,0,0,0,1772, + 1774,3,192,96,0,1773,1766,1,0,0,0,1774,1777,1,0,0,0,1775,1773,1,0,0,0, + 1775,1776,1,0,0,0,1776,1779,1,0,0,0,1777,1775,1,0,0,0,1778,1750,1,0,0, + 0,1778,1764,1,0,0,0,1779,191,1,0,0,0,1780,1781,3,244,122,0,1781,1782, + 5,183,0,0,1782,1783,5,51,0,0,1783,1784,5,183,0,0,1784,1785,3,332,166, + 0,1785,1788,1,0,0,0,1786,1788,3,244,122,0,1787,1780,1,0,0,0,1787,1786, + 1,0,0,0,1788,193,1,0,0,0,1789,1790,5,120,0,0,1790,1791,5,183,0,0,1791, + 1792,5,56,0,0,1792,1793,5,183,0,0,1793,1801,3,200,100,0,1794,1796,5,4, + 0,0,1795,1797,5,183,0,0,1796,1795,1,0,0,0,1796,1797,1,0,0,0,1797,1798, + 1,0,0,0,1798,1800,3,200,100,0,1799,1794,1,0,0,0,1800,1803,1,0,0,0,1801, + 1799,1,0,0,0,1801,1802,1,0,0,0,1802,195,1,0,0,0,1803,1801,1,0,0,0,1804, + 1805,5,162,0,0,1805,1806,5,183,0,0,1806,1807,3,244,122,0,1807,197,1,0, + 0,0,1808,1809,5,102,0,0,1809,1810,5,183,0,0,1810,1811,3,244,122,0,1811, + 199,1,0,0,0,1812,1817,3,244,122,0,1813,1815,5,183,0,0,1814,1813,1,0,0, + 0,1814,1815,1,0,0,0,1815,1816,1,0,0,0,1816,1818,7,1,0,0,1817,1814,1,0, + 0,0,1817,1818,1,0,0,0,1818,201,1,0,0,0,1819,1820,5,149,0,0,1820,1821, + 5,183,0,0,1821,1822,3,244,122,0,1822,203,1,0,0,0,1823,1834,3,206,103, + 0,1824,1826,5,183,0,0,1825,1824,1,0,0,0,1825,1826,1,0,0,0,1826,1827,1, + 0,0,0,1827,1829,5,4,0,0,1828,1830,5,183,0,0,1829,1828,1,0,0,0,1829,1830, + 1,0,0,0,1830,1831,1,0,0,0,1831,1833,3,206,103,0,1832,1825,1,0,0,0,1833, + 1836,1,0,0,0,1834,1832,1,0,0,0,1834,1835,1,0,0,0,1835,205,1,0,0,0,1836, + 1834,1,0,0,0,1837,1839,3,332,166,0,1838,1840,5,183,0,0,1839,1838,1,0, + 0,0,1839,1840,1,0,0,0,1840,1841,1,0,0,0,1841,1843,5,6,0,0,1842,1844,5, + 183,0,0,1843,1842,1,0,0,0,1843,1844,1,0,0,0,1844,1845,1,0,0,0,1845,1846, + 3,208,104,0,1846,1849,1,0,0,0,1847,1849,3,208,104,0,1848,1837,1,0,0,0, + 1848,1847,1,0,0,0,1849,207,1,0,0,0,1850,1851,3,210,105,0,1851,209,1,0, + 0,0,1852,1859,3,212,106,0,1853,1855,5,183,0,0,1854,1853,1,0,0,0,1854, + 1855,1,0,0,0,1855,1856,1,0,0,0,1856,1858,3,214,107,0,1857,1854,1,0,0, + 0,1858,1861,1,0,0,0,1859,1857,1,0,0,0,1859,1860,1,0,0,0,1860,1867,1,0, + 0,0,1861,1859,1,0,0,0,1862,1863,5,2,0,0,1863,1864,3,210,105,0,1864,1865, + 5,3,0,0,1865,1867,1,0,0,0,1866,1852,1,0,0,0,1866,1862,1,0,0,0,1867,211, + 1,0,0,0,1868,1870,5,2,0,0,1869,1871,5,183,0,0,1870,1869,1,0,0,0,1870, + 1871,1,0,0,0,1871,1876,1,0,0,0,1872,1874,3,332,166,0,1873,1875,5,183, + 0,0,1874,1873,1,0,0,0,1874,1875,1,0,0,0,1875,1877,1,0,0,0,1876,1872,1, + 0,0,0,1876,1877,1,0,0,0,1877,1882,1,0,0,0,1878,1880,3,224,112,0,1879, + 1881,5,183,0,0,1880,1879,1,0,0,0,1880,1881,1,0,0,0,1881,1883,1,0,0,0, + 1882,1878,1,0,0,0,1882,1883,1,0,0,0,1883,1888,1,0,0,0,1884,1886,3,220, + 110,0,1885,1887,5,183,0,0,1886,1885,1,0,0,0,1886,1887,1,0,0,0,1887,1889, + 1,0,0,0,1888,1884,1,0,0,0,1888,1889,1,0,0,0,1889,1890,1,0,0,0,1890,1891, + 5,3,0,0,1891,213,1,0,0,0,1892,1894,3,216,108,0,1893,1895,5,183,0,0,1894, + 1893,1,0,0,0,1894,1895,1,0,0,0,1895,1896,1,0,0,0,1896,1897,3,212,106, + 0,1897,215,1,0,0,0,1898,1900,3,352,176,0,1899,1901,5,183,0,0,1900,1899, + 1,0,0,0,1900,1901,1,0,0,0,1901,1902,1,0,0,0,1902,1904,3,356,178,0,1903, + 1905,5,183,0,0,1904,1903,1,0,0,0,1904,1905,1,0,0,0,1905,1907,1,0,0,0, + 1906,1908,3,218,109,0,1907,1906,1,0,0,0,1907,1908,1,0,0,0,1908,1910,1, + 0,0,0,1909,1911,5,183,0,0,1910,1909,1,0,0,0,1910,1911,1,0,0,0,1911,1912, + 1,0,0,0,1912,1913,3,356,178,0,1913,1943,1,0,0,0,1914,1916,3,356,178,0, + 1915,1917,5,183,0,0,1916,1915,1,0,0,0,1916,1917,1,0,0,0,1917,1919,1,0, + 0,0,1918,1920,3,218,109,0,1919,1918,1,0,0,0,1919,1920,1,0,0,0,1920,1922, + 1,0,0,0,1921,1923,5,183,0,0,1922,1921,1,0,0,0,1922,1923,1,0,0,0,1923, + 1924,1,0,0,0,1924,1926,3,356,178,0,1925,1927,5,183,0,0,1926,1925,1,0, + 0,0,1926,1927,1,0,0,0,1927,1928,1,0,0,0,1928,1929,3,354,177,0,1929,1943, + 1,0,0,0,1930,1932,3,356,178,0,1931,1933,5,183,0,0,1932,1931,1,0,0,0,1932, + 1933,1,0,0,0,1933,1935,1,0,0,0,1934,1936,3,218,109,0,1935,1934,1,0,0, + 0,1935,1936,1,0,0,0,1936,1938,1,0,0,0,1937,1939,5,183,0,0,1938,1937,1, + 0,0,0,1938,1939,1,0,0,0,1939,1940,1,0,0,0,1940,1941,3,356,178,0,1941, + 1943,1,0,0,0,1942,1898,1,0,0,0,1942,1914,1,0,0,0,1942,1930,1,0,0,0,1943, + 217,1,0,0,0,1944,1946,5,7,0,0,1945,1947,5,183,0,0,1946,1945,1,0,0,0,1946, + 1947,1,0,0,0,1947,1952,1,0,0,0,1948,1950,3,332,166,0,1949,1951,5,183, + 0,0,1950,1949,1,0,0,0,1950,1951,1,0,0,0,1951,1953,1,0,0,0,1952,1948,1, + 0,0,0,1952,1953,1,0,0,0,1953,1958,1,0,0,0,1954,1956,3,222,111,0,1955, + 1957,5,183,0,0,1956,1955,1,0,0,0,1956,1957,1,0,0,0,1957,1959,1,0,0,0, + 1958,1954,1,0,0,0,1958,1959,1,0,0,0,1959,1964,1,0,0,0,1960,1962,3,226, + 113,0,1961,1963,5,183,0,0,1962,1961,1,0,0,0,1962,1963,1,0,0,0,1963,1965, + 1,0,0,0,1964,1960,1,0,0,0,1964,1965,1,0,0,0,1965,1970,1,0,0,0,1966,1968, + 3,220,110,0,1967,1969,5,183,0,0,1968,1967,1,0,0,0,1968,1969,1,0,0,0,1969, + 1971,1,0,0,0,1970,1966,1,0,0,0,1970,1971,1,0,0,0,1971,1972,1,0,0,0,1972, + 1973,5,8,0,0,1973,219,1,0,0,0,1974,1976,5,9,0,0,1975,1977,5,183,0,0,1976, + 1975,1,0,0,0,1976,1977,1,0,0,0,1977,2011,1,0,0,0,1978,1980,3,340,170, + 0,1979,1981,5,183,0,0,1980,1979,1,0,0,0,1980,1981,1,0,0,0,1981,1982,1, + 0,0,0,1982,1984,5,164,0,0,1983,1985,5,183,0,0,1984,1983,1,0,0,0,1984, + 1985,1,0,0,0,1985,1986,1,0,0,0,1986,1988,3,244,122,0,1987,1989,5,183, + 0,0,1988,1987,1,0,0,0,1988,1989,1,0,0,0,1989,2008,1,0,0,0,1990,1992,5, + 4,0,0,1991,1993,5,183,0,0,1992,1991,1,0,0,0,1992,1993,1,0,0,0,1993,1994, + 1,0,0,0,1994,1996,3,340,170,0,1995,1997,5,183,0,0,1996,1995,1,0,0,0,1996, + 1997,1,0,0,0,1997,1998,1,0,0,0,1998,2000,5,164,0,0,1999,2001,5,183,0, + 0,2000,1999,1,0,0,0,2000,2001,1,0,0,0,2001,2002,1,0,0,0,2002,2004,3,244, + 122,0,2003,2005,5,183,0,0,2004,2003,1,0,0,0,2004,2005,1,0,0,0,2005,2007, + 1,0,0,0,2006,1990,1,0,0,0,2007,2010,1,0,0,0,2008,2006,1,0,0,0,2008,2009, + 1,0,0,0,2009,2012,1,0,0,0,2010,2008,1,0,0,0,2011,1978,1,0,0,0,2011,2012, + 1,0,0,0,2012,2013,1,0,0,0,2013,2014,5,10,0,0,2014,221,1,0,0,0,2015,2017, + 5,164,0,0,2016,2018,5,183,0,0,2017,2016,1,0,0,0,2017,2018,1,0,0,0,2018, + 2019,1,0,0,0,2019,2033,3,242,121,0,2020,2022,5,183,0,0,2021,2020,1,0, + 0,0,2021,2022,1,0,0,0,2022,2023,1,0,0,0,2023,2025,5,11,0,0,2024,2026, + 5,164,0,0,2025,2024,1,0,0,0,2025,2026,1,0,0,0,2026,2028,1,0,0,0,2027, + 2029,5,183,0,0,2028,2027,1,0,0,0,2028,2029,1,0,0,0,2029,2030,1,0,0,0, + 2030,2032,3,242,121,0,2031,2021,1,0,0,0,2032,2035,1,0,0,0,2033,2031,1, + 0,0,0,2033,2034,1,0,0,0,2034,223,1,0,0,0,2035,2033,1,0,0,0,2036,2038, + 5,164,0,0,2037,2039,5,183,0,0,2038,2037,1,0,0,0,2038,2039,1,0,0,0,2039, + 2040,1,0,0,0,2040,2057,3,240,120,0,2041,2043,5,183,0,0,2042,2041,1,0, + 0,0,2042,2043,1,0,0,0,2043,2049,1,0,0,0,2044,2046,5,11,0,0,2045,2047, + 5,164,0,0,2046,2045,1,0,0,0,2046,2047,1,0,0,0,2047,2050,1,0,0,0,2048, + 2050,5,164,0,0,2049,2044,1,0,0,0,2049,2048,1,0,0,0,2050,2052,1,0,0,0, + 2051,2053,5,183,0,0,2052,2051,1,0,0,0,2052,2053,1,0,0,0,2053,2054,1,0, + 0,0,2054,2056,3,240,120,0,2055,2042,1,0,0,0,2056,2059,1,0,0,0,2057,2055, + 1,0,0,0,2057,2058,1,0,0,0,2058,225,1,0,0,0,2059,2057,1,0,0,0,2060,2065, + 5,161,0,0,2061,2063,5,183,0,0,2062,2061,1,0,0,0,2062,2063,1,0,0,0,2063, + 2064,1,0,0,0,2064,2066,3,228,114,0,2065,2062,1,0,0,0,2065,2066,1,0,0, + 0,2066,2071,1,0,0,0,2067,2069,5,183,0,0,2068,2067,1,0,0,0,2068,2069,1, + 0,0,0,2069,2070,1,0,0,0,2070,2072,3,230,115,0,2071,2068,1,0,0,0,2071, + 2072,1,0,0,0,2072,2077,1,0,0,0,2073,2075,5,183,0,0,2074,2073,1,0,0,0, + 2074,2075,1,0,0,0,2075,2076,1,0,0,0,2076,2078,3,232,116,0,2077,2074,1, + 0,0,0,2077,2078,1,0,0,0,2078,227,1,0,0,0,2079,2080,5,48,0,0,2080,2082, + 5,183,0,0,2081,2079,1,0,0,0,2081,2082,1,0,0,0,2082,2083,1,0,0,0,2083, + 2085,5,152,0,0,2084,2086,5,183,0,0,2085,2084,1,0,0,0,2085,2086,1,0,0, + 0,2086,2087,1,0,0,0,2087,2089,5,2,0,0,2088,2090,5,183,0,0,2089,2088,1, + 0,0,0,2089,2090,1,0,0,0,2090,2091,1,0,0,0,2091,2093,3,340,170,0,2092, + 2094,5,183,0,0,2093,2092,1,0,0,0,2093,2094,1,0,0,0,2094,2095,1,0,0,0, + 2095,2096,5,3,0,0,2096,2104,1,0,0,0,2097,2104,5,132,0,0,2098,2099,5,48, + 0,0,2099,2100,5,183,0,0,2100,2104,5,132,0,0,2101,2104,5,139,0,0,2102, + 2104,5,45,0,0,2103,2081,1,0,0,0,2103,2097,1,0,0,0,2103,2098,1,0,0,0,2103, + 2101,1,0,0,0,2103,2102,1,0,0,0,2104,229,1,0,0,0,2105,2107,3,236,118,0, + 2106,2105,1,0,0,0,2106,2107,1,0,0,0,2107,2109,1,0,0,0,2108,2110,5,183, + 0,0,2109,2108,1,0,0,0,2109,2110,1,0,0,0,2110,2111,1,0,0,0,2111,2113,5, + 165,0,0,2112,2114,5,183,0,0,2113,2112,1,0,0,0,2113,2114,1,0,0,0,2114, + 2116,1,0,0,0,2115,2117,3,238,119,0,2116,2115,1,0,0,0,2116,2117,1,0,0, + 0,2117,2120,1,0,0,0,2118,2120,3,342,171,0,2119,2106,1,0,0,0,2119,2118, + 1,0,0,0,2120,231,1,0,0,0,2121,2123,5,2,0,0,2122,2124,5,183,0,0,2123,2122, + 1,0,0,0,2123,2124,1,0,0,0,2124,2125,1,0,0,0,2125,2127,3,332,166,0,2126, + 2128,5,183,0,0,2127,2126,1,0,0,0,2127,2128,1,0,0,0,2128,2129,1,0,0,0, + 2129,2131,5,4,0,0,2130,2132,5,183,0,0,2131,2130,1,0,0,0,2131,2132,1,0, + 0,0,2132,2133,1,0,0,0,2133,2145,3,332,166,0,2134,2136,5,183,0,0,2135, + 2134,1,0,0,0,2135,2136,1,0,0,0,2136,2137,1,0,0,0,2137,2139,5,11,0,0,2138, + 2140,5,183,0,0,2139,2138,1,0,0,0,2139,2140,1,0,0,0,2140,2141,1,0,0,0, + 2141,2143,3,202,101,0,2142,2144,5,183,0,0,2143,2142,1,0,0,0,2143,2144, + 1,0,0,0,2144,2146,1,0,0,0,2145,2135,1,0,0,0,2145,2146,1,0,0,0,2146,2166, + 1,0,0,0,2147,2149,5,183,0,0,2148,2147,1,0,0,0,2148,2149,1,0,0,0,2149, + 2150,1,0,0,0,2150,2152,5,11,0,0,2151,2153,5,183,0,0,2152,2151,1,0,0,0, + 2152,2153,1,0,0,0,2153,2154,1,0,0,0,2154,2156,3,234,117,0,2155,2157,5, + 183,0,0,2156,2155,1,0,0,0,2156,2157,1,0,0,0,2157,2158,1,0,0,0,2158,2160, + 5,4,0,0,2159,2161,5,183,0,0,2160,2159,1,0,0,0,2160,2161,1,0,0,0,2161, + 2162,1,0,0,0,2162,2164,3,234,117,0,2163,2165,5,183,0,0,2164,2163,1,0, + 0,0,2164,2165,1,0,0,0,2165,2167,1,0,0,0,2166,2148,1,0,0,0,2166,2167,1, + 0,0,0,2167,2168,1,0,0,0,2168,2169,5,3,0,0,2169,233,1,0,0,0,2170,2172, + 5,9,0,0,2171,2173,5,183,0,0,2172,2171,1,0,0,0,2172,2173,1,0,0,0,2173, + 2175,1,0,0,0,2174,2176,3,190,95,0,2175,2174,1,0,0,0,2175,2176,1,0,0,0, + 2176,2178,1,0,0,0,2177,2179,5,183,0,0,2178,2177,1,0,0,0,2178,2179,1,0, + 0,0,2179,2180,1,0,0,0,2180,2181,5,10,0,0,2181,235,1,0,0,0,2182,2183,5, + 170,0,0,2183,237,1,0,0,0,2184,2185,5,170,0,0,2185,239,1,0,0,0,2186,2187, + 3,346,173,0,2187,241,1,0,0,0,2188,2189,3,346,173,0,2189,243,1,0,0,0,2190, + 2191,3,246,123,0,2191,245,1,0,0,0,2192,2199,3,248,124,0,2193,2194,5,183, + 0,0,2194,2195,5,119,0,0,2195,2196,5,183,0,0,2196,2198,3,248,124,0,2197, + 2193,1,0,0,0,2198,2201,1,0,0,0,2199,2197,1,0,0,0,2199,2200,1,0,0,0,2200, + 247,1,0,0,0,2201,2199,1,0,0,0,2202,2209,3,250,125,0,2203,2204,5,183,0, + 0,2204,2205,5,153,0,0,2205,2206,5,183,0,0,2206,2208,3,250,125,0,2207, + 2203,1,0,0,0,2208,2211,1,0,0,0,2209,2207,1,0,0,0,2209,2210,1,0,0,0,2210, + 249,1,0,0,0,2211,2209,1,0,0,0,2212,2219,3,252,126,0,2213,2214,5,183,0, + 0,2214,2215,5,50,0,0,2215,2216,5,183,0,0,2216,2218,3,252,126,0,2217,2213, + 1,0,0,0,2218,2221,1,0,0,0,2219,2217,1,0,0,0,2219,2220,1,0,0,0,2220,251, + 1,0,0,0,2221,2219,1,0,0,0,2222,2224,5,113,0,0,2223,2225,5,183,0,0,2224, + 2223,1,0,0,0,2224,2225,1,0,0,0,2225,2227,1,0,0,0,2226,2222,1,0,0,0,2227, + 2230,1,0,0,0,2228,2226,1,0,0,0,2228,2229,1,0,0,0,2229,2231,1,0,0,0,2230, + 2228,1,0,0,0,2231,2232,3,254,127,0,2232,253,1,0,0,0,2233,2243,3,258,129, + 0,2234,2236,5,183,0,0,2235,2234,1,0,0,0,2235,2236,1,0,0,0,2236,2237,1, + 0,0,0,2237,2239,3,256,128,0,2238,2240,5,183,0,0,2239,2238,1,0,0,0,2239, + 2240,1,0,0,0,2240,2241,1,0,0,0,2241,2242,3,258,129,0,2242,2244,1,0,0, + 0,2243,2235,1,0,0,0,2243,2244,1,0,0,0,2244,2282,1,0,0,0,2245,2247,3,258, + 129,0,2246,2248,5,183,0,0,2247,2246,1,0,0,0,2247,2248,1,0,0,0,2248,2249, + 1,0,0,0,2249,2251,5,163,0,0,2250,2252,5,183,0,0,2251,2250,1,0,0,0,2251, + 2252,1,0,0,0,2252,2253,1,0,0,0,2253,2254,3,258,129,0,2254,2255,1,0,0, + 0,2255,2256,6,127,-1,0,2256,2282,1,0,0,0,2257,2259,3,258,129,0,2258,2260, + 5,183,0,0,2259,2258,1,0,0,0,2259,2260,1,0,0,0,2260,2261,1,0,0,0,2261, + 2263,3,256,128,0,2262,2264,5,183,0,0,2263,2262,1,0,0,0,2263,2264,1,0, + 0,0,2264,2265,1,0,0,0,2265,2275,3,258,129,0,2266,2268,5,183,0,0,2267, + 2266,1,0,0,0,2267,2268,1,0,0,0,2268,2269,1,0,0,0,2269,2271,3,256,128, + 0,2270,2272,5,183,0,0,2271,2270,1,0,0,0,2271,2272,1,0,0,0,2272,2273,1, + 0,0,0,2273,2274,3,258,129,0,2274,2276,1,0,0,0,2275,2267,1,0,0,0,2276, + 2277,1,0,0,0,2277,2275,1,0,0,0,2277,2278,1,0,0,0,2278,2279,1,0,0,0,2279, + 2280,6,127,-1,0,2280,2282,1,0,0,0,2281,2233,1,0,0,0,2281,2245,1,0,0,0, + 2281,2257,1,0,0,0,2282,255,1,0,0,0,2283,2284,7,2,0,0,2284,257,1,0,0,0, + 2285,2296,3,260,130,0,2286,2288,5,183,0,0,2287,2286,1,0,0,0,2287,2288, + 1,0,0,0,2288,2289,1,0,0,0,2289,2291,5,11,0,0,2290,2292,5,183,0,0,2291, + 2290,1,0,0,0,2291,2292,1,0,0,0,2292,2293,1,0,0,0,2293,2295,3,260,130, + 0,2294,2287,1,0,0,0,2295,2298,1,0,0,0,2296,2294,1,0,0,0,2296,2297,1,0, + 0,0,2297,259,1,0,0,0,2298,2296,1,0,0,0,2299,2310,3,262,131,0,2300,2302, + 5,183,0,0,2301,2300,1,0,0,0,2301,2302,1,0,0,0,2302,2303,1,0,0,0,2303, + 2305,5,17,0,0,2304,2306,5,183,0,0,2305,2304,1,0,0,0,2305,2306,1,0,0,0, + 2306,2307,1,0,0,0,2307,2309,3,262,131,0,2308,2301,1,0,0,0,2309,2312,1, + 0,0,0,2310,2308,1,0,0,0,2310,2311,1,0,0,0,2311,261,1,0,0,0,2312,2310, + 1,0,0,0,2313,2325,3,266,133,0,2314,2316,5,183,0,0,2315,2314,1,0,0,0,2315, + 2316,1,0,0,0,2316,2317,1,0,0,0,2317,2319,3,264,132,0,2318,2320,5,183, + 0,0,2319,2318,1,0,0,0,2319,2320,1,0,0,0,2320,2321,1,0,0,0,2321,2322,3, + 266,133,0,2322,2324,1,0,0,0,2323,2315,1,0,0,0,2324,2327,1,0,0,0,2325, + 2323,1,0,0,0,2325,2326,1,0,0,0,2326,263,1,0,0,0,2327,2325,1,0,0,0,2328, + 2329,7,3,0,0,2329,265,1,0,0,0,2330,2342,3,270,135,0,2331,2333,5,183,0, + 0,2332,2331,1,0,0,0,2332,2333,1,0,0,0,2333,2334,1,0,0,0,2334,2336,3,268, + 134,0,2335,2337,5,183,0,0,2336,2335,1,0,0,0,2336,2337,1,0,0,0,2337,2338, + 1,0,0,0,2338,2339,3,270,135,0,2339,2341,1,0,0,0,2340,2332,1,0,0,0,2341, + 2344,1,0,0,0,2342,2340,1,0,0,0,2342,2343,1,0,0,0,2343,267,1,0,0,0,2344, + 2342,1,0,0,0,2345,2346,7,4,0,0,2346,269,1,0,0,0,2347,2359,3,274,137,0, + 2348,2350,5,183,0,0,2349,2348,1,0,0,0,2349,2350,1,0,0,0,2350,2351,1,0, + 0,0,2351,2353,3,272,136,0,2352,2354,5,183,0,0,2353,2352,1,0,0,0,2353, + 2354,1,0,0,0,2354,2355,1,0,0,0,2355,2356,3,274,137,0,2356,2358,1,0,0, + 0,2357,2349,1,0,0,0,2358,2361,1,0,0,0,2359,2357,1,0,0,0,2359,2360,1,0, + 0,0,2360,271,1,0,0,0,2361,2359,1,0,0,0,2362,2363,7,5,0,0,2363,273,1,0, + 0,0,2364,2375,3,276,138,0,2365,2367,5,183,0,0,2366,2365,1,0,0,0,2366, + 2367,1,0,0,0,2367,2368,1,0,0,0,2368,2370,5,23,0,0,2369,2371,5,183,0,0, + 2370,2369,1,0,0,0,2370,2371,1,0,0,0,2371,2372,1,0,0,0,2372,2374,3,276, + 138,0,2373,2366,1,0,0,0,2374,2377,1,0,0,0,2375,2373,1,0,0,0,2375,2376, + 1,0,0,0,2376,275,1,0,0,0,2377,2375,1,0,0,0,2378,2386,3,286,143,0,2379, + 2387,3,280,140,0,2380,2382,3,278,139,0,2381,2380,1,0,0,0,2382,2383,1, + 0,0,0,2383,2381,1,0,0,0,2383,2384,1,0,0,0,2384,2387,1,0,0,0,2385,2387, + 3,284,142,0,2386,2379,1,0,0,0,2386,2381,1,0,0,0,2386,2385,1,0,0,0,2386, + 2387,1,0,0,0,2387,277,1,0,0,0,2388,2389,5,183,0,0,2389,2391,5,96,0,0, + 2390,2392,5,183,0,0,2391,2390,1,0,0,0,2391,2392,1,0,0,0,2392,2393,1,0, + 0,0,2393,2408,3,288,144,0,2394,2395,5,7,0,0,2395,2396,3,244,122,0,2396, + 2397,5,8,0,0,2397,2408,1,0,0,0,2398,2400,5,7,0,0,2399,2401,3,244,122, + 0,2400,2399,1,0,0,0,2400,2401,1,0,0,0,2401,2402,1,0,0,0,2402,2404,7,6, + 0,0,2403,2405,3,244,122,0,2404,2403,1,0,0,0,2404,2405,1,0,0,0,2405,2406, + 1,0,0,0,2406,2408,5,8,0,0,2407,2388,1,0,0,0,2407,2394,1,0,0,0,2407,2398, + 1,0,0,0,2408,279,1,0,0,0,2409,2421,3,282,141,0,2410,2411,5,183,0,0,2411, + 2412,5,134,0,0,2412,2413,5,183,0,0,2413,2421,5,150,0,0,2414,2415,5,183, + 0,0,2415,2416,5,81,0,0,2416,2417,5,183,0,0,2417,2421,5,150,0,0,2418,2419, + 5,183,0,0,2419,2421,5,65,0,0,2420,2409,1,0,0,0,2420,2410,1,0,0,0,2420, + 2414,1,0,0,0,2420,2418,1,0,0,0,2421,2423,1,0,0,0,2422,2424,5,183,0,0, + 2423,2422,1,0,0,0,2423,2424,1,0,0,0,2424,2425,1,0,0,0,2425,2426,3,288, + 144,0,2426,281,1,0,0,0,2427,2429,5,183,0,0,2428,2427,1,0,0,0,2428,2429, + 1,0,0,0,2429,2430,1,0,0,0,2430,2431,5,24,0,0,2431,283,1,0,0,0,2432,2433, + 5,183,0,0,2433,2434,5,99,0,0,2434,2435,5,183,0,0,2435,2443,5,115,0,0, + 2436,2437,5,183,0,0,2437,2438,5,99,0,0,2438,2439,5,183,0,0,2439,2440, + 5,113,0,0,2440,2441,5,183,0,0,2441,2443,5,115,0,0,2442,2432,1,0,0,0,2442, + 2436,1,0,0,0,2443,285,1,0,0,0,2444,2446,5,166,0,0,2445,2447,5,183,0,0, + 2446,2445,1,0,0,0,2446,2447,1,0,0,0,2447,2449,1,0,0,0,2448,2444,1,0,0, + 0,2449,2452,1,0,0,0,2450,2448,1,0,0,0,2450,2451,1,0,0,0,2451,2453,1,0, + 0,0,2452,2450,1,0,0,0,2453,2458,3,288,144,0,2454,2456,5,183,0,0,2455, + 2454,1,0,0,0,2455,2456,1,0,0,0,2456,2457,1,0,0,0,2457,2459,5,167,0,0, + 2458,2455,1,0,0,0,2458,2459,1,0,0,0,2459,287,1,0,0,0,2460,2467,3,290, + 145,0,2461,2463,5,183,0,0,2462,2461,1,0,0,0,2462,2463,1,0,0,0,2463,2464, + 1,0,0,0,2464,2466,3,326,163,0,2465,2462,1,0,0,0,2466,2469,1,0,0,0,2467, + 2465,1,0,0,0,2467,2468,1,0,0,0,2468,289,1,0,0,0,2469,2467,1,0,0,0,2470, + 2480,3,298,149,0,2471,2480,3,336,168,0,2472,2480,3,328,164,0,2473,2480, + 3,310,155,0,2474,2480,3,312,156,0,2475,2480,3,322,161,0,2476,2480,3,324, + 162,0,2477,2480,3,332,166,0,2478,2480,3,292,146,0,2479,2470,1,0,0,0,2479, + 2471,1,0,0,0,2479,2472,1,0,0,0,2479,2473,1,0,0,0,2479,2474,1,0,0,0,2479, + 2475,1,0,0,0,2479,2476,1,0,0,0,2479,2477,1,0,0,0,2479,2478,1,0,0,0,2480, + 291,1,0,0,0,2481,2483,5,48,0,0,2482,2484,5,183,0,0,2483,2482,1,0,0,0, + 2483,2484,1,0,0,0,2484,2485,1,0,0,0,2485,2487,5,2,0,0,2486,2488,5,183, + 0,0,2487,2486,1,0,0,0,2487,2488,1,0,0,0,2488,2489,1,0,0,0,2489,2491,3, + 294,147,0,2490,2492,5,183,0,0,2491,2490,1,0,0,0,2491,2492,1,0,0,0,2492, + 2493,1,0,0,0,2493,2494,5,3,0,0,2494,2538,1,0,0,0,2495,2497,5,46,0,0,2496, + 2498,5,183,0,0,2497,2496,1,0,0,0,2497,2498,1,0,0,0,2498,2499,1,0,0,0, + 2499,2501,5,2,0,0,2500,2502,5,183,0,0,2501,2500,1,0,0,0,2501,2502,1,0, + 0,0,2502,2503,1,0,0,0,2503,2505,3,294,147,0,2504,2506,5,183,0,0,2505, + 2504,1,0,0,0,2505,2506,1,0,0,0,2506,2507,1,0,0,0,2507,2508,5,3,0,0,2508, + 2538,1,0,0,0,2509,2511,5,114,0,0,2510,2512,5,183,0,0,2511,2510,1,0,0, + 0,2511,2512,1,0,0,0,2512,2513,1,0,0,0,2513,2515,5,2,0,0,2514,2516,5,183, + 0,0,2515,2514,1,0,0,0,2515,2516,1,0,0,0,2516,2517,1,0,0,0,2517,2519,3, + 294,147,0,2518,2520,5,183,0,0,2519,2518,1,0,0,0,2519,2520,1,0,0,0,2520, + 2521,1,0,0,0,2521,2522,5,3,0,0,2522,2538,1,0,0,0,2523,2525,5,154,0,0, + 2524,2526,5,183,0,0,2525,2524,1,0,0,0,2525,2526,1,0,0,0,2526,2527,1,0, + 0,0,2527,2529,5,2,0,0,2528,2530,5,183,0,0,2529,2528,1,0,0,0,2529,2530, + 1,0,0,0,2530,2531,1,0,0,0,2531,2533,3,294,147,0,2532,2534,5,183,0,0,2533, + 2532,1,0,0,0,2533,2534,1,0,0,0,2534,2535,1,0,0,0,2535,2536,5,3,0,0,2536, + 2538,1,0,0,0,2537,2481,1,0,0,0,2537,2495,1,0,0,0,2537,2509,1,0,0,0,2537, + 2523,1,0,0,0,2538,293,1,0,0,0,2539,2540,3,296,148,0,2540,2541,5,183,0, + 0,2541,2542,3,202,101,0,2542,295,1,0,0,0,2543,2544,3,332,166,0,2544,2545, + 5,183,0,0,2545,2546,5,96,0,0,2546,2547,5,183,0,0,2547,2548,3,244,122, + 0,2548,297,1,0,0,0,2549,2556,3,334,167,0,2550,2556,5,168,0,0,2551,2556, + 3,300,150,0,2552,2556,5,115,0,0,2553,2556,3,302,151,0,2554,2556,3,306, + 153,0,2555,2549,1,0,0,0,2555,2550,1,0,0,0,2555,2551,1,0,0,0,2555,2552, + 1,0,0,0,2555,2553,1,0,0,0,2555,2554,1,0,0,0,2556,299,1,0,0,0,2557,2558, + 7,7,0,0,2558,301,1,0,0,0,2559,2561,5,7,0,0,2560,2562,5,183,0,0,2561,2560, + 1,0,0,0,2561,2562,1,0,0,0,2562,2576,1,0,0,0,2563,2565,3,244,122,0,2564, + 2566,5,183,0,0,2565,2564,1,0,0,0,2565,2566,1,0,0,0,2566,2573,1,0,0,0, + 2567,2569,3,304,152,0,2568,2570,5,183,0,0,2569,2568,1,0,0,0,2569,2570, + 1,0,0,0,2570,2572,1,0,0,0,2571,2567,1,0,0,0,2572,2575,1,0,0,0,2573,2571, + 1,0,0,0,2573,2574,1,0,0,0,2574,2577,1,0,0,0,2575,2573,1,0,0,0,2576,2563, + 1,0,0,0,2576,2577,1,0,0,0,2577,2578,1,0,0,0,2578,2579,5,8,0,0,2579,303, + 1,0,0,0,2580,2582,5,4,0,0,2581,2583,5,183,0,0,2582,2581,1,0,0,0,2582, + 2583,1,0,0,0,2583,2585,1,0,0,0,2584,2586,3,244,122,0,2585,2584,1,0,0, + 0,2585,2586,1,0,0,0,2586,305,1,0,0,0,2587,2589,5,9,0,0,2588,2590,5,183, + 0,0,2589,2588,1,0,0,0,2589,2590,1,0,0,0,2590,2591,1,0,0,0,2591,2593,3, + 308,154,0,2592,2594,5,183,0,0,2593,2592,1,0,0,0,2593,2594,1,0,0,0,2594, + 2605,1,0,0,0,2595,2597,5,4,0,0,2596,2598,5,183,0,0,2597,2596,1,0,0,0, + 2597,2598,1,0,0,0,2598,2599,1,0,0,0,2599,2601,3,308,154,0,2600,2602,5, + 183,0,0,2601,2600,1,0,0,0,2601,2602,1,0,0,0,2602,2604,1,0,0,0,2603,2595, + 1,0,0,0,2604,2607,1,0,0,0,2605,2603,1,0,0,0,2605,2606,1,0,0,0,2606,2608, + 1,0,0,0,2607,2605,1,0,0,0,2608,2609,5,10,0,0,2609,307,1,0,0,0,2610,2613, + 3,348,174,0,2611,2613,5,168,0,0,2612,2610,1,0,0,0,2612,2611,1,0,0,0,2613, + 2615,1,0,0,0,2614,2616,5,183,0,0,2615,2614,1,0,0,0,2615,2616,1,0,0,0, + 2616,2617,1,0,0,0,2617,2619,5,164,0,0,2618,2620,5,183,0,0,2619,2618,1, + 0,0,0,2619,2620,1,0,0,0,2620,2621,1,0,0,0,2621,2622,3,244,122,0,2622, + 309,1,0,0,0,2623,2625,5,2,0,0,2624,2626,5,183,0,0,2625,2624,1,0,0,0,2625, + 2626,1,0,0,0,2626,2627,1,0,0,0,2627,2629,3,244,122,0,2628,2630,5,183, + 0,0,2629,2628,1,0,0,0,2629,2630,1,0,0,0,2630,2631,1,0,0,0,2631,2632,5, + 3,0,0,2632,311,1,0,0,0,2633,2635,5,67,0,0,2634,2636,5,183,0,0,2635,2634, + 1,0,0,0,2635,2636,1,0,0,0,2636,2637,1,0,0,0,2637,2639,5,2,0,0,2638,2640, + 5,183,0,0,2639,2638,1,0,0,0,2639,2640,1,0,0,0,2640,2641,1,0,0,0,2641, + 2643,5,161,0,0,2642,2644,5,183,0,0,2643,2642,1,0,0,0,2643,2644,1,0,0, + 0,2644,2645,1,0,0,0,2645,2711,5,3,0,0,2646,2648,5,59,0,0,2647,2649,5, + 183,0,0,2648,2647,1,0,0,0,2648,2649,1,0,0,0,2649,2650,1,0,0,0,2650,2652, + 5,2,0,0,2651,2653,5,183,0,0,2652,2651,1,0,0,0,2652,2653,1,0,0,0,2653, + 2654,1,0,0,0,2654,2656,3,316,158,0,2655,2657,5,183,0,0,2656,2655,1,0, + 0,0,2656,2657,1,0,0,0,2657,2668,1,0,0,0,2658,2660,5,51,0,0,2659,2661, + 5,183,0,0,2660,2659,1,0,0,0,2660,2661,1,0,0,0,2661,2662,1,0,0,0,2662, + 2669,3,114,57,0,2663,2665,5,4,0,0,2664,2666,5,183,0,0,2665,2664,1,0,0, + 0,2665,2666,1,0,0,0,2666,2667,1,0,0,0,2667,2669,3,316,158,0,2668,2658, + 1,0,0,0,2668,2663,1,0,0,0,2669,2671,1,0,0,0,2670,2672,5,183,0,0,2671, + 2670,1,0,0,0,2671,2672,1,0,0,0,2672,2673,1,0,0,0,2673,2674,5,3,0,0,2674, + 2711,1,0,0,0,2675,2677,3,314,157,0,2676,2678,5,183,0,0,2677,2676,1,0, + 0,0,2677,2678,1,0,0,0,2678,2679,1,0,0,0,2679,2681,5,2,0,0,2680,2682,5, + 183,0,0,2681,2680,1,0,0,0,2681,2682,1,0,0,0,2682,2687,1,0,0,0,2683,2685, + 5,77,0,0,2684,2686,5,183,0,0,2685,2684,1,0,0,0,2685,2686,1,0,0,0,2686, + 2688,1,0,0,0,2687,2683,1,0,0,0,2687,2688,1,0,0,0,2688,2706,1,0,0,0,2689, + 2691,3,316,158,0,2690,2692,5,183,0,0,2691,2690,1,0,0,0,2691,2692,1,0, + 0,0,2692,2703,1,0,0,0,2693,2695,5,4,0,0,2694,2696,5,183,0,0,2695,2694, + 1,0,0,0,2695,2696,1,0,0,0,2696,2697,1,0,0,0,2697,2699,3,316,158,0,2698, + 2700,5,183,0,0,2699,2698,1,0,0,0,2699,2700,1,0,0,0,2700,2702,1,0,0,0, + 2701,2693,1,0,0,0,2702,2705,1,0,0,0,2703,2701,1,0,0,0,2703,2704,1,0,0, + 0,2704,2707,1,0,0,0,2705,2703,1,0,0,0,2706,2689,1,0,0,0,2706,2707,1,0, + 0,0,2707,2708,1,0,0,0,2708,2709,5,3,0,0,2709,2711,1,0,0,0,2710,2633,1, + 0,0,0,2710,2646,1,0,0,0,2710,2675,1,0,0,0,2711,313,1,0,0,0,2712,2713, + 3,348,174,0,2713,315,1,0,0,0,2714,2716,3,348,174,0,2715,2717,5,183,0, + 0,2716,2715,1,0,0,0,2716,2717,1,0,0,0,2717,2718,1,0,0,0,2718,2719,5,164, + 0,0,2719,2721,5,6,0,0,2720,2722,5,183,0,0,2721,2720,1,0,0,0,2721,2722, + 1,0,0,0,2722,2724,1,0,0,0,2723,2714,1,0,0,0,2723,2724,1,0,0,0,2724,2725, + 1,0,0,0,2725,2728,3,244,122,0,2726,2728,3,318,159,0,2727,2723,1,0,0,0, + 2727,2726,1,0,0,0,2728,317,1,0,0,0,2729,2731,3,320,160,0,2730,2732,5, + 183,0,0,2731,2730,1,0,0,0,2731,2732,1,0,0,0,2732,2733,1,0,0,0,2733,2734, + 5,166,0,0,2734,2736,5,15,0,0,2735,2737,5,183,0,0,2736,2735,1,0,0,0,2736, + 2737,1,0,0,0,2737,2738,1,0,0,0,2738,2740,3,244,122,0,2739,2741,5,183, + 0,0,2740,2739,1,0,0,0,2740,2741,1,0,0,0,2741,319,1,0,0,0,2742,2767,3, + 348,174,0,2743,2745,5,2,0,0,2744,2746,5,183,0,0,2745,2744,1,0,0,0,2745, + 2746,1,0,0,0,2746,2747,1,0,0,0,2747,2749,3,348,174,0,2748,2750,5,183, + 0,0,2749,2748,1,0,0,0,2749,2750,1,0,0,0,2750,2761,1,0,0,0,2751,2753,5, + 4,0,0,2752,2754,5,183,0,0,2753,2752,1,0,0,0,2753,2754,1,0,0,0,2754,2755, + 1,0,0,0,2755,2757,3,348,174,0,2756,2758,5,183,0,0,2757,2756,1,0,0,0,2757, + 2758,1,0,0,0,2758,2760,1,0,0,0,2759,2751,1,0,0,0,2760,2763,1,0,0,0,2761, + 2759,1,0,0,0,2761,2762,1,0,0,0,2762,2764,1,0,0,0,2763,2761,1,0,0,0,2764, + 2765,5,3,0,0,2765,2767,1,0,0,0,2766,2742,1,0,0,0,2766,2743,1,0,0,0,2767, + 321,1,0,0,0,2768,2773,3,212,106,0,2769,2771,5,183,0,0,2770,2769,1,0,0, + 0,2770,2771,1,0,0,0,2771,2772,1,0,0,0,2772,2774,3,214,107,0,2773,2770, + 1,0,0,0,2774,2775,1,0,0,0,2775,2773,1,0,0,0,2775,2776,1,0,0,0,2776,323, + 1,0,0,0,2777,2779,7,8,0,0,2778,2780,5,183,0,0,2779,2778,1,0,0,0,2779, + 2780,1,0,0,0,2780,2781,1,0,0,0,2781,2783,5,9,0,0,2782,2784,5,183,0,0, + 2783,2782,1,0,0,0,2783,2784,1,0,0,0,2784,2785,1,0,0,0,2785,2787,5,106, + 0,0,2786,2788,5,183,0,0,2787,2786,1,0,0,0,2787,2788,1,0,0,0,2788,2789, + 1,0,0,0,2789,2794,3,204,102,0,2790,2792,5,183,0,0,2791,2790,1,0,0,0,2791, + 2792,1,0,0,0,2792,2793,1,0,0,0,2793,2795,3,202,101,0,2794,2791,1,0,0, + 0,2794,2795,1,0,0,0,2795,2800,1,0,0,0,2796,2798,5,183,0,0,2797,2796,1, + 0,0,0,2797,2798,1,0,0,0,2798,2799,1,0,0,0,2799,2801,3,166,83,0,2800,2797, + 1,0,0,0,2800,2801,1,0,0,0,2801,2803,1,0,0,0,2802,2804,5,183,0,0,2803, + 2802,1,0,0,0,2803,2804,1,0,0,0,2804,2805,1,0,0,0,2805,2806,5,10,0,0,2806, + 325,1,0,0,0,2807,2809,5,5,0,0,2808,2810,5,183,0,0,2809,2808,1,0,0,0,2809, + 2810,1,0,0,0,2810,2813,1,0,0,0,2811,2814,3,340,170,0,2812,2814,5,161, + 0,0,2813,2811,1,0,0,0,2813,2812,1,0,0,0,2814,327,1,0,0,0,2815,2820,5, + 58,0,0,2816,2818,5,183,0,0,2817,2816,1,0,0,0,2817,2818,1,0,0,0,2818,2819, + 1,0,0,0,2819,2821,3,330,165,0,2820,2817,1,0,0,0,2821,2822,1,0,0,0,2822, + 2820,1,0,0,0,2822,2823,1,0,0,0,2823,2838,1,0,0,0,2824,2826,5,58,0,0,2825, + 2827,5,183,0,0,2826,2825,1,0,0,0,2826,2827,1,0,0,0,2827,2828,1,0,0,0, + 2828,2833,3,244,122,0,2829,2831,5,183,0,0,2830,2829,1,0,0,0,2830,2831, + 1,0,0,0,2831,2832,1,0,0,0,2832,2834,3,330,165,0,2833,2830,1,0,0,0,2834, + 2835,1,0,0,0,2835,2833,1,0,0,0,2835,2836,1,0,0,0,2836,2838,1,0,0,0,2837, + 2815,1,0,0,0,2837,2824,1,0,0,0,2838,2847,1,0,0,0,2839,2841,5,183,0,0, + 2840,2839,1,0,0,0,2840,2841,1,0,0,0,2841,2842,1,0,0,0,2842,2844,5,79, + 0,0,2843,2845,5,183,0,0,2844,2843,1,0,0,0,2844,2845,1,0,0,0,2845,2846, + 1,0,0,0,2846,2848,3,244,122,0,2847,2840,1,0,0,0,2847,2848,1,0,0,0,2848, + 2850,1,0,0,0,2849,2851,5,183,0,0,2850,2849,1,0,0,0,2850,2851,1,0,0,0, + 2851,2852,1,0,0,0,2852,2853,5,80,0,0,2853,329,1,0,0,0,2854,2856,5,148, + 0,0,2855,2857,5,183,0,0,2856,2855,1,0,0,0,2856,2857,1,0,0,0,2857,2858, + 1,0,0,0,2858,2860,3,244,122,0,2859,2861,5,183,0,0,2860,2859,1,0,0,0,2860, + 2861,1,0,0,0,2861,2862,1,0,0,0,2862,2864,5,137,0,0,2863,2865,5,183,0, + 0,2864,2863,1,0,0,0,2864,2865,1,0,0,0,2865,2866,1,0,0,0,2866,2867,3,244, + 122,0,2867,331,1,0,0,0,2868,2869,3,348,174,0,2869,333,1,0,0,0,2870,2873, + 3,344,172,0,2871,2873,3,342,171,0,2872,2870,1,0,0,0,2872,2871,1,0,0,0, + 2873,335,1,0,0,0,2874,2877,5,25,0,0,2875,2878,3,348,174,0,2876,2878,5, + 170,0,0,2877,2875,1,0,0,0,2877,2876,1,0,0,0,2878,337,1,0,0,0,2879,2881, + 3,290,145,0,2880,2882,5,183,0,0,2881,2880,1,0,0,0,2881,2882,1,0,0,0,2882, + 2883,1,0,0,0,2883,2884,3,326,163,0,2884,339,1,0,0,0,2885,2886,3,346,173, + 0,2886,341,1,0,0,0,2887,2888,5,170,0,0,2888,343,1,0,0,0,2889,2890,7,9, + 0,0,2890,345,1,0,0,0,2891,2892,3,348,174,0,2892,347,1,0,0,0,2893,2899, + 5,179,0,0,2894,2895,5,182,0,0,2895,2899,6,174,-1,0,2896,2899,5,171,0, + 0,2897,2899,3,350,175,0,2898,2893,1,0,0,0,2898,2894,1,0,0,0,2898,2896, + 1,0,0,0,2898,2897,1,0,0,0,2899,349,1,0,0,0,2900,2901,7,10,0,0,2901,351, + 1,0,0,0,2902,2903,7,11,0,0,2903,353,1,0,0,0,2904,2905,7,12,0,0,2905,355, + 1,0,0,0,2906,2907,7,13,0,0,2907,357,1,0,0,0,495,360,364,369,373,378,381, + 385,388,412,418,425,429,433,437,440,444,448,452,457,461,463,470,474,483, + 488,498,502,506,511,524,528,536,540,544,548,556,560,564,568,583,588,594, + 598,601,604,610,614,619,622,627,631,635,640,655,659,666,686,690,693,696, + 699,702,706,711,715,725,729,734,739,744,750,754,758,763,770,774,778,781, + 798,802,806,810,814,817,820,828,837,841,846,850,854,858,862,864,868,872, + 874,882,887,891,895,899,903,907,912,930,937,950,957,973,977,986,994,997, + 1007,1010,1018,1021,1027,1030,1036,1051,1069,1076,1083,1094,1117,1126, + 1132,1136,1141,1150,1154,1159,1165,1171,1177,1181,1185,1191,1195,1199, + 1205,1209,1213,1219,1223,1227,1231,1235,1241,1245,1249,1253,1257,1267, + 1273,1280,1285,1291,1296,1313,1319,1325,1329,1333,1342,1356,1361,1366, + 1370,1375,1381,1386,1389,1393,1397,1401,1407,1411,1416,1421,1425,1428, + 1430,1434,1438,1444,1448,1453,1457,1466,1472,1480,1484,1488,1492,1499, + 1503,1507,1511,1514,1517,1524,1530,1534,1539,1546,1549,1552,1557,1561, + 1565,1570,1574,1583,1587,1592,1606,1608,1610,1615,1625,1631,1638,1651, + 1655,1659,1663,1668,1673,1677,1681,1685,1689,1693,1699,1703,1707,1711, + 1716,1722,1725,1731,1734,1740,1744,1748,1752,1756,1761,1766,1770,1775, + 1778,1787,1796,1801,1814,1817,1825,1829,1834,1839,1843,1848,1854,1859, + 1866,1870,1874,1876,1880,1882,1886,1888,1894,1900,1904,1907,1910,1916, + 1919,1922,1926,1932,1935,1938,1942,1946,1950,1952,1956,1958,1962,1964, + 1968,1970,1976,1980,1984,1988,1992,1996,2000,2004,2008,2011,2017,2021, + 2025,2028,2033,2038,2042,2046,2049,2052,2057,2062,2065,2068,2071,2074, + 2077,2081,2085,2089,2093,2103,2106,2109,2113,2116,2119,2123,2127,2131, + 2135,2139,2143,2145,2148,2152,2156,2160,2164,2166,2172,2175,2178,2199, + 2209,2219,2224,2228,2235,2239,2243,2247,2251,2259,2263,2267,2271,2277, + 2281,2287,2291,2296,2301,2305,2310,2315,2319,2325,2332,2336,2342,2349, + 2353,2359,2366,2370,2375,2383,2386,2391,2400,2404,2407,2420,2423,2428, + 2442,2446,2450,2455,2458,2462,2467,2479,2483,2487,2491,2497,2501,2505, + 2511,2515,2519,2525,2529,2533,2537,2555,2561,2565,2569,2573,2576,2582, + 2585,2589,2593,2597,2601,2605,2612,2615,2619,2625,2629,2635,2639,2643, + 2648,2652,2656,2660,2665,2668,2671,2677,2681,2685,2687,2691,2695,2699, + 2703,2706,2710,2716,2721,2723,2727,2731,2736,2740,2745,2749,2753,2757, + 2761,2766,2770,2775,2779,2783,2787,2791,2794,2797,2800,2803,2809,2813, + 2817,2822,2826,2830,2835,2837,2840,2844,2847,2850,2856,2860,2864,2872, + 2877,2881,2898 + }; + staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); + + antlr4::atn::ATNDeserializer deserializer; + staticData->atn = deserializer.deserialize(staticData->serializedATN); + + const size_t count = staticData->atn->getNumberOfDecisions(); + staticData->decisionToDFA.reserve(count); + for (size_t i = 0; i < count; i++) { + staticData->decisionToDFA.emplace_back(staticData->atn->getDecisionState(i), i); + } + cypherParserStaticData = staticData.release(); +} + +} + +CypherParser::CypherParser(TokenStream *input) : CypherParser(input, antlr4::atn::ParserATNSimulatorOptions()) {} + +CypherParser::CypherParser(TokenStream *input, const antlr4::atn::ParserATNSimulatorOptions &options) : Parser(input) { + CypherParser::initialize(); + _interpreter = new atn::ParserATNSimulator(this, *cypherParserStaticData->atn, cypherParserStaticData->decisionToDFA, cypherParserStaticData->sharedContextCache, options); +} + +CypherParser::~CypherParser() { + delete _interpreter; +} + +const atn::ATN& CypherParser::getATN() const { + return *cypherParserStaticData->atn; +} + +std::string CypherParser::getGrammarFileName() const { + return "Cypher.g4"; +} + +const std::vector& CypherParser::getRuleNames() const { + return cypherParserStaticData->ruleNames; +} + +const dfa::Vocabulary& CypherParser::getVocabulary() const { + return cypherParserStaticData->vocabulary; +} + +antlr4::atn::SerializedATNView CypherParser::getSerializedATN() const { + return cypherParserStaticData->serializedATN; +} + + +//----------------- Ku_StatementsContext ------------------------------------------------------------------ + +CypherParser::Ku_StatementsContext::Ku_StatementsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::Ku_StatementsContext::oC_Cypher() { + return getRuleContexts(); +} + +CypherParser::OC_CypherContext* CypherParser::Ku_StatementsContext::oC_Cypher(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::Ku_StatementsContext::EOF() { + return getToken(CypherParser::EOF, 0); +} + +std::vector CypherParser::Ku_StatementsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::Ku_StatementsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::Ku_StatementsContext::getRuleIndex() const { + return CypherParser::RuleKu_Statements; +} + + +CypherParser::Ku_StatementsContext* CypherParser::ku_Statements() { + Ku_StatementsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 0, CypherParser::RuleKu_Statements); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(358); + oC_Cypher(); + setState(369); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 2, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(360); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(359); + match(CypherParser::SP); + } + setState(362); + match(CypherParser::T__0); + setState(364); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 1, _ctx)) { + case 1: { + setState(363); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(366); + oC_Cypher(); + } + setState(371); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 2, _ctx); + } + setState(373); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(372); + match(CypherParser::SP); + } + setState(375); + match(CypherParser::EOF); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_CypherContext ------------------------------------------------------------------ + +CypherParser::OC_CypherContext::OC_CypherContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_StatementContext* CypherParser::OC_CypherContext::oC_Statement() { + return getRuleContext(0); +} + +CypherParser::OC_AnyCypherOptionContext* CypherParser::OC_CypherContext::oC_AnyCypherOption() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_CypherContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_CypherContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_CypherContext::getRuleIndex() const { + return CypherParser::RuleOC_Cypher; +} + + +CypherParser::OC_CypherContext* CypherParser::oC_Cypher() { + OC_CypherContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 2, CypherParser::RuleOC_Cypher); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(378); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::EXPLAIN + + || _la == CypherParser::PROFILE) { + setState(377); + oC_AnyCypherOption(); + } + setState(381); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(380); + match(CypherParser::SP); + } + + setState(383); + oC_Statement(); + setState(388); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 7, _ctx)) { + case 1: { + setState(385); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(384); + match(CypherParser::SP); + } + setState(387); + match(CypherParser::T__0); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_StatementContext ------------------------------------------------------------------ + +CypherParser::OC_StatementContext::OC_StatementContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_QueryContext* CypherParser::OC_StatementContext::oC_Query() { + return getRuleContext(0); +} + +CypherParser::KU_CreateUserContext* CypherParser::OC_StatementContext::kU_CreateUser() { + return getRuleContext(0); +} + +CypherParser::KU_CreateRoleContext* CypherParser::OC_StatementContext::kU_CreateRole() { + return getRuleContext(0); +} + +CypherParser::KU_CreateNodeTableContext* CypherParser::OC_StatementContext::kU_CreateNodeTable() { + return getRuleContext(0); +} + +CypherParser::KU_CreateRelTableContext* CypherParser::OC_StatementContext::kU_CreateRelTable() { + return getRuleContext(0); +} + +CypherParser::KU_CreateSequenceContext* CypherParser::OC_StatementContext::kU_CreateSequence() { + return getRuleContext(0); +} + +CypherParser::KU_CreateTypeContext* CypherParser::OC_StatementContext::kU_CreateType() { + return getRuleContext(0); +} + +CypherParser::KU_DropContext* CypherParser::OC_StatementContext::kU_Drop() { + return getRuleContext(0); +} + +CypherParser::KU_AlterTableContext* CypherParser::OC_StatementContext::kU_AlterTable() { + return getRuleContext(0); +} + +CypherParser::KU_CopyFromContext* CypherParser::OC_StatementContext::kU_CopyFrom() { + return getRuleContext(0); +} + +CypherParser::KU_CopyFromByColumnContext* CypherParser::OC_StatementContext::kU_CopyFromByColumn() { + return getRuleContext(0); +} + +CypherParser::KU_CopyTOContext* CypherParser::OC_StatementContext::kU_CopyTO() { + return getRuleContext(0); +} + +CypherParser::KU_StandaloneCallContext* CypherParser::OC_StatementContext::kU_StandaloneCall() { + return getRuleContext(0); +} + +CypherParser::KU_CreateMacroContext* CypherParser::OC_StatementContext::kU_CreateMacro() { + return getRuleContext(0); +} + +CypherParser::KU_CommentOnContext* CypherParser::OC_StatementContext::kU_CommentOn() { + return getRuleContext(0); +} + +CypherParser::KU_TransactionContext* CypherParser::OC_StatementContext::kU_Transaction() { + return getRuleContext(0); +} + +CypherParser::KU_ExtensionContext* CypherParser::OC_StatementContext::kU_Extension() { + return getRuleContext(0); +} + +CypherParser::KU_ExportDatabaseContext* CypherParser::OC_StatementContext::kU_ExportDatabase() { + return getRuleContext(0); +} + +CypherParser::KU_ImportDatabaseContext* CypherParser::OC_StatementContext::kU_ImportDatabase() { + return getRuleContext(0); +} + +CypherParser::KU_AttachDatabaseContext* CypherParser::OC_StatementContext::kU_AttachDatabase() { + return getRuleContext(0); +} + +CypherParser::KU_DetachDatabaseContext* CypherParser::OC_StatementContext::kU_DetachDatabase() { + return getRuleContext(0); +} + +CypherParser::KU_UseDatabaseContext* CypherParser::OC_StatementContext::kU_UseDatabase() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_StatementContext::getRuleIndex() const { + return CypherParser::RuleOC_Statement; +} + + +CypherParser::OC_StatementContext* CypherParser::oC_Statement() { + OC_StatementContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 4, CypherParser::RuleOC_Statement); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(412); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 8, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(390); + oC_Query(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(391); + kU_CreateUser(); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(392); + kU_CreateRole(); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(393); + kU_CreateNodeTable(); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(394); + kU_CreateRelTable(); + break; + } + + case 6: { + enterOuterAlt(_localctx, 6); + setState(395); + kU_CreateSequence(); + break; + } + + case 7: { + enterOuterAlt(_localctx, 7); + setState(396); + kU_CreateType(); + break; + } + + case 8: { + enterOuterAlt(_localctx, 8); + setState(397); + kU_Drop(); + break; + } + + case 9: { + enterOuterAlt(_localctx, 9); + setState(398); + kU_AlterTable(); + break; + } + + case 10: { + enterOuterAlt(_localctx, 10); + setState(399); + kU_CopyFrom(); + break; + } + + case 11: { + enterOuterAlt(_localctx, 11); + setState(400); + kU_CopyFromByColumn(); + break; + } + + case 12: { + enterOuterAlt(_localctx, 12); + setState(401); + kU_CopyTO(); + break; + } + + case 13: { + enterOuterAlt(_localctx, 13); + setState(402); + kU_StandaloneCall(); + break; + } + + case 14: { + enterOuterAlt(_localctx, 14); + setState(403); + kU_CreateMacro(); + break; + } + + case 15: { + enterOuterAlt(_localctx, 15); + setState(404); + kU_CommentOn(); + break; + } + + case 16: { + enterOuterAlt(_localctx, 16); + setState(405); + kU_Transaction(); + break; + } + + case 17: { + enterOuterAlt(_localctx, 17); + setState(406); + kU_Extension(); + break; + } + + case 18: { + enterOuterAlt(_localctx, 18); + setState(407); + kU_ExportDatabase(); + break; + } + + case 19: { + enterOuterAlt(_localctx, 19); + setState(408); + kU_ImportDatabase(); + break; + } + + case 20: { + enterOuterAlt(_localctx, 20); + setState(409); + kU_AttachDatabase(); + break; + } + + case 21: { + enterOuterAlt(_localctx, 21); + setState(410); + kU_DetachDatabase(); + break; + } + + case 22: { + enterOuterAlt(_localctx, 22); + setState(411); + kU_UseDatabase(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CopyFromContext ------------------------------------------------------------------ + +CypherParser::KU_CopyFromContext::KU_CopyFromContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CopyFromContext::COPY() { + return getToken(CypherParser::COPY, 0); +} + +std::vector CypherParser::KU_CopyFromContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CopyFromContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CopyFromContext::oC_SchemaName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CopyFromContext::FROM() { + return getToken(CypherParser::FROM, 0); +} + +CypherParser::KU_ScanSourceContext* CypherParser::KU_CopyFromContext::kU_ScanSource() { + return getRuleContext(0); +} + +CypherParser::KU_ColumnNamesContext* CypherParser::KU_CopyFromContext::kU_ColumnNames() { + return getRuleContext(0); +} + +CypherParser::KU_OptionsContext* CypherParser::KU_CopyFromContext::kU_Options() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CopyFromContext::getRuleIndex() const { + return CypherParser::RuleKU_CopyFrom; +} + + +CypherParser::KU_CopyFromContext* CypherParser::kU_CopyFrom() { + KU_CopyFromContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 6, CypherParser::RuleKU_CopyFrom); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(414); + match(CypherParser::COPY); + setState(415); + match(CypherParser::SP); + setState(416); + oC_SchemaName(); + setState(418); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 9, _ctx)) { + case 1: { + setState(417); + kU_ColumnNames(); + break; + } + + default: + break; + } + setState(420); + match(CypherParser::SP); + setState(421); + match(CypherParser::FROM); + setState(422); + match(CypherParser::SP); + setState(423); + kU_ScanSource(); + setState(437); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 13, _ctx)) { + case 1: { + setState(425); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(424); + match(CypherParser::SP); + } + setState(427); + match(CypherParser::T__1); + setState(429); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(428); + match(CypherParser::SP); + } + setState(431); + kU_Options(); + setState(433); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(432); + match(CypherParser::SP); + } + setState(435); + match(CypherParser::T__2); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ColumnNamesContext ------------------------------------------------------------------ + +CypherParser::KU_ColumnNamesContext::KU_ColumnNamesContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_ColumnNamesContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_ColumnNamesContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::KU_ColumnNamesContext::oC_SchemaName() { + return getRuleContexts(); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_ColumnNamesContext::oC_SchemaName(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::KU_ColumnNamesContext::getRuleIndex() const { + return CypherParser::RuleKU_ColumnNames; +} + + +CypherParser::KU_ColumnNamesContext* CypherParser::kU_ColumnNames() { + KU_ColumnNamesContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 8, CypherParser::RuleKU_ColumnNames); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(440); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(439); + match(CypherParser::SP); + } + setState(442); + match(CypherParser::T__1); + setState(444); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(443); + match(CypherParser::SP); + } + setState(463); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -3185593048922849280) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -287985230644762313) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5068755015275819) != 0)) { + setState(446); + oC_SchemaName(); + setState(457); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 18, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(448); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(447); + match(CypherParser::SP); + } + setState(450); + match(CypherParser::T__3); + setState(452); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(451); + match(CypherParser::SP); + } + setState(454); + oC_SchemaName(); + } + setState(459); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 18, _ctx); + } + setState(461); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(460); + match(CypherParser::SP); + } + } + setState(465); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ScanSourceContext ------------------------------------------------------------------ + +CypherParser::KU_ScanSourceContext::KU_ScanSourceContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::KU_FilePathsContext* CypherParser::KU_ScanSourceContext::kU_FilePaths() { + return getRuleContext(0); +} + +CypherParser::OC_QueryContext* CypherParser::KU_ScanSourceContext::oC_Query() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_ScanSourceContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_ScanSourceContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_ParameterContext* CypherParser::KU_ScanSourceContext::oC_Parameter() { + return getRuleContext(0); +} + +CypherParser::OC_VariableContext* CypherParser::KU_ScanSourceContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_ScanSourceContext::oC_SchemaName() { + return getRuleContext(0); +} + +CypherParser::OC_FunctionInvocationContext* CypherParser::KU_ScanSourceContext::oC_FunctionInvocation() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_ScanSourceContext::getRuleIndex() const { + return CypherParser::RuleKU_ScanSource; +} + + +CypherParser::KU_ScanSourceContext* CypherParser::kU_ScanSource() { + KU_ScanSourceContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 10, CypherParser::RuleKU_ScanSource); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(488); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 24, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(467); + kU_FilePaths(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(468); + match(CypherParser::T__1); + setState(470); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(469); + match(CypherParser::SP); + } + setState(472); + oC_Query(); + setState(474); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(473); + match(CypherParser::SP); + } + setState(476); + match(CypherParser::T__2); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(478); + oC_Parameter(); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(479); + oC_Variable(); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(480); + oC_Variable(); + setState(481); + match(CypherParser::T__4); + setState(483); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(482); + match(CypherParser::SP); + } + setState(485); + oC_SchemaName(); + break; + } + + case 6: { + enterOuterAlt(_localctx, 6); + setState(487); + oC_FunctionInvocation(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CopyFromByColumnContext ------------------------------------------------------------------ + +CypherParser::KU_CopyFromByColumnContext::KU_CopyFromByColumnContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CopyFromByColumnContext::COPY() { + return getToken(CypherParser::COPY, 0); +} + +std::vector CypherParser::KU_CopyFromByColumnContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CopyFromByColumnContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CopyFromByColumnContext::oC_SchemaName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CopyFromByColumnContext::FROM() { + return getToken(CypherParser::FROM, 0); +} + +std::vector CypherParser::KU_CopyFromByColumnContext::StringLiteral() { + return getTokens(CypherParser::StringLiteral); +} + +tree::TerminalNode* CypherParser::KU_CopyFromByColumnContext::StringLiteral(size_t i) { + return getToken(CypherParser::StringLiteral, i); +} + +tree::TerminalNode* CypherParser::KU_CopyFromByColumnContext::BY() { + return getToken(CypherParser::BY, 0); +} + +tree::TerminalNode* CypherParser::KU_CopyFromByColumnContext::COLUMN() { + return getToken(CypherParser::COLUMN, 0); +} + + +size_t CypherParser::KU_CopyFromByColumnContext::getRuleIndex() const { + return CypherParser::RuleKU_CopyFromByColumn; +} + + +CypherParser::KU_CopyFromByColumnContext* CypherParser::kU_CopyFromByColumn() { + KU_CopyFromByColumnContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 12, CypherParser::RuleKU_CopyFromByColumn); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(490); + match(CypherParser::COPY); + setState(491); + match(CypherParser::SP); + setState(492); + oC_SchemaName(); + setState(493); + match(CypherParser::SP); + setState(494); + match(CypherParser::FROM); + setState(495); + match(CypherParser::SP); + setState(496); + match(CypherParser::T__1); + setState(498); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(497); + match(CypherParser::SP); + } + setState(500); + match(CypherParser::StringLiteral); + setState(511); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3 || _la == CypherParser::SP) { + setState(502); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(501); + match(CypherParser::SP); + } + setState(504); + match(CypherParser::T__3); + setState(506); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(505); + match(CypherParser::SP); + } + setState(508); + match(CypherParser::StringLiteral); + setState(513); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(514); + match(CypherParser::T__2); + setState(515); + match(CypherParser::SP); + setState(516); + match(CypherParser::BY); + setState(517); + match(CypherParser::SP); + setState(518); + match(CypherParser::COLUMN); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CopyTOContext ------------------------------------------------------------------ + +CypherParser::KU_CopyTOContext::KU_CopyTOContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CopyTOContext::COPY() { + return getToken(CypherParser::COPY, 0); +} + +std::vector CypherParser::KU_CopyTOContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CopyTOContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_QueryContext* CypherParser::KU_CopyTOContext::oC_Query() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CopyTOContext::TO() { + return getToken(CypherParser::TO, 0); +} + +tree::TerminalNode* CypherParser::KU_CopyTOContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + +CypherParser::KU_OptionsContext* CypherParser::KU_CopyTOContext::kU_Options() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CopyTOContext::getRuleIndex() const { + return CypherParser::RuleKU_CopyTO; +} + + +CypherParser::KU_CopyTOContext* CypherParser::kU_CopyTO() { + KU_CopyTOContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 14, CypherParser::RuleKU_CopyTO); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(520); + match(CypherParser::COPY); + setState(521); + match(CypherParser::SP); + setState(522); + match(CypherParser::T__1); + setState(524); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(523); + match(CypherParser::SP); + } + setState(526); + oC_Query(); + setState(528); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(527); + match(CypherParser::SP); + } + setState(530); + match(CypherParser::T__2); + setState(531); + match(CypherParser::SP); + setState(532); + match(CypherParser::TO); + setState(533); + match(CypherParser::SP); + setState(534); + match(CypherParser::StringLiteral); + setState(548); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 34, _ctx)) { + case 1: { + setState(536); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(535); + match(CypherParser::SP); + } + setState(538); + match(CypherParser::T__1); + setState(540); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(539); + match(CypherParser::SP); + } + setState(542); + kU_Options(); + setState(544); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(543); + match(CypherParser::SP); + } + setState(546); + match(CypherParser::T__2); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ExportDatabaseContext ------------------------------------------------------------------ + +CypherParser::KU_ExportDatabaseContext::KU_ExportDatabaseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_ExportDatabaseContext::EXPORT() { + return getToken(CypherParser::EXPORT, 0); +} + +std::vector CypherParser::KU_ExportDatabaseContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_ExportDatabaseContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_ExportDatabaseContext::DATABASE() { + return getToken(CypherParser::DATABASE, 0); +} + +tree::TerminalNode* CypherParser::KU_ExportDatabaseContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + +CypherParser::KU_OptionsContext* CypherParser::KU_ExportDatabaseContext::kU_Options() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_ExportDatabaseContext::getRuleIndex() const { + return CypherParser::RuleKU_ExportDatabase; +} + + +CypherParser::KU_ExportDatabaseContext* CypherParser::kU_ExportDatabase() { + KU_ExportDatabaseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 16, CypherParser::RuleKU_ExportDatabase); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(550); + match(CypherParser::EXPORT); + setState(551); + match(CypherParser::SP); + setState(552); + match(CypherParser::DATABASE); + setState(553); + match(CypherParser::SP); + setState(554); + match(CypherParser::StringLiteral); + setState(568); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 38, _ctx)) { + case 1: { + setState(556); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(555); + match(CypherParser::SP); + } + setState(558); + match(CypherParser::T__1); + setState(560); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(559); + match(CypherParser::SP); + } + setState(562); + kU_Options(); + setState(564); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(563); + match(CypherParser::SP); + } + setState(566); + match(CypherParser::T__2); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ImportDatabaseContext ------------------------------------------------------------------ + +CypherParser::KU_ImportDatabaseContext::KU_ImportDatabaseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_ImportDatabaseContext::IMPORT() { + return getToken(CypherParser::IMPORT, 0); +} + +std::vector CypherParser::KU_ImportDatabaseContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_ImportDatabaseContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_ImportDatabaseContext::DATABASE() { + return getToken(CypherParser::DATABASE, 0); +} + +tree::TerminalNode* CypherParser::KU_ImportDatabaseContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + + +size_t CypherParser::KU_ImportDatabaseContext::getRuleIndex() const { + return CypherParser::RuleKU_ImportDatabase; +} + + +CypherParser::KU_ImportDatabaseContext* CypherParser::kU_ImportDatabase() { + KU_ImportDatabaseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 18, CypherParser::RuleKU_ImportDatabase); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(570); + match(CypherParser::IMPORT); + setState(571); + match(CypherParser::SP); + setState(572); + match(CypherParser::DATABASE); + setState(573); + match(CypherParser::SP); + setState(574); + match(CypherParser::StringLiteral); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_AttachDatabaseContext ------------------------------------------------------------------ + +CypherParser::KU_AttachDatabaseContext::KU_AttachDatabaseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_AttachDatabaseContext::ATTACH() { + return getToken(CypherParser::ATTACH, 0); +} + +std::vector CypherParser::KU_AttachDatabaseContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_AttachDatabaseContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_AttachDatabaseContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + +tree::TerminalNode* CypherParser::KU_AttachDatabaseContext::DBTYPE() { + return getToken(CypherParser::DBTYPE, 0); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_AttachDatabaseContext::oC_SymbolicName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_AttachDatabaseContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_AttachDatabaseContext::oC_SchemaName() { + return getRuleContext(0); +} + +CypherParser::KU_OptionsContext* CypherParser::KU_AttachDatabaseContext::kU_Options() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_AttachDatabaseContext::getRuleIndex() const { + return CypherParser::RuleKU_AttachDatabase; +} + + +CypherParser::KU_AttachDatabaseContext* CypherParser::kU_AttachDatabase() { + KU_AttachDatabaseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 20, CypherParser::RuleKU_AttachDatabase); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(576); + match(CypherParser::ATTACH); + setState(577); + match(CypherParser::SP); + setState(578); + match(CypherParser::StringLiteral); + setState(583); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 39, _ctx)) { + case 1: { + setState(579); + match(CypherParser::SP); + setState(580); + match(CypherParser::AS); + setState(581); + match(CypherParser::SP); + setState(582); + oC_SchemaName(); + break; + } + + default: + break; + } + setState(585); + match(CypherParser::SP); + setState(586); + match(CypherParser::T__1); + setState(588); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(587); + match(CypherParser::SP); + } + setState(590); + match(CypherParser::DBTYPE); + setState(591); + match(CypherParser::SP); + setState(592); + oC_SymbolicName(); + setState(601); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 43, _ctx)) { + case 1: { + setState(594); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(593); + match(CypherParser::SP); + } + setState(596); + match(CypherParser::T__3); + setState(598); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(597); + match(CypherParser::SP); + } + setState(600); + kU_Options(); + break; + } + + default: + break; + } + setState(604); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(603); + match(CypherParser::SP); + } + setState(606); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_OptionContext ------------------------------------------------------------------ + +CypherParser::KU_OptionContext::KU_OptionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_OptionContext::oC_SymbolicName() { + return getRuleContext(0); +} + +CypherParser::OC_LiteralContext* CypherParser::KU_OptionContext::oC_Literal() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_OptionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_OptionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_OptionContext::getRuleIndex() const { + return CypherParser::RuleKU_Option; +} + + +CypherParser::KU_OptionContext* CypherParser::kU_Option() { + KU_OptionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 22, CypherParser::RuleKU_Option); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(627); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 49, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(608); + oC_SymbolicName(); + setState(622); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 48, _ctx)) { + case 1: { + setState(610); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(609); + match(CypherParser::SP); + } + setState(612); + match(CypherParser::T__5); + setState(614); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(613); + match(CypherParser::SP); + } + break; + } + + case 2: { + setState(619); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::SP) { + setState(616); + match(CypherParser::SP); + setState(621); + _errHandler->sync(this); + _la = _input->LA(1); + } + break; + } + + default: + break; + } + setState(624); + oC_Literal(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(626); + oC_SymbolicName(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_OptionsContext ------------------------------------------------------------------ + +CypherParser::KU_OptionsContext::KU_OptionsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_OptionsContext::kU_Option() { + return getRuleContexts(); +} + +CypherParser::KU_OptionContext* CypherParser::KU_OptionsContext::kU_Option(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_OptionsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_OptionsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_OptionsContext::getRuleIndex() const { + return CypherParser::RuleKU_Options; +} + + +CypherParser::KU_OptionsContext* CypherParser::kU_Options() { + KU_OptionsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 24, CypherParser::RuleKU_Options); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(629); + kU_Option(); + setState(640); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 52, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(631); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(630); + match(CypherParser::SP); + } + setState(633); + match(CypherParser::T__3); + setState(635); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(634); + match(CypherParser::SP); + } + setState(637); + kU_Option(); + } + setState(642); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 52, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DetachDatabaseContext ------------------------------------------------------------------ + +CypherParser::KU_DetachDatabaseContext::KU_DetachDatabaseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_DetachDatabaseContext::DETACH() { + return getToken(CypherParser::DETACH, 0); +} + +tree::TerminalNode* CypherParser::KU_DetachDatabaseContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_DetachDatabaseContext::oC_SchemaName() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_DetachDatabaseContext::getRuleIndex() const { + return CypherParser::RuleKU_DetachDatabase; +} + + +CypherParser::KU_DetachDatabaseContext* CypherParser::kU_DetachDatabase() { + KU_DetachDatabaseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 26, CypherParser::RuleKU_DetachDatabase); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(643); + match(CypherParser::DETACH); + setState(644); + match(CypherParser::SP); + setState(645); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_UseDatabaseContext ------------------------------------------------------------------ + +CypherParser::KU_UseDatabaseContext::KU_UseDatabaseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_UseDatabaseContext::USE() { + return getToken(CypherParser::USE, 0); +} + +tree::TerminalNode* CypherParser::KU_UseDatabaseContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_UseDatabaseContext::oC_SchemaName() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_UseDatabaseContext::getRuleIndex() const { + return CypherParser::RuleKU_UseDatabase; +} + + +CypherParser::KU_UseDatabaseContext* CypherParser::kU_UseDatabase() { + KU_UseDatabaseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 28, CypherParser::RuleKU_UseDatabase); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(647); + match(CypherParser::USE); + setState(648); + match(CypherParser::SP); + setState(649); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_StandaloneCallContext ------------------------------------------------------------------ + +CypherParser::KU_StandaloneCallContext::KU_StandaloneCallContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_StandaloneCallContext::CALL() { + return getToken(CypherParser::CALL, 0); +} + +std::vector CypherParser::KU_StandaloneCallContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_StandaloneCallContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_StandaloneCallContext::oC_SymbolicName() { + return getRuleContext(0); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_StandaloneCallContext::oC_Expression() { + return getRuleContext(0); +} + +CypherParser::OC_FunctionInvocationContext* CypherParser::KU_StandaloneCallContext::oC_FunctionInvocation() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_StandaloneCallContext::getRuleIndex() const { + return CypherParser::RuleKU_StandaloneCall; +} + + +CypherParser::KU_StandaloneCallContext* CypherParser::kU_StandaloneCall() { + KU_StandaloneCallContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 30, CypherParser::RuleKU_StandaloneCall); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(666); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 55, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(651); + match(CypherParser::CALL); + setState(652); + match(CypherParser::SP); + setState(653); + oC_SymbolicName(); + setState(655); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(654); + match(CypherParser::SP); + } + setState(657); + match(CypherParser::T__5); + setState(659); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(658); + match(CypherParser::SP); + } + setState(661); + oC_Expression(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(663); + match(CypherParser::CALL); + setState(664); + match(CypherParser::SP); + setState(665); + oC_FunctionInvocation(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CommentOnContext ------------------------------------------------------------------ + +CypherParser::KU_CommentOnContext::KU_CommentOnContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CommentOnContext::COMMENT() { + return getToken(CypherParser::COMMENT, 0); +} + +std::vector CypherParser::KU_CommentOnContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CommentOnContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CommentOnContext::ON() { + return getToken(CypherParser::ON, 0); +} + +tree::TerminalNode* CypherParser::KU_CommentOnContext::TABLE() { + return getToken(CypherParser::TABLE, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CommentOnContext::oC_SchemaName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CommentOnContext::IS() { + return getToken(CypherParser::IS, 0); +} + +tree::TerminalNode* CypherParser::KU_CommentOnContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + + +size_t CypherParser::KU_CommentOnContext::getRuleIndex() const { + return CypherParser::RuleKU_CommentOn; +} + + +CypherParser::KU_CommentOnContext* CypherParser::kU_CommentOn() { + KU_CommentOnContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 32, CypherParser::RuleKU_CommentOn); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(668); + match(CypherParser::COMMENT); + setState(669); + match(CypherParser::SP); + setState(670); + match(CypherParser::ON); + setState(671); + match(CypherParser::SP); + setState(672); + match(CypherParser::TABLE); + setState(673); + match(CypherParser::SP); + setState(674); + oC_SchemaName(); + setState(675); + match(CypherParser::SP); + setState(676); + match(CypherParser::IS); + setState(677); + match(CypherParser::SP); + setState(678); + match(CypherParser::StringLiteral); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateMacroContext ------------------------------------------------------------------ + +CypherParser::KU_CreateMacroContext::KU_CreateMacroContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateMacroContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateMacroContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateMacroContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateMacroContext::MACRO() { + return getToken(CypherParser::MACRO, 0); +} + +CypherParser::OC_FunctionNameContext* CypherParser::KU_CreateMacroContext::oC_FunctionName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CreateMacroContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_CreateMacroContext::oC_Expression() { + return getRuleContext(0); +} + +CypherParser::KU_PositionalArgsContext* CypherParser::KU_CreateMacroContext::kU_PositionalArgs() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_CreateMacroContext::kU_DefaultArg() { + return getRuleContexts(); +} + +CypherParser::KU_DefaultArgContext* CypherParser::KU_CreateMacroContext::kU_DefaultArg(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::KU_CreateMacroContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateMacro; +} + + +CypherParser::KU_CreateMacroContext* CypherParser::kU_CreateMacro() { + KU_CreateMacroContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 34, CypherParser::RuleKU_CreateMacro); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(680); + match(CypherParser::CREATE); + setState(681); + match(CypherParser::SP); + setState(682); + match(CypherParser::MACRO); + setState(683); + match(CypherParser::SP); + setState(684); + oC_FunctionName(); + setState(686); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(685); + match(CypherParser::SP); + } + setState(688); + match(CypherParser::T__1); + setState(690); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 57, _ctx)) { + case 1: { + setState(689); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(693); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 58, _ctx)) { + case 1: { + setState(692); + kU_PositionalArgs(); + break; + } + + default: + break; + } + setState(696); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 59, _ctx)) { + case 1: { + setState(695); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(699); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -3185593048922849280) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -287985230644762313) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5068755015275819) != 0)) { + setState(698); + kU_DefaultArg(); + } + setState(711); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 63, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(702); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(701); + match(CypherParser::SP); + } + setState(704); + match(CypherParser::T__3); + setState(706); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(705); + match(CypherParser::SP); + } + setState(708); + kU_DefaultArg(); + } + setState(713); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 63, _ctx); + } + setState(715); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(714); + match(CypherParser::SP); + } + setState(717); + match(CypherParser::T__2); + setState(718); + match(CypherParser::SP); + setState(719); + match(CypherParser::AS); + setState(720); + match(CypherParser::SP); + setState(721); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_PositionalArgsContext ------------------------------------------------------------------ + +CypherParser::KU_PositionalArgsContext::KU_PositionalArgsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_PositionalArgsContext::oC_SymbolicName() { + return getRuleContexts(); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_PositionalArgsContext::oC_SymbolicName(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_PositionalArgsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_PositionalArgsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_PositionalArgsContext::getRuleIndex() const { + return CypherParser::RuleKU_PositionalArgs; +} + + +CypherParser::KU_PositionalArgsContext* CypherParser::kU_PositionalArgs() { + KU_PositionalArgsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 36, CypherParser::RuleKU_PositionalArgs); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(723); + oC_SymbolicName(); + setState(734); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 67, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(725); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(724); + match(CypherParser::SP); + } + setState(727); + match(CypherParser::T__3); + setState(729); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(728); + match(CypherParser::SP); + } + setState(731); + oC_SymbolicName(); + } + setState(736); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 67, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DefaultArgContext ------------------------------------------------------------------ + +CypherParser::KU_DefaultArgContext::KU_DefaultArgContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_DefaultArgContext::oC_SymbolicName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_DefaultArgContext::COLON() { + return getToken(CypherParser::COLON, 0); +} + +CypherParser::OC_LiteralContext* CypherParser::KU_DefaultArgContext::oC_Literal() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_DefaultArgContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_DefaultArgContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_DefaultArgContext::getRuleIndex() const { + return CypherParser::RuleKU_DefaultArg; +} + + +CypherParser::KU_DefaultArgContext* CypherParser::kU_DefaultArg() { + KU_DefaultArgContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 38, CypherParser::RuleKU_DefaultArg); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(737); + oC_SymbolicName(); + setState(739); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(738); + match(CypherParser::SP); + } + setState(741); + match(CypherParser::COLON); + setState(742); + match(CypherParser::T__5); + setState(744); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(743); + match(CypherParser::SP); + } + setState(746); + oC_Literal(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_FilePathsContext ------------------------------------------------------------------ + +CypherParser::KU_FilePathsContext::KU_FilePathsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_FilePathsContext::StringLiteral() { + return getTokens(CypherParser::StringLiteral); +} + +tree::TerminalNode* CypherParser::KU_FilePathsContext::StringLiteral(size_t i) { + return getToken(CypherParser::StringLiteral, i); +} + +std::vector CypherParser::KU_FilePathsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_FilePathsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_FilePathsContext::GLOB() { + return getToken(CypherParser::GLOB, 0); +} + + +size_t CypherParser::KU_FilePathsContext::getRuleIndex() const { + return CypherParser::RuleKU_FilePaths; +} + + +CypherParser::KU_FilePathsContext* CypherParser::kU_FilePaths() { + KU_FilePathsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 40, CypherParser::RuleKU_FilePaths); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(781); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::T__6: { + enterOuterAlt(_localctx, 1); + setState(748); + match(CypherParser::T__6); + setState(750); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(749); + match(CypherParser::SP); + } + setState(752); + match(CypherParser::StringLiteral); + setState(763); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3 || _la == CypherParser::SP) { + setState(754); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(753); + match(CypherParser::SP); + } + setState(756); + match(CypherParser::T__3); + setState(758); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(757); + match(CypherParser::SP); + } + setState(760); + match(CypherParser::StringLiteral); + setState(765); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(766); + match(CypherParser::T__7); + break; + } + + case CypherParser::StringLiteral: { + enterOuterAlt(_localctx, 2); + setState(767); + match(CypherParser::StringLiteral); + break; + } + + case CypherParser::GLOB: { + enterOuterAlt(_localctx, 3); + setState(768); + match(CypherParser::GLOB); + setState(770); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(769); + match(CypherParser::SP); + } + setState(772); + match(CypherParser::T__1); + setState(774); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(773); + match(CypherParser::SP); + } + setState(776); + match(CypherParser::StringLiteral); + setState(778); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(777); + match(CypherParser::SP); + } + setState(780); + match(CypherParser::T__2); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_IfNotExistsContext ------------------------------------------------------------------ + +CypherParser::KU_IfNotExistsContext::KU_IfNotExistsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_IfNotExistsContext::IF() { + return getToken(CypherParser::IF, 0); +} + +std::vector CypherParser::KU_IfNotExistsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_IfNotExistsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_IfNotExistsContext::NOT() { + return getToken(CypherParser::NOT, 0); +} + +tree::TerminalNode* CypherParser::KU_IfNotExistsContext::EXISTS() { + return getToken(CypherParser::EXISTS, 0); +} + + +size_t CypherParser::KU_IfNotExistsContext::getRuleIndex() const { + return CypherParser::RuleKU_IfNotExists; +} + + +CypherParser::KU_IfNotExistsContext* CypherParser::kU_IfNotExists() { + KU_IfNotExistsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 42, CypherParser::RuleKU_IfNotExists); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(783); + match(CypherParser::IF); + setState(784); + match(CypherParser::SP); + setState(785); + match(CypherParser::NOT); + setState(786); + match(CypherParser::SP); + setState(787); + match(CypherParser::EXISTS); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateNodeTableContext ------------------------------------------------------------------ + +CypherParser::KU_CreateNodeTableContext::KU_CreateNodeTableContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateNodeTableContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateNodeTableContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateNodeTableContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateNodeTableContext::NODE() { + return getToken(CypherParser::NODE, 0); +} + +tree::TerminalNode* CypherParser::KU_CreateNodeTableContext::TABLE() { + return getToken(CypherParser::TABLE, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CreateNodeTableContext::oC_SchemaName() { + return getRuleContext(0); +} + +CypherParser::KU_PropertyDefinitionsContext* CypherParser::KU_CreateNodeTableContext::kU_PropertyDefinitions() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CreateNodeTableContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::OC_QueryContext* CypherParser::KU_CreateNodeTableContext::oC_Query() { + return getRuleContext(0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_CreateNodeTableContext::kU_IfNotExists() { + return getRuleContext(0); +} + +CypherParser::KU_CreateNodeConstraintContext* CypherParser::KU_CreateNodeTableContext::kU_CreateNodeConstraint() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CreateNodeTableContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateNodeTable; +} + + +CypherParser::KU_CreateNodeTableContext* CypherParser::kU_CreateNodeTable() { + KU_CreateNodeTableContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 44, CypherParser::RuleKU_CreateNodeTable); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(789); + match(CypherParser::CREATE); + setState(790); + match(CypherParser::SP); + setState(791); + match(CypherParser::NODE); + setState(792); + match(CypherParser::SP); + setState(793); + match(CypherParser::TABLE); + setState(794); + match(CypherParser::SP); + setState(798); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 78, _ctx)) { + case 1: { + setState(795); + kU_IfNotExists(); + setState(796); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(800); + oC_SchemaName(); + setState(828); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 85, _ctx)) { + case 1: { + setState(802); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(801); + match(CypherParser::SP); + } + setState(804); + match(CypherParser::T__1); + setState(806); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(805); + match(CypherParser::SP); + } + setState(808); + kU_PropertyDefinitions(); + setState(810); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 81, _ctx)) { + case 1: { + setState(809); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(817); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__3) { + setState(812); + match(CypherParser::T__3); + setState(814); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(813); + match(CypherParser::SP); + } + setState(816); + kU_CreateNodeConstraint(); + } + setState(820); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(819); + match(CypherParser::SP); + } + setState(822); + match(CypherParser::T__2); + break; + } + + case 2: { + setState(824); + match(CypherParser::SP); + setState(825); + match(CypherParser::AS); + setState(826); + match(CypherParser::SP); + setState(827); + oC_Query(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateRelTableContext ------------------------------------------------------------------ + +CypherParser::KU_CreateRelTableContext::KU_CreateRelTableContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateRelTableContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::REL() { + return getToken(CypherParser::REL, 0); +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::TABLE() { + return getToken(CypherParser::TABLE, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CreateRelTableContext::oC_SchemaName() { + return getRuleContext(0); +} + +CypherParser::KU_FromToConnectionsContext* CypherParser::KU_CreateRelTableContext::kU_FromToConnections() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::OC_QueryContext* CypherParser::KU_CreateRelTableContext::oC_Query() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::GROUP() { + return getToken(CypherParser::GROUP, 0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_CreateRelTableContext::kU_IfNotExists() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CreateRelTableContext::WITH() { + return getToken(CypherParser::WITH, 0); +} + +CypherParser::KU_OptionsContext* CypherParser::KU_CreateRelTableContext::kU_Options() { + return getRuleContext(0); +} + +CypherParser::KU_PropertyDefinitionsContext* CypherParser::KU_CreateRelTableContext::kU_PropertyDefinitions() { + return getRuleContext(0); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_CreateRelTableContext::oC_SymbolicName() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CreateRelTableContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateRelTable; +} + + +CypherParser::KU_CreateRelTableContext* CypherParser::kU_CreateRelTable() { + KU_CreateRelTableContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 46, CypherParser::RuleKU_CreateRelTable); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(830); + match(CypherParser::CREATE); + setState(831); + match(CypherParser::SP); + setState(832); + match(CypherParser::REL); + setState(833); + match(CypherParser::SP); + setState(834); + match(CypherParser::TABLE); + setState(837); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 86, _ctx)) { + case 1: { + setState(835); + match(CypherParser::SP); + setState(836); + match(CypherParser::GROUP); + break; + } + + default: + break; + } + setState(841); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 87, _ctx)) { + case 1: { + setState(839); + match(CypherParser::SP); + setState(840); + kU_IfNotExists(); + break; + } + + default: + break; + } + setState(843); + match(CypherParser::SP); + setState(844); + oC_SchemaName(); + setState(846); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(845); + match(CypherParser::SP); + } + setState(848); + match(CypherParser::T__1); + setState(850); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(849); + match(CypherParser::SP); + } + setState(852); + kU_FromToConnections(); + setState(854); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(853); + match(CypherParser::SP); + } + setState(882); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 97, _ctx)) { + case 1: { + setState(864); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 93, _ctx)) { + case 1: { + setState(856); + match(CypherParser::T__3); + setState(858); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(857); + match(CypherParser::SP); + } + setState(860); + kU_PropertyDefinitions(); + setState(862); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(861); + match(CypherParser::SP); + } + break; + } + + default: + break; + } + setState(874); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__3) { + setState(866); + match(CypherParser::T__3); + setState(868); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(867); + match(CypherParser::SP); + } + setState(870); + oC_SymbolicName(); + setState(872); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(871); + match(CypherParser::SP); + } + } + setState(876); + match(CypherParser::T__2); + break; + } + + case 2: { + setState(877); + match(CypherParser::T__2); + setState(878); + match(CypherParser::SP); + setState(879); + match(CypherParser::AS); + setState(880); + match(CypherParser::SP); + setState(881); + oC_Query(); + break; + } + + default: + break; + } + setState(899); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 101, _ctx)) { + case 1: { + setState(884); + match(CypherParser::SP); + setState(885); + match(CypherParser::WITH); + setState(887); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(886); + match(CypherParser::SP); + } + setState(889); + match(CypherParser::T__1); + setState(891); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(890); + match(CypherParser::SP); + } + setState(893); + kU_Options(); + setState(895); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(894); + match(CypherParser::SP); + } + setState(897); + match(CypherParser::T__2); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_FromToConnectionsContext ------------------------------------------------------------------ + +CypherParser::KU_FromToConnectionsContext::KU_FromToConnectionsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_FromToConnectionsContext::kU_FromToConnection() { + return getRuleContexts(); +} + +CypherParser::KU_FromToConnectionContext* CypherParser::KU_FromToConnectionsContext::kU_FromToConnection(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_FromToConnectionsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_FromToConnectionsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_FromToConnectionsContext::getRuleIndex() const { + return CypherParser::RuleKU_FromToConnections; +} + + +CypherParser::KU_FromToConnectionsContext* CypherParser::kU_FromToConnections() { + KU_FromToConnectionsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 48, CypherParser::RuleKU_FromToConnections); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(901); + kU_FromToConnection(); + setState(912); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 104, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(903); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(902); + match(CypherParser::SP); + } + setState(905); + match(CypherParser::T__3); + setState(907); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(906); + match(CypherParser::SP); + } + setState(909); + kU_FromToConnection(); + } + setState(914); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 104, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_FromToConnectionContext ------------------------------------------------------------------ + +CypherParser::KU_FromToConnectionContext::KU_FromToConnectionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_FromToConnectionContext::FROM() { + return getToken(CypherParser::FROM, 0); +} + +std::vector CypherParser::KU_FromToConnectionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_FromToConnectionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::KU_FromToConnectionContext::oC_SchemaName() { + return getRuleContexts(); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_FromToConnectionContext::oC_SchemaName(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::KU_FromToConnectionContext::TO() { + return getToken(CypherParser::TO, 0); +} + + +size_t CypherParser::KU_FromToConnectionContext::getRuleIndex() const { + return CypherParser::RuleKU_FromToConnection; +} + + +CypherParser::KU_FromToConnectionContext* CypherParser::kU_FromToConnection() { + KU_FromToConnectionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 50, CypherParser::RuleKU_FromToConnection); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(915); + match(CypherParser::FROM); + setState(916); + match(CypherParser::SP); + setState(917); + oC_SchemaName(); + setState(918); + match(CypherParser::SP); + setState(919); + match(CypherParser::TO); + setState(920); + match(CypherParser::SP); + setState(921); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateSequenceContext ------------------------------------------------------------------ + +CypherParser::KU_CreateSequenceContext::KU_CreateSequenceContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateSequenceContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateSequenceContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateSequenceContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateSequenceContext::SEQUENCE() { + return getToken(CypherParser::SEQUENCE, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CreateSequenceContext::oC_SchemaName() { + return getRuleContext(0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_CreateSequenceContext::kU_IfNotExists() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_CreateSequenceContext::kU_SequenceOptions() { + return getRuleContexts(); +} + +CypherParser::KU_SequenceOptionsContext* CypherParser::KU_CreateSequenceContext::kU_SequenceOptions(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::KU_CreateSequenceContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateSequence; +} + + +CypherParser::KU_CreateSequenceContext* CypherParser::kU_CreateSequence() { + KU_CreateSequenceContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 52, CypherParser::RuleKU_CreateSequence); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(923); + match(CypherParser::CREATE); + setState(924); + match(CypherParser::SP); + setState(925); + match(CypherParser::SEQUENCE); + setState(926); + match(CypherParser::SP); + setState(930); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 105, _ctx)) { + case 1: { + setState(927); + kU_IfNotExists(); + setState(928); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(932); + oC_SchemaName(); + setState(937); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 106, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(933); + match(CypherParser::SP); + setState(934); + kU_SequenceOptions(); + } + setState(939); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 106, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateTypeContext ------------------------------------------------------------------ + +CypherParser::KU_CreateTypeContext::KU_CreateTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateTypeContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateTypeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateTypeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateTypeContext::TYPE() { + return getToken(CypherParser::TYPE, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_CreateTypeContext::oC_SchemaName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_CreateTypeContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::KU_DataTypeContext* CypherParser::KU_CreateTypeContext::kU_DataType() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CreateTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateType; +} + + +CypherParser::KU_CreateTypeContext* CypherParser::kU_CreateType() { + KU_CreateTypeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 54, CypherParser::RuleKU_CreateType); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(940); + match(CypherParser::CREATE); + setState(941); + match(CypherParser::SP); + setState(942); + match(CypherParser::TYPE); + setState(943); + match(CypherParser::SP); + setState(944); + oC_SchemaName(); + setState(945); + match(CypherParser::SP); + setState(946); + match(CypherParser::AS); + setState(947); + match(CypherParser::SP); + setState(948); + kU_DataType(0); + setState(950); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 107, _ctx)) { + case 1: { + setState(949); + match(CypherParser::SP); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_SequenceOptionsContext ------------------------------------------------------------------ + +CypherParser::KU_SequenceOptionsContext::KU_SequenceOptionsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::KU_IncrementByContext* CypherParser::KU_SequenceOptionsContext::kU_IncrementBy() { + return getRuleContext(0); +} + +CypherParser::KU_MinValueContext* CypherParser::KU_SequenceOptionsContext::kU_MinValue() { + return getRuleContext(0); +} + +CypherParser::KU_MaxValueContext* CypherParser::KU_SequenceOptionsContext::kU_MaxValue() { + return getRuleContext(0); +} + +CypherParser::KU_StartWithContext* CypherParser::KU_SequenceOptionsContext::kU_StartWith() { + return getRuleContext(0); +} + +CypherParser::KU_CycleContext* CypherParser::KU_SequenceOptionsContext::kU_Cycle() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_SequenceOptionsContext::getRuleIndex() const { + return CypherParser::RuleKU_SequenceOptions; +} + + +CypherParser::KU_SequenceOptionsContext* CypherParser::kU_SequenceOptions() { + KU_SequenceOptionsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 56, CypherParser::RuleKU_SequenceOptions); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(957); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 108, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(952); + kU_IncrementBy(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(953); + kU_MinValue(); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(954); + kU_MaxValue(); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(955); + kU_StartWith(); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(956); + kU_Cycle(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_WithPasswdContext ------------------------------------------------------------------ + +CypherParser::KU_WithPasswdContext::KU_WithPasswdContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_WithPasswdContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_WithPasswdContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_WithPasswdContext::WITH() { + return getToken(CypherParser::WITH, 0); +} + +tree::TerminalNode* CypherParser::KU_WithPasswdContext::PASSWORD() { + return getToken(CypherParser::PASSWORD, 0); +} + +tree::TerminalNode* CypherParser::KU_WithPasswdContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + + +size_t CypherParser::KU_WithPasswdContext::getRuleIndex() const { + return CypherParser::RuleKU_WithPasswd; +} + + +CypherParser::KU_WithPasswdContext* CypherParser::kU_WithPasswd() { + KU_WithPasswdContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 58, CypherParser::RuleKU_WithPasswd); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(959); + match(CypherParser::SP); + setState(960); + match(CypherParser::WITH); + setState(961); + match(CypherParser::SP); + setState(962); + match(CypherParser::PASSWORD); + setState(963); + match(CypherParser::SP); + setState(964); + match(CypherParser::StringLiteral); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateUserContext ------------------------------------------------------------------ + +CypherParser::KU_CreateUserContext::KU_CreateUserContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateUserContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateUserContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateUserContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateUserContext::USER() { + return getToken(CypherParser::USER, 0); +} + +CypherParser::OC_VariableContext* CypherParser::KU_CreateUserContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_CreateUserContext::kU_IfNotExists() { + return getRuleContext(0); +} + +CypherParser::KU_WithPasswdContext* CypherParser::KU_CreateUserContext::kU_WithPasswd() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CreateUserContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateUser; +} + + +CypherParser::KU_CreateUserContext* CypherParser::kU_CreateUser() { + KU_CreateUserContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 60, CypherParser::RuleKU_CreateUser); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(966); + match(CypherParser::CREATE); + setState(967); + match(CypherParser::SP); + setState(968); + match(CypherParser::USER); + setState(969); + match(CypherParser::SP); + setState(973); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 109, _ctx)) { + case 1: { + setState(970); + kU_IfNotExists(); + setState(971); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(975); + oC_Variable(); + setState(977); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 110, _ctx)) { + case 1: { + setState(976); + kU_WithPasswd(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateRoleContext ------------------------------------------------------------------ + +CypherParser::KU_CreateRoleContext::KU_CreateRoleContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateRoleContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +std::vector CypherParser::KU_CreateRoleContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateRoleContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateRoleContext::ROLE() { + return getToken(CypherParser::ROLE, 0); +} + +CypherParser::OC_VariableContext* CypherParser::KU_CreateRoleContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_CreateRoleContext::kU_IfNotExists() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CreateRoleContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateRole; +} + + +CypherParser::KU_CreateRoleContext* CypherParser::kU_CreateRole() { + KU_CreateRoleContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 62, CypherParser::RuleKU_CreateRole); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(979); + match(CypherParser::CREATE); + setState(980); + match(CypherParser::SP); + setState(981); + match(CypherParser::ROLE); + setState(982); + match(CypherParser::SP); + setState(986); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 111, _ctx)) { + case 1: { + setState(983); + kU_IfNotExists(); + setState(984); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(988); + oC_Variable(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_IncrementByContext ------------------------------------------------------------------ + +CypherParser::KU_IncrementByContext::KU_IncrementByContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_IncrementByContext::INCREMENT() { + return getToken(CypherParser::INCREMENT, 0); +} + +std::vector CypherParser::KU_IncrementByContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_IncrementByContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::KU_IncrementByContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_IncrementByContext::BY() { + return getToken(CypherParser::BY, 0); +} + +tree::TerminalNode* CypherParser::KU_IncrementByContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + + +size_t CypherParser::KU_IncrementByContext::getRuleIndex() const { + return CypherParser::RuleKU_IncrementBy; +} + + +CypherParser::KU_IncrementByContext* CypherParser::kU_IncrementBy() { + KU_IncrementByContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 64, CypherParser::RuleKU_IncrementBy); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(990); + match(CypherParser::INCREMENT); + setState(991); + match(CypherParser::SP); + setState(994); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::BY) { + setState(992); + match(CypherParser::BY); + setState(993); + match(CypherParser::SP); + } + setState(997); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::MINUS) { + setState(996); + match(CypherParser::MINUS); + } + setState(999); + oC_IntegerLiteral(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_MinValueContext ------------------------------------------------------------------ + +CypherParser::KU_MinValueContext::KU_MinValueContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_MinValueContext::NO() { + return getToken(CypherParser::NO, 0); +} + +tree::TerminalNode* CypherParser::KU_MinValueContext::SP() { + return getToken(CypherParser::SP, 0); +} + +tree::TerminalNode* CypherParser::KU_MinValueContext::MINVALUE() { + return getToken(CypherParser::MINVALUE, 0); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::KU_MinValueContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_MinValueContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + + +size_t CypherParser::KU_MinValueContext::getRuleIndex() const { + return CypherParser::RuleKU_MinValue; +} + + +CypherParser::KU_MinValueContext* CypherParser::kU_MinValue() { + KU_MinValueContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 66, CypherParser::RuleKU_MinValue); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1010); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::NO: { + enterOuterAlt(_localctx, 1); + setState(1001); + match(CypherParser::NO); + setState(1002); + match(CypherParser::SP); + setState(1003); + match(CypherParser::MINVALUE); + break; + } + + case CypherParser::MINVALUE: { + enterOuterAlt(_localctx, 2); + setState(1004); + match(CypherParser::MINVALUE); + setState(1005); + match(CypherParser::SP); + setState(1007); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::MINUS) { + setState(1006); + match(CypherParser::MINUS); + } + setState(1009); + oC_IntegerLiteral(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_MaxValueContext ------------------------------------------------------------------ + +CypherParser::KU_MaxValueContext::KU_MaxValueContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_MaxValueContext::NO() { + return getToken(CypherParser::NO, 0); +} + +tree::TerminalNode* CypherParser::KU_MaxValueContext::SP() { + return getToken(CypherParser::SP, 0); +} + +tree::TerminalNode* CypherParser::KU_MaxValueContext::MAXVALUE() { + return getToken(CypherParser::MAXVALUE, 0); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::KU_MaxValueContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_MaxValueContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + + +size_t CypherParser::KU_MaxValueContext::getRuleIndex() const { + return CypherParser::RuleKU_MaxValue; +} + + +CypherParser::KU_MaxValueContext* CypherParser::kU_MaxValue() { + KU_MaxValueContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 68, CypherParser::RuleKU_MaxValue); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1021); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::NO: { + enterOuterAlt(_localctx, 1); + setState(1012); + match(CypherParser::NO); + setState(1013); + match(CypherParser::SP); + setState(1014); + match(CypherParser::MAXVALUE); + break; + } + + case CypherParser::MAXVALUE: { + enterOuterAlt(_localctx, 2); + setState(1015); + match(CypherParser::MAXVALUE); + setState(1016); + match(CypherParser::SP); + setState(1018); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::MINUS) { + setState(1017); + match(CypherParser::MINUS); + } + setState(1020); + oC_IntegerLiteral(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_StartWithContext ------------------------------------------------------------------ + +CypherParser::KU_StartWithContext::KU_StartWithContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_StartWithContext::START() { + return getToken(CypherParser::START, 0); +} + +std::vector CypherParser::KU_StartWithContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_StartWithContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::KU_StartWithContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_StartWithContext::WITH() { + return getToken(CypherParser::WITH, 0); +} + +tree::TerminalNode* CypherParser::KU_StartWithContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + + +size_t CypherParser::KU_StartWithContext::getRuleIndex() const { + return CypherParser::RuleKU_StartWith; +} + + +CypherParser::KU_StartWithContext* CypherParser::kU_StartWith() { + KU_StartWithContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 70, CypherParser::RuleKU_StartWith); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1023); + match(CypherParser::START); + setState(1024); + match(CypherParser::SP); + setState(1027); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::WITH) { + setState(1025); + match(CypherParser::WITH); + setState(1026); + match(CypherParser::SP); + } + setState(1030); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::MINUS) { + setState(1029); + match(CypherParser::MINUS); + } + setState(1032); + oC_IntegerLiteral(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CycleContext ------------------------------------------------------------------ + +CypherParser::KU_CycleContext::KU_CycleContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CycleContext::CYCLE() { + return getToken(CypherParser::CYCLE, 0); +} + +tree::TerminalNode* CypherParser::KU_CycleContext::NO() { + return getToken(CypherParser::NO, 0); +} + +tree::TerminalNode* CypherParser::KU_CycleContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::KU_CycleContext::getRuleIndex() const { + return CypherParser::RuleKU_Cycle; +} + + +CypherParser::KU_CycleContext* CypherParser::kU_Cycle() { + KU_CycleContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 72, CypherParser::RuleKU_Cycle); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1036); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::NO) { + setState(1034); + match(CypherParser::NO); + setState(1035); + match(CypherParser::SP); + } + setState(1038); + match(CypherParser::CYCLE); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_IfExistsContext ------------------------------------------------------------------ + +CypherParser::KU_IfExistsContext::KU_IfExistsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_IfExistsContext::IF() { + return getToken(CypherParser::IF, 0); +} + +tree::TerminalNode* CypherParser::KU_IfExistsContext::SP() { + return getToken(CypherParser::SP, 0); +} + +tree::TerminalNode* CypherParser::KU_IfExistsContext::EXISTS() { + return getToken(CypherParser::EXISTS, 0); +} + + +size_t CypherParser::KU_IfExistsContext::getRuleIndex() const { + return CypherParser::RuleKU_IfExists; +} + + +CypherParser::KU_IfExistsContext* CypherParser::kU_IfExists() { + KU_IfExistsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 74, CypherParser::RuleKU_IfExists); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1040); + match(CypherParser::IF); + setState(1041); + match(CypherParser::SP); + setState(1042); + match(CypherParser::EXISTS); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DropContext ------------------------------------------------------------------ + +CypherParser::KU_DropContext::KU_DropContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_DropContext::DROP() { + return getToken(CypherParser::DROP, 0); +} + +std::vector CypherParser::KU_DropContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_DropContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_DropContext::oC_SchemaName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_DropContext::TABLE() { + return getToken(CypherParser::TABLE, 0); +} + +tree::TerminalNode* CypherParser::KU_DropContext::SEQUENCE() { + return getToken(CypherParser::SEQUENCE, 0); +} + +tree::TerminalNode* CypherParser::KU_DropContext::MACRO() { + return getToken(CypherParser::MACRO, 0); +} + +CypherParser::KU_IfExistsContext* CypherParser::KU_DropContext::kU_IfExists() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_DropContext::getRuleIndex() const { + return CypherParser::RuleKU_Drop; +} + + +CypherParser::KU_DropContext* CypherParser::kU_Drop() { + KU_DropContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 76, CypherParser::RuleKU_Drop); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1044); + match(CypherParser::DROP); + setState(1045); + match(CypherParser::SP); + setState(1046); + _la = _input->LA(1); + if (!(((((_la - 105) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 105)) & 2181038081) != 0))) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(1047); + match(CypherParser::SP); + setState(1051); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 121, _ctx)) { + case 1: { + setState(1048); + kU_IfExists(); + setState(1049); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1053); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_AlterTableContext ------------------------------------------------------------------ + +CypherParser::KU_AlterTableContext::KU_AlterTableContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_AlterTableContext::ALTER() { + return getToken(CypherParser::ALTER, 0); +} + +std::vector CypherParser::KU_AlterTableContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_AlterTableContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_AlterTableContext::TABLE() { + return getToken(CypherParser::TABLE, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_AlterTableContext::oC_SchemaName() { + return getRuleContext(0); +} + +CypherParser::KU_AlterOptionsContext* CypherParser::KU_AlterTableContext::kU_AlterOptions() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_AlterTableContext::getRuleIndex() const { + return CypherParser::RuleKU_AlterTable; +} + + +CypherParser::KU_AlterTableContext* CypherParser::kU_AlterTable() { + KU_AlterTableContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 78, CypherParser::RuleKU_AlterTable); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1055); + match(CypherParser::ALTER); + setState(1056); + match(CypherParser::SP); + setState(1057); + match(CypherParser::TABLE); + setState(1058); + match(CypherParser::SP); + setState(1059); + oC_SchemaName(); + setState(1060); + match(CypherParser::SP); + setState(1061); + kU_AlterOptions(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_AlterOptionsContext ------------------------------------------------------------------ + +CypherParser::KU_AlterOptionsContext::KU_AlterOptionsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::KU_AddPropertyContext* CypherParser::KU_AlterOptionsContext::kU_AddProperty() { + return getRuleContext(0); +} + +CypherParser::KU_DropPropertyContext* CypherParser::KU_AlterOptionsContext::kU_DropProperty() { + return getRuleContext(0); +} + +CypherParser::KU_RenameTableContext* CypherParser::KU_AlterOptionsContext::kU_RenameTable() { + return getRuleContext(0); +} + +CypherParser::KU_RenamePropertyContext* CypherParser::KU_AlterOptionsContext::kU_RenameProperty() { + return getRuleContext(0); +} + +CypherParser::KU_AddFromToConnectionContext* CypherParser::KU_AlterOptionsContext::kU_AddFromToConnection() { + return getRuleContext(0); +} + +CypherParser::KU_DropFromToConnectionContext* CypherParser::KU_AlterOptionsContext::kU_DropFromToConnection() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_AlterOptionsContext::getRuleIndex() const { + return CypherParser::RuleKU_AlterOptions; +} + + +CypherParser::KU_AlterOptionsContext* CypherParser::kU_AlterOptions() { + KU_AlterOptionsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 80, CypherParser::RuleKU_AlterOptions); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1069); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 122, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1063); + kU_AddProperty(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1064); + kU_DropProperty(); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(1065); + kU_RenameTable(); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(1066); + kU_RenameProperty(); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(1067); + kU_AddFromToConnection(); + break; + } + + case 6: { + enterOuterAlt(_localctx, 6); + setState(1068); + kU_DropFromToConnection(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_AddPropertyContext ------------------------------------------------------------------ + +CypherParser::KU_AddPropertyContext::KU_AddPropertyContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_AddPropertyContext::ADD() { + return getToken(CypherParser::ADD, 0); +} + +std::vector CypherParser::KU_AddPropertyContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_AddPropertyContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_AddPropertyContext::oC_PropertyKeyName() { + return getRuleContext(0); +} + +CypherParser::KU_DataTypeContext* CypherParser::KU_AddPropertyContext::kU_DataType() { + return getRuleContext(0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_AddPropertyContext::kU_IfNotExists() { + return getRuleContext(0); +} + +CypherParser::KU_DefaultContext* CypherParser::KU_AddPropertyContext::kU_Default() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_AddPropertyContext::getRuleIndex() const { + return CypherParser::RuleKU_AddProperty; +} + + +CypherParser::KU_AddPropertyContext* CypherParser::kU_AddProperty() { + KU_AddPropertyContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 82, CypherParser::RuleKU_AddProperty); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1071); + match(CypherParser::ADD); + setState(1072); + match(CypherParser::SP); + setState(1076); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 123, _ctx)) { + case 1: { + setState(1073); + kU_IfNotExists(); + setState(1074); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1078); + oC_PropertyKeyName(); + setState(1079); + match(CypherParser::SP); + setState(1080); + kU_DataType(0); + setState(1083); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 124, _ctx)) { + case 1: { + setState(1081); + match(CypherParser::SP); + setState(1082); + kU_Default(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DefaultContext ------------------------------------------------------------------ + +CypherParser::KU_DefaultContext::KU_DefaultContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_DefaultContext::DEFAULT() { + return getToken(CypherParser::DEFAULT, 0); +} + +tree::TerminalNode* CypherParser::KU_DefaultContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_DefaultContext::oC_Expression() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_DefaultContext::getRuleIndex() const { + return CypherParser::RuleKU_Default; +} + + +CypherParser::KU_DefaultContext* CypherParser::kU_Default() { + KU_DefaultContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 84, CypherParser::RuleKU_Default); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1085); + match(CypherParser::DEFAULT); + setState(1086); + match(CypherParser::SP); + setState(1087); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DropPropertyContext ------------------------------------------------------------------ + +CypherParser::KU_DropPropertyContext::KU_DropPropertyContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_DropPropertyContext::DROP() { + return getToken(CypherParser::DROP, 0); +} + +std::vector CypherParser::KU_DropPropertyContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_DropPropertyContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_DropPropertyContext::oC_PropertyKeyName() { + return getRuleContext(0); +} + +CypherParser::KU_IfExistsContext* CypherParser::KU_DropPropertyContext::kU_IfExists() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_DropPropertyContext::getRuleIndex() const { + return CypherParser::RuleKU_DropProperty; +} + + +CypherParser::KU_DropPropertyContext* CypherParser::kU_DropProperty() { + KU_DropPropertyContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 86, CypherParser::RuleKU_DropProperty); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1089); + match(CypherParser::DROP); + setState(1090); + match(CypherParser::SP); + setState(1094); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 125, _ctx)) { + case 1: { + setState(1091); + kU_IfExists(); + setState(1092); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1096); + oC_PropertyKeyName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_RenameTableContext ------------------------------------------------------------------ + +CypherParser::KU_RenameTableContext::KU_RenameTableContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_RenameTableContext::RENAME() { + return getToken(CypherParser::RENAME, 0); +} + +std::vector CypherParser::KU_RenameTableContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_RenameTableContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_RenameTableContext::TO() { + return getToken(CypherParser::TO, 0); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_RenameTableContext::oC_SchemaName() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_RenameTableContext::getRuleIndex() const { + return CypherParser::RuleKU_RenameTable; +} + + +CypherParser::KU_RenameTableContext* CypherParser::kU_RenameTable() { + KU_RenameTableContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 88, CypherParser::RuleKU_RenameTable); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1098); + match(CypherParser::RENAME); + setState(1099); + match(CypherParser::SP); + setState(1100); + match(CypherParser::TO); + setState(1101); + match(CypherParser::SP); + setState(1102); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_RenamePropertyContext ------------------------------------------------------------------ + +CypherParser::KU_RenamePropertyContext::KU_RenamePropertyContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_RenamePropertyContext::RENAME() { + return getToken(CypherParser::RENAME, 0); +} + +std::vector CypherParser::KU_RenamePropertyContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_RenamePropertyContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::KU_RenamePropertyContext::oC_PropertyKeyName() { + return getRuleContexts(); +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_RenamePropertyContext::oC_PropertyKeyName(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::KU_RenamePropertyContext::TO() { + return getToken(CypherParser::TO, 0); +} + + +size_t CypherParser::KU_RenamePropertyContext::getRuleIndex() const { + return CypherParser::RuleKU_RenameProperty; +} + + +CypherParser::KU_RenamePropertyContext* CypherParser::kU_RenameProperty() { + KU_RenamePropertyContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 90, CypherParser::RuleKU_RenameProperty); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1104); + match(CypherParser::RENAME); + setState(1105); + match(CypherParser::SP); + setState(1106); + oC_PropertyKeyName(); + setState(1107); + match(CypherParser::SP); + setState(1108); + match(CypherParser::TO); + setState(1109); + match(CypherParser::SP); + setState(1110); + oC_PropertyKeyName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_AddFromToConnectionContext ------------------------------------------------------------------ + +CypherParser::KU_AddFromToConnectionContext::KU_AddFromToConnectionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_AddFromToConnectionContext::ADD() { + return getToken(CypherParser::ADD, 0); +} + +std::vector CypherParser::KU_AddFromToConnectionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_AddFromToConnectionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::KU_FromToConnectionContext* CypherParser::KU_AddFromToConnectionContext::kU_FromToConnection() { + return getRuleContext(0); +} + +CypherParser::KU_IfNotExistsContext* CypherParser::KU_AddFromToConnectionContext::kU_IfNotExists() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_AddFromToConnectionContext::getRuleIndex() const { + return CypherParser::RuleKU_AddFromToConnection; +} + + +CypherParser::KU_AddFromToConnectionContext* CypherParser::kU_AddFromToConnection() { + KU_AddFromToConnectionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 92, CypherParser::RuleKU_AddFromToConnection); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1112); + match(CypherParser::ADD); + setState(1113); + match(CypherParser::SP); + setState(1117); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::IF) { + setState(1114); + kU_IfNotExists(); + setState(1115); + match(CypherParser::SP); + } + setState(1119); + kU_FromToConnection(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DropFromToConnectionContext ------------------------------------------------------------------ + +CypherParser::KU_DropFromToConnectionContext::KU_DropFromToConnectionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_DropFromToConnectionContext::DROP() { + return getToken(CypherParser::DROP, 0); +} + +std::vector CypherParser::KU_DropFromToConnectionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_DropFromToConnectionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::KU_FromToConnectionContext* CypherParser::KU_DropFromToConnectionContext::kU_FromToConnection() { + return getRuleContext(0); +} + +CypherParser::KU_IfExistsContext* CypherParser::KU_DropFromToConnectionContext::kU_IfExists() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_DropFromToConnectionContext::getRuleIndex() const { + return CypherParser::RuleKU_DropFromToConnection; +} + + +CypherParser::KU_DropFromToConnectionContext* CypherParser::kU_DropFromToConnection() { + KU_DropFromToConnectionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 94, CypherParser::RuleKU_DropFromToConnection); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1121); + match(CypherParser::DROP); + setState(1122); + match(CypherParser::SP); + setState(1126); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::IF) { + setState(1123); + kU_IfExists(); + setState(1124); + match(CypherParser::SP); + } + setState(1128); + kU_FromToConnection(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ColumnDefinitionsContext ------------------------------------------------------------------ + +CypherParser::KU_ColumnDefinitionsContext::KU_ColumnDefinitionsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_ColumnDefinitionsContext::kU_ColumnDefinition() { + return getRuleContexts(); +} + +CypherParser::KU_ColumnDefinitionContext* CypherParser::KU_ColumnDefinitionsContext::kU_ColumnDefinition(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_ColumnDefinitionsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_ColumnDefinitionsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_ColumnDefinitionsContext::getRuleIndex() const { + return CypherParser::RuleKU_ColumnDefinitions; +} + + +CypherParser::KU_ColumnDefinitionsContext* CypherParser::kU_ColumnDefinitions() { + KU_ColumnDefinitionsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 96, CypherParser::RuleKU_ColumnDefinitions); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1130); + kU_ColumnDefinition(); + setState(1141); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 130, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1132); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1131); + match(CypherParser::SP); + } + setState(1134); + match(CypherParser::T__3); + setState(1136); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1135); + match(CypherParser::SP); + } + setState(1138); + kU_ColumnDefinition(); + } + setState(1143); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 130, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ColumnDefinitionContext ------------------------------------------------------------------ + +CypherParser::KU_ColumnDefinitionContext::KU_ColumnDefinitionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_ColumnDefinitionContext::oC_PropertyKeyName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_ColumnDefinitionContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::KU_DataTypeContext* CypherParser::KU_ColumnDefinitionContext::kU_DataType() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_ColumnDefinitionContext::getRuleIndex() const { + return CypherParser::RuleKU_ColumnDefinition; +} + + +CypherParser::KU_ColumnDefinitionContext* CypherParser::kU_ColumnDefinition() { + KU_ColumnDefinitionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 98, CypherParser::RuleKU_ColumnDefinition); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1144); + oC_PropertyKeyName(); + setState(1145); + match(CypherParser::SP); + setState(1146); + kU_DataType(0); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_PropertyDefinitionsContext ------------------------------------------------------------------ + +CypherParser::KU_PropertyDefinitionsContext::KU_PropertyDefinitionsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_PropertyDefinitionsContext::kU_PropertyDefinition() { + return getRuleContexts(); +} + +CypherParser::KU_PropertyDefinitionContext* CypherParser::KU_PropertyDefinitionsContext::kU_PropertyDefinition(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_PropertyDefinitionsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_PropertyDefinitionsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_PropertyDefinitionsContext::getRuleIndex() const { + return CypherParser::RuleKU_PropertyDefinitions; +} + + +CypherParser::KU_PropertyDefinitionsContext* CypherParser::kU_PropertyDefinitions() { + KU_PropertyDefinitionsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 100, CypherParser::RuleKU_PropertyDefinitions); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1148); + kU_PropertyDefinition(); + setState(1159); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 133, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1150); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1149); + match(CypherParser::SP); + } + setState(1152); + match(CypherParser::T__3); + setState(1154); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1153); + match(CypherParser::SP); + } + setState(1156); + kU_PropertyDefinition(); + } + setState(1161); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 133, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_PropertyDefinitionContext ------------------------------------------------------------------ + +CypherParser::KU_PropertyDefinitionContext::KU_PropertyDefinitionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::KU_ColumnDefinitionContext* CypherParser::KU_PropertyDefinitionContext::kU_ColumnDefinition() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_PropertyDefinitionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_PropertyDefinitionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::KU_DefaultContext* CypherParser::KU_PropertyDefinitionContext::kU_Default() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_PropertyDefinitionContext::PRIMARY() { + return getToken(CypherParser::PRIMARY, 0); +} + +tree::TerminalNode* CypherParser::KU_PropertyDefinitionContext::KEY() { + return getToken(CypherParser::KEY, 0); +} + + +size_t CypherParser::KU_PropertyDefinitionContext::getRuleIndex() const { + return CypherParser::RuleKU_PropertyDefinition; +} + + +CypherParser::KU_PropertyDefinitionContext* CypherParser::kU_PropertyDefinition() { + KU_PropertyDefinitionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 102, CypherParser::RuleKU_PropertyDefinition); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1162); + kU_ColumnDefinition(); + setState(1165); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 134, _ctx)) { + case 1: { + setState(1163); + match(CypherParser::SP); + setState(1164); + kU_Default(); + break; + } + + default: + break; + } + setState(1171); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 135, _ctx)) { + case 1: { + setState(1167); + match(CypherParser::SP); + setState(1168); + match(CypherParser::PRIMARY); + setState(1169); + match(CypherParser::SP); + setState(1170); + match(CypherParser::KEY); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_CreateNodeConstraintContext ------------------------------------------------------------------ + +CypherParser::KU_CreateNodeConstraintContext::KU_CreateNodeConstraintContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_CreateNodeConstraintContext::PRIMARY() { + return getToken(CypherParser::PRIMARY, 0); +} + +std::vector CypherParser::KU_CreateNodeConstraintContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_CreateNodeConstraintContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_CreateNodeConstraintContext::KEY() { + return getToken(CypherParser::KEY, 0); +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_CreateNodeConstraintContext::oC_PropertyKeyName() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_CreateNodeConstraintContext::getRuleIndex() const { + return CypherParser::RuleKU_CreateNodeConstraint; +} + + +CypherParser::KU_CreateNodeConstraintContext* CypherParser::kU_CreateNodeConstraint() { + KU_CreateNodeConstraintContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 104, CypherParser::RuleKU_CreateNodeConstraint); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1173); + match(CypherParser::PRIMARY); + setState(1174); + match(CypherParser::SP); + setState(1175); + match(CypherParser::KEY); + setState(1177); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1176); + match(CypherParser::SP); + } + setState(1179); + match(CypherParser::T__1); + setState(1181); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1180); + match(CypherParser::SP); + } + setState(1183); + oC_PropertyKeyName(); + setState(1185); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1184); + match(CypherParser::SP); + } + setState(1187); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_UnionTypeContext ------------------------------------------------------------------ + +CypherParser::KU_UnionTypeContext::KU_UnionTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_UnionTypeContext::UNION() { + return getToken(CypherParser::UNION, 0); +} + +CypherParser::KU_ColumnDefinitionsContext* CypherParser::KU_UnionTypeContext::kU_ColumnDefinitions() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_UnionTypeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_UnionTypeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_UnionTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_UnionType; +} + + +CypherParser::KU_UnionTypeContext* CypherParser::kU_UnionType() { + KU_UnionTypeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 106, CypherParser::RuleKU_UnionType); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1189); + match(CypherParser::UNION); + setState(1191); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1190); + match(CypherParser::SP); + } + setState(1193); + match(CypherParser::T__1); + setState(1195); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1194); + match(CypherParser::SP); + } + setState(1197); + kU_ColumnDefinitions(); + setState(1199); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1198); + match(CypherParser::SP); + } + setState(1201); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_StructTypeContext ------------------------------------------------------------------ + +CypherParser::KU_StructTypeContext::KU_StructTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_StructTypeContext::STRUCT() { + return getToken(CypherParser::STRUCT, 0); +} + +CypherParser::KU_ColumnDefinitionsContext* CypherParser::KU_StructTypeContext::kU_ColumnDefinitions() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_StructTypeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_StructTypeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_StructTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_StructType; +} + + +CypherParser::KU_StructTypeContext* CypherParser::kU_StructType() { + KU_StructTypeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 108, CypherParser::RuleKU_StructType); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1203); + match(CypherParser::STRUCT); + setState(1205); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1204); + match(CypherParser::SP); + } + setState(1207); + match(CypherParser::T__1); + setState(1209); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1208); + match(CypherParser::SP); + } + setState(1211); + kU_ColumnDefinitions(); + setState(1213); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1212); + match(CypherParser::SP); + } + setState(1215); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_MapTypeContext ------------------------------------------------------------------ + +CypherParser::KU_MapTypeContext::KU_MapTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_MapTypeContext::MAP() { + return getToken(CypherParser::MAP, 0); +} + +std::vector CypherParser::KU_MapTypeContext::kU_DataType() { + return getRuleContexts(); +} + +CypherParser::KU_DataTypeContext* CypherParser::KU_MapTypeContext::kU_DataType(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_MapTypeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_MapTypeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_MapTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_MapType; +} + + +CypherParser::KU_MapTypeContext* CypherParser::kU_MapType() { + KU_MapTypeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 110, CypherParser::RuleKU_MapType); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1217); + match(CypherParser::MAP); + setState(1219); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1218); + match(CypherParser::SP); + } + setState(1221); + match(CypherParser::T__1); + setState(1223); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1222); + match(CypherParser::SP); + } + setState(1225); + kU_DataType(0); + setState(1227); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1226); + match(CypherParser::SP); + } + setState(1229); + match(CypherParser::T__3); + setState(1231); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1230); + match(CypherParser::SP); + } + setState(1233); + kU_DataType(0); + setState(1235); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1234); + match(CypherParser::SP); + } + setState(1237); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DecimalTypeContext ------------------------------------------------------------------ + +CypherParser::KU_DecimalTypeContext::KU_DecimalTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_DecimalTypeContext::DECIMAL() { + return getToken(CypherParser::DECIMAL, 0); +} + +std::vector CypherParser::KU_DecimalTypeContext::oC_IntegerLiteral() { + return getRuleContexts(); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::KU_DecimalTypeContext::oC_IntegerLiteral(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_DecimalTypeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_DecimalTypeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_DecimalTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_DecimalType; +} + + +CypherParser::KU_DecimalTypeContext* CypherParser::kU_DecimalType() { + KU_DecimalTypeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 112, CypherParser::RuleKU_DecimalType); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1239); + match(CypherParser::DECIMAL); + setState(1241); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1240); + match(CypherParser::SP); + } + setState(1243); + match(CypherParser::T__1); + setState(1245); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1244); + match(CypherParser::SP); + } + setState(1247); + oC_IntegerLiteral(); + setState(1249); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1248); + match(CypherParser::SP); + } + setState(1251); + match(CypherParser::T__3); + setState(1253); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1252); + match(CypherParser::SP); + } + setState(1255); + oC_IntegerLiteral(); + setState(1257); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1256); + match(CypherParser::SP); + } + setState(1259); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_DataTypeContext ------------------------------------------------------------------ + +CypherParser::KU_DataTypeContext::KU_DataTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_DataTypeContext::oC_SymbolicName() { + return getRuleContext(0); +} + +CypherParser::KU_UnionTypeContext* CypherParser::KU_DataTypeContext::kU_UnionType() { + return getRuleContext(0); +} + +CypherParser::KU_StructTypeContext* CypherParser::KU_DataTypeContext::kU_StructType() { + return getRuleContext(0); +} + +CypherParser::KU_MapTypeContext* CypherParser::KU_DataTypeContext::kU_MapType() { + return getRuleContext(0); +} + +CypherParser::KU_DecimalTypeContext* CypherParser::KU_DataTypeContext::kU_DecimalType() { + return getRuleContext(0); +} + +CypherParser::KU_DataTypeContext* CypherParser::KU_DataTypeContext::kU_DataType() { + return getRuleContext(0); +} + +CypherParser::KU_ListIdentifiersContext* CypherParser::KU_DataTypeContext::kU_ListIdentifiers() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_DataTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_DataType; +} + + + +CypherParser::KU_DataTypeContext* CypherParser::kU_DataType() { + return kU_DataType(0); +} + +CypherParser::KU_DataTypeContext* CypherParser::kU_DataType(int precedence) { + ParserRuleContext *parentContext = _ctx; + size_t parentState = getState(); + CypherParser::KU_DataTypeContext *_localctx = _tracker.createInstance(_ctx, parentState); + CypherParser::KU_DataTypeContext *previousContext = _localctx; + (void)previousContext; // Silence compiler, in case the context is not used by generated code. + size_t startState = 114; + enterRecursionRule(_localctx, 114, CypherParser::RuleKU_DataType, precedence); + + + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + unrollRecursionContexts(parentContext); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1267); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 155, _ctx)) { + case 1: { + setState(1262); + oC_SymbolicName(); + break; + } + + case 2: { + setState(1263); + kU_UnionType(); + break; + } + + case 3: { + setState(1264); + kU_StructType(); + break; + } + + case 4: { + setState(1265); + kU_MapType(); + break; + } + + case 5: { + setState(1266); + kU_DecimalType(); + break; + } + + default: + break; + } + _ctx->stop = _input->LT(-1); + setState(1273); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 156, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + if (!_parseListeners.empty()) + triggerExitRuleEvent(); + previousContext = _localctx; + _localctx = _tracker.createInstance(parentContext, parentState); + pushNewRecursionContext(_localctx, startState, RuleKU_DataType); + setState(1269); + + if (!(precpred(_ctx, 5))) throw FailedPredicateException(this, "precpred(_ctx, 5)"); + setState(1270); + kU_ListIdentifiers(); + } + setState(1275); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 156, _ctx); + } + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + return _localctx; +} + +//----------------- KU_ListIdentifiersContext ------------------------------------------------------------------ + +CypherParser::KU_ListIdentifiersContext::KU_ListIdentifiersContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_ListIdentifiersContext::kU_ListIdentifier() { + return getRuleContexts(); +} + +CypherParser::KU_ListIdentifierContext* CypherParser::KU_ListIdentifiersContext::kU_ListIdentifier(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::KU_ListIdentifiersContext::getRuleIndex() const { + return CypherParser::RuleKU_ListIdentifiers; +} + + +CypherParser::KU_ListIdentifiersContext* CypherParser::kU_ListIdentifiers() { + KU_ListIdentifiersContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 116, CypherParser::RuleKU_ListIdentifiers); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1276); + kU_ListIdentifier(); + setState(1280); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 157, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1277); + kU_ListIdentifier(); + } + setState(1282); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 157, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ListIdentifierContext ------------------------------------------------------------------ + +CypherParser::KU_ListIdentifierContext::KU_ListIdentifierContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::KU_ListIdentifierContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_ListIdentifierContext::getRuleIndex() const { + return CypherParser::RuleKU_ListIdentifier; +} + + +CypherParser::KU_ListIdentifierContext* CypherParser::kU_ListIdentifier() { + KU_ListIdentifierContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 118, CypherParser::RuleKU_ListIdentifier); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1283); + match(CypherParser::T__6); + setState(1285); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::DecimalInteger) { + setState(1284); + oC_IntegerLiteral(); + } + setState(1287); + match(CypherParser::T__7); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_AnyCypherOptionContext ------------------------------------------------------------------ + +CypherParser::OC_AnyCypherOptionContext::OC_AnyCypherOptionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ExplainContext* CypherParser::OC_AnyCypherOptionContext::oC_Explain() { + return getRuleContext(0); +} + +CypherParser::OC_ProfileContext* CypherParser::OC_AnyCypherOptionContext::oC_Profile() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_AnyCypherOptionContext::getRuleIndex() const { + return CypherParser::RuleOC_AnyCypherOption; +} + + +CypherParser::OC_AnyCypherOptionContext* CypherParser::oC_AnyCypherOption() { + OC_AnyCypherOptionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 120, CypherParser::RuleOC_AnyCypherOption); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1291); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::EXPLAIN: { + enterOuterAlt(_localctx, 1); + setState(1289); + oC_Explain(); + break; + } + + case CypherParser::PROFILE: { + enterOuterAlt(_localctx, 2); + setState(1290); + oC_Profile(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ExplainContext ------------------------------------------------------------------ + +CypherParser::OC_ExplainContext::OC_ExplainContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_ExplainContext::EXPLAIN() { + return getToken(CypherParser::EXPLAIN, 0); +} + +tree::TerminalNode* CypherParser::OC_ExplainContext::SP() { + return getToken(CypherParser::SP, 0); +} + +tree::TerminalNode* CypherParser::OC_ExplainContext::LOGICAL() { + return getToken(CypherParser::LOGICAL, 0); +} + + +size_t CypherParser::OC_ExplainContext::getRuleIndex() const { + return CypherParser::RuleOC_Explain; +} + + +CypherParser::OC_ExplainContext* CypherParser::oC_Explain() { + OC_ExplainContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 122, CypherParser::RuleOC_Explain); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1293); + match(CypherParser::EXPLAIN); + setState(1296); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 160, _ctx)) { + case 1: { + setState(1294); + match(CypherParser::SP); + setState(1295); + match(CypherParser::LOGICAL); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ProfileContext ------------------------------------------------------------------ + +CypherParser::OC_ProfileContext::OC_ProfileContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_ProfileContext::PROFILE() { + return getToken(CypherParser::PROFILE, 0); +} + + +size_t CypherParser::OC_ProfileContext::getRuleIndex() const { + return CypherParser::RuleOC_Profile; +} + + +CypherParser::OC_ProfileContext* CypherParser::oC_Profile() { + OC_ProfileContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 124, CypherParser::RuleOC_Profile); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1298); + match(CypherParser::PROFILE); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_TransactionContext ------------------------------------------------------------------ + +CypherParser::KU_TransactionContext::KU_TransactionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::BEGIN() { + return getToken(CypherParser::BEGIN, 0); +} + +std::vector CypherParser::KU_TransactionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::TRANSACTION() { + return getToken(CypherParser::TRANSACTION, 0); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::READ() { + return getToken(CypherParser::READ, 0); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::ONLY() { + return getToken(CypherParser::ONLY, 0); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::COMMIT() { + return getToken(CypherParser::COMMIT, 0); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::ROLLBACK() { + return getToken(CypherParser::ROLLBACK, 0); +} + +tree::TerminalNode* CypherParser::KU_TransactionContext::CHECKPOINT() { + return getToken(CypherParser::CHECKPOINT, 0); +} + + +size_t CypherParser::KU_TransactionContext::getRuleIndex() const { + return CypherParser::RuleKU_Transaction; +} + + +CypherParser::KU_TransactionContext* CypherParser::kU_Transaction() { + KU_TransactionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 126, CypherParser::RuleKU_Transaction); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1313); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 161, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1300); + match(CypherParser::BEGIN); + setState(1301); + match(CypherParser::SP); + setState(1302); + match(CypherParser::TRANSACTION); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1303); + match(CypherParser::BEGIN); + setState(1304); + match(CypherParser::SP); + setState(1305); + match(CypherParser::TRANSACTION); + setState(1306); + match(CypherParser::SP); + setState(1307); + match(CypherParser::READ); + setState(1308); + match(CypherParser::SP); + setState(1309); + match(CypherParser::ONLY); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(1310); + match(CypherParser::COMMIT); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(1311); + match(CypherParser::ROLLBACK); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(1312); + match(CypherParser::CHECKPOINT); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ExtensionContext ------------------------------------------------------------------ + +CypherParser::KU_ExtensionContext::KU_ExtensionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::KU_LoadExtensionContext* CypherParser::KU_ExtensionContext::kU_LoadExtension() { + return getRuleContext(0); +} + +CypherParser::KU_InstallExtensionContext* CypherParser::KU_ExtensionContext::kU_InstallExtension() { + return getRuleContext(0); +} + +CypherParser::KU_UninstallExtensionContext* CypherParser::KU_ExtensionContext::kU_UninstallExtension() { + return getRuleContext(0); +} + +CypherParser::KU_UpdateExtensionContext* CypherParser::KU_ExtensionContext::kU_UpdateExtension() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_ExtensionContext::getRuleIndex() const { + return CypherParser::RuleKU_Extension; +} + + +CypherParser::KU_ExtensionContext* CypherParser::kU_Extension() { + KU_ExtensionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 128, CypherParser::RuleKU_Extension); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1319); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::LOAD: { + enterOuterAlt(_localctx, 1); + setState(1315); + kU_LoadExtension(); + break; + } + + case CypherParser::FORCE: + case CypherParser::INSTALL: { + enterOuterAlt(_localctx, 2); + setState(1316); + kU_InstallExtension(); + break; + } + + case CypherParser::UNINSTALL: { + enterOuterAlt(_localctx, 3); + setState(1317); + kU_UninstallExtension(); + break; + } + + case CypherParser::UPDATE: { + enterOuterAlt(_localctx, 4); + setState(1318); + kU_UpdateExtension(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_LoadExtensionContext ------------------------------------------------------------------ + +CypherParser::KU_LoadExtensionContext::KU_LoadExtensionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_LoadExtensionContext::LOAD() { + return getToken(CypherParser::LOAD, 0); +} + +std::vector CypherParser::KU_LoadExtensionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_LoadExtensionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_LoadExtensionContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + +CypherParser::OC_VariableContext* CypherParser::KU_LoadExtensionContext::oC_Variable() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_LoadExtensionContext::EXTENSION() { + return getToken(CypherParser::EXTENSION, 0); +} + + +size_t CypherParser::KU_LoadExtensionContext::getRuleIndex() const { + return CypherParser::RuleKU_LoadExtension; +} + + +CypherParser::KU_LoadExtensionContext* CypherParser::kU_LoadExtension() { + KU_LoadExtensionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 130, CypherParser::RuleKU_LoadExtension); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1321); + match(CypherParser::LOAD); + setState(1322); + match(CypherParser::SP); + setState(1325); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 163, _ctx)) { + case 1: { + setState(1323); + match(CypherParser::EXTENSION); + setState(1324); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1329); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::StringLiteral: { + setState(1327); + match(CypherParser::StringLiteral); + break; + } + + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + setState(1328); + oC_Variable(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_InstallExtensionContext ------------------------------------------------------------------ + +CypherParser::KU_InstallExtensionContext::KU_InstallExtensionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_InstallExtensionContext::INSTALL() { + return getToken(CypherParser::INSTALL, 0); +} + +std::vector CypherParser::KU_InstallExtensionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_InstallExtensionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_VariableContext* CypherParser::KU_InstallExtensionContext::oC_Variable() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_InstallExtensionContext::FORCE() { + return getToken(CypherParser::FORCE, 0); +} + +tree::TerminalNode* CypherParser::KU_InstallExtensionContext::FROM() { + return getToken(CypherParser::FROM, 0); +} + +tree::TerminalNode* CypherParser::KU_InstallExtensionContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + + +size_t CypherParser::KU_InstallExtensionContext::getRuleIndex() const { + return CypherParser::RuleKU_InstallExtension; +} + + +CypherParser::KU_InstallExtensionContext* CypherParser::kU_InstallExtension() { + KU_InstallExtensionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 132, CypherParser::RuleKU_InstallExtension); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1333); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::FORCE) { + setState(1331); + match(CypherParser::FORCE); + setState(1332); + match(CypherParser::SP); + } + setState(1335); + match(CypherParser::INSTALL); + setState(1336); + match(CypherParser::SP); + setState(1337); + oC_Variable(); + setState(1342); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 166, _ctx)) { + case 1: { + setState(1338); + match(CypherParser::SP); + setState(1339); + match(CypherParser::FROM); + setState(1340); + match(CypherParser::SP); + setState(1341); + match(CypherParser::StringLiteral); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_UninstallExtensionContext ------------------------------------------------------------------ + +CypherParser::KU_UninstallExtensionContext::KU_UninstallExtensionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_UninstallExtensionContext::UNINSTALL() { + return getToken(CypherParser::UNINSTALL, 0); +} + +tree::TerminalNode* CypherParser::KU_UninstallExtensionContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_VariableContext* CypherParser::KU_UninstallExtensionContext::oC_Variable() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_UninstallExtensionContext::getRuleIndex() const { + return CypherParser::RuleKU_UninstallExtension; +} + + +CypherParser::KU_UninstallExtensionContext* CypherParser::kU_UninstallExtension() { + KU_UninstallExtensionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 134, CypherParser::RuleKU_UninstallExtension); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1344); + match(CypherParser::UNINSTALL); + setState(1345); + match(CypherParser::SP); + setState(1346); + oC_Variable(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_UpdateExtensionContext ------------------------------------------------------------------ + +CypherParser::KU_UpdateExtensionContext::KU_UpdateExtensionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_UpdateExtensionContext::UPDATE() { + return getToken(CypherParser::UPDATE, 0); +} + +tree::TerminalNode* CypherParser::KU_UpdateExtensionContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_VariableContext* CypherParser::KU_UpdateExtensionContext::oC_Variable() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_UpdateExtensionContext::getRuleIndex() const { + return CypherParser::RuleKU_UpdateExtension; +} + + +CypherParser::KU_UpdateExtensionContext* CypherParser::kU_UpdateExtension() { + KU_UpdateExtensionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 136, CypherParser::RuleKU_UpdateExtension); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1348); + match(CypherParser::UPDATE); + setState(1349); + match(CypherParser::SP); + setState(1350); + oC_Variable(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_QueryContext ------------------------------------------------------------------ + +CypherParser::OC_QueryContext::OC_QueryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_RegularQueryContext* CypherParser::OC_QueryContext::oC_RegularQuery() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_QueryContext::getRuleIndex() const { + return CypherParser::RuleOC_Query; +} + + +CypherParser::OC_QueryContext* CypherParser::oC_Query() { + OC_QueryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 138, CypherParser::RuleOC_Query); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1352); + oC_RegularQuery(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RegularQueryContext ------------------------------------------------------------------ + +CypherParser::OC_RegularQueryContext::OC_RegularQueryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SingleQueryContext* CypherParser::OC_RegularQueryContext::oC_SingleQuery() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_RegularQueryContext::oC_Union() { + return getRuleContexts(); +} + +CypherParser::OC_UnionContext* CypherParser::OC_RegularQueryContext::oC_Union(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_RegularQueryContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_RegularQueryContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::OC_RegularQueryContext::oC_Return() { + return getRuleContexts(); +} + +CypherParser::OC_ReturnContext* CypherParser::OC_RegularQueryContext::oC_Return(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_RegularQueryContext::getRuleIndex() const { + return CypherParser::RuleOC_RegularQuery; +} + + +CypherParser::OC_RegularQueryContext* CypherParser::oC_RegularQuery() { + OC_RegularQueryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 140, CypherParser::RuleOC_RegularQuery); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + setState(1375); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 171, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1354); + oC_SingleQuery(); + setState(1361); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 168, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1356); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1355); + match(CypherParser::SP); + } + setState(1358); + oC_Union(); + } + setState(1363); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 168, _ctx); + } + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1368); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(1364); + oC_Return(); + setState(1366); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1365); + match(CypherParser::SP); + } + break; + } + + default: + throw NoViableAltException(this); + } + setState(1370); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 170, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + setState(1372); + oC_SingleQuery(); + notifyReturnNotAtEnd(_localctx->start); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_UnionContext ------------------------------------------------------------------ + +CypherParser::OC_UnionContext::OC_UnionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_UnionContext::UNION() { + return getToken(CypherParser::UNION, 0); +} + +std::vector CypherParser::OC_UnionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_UnionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_UnionContext::ALL() { + return getToken(CypherParser::ALL, 0); +} + +CypherParser::OC_SingleQueryContext* CypherParser::OC_UnionContext::oC_SingleQuery() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_UnionContext::getRuleIndex() const { + return CypherParser::RuleOC_Union; +} + + +CypherParser::OC_UnionContext* CypherParser::oC_Union() { + OC_UnionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 142, CypherParser::RuleOC_Union); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1389); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 174, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1377); + match(CypherParser::UNION); + setState(1378); + match(CypherParser::SP); + setState(1379); + match(CypherParser::ALL); + setState(1381); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1380); + match(CypherParser::SP); + } + setState(1383); + oC_SingleQuery(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1384); + match(CypherParser::UNION); + setState(1386); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1385); + match(CypherParser::SP); + } + setState(1388); + oC_SingleQuery(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SingleQueryContext ------------------------------------------------------------------ + +CypherParser::OC_SingleQueryContext::OC_SingleQueryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SinglePartQueryContext* CypherParser::OC_SingleQueryContext::oC_SinglePartQuery() { + return getRuleContext(0); +} + +CypherParser::OC_MultiPartQueryContext* CypherParser::OC_SingleQueryContext::oC_MultiPartQuery() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_SingleQueryContext::getRuleIndex() const { + return CypherParser::RuleOC_SingleQuery; +} + + +CypherParser::OC_SingleQueryContext* CypherParser::oC_SingleQuery() { + OC_SingleQueryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 144, CypherParser::RuleOC_SingleQuery); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1393); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 175, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1391); + oC_SinglePartQuery(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1392); + oC_MultiPartQuery(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SinglePartQueryContext ------------------------------------------------------------------ + +CypherParser::OC_SinglePartQueryContext::OC_SinglePartQueryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ReturnContext* CypherParser::OC_SinglePartQueryContext::oC_Return() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_SinglePartQueryContext::oC_ReadingClause() { + return getRuleContexts(); +} + +CypherParser::OC_ReadingClauseContext* CypherParser::OC_SinglePartQueryContext::oC_ReadingClause(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_SinglePartQueryContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_SinglePartQueryContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::OC_SinglePartQueryContext::oC_UpdatingClause() { + return getRuleContexts(); +} + +CypherParser::OC_UpdatingClauseContext* CypherParser::OC_SinglePartQueryContext::oC_UpdatingClause(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_SinglePartQueryContext::getRuleIndex() const { + return CypherParser::RuleOC_SinglePartQuery; +} + + +CypherParser::OC_SinglePartQueryContext* CypherParser::oC_SinglePartQuery() { + OC_SinglePartQueryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 146, CypherParser::RuleOC_SinglePartQuery); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + setState(1430); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 184, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1401); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::CALL || ((((_la - 103) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 103)) & 2199023288329) != 0)) { + setState(1395); + oC_ReadingClause(); + setState(1397); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1396); + match(CypherParser::SP); + } + setState(1403); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(1404); + oC_Return(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1411); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::CALL || ((((_la - 103) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 103)) & 2199023288329) != 0)) { + setState(1405); + oC_ReadingClause(); + setState(1407); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1406); + match(CypherParser::SP); + } + setState(1413); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(1414); + oC_UpdatingClause(); + setState(1421); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 181, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1416); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1415); + match(CypherParser::SP); + } + setState(1418); + oC_UpdatingClause(); + } + setState(1423); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 181, _ctx); + } + setState(1428); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 183, _ctx)) { + case 1: { + setState(1425); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1424); + match(CypherParser::SP); + } + setState(1427); + oC_Return(); + break; + } + + default: + break; + } + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_MultiPartQueryContext ------------------------------------------------------------------ + +CypherParser::OC_MultiPartQueryContext::OC_MultiPartQueryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SinglePartQueryContext* CypherParser::OC_MultiPartQueryContext::oC_SinglePartQuery() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_MultiPartQueryContext::kU_QueryPart() { + return getRuleContexts(); +} + +CypherParser::KU_QueryPartContext* CypherParser::OC_MultiPartQueryContext::kU_QueryPart(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_MultiPartQueryContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_MultiPartQueryContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_MultiPartQueryContext::getRuleIndex() const { + return CypherParser::RuleOC_MultiPartQuery; +} + + +CypherParser::OC_MultiPartQueryContext* CypherParser::oC_MultiPartQuery() { + OC_MultiPartQueryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 148, CypherParser::RuleOC_MultiPartQuery); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1436); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(1432); + kU_QueryPart(); + setState(1434); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1433); + match(CypherParser::SP); + } + break; + } + + default: + throw NoViableAltException(this); + } + setState(1438); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 186, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + setState(1440); + oC_SinglePartQuery(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_QueryPartContext ------------------------------------------------------------------ + +CypherParser::KU_QueryPartContext::KU_QueryPartContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_WithContext* CypherParser::KU_QueryPartContext::oC_With() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_QueryPartContext::oC_ReadingClause() { + return getRuleContexts(); +} + +CypherParser::OC_ReadingClauseContext* CypherParser::KU_QueryPartContext::oC_ReadingClause(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_QueryPartContext::oC_UpdatingClause() { + return getRuleContexts(); +} + +CypherParser::OC_UpdatingClauseContext* CypherParser::KU_QueryPartContext::oC_UpdatingClause(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_QueryPartContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_QueryPartContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_QueryPartContext::getRuleIndex() const { + return CypherParser::RuleKU_QueryPart; +} + + +CypherParser::KU_QueryPartContext* CypherParser::kU_QueryPart() { + KU_QueryPartContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 150, CypherParser::RuleKU_QueryPart); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1448); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::CALL || ((((_la - 103) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 103)) & 2199023288329) != 0)) { + setState(1442); + oC_ReadingClause(); + setState(1444); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1443); + match(CypherParser::SP); + } + setState(1450); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(1457); + _errHandler->sync(this); + _la = _input->LA(1); + while (((((_la - 68) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 68)) & -9223370937343147743) != 0)) { + setState(1451); + oC_UpdatingClause(); + setState(1453); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1452); + match(CypherParser::SP); + } + setState(1459); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(1460); + oC_With(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_UpdatingClauseContext ------------------------------------------------------------------ + +CypherParser::OC_UpdatingClauseContext::OC_UpdatingClauseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_CreateContext* CypherParser::OC_UpdatingClauseContext::oC_Create() { + return getRuleContext(0); +} + +CypherParser::OC_MergeContext* CypherParser::OC_UpdatingClauseContext::oC_Merge() { + return getRuleContext(0); +} + +CypherParser::OC_SetContext* CypherParser::OC_UpdatingClauseContext::oC_Set() { + return getRuleContext(0); +} + +CypherParser::OC_DeleteContext* CypherParser::OC_UpdatingClauseContext::oC_Delete() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_UpdatingClauseContext::getRuleIndex() const { + return CypherParser::RuleOC_UpdatingClause; +} + + +CypherParser::OC_UpdatingClauseContext* CypherParser::oC_UpdatingClause() { + OC_UpdatingClauseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 152, CypherParser::RuleOC_UpdatingClause); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1466); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::CREATE: { + enterOuterAlt(_localctx, 1); + setState(1462); + oC_Create(); + break; + } + + case CypherParser::MERGE: { + enterOuterAlt(_localctx, 2); + setState(1463); + oC_Merge(); + break; + } + + case CypherParser::SET: { + enterOuterAlt(_localctx, 3); + setState(1464); + oC_Set(); + break; + } + + case CypherParser::DELETE: + case CypherParser::DETACH: { + enterOuterAlt(_localctx, 4); + setState(1465); + oC_Delete(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ReadingClauseContext ------------------------------------------------------------------ + +CypherParser::OC_ReadingClauseContext::OC_ReadingClauseContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_MatchContext* CypherParser::OC_ReadingClauseContext::oC_Match() { + return getRuleContext(0); +} + +CypherParser::OC_UnwindContext* CypherParser::OC_ReadingClauseContext::oC_Unwind() { + return getRuleContext(0); +} + +CypherParser::KU_InQueryCallContext* CypherParser::OC_ReadingClauseContext::kU_InQueryCall() { + return getRuleContext(0); +} + +CypherParser::KU_LoadFromContext* CypherParser::OC_ReadingClauseContext::kU_LoadFrom() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_ReadingClauseContext::getRuleIndex() const { + return CypherParser::RuleOC_ReadingClause; +} + + +CypherParser::OC_ReadingClauseContext* CypherParser::oC_ReadingClause() { + OC_ReadingClauseContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 154, CypherParser::RuleOC_ReadingClause); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1472); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::MATCH: + case CypherParser::OPTIONAL: { + enterOuterAlt(_localctx, 1); + setState(1468); + oC_Match(); + break; + } + + case CypherParser::UNWIND: { + enterOuterAlt(_localctx, 2); + setState(1469); + oC_Unwind(); + break; + } + + case CypherParser::CALL: { + enterOuterAlt(_localctx, 3); + setState(1470); + kU_InQueryCall(); + break; + } + + case CypherParser::LOAD: { + enterOuterAlt(_localctx, 4); + setState(1471); + kU_LoadFrom(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_LoadFromContext ------------------------------------------------------------------ + +CypherParser::KU_LoadFromContext::KU_LoadFromContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_LoadFromContext::LOAD() { + return getToken(CypherParser::LOAD, 0); +} + +std::vector CypherParser::KU_LoadFromContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_LoadFromContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_LoadFromContext::FROM() { + return getToken(CypherParser::FROM, 0); +} + +CypherParser::KU_ScanSourceContext* CypherParser::KU_LoadFromContext::kU_ScanSource() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_LoadFromContext::WITH() { + return getToken(CypherParser::WITH, 0); +} + +tree::TerminalNode* CypherParser::KU_LoadFromContext::HEADERS() { + return getToken(CypherParser::HEADERS, 0); +} + +CypherParser::KU_ColumnDefinitionsContext* CypherParser::KU_LoadFromContext::kU_ColumnDefinitions() { + return getRuleContext(0); +} + +CypherParser::KU_OptionsContext* CypherParser::KU_LoadFromContext::kU_Options() { + return getRuleContext(0); +} + +CypherParser::OC_WhereContext* CypherParser::KU_LoadFromContext::oC_Where() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_LoadFromContext::getRuleIndex() const { + return CypherParser::RuleKU_LoadFrom; +} + + +CypherParser::KU_LoadFromContext* CypherParser::kU_LoadFrom() { + KU_LoadFromContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 156, CypherParser::RuleKU_LoadFrom); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1474); + match(CypherParser::LOAD); + setState(1492); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 196, _ctx)) { + case 1: { + setState(1475); + match(CypherParser::SP); + setState(1476); + match(CypherParser::WITH); + setState(1477); + match(CypherParser::SP); + setState(1478); + match(CypherParser::HEADERS); + setState(1480); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1479); + match(CypherParser::SP); + } + setState(1482); + match(CypherParser::T__1); + setState(1484); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1483); + match(CypherParser::SP); + } + setState(1486); + kU_ColumnDefinitions(); + setState(1488); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1487); + match(CypherParser::SP); + } + setState(1490); + match(CypherParser::T__2); + break; + } + + default: + break; + } + setState(1494); + match(CypherParser::SP); + setState(1495); + match(CypherParser::FROM); + setState(1496); + match(CypherParser::SP); + setState(1497); + kU_ScanSource(); + setState(1511); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 200, _ctx)) { + case 1: { + setState(1499); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1498); + match(CypherParser::SP); + } + setState(1501); + match(CypherParser::T__1); + setState(1503); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1502); + match(CypherParser::SP); + } + setState(1505); + kU_Options(); + setState(1507); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1506); + match(CypherParser::SP); + } + setState(1509); + match(CypherParser::T__2); + break; + } + + default: + break; + } + setState(1517); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 202, _ctx)) { + case 1: { + setState(1514); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1513); + match(CypherParser::SP); + } + setState(1516); + oC_Where(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_YieldItemContext ------------------------------------------------------------------ + +CypherParser::OC_YieldItemContext::OC_YieldItemContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_YieldItemContext::oC_Variable() { + return getRuleContexts(); +} + +CypherParser::OC_VariableContext* CypherParser::OC_YieldItemContext::oC_Variable(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_YieldItemContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_YieldItemContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_YieldItemContext::AS() { + return getToken(CypherParser::AS, 0); +} + + +size_t CypherParser::OC_YieldItemContext::getRuleIndex() const { + return CypherParser::RuleOC_YieldItem; +} + + +CypherParser::OC_YieldItemContext* CypherParser::oC_YieldItem() { + OC_YieldItemContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 158, CypherParser::RuleOC_YieldItem); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1524); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 203, _ctx)) { + case 1: { + setState(1519); + oC_Variable(); + setState(1520); + match(CypherParser::SP); + setState(1521); + match(CypherParser::AS); + setState(1522); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1526); + oC_Variable(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_YieldItemsContext ------------------------------------------------------------------ + +CypherParser::OC_YieldItemsContext::OC_YieldItemsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_YieldItemsContext::oC_YieldItem() { + return getRuleContexts(); +} + +CypherParser::OC_YieldItemContext* CypherParser::OC_YieldItemsContext::oC_YieldItem(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_YieldItemsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_YieldItemsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_YieldItemsContext::getRuleIndex() const { + return CypherParser::RuleOC_YieldItems; +} + + +CypherParser::OC_YieldItemsContext* CypherParser::oC_YieldItems() { + OC_YieldItemsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 160, CypherParser::RuleOC_YieldItems); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1528); + oC_YieldItem(); + setState(1539); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 206, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1530); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1529); + match(CypherParser::SP); + } + setState(1532); + match(CypherParser::T__3); + setState(1534); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1533); + match(CypherParser::SP); + } + setState(1536); + oC_YieldItem(); + } + setState(1541); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 206, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_InQueryCallContext ------------------------------------------------------------------ + +CypherParser::KU_InQueryCallContext::KU_InQueryCallContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_InQueryCallContext::CALL() { + return getToken(CypherParser::CALL, 0); +} + +std::vector CypherParser::KU_InQueryCallContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_InQueryCallContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_FunctionInvocationContext* CypherParser::KU_InQueryCallContext::oC_FunctionInvocation() { + return getRuleContext(0); +} + +CypherParser::OC_WhereContext* CypherParser::KU_InQueryCallContext::oC_Where() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_InQueryCallContext::YIELD() { + return getToken(CypherParser::YIELD, 0); +} + +CypherParser::OC_YieldItemsContext* CypherParser::KU_InQueryCallContext::oC_YieldItems() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_InQueryCallContext::getRuleIndex() const { + return CypherParser::RuleKU_InQueryCall; +} + + +CypherParser::KU_InQueryCallContext* CypherParser::kU_InQueryCall() { + KU_InQueryCallContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 162, CypherParser::RuleKU_InQueryCall); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1542); + match(CypherParser::CALL); + setState(1543); + match(CypherParser::SP); + setState(1544); + oC_FunctionInvocation(); + setState(1549); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 208, _ctx)) { + case 1: { + setState(1546); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1545); + match(CypherParser::SP); + } + setState(1548); + oC_Where(); + break; + } + + default: + break; + } + setState(1557); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 210, _ctx)) { + case 1: { + setState(1552); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1551); + match(CypherParser::SP); + } + setState(1554); + match(CypherParser::YIELD); + setState(1555); + match(CypherParser::SP); + setState(1556); + oC_YieldItems(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_MatchContext ------------------------------------------------------------------ + +CypherParser::OC_MatchContext::OC_MatchContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_MatchContext::MATCH() { + return getToken(CypherParser::MATCH, 0); +} + +CypherParser::OC_PatternContext* CypherParser::OC_MatchContext::oC_Pattern() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_MatchContext::OPTIONAL() { + return getToken(CypherParser::OPTIONAL, 0); +} + +std::vector CypherParser::OC_MatchContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_MatchContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_WhereContext* CypherParser::OC_MatchContext::oC_Where() { + return getRuleContext(0); +} + +CypherParser::KU_HintContext* CypherParser::OC_MatchContext::kU_Hint() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_MatchContext::getRuleIndex() const { + return CypherParser::RuleOC_Match; +} + + +CypherParser::OC_MatchContext* CypherParser::oC_Match() { + OC_MatchContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 164, CypherParser::RuleOC_Match); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1561); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::OPTIONAL) { + setState(1559); + match(CypherParser::OPTIONAL); + setState(1560); + match(CypherParser::SP); + } + setState(1563); + match(CypherParser::MATCH); + setState(1565); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1564); + match(CypherParser::SP); + } + setState(1567); + oC_Pattern(); + setState(1570); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 213, _ctx)) { + case 1: { + setState(1568); + match(CypherParser::SP); + setState(1569); + oC_Where(); + break; + } + + default: + break; + } + setState(1574); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 214, _ctx)) { + case 1: { + setState(1572); + match(CypherParser::SP); + setState(1573); + kU_Hint(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_HintContext ------------------------------------------------------------------ + +CypherParser::KU_HintContext::KU_HintContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_HintContext::HINT() { + return getToken(CypherParser::HINT, 0); +} + +tree::TerminalNode* CypherParser::KU_HintContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::KU_JoinNodeContext* CypherParser::KU_HintContext::kU_JoinNode() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_HintContext::getRuleIndex() const { + return CypherParser::RuleKU_Hint; +} + + +CypherParser::KU_HintContext* CypherParser::kU_Hint() { + KU_HintContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 166, CypherParser::RuleKU_Hint); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1576); + match(CypherParser::HINT); + setState(1577); + match(CypherParser::SP); + setState(1578); + kU_JoinNode(0); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_JoinNodeContext ------------------------------------------------------------------ + +CypherParser::KU_JoinNodeContext::KU_JoinNodeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_JoinNodeContext::kU_JoinNode() { + return getRuleContexts(); +} + +CypherParser::KU_JoinNodeContext* CypherParser::KU_JoinNodeContext::kU_JoinNode(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_JoinNodeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_JoinNodeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::KU_JoinNodeContext::oC_SchemaName() { + return getRuleContexts(); +} + +CypherParser::OC_SchemaNameContext* CypherParser::KU_JoinNodeContext::oC_SchemaName(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::KU_JoinNodeContext::JOIN() { + return getToken(CypherParser::JOIN, 0); +} + +std::vector CypherParser::KU_JoinNodeContext::MULTI_JOIN() { + return getTokens(CypherParser::MULTI_JOIN); +} + +tree::TerminalNode* CypherParser::KU_JoinNodeContext::MULTI_JOIN(size_t i) { + return getToken(CypherParser::MULTI_JOIN, i); +} + + +size_t CypherParser::KU_JoinNodeContext::getRuleIndex() const { + return CypherParser::RuleKU_JoinNode; +} + + + +CypherParser::KU_JoinNodeContext* CypherParser::kU_JoinNode() { + return kU_JoinNode(0); +} + +CypherParser::KU_JoinNodeContext* CypherParser::kU_JoinNode(int precedence) { + ParserRuleContext *parentContext = _ctx; + size_t parentState = getState(); + CypherParser::KU_JoinNodeContext *_localctx = _tracker.createInstance(_ctx, parentState); + CypherParser::KU_JoinNodeContext *previousContext = _localctx; + (void)previousContext; // Silence compiler, in case the context is not used by generated code. + size_t startState = 168; + enterRecursionRule(_localctx, 168, CypherParser::RuleKU_JoinNode, precedence); + + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + unrollRecursionContexts(parentContext); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1592); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::T__1: { + setState(1581); + match(CypherParser::T__1); + setState(1583); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1582); + match(CypherParser::SP); + } + setState(1585); + kU_JoinNode(0); + setState(1587); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1586); + match(CypherParser::SP); + } + setState(1589); + match(CypherParser::T__2); + break; + } + + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + setState(1591); + oC_SchemaName(); + break; + } + + default: + throw NoViableAltException(this); + } + _ctx->stop = _input->LT(-1); + setState(1610); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 220, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + if (!_parseListeners.empty()) + triggerExitRuleEvent(); + previousContext = _localctx; + setState(1608); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 219, _ctx)) { + case 1: { + _localctx = _tracker.createInstance(parentContext, parentState); + pushNewRecursionContext(_localctx, startState, RuleKU_JoinNode); + setState(1594); + + if (!(precpred(_ctx, 4))) throw FailedPredicateException(this, "precpred(_ctx, 4)"); + setState(1595); + match(CypherParser::SP); + setState(1596); + match(CypherParser::JOIN); + setState(1597); + match(CypherParser::SP); + setState(1598); + kU_JoinNode(5); + break; + } + + case 2: { + _localctx = _tracker.createInstance(parentContext, parentState); + pushNewRecursionContext(_localctx, startState, RuleKU_JoinNode); + setState(1599); + + if (!(precpred(_ctx, 3))) throw FailedPredicateException(this, "precpred(_ctx, 3)"); + setState(1604); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(1600); + match(CypherParser::SP); + setState(1601); + match(CypherParser::MULTI_JOIN); + setState(1602); + match(CypherParser::SP); + setState(1603); + oC_SchemaName(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(1606); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 218, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + break; + } + + default: + break; + } + } + setState(1612); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 220, _ctx); + } + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + return _localctx; +} + +//----------------- OC_UnwindContext ------------------------------------------------------------------ + +CypherParser::OC_UnwindContext::OC_UnwindContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_UnwindContext::UNWIND() { + return getToken(CypherParser::UNWIND, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_UnwindContext::oC_Expression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_UnwindContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_UnwindContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_UnwindContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::OC_VariableContext* CypherParser::OC_UnwindContext::oC_Variable() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_UnwindContext::getRuleIndex() const { + return CypherParser::RuleOC_Unwind; +} + + +CypherParser::OC_UnwindContext* CypherParser::oC_Unwind() { + OC_UnwindContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 170, CypherParser::RuleOC_Unwind); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1613); + match(CypherParser::UNWIND); + setState(1615); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1614); + match(CypherParser::SP); + } + setState(1617); + oC_Expression(); + setState(1618); + match(CypherParser::SP); + setState(1619); + match(CypherParser::AS); + setState(1620); + match(CypherParser::SP); + setState(1621); + oC_Variable(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_CreateContext ------------------------------------------------------------------ + +CypherParser::OC_CreateContext::OC_CreateContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_CreateContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + +CypherParser::OC_PatternContext* CypherParser::OC_CreateContext::oC_Pattern() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_CreateContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_CreateContext::getRuleIndex() const { + return CypherParser::RuleOC_Create; +} + + +CypherParser::OC_CreateContext* CypherParser::oC_Create() { + OC_CreateContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 172, CypherParser::RuleOC_Create); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1623); + match(CypherParser::CREATE); + setState(1625); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1624); + match(CypherParser::SP); + } + setState(1627); + oC_Pattern(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_MergeContext ------------------------------------------------------------------ + +CypherParser::OC_MergeContext::OC_MergeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_MergeContext::MERGE() { + return getToken(CypherParser::MERGE, 0); +} + +CypherParser::OC_PatternContext* CypherParser::OC_MergeContext::oC_Pattern() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_MergeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_MergeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::OC_MergeContext::oC_MergeAction() { + return getRuleContexts(); +} + +CypherParser::OC_MergeActionContext* CypherParser::OC_MergeContext::oC_MergeAction(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_MergeContext::getRuleIndex() const { + return CypherParser::RuleOC_Merge; +} + + +CypherParser::OC_MergeContext* CypherParser::oC_Merge() { + OC_MergeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 174, CypherParser::RuleOC_Merge); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1629); + match(CypherParser::MERGE); + setState(1631); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1630); + match(CypherParser::SP); + } + setState(1633); + oC_Pattern(); + setState(1638); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 224, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1634); + match(CypherParser::SP); + setState(1635); + oC_MergeAction(); + } + setState(1640); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 224, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_MergeActionContext ------------------------------------------------------------------ + +CypherParser::OC_MergeActionContext::OC_MergeActionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_MergeActionContext::ON() { + return getToken(CypherParser::ON, 0); +} + +std::vector CypherParser::OC_MergeActionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_MergeActionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_MergeActionContext::MATCH() { + return getToken(CypherParser::MATCH, 0); +} + +CypherParser::OC_SetContext* CypherParser::OC_MergeActionContext::oC_Set() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_MergeActionContext::CREATE() { + return getToken(CypherParser::CREATE, 0); +} + + +size_t CypherParser::OC_MergeActionContext::getRuleIndex() const { + return CypherParser::RuleOC_MergeAction; +} + + +CypherParser::OC_MergeActionContext* CypherParser::oC_MergeAction() { + OC_MergeActionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 176, CypherParser::RuleOC_MergeAction); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1651); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 225, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1641); + match(CypherParser::ON); + setState(1642); + match(CypherParser::SP); + setState(1643); + match(CypherParser::MATCH); + setState(1644); + match(CypherParser::SP); + setState(1645); + oC_Set(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1646); + match(CypherParser::ON); + setState(1647); + match(CypherParser::SP); + setState(1648); + match(CypherParser::CREATE); + setState(1649); + match(CypherParser::SP); + setState(1650); + oC_Set(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SetContext ------------------------------------------------------------------ + +CypherParser::OC_SetContext::OC_SetContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_SetContext::SET() { + return getToken(CypherParser::SET, 0); +} + +std::vector CypherParser::OC_SetContext::oC_SetItem() { + return getRuleContexts(); +} + +CypherParser::OC_SetItemContext* CypherParser::OC_SetContext::oC_SetItem(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_SetContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_SetContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_AtomContext* CypherParser::OC_SetContext::oC_Atom() { + return getRuleContext(0); +} + +CypherParser::KU_PropertiesContext* CypherParser::OC_SetContext::kU_Properties() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_SetContext::getRuleIndex() const { + return CypherParser::RuleOC_Set; +} + + +CypherParser::OC_SetContext* CypherParser::oC_Set() { + OC_SetContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 178, CypherParser::RuleOC_Set); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + setState(1685); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 233, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1653); + match(CypherParser::SET); + setState(1655); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1654); + match(CypherParser::SP); + } + setState(1657); + oC_SetItem(); + setState(1668); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 229, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1659); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1658); + match(CypherParser::SP); + } + setState(1661); + match(CypherParser::T__3); + setState(1663); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1662); + match(CypherParser::SP); + } + setState(1665); + oC_SetItem(); + } + setState(1670); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 229, _ctx); + } + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1671); + match(CypherParser::SET); + setState(1673); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1672); + match(CypherParser::SP); + } + setState(1675); + oC_Atom(); + setState(1677); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1676); + match(CypherParser::SP); + } + setState(1679); + match(CypherParser::T__5); + setState(1681); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1680); + match(CypherParser::SP); + } + setState(1683); + kU_Properties(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SetItemContext ------------------------------------------------------------------ + +CypherParser::OC_SetItemContext::OC_SetItemContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_PropertyExpressionContext* CypherParser::OC_SetItemContext::oC_PropertyExpression() { + return getRuleContext(0); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_SetItemContext::oC_Expression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_SetItemContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_SetItemContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_SetItemContext::getRuleIndex() const { + return CypherParser::RuleOC_SetItem; +} + + +CypherParser::OC_SetItemContext* CypherParser::oC_SetItem() { + OC_SetItemContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 180, CypherParser::RuleOC_SetItem); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1687); + oC_PropertyExpression(); + setState(1689); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1688); + match(CypherParser::SP); + } + setState(1691); + match(CypherParser::T__5); + setState(1693); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1692); + match(CypherParser::SP); + } + setState(1695); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_DeleteContext ------------------------------------------------------------------ + +CypherParser::OC_DeleteContext::OC_DeleteContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_DeleteContext::DELETE() { + return getToken(CypherParser::DELETE, 0); +} + +std::vector CypherParser::OC_DeleteContext::oC_Expression() { + return getRuleContexts(); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_DeleteContext::oC_Expression(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::OC_DeleteContext::DETACH() { + return getToken(CypherParser::DETACH, 0); +} + +std::vector CypherParser::OC_DeleteContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_DeleteContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_DeleteContext::getRuleIndex() const { + return CypherParser::RuleOC_Delete; +} + + +CypherParser::OC_DeleteContext* CypherParser::oC_Delete() { + OC_DeleteContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 182, CypherParser::RuleOC_Delete); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1699); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::DETACH) { + setState(1697); + match(CypherParser::DETACH); + setState(1698); + match(CypherParser::SP); + } + setState(1701); + match(CypherParser::DELETE); + setState(1703); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1702); + match(CypherParser::SP); + } + setState(1705); + oC_Expression(); + setState(1716); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 240, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1707); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1706); + match(CypherParser::SP); + } + setState(1709); + match(CypherParser::T__3); + setState(1711); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1710); + match(CypherParser::SP); + } + setState(1713); + oC_Expression(); + } + setState(1718); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 240, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_WithContext ------------------------------------------------------------------ + +CypherParser::OC_WithContext::OC_WithContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_WithContext::WITH() { + return getToken(CypherParser::WITH, 0); +} + +CypherParser::OC_ProjectionBodyContext* CypherParser::OC_WithContext::oC_ProjectionBody() { + return getRuleContext(0); +} + +CypherParser::OC_WhereContext* CypherParser::OC_WithContext::oC_Where() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_WithContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_WithContext::getRuleIndex() const { + return CypherParser::RuleOC_With; +} + + +CypherParser::OC_WithContext* CypherParser::oC_With() { + OC_WithContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 184, CypherParser::RuleOC_With); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1719); + match(CypherParser::WITH); + setState(1720); + oC_ProjectionBody(); + setState(1725); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 242, _ctx)) { + case 1: { + setState(1722); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1721); + match(CypherParser::SP); + } + setState(1724); + oC_Where(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ReturnContext ------------------------------------------------------------------ + +CypherParser::OC_ReturnContext::OC_ReturnContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_ReturnContext::RETURN() { + return getToken(CypherParser::RETURN, 0); +} + +CypherParser::OC_ProjectionBodyContext* CypherParser::OC_ReturnContext::oC_ProjectionBody() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_ReturnContext::getRuleIndex() const { + return CypherParser::RuleOC_Return; +} + + +CypherParser::OC_ReturnContext* CypherParser::oC_Return() { + OC_ReturnContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 186, CypherParser::RuleOC_Return); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1727); + match(CypherParser::RETURN); + setState(1728); + oC_ProjectionBody(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ProjectionBodyContext ------------------------------------------------------------------ + +CypherParser::OC_ProjectionBodyContext::OC_ProjectionBodyContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_ProjectionBodyContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ProjectionBodyContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_ProjectionItemsContext* CypherParser::OC_ProjectionBodyContext::oC_ProjectionItems() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_ProjectionBodyContext::DISTINCT() { + return getToken(CypherParser::DISTINCT, 0); +} + +CypherParser::OC_OrderContext* CypherParser::OC_ProjectionBodyContext::oC_Order() { + return getRuleContext(0); +} + +CypherParser::OC_SkipContext* CypherParser::OC_ProjectionBodyContext::oC_Skip() { + return getRuleContext(0); +} + +CypherParser::OC_LimitContext* CypherParser::OC_ProjectionBodyContext::oC_Limit() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_ProjectionBodyContext::getRuleIndex() const { + return CypherParser::RuleOC_ProjectionBody; +} + + +CypherParser::OC_ProjectionBodyContext* CypherParser::oC_ProjectionBody() { + OC_ProjectionBodyContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 188, CypherParser::RuleOC_ProjectionBody); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1734); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 244, _ctx)) { + case 1: { + setState(1731); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1730); + match(CypherParser::SP); + } + setState(1733); + match(CypherParser::DISTINCT); + break; + } + + default: + break; + } + setState(1736); + match(CypherParser::SP); + setState(1737); + oC_ProjectionItems(); + setState(1740); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 245, _ctx)) { + case 1: { + setState(1738); + match(CypherParser::SP); + setState(1739); + oC_Order(); + break; + } + + default: + break; + } + setState(1744); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 246, _ctx)) { + case 1: { + setState(1742); + match(CypherParser::SP); + setState(1743); + oC_Skip(); + break; + } + + default: + break; + } + setState(1748); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 247, _ctx)) { + case 1: { + setState(1746); + match(CypherParser::SP); + setState(1747); + oC_Limit(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ProjectionItemsContext ------------------------------------------------------------------ + +CypherParser::OC_ProjectionItemsContext::OC_ProjectionItemsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_ProjectionItemsContext::STAR() { + return getToken(CypherParser::STAR, 0); +} + +std::vector CypherParser::OC_ProjectionItemsContext::oC_ProjectionItem() { + return getRuleContexts(); +} + +CypherParser::OC_ProjectionItemContext* CypherParser::OC_ProjectionItemsContext::oC_ProjectionItem(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_ProjectionItemsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ProjectionItemsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_ProjectionItemsContext::getRuleIndex() const { + return CypherParser::RuleOC_ProjectionItems; +} + + +CypherParser::OC_ProjectionItemsContext* CypherParser::oC_ProjectionItems() { + OC_ProjectionItemsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 190, CypherParser::RuleOC_ProjectionItems); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + setState(1778); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::STAR: { + enterOuterAlt(_localctx, 1); + setState(1750); + match(CypherParser::STAR); + setState(1761); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 250, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1752); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1751); + match(CypherParser::SP); + } + setState(1754); + match(CypherParser::T__3); + setState(1756); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1755); + match(CypherParser::SP); + } + setState(1758); + oC_ProjectionItem(); + } + setState(1763); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 250, _ctx); + } + break; + } + + case CypherParser::T__1: + case CypherParser::T__6: + case CypherParser::T__8: + case CypherParser::T__24: + case CypherParser::ANY: + case CypherParser::ADD: + case CypherParser::ALL: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CASE: + case CypherParser::CAST: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXISTS: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FALSE: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::NOT: + case CypherParser::NONE: + case CypherParser::NULL_: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TRUE: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::SINGLE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::MINUS: + case CypherParser::StringLiteral: + case CypherParser::DecimalInteger: + case CypherParser::HexLetter: + case CypherParser::ExponentDecimalReal: + case CypherParser::RegularDecimalReal: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + enterOuterAlt(_localctx, 2); + setState(1764); + oC_ProjectionItem(); + setState(1775); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 253, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1766); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1765); + match(CypherParser::SP); + } + setState(1768); + match(CypherParser::T__3); + setState(1770); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1769); + match(CypherParser::SP); + } + setState(1772); + oC_ProjectionItem(); + } + setState(1777); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 253, _ctx); + } + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ProjectionItemContext ------------------------------------------------------------------ + +CypherParser::OC_ProjectionItemContext::OC_ProjectionItemContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_ProjectionItemContext::oC_Expression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_ProjectionItemContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ProjectionItemContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_ProjectionItemContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::OC_VariableContext* CypherParser::OC_ProjectionItemContext::oC_Variable() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_ProjectionItemContext::getRuleIndex() const { + return CypherParser::RuleOC_ProjectionItem; +} + + +CypherParser::OC_ProjectionItemContext* CypherParser::oC_ProjectionItem() { + OC_ProjectionItemContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 192, CypherParser::RuleOC_ProjectionItem); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1787); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 255, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1780); + oC_Expression(); + setState(1781); + match(CypherParser::SP); + setState(1782); + match(CypherParser::AS); + setState(1783); + match(CypherParser::SP); + setState(1784); + oC_Variable(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1786); + oC_Expression(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_OrderContext ------------------------------------------------------------------ + +CypherParser::OC_OrderContext::OC_OrderContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_OrderContext::ORDER() { + return getToken(CypherParser::ORDER, 0); +} + +std::vector CypherParser::OC_OrderContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_OrderContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_OrderContext::BY() { + return getToken(CypherParser::BY, 0); +} + +std::vector CypherParser::OC_OrderContext::oC_SortItem() { + return getRuleContexts(); +} + +CypherParser::OC_SortItemContext* CypherParser::OC_OrderContext::oC_SortItem(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_OrderContext::getRuleIndex() const { + return CypherParser::RuleOC_Order; +} + + +CypherParser::OC_OrderContext* CypherParser::oC_Order() { + OC_OrderContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 194, CypherParser::RuleOC_Order); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1789); + match(CypherParser::ORDER); + setState(1790); + match(CypherParser::SP); + setState(1791); + match(CypherParser::BY); + setState(1792); + match(CypherParser::SP); + setState(1793); + oC_SortItem(); + setState(1801); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3) { + setState(1794); + match(CypherParser::T__3); + setState(1796); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1795); + match(CypherParser::SP); + } + setState(1798); + oC_SortItem(); + setState(1803); + _errHandler->sync(this); + _la = _input->LA(1); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SkipContext ------------------------------------------------------------------ + +CypherParser::OC_SkipContext::OC_SkipContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_SkipContext::L_SKIP() { + return getToken(CypherParser::L_SKIP, 0); +} + +tree::TerminalNode* CypherParser::OC_SkipContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_SkipContext::oC_Expression() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_SkipContext::getRuleIndex() const { + return CypherParser::RuleOC_Skip; +} + + +CypherParser::OC_SkipContext* CypherParser::oC_Skip() { + OC_SkipContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 196, CypherParser::RuleOC_Skip); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1804); + match(CypherParser::L_SKIP); + setState(1805); + match(CypherParser::SP); + setState(1806); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_LimitContext ------------------------------------------------------------------ + +CypherParser::OC_LimitContext::OC_LimitContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_LimitContext::LIMIT() { + return getToken(CypherParser::LIMIT, 0); +} + +tree::TerminalNode* CypherParser::OC_LimitContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_LimitContext::oC_Expression() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_LimitContext::getRuleIndex() const { + return CypherParser::RuleOC_Limit; +} + + +CypherParser::OC_LimitContext* CypherParser::oC_Limit() { + OC_LimitContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 198, CypherParser::RuleOC_Limit); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1808); + match(CypherParser::LIMIT); + setState(1809); + match(CypherParser::SP); + setState(1810); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SortItemContext ------------------------------------------------------------------ + +CypherParser::OC_SortItemContext::OC_SortItemContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_SortItemContext::oC_Expression() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_SortItemContext::ASCENDING() { + return getToken(CypherParser::ASCENDING, 0); +} + +tree::TerminalNode* CypherParser::OC_SortItemContext::ASC() { + return getToken(CypherParser::ASC, 0); +} + +tree::TerminalNode* CypherParser::OC_SortItemContext::DESCENDING() { + return getToken(CypherParser::DESCENDING, 0); +} + +tree::TerminalNode* CypherParser::OC_SortItemContext::DESC() { + return getToken(CypherParser::DESC, 0); +} + +tree::TerminalNode* CypherParser::OC_SortItemContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_SortItemContext::getRuleIndex() const { + return CypherParser::RuleOC_SortItem; +} + + +CypherParser::OC_SortItemContext* CypherParser::oC_SortItem() { + OC_SortItemContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 200, CypherParser::RuleOC_SortItem); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1812); + oC_Expression(); + setState(1817); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 259, _ctx)) { + case 1: { + setState(1814); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1813); + match(CypherParser::SP); + } + setState(1816); + _la = _input->LA(1); + if (!(((((_la - 52) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 52)) & 12582915) != 0))) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_WhereContext ------------------------------------------------------------------ + +CypherParser::OC_WhereContext::OC_WhereContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_WhereContext::WHERE() { + return getToken(CypherParser::WHERE, 0); +} + +tree::TerminalNode* CypherParser::OC_WhereContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_WhereContext::oC_Expression() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_WhereContext::getRuleIndex() const { + return CypherParser::RuleOC_Where; +} + + +CypherParser::OC_WhereContext* CypherParser::oC_Where() { + OC_WhereContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 202, CypherParser::RuleOC_Where); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1819); + match(CypherParser::WHERE); + setState(1820); + match(CypherParser::SP); + setState(1821); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PatternContext ------------------------------------------------------------------ + +CypherParser::OC_PatternContext::OC_PatternContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_PatternContext::oC_PatternPart() { + return getRuleContexts(); +} + +CypherParser::OC_PatternPartContext* CypherParser::OC_PatternContext::oC_PatternPart(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_PatternContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_PatternContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_PatternContext::getRuleIndex() const { + return CypherParser::RuleOC_Pattern; +} + + +CypherParser::OC_PatternContext* CypherParser::oC_Pattern() { + OC_PatternContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 204, CypherParser::RuleOC_Pattern); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(1823); + oC_PatternPart(); + setState(1834); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 262, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1825); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1824); + match(CypherParser::SP); + } + setState(1827); + match(CypherParser::T__3); + setState(1829); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1828); + match(CypherParser::SP); + } + setState(1831); + oC_PatternPart(); + } + setState(1836); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 262, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PatternPartContext ------------------------------------------------------------------ + +CypherParser::OC_PatternPartContext::OC_PatternPartContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_VariableContext* CypherParser::OC_PatternPartContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::OC_AnonymousPatternPartContext* CypherParser::OC_PatternPartContext::oC_AnonymousPatternPart() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_PatternPartContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_PatternPartContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_PatternPartContext::getRuleIndex() const { + return CypherParser::RuleOC_PatternPart; +} + + +CypherParser::OC_PatternPartContext* CypherParser::oC_PatternPart() { + OC_PatternPartContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 206, CypherParser::RuleOC_PatternPart); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1848); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + enterOuterAlt(_localctx, 1); + setState(1837); + oC_Variable(); + setState(1839); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1838); + match(CypherParser::SP); + } + setState(1841); + match(CypherParser::T__5); + setState(1843); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1842); + match(CypherParser::SP); + } + setState(1845); + oC_AnonymousPatternPart(); + break; + } + + case CypherParser::T__1: { + enterOuterAlt(_localctx, 2); + setState(1847); + oC_AnonymousPatternPart(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_AnonymousPatternPartContext ------------------------------------------------------------------ + +CypherParser::OC_AnonymousPatternPartContext::OC_AnonymousPatternPartContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_PatternElementContext* CypherParser::OC_AnonymousPatternPartContext::oC_PatternElement() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_AnonymousPatternPartContext::getRuleIndex() const { + return CypherParser::RuleOC_AnonymousPatternPart; +} + + +CypherParser::OC_AnonymousPatternPartContext* CypherParser::oC_AnonymousPatternPart() { + OC_AnonymousPatternPartContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 208, CypherParser::RuleOC_AnonymousPatternPart); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1850); + oC_PatternElement(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PatternElementContext ------------------------------------------------------------------ + +CypherParser::OC_PatternElementContext::OC_PatternElementContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_NodePatternContext* CypherParser::OC_PatternElementContext::oC_NodePattern() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_PatternElementContext::oC_PatternElementChain() { + return getRuleContexts(); +} + +CypherParser::OC_PatternElementChainContext* CypherParser::OC_PatternElementContext::oC_PatternElementChain(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_PatternElementContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_PatternElementContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_PatternElementContext* CypherParser::OC_PatternElementContext::oC_PatternElement() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_PatternElementContext::getRuleIndex() const { + return CypherParser::RuleOC_PatternElement; +} + + +CypherParser::OC_PatternElementContext* CypherParser::oC_PatternElement() { + OC_PatternElementContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 210, CypherParser::RuleOC_PatternElement); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + setState(1866); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 268, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1852); + oC_NodePattern(); + setState(1859); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 267, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(1854); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1853); + match(CypherParser::SP); + } + setState(1856); + oC_PatternElementChain(); + } + setState(1861); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 267, _ctx); + } + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1862); + match(CypherParser::T__1); + setState(1863); + oC_PatternElement(); + setState(1864); + match(CypherParser::T__2); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_NodePatternContext ------------------------------------------------------------------ + +CypherParser::OC_NodePatternContext::OC_NodePatternContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_NodePatternContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_NodePatternContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_VariableContext* CypherParser::OC_NodePatternContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::OC_NodeLabelsContext* CypherParser::OC_NodePatternContext::oC_NodeLabels() { + return getRuleContext(0); +} + +CypherParser::KU_PropertiesContext* CypherParser::OC_NodePatternContext::kU_Properties() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_NodePatternContext::getRuleIndex() const { + return CypherParser::RuleOC_NodePattern; +} + + +CypherParser::OC_NodePatternContext* CypherParser::oC_NodePattern() { + OC_NodePatternContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 212, CypherParser::RuleOC_NodePattern); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1868); + match(CypherParser::T__1); + setState(1870); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1869); + match(CypherParser::SP); + } + setState(1876); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -3185593048922849280) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -287985230644762313) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5068755015275819) != 0)) { + setState(1872); + oC_Variable(); + setState(1874); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1873); + match(CypherParser::SP); + } + } + setState(1882); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::COLON) { + setState(1878); + oC_NodeLabels(); + setState(1880); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1879); + match(CypherParser::SP); + } + } + setState(1888); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__8) { + setState(1884); + kU_Properties(); + setState(1886); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1885); + match(CypherParser::SP); + } + } + setState(1890); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PatternElementChainContext ------------------------------------------------------------------ + +CypherParser::OC_PatternElementChainContext::OC_PatternElementChainContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_RelationshipPatternContext* CypherParser::OC_PatternElementChainContext::oC_RelationshipPattern() { + return getRuleContext(0); +} + +CypherParser::OC_NodePatternContext* CypherParser::OC_PatternElementChainContext::oC_NodePattern() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_PatternElementChainContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_PatternElementChainContext::getRuleIndex() const { + return CypherParser::RuleOC_PatternElementChain; +} + + +CypherParser::OC_PatternElementChainContext* CypherParser::oC_PatternElementChain() { + OC_PatternElementChainContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 214, CypherParser::RuleOC_PatternElementChain); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1892); + oC_RelationshipPattern(); + setState(1894); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1893); + match(CypherParser::SP); + } + setState(1896); + oC_NodePattern(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RelationshipPatternContext ------------------------------------------------------------------ + +CypherParser::OC_RelationshipPatternContext::OC_RelationshipPatternContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_LeftArrowHeadContext* CypherParser::OC_RelationshipPatternContext::oC_LeftArrowHead() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_RelationshipPatternContext::oC_Dash() { + return getRuleContexts(); +} + +CypherParser::OC_DashContext* CypherParser::OC_RelationshipPatternContext::oC_Dash(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_RelationshipPatternContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_RelationshipPatternContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_RelationshipDetailContext* CypherParser::OC_RelationshipPatternContext::oC_RelationshipDetail() { + return getRuleContext(0); +} + +CypherParser::OC_RightArrowHeadContext* CypherParser::OC_RelationshipPatternContext::oC_RightArrowHead() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_RelationshipPatternContext::getRuleIndex() const { + return CypherParser::RuleOC_RelationshipPattern; +} + + +CypherParser::OC_RelationshipPatternContext* CypherParser::oC_RelationshipPattern() { + OC_RelationshipPatternContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 216, CypherParser::RuleOC_RelationshipPattern); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(1942); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 288, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(1898); + oC_LeftArrowHead(); + setState(1900); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1899); + match(CypherParser::SP); + } + setState(1902); + oC_Dash(); + setState(1904); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 278, _ctx)) { + case 1: { + setState(1903); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1907); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__6) { + setState(1906); + oC_RelationshipDetail(); + } + setState(1910); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1909); + match(CypherParser::SP); + } + setState(1912); + oC_Dash(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(1914); + oC_Dash(); + setState(1916); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 281, _ctx)) { + case 1: { + setState(1915); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1919); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__6) { + setState(1918); + oC_RelationshipDetail(); + } + setState(1922); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1921); + match(CypherParser::SP); + } + setState(1924); + oC_Dash(); + setState(1926); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1925); + match(CypherParser::SP); + } + setState(1928); + oC_RightArrowHead(); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(1930); + oC_Dash(); + setState(1932); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 285, _ctx)) { + case 1: { + setState(1931); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(1935); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__6) { + setState(1934); + oC_RelationshipDetail(); + } + setState(1938); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1937); + match(CypherParser::SP); + } + setState(1940); + oC_Dash(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RelationshipDetailContext ------------------------------------------------------------------ + +CypherParser::OC_RelationshipDetailContext::OC_RelationshipDetailContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_RelationshipDetailContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_RelationshipDetailContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_VariableContext* CypherParser::OC_RelationshipDetailContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::OC_RelationshipTypesContext* CypherParser::OC_RelationshipDetailContext::oC_RelationshipTypes() { + return getRuleContext(0); +} + +CypherParser::KU_RecursiveDetailContext* CypherParser::OC_RelationshipDetailContext::kU_RecursiveDetail() { + return getRuleContext(0); +} + +CypherParser::KU_PropertiesContext* CypherParser::OC_RelationshipDetailContext::kU_Properties() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_RelationshipDetailContext::getRuleIndex() const { + return CypherParser::RuleOC_RelationshipDetail; +} + + +CypherParser::OC_RelationshipDetailContext* CypherParser::oC_RelationshipDetail() { + OC_RelationshipDetailContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 218, CypherParser::RuleOC_RelationshipDetail); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1944); + match(CypherParser::T__6); + setState(1946); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1945); + match(CypherParser::SP); + } + setState(1952); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -3185593048922849280) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -287985230644762313) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5068755015275819) != 0)) { + setState(1948); + oC_Variable(); + setState(1950); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1949); + match(CypherParser::SP); + } + } + setState(1958); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::COLON) { + setState(1954); + oC_RelationshipTypes(); + setState(1956); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1955); + match(CypherParser::SP); + } + } + setState(1964); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::STAR) { + setState(1960); + kU_RecursiveDetail(); + setState(1962); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1961); + match(CypherParser::SP); + } + } + setState(1970); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__8) { + setState(1966); + kU_Properties(); + setState(1968); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1967); + match(CypherParser::SP); + } + } + setState(1972); + match(CypherParser::T__7); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_PropertiesContext ------------------------------------------------------------------ + +CypherParser::KU_PropertiesContext::KU_PropertiesContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_PropertiesContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_PropertiesContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::KU_PropertiesContext::oC_PropertyKeyName() { + return getRuleContexts(); +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_PropertiesContext::oC_PropertyKeyName(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_PropertiesContext::COLON() { + return getTokens(CypherParser::COLON); +} + +tree::TerminalNode* CypherParser::KU_PropertiesContext::COLON(size_t i) { + return getToken(CypherParser::COLON, i); +} + +std::vector CypherParser::KU_PropertiesContext::oC_Expression() { + return getRuleContexts(); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_PropertiesContext::oC_Expression(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::KU_PropertiesContext::getRuleIndex() const { + return CypherParser::RuleKU_Properties; +} + + +CypherParser::KU_PropertiesContext* CypherParser::kU_Properties() { + KU_PropertiesContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 220, CypherParser::RuleKU_Properties); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(1974); + match(CypherParser::T__8); + setState(1976); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1975); + match(CypherParser::SP); + } + setState(2011); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -3185593048922849280) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -287985230644762313) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5068755015275819) != 0)) { + setState(1978); + oC_PropertyKeyName(); + setState(1980); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1979); + match(CypherParser::SP); + } + setState(1982); + match(CypherParser::COLON); + setState(1984); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1983); + match(CypherParser::SP); + } + setState(1986); + oC_Expression(); + setState(1988); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1987); + match(CypherParser::SP); + } + setState(2008); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3) { + setState(1990); + match(CypherParser::T__3); + setState(1992); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1991); + match(CypherParser::SP); + } + setState(1994); + oC_PropertyKeyName(); + setState(1996); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1995); + match(CypherParser::SP); + } + setState(1998); + match(CypherParser::COLON); + setState(2000); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(1999); + match(CypherParser::SP); + } + setState(2002); + oC_Expression(); + setState(2004); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2003); + match(CypherParser::SP); + } + setState(2010); + _errHandler->sync(this); + _la = _input->LA(1); + } + } + setState(2013); + match(CypherParser::T__9); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RelationshipTypesContext ------------------------------------------------------------------ + +CypherParser::OC_RelationshipTypesContext::OC_RelationshipTypesContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_RelationshipTypesContext::COLON() { + return getTokens(CypherParser::COLON); +} + +tree::TerminalNode* CypherParser::OC_RelationshipTypesContext::COLON(size_t i) { + return getToken(CypherParser::COLON, i); +} + +std::vector CypherParser::OC_RelationshipTypesContext::oC_RelTypeName() { + return getRuleContexts(); +} + +CypherParser::OC_RelTypeNameContext* CypherParser::OC_RelationshipTypesContext::oC_RelTypeName(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_RelationshipTypesContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_RelationshipTypesContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_RelationshipTypesContext::getRuleIndex() const { + return CypherParser::RuleOC_RelationshipTypes; +} + + +CypherParser::OC_RelationshipTypesContext* CypherParser::oC_RelationshipTypes() { + OC_RelationshipTypesContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 222, CypherParser::RuleOC_RelationshipTypes); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2015); + match(CypherParser::COLON); + setState(2017); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2016); + match(CypherParser::SP); + } + setState(2019); + oC_RelTypeName(); + setState(2033); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 312, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2021); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2020); + match(CypherParser::SP); + } + setState(2023); + match(CypherParser::T__10); + setState(2025); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::COLON) { + setState(2024); + match(CypherParser::COLON); + } + setState(2028); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2027); + match(CypherParser::SP); + } + setState(2030); + oC_RelTypeName(); + } + setState(2035); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 312, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_NodeLabelsContext ------------------------------------------------------------------ + +CypherParser::OC_NodeLabelsContext::OC_NodeLabelsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_NodeLabelsContext::COLON() { + return getTokens(CypherParser::COLON); +} + +tree::TerminalNode* CypherParser::OC_NodeLabelsContext::COLON(size_t i) { + return getToken(CypherParser::COLON, i); +} + +std::vector CypherParser::OC_NodeLabelsContext::oC_LabelName() { + return getRuleContexts(); +} + +CypherParser::OC_LabelNameContext* CypherParser::OC_NodeLabelsContext::oC_LabelName(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_NodeLabelsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_NodeLabelsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_NodeLabelsContext::getRuleIndex() const { + return CypherParser::RuleOC_NodeLabels; +} + + +CypherParser::OC_NodeLabelsContext* CypherParser::oC_NodeLabels() { + OC_NodeLabelsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 224, CypherParser::RuleOC_NodeLabels); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2036); + match(CypherParser::COLON); + setState(2038); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2037); + match(CypherParser::SP); + } + setState(2040); + oC_LabelName(); + setState(2057); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 318, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2042); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2041); + match(CypherParser::SP); + } + setState(2049); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::T__10: { + setState(2044); + match(CypherParser::T__10); + setState(2046); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::COLON) { + setState(2045); + match(CypherParser::COLON); + } + break; + } + + case CypherParser::COLON: { + setState(2048); + match(CypherParser::COLON); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2052); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2051); + match(CypherParser::SP); + } + setState(2054); + oC_LabelName(); + } + setState(2059); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 318, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_RecursiveDetailContext ------------------------------------------------------------------ + +CypherParser::KU_RecursiveDetailContext::KU_RecursiveDetailContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_RecursiveDetailContext::STAR() { + return getToken(CypherParser::STAR, 0); +} + +CypherParser::KU_RecursiveTypeContext* CypherParser::KU_RecursiveDetailContext::kU_RecursiveType() { + return getRuleContext(0); +} + +CypherParser::OC_RangeLiteralContext* CypherParser::KU_RecursiveDetailContext::oC_RangeLiteral() { + return getRuleContext(0); +} + +CypherParser::KU_RecursiveComprehensionContext* CypherParser::KU_RecursiveDetailContext::kU_RecursiveComprehension() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_RecursiveDetailContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_RecursiveDetailContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_RecursiveDetailContext::getRuleIndex() const { + return CypherParser::RuleKU_RecursiveDetail; +} + + +CypherParser::KU_RecursiveDetailContext* CypherParser::kU_RecursiveDetail() { + KU_RecursiveDetailContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 226, CypherParser::RuleKU_RecursiveDetail); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2060); + match(CypherParser::STAR); + setState(2065); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 320, _ctx)) { + case 1: { + setState(2062); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2061); + match(CypherParser::SP); + } + setState(2064); + kU_RecursiveType(); + break; + } + + default: + break; + } + setState(2071); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 322, _ctx)) { + case 1: { + setState(2068); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 321, _ctx)) { + case 1: { + setState(2067); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(2070); + oC_RangeLiteral(); + break; + } + + default: + break; + } + setState(2077); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 324, _ctx)) { + case 1: { + setState(2074); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2073); + match(CypherParser::SP); + } + setState(2076); + kU_RecursiveComprehension(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_RecursiveTypeContext ------------------------------------------------------------------ + +CypherParser::KU_RecursiveTypeContext::KU_RecursiveTypeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_RecursiveTypeContext::WSHORTEST() { + return getToken(CypherParser::WSHORTEST, 0); +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::KU_RecursiveTypeContext::oC_PropertyKeyName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_RecursiveTypeContext::ALL() { + return getToken(CypherParser::ALL, 0); +} + +std::vector CypherParser::KU_RecursiveTypeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_RecursiveTypeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::KU_RecursiveTypeContext::SHORTEST() { + return getToken(CypherParser::SHORTEST, 0); +} + +tree::TerminalNode* CypherParser::KU_RecursiveTypeContext::TRAIL() { + return getToken(CypherParser::TRAIL, 0); +} + +tree::TerminalNode* CypherParser::KU_RecursiveTypeContext::ACYCLIC() { + return getToken(CypherParser::ACYCLIC, 0); +} + + +size_t CypherParser::KU_RecursiveTypeContext::getRuleIndex() const { + return CypherParser::RuleKU_RecursiveType; +} + + +CypherParser::KU_RecursiveTypeContext* CypherParser::kU_RecursiveType() { + KU_RecursiveTypeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 228, CypherParser::RuleKU_RecursiveType); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2103); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 329, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2081); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::ALL) { + setState(2079); + match(CypherParser::ALL); + setState(2080); + match(CypherParser::SP); + } + setState(2083); + match(CypherParser::WSHORTEST); + setState(2085); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2084); + match(CypherParser::SP); + } + setState(2087); + match(CypherParser::T__1); + setState(2089); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2088); + match(CypherParser::SP); + } + setState(2091); + oC_PropertyKeyName(); + setState(2093); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2092); + match(CypherParser::SP); + } + setState(2095); + match(CypherParser::T__2); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2097); + match(CypherParser::SHORTEST); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(2098); + match(CypherParser::ALL); + setState(2099); + match(CypherParser::SP); + setState(2100); + match(CypherParser::SHORTEST); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(2101); + match(CypherParser::TRAIL); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(2102); + match(CypherParser::ACYCLIC); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RangeLiteralContext ------------------------------------------------------------------ + +CypherParser::OC_RangeLiteralContext::OC_RangeLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_RangeLiteralContext::DOTDOT() { + return getToken(CypherParser::DOTDOT, 0); +} + +CypherParser::OC_LowerBoundContext* CypherParser::OC_RangeLiteralContext::oC_LowerBound() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_RangeLiteralContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_RangeLiteralContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_UpperBoundContext* CypherParser::OC_RangeLiteralContext::oC_UpperBound() { + return getRuleContext(0); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::OC_RangeLiteralContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_RangeLiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_RangeLiteral; +} + + +CypherParser::OC_RangeLiteralContext* CypherParser::oC_RangeLiteral() { + OC_RangeLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 230, CypherParser::RuleOC_RangeLiteral); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2119); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 334, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2106); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::DecimalInteger) { + setState(2105); + oC_LowerBound(); + } + setState(2109); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2108); + match(CypherParser::SP); + } + setState(2111); + match(CypherParser::DOTDOT); + setState(2113); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 332, _ctx)) { + case 1: { + setState(2112); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(2116); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::DecimalInteger) { + setState(2115); + oC_UpperBound(); + } + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2118); + oC_IntegerLiteral(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_RecursiveComprehensionContext ------------------------------------------------------------------ + +CypherParser::KU_RecursiveComprehensionContext::KU_RecursiveComprehensionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_RecursiveComprehensionContext::oC_Variable() { + return getRuleContexts(); +} + +CypherParser::OC_VariableContext* CypherParser::KU_RecursiveComprehensionContext::oC_Variable(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_RecursiveComprehensionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_RecursiveComprehensionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_WhereContext* CypherParser::KU_RecursiveComprehensionContext::oC_Where() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_RecursiveComprehensionContext::kU_RecursiveProjectionItems() { + return getRuleContexts(); +} + +CypherParser::KU_RecursiveProjectionItemsContext* CypherParser::KU_RecursiveComprehensionContext::kU_RecursiveProjectionItems(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::KU_RecursiveComprehensionContext::getRuleIndex() const { + return CypherParser::RuleKU_RecursiveComprehension; +} + + +CypherParser::KU_RecursiveComprehensionContext* CypherParser::kU_RecursiveComprehension() { + KU_RecursiveComprehensionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 232, CypherParser::RuleKU_RecursiveComprehension); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2121); + match(CypherParser::T__1); + setState(2123); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2122); + match(CypherParser::SP); + } + setState(2125); + oC_Variable(); + setState(2127); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2126); + match(CypherParser::SP); + } + setState(2129); + match(CypherParser::T__3); + setState(2131); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2130); + match(CypherParser::SP); + } + setState(2133); + oC_Variable(); + setState(2145); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 341, _ctx)) { + case 1: { + setState(2135); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2134); + match(CypherParser::SP); + } + setState(2137); + match(CypherParser::T__10); + setState(2139); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2138); + match(CypherParser::SP); + } + setState(2141); + oC_Where(); + setState(2143); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 340, _ctx)) { + case 1: { + setState(2142); + match(CypherParser::SP); + break; + } + + default: + break; + } + break; + } + + default: + break; + } + setState(2166); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::T__10 || _la == CypherParser::SP) { + setState(2148); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2147); + match(CypherParser::SP); + } + setState(2150); + match(CypherParser::T__10); + setState(2152); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2151); + match(CypherParser::SP); + } + setState(2154); + kU_RecursiveProjectionItems(); + setState(2156); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2155); + match(CypherParser::SP); + } + setState(2158); + match(CypherParser::T__3); + setState(2160); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2159); + match(CypherParser::SP); + } + setState(2162); + kU_RecursiveProjectionItems(); + setState(2164); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2163); + match(CypherParser::SP); + } + } + setState(2168); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_RecursiveProjectionItemsContext ------------------------------------------------------------------ + +CypherParser::KU_RecursiveProjectionItemsContext::KU_RecursiveProjectionItemsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_RecursiveProjectionItemsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_RecursiveProjectionItemsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_ProjectionItemsContext* CypherParser::KU_RecursiveProjectionItemsContext::oC_ProjectionItems() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_RecursiveProjectionItemsContext::getRuleIndex() const { + return CypherParser::RuleKU_RecursiveProjectionItems; +} + + +CypherParser::KU_RecursiveProjectionItemsContext* CypherParser::kU_RecursiveProjectionItems() { + KU_RecursiveProjectionItemsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 234, CypherParser::RuleKU_RecursiveProjectionItems); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2170); + match(CypherParser::T__8); + setState(2172); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 348, _ctx)) { + case 1: { + setState(2171); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(2175); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -2320550076713270652) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -286014905805559497) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5492412753616171) != 0)) { + setState(2174); + oC_ProjectionItems(); + } + setState(2178); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2177); + match(CypherParser::SP); + } + setState(2180); + match(CypherParser::T__9); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_LowerBoundContext ------------------------------------------------------------------ + +CypherParser::OC_LowerBoundContext::OC_LowerBoundContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_LowerBoundContext::DecimalInteger() { + return getToken(CypherParser::DecimalInteger, 0); +} + + +size_t CypherParser::OC_LowerBoundContext::getRuleIndex() const { + return CypherParser::RuleOC_LowerBound; +} + + +CypherParser::OC_LowerBoundContext* CypherParser::oC_LowerBound() { + OC_LowerBoundContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 236, CypherParser::RuleOC_LowerBound); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2182); + match(CypherParser::DecimalInteger); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_UpperBoundContext ------------------------------------------------------------------ + +CypherParser::OC_UpperBoundContext::OC_UpperBoundContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_UpperBoundContext::DecimalInteger() { + return getToken(CypherParser::DecimalInteger, 0); +} + + +size_t CypherParser::OC_UpperBoundContext::getRuleIndex() const { + return CypherParser::RuleOC_UpperBound; +} + + +CypherParser::OC_UpperBoundContext* CypherParser::oC_UpperBound() { + OC_UpperBoundContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 238, CypherParser::RuleOC_UpperBound); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2184); + match(CypherParser::DecimalInteger); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_LabelNameContext ------------------------------------------------------------------ + +CypherParser::OC_LabelNameContext::OC_LabelNameContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SchemaNameContext* CypherParser::OC_LabelNameContext::oC_SchemaName() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_LabelNameContext::getRuleIndex() const { + return CypherParser::RuleOC_LabelName; +} + + +CypherParser::OC_LabelNameContext* CypherParser::oC_LabelName() { + OC_LabelNameContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 240, CypherParser::RuleOC_LabelName); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2186); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RelTypeNameContext ------------------------------------------------------------------ + +CypherParser::OC_RelTypeNameContext::OC_RelTypeNameContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SchemaNameContext* CypherParser::OC_RelTypeNameContext::oC_SchemaName() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_RelTypeNameContext::getRuleIndex() const { + return CypherParser::RuleOC_RelTypeName; +} + + +CypherParser::OC_RelTypeNameContext* CypherParser::oC_RelTypeName() { + OC_RelTypeNameContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 242, CypherParser::RuleOC_RelTypeName); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2188); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_ExpressionContext::OC_ExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_OrExpressionContext* CypherParser::OC_ExpressionContext::oC_OrExpression() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_ExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_Expression; +} + + +CypherParser::OC_ExpressionContext* CypherParser::oC_Expression() { + OC_ExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 244, CypherParser::RuleOC_Expression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2190); + oC_OrExpression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_OrExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_OrExpressionContext::OC_OrExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_OrExpressionContext::oC_XorExpression() { + return getRuleContexts(); +} + +CypherParser::OC_XorExpressionContext* CypherParser::OC_OrExpressionContext::oC_XorExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_OrExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_OrExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::OC_OrExpressionContext::OR() { + return getTokens(CypherParser::OR); +} + +tree::TerminalNode* CypherParser::OC_OrExpressionContext::OR(size_t i) { + return getToken(CypherParser::OR, i); +} + + +size_t CypherParser::OC_OrExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_OrExpression; +} + + +CypherParser::OC_OrExpressionContext* CypherParser::oC_OrExpression() { + OC_OrExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 246, CypherParser::RuleOC_OrExpression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2192); + oC_XorExpression(); + setState(2199); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 351, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2193); + match(CypherParser::SP); + setState(2194); + match(CypherParser::OR); + setState(2195); + match(CypherParser::SP); + setState(2196); + oC_XorExpression(); + } + setState(2201); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 351, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_XorExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_XorExpressionContext::OC_XorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_XorExpressionContext::oC_AndExpression() { + return getRuleContexts(); +} + +CypherParser::OC_AndExpressionContext* CypherParser::OC_XorExpressionContext::oC_AndExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_XorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_XorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::OC_XorExpressionContext::XOR() { + return getTokens(CypherParser::XOR); +} + +tree::TerminalNode* CypherParser::OC_XorExpressionContext::XOR(size_t i) { + return getToken(CypherParser::XOR, i); +} + + +size_t CypherParser::OC_XorExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_XorExpression; +} + + +CypherParser::OC_XorExpressionContext* CypherParser::oC_XorExpression() { + OC_XorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 248, CypherParser::RuleOC_XorExpression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2202); + oC_AndExpression(); + setState(2209); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 352, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2203); + match(CypherParser::SP); + setState(2204); + match(CypherParser::XOR); + setState(2205); + match(CypherParser::SP); + setState(2206); + oC_AndExpression(); + } + setState(2211); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 352, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_AndExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_AndExpressionContext::OC_AndExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_AndExpressionContext::oC_NotExpression() { + return getRuleContexts(); +} + +CypherParser::OC_NotExpressionContext* CypherParser::OC_AndExpressionContext::oC_NotExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_AndExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_AndExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +std::vector CypherParser::OC_AndExpressionContext::AND() { + return getTokens(CypherParser::AND); +} + +tree::TerminalNode* CypherParser::OC_AndExpressionContext::AND(size_t i) { + return getToken(CypherParser::AND, i); +} + + +size_t CypherParser::OC_AndExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_AndExpression; +} + + +CypherParser::OC_AndExpressionContext* CypherParser::oC_AndExpression() { + OC_AndExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 250, CypherParser::RuleOC_AndExpression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2212); + oC_NotExpression(); + setState(2219); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 353, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2213); + match(CypherParser::SP); + setState(2214); + match(CypherParser::AND); + setState(2215); + match(CypherParser::SP); + setState(2216); + oC_NotExpression(); + } + setState(2221); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 353, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_NotExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_NotExpressionContext::OC_NotExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ComparisonExpressionContext* CypherParser::OC_NotExpressionContext::oC_ComparisonExpression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_NotExpressionContext::NOT() { + return getTokens(CypherParser::NOT); +} + +tree::TerminalNode* CypherParser::OC_NotExpressionContext::NOT(size_t i) { + return getToken(CypherParser::NOT, i); +} + +std::vector CypherParser::OC_NotExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_NotExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_NotExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_NotExpression; +} + + +CypherParser::OC_NotExpressionContext* CypherParser::oC_NotExpression() { + OC_NotExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 252, CypherParser::RuleOC_NotExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2228); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::NOT) { + setState(2222); + match(CypherParser::NOT); + setState(2224); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2223); + match(CypherParser::SP); + } + setState(2230); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(2231); + oC_ComparisonExpression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ComparisonExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_ComparisonExpressionContext::OC_ComparisonExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_ComparisonExpressionContext::kU_BitwiseOrOperatorExpression() { + return getRuleContexts(); +} + +CypherParser::KU_BitwiseOrOperatorExpressionContext* CypherParser::OC_ComparisonExpressionContext::kU_BitwiseOrOperatorExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_ComparisonExpressionContext::kU_ComparisonOperator() { + return getRuleContexts(); +} + +CypherParser::KU_ComparisonOperatorContext* CypherParser::OC_ComparisonExpressionContext::kU_ComparisonOperator(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_ComparisonExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ComparisonExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_ComparisonExpressionContext::INVALID_NOT_EQUAL() { + return getToken(CypherParser::INVALID_NOT_EQUAL, 0); +} + + +size_t CypherParser::OC_ComparisonExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_ComparisonExpression; +} + + +CypherParser::OC_ComparisonExpressionContext* CypherParser::oC_ComparisonExpression() { + OC_ComparisonExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 254, CypherParser::RuleOC_ComparisonExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + setState(2281); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 366, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2233); + kU_BitwiseOrOperatorExpression(); + setState(2243); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 358, _ctx)) { + case 1: { + setState(2235); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2234); + match(CypherParser::SP); + } + setState(2237); + kU_ComparisonOperator(); + setState(2239); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2238); + match(CypherParser::SP); + } + setState(2241); + kU_BitwiseOrOperatorExpression(); + break; + } + + default: + break; + } + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2245); + kU_BitwiseOrOperatorExpression(); + + setState(2247); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2246); + match(CypherParser::SP); + } + setState(2249); + antlrcpp::downCast(_localctx)->invalid_not_equalToken = match(CypherParser::INVALID_NOT_EQUAL); + setState(2251); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2250); + match(CypherParser::SP); + } + setState(2253); + kU_BitwiseOrOperatorExpression(); + notifyInvalidNotEqualOperator(antlrcpp::downCast(_localctx)->invalid_not_equalToken); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(2257); + kU_BitwiseOrOperatorExpression(); + setState(2259); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2258); + match(CypherParser::SP); + } + setState(2261); + kU_ComparisonOperator(); + setState(2263); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2262); + match(CypherParser::SP); + } + setState(2265); + kU_BitwiseOrOperatorExpression(); + setState(2275); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(2267); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2266); + match(CypherParser::SP); + } + setState(2269); + kU_ComparisonOperator(); + setState(2271); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2270); + match(CypherParser::SP); + } + setState(2273); + kU_BitwiseOrOperatorExpression(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2277); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 365, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + notifyNonBinaryComparison(_localctx->start); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ComparisonOperatorContext ------------------------------------------------------------------ + +CypherParser::KU_ComparisonOperatorContext::KU_ComparisonOperatorContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + + +size_t CypherParser::KU_ComparisonOperatorContext::getRuleIndex() const { + return CypherParser::RuleKU_ComparisonOperator; +} + + +CypherParser::KU_ComparisonOperatorContext* CypherParser::kU_ComparisonOperator() { + KU_ComparisonOperatorContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 256, CypherParser::RuleKU_ComparisonOperator); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2283); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 127040) != 0))) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_BitwiseOrOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::KU_BitwiseOrOperatorExpressionContext::KU_BitwiseOrOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_BitwiseOrOperatorExpressionContext::kU_BitwiseAndOperatorExpression() { + return getRuleContexts(); +} + +CypherParser::KU_BitwiseAndOperatorExpressionContext* CypherParser::KU_BitwiseOrOperatorExpressionContext::kU_BitwiseAndOperatorExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_BitwiseOrOperatorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_BitwiseOrOperatorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_BitwiseOrOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleKU_BitwiseOrOperatorExpression; +} + + +CypherParser::KU_BitwiseOrOperatorExpressionContext* CypherParser::kU_BitwiseOrOperatorExpression() { + KU_BitwiseOrOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 258, CypherParser::RuleKU_BitwiseOrOperatorExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2285); + kU_BitwiseAndOperatorExpression(); + setState(2296); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 369, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2287); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2286); + match(CypherParser::SP); + } + setState(2289); + match(CypherParser::T__10); + setState(2291); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2290); + match(CypherParser::SP); + } + setState(2293); + kU_BitwiseAndOperatorExpression(); + } + setState(2298); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 369, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_BitwiseAndOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::KU_BitwiseAndOperatorExpressionContext::KU_BitwiseAndOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_BitwiseAndOperatorExpressionContext::kU_BitShiftOperatorExpression() { + return getRuleContexts(); +} + +CypherParser::KU_BitShiftOperatorExpressionContext* CypherParser::KU_BitwiseAndOperatorExpressionContext::kU_BitShiftOperatorExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_BitwiseAndOperatorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_BitwiseAndOperatorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_BitwiseAndOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleKU_BitwiseAndOperatorExpression; +} + + +CypherParser::KU_BitwiseAndOperatorExpressionContext* CypherParser::kU_BitwiseAndOperatorExpression() { + KU_BitwiseAndOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 260, CypherParser::RuleKU_BitwiseAndOperatorExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2299); + kU_BitShiftOperatorExpression(); + setState(2310); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 372, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2301); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2300); + match(CypherParser::SP); + } + setState(2303); + match(CypherParser::T__16); + setState(2305); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2304); + match(CypherParser::SP); + } + setState(2307); + kU_BitShiftOperatorExpression(); + } + setState(2312); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 372, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_BitShiftOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::KU_BitShiftOperatorExpressionContext::KU_BitShiftOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_BitShiftOperatorExpressionContext::oC_AddOrSubtractExpression() { + return getRuleContexts(); +} + +CypherParser::OC_AddOrSubtractExpressionContext* CypherParser::KU_BitShiftOperatorExpressionContext::oC_AddOrSubtractExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_BitShiftOperatorExpressionContext::kU_BitShiftOperator() { + return getRuleContexts(); +} + +CypherParser::KU_BitShiftOperatorContext* CypherParser::KU_BitShiftOperatorExpressionContext::kU_BitShiftOperator(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_BitShiftOperatorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_BitShiftOperatorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_BitShiftOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleKU_BitShiftOperatorExpression; +} + + +CypherParser::KU_BitShiftOperatorExpressionContext* CypherParser::kU_BitShiftOperatorExpression() { + KU_BitShiftOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 262, CypherParser::RuleKU_BitShiftOperatorExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2313); + oC_AddOrSubtractExpression(); + setState(2325); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 375, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2315); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2314); + match(CypherParser::SP); + } + setState(2317); + kU_BitShiftOperator(); + setState(2319); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2318); + match(CypherParser::SP); + } + setState(2321); + oC_AddOrSubtractExpression(); + } + setState(2327); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 375, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_BitShiftOperatorContext ------------------------------------------------------------------ + +CypherParser::KU_BitShiftOperatorContext::KU_BitShiftOperatorContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + + +size_t CypherParser::KU_BitShiftOperatorContext::getRuleIndex() const { + return CypherParser::RuleKU_BitShiftOperator; +} + + +CypherParser::KU_BitShiftOperatorContext* CypherParser::kU_BitShiftOperator() { + KU_BitShiftOperatorContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 264, CypherParser::RuleKU_BitShiftOperator); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2328); + _la = _input->LA(1); + if (!(_la == CypherParser::T__17 + + || _la == CypherParser::T__18)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_AddOrSubtractExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_AddOrSubtractExpressionContext::OC_AddOrSubtractExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_AddOrSubtractExpressionContext::oC_MultiplyDivideModuloExpression() { + return getRuleContexts(); +} + +CypherParser::OC_MultiplyDivideModuloExpressionContext* CypherParser::OC_AddOrSubtractExpressionContext::oC_MultiplyDivideModuloExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_AddOrSubtractExpressionContext::kU_AddOrSubtractOperator() { + return getRuleContexts(); +} + +CypherParser::KU_AddOrSubtractOperatorContext* CypherParser::OC_AddOrSubtractExpressionContext::kU_AddOrSubtractOperator(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_AddOrSubtractExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_AddOrSubtractExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_AddOrSubtractExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_AddOrSubtractExpression; +} + + +CypherParser::OC_AddOrSubtractExpressionContext* CypherParser::oC_AddOrSubtractExpression() { + OC_AddOrSubtractExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 266, CypherParser::RuleOC_AddOrSubtractExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2330); + oC_MultiplyDivideModuloExpression(); + setState(2342); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 378, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2332); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2331); + match(CypherParser::SP); + } + setState(2334); + kU_AddOrSubtractOperator(); + setState(2336); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2335); + match(CypherParser::SP); + } + setState(2338); + oC_MultiplyDivideModuloExpression(); + } + setState(2344); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 378, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_AddOrSubtractOperatorContext ------------------------------------------------------------------ + +CypherParser::KU_AddOrSubtractOperatorContext::KU_AddOrSubtractOperatorContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_AddOrSubtractOperatorContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + + +size_t CypherParser::KU_AddOrSubtractOperatorContext::getRuleIndex() const { + return CypherParser::RuleKU_AddOrSubtractOperator; +} + + +CypherParser::KU_AddOrSubtractOperatorContext* CypherParser::kU_AddOrSubtractOperator() { + KU_AddOrSubtractOperatorContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 268, CypherParser::RuleKU_AddOrSubtractOperator); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2345); + _la = _input->LA(1); + if (!(_la == CypherParser::T__19 || _la == CypherParser::MINUS)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_MultiplyDivideModuloExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_MultiplyDivideModuloExpressionContext::OC_MultiplyDivideModuloExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_MultiplyDivideModuloExpressionContext::oC_PowerOfExpression() { + return getRuleContexts(); +} + +CypherParser::OC_PowerOfExpressionContext* CypherParser::OC_MultiplyDivideModuloExpressionContext::oC_PowerOfExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_MultiplyDivideModuloExpressionContext::kU_MultiplyDivideModuloOperator() { + return getRuleContexts(); +} + +CypherParser::KU_MultiplyDivideModuloOperatorContext* CypherParser::OC_MultiplyDivideModuloExpressionContext::kU_MultiplyDivideModuloOperator(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_MultiplyDivideModuloExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_MultiplyDivideModuloExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_MultiplyDivideModuloExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_MultiplyDivideModuloExpression; +} + + +CypherParser::OC_MultiplyDivideModuloExpressionContext* CypherParser::oC_MultiplyDivideModuloExpression() { + OC_MultiplyDivideModuloExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 270, CypherParser::RuleOC_MultiplyDivideModuloExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2347); + oC_PowerOfExpression(); + setState(2359); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 381, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2349); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2348); + match(CypherParser::SP); + } + setState(2351); + kU_MultiplyDivideModuloOperator(); + setState(2353); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2352); + match(CypherParser::SP); + } + setState(2355); + oC_PowerOfExpression(); + } + setState(2361); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 381, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_MultiplyDivideModuloOperatorContext ------------------------------------------------------------------ + +CypherParser::KU_MultiplyDivideModuloOperatorContext::KU_MultiplyDivideModuloOperatorContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_MultiplyDivideModuloOperatorContext::STAR() { + return getToken(CypherParser::STAR, 0); +} + + +size_t CypherParser::KU_MultiplyDivideModuloOperatorContext::getRuleIndex() const { + return CypherParser::RuleKU_MultiplyDivideModuloOperator; +} + + +CypherParser::KU_MultiplyDivideModuloOperatorContext* CypherParser::kU_MultiplyDivideModuloOperator() { + KU_MultiplyDivideModuloOperatorContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 272, CypherParser::RuleKU_MultiplyDivideModuloOperator); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2362); + _la = _input->LA(1); + if (!(_la == CypherParser::T__20 + + || _la == CypherParser::T__21 || _la == CypherParser::STAR)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PowerOfExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_PowerOfExpressionContext::OC_PowerOfExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_PowerOfExpressionContext::oC_StringListNullOperatorExpression() { + return getRuleContexts(); +} + +CypherParser::OC_StringListNullOperatorExpressionContext* CypherParser::OC_PowerOfExpressionContext::oC_StringListNullOperatorExpression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_PowerOfExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_PowerOfExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_PowerOfExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_PowerOfExpression; +} + + +CypherParser::OC_PowerOfExpressionContext* CypherParser::oC_PowerOfExpression() { + OC_PowerOfExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 274, CypherParser::RuleOC_PowerOfExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2364); + oC_StringListNullOperatorExpression(); + setState(2375); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 384, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2366); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2365); + match(CypherParser::SP); + } + setState(2368); + match(CypherParser::T__22); + setState(2370); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2369); + match(CypherParser::SP); + } + setState(2372); + oC_StringListNullOperatorExpression(); + } + setState(2377); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 384, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_StringListNullOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_StringListNullOperatorExpressionContext::OC_StringListNullOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext* CypherParser::OC_StringListNullOperatorExpressionContext::oC_UnaryAddSubtractOrFactorialExpression() { + return getRuleContext(0); +} + +CypherParser::OC_StringOperatorExpressionContext* CypherParser::OC_StringListNullOperatorExpressionContext::oC_StringOperatorExpression() { + return getRuleContext(0); +} + +CypherParser::OC_NullOperatorExpressionContext* CypherParser::OC_StringListNullOperatorExpressionContext::oC_NullOperatorExpression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_StringListNullOperatorExpressionContext::oC_ListOperatorExpression() { + return getRuleContexts(); +} + +CypherParser::OC_ListOperatorExpressionContext* CypherParser::OC_StringListNullOperatorExpressionContext::oC_ListOperatorExpression(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_StringListNullOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_StringListNullOperatorExpression; +} + + +CypherParser::OC_StringListNullOperatorExpressionContext* CypherParser::oC_StringListNullOperatorExpression() { + OC_StringListNullOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 276, CypherParser::RuleOC_StringListNullOperatorExpression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2378); + oC_UnaryAddSubtractOrFactorialExpression(); + setState(2386); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 386, _ctx)) { + case 1: { + setState(2379); + oC_StringOperatorExpression(); + break; + } + + case 2: { + setState(2381); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(2380); + oC_ListOperatorExpression(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2383); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 385, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + break; + } + + case 3: { + setState(2385); + oC_NullOperatorExpression(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ListOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_ListOperatorExpressionContext::OC_ListOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_ListOperatorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ListOperatorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_ListOperatorExpressionContext::IN() { + return getToken(CypherParser::IN, 0); +} + +CypherParser::OC_PropertyOrLabelsExpressionContext* CypherParser::OC_ListOperatorExpressionContext::oC_PropertyOrLabelsExpression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_ListOperatorExpressionContext::oC_Expression() { + return getRuleContexts(); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_ListOperatorExpressionContext::oC_Expression(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::OC_ListOperatorExpressionContext::COLON() { + return getToken(CypherParser::COLON, 0); +} + +tree::TerminalNode* CypherParser::OC_ListOperatorExpressionContext::DOTDOT() { + return getToken(CypherParser::DOTDOT, 0); +} + + +size_t CypherParser::OC_ListOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_ListOperatorExpression; +} + + +CypherParser::OC_ListOperatorExpressionContext* CypherParser::oC_ListOperatorExpression() { + OC_ListOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 278, CypherParser::RuleOC_ListOperatorExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2407); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 390, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2388); + match(CypherParser::SP); + setState(2389); + match(CypherParser::IN); + setState(2391); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2390); + match(CypherParser::SP); + } + setState(2393); + oC_PropertyOrLabelsExpression(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2394); + match(CypherParser::T__6); + setState(2395); + oC_Expression(); + setState(2396); + match(CypherParser::T__7); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(2398); + match(CypherParser::T__6); + setState(2400); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -2320550076713270652) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -286014905805559497) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5492410606132523) != 0)) { + setState(2399); + oC_Expression(); + } + setState(2402); + _la = _input->LA(1); + if (!(_la == CypherParser::COLON + + || _la == CypherParser::DOTDOT)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(2404); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -2320550076713270652) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -286014905805559497) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5492410606132523) != 0)) { + setState(2403); + oC_Expression(); + } + setState(2406); + match(CypherParser::T__7); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_StringOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_StringOperatorExpressionContext::OC_StringOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_PropertyOrLabelsExpressionContext* CypherParser::OC_StringOperatorExpressionContext::oC_PropertyOrLabelsExpression() { + return getRuleContext(0); +} + +CypherParser::OC_RegularExpressionContext* CypherParser::OC_StringOperatorExpressionContext::oC_RegularExpression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_StringOperatorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_StringOperatorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_StringOperatorExpressionContext::STARTS() { + return getToken(CypherParser::STARTS, 0); +} + +tree::TerminalNode* CypherParser::OC_StringOperatorExpressionContext::WITH() { + return getToken(CypherParser::WITH, 0); +} + +tree::TerminalNode* CypherParser::OC_StringOperatorExpressionContext::ENDS() { + return getToken(CypherParser::ENDS, 0); +} + +tree::TerminalNode* CypherParser::OC_StringOperatorExpressionContext::CONTAINS() { + return getToken(CypherParser::CONTAINS, 0); +} + + +size_t CypherParser::OC_StringOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_StringOperatorExpression; +} + + +CypherParser::OC_StringOperatorExpressionContext* CypherParser::oC_StringOperatorExpression() { + OC_StringOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 280, CypherParser::RuleOC_StringOperatorExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2420); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 391, _ctx)) { + case 1: { + setState(2409); + oC_RegularExpression(); + break; + } + + case 2: { + setState(2410); + match(CypherParser::SP); + setState(2411); + match(CypherParser::STARTS); + setState(2412); + match(CypherParser::SP); + setState(2413); + match(CypherParser::WITH); + break; + } + + case 3: { + setState(2414); + match(CypherParser::SP); + setState(2415); + match(CypherParser::ENDS); + setState(2416); + match(CypherParser::SP); + setState(2417); + match(CypherParser::WITH); + break; + } + + case 4: { + setState(2418); + match(CypherParser::SP); + setState(2419); + match(CypherParser::CONTAINS); + break; + } + + default: + break; + } + setState(2423); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2422); + match(CypherParser::SP); + } + setState(2425); + oC_PropertyOrLabelsExpression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RegularExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_RegularExpressionContext::OC_RegularExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_RegularExpressionContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_RegularExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_RegularExpression; +} + + +CypherParser::OC_RegularExpressionContext* CypherParser::oC_RegularExpression() { + OC_RegularExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 282, CypherParser::RuleOC_RegularExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2428); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2427); + match(CypherParser::SP); + } + setState(2430); + match(CypherParser::T__23); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_NullOperatorExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_NullOperatorExpressionContext::OC_NullOperatorExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_NullOperatorExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_NullOperatorExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_NullOperatorExpressionContext::IS() { + return getToken(CypherParser::IS, 0); +} + +tree::TerminalNode* CypherParser::OC_NullOperatorExpressionContext::NULL_() { + return getToken(CypherParser::NULL_, 0); +} + +tree::TerminalNode* CypherParser::OC_NullOperatorExpressionContext::NOT() { + return getToken(CypherParser::NOT, 0); +} + + +size_t CypherParser::OC_NullOperatorExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_NullOperatorExpression; +} + + +CypherParser::OC_NullOperatorExpressionContext* CypherParser::oC_NullOperatorExpression() { + OC_NullOperatorExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 284, CypherParser::RuleOC_NullOperatorExpression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2442); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 394, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2432); + match(CypherParser::SP); + setState(2433); + match(CypherParser::IS); + setState(2434); + match(CypherParser::SP); + setState(2435); + match(CypherParser::NULL_); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2436); + match(CypherParser::SP); + setState(2437); + match(CypherParser::IS); + setState(2438); + match(CypherParser::SP); + setState(2439); + match(CypherParser::NOT); + setState(2440); + match(CypherParser::SP); + setState(2441); + match(CypherParser::NULL_); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_UnaryAddSubtractOrFactorialExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::OC_UnaryAddSubtractOrFactorialExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_PropertyOrLabelsExpressionContext* CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::oC_PropertyOrLabelsExpression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::MINUS() { + return getTokens(CypherParser::MINUS); +} + +tree::TerminalNode* CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::MINUS(size_t i) { + return getToken(CypherParser::MINUS, i); +} + +tree::TerminalNode* CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::FACTORIAL() { + return getToken(CypherParser::FACTORIAL, 0); +} + +std::vector CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_UnaryAddSubtractOrFactorialExpression; +} + + +CypherParser::OC_UnaryAddSubtractOrFactorialExpressionContext* CypherParser::oC_UnaryAddSubtractOrFactorialExpression() { + OC_UnaryAddSubtractOrFactorialExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 286, CypherParser::RuleOC_UnaryAddSubtractOrFactorialExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2450); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::MINUS) { + setState(2444); + match(CypherParser::MINUS); + setState(2446); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2445); + match(CypherParser::SP); + } + setState(2452); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(2453); + oC_PropertyOrLabelsExpression(); + setState(2458); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 398, _ctx)) { + case 1: { + setState(2455); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2454); + match(CypherParser::SP); + } + setState(2457); + match(CypherParser::FACTORIAL); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PropertyOrLabelsExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_PropertyOrLabelsExpressionContext::OC_PropertyOrLabelsExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_AtomContext* CypherParser::OC_PropertyOrLabelsExpressionContext::oC_Atom() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_PropertyOrLabelsExpressionContext::oC_PropertyLookup() { + return getRuleContexts(); +} + +CypherParser::OC_PropertyLookupContext* CypherParser::OC_PropertyOrLabelsExpressionContext::oC_PropertyLookup(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_PropertyOrLabelsExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_PropertyOrLabelsExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_PropertyOrLabelsExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_PropertyOrLabelsExpression; +} + + +CypherParser::OC_PropertyOrLabelsExpressionContext* CypherParser::oC_PropertyOrLabelsExpression() { + OC_PropertyOrLabelsExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 288, CypherParser::RuleOC_PropertyOrLabelsExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2460); + oC_Atom(); + setState(2467); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 400, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(2462); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2461); + match(CypherParser::SP); + } + setState(2464); + oC_PropertyLookup(); + } + setState(2469); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 400, _ctx); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_AtomContext ------------------------------------------------------------------ + +CypherParser::OC_AtomContext::OC_AtomContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_LiteralContext* CypherParser::OC_AtomContext::oC_Literal() { + return getRuleContext(0); +} + +CypherParser::OC_ParameterContext* CypherParser::OC_AtomContext::oC_Parameter() { + return getRuleContext(0); +} + +CypherParser::OC_CaseExpressionContext* CypherParser::OC_AtomContext::oC_CaseExpression() { + return getRuleContext(0); +} + +CypherParser::OC_ParenthesizedExpressionContext* CypherParser::OC_AtomContext::oC_ParenthesizedExpression() { + return getRuleContext(0); +} + +CypherParser::OC_FunctionInvocationContext* CypherParser::OC_AtomContext::oC_FunctionInvocation() { + return getRuleContext(0); +} + +CypherParser::OC_PathPatternsContext* CypherParser::OC_AtomContext::oC_PathPatterns() { + return getRuleContext(0); +} + +CypherParser::OC_ExistCountSubqueryContext* CypherParser::OC_AtomContext::oC_ExistCountSubquery() { + return getRuleContext(0); +} + +CypherParser::OC_VariableContext* CypherParser::OC_AtomContext::oC_Variable() { + return getRuleContext(0); +} + +CypherParser::OC_QuantifierContext* CypherParser::OC_AtomContext::oC_Quantifier() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_AtomContext::getRuleIndex() const { + return CypherParser::RuleOC_Atom; +} + + +CypherParser::OC_AtomContext* CypherParser::oC_Atom() { + OC_AtomContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 290, CypherParser::RuleOC_Atom); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2479); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 401, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2470); + oC_Literal(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2471); + oC_Parameter(); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(2472); + oC_CaseExpression(); + break; + } + + case 4: { + enterOuterAlt(_localctx, 4); + setState(2473); + oC_ParenthesizedExpression(); + break; + } + + case 5: { + enterOuterAlt(_localctx, 5); + setState(2474); + oC_FunctionInvocation(); + break; + } + + case 6: { + enterOuterAlt(_localctx, 6); + setState(2475); + oC_PathPatterns(); + break; + } + + case 7: { + enterOuterAlt(_localctx, 7); + setState(2476); + oC_ExistCountSubquery(); + break; + } + + case 8: { + enterOuterAlt(_localctx, 8); + setState(2477); + oC_Variable(); + break; + } + + case 9: { + enterOuterAlt(_localctx, 9); + setState(2478); + oC_Quantifier(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_QuantifierContext ------------------------------------------------------------------ + +CypherParser::OC_QuantifierContext::OC_QuantifierContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_QuantifierContext::ALL() { + return getToken(CypherParser::ALL, 0); +} + +CypherParser::OC_FilterExpressionContext* CypherParser::OC_QuantifierContext::oC_FilterExpression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_QuantifierContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_QuantifierContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_QuantifierContext::ANY() { + return getToken(CypherParser::ANY, 0); +} + +tree::TerminalNode* CypherParser::OC_QuantifierContext::NONE() { + return getToken(CypherParser::NONE, 0); +} + +tree::TerminalNode* CypherParser::OC_QuantifierContext::SINGLE() { + return getToken(CypherParser::SINGLE, 0); +} + + +size_t CypherParser::OC_QuantifierContext::getRuleIndex() const { + return CypherParser::RuleOC_Quantifier; +} + + +CypherParser::OC_QuantifierContext* CypherParser::oC_Quantifier() { + OC_QuantifierContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 292, CypherParser::RuleOC_Quantifier); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2537); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ALL: { + enterOuterAlt(_localctx, 1); + setState(2481); + match(CypherParser::ALL); + setState(2483); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2482); + match(CypherParser::SP); + } + setState(2485); + match(CypherParser::T__1); + setState(2487); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2486); + match(CypherParser::SP); + } + setState(2489); + oC_FilterExpression(); + setState(2491); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2490); + match(CypherParser::SP); + } + setState(2493); + match(CypherParser::T__2); + break; + } + + case CypherParser::ANY: { + enterOuterAlt(_localctx, 2); + setState(2495); + match(CypherParser::ANY); + setState(2497); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2496); + match(CypherParser::SP); + } + setState(2499); + match(CypherParser::T__1); + setState(2501); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2500); + match(CypherParser::SP); + } + setState(2503); + oC_FilterExpression(); + setState(2505); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2504); + match(CypherParser::SP); + } + setState(2507); + match(CypherParser::T__2); + break; + } + + case CypherParser::NONE: { + enterOuterAlt(_localctx, 3); + setState(2509); + match(CypherParser::NONE); + setState(2511); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2510); + match(CypherParser::SP); + } + setState(2513); + match(CypherParser::T__1); + setState(2515); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2514); + match(CypherParser::SP); + } + setState(2517); + oC_FilterExpression(); + setState(2519); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2518); + match(CypherParser::SP); + } + setState(2521); + match(CypherParser::T__2); + break; + } + + case CypherParser::SINGLE: { + enterOuterAlt(_localctx, 4); + setState(2523); + match(CypherParser::SINGLE); + setState(2525); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2524); + match(CypherParser::SP); + } + setState(2527); + match(CypherParser::T__1); + setState(2529); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2528); + match(CypherParser::SP); + } + setState(2531); + oC_FilterExpression(); + setState(2533); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2532); + match(CypherParser::SP); + } + setState(2535); + match(CypherParser::T__2); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_FilterExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_FilterExpressionContext::OC_FilterExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_IdInCollContext* CypherParser::OC_FilterExpressionContext::oC_IdInColl() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_FilterExpressionContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_WhereContext* CypherParser::OC_FilterExpressionContext::oC_Where() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_FilterExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_FilterExpression; +} + + +CypherParser::OC_FilterExpressionContext* CypherParser::oC_FilterExpression() { + OC_FilterExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 294, CypherParser::RuleOC_FilterExpression); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2539); + oC_IdInColl(); + setState(2540); + match(CypherParser::SP); + setState(2541); + oC_Where(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_IdInCollContext ------------------------------------------------------------------ + +CypherParser::OC_IdInCollContext::OC_IdInCollContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_VariableContext* CypherParser::OC_IdInCollContext::oC_Variable() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_IdInCollContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_IdInCollContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_IdInCollContext::IN() { + return getToken(CypherParser::IN, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_IdInCollContext::oC_Expression() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_IdInCollContext::getRuleIndex() const { + return CypherParser::RuleOC_IdInColl; +} + + +CypherParser::OC_IdInCollContext* CypherParser::oC_IdInColl() { + OC_IdInCollContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 296, CypherParser::RuleOC_IdInColl); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2543); + oC_Variable(); + setState(2544); + match(CypherParser::SP); + setState(2545); + match(CypherParser::IN); + setState(2546); + match(CypherParser::SP); + setState(2547); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_LiteralContext ------------------------------------------------------------------ + +CypherParser::OC_LiteralContext::OC_LiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_NumberLiteralContext* CypherParser::OC_LiteralContext::oC_NumberLiteral() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_LiteralContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + +CypherParser::OC_BooleanLiteralContext* CypherParser::OC_LiteralContext::oC_BooleanLiteral() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_LiteralContext::NULL_() { + return getToken(CypherParser::NULL_, 0); +} + +CypherParser::OC_ListLiteralContext* CypherParser::OC_LiteralContext::oC_ListLiteral() { + return getRuleContext(0); +} + +CypherParser::KU_StructLiteralContext* CypherParser::OC_LiteralContext::kU_StructLiteral() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_LiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_Literal; +} + + +CypherParser::OC_LiteralContext* CypherParser::oC_Literal() { + OC_LiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 298, CypherParser::RuleOC_Literal); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2555); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::DecimalInteger: + case CypherParser::ExponentDecimalReal: + case CypherParser::RegularDecimalReal: { + enterOuterAlt(_localctx, 1); + setState(2549); + oC_NumberLiteral(); + break; + } + + case CypherParser::StringLiteral: { + enterOuterAlt(_localctx, 2); + setState(2550); + match(CypherParser::StringLiteral); + break; + } + + case CypherParser::FALSE: + case CypherParser::TRUE: { + enterOuterAlt(_localctx, 3); + setState(2551); + oC_BooleanLiteral(); + break; + } + + case CypherParser::NULL_: { + enterOuterAlt(_localctx, 4); + setState(2552); + match(CypherParser::NULL_); + break; + } + + case CypherParser::T__6: { + enterOuterAlt(_localctx, 5); + setState(2553); + oC_ListLiteral(); + break; + } + + case CypherParser::T__8: { + enterOuterAlt(_localctx, 6); + setState(2554); + kU_StructLiteral(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_BooleanLiteralContext ------------------------------------------------------------------ + +CypherParser::OC_BooleanLiteralContext::OC_BooleanLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_BooleanLiteralContext::TRUE() { + return getToken(CypherParser::TRUE, 0); +} + +tree::TerminalNode* CypherParser::OC_BooleanLiteralContext::FALSE() { + return getToken(CypherParser::FALSE, 0); +} + + +size_t CypherParser::OC_BooleanLiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_BooleanLiteral; +} + + +CypherParser::OC_BooleanLiteralContext* CypherParser::oC_BooleanLiteral() { + OC_BooleanLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 300, CypherParser::RuleOC_BooleanLiteral); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2557); + _la = _input->LA(1); + if (!(_la == CypherParser::FALSE + + || _la == CypherParser::TRUE)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ListLiteralContext ------------------------------------------------------------------ + +CypherParser::OC_ListLiteralContext::OC_ListLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::OC_ListLiteralContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ListLiteralContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_ListLiteralContext::oC_Expression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_ListLiteralContext::kU_ListEntry() { + return getRuleContexts(); +} + +CypherParser::KU_ListEntryContext* CypherParser::OC_ListLiteralContext::kU_ListEntry(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_ListLiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_ListLiteral; +} + + +CypherParser::OC_ListLiteralContext* CypherParser::oC_ListLiteral() { + OC_ListLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 302, CypherParser::RuleOC_ListLiteral); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2559); + match(CypherParser::T__6); + setState(2561); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2560); + match(CypherParser::SP); + } + setState(2576); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -2320550076713270652) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -286014905805559497) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5492410606132523) != 0)) { + setState(2563); + oC_Expression(); + setState(2565); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2564); + match(CypherParser::SP); + } + setState(2573); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3) { + setState(2567); + kU_ListEntry(); + setState(2569); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2568); + match(CypherParser::SP); + } + setState(2575); + _errHandler->sync(this); + _la = _input->LA(1); + } + } + setState(2578); + match(CypherParser::T__7); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_ListEntryContext ------------------------------------------------------------------ + +CypherParser::KU_ListEntryContext::KU_ListEntryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_ListEntryContext::SP() { + return getToken(CypherParser::SP, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_ListEntryContext::oC_Expression() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_ListEntryContext::getRuleIndex() const { + return CypherParser::RuleKU_ListEntry; +} + + +CypherParser::KU_ListEntryContext* CypherParser::kU_ListEntry() { + KU_ListEntryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 304, CypherParser::RuleKU_ListEntry); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2580); + match(CypherParser::T__3); + setState(2582); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 421, _ctx)) { + case 1: { + setState(2581); + match(CypherParser::SP); + break; + } + + default: + break; + } + setState(2585); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -2320550076713270652) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -286014905805559497) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5492410606132523) != 0)) { + setState(2584); + oC_Expression(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_StructLiteralContext ------------------------------------------------------------------ + +CypherParser::KU_StructLiteralContext::KU_StructLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_StructLiteralContext::kU_StructField() { + return getRuleContexts(); +} + +CypherParser::KU_StructFieldContext* CypherParser::KU_StructLiteralContext::kU_StructField(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_StructLiteralContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_StructLiteralContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_StructLiteralContext::getRuleIndex() const { + return CypherParser::RuleKU_StructLiteral; +} + + +CypherParser::KU_StructLiteralContext* CypherParser::kU_StructLiteral() { + KU_StructLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 306, CypherParser::RuleKU_StructLiteral); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2587); + match(CypherParser::T__8); + setState(2589); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2588); + match(CypherParser::SP); + } + setState(2591); + kU_StructField(); + setState(2593); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2592); + match(CypherParser::SP); + } + setState(2605); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3) { + setState(2595); + match(CypherParser::T__3); + setState(2597); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2596); + match(CypherParser::SP); + } + setState(2599); + kU_StructField(); + setState(2601); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2600); + match(CypherParser::SP); + } + setState(2607); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(2608); + match(CypherParser::T__9); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_StructFieldContext ------------------------------------------------------------------ + +CypherParser::KU_StructFieldContext::KU_StructFieldContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_StructFieldContext::COLON() { + return getToken(CypherParser::COLON, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_StructFieldContext::oC_Expression() { + return getRuleContext(0); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_StructFieldContext::oC_SymbolicName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_StructFieldContext::StringLiteral() { + return getToken(CypherParser::StringLiteral, 0); +} + +std::vector CypherParser::KU_StructFieldContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_StructFieldContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_StructFieldContext::getRuleIndex() const { + return CypherParser::RuleKU_StructField; +} + + +CypherParser::KU_StructFieldContext* CypherParser::kU_StructField() { + KU_StructFieldContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 308, CypherParser::RuleKU_StructField); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2612); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + setState(2610); + oC_SymbolicName(); + break; + } + + case CypherParser::StringLiteral: { + setState(2611); + match(CypherParser::StringLiteral); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2615); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2614); + match(CypherParser::SP); + } + setState(2617); + match(CypherParser::COLON); + setState(2619); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2618); + match(CypherParser::SP); + } + setState(2621); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ParenthesizedExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_ParenthesizedExpressionContext::OC_ParenthesizedExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_ParenthesizedExpressionContext::oC_Expression() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_ParenthesizedExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ParenthesizedExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_ParenthesizedExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_ParenthesizedExpression; +} + + +CypherParser::OC_ParenthesizedExpressionContext* CypherParser::oC_ParenthesizedExpression() { + OC_ParenthesizedExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 310, CypherParser::RuleOC_ParenthesizedExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2623); + match(CypherParser::T__1); + setState(2625); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2624); + match(CypherParser::SP); + } + setState(2627); + oC_Expression(); + setState(2629); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2628); + match(CypherParser::SP); + } + setState(2631); + match(CypherParser::T__2); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_FunctionInvocationContext ------------------------------------------------------------------ + +CypherParser::OC_FunctionInvocationContext::OC_FunctionInvocationContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_FunctionInvocationContext::COUNT() { + return getToken(CypherParser::COUNT, 0); +} + +tree::TerminalNode* CypherParser::OC_FunctionInvocationContext::STAR() { + return getToken(CypherParser::STAR, 0); +} + +std::vector CypherParser::OC_FunctionInvocationContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_FunctionInvocationContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_FunctionInvocationContext::CAST() { + return getToken(CypherParser::CAST, 0); +} + +std::vector CypherParser::OC_FunctionInvocationContext::kU_FunctionParameter() { + return getRuleContexts(); +} + +CypherParser::KU_FunctionParameterContext* CypherParser::OC_FunctionInvocationContext::kU_FunctionParameter(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::OC_FunctionInvocationContext::AS() { + return getToken(CypherParser::AS, 0); +} + +CypherParser::KU_DataTypeContext* CypherParser::OC_FunctionInvocationContext::kU_DataType() { + return getRuleContext(0); +} + +CypherParser::OC_FunctionNameContext* CypherParser::OC_FunctionInvocationContext::oC_FunctionName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_FunctionInvocationContext::DISTINCT() { + return getToken(CypherParser::DISTINCT, 0); +} + + +size_t CypherParser::OC_FunctionInvocationContext::getRuleIndex() const { + return CypherParser::RuleOC_FunctionInvocation; +} + + +CypherParser::OC_FunctionInvocationContext* CypherParser::oC_FunctionInvocation() { + OC_FunctionInvocationContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 312, CypherParser::RuleOC_FunctionInvocation); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2710); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 452, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2633); + match(CypherParser::COUNT); + setState(2635); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2634); + match(CypherParser::SP); + } + setState(2637); + match(CypherParser::T__1); + setState(2639); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2638); + match(CypherParser::SP); + } + setState(2641); + match(CypherParser::STAR); + setState(2643); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2642); + match(CypherParser::SP); + } + setState(2645); + match(CypherParser::T__2); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2646); + match(CypherParser::CAST); + setState(2648); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2647); + match(CypherParser::SP); + } + setState(2650); + match(CypherParser::T__1); + setState(2652); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2651); + match(CypherParser::SP); + } + setState(2654); + kU_FunctionParameter(); + setState(2656); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2655); + match(CypherParser::SP); + } + setState(2668); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::AS: { + setState(2658); + match(CypherParser::AS); + setState(2660); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2659); + match(CypherParser::SP); + } + setState(2662); + kU_DataType(0); + break; + } + + case CypherParser::T__3: { + setState(2663); + match(CypherParser::T__3); + setState(2665); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2664); + match(CypherParser::SP); + } + setState(2667); + kU_FunctionParameter(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2671); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2670); + match(CypherParser::SP); + } + setState(2673); + match(CypherParser::T__2); + break; + } + + case 3: { + enterOuterAlt(_localctx, 3); + setState(2675); + oC_FunctionName(); + setState(2677); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2676); + match(CypherParser::SP); + } + setState(2679); + match(CypherParser::T__1); + setState(2681); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2680); + match(CypherParser::SP); + } + setState(2687); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::DISTINCT) { + setState(2683); + match(CypherParser::DISTINCT); + setState(2685); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2684); + match(CypherParser::SP); + } + } + setState(2706); + _errHandler->sync(this); + + _la = _input->LA(1); + if ((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & -2320550076713270652) != 0) || ((((_la - 65) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 65)) & -286014905805559497) != 0) || ((((_la - 130) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 130)) & 5492410606132523) != 0)) { + setState(2689); + kU_FunctionParameter(); + setState(2691); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2690); + match(CypherParser::SP); + } + setState(2703); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3) { + setState(2693); + match(CypherParser::T__3); + setState(2695); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2694); + match(CypherParser::SP); + } + setState(2697); + kU_FunctionParameter(); + setState(2699); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2698); + match(CypherParser::SP); + } + setState(2705); + _errHandler->sync(this); + _la = _input->LA(1); + } + } + setState(2708); + match(CypherParser::T__2); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_FunctionNameContext ------------------------------------------------------------------ + +CypherParser::OC_FunctionNameContext::OC_FunctionNameContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::OC_FunctionNameContext::oC_SymbolicName() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_FunctionNameContext::getRuleIndex() const { + return CypherParser::RuleOC_FunctionName; +} + + +CypherParser::OC_FunctionNameContext* CypherParser::oC_FunctionName() { + OC_FunctionNameContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 314, CypherParser::RuleOC_FunctionName); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2712); + oC_SymbolicName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_FunctionParameterContext ------------------------------------------------------------------ + +CypherParser::KU_FunctionParameterContext::KU_FunctionParameterContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_FunctionParameterContext::oC_Expression() { + return getRuleContext(0); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_FunctionParameterContext::oC_SymbolicName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_FunctionParameterContext::COLON() { + return getToken(CypherParser::COLON, 0); +} + +std::vector CypherParser::KU_FunctionParameterContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_FunctionParameterContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::KU_LambdaParameterContext* CypherParser::KU_FunctionParameterContext::kU_LambdaParameter() { + return getRuleContext(0); +} + + +size_t CypherParser::KU_FunctionParameterContext::getRuleIndex() const { + return CypherParser::RuleKU_FunctionParameter; +} + + +CypherParser::KU_FunctionParameterContext* CypherParser::kU_FunctionParameter() { + KU_FunctionParameterContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 316, CypherParser::RuleKU_FunctionParameter); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2727); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 456, _ctx)) { + case 1: { + enterOuterAlt(_localctx, 1); + setState(2723); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 455, _ctx)) { + case 1: { + setState(2714); + oC_SymbolicName(); + setState(2716); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2715); + match(CypherParser::SP); + } + setState(2718); + match(CypherParser::COLON); + setState(2719); + match(CypherParser::T__5); + setState(2721); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2720); + match(CypherParser::SP); + } + break; + } + + default: + break; + } + setState(2725); + oC_Expression(); + break; + } + + case 2: { + enterOuterAlt(_localctx, 2); + setState(2726); + kU_LambdaParameter(); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_LambdaParameterContext ------------------------------------------------------------------ + +CypherParser::KU_LambdaParameterContext::KU_LambdaParameterContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::KU_LambdaVarsContext* CypherParser::KU_LambdaParameterContext::kU_LambdaVars() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::KU_LambdaParameterContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + +CypherParser::OC_ExpressionContext* CypherParser::KU_LambdaParameterContext::oC_Expression() { + return getRuleContext(0); +} + +std::vector CypherParser::KU_LambdaParameterContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_LambdaParameterContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_LambdaParameterContext::getRuleIndex() const { + return CypherParser::RuleKU_LambdaParameter; +} + + +CypherParser::KU_LambdaParameterContext* CypherParser::kU_LambdaParameter() { + KU_LambdaParameterContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 318, CypherParser::RuleKU_LambdaParameter); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2729); + kU_LambdaVars(); + setState(2731); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2730); + match(CypherParser::SP); + } + setState(2733); + match(CypherParser::MINUS); + setState(2734); + match(CypherParser::T__14); + setState(2736); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2735); + match(CypherParser::SP); + } + setState(2738); + oC_Expression(); + setState(2740); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 459, _ctx)) { + case 1: { + setState(2739); + match(CypherParser::SP); + break; + } + + default: + break; + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_LambdaVarsContext ------------------------------------------------------------------ + +CypherParser::KU_LambdaVarsContext::KU_LambdaVarsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +std::vector CypherParser::KU_LambdaVarsContext::oC_SymbolicName() { + return getRuleContexts(); +} + +CypherParser::OC_SymbolicNameContext* CypherParser::KU_LambdaVarsContext::oC_SymbolicName(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::KU_LambdaVarsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::KU_LambdaVarsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::KU_LambdaVarsContext::getRuleIndex() const { + return CypherParser::RuleKU_LambdaVars; +} + + +CypherParser::KU_LambdaVarsContext* CypherParser::kU_LambdaVars() { + KU_LambdaVarsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 320, CypherParser::RuleKU_LambdaVars); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2766); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + enterOuterAlt(_localctx, 1); + setState(2742); + oC_SymbolicName(); + break; + } + + case CypherParser::T__1: { + enterOuterAlt(_localctx, 2); + setState(2743); + match(CypherParser::T__1); + setState(2745); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2744); + match(CypherParser::SP); + } + setState(2747); + oC_SymbolicName(); + setState(2749); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2748); + match(CypherParser::SP); + } + setState(2761); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == CypherParser::T__3) { + setState(2751); + match(CypherParser::T__3); + setState(2753); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2752); + match(CypherParser::SP); + } + setState(2755); + oC_SymbolicName(); + setState(2757); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2756); + match(CypherParser::SP); + } + setState(2763); + _errHandler->sync(this); + _la = _input->LA(1); + } + setState(2764); + match(CypherParser::T__2); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PathPatternsContext ------------------------------------------------------------------ + +CypherParser::OC_PathPatternsContext::OC_PathPatternsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_NodePatternContext* CypherParser::OC_PathPatternsContext::oC_NodePattern() { + return getRuleContext(0); +} + +std::vector CypherParser::OC_PathPatternsContext::oC_PatternElementChain() { + return getRuleContexts(); +} + +CypherParser::OC_PatternElementChainContext* CypherParser::OC_PathPatternsContext::oC_PatternElementChain(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_PathPatternsContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_PathPatternsContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_PathPatternsContext::getRuleIndex() const { + return CypherParser::RuleOC_PathPatterns; +} + + +CypherParser::OC_PathPatternsContext* CypherParser::oC_PathPatterns() { + OC_PathPatternsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 322, CypherParser::RuleOC_PathPatterns); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2768); + oC_NodePattern(); + setState(2773); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(2770); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2769); + match(CypherParser::SP); + } + setState(2772); + oC_PatternElementChain(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2775); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 467, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ExistCountSubqueryContext ------------------------------------------------------------------ + +CypherParser::OC_ExistCountSubqueryContext::OC_ExistCountSubqueryContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_ExistCountSubqueryContext::MATCH() { + return getToken(CypherParser::MATCH, 0); +} + +CypherParser::OC_PatternContext* CypherParser::OC_ExistCountSubqueryContext::oC_Pattern() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_ExistCountSubqueryContext::EXISTS() { + return getToken(CypherParser::EXISTS, 0); +} + +tree::TerminalNode* CypherParser::OC_ExistCountSubqueryContext::COUNT() { + return getToken(CypherParser::COUNT, 0); +} + +std::vector CypherParser::OC_ExistCountSubqueryContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_ExistCountSubqueryContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +CypherParser::OC_WhereContext* CypherParser::OC_ExistCountSubqueryContext::oC_Where() { + return getRuleContext(0); +} + +CypherParser::KU_HintContext* CypherParser::OC_ExistCountSubqueryContext::kU_Hint() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_ExistCountSubqueryContext::getRuleIndex() const { + return CypherParser::RuleOC_ExistCountSubquery; +} + + +CypherParser::OC_ExistCountSubqueryContext* CypherParser::oC_ExistCountSubquery() { + OC_ExistCountSubqueryContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 324, CypherParser::RuleOC_ExistCountSubquery); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2777); + _la = _input->LA(1); + if (!(_la == CypherParser::COUNT + + || _la == CypherParser::EXISTS)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(2779); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2778); + match(CypherParser::SP); + } + setState(2781); + match(CypherParser::T__8); + setState(2783); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2782); + match(CypherParser::SP); + } + setState(2785); + match(CypherParser::MATCH); + setState(2787); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2786); + match(CypherParser::SP); + } + setState(2789); + oC_Pattern(); + setState(2794); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 472, _ctx)) { + case 1: { + setState(2791); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2790); + match(CypherParser::SP); + } + setState(2793); + oC_Where(); + break; + } + + default: + break; + } + setState(2800); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 474, _ctx)) { + case 1: { + setState(2797); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2796); + match(CypherParser::SP); + } + setState(2799); + kU_Hint(); + break; + } + + default: + break; + } + setState(2803); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2802); + match(CypherParser::SP); + } + setState(2805); + match(CypherParser::T__9); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PropertyLookupContext ------------------------------------------------------------------ + +CypherParser::OC_PropertyLookupContext::OC_PropertyLookupContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_PropertyKeyNameContext* CypherParser::OC_PropertyLookupContext::oC_PropertyKeyName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_PropertyLookupContext::STAR() { + return getToken(CypherParser::STAR, 0); +} + +tree::TerminalNode* CypherParser::OC_PropertyLookupContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_PropertyLookupContext::getRuleIndex() const { + return CypherParser::RuleOC_PropertyLookup; +} + + +CypherParser::OC_PropertyLookupContext* CypherParser::oC_PropertyLookup() { + OC_PropertyLookupContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 326, CypherParser::RuleOC_PropertyLookup); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2807); + match(CypherParser::T__4); + setState(2809); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2808); + match(CypherParser::SP); + } + setState(2813); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + setState(2811); + oC_PropertyKeyName(); + break; + } + + case CypherParser::STAR: { + setState(2812); + match(CypherParser::STAR); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_CaseExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_CaseExpressionContext::OC_CaseExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_CaseExpressionContext::END() { + return getToken(CypherParser::END, 0); +} + +tree::TerminalNode* CypherParser::OC_CaseExpressionContext::ELSE() { + return getToken(CypherParser::ELSE, 0); +} + +std::vector CypherParser::OC_CaseExpressionContext::oC_Expression() { + return getRuleContexts(); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_CaseExpressionContext::oC_Expression(size_t i) { + return getRuleContext(i); +} + +std::vector CypherParser::OC_CaseExpressionContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_CaseExpressionContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + +tree::TerminalNode* CypherParser::OC_CaseExpressionContext::CASE() { + return getToken(CypherParser::CASE, 0); +} + +std::vector CypherParser::OC_CaseExpressionContext::oC_CaseAlternative() { + return getRuleContexts(); +} + +CypherParser::OC_CaseAlternativeContext* CypherParser::OC_CaseExpressionContext::oC_CaseAlternative(size_t i) { + return getRuleContext(i); +} + + +size_t CypherParser::OC_CaseExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_CaseExpression; +} + + +CypherParser::OC_CaseExpressionContext* CypherParser::oC_CaseExpression() { + OC_CaseExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 328, CypherParser::RuleOC_CaseExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(2837); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 483, _ctx)) { + case 1: { + setState(2815); + match(CypherParser::CASE); + setState(2820); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(2817); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2816); + match(CypherParser::SP); + } + setState(2819); + oC_CaseAlternative(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2822); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 479, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + break; + } + + case 2: { + setState(2824); + match(CypherParser::CASE); + setState(2826); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2825); + match(CypherParser::SP); + } + setState(2828); + oC_Expression(); + setState(2833); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(2830); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2829); + match(CypherParser::SP); + } + setState(2832); + oC_CaseAlternative(); + break; + } + + default: + throw NoViableAltException(this); + } + setState(2835); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 482, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + break; + } + + default: + break; + } + setState(2847); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict(_input, 486, _ctx)) { + case 1: { + setState(2840); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2839); + match(CypherParser::SP); + } + setState(2842); + match(CypherParser::ELSE); + setState(2844); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2843); + match(CypherParser::SP); + } + setState(2846); + oC_Expression(); + break; + } + + default: + break; + } + setState(2850); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2849); + match(CypherParser::SP); + } + setState(2852); + match(CypherParser::END); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_CaseAlternativeContext ------------------------------------------------------------------ + +CypherParser::OC_CaseAlternativeContext::OC_CaseAlternativeContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_CaseAlternativeContext::WHEN() { + return getToken(CypherParser::WHEN, 0); +} + +std::vector CypherParser::OC_CaseAlternativeContext::oC_Expression() { + return getRuleContexts(); +} + +CypherParser::OC_ExpressionContext* CypherParser::OC_CaseAlternativeContext::oC_Expression(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* CypherParser::OC_CaseAlternativeContext::THEN() { + return getToken(CypherParser::THEN, 0); +} + +std::vector CypherParser::OC_CaseAlternativeContext::SP() { + return getTokens(CypherParser::SP); +} + +tree::TerminalNode* CypherParser::OC_CaseAlternativeContext::SP(size_t i) { + return getToken(CypherParser::SP, i); +} + + +size_t CypherParser::OC_CaseAlternativeContext::getRuleIndex() const { + return CypherParser::RuleOC_CaseAlternative; +} + + +CypherParser::OC_CaseAlternativeContext* CypherParser::oC_CaseAlternative() { + OC_CaseAlternativeContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 330, CypherParser::RuleOC_CaseAlternative); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2854); + match(CypherParser::WHEN); + setState(2856); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2855); + match(CypherParser::SP); + } + setState(2858); + oC_Expression(); + setState(2860); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2859); + match(CypherParser::SP); + } + setState(2862); + match(CypherParser::THEN); + setState(2864); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2863); + match(CypherParser::SP); + } + setState(2866); + oC_Expression(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_VariableContext ------------------------------------------------------------------ + +CypherParser::OC_VariableContext::OC_VariableContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::OC_VariableContext::oC_SymbolicName() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_VariableContext::getRuleIndex() const { + return CypherParser::RuleOC_Variable; +} + + +CypherParser::OC_VariableContext* CypherParser::oC_Variable() { + OC_VariableContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 332, CypherParser::RuleOC_Variable); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2868); + oC_SymbolicName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_NumberLiteralContext ------------------------------------------------------------------ + +CypherParser::OC_NumberLiteralContext::OC_NumberLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_DoubleLiteralContext* CypherParser::OC_NumberLiteralContext::oC_DoubleLiteral() { + return getRuleContext(0); +} + +CypherParser::OC_IntegerLiteralContext* CypherParser::OC_NumberLiteralContext::oC_IntegerLiteral() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_NumberLiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_NumberLiteral; +} + + +CypherParser::OC_NumberLiteralContext* CypherParser::oC_NumberLiteral() { + OC_NumberLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 334, CypherParser::RuleOC_NumberLiteral); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2872); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ExponentDecimalReal: + case CypherParser::RegularDecimalReal: { + enterOuterAlt(_localctx, 1); + setState(2870); + oC_DoubleLiteral(); + break; + } + + case CypherParser::DecimalInteger: { + enterOuterAlt(_localctx, 2); + setState(2871); + oC_IntegerLiteral(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_ParameterContext ------------------------------------------------------------------ + +CypherParser::OC_ParameterContext::OC_ParameterContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::OC_ParameterContext::oC_SymbolicName() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_ParameterContext::DecimalInteger() { + return getToken(CypherParser::DecimalInteger, 0); +} + + +size_t CypherParser::OC_ParameterContext::getRuleIndex() const { + return CypherParser::RuleOC_Parameter; +} + + +CypherParser::OC_ParameterContext* CypherParser::oC_Parameter() { + OC_ParameterContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 336, CypherParser::RuleOC_Parameter); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2874); + match(CypherParser::T__24); + setState(2877); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: + case CypherParser::HexLetter: + case CypherParser::UnescapedSymbolicName: + case CypherParser::EscapedSymbolicName: { + setState(2875); + oC_SymbolicName(); + break; + } + + case CypherParser::DecimalInteger: { + setState(2876); + match(CypherParser::DecimalInteger); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PropertyExpressionContext ------------------------------------------------------------------ + +CypherParser::OC_PropertyExpressionContext::OC_PropertyExpressionContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_AtomContext* CypherParser::OC_PropertyExpressionContext::oC_Atom() { + return getRuleContext(0); +} + +CypherParser::OC_PropertyLookupContext* CypherParser::OC_PropertyExpressionContext::oC_PropertyLookup() { + return getRuleContext(0); +} + +tree::TerminalNode* CypherParser::OC_PropertyExpressionContext::SP() { + return getToken(CypherParser::SP, 0); +} + + +size_t CypherParser::OC_PropertyExpressionContext::getRuleIndex() const { + return CypherParser::RuleOC_PropertyExpression; +} + + +CypherParser::OC_PropertyExpressionContext* CypherParser::oC_PropertyExpression() { + OC_PropertyExpressionContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 338, CypherParser::RuleOC_PropertyExpression); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2879); + oC_Atom(); + setState(2881); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == CypherParser::SP) { + setState(2880); + match(CypherParser::SP); + } + setState(2883); + oC_PropertyLookup(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_PropertyKeyNameContext ------------------------------------------------------------------ + +CypherParser::OC_PropertyKeyNameContext::OC_PropertyKeyNameContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SchemaNameContext* CypherParser::OC_PropertyKeyNameContext::oC_SchemaName() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_PropertyKeyNameContext::getRuleIndex() const { + return CypherParser::RuleOC_PropertyKeyName; +} + + +CypherParser::OC_PropertyKeyNameContext* CypherParser::oC_PropertyKeyName() { + OC_PropertyKeyNameContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 340, CypherParser::RuleOC_PropertyKeyName); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2885); + oC_SchemaName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_IntegerLiteralContext ------------------------------------------------------------------ + +CypherParser::OC_IntegerLiteralContext::OC_IntegerLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_IntegerLiteralContext::DecimalInteger() { + return getToken(CypherParser::DecimalInteger, 0); +} + + +size_t CypherParser::OC_IntegerLiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_IntegerLiteral; +} + + +CypherParser::OC_IntegerLiteralContext* CypherParser::oC_IntegerLiteral() { + OC_IntegerLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 342, CypherParser::RuleOC_IntegerLiteral); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2887); + match(CypherParser::DecimalInteger); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_DoubleLiteralContext ------------------------------------------------------------------ + +CypherParser::OC_DoubleLiteralContext::OC_DoubleLiteralContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_DoubleLiteralContext::ExponentDecimalReal() { + return getToken(CypherParser::ExponentDecimalReal, 0); +} + +tree::TerminalNode* CypherParser::OC_DoubleLiteralContext::RegularDecimalReal() { + return getToken(CypherParser::RegularDecimalReal, 0); +} + + +size_t CypherParser::OC_DoubleLiteralContext::getRuleIndex() const { + return CypherParser::RuleOC_DoubleLiteral; +} + + +CypherParser::OC_DoubleLiteralContext* CypherParser::oC_DoubleLiteral() { + OC_DoubleLiteralContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 344, CypherParser::RuleOC_DoubleLiteral); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2889); + _la = _input->LA(1); + if (!(_la == CypherParser::ExponentDecimalReal + + || _la == CypherParser::RegularDecimalReal)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SchemaNameContext ------------------------------------------------------------------ + +CypherParser::OC_SchemaNameContext::OC_SchemaNameContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +CypherParser::OC_SymbolicNameContext* CypherParser::OC_SchemaNameContext::oC_SymbolicName() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_SchemaNameContext::getRuleIndex() const { + return CypherParser::RuleOC_SchemaName; +} + + +CypherParser::OC_SchemaNameContext* CypherParser::oC_SchemaName() { + OC_SchemaNameContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 346, CypherParser::RuleOC_SchemaName); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2891); + oC_SymbolicName(); + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_SymbolicNameContext ------------------------------------------------------------------ + +CypherParser::OC_SymbolicNameContext::OC_SymbolicNameContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_SymbolicNameContext::UnescapedSymbolicName() { + return getToken(CypherParser::UnescapedSymbolicName, 0); +} + +tree::TerminalNode* CypherParser::OC_SymbolicNameContext::EscapedSymbolicName() { + return getToken(CypherParser::EscapedSymbolicName, 0); +} + +tree::TerminalNode* CypherParser::OC_SymbolicNameContext::HexLetter() { + return getToken(CypherParser::HexLetter, 0); +} + +CypherParser::KU_NonReservedKeywordsContext* CypherParser::OC_SymbolicNameContext::kU_NonReservedKeywords() { + return getRuleContext(0); +} + + +size_t CypherParser::OC_SymbolicNameContext::getRuleIndex() const { + return CypherParser::RuleOC_SymbolicName; +} + + +CypherParser::OC_SymbolicNameContext* CypherParser::oC_SymbolicName() { + OC_SymbolicNameContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 348, CypherParser::RuleOC_SymbolicName); + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + setState(2898); + _errHandler->sync(this); + switch (_input->LA(1)) { + case CypherParser::UnescapedSymbolicName: { + enterOuterAlt(_localctx, 1); + setState(2893); + match(CypherParser::UnescapedSymbolicName); + break; + } + + case CypherParser::EscapedSymbolicName: { + enterOuterAlt(_localctx, 2); + setState(2894); + antlrcpp::downCast(_localctx)->escapedsymbolicnameToken = match(CypherParser::EscapedSymbolicName); + if ((antlrcpp::downCast(_localctx)->escapedsymbolicnameToken != nullptr ? antlrcpp::downCast(_localctx)->escapedsymbolicnameToken->getText() : "") == "``") { notifyEmptyToken(antlrcpp::downCast(_localctx)->escapedsymbolicnameToken); } + break; + } + + case CypherParser::HexLetter: { + enterOuterAlt(_localctx, 3); + setState(2896); + match(CypherParser::HexLetter); + break; + } + + case CypherParser::ADD: + case CypherParser::ALTER: + case CypherParser::AS: + case CypherParser::ATTACH: + case CypherParser::BEGIN: + case CypherParser::BY: + case CypherParser::CALL: + case CypherParser::CHECKPOINT: + case CypherParser::COMMENT: + case CypherParser::COMMIT: + case CypherParser::CONTAINS: + case CypherParser::COPY: + case CypherParser::COUNT: + case CypherParser::CYCLE: + case CypherParser::DATABASE: + case CypherParser::DELETE: + case CypherParser::DETACH: + case CypherParser::DROP: + case CypherParser::EXPLAIN: + case CypherParser::EXPORT: + case CypherParser::EXTENSION: + case CypherParser::FROM: + case CypherParser::FORCE: + case CypherParser::GRAPH: + case CypherParser::IMPORT: + case CypherParser::IF: + case CypherParser::INCREMENT: + case CypherParser::IS: + case CypherParser::KEY: + case CypherParser::LIMIT: + case CypherParser::LOAD: + case CypherParser::LOGICAL: + case CypherParser::MATCH: + case CypherParser::MAXVALUE: + case CypherParser::MERGE: + case CypherParser::MINVALUE: + case CypherParser::NO: + case CypherParser::NODE: + case CypherParser::PROJECT: + case CypherParser::READ: + case CypherParser::REL: + case CypherParser::RENAME: + case CypherParser::RETURN: + case CypherParser::ROLLBACK: + case CypherParser::SEQUENCE: + case CypherParser::SET: + case CypherParser::START: + case CypherParser::STRUCT: + case CypherParser::TO: + case CypherParser::TRANSACTION: + case CypherParser::TYPE: + case CypherParser::UNINSTALL: + case CypherParser::UPDATE: + case CypherParser::USE: + case CypherParser::WRITE: + case CypherParser::YIELD: + case CypherParser::USER: + case CypherParser::PASSWORD: + case CypherParser::ROLE: + case CypherParser::MAP: + case CypherParser::DECIMAL: + case CypherParser::L_SKIP: { + enterOuterAlt(_localctx, 4); + setState(2897); + kU_NonReservedKeywords(); + break; + } + + default: + throw NoViableAltException(this); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- KU_NonReservedKeywordsContext ------------------------------------------------------------------ + +CypherParser::KU_NonReservedKeywordsContext::KU_NonReservedKeywordsContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::COMMENT() { + return getToken(CypherParser::COMMENT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::ADD() { + return getToken(CypherParser::ADD, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::ALTER() { + return getToken(CypherParser::ALTER, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::AS() { + return getToken(CypherParser::AS, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::ATTACH() { + return getToken(CypherParser::ATTACH, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::BEGIN() { + return getToken(CypherParser::BEGIN, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::BY() { + return getToken(CypherParser::BY, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::CALL() { + return getToken(CypherParser::CALL, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::CHECKPOINT() { + return getToken(CypherParser::CHECKPOINT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::COMMIT() { + return getToken(CypherParser::COMMIT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::CONTAINS() { + return getToken(CypherParser::CONTAINS, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::COPY() { + return getToken(CypherParser::COPY, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::COUNT() { + return getToken(CypherParser::COUNT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::CYCLE() { + return getToken(CypherParser::CYCLE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::DATABASE() { + return getToken(CypherParser::DATABASE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::DECIMAL() { + return getToken(CypherParser::DECIMAL, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::DELETE() { + return getToken(CypherParser::DELETE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::DETACH() { + return getToken(CypherParser::DETACH, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::DROP() { + return getToken(CypherParser::DROP, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::EXPLAIN() { + return getToken(CypherParser::EXPLAIN, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::EXPORT() { + return getToken(CypherParser::EXPORT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::EXTENSION() { + return getToken(CypherParser::EXTENSION, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::FORCE() { + return getToken(CypherParser::FORCE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::GRAPH() { + return getToken(CypherParser::GRAPH, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::IF() { + return getToken(CypherParser::IF, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::IS() { + return getToken(CypherParser::IS, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::IMPORT() { + return getToken(CypherParser::IMPORT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::INCREMENT() { + return getToken(CypherParser::INCREMENT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::KEY() { + return getToken(CypherParser::KEY, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::LOAD() { + return getToken(CypherParser::LOAD, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::LOGICAL() { + return getToken(CypherParser::LOGICAL, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::MATCH() { + return getToken(CypherParser::MATCH, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::MAXVALUE() { + return getToken(CypherParser::MAXVALUE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::MERGE() { + return getToken(CypherParser::MERGE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::MINVALUE() { + return getToken(CypherParser::MINVALUE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::NO() { + return getToken(CypherParser::NO, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::NODE() { + return getToken(CypherParser::NODE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::PROJECT() { + return getToken(CypherParser::PROJECT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::READ() { + return getToken(CypherParser::READ, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::REL() { + return getToken(CypherParser::REL, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::RENAME() { + return getToken(CypherParser::RENAME, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::RETURN() { + return getToken(CypherParser::RETURN, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::ROLLBACK() { + return getToken(CypherParser::ROLLBACK, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::ROLE() { + return getToken(CypherParser::ROLE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::SEQUENCE() { + return getToken(CypherParser::SEQUENCE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::SET() { + return getToken(CypherParser::SET, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::START() { + return getToken(CypherParser::START, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::STRUCT() { + return getToken(CypherParser::STRUCT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::L_SKIP() { + return getToken(CypherParser::L_SKIP, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::LIMIT() { + return getToken(CypherParser::LIMIT, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::TRANSACTION() { + return getToken(CypherParser::TRANSACTION, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::TYPE() { + return getToken(CypherParser::TYPE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::USE() { + return getToken(CypherParser::USE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::UNINSTALL() { + return getToken(CypherParser::UNINSTALL, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::UPDATE() { + return getToken(CypherParser::UPDATE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::WRITE() { + return getToken(CypherParser::WRITE, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::FROM() { + return getToken(CypherParser::FROM, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::TO() { + return getToken(CypherParser::TO, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::YIELD() { + return getToken(CypherParser::YIELD, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::USER() { + return getToken(CypherParser::USER, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::PASSWORD() { + return getToken(CypherParser::PASSWORD, 0); +} + +tree::TerminalNode* CypherParser::KU_NonReservedKeywordsContext::MAP() { + return getToken(CypherParser::MAP, 0); +} + + +size_t CypherParser::KU_NonReservedKeywordsContext::getRuleIndex() const { + return CypherParser::RuleKU_NonReservedKeywords; +} + + +CypherParser::KU_NonReservedKeywordsContext* CypherParser::kU_NonReservedKeywords() { + KU_NonReservedKeywordsContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 350, CypherParser::RuleKU_NonReservedKeywords); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2900); + _la = _input->LA(1); + if (!(((((_la - 47) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 47)) & 8923191552623093653) != 0) || ((((_la - 111) & ~ 0x3fULL) == 0) && + ((1ULL << (_la - 111)) & 3361330146570243) != 0))) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_LeftArrowHeadContext ------------------------------------------------------------------ + +CypherParser::OC_LeftArrowHeadContext::OC_LeftArrowHeadContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + + +size_t CypherParser::OC_LeftArrowHeadContext::getRuleIndex() const { + return CypherParser::RuleOC_LeftArrowHead; +} + + +CypherParser::OC_LeftArrowHeadContext* CypherParser::oC_LeftArrowHead() { + OC_LeftArrowHeadContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 352, CypherParser::RuleOC_LeftArrowHead); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2902); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 1006641152) != 0))) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_RightArrowHeadContext ------------------------------------------------------------------ + +CypherParser::OC_RightArrowHeadContext::OC_RightArrowHeadContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + + +size_t CypherParser::OC_RightArrowHeadContext::getRuleIndex() const { + return CypherParser::RuleOC_RightArrowHead; +} + + +CypherParser::OC_RightArrowHeadContext* CypherParser::oC_RightArrowHead() { + OC_RightArrowHeadContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 354, CypherParser::RuleOC_RightArrowHead); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2904); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 16106160128) != 0))) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- OC_DashContext ------------------------------------------------------------------ + +CypherParser::OC_DashContext::OC_DashContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + +tree::TerminalNode* CypherParser::OC_DashContext::MINUS() { + return getToken(CypherParser::MINUS, 0); +} + + +size_t CypherParser::OC_DashContext::getRuleIndex() const { + return CypherParser::RuleOC_Dash; +} + + +CypherParser::OC_DashContext* CypherParser::oC_Dash() { + OC_DashContext *_localctx = _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 356, CypherParser::RuleOC_Dash); + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + exitRule(); + }); + try { + enterOuterAlt(_localctx, 1); + setState(2906); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 35167192219648) != 0) || _la == CypherParser::MINUS)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +bool CypherParser::sempred(RuleContext *context, size_t ruleIndex, size_t predicateIndex) { + switch (ruleIndex) { + case 57: return kU_DataTypeSempred(antlrcpp::downCast(context), predicateIndex); + case 84: return kU_JoinNodeSempred(antlrcpp::downCast(context), predicateIndex); + + default: + break; + } + return true; +} + +bool CypherParser::kU_DataTypeSempred(KU_DataTypeContext *_localctx, size_t predicateIndex) { + switch (predicateIndex) { + case 0: return precpred(_ctx, 5); + + default: + break; + } + return true; +} + +bool CypherParser::kU_JoinNodeSempred(KU_JoinNodeContext *_localctx, size_t predicateIndex) { + switch (predicateIndex) { + case 1: return precpred(_ctx, 4); + case 2: return precpred(_ctx, 3); + + default: + break; + } + return true; +} + +void CypherParser::initialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + cypherParserInitialize(); +#else + ::antlr4::internal::call_once(cypherParserOnceFlag, cypherParserInitialize); +#endif +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/include/cypher_lexer.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/include/cypher_lexer.h new file mode 100644 index 0000000000..f4357ba655 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/include/cypher_lexer.h @@ -0,0 +1,82 @@ + +// Generated from Cypher.g4 by ANTLR 4.13.1 + +#pragma once + + +#include "antlr4-runtime.h" + + + + +class CypherLexer : public antlr4::Lexer { +public: + enum { + T__0 = 1, T__1 = 2, T__2 = 3, T__3 = 4, T__4 = 5, T__5 = 6, T__6 = 7, + T__7 = 8, T__8 = 9, T__9 = 10, T__10 = 11, T__11 = 12, T__12 = 13, T__13 = 14, + T__14 = 15, T__15 = 16, T__16 = 17, T__17 = 18, T__18 = 19, T__19 = 20, + T__20 = 21, T__21 = 22, T__22 = 23, T__23 = 24, T__24 = 25, T__25 = 26, + T__26 = 27, T__27 = 28, T__28 = 29, T__29 = 30, T__30 = 31, T__31 = 32, + T__32 = 33, T__33 = 34, T__34 = 35, T__35 = 36, T__36 = 37, T__37 = 38, + T__38 = 39, T__39 = 40, T__40 = 41, T__41 = 42, T__42 = 43, T__43 = 44, + ACYCLIC = 45, ANY = 46, ADD = 47, ALL = 48, ALTER = 49, AND = 50, AS = 51, + ASC = 52, ASCENDING = 53, ATTACH = 54, BEGIN = 55, BY = 56, CALL = 57, + CASE = 58, CAST = 59, CHECKPOINT = 60, COLUMN = 61, COMMENT = 62, COMMIT = 63, + COMMIT_SKIP_CHECKPOINT = 64, CONTAINS = 65, COPY = 66, COUNT = 67, CREATE = 68, + CYCLE = 69, DATABASE = 70, DBTYPE = 71, DEFAULT = 72, DELETE = 73, DESC = 74, + DESCENDING = 75, DETACH = 76, DISTINCT = 77, DROP = 78, ELSE = 79, END = 80, + ENDS = 81, EXISTS = 82, EXPLAIN = 83, EXPORT = 84, EXTENSION = 85, FALSE = 86, + FROM = 87, FORCE = 88, GLOB = 89, GRAPH = 90, GROUP = 91, HEADERS = 92, + HINT = 93, IMPORT = 94, IF = 95, IN = 96, INCREMENT = 97, INSTALL = 98, + IS = 99, JOIN = 100, KEY = 101, LIMIT = 102, LOAD = 103, LOGICAL = 104, + MACRO = 105, MATCH = 106, MAXVALUE = 107, MERGE = 108, MINVALUE = 109, + MULTI_JOIN = 110, NO = 111, NODE = 112, NOT = 113, NONE = 114, NULL_ = 115, + ON = 116, ONLY = 117, OPTIONAL = 118, OR = 119, ORDER = 120, PRIMARY = 121, + PROFILE = 122, PROJECT = 123, READ = 124, REL = 125, RENAME = 126, RETURN = 127, + ROLLBACK = 128, ROLLBACK_SKIP_CHECKPOINT = 129, SEQUENCE = 130, SET = 131, + SHORTEST = 132, START = 133, STARTS = 134, STRUCT = 135, TABLE = 136, + THEN = 137, TO = 138, TRAIL = 139, TRANSACTION = 140, TRUE = 141, TYPE = 142, + UNION = 143, UNWIND = 144, UNINSTALL = 145, UPDATE = 146, USE = 147, + WHEN = 148, WHERE = 149, WITH = 150, WRITE = 151, WSHORTEST = 152, XOR = 153, + SINGLE = 154, YIELD = 155, USER = 156, PASSWORD = 157, ROLE = 158, MAP = 159, + DECIMAL = 160, STAR = 161, L_SKIP = 162, INVALID_NOT_EQUAL = 163, COLON = 164, + DOTDOT = 165, MINUS = 166, FACTORIAL = 167, StringLiteral = 168, EscapedChar = 169, + DecimalInteger = 170, HexLetter = 171, HexDigit = 172, Digit = 173, + NonZeroDigit = 174, NonZeroOctDigit = 175, ZeroDigit = 176, ExponentDecimalReal = 177, + RegularDecimalReal = 178, UnescapedSymbolicName = 179, IdentifierStart = 180, + IdentifierPart = 181, EscapedSymbolicName = 182, SP = 183, WHITESPACE = 184, + CypherComment = 185, Unknown = 186 + }; + + explicit CypherLexer(antlr4::CharStream *input); + + ~CypherLexer() override; + + + std::string getGrammarFileName() const override; + + const std::vector& getRuleNames() const override; + + const std::vector& getChannelNames() const override; + + const std::vector& getModeNames() const override; + + const antlr4::dfa::Vocabulary& getVocabulary() const override; + + antlr4::atn::SerializedATNView getSerializedATN() const override; + + const antlr4::atn::ATN& getATN() const override; + + // By default the static state used to implement the lexer is lazily initialized during the first + // call to the constructor. You can call this function if you wish to initialize the static state + // ahead of time. + static void initialize(); + +private: + + // Individual action functions triggered by action() above. + + // Individual semantic predicate functions triggered by sempred() above. + +}; + diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/include/cypher_parser.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/include/cypher_parser.h new file mode 100644 index 0000000000..7cfd61bda7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_cypher/include/cypher_parser.h @@ -0,0 +1,3065 @@ + +// Generated from Cypher.g4 by ANTLR 4.13.1 + +#pragma once + + +#include "antlr4-runtime.h" + + + + +class CypherParser : public antlr4::Parser { +public: + enum { + T__0 = 1, T__1 = 2, T__2 = 3, T__3 = 4, T__4 = 5, T__5 = 6, T__6 = 7, + T__7 = 8, T__8 = 9, T__9 = 10, T__10 = 11, T__11 = 12, T__12 = 13, T__13 = 14, + T__14 = 15, T__15 = 16, T__16 = 17, T__17 = 18, T__18 = 19, T__19 = 20, + T__20 = 21, T__21 = 22, T__22 = 23, T__23 = 24, T__24 = 25, T__25 = 26, + T__26 = 27, T__27 = 28, T__28 = 29, T__29 = 30, T__30 = 31, T__31 = 32, + T__32 = 33, T__33 = 34, T__34 = 35, T__35 = 36, T__36 = 37, T__37 = 38, + T__38 = 39, T__39 = 40, T__40 = 41, T__41 = 42, T__42 = 43, T__43 = 44, + ACYCLIC = 45, ANY = 46, ADD = 47, ALL = 48, ALTER = 49, AND = 50, AS = 51, + ASC = 52, ASCENDING = 53, ATTACH = 54, BEGIN = 55, BY = 56, CALL = 57, + CASE = 58, CAST = 59, CHECKPOINT = 60, COLUMN = 61, COMMENT = 62, COMMIT = 63, + COMMIT_SKIP_CHECKPOINT = 64, CONTAINS = 65, COPY = 66, COUNT = 67, CREATE = 68, + CYCLE = 69, DATABASE = 70, DBTYPE = 71, DEFAULT = 72, DELETE = 73, DESC = 74, + DESCENDING = 75, DETACH = 76, DISTINCT = 77, DROP = 78, ELSE = 79, END = 80, + ENDS = 81, EXISTS = 82, EXPLAIN = 83, EXPORT = 84, EXTENSION = 85, FALSE = 86, + FROM = 87, FORCE = 88, GLOB = 89, GRAPH = 90, GROUP = 91, HEADERS = 92, + HINT = 93, IMPORT = 94, IF = 95, IN = 96, INCREMENT = 97, INSTALL = 98, + IS = 99, JOIN = 100, KEY = 101, LIMIT = 102, LOAD = 103, LOGICAL = 104, + MACRO = 105, MATCH = 106, MAXVALUE = 107, MERGE = 108, MINVALUE = 109, + MULTI_JOIN = 110, NO = 111, NODE = 112, NOT = 113, NONE = 114, NULL_ = 115, + ON = 116, ONLY = 117, OPTIONAL = 118, OR = 119, ORDER = 120, PRIMARY = 121, + PROFILE = 122, PROJECT = 123, READ = 124, REL = 125, RENAME = 126, RETURN = 127, + ROLLBACK = 128, ROLLBACK_SKIP_CHECKPOINT = 129, SEQUENCE = 130, SET = 131, + SHORTEST = 132, START = 133, STARTS = 134, STRUCT = 135, TABLE = 136, + THEN = 137, TO = 138, TRAIL = 139, TRANSACTION = 140, TRUE = 141, TYPE = 142, + UNION = 143, UNWIND = 144, UNINSTALL = 145, UPDATE = 146, USE = 147, + WHEN = 148, WHERE = 149, WITH = 150, WRITE = 151, WSHORTEST = 152, XOR = 153, + SINGLE = 154, YIELD = 155, USER = 156, PASSWORD = 157, ROLE = 158, MAP = 159, + DECIMAL = 160, STAR = 161, L_SKIP = 162, INVALID_NOT_EQUAL = 163, COLON = 164, + DOTDOT = 165, MINUS = 166, FACTORIAL = 167, StringLiteral = 168, EscapedChar = 169, + DecimalInteger = 170, HexLetter = 171, HexDigit = 172, Digit = 173, + NonZeroDigit = 174, NonZeroOctDigit = 175, ZeroDigit = 176, ExponentDecimalReal = 177, + RegularDecimalReal = 178, UnescapedSymbolicName = 179, IdentifierStart = 180, + IdentifierPart = 181, EscapedSymbolicName = 182, SP = 183, WHITESPACE = 184, + CypherComment = 185, Unknown = 186 + }; + + enum { + RuleKu_Statements = 0, RuleOC_Cypher = 1, RuleOC_Statement = 2, RuleKU_CopyFrom = 3, + RuleKU_ColumnNames = 4, RuleKU_ScanSource = 5, RuleKU_CopyFromByColumn = 6, + RuleKU_CopyTO = 7, RuleKU_ExportDatabase = 8, RuleKU_ImportDatabase = 9, + RuleKU_AttachDatabase = 10, RuleKU_Option = 11, RuleKU_Options = 12, + RuleKU_DetachDatabase = 13, RuleKU_UseDatabase = 14, RuleKU_StandaloneCall = 15, + RuleKU_CommentOn = 16, RuleKU_CreateMacro = 17, RuleKU_PositionalArgs = 18, + RuleKU_DefaultArg = 19, RuleKU_FilePaths = 20, RuleKU_IfNotExists = 21, + RuleKU_CreateNodeTable = 22, RuleKU_CreateRelTable = 23, RuleKU_FromToConnections = 24, + RuleKU_FromToConnection = 25, RuleKU_CreateSequence = 26, RuleKU_CreateType = 27, + RuleKU_SequenceOptions = 28, RuleKU_WithPasswd = 29, RuleKU_CreateUser = 30, + RuleKU_CreateRole = 31, RuleKU_IncrementBy = 32, RuleKU_MinValue = 33, + RuleKU_MaxValue = 34, RuleKU_StartWith = 35, RuleKU_Cycle = 36, RuleKU_IfExists = 37, + RuleKU_Drop = 38, RuleKU_AlterTable = 39, RuleKU_AlterOptions = 40, + RuleKU_AddProperty = 41, RuleKU_Default = 42, RuleKU_DropProperty = 43, + RuleKU_RenameTable = 44, RuleKU_RenameProperty = 45, RuleKU_AddFromToConnection = 46, + RuleKU_DropFromToConnection = 47, RuleKU_ColumnDefinitions = 48, RuleKU_ColumnDefinition = 49, + RuleKU_PropertyDefinitions = 50, RuleKU_PropertyDefinition = 51, RuleKU_CreateNodeConstraint = 52, + RuleKU_UnionType = 53, RuleKU_StructType = 54, RuleKU_MapType = 55, + RuleKU_DecimalType = 56, RuleKU_DataType = 57, RuleKU_ListIdentifiers = 58, + RuleKU_ListIdentifier = 59, RuleOC_AnyCypherOption = 60, RuleOC_Explain = 61, + RuleOC_Profile = 62, RuleKU_Transaction = 63, RuleKU_Extension = 64, + RuleKU_LoadExtension = 65, RuleKU_InstallExtension = 66, RuleKU_UninstallExtension = 67, + RuleKU_UpdateExtension = 68, RuleOC_Query = 69, RuleOC_RegularQuery = 70, + RuleOC_Union = 71, RuleOC_SingleQuery = 72, RuleOC_SinglePartQuery = 73, + RuleOC_MultiPartQuery = 74, RuleKU_QueryPart = 75, RuleOC_UpdatingClause = 76, + RuleOC_ReadingClause = 77, RuleKU_LoadFrom = 78, RuleOC_YieldItem = 79, + RuleOC_YieldItems = 80, RuleKU_InQueryCall = 81, RuleOC_Match = 82, + RuleKU_Hint = 83, RuleKU_JoinNode = 84, RuleOC_Unwind = 85, RuleOC_Create = 86, + RuleOC_Merge = 87, RuleOC_MergeAction = 88, RuleOC_Set = 89, RuleOC_SetItem = 90, + RuleOC_Delete = 91, RuleOC_With = 92, RuleOC_Return = 93, RuleOC_ProjectionBody = 94, + RuleOC_ProjectionItems = 95, RuleOC_ProjectionItem = 96, RuleOC_Order = 97, + RuleOC_Skip = 98, RuleOC_Limit = 99, RuleOC_SortItem = 100, RuleOC_Where = 101, + RuleOC_Pattern = 102, RuleOC_PatternPart = 103, RuleOC_AnonymousPatternPart = 104, + RuleOC_PatternElement = 105, RuleOC_NodePattern = 106, RuleOC_PatternElementChain = 107, + RuleOC_RelationshipPattern = 108, RuleOC_RelationshipDetail = 109, RuleKU_Properties = 110, + RuleOC_RelationshipTypes = 111, RuleOC_NodeLabels = 112, RuleKU_RecursiveDetail = 113, + RuleKU_RecursiveType = 114, RuleOC_RangeLiteral = 115, RuleKU_RecursiveComprehension = 116, + RuleKU_RecursiveProjectionItems = 117, RuleOC_LowerBound = 118, RuleOC_UpperBound = 119, + RuleOC_LabelName = 120, RuleOC_RelTypeName = 121, RuleOC_Expression = 122, + RuleOC_OrExpression = 123, RuleOC_XorExpression = 124, RuleOC_AndExpression = 125, + RuleOC_NotExpression = 126, RuleOC_ComparisonExpression = 127, RuleKU_ComparisonOperator = 128, + RuleKU_BitwiseOrOperatorExpression = 129, RuleKU_BitwiseAndOperatorExpression = 130, + RuleKU_BitShiftOperatorExpression = 131, RuleKU_BitShiftOperator = 132, + RuleOC_AddOrSubtractExpression = 133, RuleKU_AddOrSubtractOperator = 134, + RuleOC_MultiplyDivideModuloExpression = 135, RuleKU_MultiplyDivideModuloOperator = 136, + RuleOC_PowerOfExpression = 137, RuleOC_StringListNullOperatorExpression = 138, + RuleOC_ListOperatorExpression = 139, RuleOC_StringOperatorExpression = 140, + RuleOC_RegularExpression = 141, RuleOC_NullOperatorExpression = 142, + RuleOC_UnaryAddSubtractOrFactorialExpression = 143, RuleOC_PropertyOrLabelsExpression = 144, + RuleOC_Atom = 145, RuleOC_Quantifier = 146, RuleOC_FilterExpression = 147, + RuleOC_IdInColl = 148, RuleOC_Literal = 149, RuleOC_BooleanLiteral = 150, + RuleOC_ListLiteral = 151, RuleKU_ListEntry = 152, RuleKU_StructLiteral = 153, + RuleKU_StructField = 154, RuleOC_ParenthesizedExpression = 155, RuleOC_FunctionInvocation = 156, + RuleOC_FunctionName = 157, RuleKU_FunctionParameter = 158, RuleKU_LambdaParameter = 159, + RuleKU_LambdaVars = 160, RuleOC_PathPatterns = 161, RuleOC_ExistCountSubquery = 162, + RuleOC_PropertyLookup = 163, RuleOC_CaseExpression = 164, RuleOC_CaseAlternative = 165, + RuleOC_Variable = 166, RuleOC_NumberLiteral = 167, RuleOC_Parameter = 168, + RuleOC_PropertyExpression = 169, RuleOC_PropertyKeyName = 170, RuleOC_IntegerLiteral = 171, + RuleOC_DoubleLiteral = 172, RuleOC_SchemaName = 173, RuleOC_SymbolicName = 174, + RuleKU_NonReservedKeywords = 175, RuleOC_LeftArrowHead = 176, RuleOC_RightArrowHead = 177, + RuleOC_Dash = 178 + }; + + explicit CypherParser(antlr4::TokenStream *input); + + CypherParser(antlr4::TokenStream *input, const antlr4::atn::ParserATNSimulatorOptions &options); + + ~CypherParser() override; + + std::string getGrammarFileName() const override; + + const antlr4::atn::ATN& getATN() const override; + + const std::vector& getRuleNames() const override; + + const antlr4::dfa::Vocabulary& getVocabulary() const override; + + antlr4::atn::SerializedATNView getSerializedATN() const override; + + + class Ku_StatementsContext; + class OC_CypherContext; + class OC_StatementContext; + class KU_CopyFromContext; + class KU_ColumnNamesContext; + class KU_ScanSourceContext; + class KU_CopyFromByColumnContext; + class KU_CopyTOContext; + class KU_ExportDatabaseContext; + class KU_ImportDatabaseContext; + class KU_AttachDatabaseContext; + class KU_OptionContext; + class KU_OptionsContext; + class KU_DetachDatabaseContext; + class KU_UseDatabaseContext; + class KU_StandaloneCallContext; + class KU_CommentOnContext; + class KU_CreateMacroContext; + class KU_PositionalArgsContext; + class KU_DefaultArgContext; + class KU_FilePathsContext; + class KU_IfNotExistsContext; + class KU_CreateNodeTableContext; + class KU_CreateRelTableContext; + class KU_FromToConnectionsContext; + class KU_FromToConnectionContext; + class KU_CreateSequenceContext; + class KU_CreateTypeContext; + class KU_SequenceOptionsContext; + class KU_WithPasswdContext; + class KU_CreateUserContext; + class KU_CreateRoleContext; + class KU_IncrementByContext; + class KU_MinValueContext; + class KU_MaxValueContext; + class KU_StartWithContext; + class KU_CycleContext; + class KU_IfExistsContext; + class KU_DropContext; + class KU_AlterTableContext; + class KU_AlterOptionsContext; + class KU_AddPropertyContext; + class KU_DefaultContext; + class KU_DropPropertyContext; + class KU_RenameTableContext; + class KU_RenamePropertyContext; + class KU_AddFromToConnectionContext; + class KU_DropFromToConnectionContext; + class KU_ColumnDefinitionsContext; + class KU_ColumnDefinitionContext; + class KU_PropertyDefinitionsContext; + class KU_PropertyDefinitionContext; + class KU_CreateNodeConstraintContext; + class KU_UnionTypeContext; + class KU_StructTypeContext; + class KU_MapTypeContext; + class KU_DecimalTypeContext; + class KU_DataTypeContext; + class KU_ListIdentifiersContext; + class KU_ListIdentifierContext; + class OC_AnyCypherOptionContext; + class OC_ExplainContext; + class OC_ProfileContext; + class KU_TransactionContext; + class KU_ExtensionContext; + class KU_LoadExtensionContext; + class KU_InstallExtensionContext; + class KU_UninstallExtensionContext; + class KU_UpdateExtensionContext; + class OC_QueryContext; + class OC_RegularQueryContext; + class OC_UnionContext; + class OC_SingleQueryContext; + class OC_SinglePartQueryContext; + class OC_MultiPartQueryContext; + class KU_QueryPartContext; + class OC_UpdatingClauseContext; + class OC_ReadingClauseContext; + class KU_LoadFromContext; + class OC_YieldItemContext; + class OC_YieldItemsContext; + class KU_InQueryCallContext; + class OC_MatchContext; + class KU_HintContext; + class KU_JoinNodeContext; + class OC_UnwindContext; + class OC_CreateContext; + class OC_MergeContext; + class OC_MergeActionContext; + class OC_SetContext; + class OC_SetItemContext; + class OC_DeleteContext; + class OC_WithContext; + class OC_ReturnContext; + class OC_ProjectionBodyContext; + class OC_ProjectionItemsContext; + class OC_ProjectionItemContext; + class OC_OrderContext; + class OC_SkipContext; + class OC_LimitContext; + class OC_SortItemContext; + class OC_WhereContext; + class OC_PatternContext; + class OC_PatternPartContext; + class OC_AnonymousPatternPartContext; + class OC_PatternElementContext; + class OC_NodePatternContext; + class OC_PatternElementChainContext; + class OC_RelationshipPatternContext; + class OC_RelationshipDetailContext; + class KU_PropertiesContext; + class OC_RelationshipTypesContext; + class OC_NodeLabelsContext; + class KU_RecursiveDetailContext; + class KU_RecursiveTypeContext; + class OC_RangeLiteralContext; + class KU_RecursiveComprehensionContext; + class KU_RecursiveProjectionItemsContext; + class OC_LowerBoundContext; + class OC_UpperBoundContext; + class OC_LabelNameContext; + class OC_RelTypeNameContext; + class OC_ExpressionContext; + class OC_OrExpressionContext; + class OC_XorExpressionContext; + class OC_AndExpressionContext; + class OC_NotExpressionContext; + class OC_ComparisonExpressionContext; + class KU_ComparisonOperatorContext; + class KU_BitwiseOrOperatorExpressionContext; + class KU_BitwiseAndOperatorExpressionContext; + class KU_BitShiftOperatorExpressionContext; + class KU_BitShiftOperatorContext; + class OC_AddOrSubtractExpressionContext; + class KU_AddOrSubtractOperatorContext; + class OC_MultiplyDivideModuloExpressionContext; + class KU_MultiplyDivideModuloOperatorContext; + class OC_PowerOfExpressionContext; + class OC_StringListNullOperatorExpressionContext; + class OC_ListOperatorExpressionContext; + class OC_StringOperatorExpressionContext; + class OC_RegularExpressionContext; + class OC_NullOperatorExpressionContext; + class OC_UnaryAddSubtractOrFactorialExpressionContext; + class OC_PropertyOrLabelsExpressionContext; + class OC_AtomContext; + class OC_QuantifierContext; + class OC_FilterExpressionContext; + class OC_IdInCollContext; + class OC_LiteralContext; + class OC_BooleanLiteralContext; + class OC_ListLiteralContext; + class KU_ListEntryContext; + class KU_StructLiteralContext; + class KU_StructFieldContext; + class OC_ParenthesizedExpressionContext; + class OC_FunctionInvocationContext; + class OC_FunctionNameContext; + class KU_FunctionParameterContext; + class KU_LambdaParameterContext; + class KU_LambdaVarsContext; + class OC_PathPatternsContext; + class OC_ExistCountSubqueryContext; + class OC_PropertyLookupContext; + class OC_CaseExpressionContext; + class OC_CaseAlternativeContext; + class OC_VariableContext; + class OC_NumberLiteralContext; + class OC_ParameterContext; + class OC_PropertyExpressionContext; + class OC_PropertyKeyNameContext; + class OC_IntegerLiteralContext; + class OC_DoubleLiteralContext; + class OC_SchemaNameContext; + class OC_SymbolicNameContext; + class KU_NonReservedKeywordsContext; + class OC_LeftArrowHeadContext; + class OC_RightArrowHeadContext; + class OC_DashContext; + + class Ku_StatementsContext : public antlr4::ParserRuleContext { + public: + Ku_StatementsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_Cypher(); + OC_CypherContext* oC_Cypher(size_t i); + antlr4::tree::TerminalNode *EOF(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + Ku_StatementsContext* ku_Statements(); + + class OC_CypherContext : public antlr4::ParserRuleContext { + public: + OC_CypherContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_StatementContext *oC_Statement(); + OC_AnyCypherOptionContext *oC_AnyCypherOption(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_CypherContext* oC_Cypher(); + + class OC_StatementContext : public antlr4::ParserRuleContext { + public: + OC_StatementContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_QueryContext *oC_Query(); + KU_CreateUserContext *kU_CreateUser(); + KU_CreateRoleContext *kU_CreateRole(); + KU_CreateNodeTableContext *kU_CreateNodeTable(); + KU_CreateRelTableContext *kU_CreateRelTable(); + KU_CreateSequenceContext *kU_CreateSequence(); + KU_CreateTypeContext *kU_CreateType(); + KU_DropContext *kU_Drop(); + KU_AlterTableContext *kU_AlterTable(); + KU_CopyFromContext *kU_CopyFrom(); + KU_CopyFromByColumnContext *kU_CopyFromByColumn(); + KU_CopyTOContext *kU_CopyTO(); + KU_StandaloneCallContext *kU_StandaloneCall(); + KU_CreateMacroContext *kU_CreateMacro(); + KU_CommentOnContext *kU_CommentOn(); + KU_TransactionContext *kU_Transaction(); + KU_ExtensionContext *kU_Extension(); + KU_ExportDatabaseContext *kU_ExportDatabase(); + KU_ImportDatabaseContext *kU_ImportDatabase(); + KU_AttachDatabaseContext *kU_AttachDatabase(); + KU_DetachDatabaseContext *kU_DetachDatabase(); + KU_UseDatabaseContext *kU_UseDatabase(); + + + }; + + OC_StatementContext* oC_Statement(); + + class KU_CopyFromContext : public antlr4::ParserRuleContext { + public: + KU_CopyFromContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COPY(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_SchemaNameContext *oC_SchemaName(); + antlr4::tree::TerminalNode *FROM(); + KU_ScanSourceContext *kU_ScanSource(); + KU_ColumnNamesContext *kU_ColumnNames(); + KU_OptionsContext *kU_Options(); + + + }; + + KU_CopyFromContext* kU_CopyFrom(); + + class KU_ColumnNamesContext : public antlr4::ParserRuleContext { + public: + KU_ColumnNamesContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_SchemaName(); + OC_SchemaNameContext* oC_SchemaName(size_t i); + + + }; + + KU_ColumnNamesContext* kU_ColumnNames(); + + class KU_ScanSourceContext : public antlr4::ParserRuleContext { + public: + KU_ScanSourceContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + KU_FilePathsContext *kU_FilePaths(); + OC_QueryContext *oC_Query(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_ParameterContext *oC_Parameter(); + OC_VariableContext *oC_Variable(); + OC_SchemaNameContext *oC_SchemaName(); + OC_FunctionInvocationContext *oC_FunctionInvocation(); + + + }; + + KU_ScanSourceContext* kU_ScanSource(); + + class KU_CopyFromByColumnContext : public antlr4::ParserRuleContext { + public: + KU_CopyFromByColumnContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COPY(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_SchemaNameContext *oC_SchemaName(); + antlr4::tree::TerminalNode *FROM(); + std::vector StringLiteral(); + antlr4::tree::TerminalNode* StringLiteral(size_t i); + antlr4::tree::TerminalNode *BY(); + antlr4::tree::TerminalNode *COLUMN(); + + + }; + + KU_CopyFromByColumnContext* kU_CopyFromByColumn(); + + class KU_CopyTOContext : public antlr4::ParserRuleContext { + public: + KU_CopyTOContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COPY(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_QueryContext *oC_Query(); + antlr4::tree::TerminalNode *TO(); + antlr4::tree::TerminalNode *StringLiteral(); + KU_OptionsContext *kU_Options(); + + + }; + + KU_CopyTOContext* kU_CopyTO(); + + class KU_ExportDatabaseContext : public antlr4::ParserRuleContext { + public: + KU_ExportDatabaseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *EXPORT(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *DATABASE(); + antlr4::tree::TerminalNode *StringLiteral(); + KU_OptionsContext *kU_Options(); + + + }; + + KU_ExportDatabaseContext* kU_ExportDatabase(); + + class KU_ImportDatabaseContext : public antlr4::ParserRuleContext { + public: + KU_ImportDatabaseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *IMPORT(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *DATABASE(); + antlr4::tree::TerminalNode *StringLiteral(); + + + }; + + KU_ImportDatabaseContext* kU_ImportDatabase(); + + class KU_AttachDatabaseContext : public antlr4::ParserRuleContext { + public: + KU_AttachDatabaseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ATTACH(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *StringLiteral(); + antlr4::tree::TerminalNode *DBTYPE(); + OC_SymbolicNameContext *oC_SymbolicName(); + antlr4::tree::TerminalNode *AS(); + OC_SchemaNameContext *oC_SchemaName(); + KU_OptionsContext *kU_Options(); + + + }; + + KU_AttachDatabaseContext* kU_AttachDatabase(); + + class KU_OptionContext : public antlr4::ParserRuleContext { + public: + KU_OptionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + OC_LiteralContext *oC_Literal(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_OptionContext* kU_Option(); + + class KU_OptionsContext : public antlr4::ParserRuleContext { + public: + KU_OptionsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_Option(); + KU_OptionContext* kU_Option(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_OptionsContext* kU_Options(); + + class KU_DetachDatabaseContext : public antlr4::ParserRuleContext { + public: + KU_DetachDatabaseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DETACH(); + antlr4::tree::TerminalNode *SP(); + OC_SchemaNameContext *oC_SchemaName(); + + + }; + + KU_DetachDatabaseContext* kU_DetachDatabase(); + + class KU_UseDatabaseContext : public antlr4::ParserRuleContext { + public: + KU_UseDatabaseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *USE(); + antlr4::tree::TerminalNode *SP(); + OC_SchemaNameContext *oC_SchemaName(); + + + }; + + KU_UseDatabaseContext* kU_UseDatabase(); + + class KU_StandaloneCallContext : public antlr4::ParserRuleContext { + public: + KU_StandaloneCallContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CALL(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_SymbolicNameContext *oC_SymbolicName(); + OC_ExpressionContext *oC_Expression(); + OC_FunctionInvocationContext *oC_FunctionInvocation(); + + + }; + + KU_StandaloneCallContext* kU_StandaloneCall(); + + class KU_CommentOnContext : public antlr4::ParserRuleContext { + public: + KU_CommentOnContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COMMENT(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *ON(); + antlr4::tree::TerminalNode *TABLE(); + OC_SchemaNameContext *oC_SchemaName(); + antlr4::tree::TerminalNode *IS(); + antlr4::tree::TerminalNode *StringLiteral(); + + + }; + + KU_CommentOnContext* kU_CommentOn(); + + class KU_CreateMacroContext : public antlr4::ParserRuleContext { + public: + KU_CreateMacroContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *MACRO(); + OC_FunctionNameContext *oC_FunctionName(); + antlr4::tree::TerminalNode *AS(); + OC_ExpressionContext *oC_Expression(); + KU_PositionalArgsContext *kU_PositionalArgs(); + std::vector kU_DefaultArg(); + KU_DefaultArgContext* kU_DefaultArg(size_t i); + + + }; + + KU_CreateMacroContext* kU_CreateMacro(); + + class KU_PositionalArgsContext : public antlr4::ParserRuleContext { + public: + KU_PositionalArgsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_SymbolicName(); + OC_SymbolicNameContext* oC_SymbolicName(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_PositionalArgsContext* kU_PositionalArgs(); + + class KU_DefaultArgContext : public antlr4::ParserRuleContext { + public: + KU_DefaultArgContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + antlr4::tree::TerminalNode *COLON(); + OC_LiteralContext *oC_Literal(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_DefaultArgContext* kU_DefaultArg(); + + class KU_FilePathsContext : public antlr4::ParserRuleContext { + public: + KU_FilePathsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector StringLiteral(); + antlr4::tree::TerminalNode* StringLiteral(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *GLOB(); + + + }; + + KU_FilePathsContext* kU_FilePaths(); + + class KU_IfNotExistsContext : public antlr4::ParserRuleContext { + public: + KU_IfNotExistsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *IF(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *NOT(); + antlr4::tree::TerminalNode *EXISTS(); + + + }; + + KU_IfNotExistsContext* kU_IfNotExists(); + + class KU_CreateNodeTableContext : public antlr4::ParserRuleContext { + public: + KU_CreateNodeTableContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *NODE(); + antlr4::tree::TerminalNode *TABLE(); + OC_SchemaNameContext *oC_SchemaName(); + KU_PropertyDefinitionsContext *kU_PropertyDefinitions(); + antlr4::tree::TerminalNode *AS(); + OC_QueryContext *oC_Query(); + KU_IfNotExistsContext *kU_IfNotExists(); + KU_CreateNodeConstraintContext *kU_CreateNodeConstraint(); + + + }; + + KU_CreateNodeTableContext* kU_CreateNodeTable(); + + class KU_CreateRelTableContext : public antlr4::ParserRuleContext { + public: + KU_CreateRelTableContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *REL(); + antlr4::tree::TerminalNode *TABLE(); + OC_SchemaNameContext *oC_SchemaName(); + KU_FromToConnectionsContext *kU_FromToConnections(); + antlr4::tree::TerminalNode *AS(); + OC_QueryContext *oC_Query(); + antlr4::tree::TerminalNode *GROUP(); + KU_IfNotExistsContext *kU_IfNotExists(); + antlr4::tree::TerminalNode *WITH(); + KU_OptionsContext *kU_Options(); + KU_PropertyDefinitionsContext *kU_PropertyDefinitions(); + OC_SymbolicNameContext *oC_SymbolicName(); + + + }; + + KU_CreateRelTableContext* kU_CreateRelTable(); + + class KU_FromToConnectionsContext : public antlr4::ParserRuleContext { + public: + KU_FromToConnectionsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_FromToConnection(); + KU_FromToConnectionContext* kU_FromToConnection(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_FromToConnectionsContext* kU_FromToConnections(); + + class KU_FromToConnectionContext : public antlr4::ParserRuleContext { + public: + KU_FromToConnectionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *FROM(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_SchemaName(); + OC_SchemaNameContext* oC_SchemaName(size_t i); + antlr4::tree::TerminalNode *TO(); + + + }; + + KU_FromToConnectionContext* kU_FromToConnection(); + + class KU_CreateSequenceContext : public antlr4::ParserRuleContext { + public: + KU_CreateSequenceContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *SEQUENCE(); + OC_SchemaNameContext *oC_SchemaName(); + KU_IfNotExistsContext *kU_IfNotExists(); + std::vector kU_SequenceOptions(); + KU_SequenceOptionsContext* kU_SequenceOptions(size_t i); + + + }; + + KU_CreateSequenceContext* kU_CreateSequence(); + + class KU_CreateTypeContext : public antlr4::ParserRuleContext { + public: + KU_CreateTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *TYPE(); + OC_SchemaNameContext *oC_SchemaName(); + antlr4::tree::TerminalNode *AS(); + KU_DataTypeContext *kU_DataType(); + + + }; + + KU_CreateTypeContext* kU_CreateType(); + + class KU_SequenceOptionsContext : public antlr4::ParserRuleContext { + public: + KU_SequenceOptionsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + KU_IncrementByContext *kU_IncrementBy(); + KU_MinValueContext *kU_MinValue(); + KU_MaxValueContext *kU_MaxValue(); + KU_StartWithContext *kU_StartWith(); + KU_CycleContext *kU_Cycle(); + + + }; + + KU_SequenceOptionsContext* kU_SequenceOptions(); + + class KU_WithPasswdContext : public antlr4::ParserRuleContext { + public: + KU_WithPasswdContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *WITH(); + antlr4::tree::TerminalNode *PASSWORD(); + antlr4::tree::TerminalNode *StringLiteral(); + + + }; + + KU_WithPasswdContext* kU_WithPasswd(); + + class KU_CreateUserContext : public antlr4::ParserRuleContext { + public: + KU_CreateUserContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *USER(); + OC_VariableContext *oC_Variable(); + KU_IfNotExistsContext *kU_IfNotExists(); + KU_WithPasswdContext *kU_WithPasswd(); + + + }; + + KU_CreateUserContext* kU_CreateUser(); + + class KU_CreateRoleContext : public antlr4::ParserRuleContext { + public: + KU_CreateRoleContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *ROLE(); + OC_VariableContext *oC_Variable(); + KU_IfNotExistsContext *kU_IfNotExists(); + + + }; + + KU_CreateRoleContext* kU_CreateRole(); + + class KU_IncrementByContext : public antlr4::ParserRuleContext { + public: + KU_IncrementByContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *INCREMENT(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_IntegerLiteralContext *oC_IntegerLiteral(); + antlr4::tree::TerminalNode *BY(); + antlr4::tree::TerminalNode *MINUS(); + + + }; + + KU_IncrementByContext* kU_IncrementBy(); + + class KU_MinValueContext : public antlr4::ParserRuleContext { + public: + KU_MinValueContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *NO(); + antlr4::tree::TerminalNode *SP(); + antlr4::tree::TerminalNode *MINVALUE(); + OC_IntegerLiteralContext *oC_IntegerLiteral(); + antlr4::tree::TerminalNode *MINUS(); + + + }; + + KU_MinValueContext* kU_MinValue(); + + class KU_MaxValueContext : public antlr4::ParserRuleContext { + public: + KU_MaxValueContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *NO(); + antlr4::tree::TerminalNode *SP(); + antlr4::tree::TerminalNode *MAXVALUE(); + OC_IntegerLiteralContext *oC_IntegerLiteral(); + antlr4::tree::TerminalNode *MINUS(); + + + }; + + KU_MaxValueContext* kU_MaxValue(); + + class KU_StartWithContext : public antlr4::ParserRuleContext { + public: + KU_StartWithContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *START(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_IntegerLiteralContext *oC_IntegerLiteral(); + antlr4::tree::TerminalNode *WITH(); + antlr4::tree::TerminalNode *MINUS(); + + + }; + + KU_StartWithContext* kU_StartWith(); + + class KU_CycleContext : public antlr4::ParserRuleContext { + public: + KU_CycleContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CYCLE(); + antlr4::tree::TerminalNode *NO(); + antlr4::tree::TerminalNode *SP(); + + + }; + + KU_CycleContext* kU_Cycle(); + + class KU_IfExistsContext : public antlr4::ParserRuleContext { + public: + KU_IfExistsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *IF(); + antlr4::tree::TerminalNode *SP(); + antlr4::tree::TerminalNode *EXISTS(); + + + }; + + KU_IfExistsContext* kU_IfExists(); + + class KU_DropContext : public antlr4::ParserRuleContext { + public: + KU_DropContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DROP(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_SchemaNameContext *oC_SchemaName(); + antlr4::tree::TerminalNode *TABLE(); + antlr4::tree::TerminalNode *SEQUENCE(); + antlr4::tree::TerminalNode *MACRO(); + KU_IfExistsContext *kU_IfExists(); + + + }; + + KU_DropContext* kU_Drop(); + + class KU_AlterTableContext : public antlr4::ParserRuleContext { + public: + KU_AlterTableContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ALTER(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *TABLE(); + OC_SchemaNameContext *oC_SchemaName(); + KU_AlterOptionsContext *kU_AlterOptions(); + + + }; + + KU_AlterTableContext* kU_AlterTable(); + + class KU_AlterOptionsContext : public antlr4::ParserRuleContext { + public: + KU_AlterOptionsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + KU_AddPropertyContext *kU_AddProperty(); + KU_DropPropertyContext *kU_DropProperty(); + KU_RenameTableContext *kU_RenameTable(); + KU_RenamePropertyContext *kU_RenameProperty(); + KU_AddFromToConnectionContext *kU_AddFromToConnection(); + KU_DropFromToConnectionContext *kU_DropFromToConnection(); + + + }; + + KU_AlterOptionsContext* kU_AlterOptions(); + + class KU_AddPropertyContext : public antlr4::ParserRuleContext { + public: + KU_AddPropertyContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ADD(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_PropertyKeyNameContext *oC_PropertyKeyName(); + KU_DataTypeContext *kU_DataType(); + KU_IfNotExistsContext *kU_IfNotExists(); + KU_DefaultContext *kU_Default(); + + + }; + + KU_AddPropertyContext* kU_AddProperty(); + + class KU_DefaultContext : public antlr4::ParserRuleContext { + public: + KU_DefaultContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DEFAULT(); + antlr4::tree::TerminalNode *SP(); + OC_ExpressionContext *oC_Expression(); + + + }; + + KU_DefaultContext* kU_Default(); + + class KU_DropPropertyContext : public antlr4::ParserRuleContext { + public: + KU_DropPropertyContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DROP(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_PropertyKeyNameContext *oC_PropertyKeyName(); + KU_IfExistsContext *kU_IfExists(); + + + }; + + KU_DropPropertyContext* kU_DropProperty(); + + class KU_RenameTableContext : public antlr4::ParserRuleContext { + public: + KU_RenameTableContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *RENAME(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *TO(); + OC_SchemaNameContext *oC_SchemaName(); + + + }; + + KU_RenameTableContext* kU_RenameTable(); + + class KU_RenamePropertyContext : public antlr4::ParserRuleContext { + public: + KU_RenamePropertyContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *RENAME(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_PropertyKeyName(); + OC_PropertyKeyNameContext* oC_PropertyKeyName(size_t i); + antlr4::tree::TerminalNode *TO(); + + + }; + + KU_RenamePropertyContext* kU_RenameProperty(); + + class KU_AddFromToConnectionContext : public antlr4::ParserRuleContext { + public: + KU_AddFromToConnectionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ADD(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + KU_FromToConnectionContext *kU_FromToConnection(); + KU_IfNotExistsContext *kU_IfNotExists(); + + + }; + + KU_AddFromToConnectionContext* kU_AddFromToConnection(); + + class KU_DropFromToConnectionContext : public antlr4::ParserRuleContext { + public: + KU_DropFromToConnectionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DROP(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + KU_FromToConnectionContext *kU_FromToConnection(); + KU_IfExistsContext *kU_IfExists(); + + + }; + + KU_DropFromToConnectionContext* kU_DropFromToConnection(); + + class KU_ColumnDefinitionsContext : public antlr4::ParserRuleContext { + public: + KU_ColumnDefinitionsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_ColumnDefinition(); + KU_ColumnDefinitionContext* kU_ColumnDefinition(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_ColumnDefinitionsContext* kU_ColumnDefinitions(); + + class KU_ColumnDefinitionContext : public antlr4::ParserRuleContext { + public: + KU_ColumnDefinitionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_PropertyKeyNameContext *oC_PropertyKeyName(); + antlr4::tree::TerminalNode *SP(); + KU_DataTypeContext *kU_DataType(); + + + }; + + KU_ColumnDefinitionContext* kU_ColumnDefinition(); + + class KU_PropertyDefinitionsContext : public antlr4::ParserRuleContext { + public: + KU_PropertyDefinitionsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_PropertyDefinition(); + KU_PropertyDefinitionContext* kU_PropertyDefinition(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_PropertyDefinitionsContext* kU_PropertyDefinitions(); + + class KU_PropertyDefinitionContext : public antlr4::ParserRuleContext { + public: + KU_PropertyDefinitionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + KU_ColumnDefinitionContext *kU_ColumnDefinition(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + KU_DefaultContext *kU_Default(); + antlr4::tree::TerminalNode *PRIMARY(); + antlr4::tree::TerminalNode *KEY(); + + + }; + + KU_PropertyDefinitionContext* kU_PropertyDefinition(); + + class KU_CreateNodeConstraintContext : public antlr4::ParserRuleContext { + public: + KU_CreateNodeConstraintContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *PRIMARY(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *KEY(); + OC_PropertyKeyNameContext *oC_PropertyKeyName(); + + + }; + + KU_CreateNodeConstraintContext* kU_CreateNodeConstraint(); + + class KU_UnionTypeContext : public antlr4::ParserRuleContext { + public: + KU_UnionTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *UNION(); + KU_ColumnDefinitionsContext *kU_ColumnDefinitions(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_UnionTypeContext* kU_UnionType(); + + class KU_StructTypeContext : public antlr4::ParserRuleContext { + public: + KU_StructTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *STRUCT(); + KU_ColumnDefinitionsContext *kU_ColumnDefinitions(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_StructTypeContext* kU_StructType(); + + class KU_MapTypeContext : public antlr4::ParserRuleContext { + public: + KU_MapTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *MAP(); + std::vector kU_DataType(); + KU_DataTypeContext* kU_DataType(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_MapTypeContext* kU_MapType(); + + class KU_DecimalTypeContext : public antlr4::ParserRuleContext { + public: + KU_DecimalTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DECIMAL(); + std::vector oC_IntegerLiteral(); + OC_IntegerLiteralContext* oC_IntegerLiteral(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_DecimalTypeContext* kU_DecimalType(); + + class KU_DataTypeContext : public antlr4::ParserRuleContext { + public: + KU_DataTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + KU_UnionTypeContext *kU_UnionType(); + KU_StructTypeContext *kU_StructType(); + KU_MapTypeContext *kU_MapType(); + KU_DecimalTypeContext *kU_DecimalType(); + KU_DataTypeContext *kU_DataType(); + KU_ListIdentifiersContext *kU_ListIdentifiers(); + + + }; + + KU_DataTypeContext* kU_DataType(); + KU_DataTypeContext* kU_DataType(int precedence); + class KU_ListIdentifiersContext : public antlr4::ParserRuleContext { + public: + KU_ListIdentifiersContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_ListIdentifier(); + KU_ListIdentifierContext* kU_ListIdentifier(size_t i); + + + }; + + KU_ListIdentifiersContext* kU_ListIdentifiers(); + + class KU_ListIdentifierContext : public antlr4::ParserRuleContext { + public: + KU_ListIdentifierContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_IntegerLiteralContext *oC_IntegerLiteral(); + + + }; + + KU_ListIdentifierContext* kU_ListIdentifier(); + + class OC_AnyCypherOptionContext : public antlr4::ParserRuleContext { + public: + OC_AnyCypherOptionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ExplainContext *oC_Explain(); + OC_ProfileContext *oC_Profile(); + + + }; + + OC_AnyCypherOptionContext* oC_AnyCypherOption(); + + class OC_ExplainContext : public antlr4::ParserRuleContext { + public: + OC_ExplainContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *EXPLAIN(); + antlr4::tree::TerminalNode *SP(); + antlr4::tree::TerminalNode *LOGICAL(); + + + }; + + OC_ExplainContext* oC_Explain(); + + class OC_ProfileContext : public antlr4::ParserRuleContext { + public: + OC_ProfileContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *PROFILE(); + + + }; + + OC_ProfileContext* oC_Profile(); + + class KU_TransactionContext : public antlr4::ParserRuleContext { + public: + KU_TransactionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *BEGIN(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *TRANSACTION(); + antlr4::tree::TerminalNode *READ(); + antlr4::tree::TerminalNode *ONLY(); + antlr4::tree::TerminalNode *COMMIT(); + antlr4::tree::TerminalNode *ROLLBACK(); + antlr4::tree::TerminalNode *CHECKPOINT(); + + + }; + + KU_TransactionContext* kU_Transaction(); + + class KU_ExtensionContext : public antlr4::ParserRuleContext { + public: + KU_ExtensionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + KU_LoadExtensionContext *kU_LoadExtension(); + KU_InstallExtensionContext *kU_InstallExtension(); + KU_UninstallExtensionContext *kU_UninstallExtension(); + KU_UpdateExtensionContext *kU_UpdateExtension(); + + + }; + + KU_ExtensionContext* kU_Extension(); + + class KU_LoadExtensionContext : public antlr4::ParserRuleContext { + public: + KU_LoadExtensionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *LOAD(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *StringLiteral(); + OC_VariableContext *oC_Variable(); + antlr4::tree::TerminalNode *EXTENSION(); + + + }; + + KU_LoadExtensionContext* kU_LoadExtension(); + + class KU_InstallExtensionContext : public antlr4::ParserRuleContext { + public: + KU_InstallExtensionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *INSTALL(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_VariableContext *oC_Variable(); + antlr4::tree::TerminalNode *FORCE(); + antlr4::tree::TerminalNode *FROM(); + antlr4::tree::TerminalNode *StringLiteral(); + + + }; + + KU_InstallExtensionContext* kU_InstallExtension(); + + class KU_UninstallExtensionContext : public antlr4::ParserRuleContext { + public: + KU_UninstallExtensionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *UNINSTALL(); + antlr4::tree::TerminalNode *SP(); + OC_VariableContext *oC_Variable(); + + + }; + + KU_UninstallExtensionContext* kU_UninstallExtension(); + + class KU_UpdateExtensionContext : public antlr4::ParserRuleContext { + public: + KU_UpdateExtensionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *UPDATE(); + antlr4::tree::TerminalNode *SP(); + OC_VariableContext *oC_Variable(); + + + }; + + KU_UpdateExtensionContext* kU_UpdateExtension(); + + class OC_QueryContext : public antlr4::ParserRuleContext { + public: + OC_QueryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_RegularQueryContext *oC_RegularQuery(); + + + }; + + OC_QueryContext* oC_Query(); + + class OC_RegularQueryContext : public antlr4::ParserRuleContext { + public: + OC_RegularQueryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SingleQueryContext *oC_SingleQuery(); + std::vector oC_Union(); + OC_UnionContext* oC_Union(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_Return(); + OC_ReturnContext* oC_Return(size_t i); + + + }; + + OC_RegularQueryContext* oC_RegularQuery(); + + class OC_UnionContext : public antlr4::ParserRuleContext { + public: + OC_UnionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *UNION(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *ALL(); + OC_SingleQueryContext *oC_SingleQuery(); + + + }; + + OC_UnionContext* oC_Union(); + + class OC_SingleQueryContext : public antlr4::ParserRuleContext { + public: + OC_SingleQueryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SinglePartQueryContext *oC_SinglePartQuery(); + OC_MultiPartQueryContext *oC_MultiPartQuery(); + + + }; + + OC_SingleQueryContext* oC_SingleQuery(); + + class OC_SinglePartQueryContext : public antlr4::ParserRuleContext { + public: + OC_SinglePartQueryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ReturnContext *oC_Return(); + std::vector oC_ReadingClause(); + OC_ReadingClauseContext* oC_ReadingClause(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_UpdatingClause(); + OC_UpdatingClauseContext* oC_UpdatingClause(size_t i); + + + }; + + OC_SinglePartQueryContext* oC_SinglePartQuery(); + + class OC_MultiPartQueryContext : public antlr4::ParserRuleContext { + public: + OC_MultiPartQueryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SinglePartQueryContext *oC_SinglePartQuery(); + std::vector kU_QueryPart(); + KU_QueryPartContext* kU_QueryPart(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_MultiPartQueryContext* oC_MultiPartQuery(); + + class KU_QueryPartContext : public antlr4::ParserRuleContext { + public: + KU_QueryPartContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_WithContext *oC_With(); + std::vector oC_ReadingClause(); + OC_ReadingClauseContext* oC_ReadingClause(size_t i); + std::vector oC_UpdatingClause(); + OC_UpdatingClauseContext* oC_UpdatingClause(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_QueryPartContext* kU_QueryPart(); + + class OC_UpdatingClauseContext : public antlr4::ParserRuleContext { + public: + OC_UpdatingClauseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_CreateContext *oC_Create(); + OC_MergeContext *oC_Merge(); + OC_SetContext *oC_Set(); + OC_DeleteContext *oC_Delete(); + + + }; + + OC_UpdatingClauseContext* oC_UpdatingClause(); + + class OC_ReadingClauseContext : public antlr4::ParserRuleContext { + public: + OC_ReadingClauseContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_MatchContext *oC_Match(); + OC_UnwindContext *oC_Unwind(); + KU_InQueryCallContext *kU_InQueryCall(); + KU_LoadFromContext *kU_LoadFrom(); + + + }; + + OC_ReadingClauseContext* oC_ReadingClause(); + + class KU_LoadFromContext : public antlr4::ParserRuleContext { + public: + KU_LoadFromContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *LOAD(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *FROM(); + KU_ScanSourceContext *kU_ScanSource(); + antlr4::tree::TerminalNode *WITH(); + antlr4::tree::TerminalNode *HEADERS(); + KU_ColumnDefinitionsContext *kU_ColumnDefinitions(); + KU_OptionsContext *kU_Options(); + OC_WhereContext *oC_Where(); + + + }; + + KU_LoadFromContext* kU_LoadFrom(); + + class OC_YieldItemContext : public antlr4::ParserRuleContext { + public: + OC_YieldItemContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_Variable(); + OC_VariableContext* oC_Variable(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *AS(); + + + }; + + OC_YieldItemContext* oC_YieldItem(); + + class OC_YieldItemsContext : public antlr4::ParserRuleContext { + public: + OC_YieldItemsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_YieldItem(); + OC_YieldItemContext* oC_YieldItem(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_YieldItemsContext* oC_YieldItems(); + + class KU_InQueryCallContext : public antlr4::ParserRuleContext { + public: + KU_InQueryCallContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CALL(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_FunctionInvocationContext *oC_FunctionInvocation(); + OC_WhereContext *oC_Where(); + antlr4::tree::TerminalNode *YIELD(); + OC_YieldItemsContext *oC_YieldItems(); + + + }; + + KU_InQueryCallContext* kU_InQueryCall(); + + class OC_MatchContext : public antlr4::ParserRuleContext { + public: + OC_MatchContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *MATCH(); + OC_PatternContext *oC_Pattern(); + antlr4::tree::TerminalNode *OPTIONAL(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_WhereContext *oC_Where(); + KU_HintContext *kU_Hint(); + + + }; + + OC_MatchContext* oC_Match(); + + class KU_HintContext : public antlr4::ParserRuleContext { + public: + KU_HintContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *HINT(); + antlr4::tree::TerminalNode *SP(); + KU_JoinNodeContext *kU_JoinNode(); + + + }; + + KU_HintContext* kU_Hint(); + + class KU_JoinNodeContext : public antlr4::ParserRuleContext { + public: + KU_JoinNodeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_JoinNode(); + KU_JoinNodeContext* kU_JoinNode(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_SchemaName(); + OC_SchemaNameContext* oC_SchemaName(size_t i); + antlr4::tree::TerminalNode *JOIN(); + std::vector MULTI_JOIN(); + antlr4::tree::TerminalNode* MULTI_JOIN(size_t i); + + + }; + + KU_JoinNodeContext* kU_JoinNode(); + KU_JoinNodeContext* kU_JoinNode(int precedence); + class OC_UnwindContext : public antlr4::ParserRuleContext { + public: + OC_UnwindContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *UNWIND(); + OC_ExpressionContext *oC_Expression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *AS(); + OC_VariableContext *oC_Variable(); + + + }; + + OC_UnwindContext* oC_Unwind(); + + class OC_CreateContext : public antlr4::ParserRuleContext { + public: + OC_CreateContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CREATE(); + OC_PatternContext *oC_Pattern(); + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_CreateContext* oC_Create(); + + class OC_MergeContext : public antlr4::ParserRuleContext { + public: + OC_MergeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *MERGE(); + OC_PatternContext *oC_Pattern(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_MergeAction(); + OC_MergeActionContext* oC_MergeAction(size_t i); + + + }; + + OC_MergeContext* oC_Merge(); + + class OC_MergeActionContext : public antlr4::ParserRuleContext { + public: + OC_MergeActionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ON(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *MATCH(); + OC_SetContext *oC_Set(); + antlr4::tree::TerminalNode *CREATE(); + + + }; + + OC_MergeActionContext* oC_MergeAction(); + + class OC_SetContext : public antlr4::ParserRuleContext { + public: + OC_SetContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *SET(); + std::vector oC_SetItem(); + OC_SetItemContext* oC_SetItem(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_AtomContext *oC_Atom(); + KU_PropertiesContext *kU_Properties(); + + + }; + + OC_SetContext* oC_Set(); + + class OC_SetItemContext : public antlr4::ParserRuleContext { + public: + OC_SetItemContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_PropertyExpressionContext *oC_PropertyExpression(); + OC_ExpressionContext *oC_Expression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_SetItemContext* oC_SetItem(); + + class OC_DeleteContext : public antlr4::ParserRuleContext { + public: + OC_DeleteContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DELETE(); + std::vector oC_Expression(); + OC_ExpressionContext* oC_Expression(size_t i); + antlr4::tree::TerminalNode *DETACH(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_DeleteContext* oC_Delete(); + + class OC_WithContext : public antlr4::ParserRuleContext { + public: + OC_WithContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *WITH(); + OC_ProjectionBodyContext *oC_ProjectionBody(); + OC_WhereContext *oC_Where(); + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_WithContext* oC_With(); + + class OC_ReturnContext : public antlr4::ParserRuleContext { + public: + OC_ReturnContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *RETURN(); + OC_ProjectionBodyContext *oC_ProjectionBody(); + + + }; + + OC_ReturnContext* oC_Return(); + + class OC_ProjectionBodyContext : public antlr4::ParserRuleContext { + public: + OC_ProjectionBodyContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_ProjectionItemsContext *oC_ProjectionItems(); + antlr4::tree::TerminalNode *DISTINCT(); + OC_OrderContext *oC_Order(); + OC_SkipContext *oC_Skip(); + OC_LimitContext *oC_Limit(); + + + }; + + OC_ProjectionBodyContext* oC_ProjectionBody(); + + class OC_ProjectionItemsContext : public antlr4::ParserRuleContext { + public: + OC_ProjectionItemsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *STAR(); + std::vector oC_ProjectionItem(); + OC_ProjectionItemContext* oC_ProjectionItem(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_ProjectionItemsContext* oC_ProjectionItems(); + + class OC_ProjectionItemContext : public antlr4::ParserRuleContext { + public: + OC_ProjectionItemContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ExpressionContext *oC_Expression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *AS(); + OC_VariableContext *oC_Variable(); + + + }; + + OC_ProjectionItemContext* oC_ProjectionItem(); + + class OC_OrderContext : public antlr4::ParserRuleContext { + public: + OC_OrderContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ORDER(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *BY(); + std::vector oC_SortItem(); + OC_SortItemContext* oC_SortItem(size_t i); + + + }; + + OC_OrderContext* oC_Order(); + + class OC_SkipContext : public antlr4::ParserRuleContext { + public: + OC_SkipContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *L_SKIP(); + antlr4::tree::TerminalNode *SP(); + OC_ExpressionContext *oC_Expression(); + + + }; + + OC_SkipContext* oC_Skip(); + + class OC_LimitContext : public antlr4::ParserRuleContext { + public: + OC_LimitContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *LIMIT(); + antlr4::tree::TerminalNode *SP(); + OC_ExpressionContext *oC_Expression(); + + + }; + + OC_LimitContext* oC_Limit(); + + class OC_SortItemContext : public antlr4::ParserRuleContext { + public: + OC_SortItemContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ExpressionContext *oC_Expression(); + antlr4::tree::TerminalNode *ASCENDING(); + antlr4::tree::TerminalNode *ASC(); + antlr4::tree::TerminalNode *DESCENDING(); + antlr4::tree::TerminalNode *DESC(); + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_SortItemContext* oC_SortItem(); + + class OC_WhereContext : public antlr4::ParserRuleContext { + public: + OC_WhereContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *WHERE(); + antlr4::tree::TerminalNode *SP(); + OC_ExpressionContext *oC_Expression(); + + + }; + + OC_WhereContext* oC_Where(); + + class OC_PatternContext : public antlr4::ParserRuleContext { + public: + OC_PatternContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_PatternPart(); + OC_PatternPartContext* oC_PatternPart(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_PatternContext* oC_Pattern(); + + class OC_PatternPartContext : public antlr4::ParserRuleContext { + public: + OC_PatternPartContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_VariableContext *oC_Variable(); + OC_AnonymousPatternPartContext *oC_AnonymousPatternPart(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_PatternPartContext* oC_PatternPart(); + + class OC_AnonymousPatternPartContext : public antlr4::ParserRuleContext { + public: + OC_AnonymousPatternPartContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_PatternElementContext *oC_PatternElement(); + + + }; + + OC_AnonymousPatternPartContext* oC_AnonymousPatternPart(); + + class OC_PatternElementContext : public antlr4::ParserRuleContext { + public: + OC_PatternElementContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_NodePatternContext *oC_NodePattern(); + std::vector oC_PatternElementChain(); + OC_PatternElementChainContext* oC_PatternElementChain(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_PatternElementContext *oC_PatternElement(); + + + }; + + OC_PatternElementContext* oC_PatternElement(); + + class OC_NodePatternContext : public antlr4::ParserRuleContext { + public: + OC_NodePatternContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_VariableContext *oC_Variable(); + OC_NodeLabelsContext *oC_NodeLabels(); + KU_PropertiesContext *kU_Properties(); + + + }; + + OC_NodePatternContext* oC_NodePattern(); + + class OC_PatternElementChainContext : public antlr4::ParserRuleContext { + public: + OC_PatternElementChainContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_RelationshipPatternContext *oC_RelationshipPattern(); + OC_NodePatternContext *oC_NodePattern(); + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_PatternElementChainContext* oC_PatternElementChain(); + + class OC_RelationshipPatternContext : public antlr4::ParserRuleContext { + public: + OC_RelationshipPatternContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_LeftArrowHeadContext *oC_LeftArrowHead(); + std::vector oC_Dash(); + OC_DashContext* oC_Dash(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_RelationshipDetailContext *oC_RelationshipDetail(); + OC_RightArrowHeadContext *oC_RightArrowHead(); + + + }; + + OC_RelationshipPatternContext* oC_RelationshipPattern(); + + class OC_RelationshipDetailContext : public antlr4::ParserRuleContext { + public: + OC_RelationshipDetailContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_VariableContext *oC_Variable(); + OC_RelationshipTypesContext *oC_RelationshipTypes(); + KU_RecursiveDetailContext *kU_RecursiveDetail(); + KU_PropertiesContext *kU_Properties(); + + + }; + + OC_RelationshipDetailContext* oC_RelationshipDetail(); + + class KU_PropertiesContext : public antlr4::ParserRuleContext { + public: + KU_PropertiesContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector oC_PropertyKeyName(); + OC_PropertyKeyNameContext* oC_PropertyKeyName(size_t i); + std::vector COLON(); + antlr4::tree::TerminalNode* COLON(size_t i); + std::vector oC_Expression(); + OC_ExpressionContext* oC_Expression(size_t i); + + + }; + + KU_PropertiesContext* kU_Properties(); + + class OC_RelationshipTypesContext : public antlr4::ParserRuleContext { + public: + OC_RelationshipTypesContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector COLON(); + antlr4::tree::TerminalNode* COLON(size_t i); + std::vector oC_RelTypeName(); + OC_RelTypeNameContext* oC_RelTypeName(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_RelationshipTypesContext* oC_RelationshipTypes(); + + class OC_NodeLabelsContext : public antlr4::ParserRuleContext { + public: + OC_NodeLabelsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector COLON(); + antlr4::tree::TerminalNode* COLON(size_t i); + std::vector oC_LabelName(); + OC_LabelNameContext* oC_LabelName(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_NodeLabelsContext* oC_NodeLabels(); + + class KU_RecursiveDetailContext : public antlr4::ParserRuleContext { + public: + KU_RecursiveDetailContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *STAR(); + KU_RecursiveTypeContext *kU_RecursiveType(); + OC_RangeLiteralContext *oC_RangeLiteral(); + KU_RecursiveComprehensionContext *kU_RecursiveComprehension(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_RecursiveDetailContext* kU_RecursiveDetail(); + + class KU_RecursiveTypeContext : public antlr4::ParserRuleContext { + public: + KU_RecursiveTypeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *WSHORTEST(); + OC_PropertyKeyNameContext *oC_PropertyKeyName(); + antlr4::tree::TerminalNode *ALL(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *SHORTEST(); + antlr4::tree::TerminalNode *TRAIL(); + antlr4::tree::TerminalNode *ACYCLIC(); + + + }; + + KU_RecursiveTypeContext* kU_RecursiveType(); + + class OC_RangeLiteralContext : public antlr4::ParserRuleContext { + public: + OC_RangeLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DOTDOT(); + OC_LowerBoundContext *oC_LowerBound(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_UpperBoundContext *oC_UpperBound(); + OC_IntegerLiteralContext *oC_IntegerLiteral(); + + + }; + + OC_RangeLiteralContext* oC_RangeLiteral(); + + class KU_RecursiveComprehensionContext : public antlr4::ParserRuleContext { + public: + KU_RecursiveComprehensionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_Variable(); + OC_VariableContext* oC_Variable(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_WhereContext *oC_Where(); + std::vector kU_RecursiveProjectionItems(); + KU_RecursiveProjectionItemsContext* kU_RecursiveProjectionItems(size_t i); + + + }; + + KU_RecursiveComprehensionContext* kU_RecursiveComprehension(); + + class KU_RecursiveProjectionItemsContext : public antlr4::ParserRuleContext { + public: + KU_RecursiveProjectionItemsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_ProjectionItemsContext *oC_ProjectionItems(); + + + }; + + KU_RecursiveProjectionItemsContext* kU_RecursiveProjectionItems(); + + class OC_LowerBoundContext : public antlr4::ParserRuleContext { + public: + OC_LowerBoundContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DecimalInteger(); + + + }; + + OC_LowerBoundContext* oC_LowerBound(); + + class OC_UpperBoundContext : public antlr4::ParserRuleContext { + public: + OC_UpperBoundContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DecimalInteger(); + + + }; + + OC_UpperBoundContext* oC_UpperBound(); + + class OC_LabelNameContext : public antlr4::ParserRuleContext { + public: + OC_LabelNameContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SchemaNameContext *oC_SchemaName(); + + + }; + + OC_LabelNameContext* oC_LabelName(); + + class OC_RelTypeNameContext : public antlr4::ParserRuleContext { + public: + OC_RelTypeNameContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SchemaNameContext *oC_SchemaName(); + + + }; + + OC_RelTypeNameContext* oC_RelTypeName(); + + class OC_ExpressionContext : public antlr4::ParserRuleContext { + public: + OC_ExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_OrExpressionContext *oC_OrExpression(); + + + }; + + OC_ExpressionContext* oC_Expression(); + + class OC_OrExpressionContext : public antlr4::ParserRuleContext { + public: + OC_OrExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_XorExpression(); + OC_XorExpressionContext* oC_XorExpression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector OR(); + antlr4::tree::TerminalNode* OR(size_t i); + + + }; + + OC_OrExpressionContext* oC_OrExpression(); + + class OC_XorExpressionContext : public antlr4::ParserRuleContext { + public: + OC_XorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_AndExpression(); + OC_AndExpressionContext* oC_AndExpression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector XOR(); + antlr4::tree::TerminalNode* XOR(size_t i); + + + }; + + OC_XorExpressionContext* oC_XorExpression(); + + class OC_AndExpressionContext : public antlr4::ParserRuleContext { + public: + OC_AndExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_NotExpression(); + OC_NotExpressionContext* oC_NotExpression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + std::vector AND(); + antlr4::tree::TerminalNode* AND(size_t i); + + + }; + + OC_AndExpressionContext* oC_AndExpression(); + + class OC_NotExpressionContext : public antlr4::ParserRuleContext { + public: + OC_NotExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ComparisonExpressionContext *oC_ComparisonExpression(); + std::vector NOT(); + antlr4::tree::TerminalNode* NOT(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_NotExpressionContext* oC_NotExpression(); + + class OC_ComparisonExpressionContext : public antlr4::ParserRuleContext { + public: + antlr4::Token *invalid_not_equalToken = nullptr; + OC_ComparisonExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_BitwiseOrOperatorExpression(); + KU_BitwiseOrOperatorExpressionContext* kU_BitwiseOrOperatorExpression(size_t i); + std::vector kU_ComparisonOperator(); + KU_ComparisonOperatorContext* kU_ComparisonOperator(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *INVALID_NOT_EQUAL(); + + + }; + + OC_ComparisonExpressionContext* oC_ComparisonExpression(); + + class KU_ComparisonOperatorContext : public antlr4::ParserRuleContext { + public: + KU_ComparisonOperatorContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + + + }; + + KU_ComparisonOperatorContext* kU_ComparisonOperator(); + + class KU_BitwiseOrOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + KU_BitwiseOrOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_BitwiseAndOperatorExpression(); + KU_BitwiseAndOperatorExpressionContext* kU_BitwiseAndOperatorExpression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_BitwiseOrOperatorExpressionContext* kU_BitwiseOrOperatorExpression(); + + class KU_BitwiseAndOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + KU_BitwiseAndOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_BitShiftOperatorExpression(); + KU_BitShiftOperatorExpressionContext* kU_BitShiftOperatorExpression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_BitwiseAndOperatorExpressionContext* kU_BitwiseAndOperatorExpression(); + + class KU_BitShiftOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + KU_BitShiftOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_AddOrSubtractExpression(); + OC_AddOrSubtractExpressionContext* oC_AddOrSubtractExpression(size_t i); + std::vector kU_BitShiftOperator(); + KU_BitShiftOperatorContext* kU_BitShiftOperator(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_BitShiftOperatorExpressionContext* kU_BitShiftOperatorExpression(); + + class KU_BitShiftOperatorContext : public antlr4::ParserRuleContext { + public: + KU_BitShiftOperatorContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + + + }; + + KU_BitShiftOperatorContext* kU_BitShiftOperator(); + + class OC_AddOrSubtractExpressionContext : public antlr4::ParserRuleContext { + public: + OC_AddOrSubtractExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_MultiplyDivideModuloExpression(); + OC_MultiplyDivideModuloExpressionContext* oC_MultiplyDivideModuloExpression(size_t i); + std::vector kU_AddOrSubtractOperator(); + KU_AddOrSubtractOperatorContext* kU_AddOrSubtractOperator(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_AddOrSubtractExpressionContext* oC_AddOrSubtractExpression(); + + class KU_AddOrSubtractOperatorContext : public antlr4::ParserRuleContext { + public: + KU_AddOrSubtractOperatorContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *MINUS(); + + + }; + + KU_AddOrSubtractOperatorContext* kU_AddOrSubtractOperator(); + + class OC_MultiplyDivideModuloExpressionContext : public antlr4::ParserRuleContext { + public: + OC_MultiplyDivideModuloExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_PowerOfExpression(); + OC_PowerOfExpressionContext* oC_PowerOfExpression(size_t i); + std::vector kU_MultiplyDivideModuloOperator(); + KU_MultiplyDivideModuloOperatorContext* kU_MultiplyDivideModuloOperator(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_MultiplyDivideModuloExpressionContext* oC_MultiplyDivideModuloExpression(); + + class KU_MultiplyDivideModuloOperatorContext : public antlr4::ParserRuleContext { + public: + KU_MultiplyDivideModuloOperatorContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *STAR(); + + + }; + + KU_MultiplyDivideModuloOperatorContext* kU_MultiplyDivideModuloOperator(); + + class OC_PowerOfExpressionContext : public antlr4::ParserRuleContext { + public: + OC_PowerOfExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_StringListNullOperatorExpression(); + OC_StringListNullOperatorExpressionContext* oC_StringListNullOperatorExpression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_PowerOfExpressionContext* oC_PowerOfExpression(); + + class OC_StringListNullOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + OC_StringListNullOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_UnaryAddSubtractOrFactorialExpressionContext *oC_UnaryAddSubtractOrFactorialExpression(); + OC_StringOperatorExpressionContext *oC_StringOperatorExpression(); + OC_NullOperatorExpressionContext *oC_NullOperatorExpression(); + std::vector oC_ListOperatorExpression(); + OC_ListOperatorExpressionContext* oC_ListOperatorExpression(size_t i); + + + }; + + OC_StringListNullOperatorExpressionContext* oC_StringListNullOperatorExpression(); + + class OC_ListOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + OC_ListOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *IN(); + OC_PropertyOrLabelsExpressionContext *oC_PropertyOrLabelsExpression(); + std::vector oC_Expression(); + OC_ExpressionContext* oC_Expression(size_t i); + antlr4::tree::TerminalNode *COLON(); + antlr4::tree::TerminalNode *DOTDOT(); + + + }; + + OC_ListOperatorExpressionContext* oC_ListOperatorExpression(); + + class OC_StringOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + OC_StringOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_PropertyOrLabelsExpressionContext *oC_PropertyOrLabelsExpression(); + OC_RegularExpressionContext *oC_RegularExpression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *STARTS(); + antlr4::tree::TerminalNode *WITH(); + antlr4::tree::TerminalNode *ENDS(); + antlr4::tree::TerminalNode *CONTAINS(); + + + }; + + OC_StringOperatorExpressionContext* oC_StringOperatorExpression(); + + class OC_RegularExpressionContext : public antlr4::ParserRuleContext { + public: + OC_RegularExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_RegularExpressionContext* oC_RegularExpression(); + + class OC_NullOperatorExpressionContext : public antlr4::ParserRuleContext { + public: + OC_NullOperatorExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *IS(); + antlr4::tree::TerminalNode *NULL_(); + antlr4::tree::TerminalNode *NOT(); + + + }; + + OC_NullOperatorExpressionContext* oC_NullOperatorExpression(); + + class OC_UnaryAddSubtractOrFactorialExpressionContext : public antlr4::ParserRuleContext { + public: + OC_UnaryAddSubtractOrFactorialExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_PropertyOrLabelsExpressionContext *oC_PropertyOrLabelsExpression(); + std::vector MINUS(); + antlr4::tree::TerminalNode* MINUS(size_t i); + antlr4::tree::TerminalNode *FACTORIAL(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_UnaryAddSubtractOrFactorialExpressionContext* oC_UnaryAddSubtractOrFactorialExpression(); + + class OC_PropertyOrLabelsExpressionContext : public antlr4::ParserRuleContext { + public: + OC_PropertyOrLabelsExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_AtomContext *oC_Atom(); + std::vector oC_PropertyLookup(); + OC_PropertyLookupContext* oC_PropertyLookup(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_PropertyOrLabelsExpressionContext* oC_PropertyOrLabelsExpression(); + + class OC_AtomContext : public antlr4::ParserRuleContext { + public: + OC_AtomContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_LiteralContext *oC_Literal(); + OC_ParameterContext *oC_Parameter(); + OC_CaseExpressionContext *oC_CaseExpression(); + OC_ParenthesizedExpressionContext *oC_ParenthesizedExpression(); + OC_FunctionInvocationContext *oC_FunctionInvocation(); + OC_PathPatternsContext *oC_PathPatterns(); + OC_ExistCountSubqueryContext *oC_ExistCountSubquery(); + OC_VariableContext *oC_Variable(); + OC_QuantifierContext *oC_Quantifier(); + + + }; + + OC_AtomContext* oC_Atom(); + + class OC_QuantifierContext : public antlr4::ParserRuleContext { + public: + OC_QuantifierContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ALL(); + OC_FilterExpressionContext *oC_FilterExpression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *ANY(); + antlr4::tree::TerminalNode *NONE(); + antlr4::tree::TerminalNode *SINGLE(); + + + }; + + OC_QuantifierContext* oC_Quantifier(); + + class OC_FilterExpressionContext : public antlr4::ParserRuleContext { + public: + OC_FilterExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_IdInCollContext *oC_IdInColl(); + antlr4::tree::TerminalNode *SP(); + OC_WhereContext *oC_Where(); + + + }; + + OC_FilterExpressionContext* oC_FilterExpression(); + + class OC_IdInCollContext : public antlr4::ParserRuleContext { + public: + OC_IdInCollContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_VariableContext *oC_Variable(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *IN(); + OC_ExpressionContext *oC_Expression(); + + + }; + + OC_IdInCollContext* oC_IdInColl(); + + class OC_LiteralContext : public antlr4::ParserRuleContext { + public: + OC_LiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_NumberLiteralContext *oC_NumberLiteral(); + antlr4::tree::TerminalNode *StringLiteral(); + OC_BooleanLiteralContext *oC_BooleanLiteral(); + antlr4::tree::TerminalNode *NULL_(); + OC_ListLiteralContext *oC_ListLiteral(); + KU_StructLiteralContext *kU_StructLiteral(); + + + }; + + OC_LiteralContext* oC_Literal(); + + class OC_BooleanLiteralContext : public antlr4::ParserRuleContext { + public: + OC_BooleanLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *TRUE(); + antlr4::tree::TerminalNode *FALSE(); + + + }; + + OC_BooleanLiteralContext* oC_BooleanLiteral(); + + class OC_ListLiteralContext : public antlr4::ParserRuleContext { + public: + OC_ListLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_ExpressionContext *oC_Expression(); + std::vector kU_ListEntry(); + KU_ListEntryContext* kU_ListEntry(size_t i); + + + }; + + OC_ListLiteralContext* oC_ListLiteral(); + + class KU_ListEntryContext : public antlr4::ParserRuleContext { + public: + KU_ListEntryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *SP(); + OC_ExpressionContext *oC_Expression(); + + + }; + + KU_ListEntryContext* kU_ListEntry(); + + class KU_StructLiteralContext : public antlr4::ParserRuleContext { + public: + KU_StructLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector kU_StructField(); + KU_StructFieldContext* kU_StructField(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_StructLiteralContext* kU_StructLiteral(); + + class KU_StructFieldContext : public antlr4::ParserRuleContext { + public: + KU_StructFieldContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COLON(); + OC_ExpressionContext *oC_Expression(); + OC_SymbolicNameContext *oC_SymbolicName(); + antlr4::tree::TerminalNode *StringLiteral(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_StructFieldContext* kU_StructField(); + + class OC_ParenthesizedExpressionContext : public antlr4::ParserRuleContext { + public: + OC_ParenthesizedExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ExpressionContext *oC_Expression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_ParenthesizedExpressionContext* oC_ParenthesizedExpression(); + + class OC_FunctionInvocationContext : public antlr4::ParserRuleContext { + public: + OC_FunctionInvocationContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COUNT(); + antlr4::tree::TerminalNode *STAR(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *CAST(); + std::vector kU_FunctionParameter(); + KU_FunctionParameterContext* kU_FunctionParameter(size_t i); + antlr4::tree::TerminalNode *AS(); + KU_DataTypeContext *kU_DataType(); + OC_FunctionNameContext *oC_FunctionName(); + antlr4::tree::TerminalNode *DISTINCT(); + + + }; + + OC_FunctionInvocationContext* oC_FunctionInvocation(); + + class OC_FunctionNameContext : public antlr4::ParserRuleContext { + public: + OC_FunctionNameContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + + + }; + + OC_FunctionNameContext* oC_FunctionName(); + + class KU_FunctionParameterContext : public antlr4::ParserRuleContext { + public: + KU_FunctionParameterContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_ExpressionContext *oC_Expression(); + OC_SymbolicNameContext *oC_SymbolicName(); + antlr4::tree::TerminalNode *COLON(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + KU_LambdaParameterContext *kU_LambdaParameter(); + + + }; + + KU_FunctionParameterContext* kU_FunctionParameter(); + + class KU_LambdaParameterContext : public antlr4::ParserRuleContext { + public: + KU_LambdaParameterContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + KU_LambdaVarsContext *kU_LambdaVars(); + antlr4::tree::TerminalNode *MINUS(); + OC_ExpressionContext *oC_Expression(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_LambdaParameterContext* kU_LambdaParameter(); + + class KU_LambdaVarsContext : public antlr4::ParserRuleContext { + public: + KU_LambdaVarsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + std::vector oC_SymbolicName(); + OC_SymbolicNameContext* oC_SymbolicName(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + KU_LambdaVarsContext* kU_LambdaVars(); + + class OC_PathPatternsContext : public antlr4::ParserRuleContext { + public: + OC_PathPatternsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_NodePatternContext *oC_NodePattern(); + std::vector oC_PatternElementChain(); + OC_PatternElementChainContext* oC_PatternElementChain(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_PathPatternsContext* oC_PathPatterns(); + + class OC_ExistCountSubqueryContext : public antlr4::ParserRuleContext { + public: + OC_ExistCountSubqueryContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *MATCH(); + OC_PatternContext *oC_Pattern(); + antlr4::tree::TerminalNode *EXISTS(); + antlr4::tree::TerminalNode *COUNT(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + OC_WhereContext *oC_Where(); + KU_HintContext *kU_Hint(); + + + }; + + OC_ExistCountSubqueryContext* oC_ExistCountSubquery(); + + class OC_PropertyLookupContext : public antlr4::ParserRuleContext { + public: + OC_PropertyLookupContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_PropertyKeyNameContext *oC_PropertyKeyName(); + antlr4::tree::TerminalNode *STAR(); + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_PropertyLookupContext* oC_PropertyLookup(); + + class OC_CaseExpressionContext : public antlr4::ParserRuleContext { + public: + OC_CaseExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *END(); + antlr4::tree::TerminalNode *ELSE(); + std::vector oC_Expression(); + OC_ExpressionContext* oC_Expression(size_t i); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + antlr4::tree::TerminalNode *CASE(); + std::vector oC_CaseAlternative(); + OC_CaseAlternativeContext* oC_CaseAlternative(size_t i); + + + }; + + OC_CaseExpressionContext* oC_CaseExpression(); + + class OC_CaseAlternativeContext : public antlr4::ParserRuleContext { + public: + OC_CaseAlternativeContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *WHEN(); + std::vector oC_Expression(); + OC_ExpressionContext* oC_Expression(size_t i); + antlr4::tree::TerminalNode *THEN(); + std::vector SP(); + antlr4::tree::TerminalNode* SP(size_t i); + + + }; + + OC_CaseAlternativeContext* oC_CaseAlternative(); + + class OC_VariableContext : public antlr4::ParserRuleContext { + public: + OC_VariableContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + + + }; + + OC_VariableContext* oC_Variable(); + + class OC_NumberLiteralContext : public antlr4::ParserRuleContext { + public: + OC_NumberLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_DoubleLiteralContext *oC_DoubleLiteral(); + OC_IntegerLiteralContext *oC_IntegerLiteral(); + + + }; + + OC_NumberLiteralContext* oC_NumberLiteral(); + + class OC_ParameterContext : public antlr4::ParserRuleContext { + public: + OC_ParameterContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + antlr4::tree::TerminalNode *DecimalInteger(); + + + }; + + OC_ParameterContext* oC_Parameter(); + + class OC_PropertyExpressionContext : public antlr4::ParserRuleContext { + public: + OC_PropertyExpressionContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_AtomContext *oC_Atom(); + OC_PropertyLookupContext *oC_PropertyLookup(); + antlr4::tree::TerminalNode *SP(); + + + }; + + OC_PropertyExpressionContext* oC_PropertyExpression(); + + class OC_PropertyKeyNameContext : public antlr4::ParserRuleContext { + public: + OC_PropertyKeyNameContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SchemaNameContext *oC_SchemaName(); + + + }; + + OC_PropertyKeyNameContext* oC_PropertyKeyName(); + + class OC_IntegerLiteralContext : public antlr4::ParserRuleContext { + public: + OC_IntegerLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DecimalInteger(); + + + }; + + OC_IntegerLiteralContext* oC_IntegerLiteral(); + + class OC_DoubleLiteralContext : public antlr4::ParserRuleContext { + public: + OC_DoubleLiteralContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *ExponentDecimalReal(); + antlr4::tree::TerminalNode *RegularDecimalReal(); + + + }; + + OC_DoubleLiteralContext* oC_DoubleLiteral(); + + class OC_SchemaNameContext : public antlr4::ParserRuleContext { + public: + OC_SchemaNameContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + OC_SymbolicNameContext *oC_SymbolicName(); + + + }; + + OC_SchemaNameContext* oC_SchemaName(); + + class OC_SymbolicNameContext : public antlr4::ParserRuleContext { + public: + antlr4::Token *escapedsymbolicnameToken = nullptr; + OC_SymbolicNameContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *UnescapedSymbolicName(); + antlr4::tree::TerminalNode *EscapedSymbolicName(); + antlr4::tree::TerminalNode *HexLetter(); + KU_NonReservedKeywordsContext *kU_NonReservedKeywords(); + + + }; + + OC_SymbolicNameContext* oC_SymbolicName(); + + class KU_NonReservedKeywordsContext : public antlr4::ParserRuleContext { + public: + KU_NonReservedKeywordsContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *COMMENT(); + antlr4::tree::TerminalNode *ADD(); + antlr4::tree::TerminalNode *ALTER(); + antlr4::tree::TerminalNode *AS(); + antlr4::tree::TerminalNode *ATTACH(); + antlr4::tree::TerminalNode *BEGIN(); + antlr4::tree::TerminalNode *BY(); + antlr4::tree::TerminalNode *CALL(); + antlr4::tree::TerminalNode *CHECKPOINT(); + antlr4::tree::TerminalNode *COMMIT(); + antlr4::tree::TerminalNode *CONTAINS(); + antlr4::tree::TerminalNode *COPY(); + antlr4::tree::TerminalNode *COUNT(); + antlr4::tree::TerminalNode *CYCLE(); + antlr4::tree::TerminalNode *DATABASE(); + antlr4::tree::TerminalNode *DECIMAL(); + antlr4::tree::TerminalNode *DELETE(); + antlr4::tree::TerminalNode *DETACH(); + antlr4::tree::TerminalNode *DROP(); + antlr4::tree::TerminalNode *EXPLAIN(); + antlr4::tree::TerminalNode *EXPORT(); + antlr4::tree::TerminalNode *EXTENSION(); + antlr4::tree::TerminalNode *FORCE(); + antlr4::tree::TerminalNode *GRAPH(); + antlr4::tree::TerminalNode *IF(); + antlr4::tree::TerminalNode *IS(); + antlr4::tree::TerminalNode *IMPORT(); + antlr4::tree::TerminalNode *INCREMENT(); + antlr4::tree::TerminalNode *KEY(); + antlr4::tree::TerminalNode *LOAD(); + antlr4::tree::TerminalNode *LOGICAL(); + antlr4::tree::TerminalNode *MATCH(); + antlr4::tree::TerminalNode *MAXVALUE(); + antlr4::tree::TerminalNode *MERGE(); + antlr4::tree::TerminalNode *MINVALUE(); + antlr4::tree::TerminalNode *NO(); + antlr4::tree::TerminalNode *NODE(); + antlr4::tree::TerminalNode *PROJECT(); + antlr4::tree::TerminalNode *READ(); + antlr4::tree::TerminalNode *REL(); + antlr4::tree::TerminalNode *RENAME(); + antlr4::tree::TerminalNode *RETURN(); + antlr4::tree::TerminalNode *ROLLBACK(); + antlr4::tree::TerminalNode *ROLE(); + antlr4::tree::TerminalNode *SEQUENCE(); + antlr4::tree::TerminalNode *SET(); + antlr4::tree::TerminalNode *START(); + antlr4::tree::TerminalNode *STRUCT(); + antlr4::tree::TerminalNode *L_SKIP(); + antlr4::tree::TerminalNode *LIMIT(); + antlr4::tree::TerminalNode *TRANSACTION(); + antlr4::tree::TerminalNode *TYPE(); + antlr4::tree::TerminalNode *USE(); + antlr4::tree::TerminalNode *UNINSTALL(); + antlr4::tree::TerminalNode *UPDATE(); + antlr4::tree::TerminalNode *WRITE(); + antlr4::tree::TerminalNode *FROM(); + antlr4::tree::TerminalNode *TO(); + antlr4::tree::TerminalNode *YIELD(); + antlr4::tree::TerminalNode *USER(); + antlr4::tree::TerminalNode *PASSWORD(); + antlr4::tree::TerminalNode *MAP(); + + + }; + + KU_NonReservedKeywordsContext* kU_NonReservedKeywords(); + + class OC_LeftArrowHeadContext : public antlr4::ParserRuleContext { + public: + OC_LeftArrowHeadContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + + + }; + + OC_LeftArrowHeadContext* oC_LeftArrowHead(); + + class OC_RightArrowHeadContext : public antlr4::ParserRuleContext { + public: + OC_RightArrowHeadContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + + + }; + + OC_RightArrowHeadContext* oC_RightArrowHead(); + + class OC_DashContext : public antlr4::ParserRuleContext { + public: + OC_DashContext(antlr4::ParserRuleContext *parent, size_t invokingState); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *MINUS(); + + + }; + + OC_DashContext* oC_Dash(); + + + bool sempred(antlr4::RuleContext *_localctx, size_t ruleIndex, size_t predicateIndex) override; + + bool kU_DataTypeSempred(KU_DataTypeContext *_localctx, size_t predicateIndex); + bool kU_JoinNodeSempred(KU_JoinNodeContext *_localctx, size_t predicateIndex); + + // By default the static state used to implement the parser is lazily initialized during the first + // call to the constructor. You can call this function if you wish to initialize the static state + // ahead of time. + static void initialize(); + +private: + + virtual void notifyQueryNotConcludeWithReturn(antlr4::Token* startToken) {}; + virtual void notifyNodePatternWithoutParentheses(std::string nodeName, antlr4::Token* startToken) {}; + virtual void notifyInvalidNotEqualOperator(antlr4::Token* startToken) {}; + virtual void notifyEmptyToken(antlr4::Token* startToken) {}; + virtual void notifyReturnNotAtEnd(antlr4::Token* startToken) {}; + virtual void notifyNonBinaryComparison(antlr4::Token* startToken) {}; + +}; + diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/CMakeLists.txt b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/CMakeLists.txt new file mode 100644 index 0000000000..887e19257b --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/CMakeLists.txt @@ -0,0 +1,337 @@ +#------------------- LIBRARY CONFIGURATION ------------------------------------ + +set(ANTLR4_RUNTIME antlr4_runtime CACHE INTERNAL "") +set(ANTLR4_RUNTIME_SRC + src/antlr4-common.h + src/antlr4-runtime.h + src/ANTLRErrorListener.cpp + src/ANTLRErrorListener.h + src/ANTLRErrorStrategy.cpp + src/ANTLRErrorStrategy.h + src/ANTLRFileStream.cpp + src/ANTLRFileStream.h + src/ANTLRInputStream.cpp + src/ANTLRInputStream.h + src/atn/ActionTransition.cpp + src/atn/ActionTransition.h + src/atn/AmbiguityInfo.cpp + src/atn/AmbiguityInfo.h + src/atn/ArrayPredictionContext.cpp + src/atn/ArrayPredictionContext.h + src/atn/ATNConfig.cpp + src/atn/ATNConfig.h + src/atn/ATNConfigSet.cpp + src/atn/ATNConfigSet.h + src/atn/ATN.cpp + src/atn/ATNDeserializationOptions.cpp + src/atn/ATNDeserializationOptions.h + src/atn/ATNDeserializer.cpp + src/atn/ATNDeserializer.h + src/atn/ATN.h + src/atn/ATNSimulator.cpp + src/atn/ATNSimulator.h + src/atn/ATNState.cpp + src/atn/ATNState.h + src/atn/ATNStateType.cpp + src/atn/ATNStateType.h + src/atn/ATNType.h + src/atn/AtomTransition.cpp + src/atn/AtomTransition.h + src/atn/BasicBlockStartState.h + src/atn/BasicState.h + src/atn/BlockEndState.h + src/atn/BlockStartState.h + src/atn/ContextSensitivityInfo.cpp + src/atn/ContextSensitivityInfo.h + src/atn/DecisionEventInfo.cpp + src/atn/DecisionEventInfo.h + src/atn/DecisionInfo.cpp + src/atn/DecisionInfo.h + src/atn/DecisionState.cpp + src/atn/DecisionState.h + src/atn/EpsilonTransition.cpp + src/atn/EpsilonTransition.h + src/atn/ErrorInfo.cpp + src/atn/ErrorInfo.h + src/atn/HashUtils.h + src/atn/LexerAction.cpp + src/atn/LexerActionExecutor.cpp + src/atn/LexerActionExecutor.h + src/atn/LexerAction.h + src/atn/LexerActionType.h + src/atn/LexerATNConfig.cpp + src/atn/LexerATNConfig.h + src/atn/LexerATNSimulator.cpp + src/atn/LexerATNSimulator.h + src/atn/LexerChannelAction.cpp + src/atn/LexerChannelAction.h + src/atn/LexerCustomAction.cpp + src/atn/LexerCustomAction.h + src/atn/LexerIndexedCustomAction.cpp + src/atn/LexerIndexedCustomAction.h + src/atn/LexerModeAction.cpp + src/atn/LexerModeAction.h + src/atn/LexerMoreAction.cpp + src/atn/LexerMoreAction.h + src/atn/LexerPopModeAction.cpp + src/atn/LexerPopModeAction.h + src/atn/LexerPushModeAction.cpp + src/atn/LexerPushModeAction.h + src/atn/LexerSkipAction.cpp + src/atn/LexerSkipAction.h + src/atn/LexerTypeAction.cpp + src/atn/LexerTypeAction.h + src/atn/LL1Analyzer.cpp + src/atn/LL1Analyzer.h + src/atn/LookaheadEventInfo.cpp + src/atn/LookaheadEventInfo.h + src/atn/LoopEndState.h + src/atn/NotSetTransition.cpp + src/atn/NotSetTransition.h + src/atn/OrderedATNConfigSet.cpp + src/atn/OrderedATNConfigSet.h + src/atn/ParseInfo.cpp + src/atn/ParseInfo.h + src/atn/ParserATNSimulator.cpp + src/atn/ParserATNSimulator.h + src/atn/ParserATNSimulatorOptions.h + src/atn/PlusBlockStartState.h + src/atn/PlusLoopbackState.h + src/atn/PrecedencePredicateTransition.cpp + src/atn/PrecedencePredicateTransition.h + src/atn/PredicateEvalInfo.cpp + src/atn/PredicateEvalInfo.h + src/atn/PredicateTransition.cpp + src/atn/PredicateTransition.h + src/atn/PredictionContextCache.cpp + src/atn/PredictionContextCache.h + src/atn/PredictionContext.cpp + src/atn/PredictionContext.h + src/atn/PredictionContextMergeCache.cpp + src/atn/PredictionContextMergeCache.h + src/atn/PredictionContextMergeCacheOptions.h + src/atn/PredictionContextType.h + src/atn/PredictionMode.cpp + src/atn/PredictionMode.h + src/atn/ProfilingATNSimulator.cpp + src/atn/ProfilingATNSimulator.h + src/atn/RangeTransition.cpp + src/atn/RangeTransition.h + src/atn/RuleStartState.h + src/atn/RuleStopState.h + src/atn/RuleTransition.cpp + src/atn/RuleTransition.h + src/atn/SemanticContext.cpp + src/atn/SemanticContext.h + src/atn/SemanticContextType.h + src/atn/SerializedATNView.h + src/atn/SetTransition.cpp + src/atn/SetTransition.h + src/atn/SingletonPredictionContext.cpp + src/atn/SingletonPredictionContext.h + src/atn/StarBlockStartState.h + src/atn/StarLoopbackState.cpp + src/atn/StarLoopbackState.h + src/atn/StarLoopEntryState.h + src/atn/TokensStartState.h + src/atn/Transition.cpp + src/atn/Transition.h + src/atn/TransitionType.cpp + src/atn/TransitionType.h + src/atn/WildcardTransition.cpp + src/atn/WildcardTransition.h + src/BailErrorStrategy.cpp + src/BailErrorStrategy.h + src/BaseErrorListener.cpp + src/BaseErrorListener.h + src/BufferedTokenStream.cpp + src/BufferedTokenStream.h + src/CharStream.cpp + src/CharStream.h + src/CommonToken.cpp + src/CommonTokenFactory.cpp + src/CommonTokenFactory.h + src/CommonToken.h + src/CommonTokenStream.cpp + src/CommonTokenStream.h + src/ConsoleErrorListener.cpp + src/ConsoleErrorListener.h + src/DefaultErrorStrategy.cpp + src/DefaultErrorStrategy.h + src/dfa/DFA.cpp + src/dfa/DFA.h + src/dfa/DFASerializer.cpp + src/dfa/DFASerializer.h + src/dfa/DFAState.cpp + src/dfa/DFAState.h + src/dfa/LexerDFASerializer.cpp + src/dfa/LexerDFASerializer.h + src/DiagnosticErrorListener.cpp + src/DiagnosticErrorListener.h + src/Exceptions.cpp + src/Exceptions.h + src/FailedPredicateException.cpp + src/FailedPredicateException.h + src/FlatHashMap.h + src/FlatHashSet.h + src/InputMismatchException.cpp + src/InputMismatchException.h + src/internal/Synchronization.cpp + src/internal/Synchronization.h + src/InterpreterRuleContext.cpp + src/InterpreterRuleContext.h + src/IntStream.cpp + src/IntStream.h + src/Lexer.cpp + src/Lexer.h + src/LexerInterpreter.cpp + src/LexerInterpreter.h + src/LexerNoViableAltException.cpp + src/LexerNoViableAltException.h + src/ListTokenSource.cpp + src/ListTokenSource.h + src/misc/InterpreterDataReader.cpp + src/misc/InterpreterDataReader.h + src/misc/Interval.cpp + src/misc/Interval.h + src/misc/IntervalSet.cpp + src/misc/IntervalSet.h + src/misc/MurmurHash.cpp + src/misc/MurmurHash.h + src/misc/Predicate.cpp + src/misc/Predicate.h + src/NoViableAltException.cpp + src/NoViableAltException.h + src/Parser.cpp + src/Parser.h + src/ParserInterpreter.cpp + src/ParserInterpreter.h + src/ParserRuleContext.cpp + src/ParserRuleContext.h + src/ProxyErrorListener.cpp + src/ProxyErrorListener.h + src/RecognitionException.cpp + src/RecognitionException.h + src/Recognizer.cpp + src/Recognizer.h + src/RuleContext.cpp + src/RuleContext.h + src/RuleContextWithAltNum.cpp + src/RuleContextWithAltNum.h + src/RuntimeMetaData.cpp + src/RuntimeMetaData.h + src/support/Any.cpp + src/support/Any.h + src/support/Arrays.cpp + src/support/Arrays.h + src/support/BitSet.h + src/support/Casts.h + src/support/CPPUtils.cpp + src/support/CPPUtils.h + src/support/Declarations.h + src/support/StringUtils.cpp + src/support/StringUtils.h + src/support/Unicode.h + src/support/Utf8.cpp + src/support/Utf8.h + src/Token.cpp + src/TokenFactory.h + src/Token.h + src/TokenSource.cpp + src/TokenSource.h + src/TokenStream.cpp + src/TokenStream.h + src/TokenStreamRewriter.cpp + src/TokenStreamRewriter.h + src/tree/AbstractParseTreeVisitor.h + src/tree/ErrorNode.h + src/tree/ErrorNodeImpl.cpp + src/tree/ErrorNodeImpl.h + src/tree/IterativeParseTreeWalker.cpp + src/tree/IterativeParseTreeWalker.h + src/tree/ParseTree.cpp + src/tree/ParseTree.h + src/tree/ParseTreeListener.cpp + src/tree/ParseTreeListener.h + src/tree/ParseTreeProperty.h + src/tree/ParseTreeType.h + src/tree/ParseTreeVisitor.cpp + src/tree/ParseTreeVisitor.h + src/tree/ParseTreeWalker.cpp + src/tree/ParseTreeWalker.h + src/tree/pattern/Chunk.cpp + src/tree/pattern/Chunk.h + src/tree/pattern/ParseTreeMatch.cpp + src/tree/pattern/ParseTreeMatch.h + src/tree/pattern/ParseTreePattern.cpp + src/tree/pattern/ParseTreePattern.h + src/tree/pattern/ParseTreePatternMatcher.cpp + src/tree/pattern/ParseTreePatternMatcher.h + src/tree/pattern/RuleTagToken.cpp + src/tree/pattern/RuleTagToken.h + src/tree/pattern/TagChunk.cpp + src/tree/pattern/TagChunk.h + src/tree/pattern/TextChunk.cpp + src/tree/pattern/TextChunk.h + src/tree/pattern/TokenTagToken.cpp + src/tree/pattern/TokenTagToken.h + src/tree/TerminalNode.h + src/tree/TerminalNodeImpl.cpp + src/tree/TerminalNodeImpl.h + src/tree/Trees.cpp + src/tree/Trees.h + src/tree/xpath/XPath.cpp + src/tree/xpath/XPathElement.cpp + src/tree/xpath/XPathElement.h + src/tree/xpath/XPath.h + src/tree/xpath/XPathLexer.cpp + src/tree/xpath/XPathLexerErrorListener.cpp + src/tree/xpath/XPathLexerErrorListener.h + src/tree/xpath/XPathLexer.g4 + src/tree/xpath/XPathLexer.h + src/tree/xpath/XPathLexer.tokens + src/tree/xpath/XPathRuleAnywhereElement.cpp + src/tree/xpath/XPathRuleAnywhereElement.h + src/tree/xpath/XPathRuleElement.cpp + src/tree/xpath/XPathRuleElement.h + src/tree/xpath/XPathTokenAnywhereElement.cpp + src/tree/xpath/XPathTokenAnywhereElement.h + src/tree/xpath/XPathTokenElement.cpp + src/tree/xpath/XPathTokenElement.h + src/tree/xpath/XPathWildcardAnywhereElement.cpp + src/tree/xpath/XPathWildcardAnywhereElement.h + src/tree/xpath/XPathWildcardElement.cpp + src/tree/xpath/XPathWildcardElement.h + src/UnbufferedCharStream.cpp + src/UnbufferedCharStream.h + src/UnbufferedTokenStream.cpp + src/UnbufferedTokenStream.h + src/Version.h + src/Vocabulary.cpp + src/Vocabulary.h + src/WritableToken.cpp + src/WritableToken.h +) + +add_library(${ANTLR4_RUNTIME} STATIC ${ANTLR4_RUNTIME_SRC}) + +target_include_directories(${ANTLR4_RUNTIME} PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/src +) + +target_compile_definitions(${ANTLR4_RUNTIME} PUBLIC -DANTLR4CPP_STATIC) + +if(MSVC) + target_compile_options(${ANTLR4_RUNTIME} PRIVATE + $<$:/W4 /Od> + $<$:/O2> + ) +else() + target_compile_options(${ANTLR4_RUNTIME} PRIVATE + $<$:-g -Wall -O0> + $<$:-w -O3> + ) + target_compile_options(${ANTLR4_RUNTIME} PUBLIC + $<$:-Wno-attributes> + ) +endif() diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/LICENSE b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/LICENSE new file mode 100644 index 0000000000..d70ecc0464 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/LICENSE @@ -0,0 +1,84 @@ +BSD 3-Clause License + +Copyright (c) 2020-2021, Alejandro de Haro +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +===== + +[The "BSD 3-clause license"] +Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +===== + +MIT License for codepointat.js from https://git.io/codepointat +MIT License for fromcodepoint.js from https://git.io/vDW1m + +Copyright Mathias Bynens + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorListener.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorListener.cpp new file mode 100644 index 0000000000..6ceadb87f9 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorListener.cpp @@ -0,0 +1,10 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include "ANTLRErrorListener.h" + +antlr4::ANTLRErrorListener::~ANTLRErrorListener() +{ +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorListener.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorListener.h new file mode 100755 index 0000000000..d6efad1d9e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorListener.h @@ -0,0 +1,167 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include "RecognitionException.h" + +namespace antlrcpp { + class BitSet; +} + +namespace antlr4 { + + /// How to emit recognition errors (an interface in Java). + class ANTLR4CPP_PUBLIC ANTLRErrorListener { + public: + virtual ~ANTLRErrorListener(); + + /// + /// Upon syntax error, notify any interested parties. This is not how to + /// recover from errors or compute error messages. + /// specifies how to recover from syntax errors and how to compute error + /// messages. This listener's job is simply to emit a computed message, + /// though it has enough information to create its own message in many cases. + ///

+ /// The is non-null for all syntax errors except + /// when we discover mismatched token errors that we can recover from + /// in-line, without returning from the surrounding rule (via the single + /// token insertion and deletion mechanism). + ///

+ /// + /// What parser got the error. From this + /// object, you can access the context as well + /// as the input stream. + /// + /// The offending token in the input token + /// stream, unless recognizer is a lexer (then it's null). If + /// no viable alternative error, {@code e} has token at which we + /// started production for the decision. + /// + /// The line number in the input where the error occurred. + /// + /// The character position within that line where the error occurred. + /// + /// The message to emit. + /// + /// The exception generated by the parser that led to + /// the reporting of an error. It is null in the case where + /// the parser was able to recover in line without exiting the + /// surrounding rule. + virtual void syntaxError(Recognizer *recognizer, Token *offendingSymbol, size_t line, + size_t charPositionInLine, const std::string &msg, std::exception_ptr e) = 0; + + /** + * This method is called by the parser when a full-context prediction + * results in an ambiguity. + * + *

Each full-context prediction which does not result in a syntax error + * will call either {@link #reportContextSensitivity} or + * {@link #reportAmbiguity}.

+ * + *

When {@code ambigAlts} is not null, it contains the set of potentially + * viable alternatives identified by the prediction algorithm. When + * {@code ambigAlts} is null, use {@link ATNConfigSet#getAlts} to obtain the + * represented alternatives from the {@code configs} argument.

+ * + *

When {@code exact} is {@code true}, all of the potentially + * viable alternatives are truly viable, i.e. this is reporting an exact + * ambiguity. When {@code exact} is {@code false}, at least two of + * the potentially viable alternatives are viable for the current input, but + * the prediction algorithm terminated as soon as it determined that at + * least the minimum potentially viable alternative is truly + * viable.

+ * + *

When the {@link PredictionMode#LL_EXACT_AMBIG_DETECTION} prediction + * mode is used, the parser is required to identify exact ambiguities so + * {@code exact} will always be {@code true}.

+ * + *

This method is not used by lexers.

+ * + * @param recognizer the parser instance + * @param dfa the DFA for the current decision + * @param startIndex the input index where the decision started + * @param stopIndex the input input where the ambiguity was identified + * @param exact {@code true} if the ambiguity is exactly known, otherwise + * {@code false}. This is always {@code true} when + * {@link PredictionMode#LL_EXACT_AMBIG_DETECTION} is used. + * @param ambigAlts the potentially ambiguous alternatives, or {@code null} + * to indicate that the potentially ambiguous alternatives are the complete + * set of represented alternatives in {@code configs} + * @param configs the ATN configuration set where the ambiguity was + * identified + */ + virtual void reportAmbiguity(Parser *recognizer, const dfa::DFA &dfa, size_t startIndex, size_t stopIndex, bool exact, + const antlrcpp::BitSet &ambigAlts, atn::ATNConfigSet *configs) = 0; + + /** + * This method is called when an SLL conflict occurs and the parser is about + * to use the full context information to make an LL decision. + * + *

If one or more configurations in {@code configs} contains a semantic + * predicate, the predicates are evaluated before this method is called. The + * subset of alternatives which are still viable after predicates are + * evaluated is reported in {@code conflictingAlts}.

+ * + *

This method is not used by lexers.

+ * + * @param recognizer the parser instance + * @param dfa the DFA for the current decision + * @param startIndex the input index where the decision started + * @param stopIndex the input index where the SLL conflict occurred + * @param conflictingAlts The specific conflicting alternatives. If this is + * {@code null}, the conflicting alternatives are all alternatives + * represented in {@code configs}. At the moment, conflictingAlts is non-null + * (for the reference implementation, but Sam's optimized version can see this + * as null). + * @param configs the ATN configuration set where the SLL conflict was + * detected + */ + virtual void reportAttemptingFullContext(Parser *recognizer, const dfa::DFA &dfa, size_t startIndex, size_t stopIndex, + const antlrcpp::BitSet &conflictingAlts, atn::ATNConfigSet *configs) = 0; + + /** + * This method is called by the parser when a full-context prediction has a + * unique result. + * + *

Each full-context prediction which does not result in a syntax error + * will call either {@link #reportContextSensitivity} or + * {@link #reportAmbiguity}.

+ * + *

For prediction implementations that only evaluate full-context + * predictions when an SLL conflict is found (including the default + * {@link ParserATNSimulator} implementation), this method reports cases + * where SLL conflicts were resolved to unique full-context predictions, + * i.e. the decision was context-sensitive. This report does not necessarily + * indicate a problem, and it may appear even in completely unambiguous + * grammars.

+ * + *

{@code configs} may have more than one represented alternative if the + * full-context prediction algorithm does not evaluate predicates before + * beginning the full-context prediction. In all cases, the final prediction + * is passed as the {@code prediction} argument.

+ * + *

Note that the definition of "context sensitivity" in this method + * differs from the concept in {@link DecisionInfo#contextSensitivities}. + * This method reports all instances where an SLL conflict occurred but LL + * parsing produced a unique result, whether or not that unique result + * matches the minimum alternative in the SLL conflicting set.

+ * + *

This method is not used by lexers.

+ * + * @param recognizer the parser instance + * @param dfa the DFA for the current decision + * @param startIndex the input index where the decision started + * @param stopIndex the input index where the context sensitivity was + * finally determined + * @param prediction the unambiguous result of the full-context prediction + * @param configs the ATN configuration set where the unambiguous prediction + * was determined + */ + virtual void reportContextSensitivity(Parser *recognizer, const dfa::DFA &dfa, size_t startIndex, size_t stopIndex, + size_t prediction, atn::ATNConfigSet *configs) = 0; + }; + +} // namespace antlr4 diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorStrategy.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorStrategy.cpp new file mode 100644 index 0000000000..1655a5731d --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorStrategy.cpp @@ -0,0 +1,10 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include "ANTLRErrorStrategy.h" + +antlr4::ANTLRErrorStrategy::~ANTLRErrorStrategy() +{ +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorStrategy.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorStrategy.h new file mode 100755 index 0000000000..a3eecd14c4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRErrorStrategy.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include "Token.h" + +namespace antlr4 { + + /// + /// The interface for defining strategies to deal with syntax errors encountered + /// during a parse by ANTLR-generated parsers. We distinguish between three + /// different kinds of errors: + /// + ///
    + ///
  • The parser could not figure out which path to take in the ATN (none of + /// the available alternatives could possibly match)
  • + ///
  • The current input does not match what we were looking for
  • + ///
  • A predicate evaluated to false
  • + ///
+ /// + /// Implementations of this interface report syntax errors by calling + /// . + ///

+ /// TODO: what to do about lexers + ///

+ class ANTLR4CPP_PUBLIC ANTLRErrorStrategy { + public: + + /// + /// Reset the error handler state for the specified {@code recognizer}. + /// the parser instance + virtual ~ANTLRErrorStrategy(); + + virtual void reset(Parser *recognizer) = 0; + + /** + * This method is called when an unexpected symbol is encountered during an + * inline match operation, such as {@link Parser#match}. If the error + * strategy successfully recovers from the match failure, this method + * returns the {@link Token} instance which should be treated as the + * successful result of the match. + * + *

This method handles the consumption of any tokens - the caller should + * not call {@link Parser#consume} after a successful recovery.

+ * + *

Note that the calling code will not report an error if this method + * returns successfully. The error strategy implementation is responsible + * for calling {@link Parser#notifyErrorListeners} as appropriate.

+ * + * @param recognizer the parser instance + * @throws RecognitionException if the error strategy was not able to + * recover from the unexpected input symbol + */ + virtual Token* recoverInline(Parser *recognizer) = 0; + + /// + /// This method is called to recover from exception {@code e}. This method is + /// called after by the default exception handler + /// generated for a rule method. + /// + /// + /// the parser instance + /// the recognition exception to recover from + /// if the error strategy could not recover from + /// the recognition exception + virtual void recover(Parser *recognizer, std::exception_ptr e) = 0; + + /// + /// This method provides the error handler with an opportunity to handle + /// syntactic or semantic errors in the input stream before they result in a + /// . + ///

+ /// The generated code currently contains calls to after + /// entering the decision state of a closure block ({@code (...)*} or + /// {@code (...)+}). + ///

+ /// For an implementation based on Jim Idle's "magic sync" mechanism, see + /// . + ///

+ /// + /// the parser instance + /// if an error is detected by the error + /// strategy but cannot be automatically recovered at the current state in + /// the parsing process + virtual void sync(Parser *recognizer) = 0; + + /// + /// Tests whether or not {@code recognizer} is in the process of recovering + /// from an error. In error recovery mode, adds + /// symbols to the parse tree by calling + /// {@link Parser#createErrorNode(ParserRuleContext, Token)} then + /// {@link ParserRuleContext#addErrorNode(ErrorNode)} instead of + /// {@link Parser#createTerminalNode(ParserRuleContext, Token)}. + /// + /// the parser instance + /// {@code true} if the parser is currently recovering from a parse + /// error, otherwise {@code false} + virtual bool inErrorRecoveryMode(Parser *recognizer) = 0; + + /// + /// This method is called by when the parser successfully matches an input + /// symbol. + /// + /// the parser instance + virtual void reportMatch(Parser *recognizer) = 0; + + /// + /// Report any kind of . This method is called by + /// the default exception handler generated for a rule method. + /// + /// the parser instance + /// the recognition exception to report + virtual void reportError(Parser *recognizer, const RecognitionException &e) = 0; + }; + +} // namespace antlr4 diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRFileStream.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRFileStream.cpp new file mode 100755 index 0000000000..674817ac0e --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRFileStream.cpp @@ -0,0 +1,23 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include "ANTLRFileStream.h" + +using namespace antlr4; + +void ANTLRFileStream::loadFromFile(const std::string &fileName) { + _fileName = fileName; + if (_fileName.empty()) { + return; + } + + std::ifstream stream(fileName, std::ios::binary); + + ANTLRInputStream::load(stream); +} + +std::string ANTLRFileStream::getSourceName() const { + return _fileName; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRFileStream.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRFileStream.h new file mode 100755 index 0000000000..6c7d619a00 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRFileStream.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include "ANTLRInputStream.h" + +namespace antlr4 { + + /// This is an ANTLRInputStream that is loaded from a file all at once + /// when you construct the object (or call load()). + // TODO: this class needs testing. + class ANTLR4CPP_PUBLIC ANTLRFileStream : public ANTLRInputStream { + public: + ANTLRFileStream() = default; + ANTLRFileStream(const std::string &) = delete; + ANTLRFileStream(const char *data, size_t length) = delete; + ANTLRFileStream(std::istream &stream) = delete; + + // Assumes a file name encoded in UTF-8 and file content in the same encoding (with or w/o BOM). + virtual void loadFromFile(const std::string &fileName); + virtual std::string getSourceName() const override; + + private: + std::string _fileName; // UTF-8 encoded file name. + }; + +} // namespace antlr4 diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRInputStream.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRInputStream.cpp new file mode 100755 index 0000000000..b6470af9b7 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRInputStream.cpp @@ -0,0 +1,180 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include + +#include "Exceptions.h" +#include "misc/Interval.h" +#include "IntStream.h" + +#include "support/Utf8.h" +#include "support/CPPUtils.h" + +#include "ANTLRInputStream.h" + +using namespace antlr4; +using namespace antlrcpp; + +using misc::Interval; + +ANTLRInputStream::ANTLRInputStream() { + InitializeInstanceFields(); +} + +ANTLRInputStream::ANTLRInputStream(std::string_view input): ANTLRInputStream() { + load(input.data(), input.length()); +} + +ANTLRInputStream::ANTLRInputStream(const char *data, size_t length) { + load(data, length); +} + +ANTLRInputStream::ANTLRInputStream(std::istream &stream): ANTLRInputStream() { + load(stream); +} + +void ANTLRInputStream::load(const std::string &input, bool lenient) { + load(input.data(), input.size(), lenient); +} + +void ANTLRInputStream::load(const char *data, size_t length, bool lenient) { + // Remove the UTF-8 BOM if present. + const char *bom = "\xef\xbb\xbf"; + if (length >= 3 && strncmp(data, bom, 3) == 0) { + data += 3; + length -= 3; + } + if (lenient) { + _data = Utf8::lenientDecode(std::string_view(data, length)); + } else { + auto maybe_utf32 = Utf8::strictDecode(std::string_view(data, length)); + if (!maybe_utf32.has_value()) { + throw IllegalArgumentException("UTF-8 string contains an illegal byte sequence"); + } + _data = std::move(maybe_utf32).value(); + } + p = 0; +} + +void ANTLRInputStream::load(std::istream &stream, bool lenient) { + if (!stream.good() || stream.eof()) // No fail, bad or EOF. + return; + + _data.clear(); + + std::string s((std::istreambuf_iterator(stream)), std::istreambuf_iterator()); + load(s.data(), s.length(), lenient); +} + +void ANTLRInputStream::reset() { + p = 0; +} + +void ANTLRInputStream::consume() { + if (p >= _data.size()) { + assert(LA(1) == IntStream::EOF); + throw IllegalStateException("cannot consume EOF"); + } + + if (p < _data.size()) { + p++; + } +} + +size_t ANTLRInputStream::LA(ssize_t i) { + if (i == 0) { + return 0; // undefined + } + + ssize_t position = static_cast(p); + if (i < 0) { + i++; // e.g., translate LA(-1) to use offset i=0; then _data[p+0-1] + if ((position + i - 1) < 0) { + return IntStream::EOF; // invalid; no char before first char + } + } + + if ((position + i - 1) >= static_cast(_data.size())) { + return IntStream::EOF; + } + + return _data[static_cast((position + i - 1))]; +} + +size_t ANTLRInputStream::LT(ssize_t i) { + return LA(i); +} + +size_t ANTLRInputStream::index() { + return p; +} + +size_t ANTLRInputStream::size() { + return _data.size(); +} + +// Mark/release do nothing. We have entire buffer. +ssize_t ANTLRInputStream::mark() { + return -1; +} + +void ANTLRInputStream::release(ssize_t /* marker */) { +} + +void ANTLRInputStream::seek(size_t index) { + if (index <= p) { + p = index; // just jump; don't update stream state (line, ...) + return; + } + // seek forward, consume until p hits index or n (whichever comes first) + index = std::min(index, _data.size()); + while (p < index) { + consume(); + } +} + +std::string ANTLRInputStream::getText(const Interval &interval) { + if (interval.a < 0 || interval.b < 0) { + return ""; + } + + size_t start = static_cast(interval.a); + size_t stop = static_cast(interval.b); + + + if (stop >= _data.size()) { + stop = _data.size() - 1; + } + + size_t count = stop - start + 1; + if (start >= _data.size()) { + return ""; + } + + auto maybeUtf8 = Utf8::strictEncode(std::u32string_view(_data).substr(start, count)); + if (!maybeUtf8.has_value()) { + throw IllegalArgumentException("Input stream contains invalid Unicode code points"); + } + return std::move(maybeUtf8).value(); +} + +std::string ANTLRInputStream::getSourceName() const { + if (name.empty()) { + return IntStream::UNKNOWN_SOURCE_NAME; + } + return name; +} + +std::string ANTLRInputStream::toString() const { + auto maybeUtf8 = Utf8::strictEncode(_data); + if (!maybeUtf8.has_value()) { + throw IllegalArgumentException("Input stream contains invalid Unicode code points"); + } + return std::move(maybeUtf8).value(); +} + +void ANTLRInputStream::InitializeInstanceFields() { + p = 0; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRInputStream.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRInputStream.h new file mode 100755 index 0000000000..413eadefa4 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/ANTLRInputStream.h @@ -0,0 +1,79 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include + +#include "CharStream.h" + +namespace antlr4 { + + // Vacuum all input from a stream and then treat it + // like a string. Can also pass in a string or char[] to use. + // Input is expected to be encoded in UTF-8 and converted to UTF-32 internally. + class ANTLR4CPP_PUBLIC ANTLRInputStream : public CharStream { + protected: + /// The data being scanned. + // UTF-32 + std::u32string _data; + + /// 0..n-1 index into string of next char + size_t p; + + public: + /// What is name or source of this char stream? + std::string name; + + ANTLRInputStream(); + + ANTLRInputStream(std::string_view input); + + ANTLRInputStream(const char *data, size_t length); + ANTLRInputStream(std::istream &stream); + + virtual void load(const std::string &input, bool lenient); + virtual void load(const char *data, size_t length, bool lenient); + virtual void load(std::istream &stream, bool lenient); + + virtual void load(const std::string &input) { load(input, false); } + virtual void load(const char *data, size_t length) { load(data, length, false); } + virtual void load(std::istream &stream) { load(stream, false); } + + /// Reset the stream so that it's in the same state it was + /// when the object was created *except* the data array is not + /// touched. + virtual void reset(); + virtual void consume() override; + virtual size_t LA(ssize_t i) override; + virtual size_t LT(ssize_t i); + + /// + /// Return the current input symbol index 0..n where n indicates the + /// last symbol has been read. The index is the index of char to + /// be returned from LA(1). + /// + virtual size_t index() override; + virtual size_t size() override; + + /// + /// mark/release do nothing; we have entire buffer + virtual ssize_t mark() override; + virtual void release(ssize_t marker) override; + + /// + /// consume() ahead until p==index; can't just set p=index as we must + /// update line and charPositionInLine. If we seek backwards, just set p + /// + virtual void seek(size_t index) override; + virtual std::string getText(const misc::Interval &interval) override; + virtual std::string getSourceName() const override; + virtual std::string toString() const override; + + private: + void InitializeInstanceFields(); + }; + +} // namespace antlr4 diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BailErrorStrategy.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BailErrorStrategy.cpp new file mode 100755 index 0000000000..5fbc011611 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BailErrorStrategy.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include "Exceptions.h" +#include "ParserRuleContext.h" +#include "InputMismatchException.h" +#include "Parser.h" + +#include "BailErrorStrategy.h" + +using namespace antlr4; + +void BailErrorStrategy::recover(Parser *recognizer, std::exception_ptr e) { + ParserRuleContext *context = recognizer->getContext(); + do { + context->exception = e; + if (context->parent == nullptr) + break; + context = static_cast(context->parent); + } while (true); + + try { + std::rethrow_exception(e); // Throw the exception to be able to catch and rethrow nested. +#if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026 + } catch (RecognitionException &inner) { + throw ParseCancellationException(inner.what()); +#else + } catch (RecognitionException & /*inner*/) { + std::throw_with_nested(ParseCancellationException()); +#endif + } +} + +Token* BailErrorStrategy::recoverInline(Parser *recognizer) { + InputMismatchException e(recognizer); + std::exception_ptr exception = std::make_exception_ptr(e); + + ParserRuleContext *context = recognizer->getContext(); + do { + context->exception = exception; + if (context->parent == nullptr) + break; + context = static_cast(context->parent); + } while (true); + + try { + throw e; +#if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026 + } catch (InputMismatchException &inner) { + throw ParseCancellationException(inner.what()); +#else + } catch (InputMismatchException & /*inner*/) { + std::throw_with_nested(ParseCancellationException()); +#endif + } +} + +void BailErrorStrategy::sync(Parser * /*recognizer*/) { +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BailErrorStrategy.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BailErrorStrategy.h new file mode 100755 index 0000000000..2a8c36f9ed --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BailErrorStrategy.h @@ -0,0 +1,59 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include "DefaultErrorStrategy.h" + +namespace antlr4 { + + /** + * This implementation of {@link ANTLRErrorStrategy} responds to syntax errors + * by immediately canceling the parse operation with a + * {@link ParseCancellationException}. The implementation ensures that the + * {@link ParserRuleContext#exception} field is set for all parse tree nodes + * that were not completed prior to encountering the error. + * + *

+ * This error strategy is useful in the following scenarios.

+ * + *
    + *
  • Two-stage parsing: This error strategy allows the first + * stage of two-stage parsing to immediately terminate if an error is + * encountered, and immediately fall back to the second stage. In addition to + * avoiding wasted work by attempting to recover from errors here, the empty + * implementation of {@link BailErrorStrategy#sync} improves the performance of + * the first stage.
  • + *
  • Silent validation: When syntax errors are not being + * reported or logged, and the parse result is simply ignored if errors occur, + * the {@link BailErrorStrategy} avoids wasting work on recovering from errors + * when the result will be ignored either way.
  • + *
+ * + *

+ * {@code myparser.setErrorHandler(new BailErrorStrategy());}

+ * + * @see Parser#setErrorHandler(ANTLRErrorStrategy) + */ + class ANTLR4CPP_PUBLIC BailErrorStrategy : public DefaultErrorStrategy { + /// + /// Instead of recovering from exception {@code e}, re-throw it wrapped + /// in a so it is not caught by the + /// rule function catches. Use to get the + /// original . + /// + public: + virtual void recover(Parser *recognizer, std::exception_ptr e) override; + + /// Make sure we don't attempt to recover inline; if the parser + /// successfully recovers, it won't throw an exception. + virtual Token* recoverInline(Parser *recognizer) override; + + /// + /// Make sure we don't attempt to recover from problems in subrules. + virtual void sync(Parser *recognizer) override; + }; + +} // namespace antlr4 diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BaseErrorListener.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BaseErrorListener.cpp new file mode 100755 index 0000000000..c035f09f0f --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BaseErrorListener.cpp @@ -0,0 +1,25 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include "BaseErrorListener.h" +#include "RecognitionException.h" + +using namespace antlr4; + +void BaseErrorListener::syntaxError(Recognizer * /*recognizer*/, Token * /*offendingSymbol*/, size_t /*line*/, + size_t /*charPositionInLine*/, const std::string &/*msg*/, std::exception_ptr /*e*/) { +} + +void BaseErrorListener::reportAmbiguity(Parser * /*recognizer*/, const dfa::DFA &/*dfa*/, size_t /*startIndex*/, + size_t /*stopIndex*/, bool /*exact*/, const antlrcpp::BitSet &/*ambigAlts*/, atn::ATNConfigSet * /*configs*/) { +} + +void BaseErrorListener::reportAttemptingFullContext(Parser * /*recognizer*/, const dfa::DFA &/*dfa*/, size_t /*startIndex*/, + size_t /*stopIndex*/, const antlrcpp::BitSet &/*conflictingAlts*/, atn::ATNConfigSet * /*configs*/) { +} + +void BaseErrorListener::reportContextSensitivity(Parser * /*recognizer*/, const dfa::DFA &/*dfa*/, size_t /*startIndex*/, + size_t /*stopIndex*/, size_t /*prediction*/, atn::ATNConfigSet * /*configs*/) { +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BaseErrorListener.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BaseErrorListener.h new file mode 100755 index 0000000000..aad2e5d755 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BaseErrorListener.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include "ANTLRErrorListener.h" + +namespace antlrcpp { + class BitSet; +} + +namespace antlr4 { + + /** + * Provides an empty default implementation of {@link ANTLRErrorListener}. The + * default implementation of each method does nothing, but can be overridden as + * necessary. + */ + class ANTLR4CPP_PUBLIC BaseErrorListener : public ANTLRErrorListener { + + virtual void syntaxError(Recognizer *recognizer, Token * offendingSymbol, size_t line, size_t charPositionInLine, + const std::string &msg, std::exception_ptr e) override; + + virtual void reportAmbiguity(Parser *recognizer, const dfa::DFA &dfa, size_t startIndex, size_t stopIndex, bool exact, + const antlrcpp::BitSet &ambigAlts, atn::ATNConfigSet *configs) override; + + virtual void reportAttemptingFullContext(Parser *recognizer, const dfa::DFA &dfa, size_t startIndex, size_t stopIndex, + const antlrcpp::BitSet &conflictingAlts, atn::ATNConfigSet *configs) override; + + virtual void reportContextSensitivity(Parser *recognizer, const dfa::DFA &dfa, size_t startIndex, size_t stopIndex, + size_t prediction, atn::ATNConfigSet *configs) override; + }; + +} // namespace antlr4 diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BufferedTokenStream.cpp b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BufferedTokenStream.cpp new file mode 100755 index 0000000000..241dfe5c47 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BufferedTokenStream.cpp @@ -0,0 +1,414 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#include "WritableToken.h" +#include "Lexer.h" +#include "RuleContext.h" +#include "misc/Interval.h" +#include "Exceptions.h" +#include "support/CPPUtils.h" + +#include "BufferedTokenStream.h" + +using namespace antlr4; +using namespace antlrcpp; + +BufferedTokenStream::BufferedTokenStream(TokenSource *tokenSource) : _tokenSource(tokenSource){ + InitializeInstanceFields(); +} + +TokenSource* BufferedTokenStream::getTokenSource() const { + return _tokenSource; +} + +size_t BufferedTokenStream::index() { + return _p; +} + +ssize_t BufferedTokenStream::mark() { + return 0; +} + +void BufferedTokenStream::release(ssize_t /*marker*/) { + // no resources to release +} + +void BufferedTokenStream::reset() { + seek(0); +} + +void BufferedTokenStream::seek(size_t index) { + lazyInit(); + _p = adjustSeekIndex(index); +} + +size_t BufferedTokenStream::size() { + return _tokens.size(); +} + +void BufferedTokenStream::consume() { + bool skipEofCheck = false; + if (!_needSetup) { + if (_fetchedEOF) { + // the last token in tokens is EOF. skip check if p indexes any + // fetched token except the last. + skipEofCheck = _p < _tokens.size() - 1; + } else { + // no EOF token in tokens. skip check if p indexes a fetched token. + skipEofCheck = _p < _tokens.size(); + } + } else { + // not yet initialized + skipEofCheck = false; + } + + if (!skipEofCheck && LA(1) == Token::EOF) { + throw IllegalStateException("cannot consume EOF"); + } + + if (sync(_p + 1)) { + _p = adjustSeekIndex(_p + 1); + } +} + +bool BufferedTokenStream::sync(size_t i) { + if (i + 1 < _tokens.size()) + return true; + size_t n = i - _tokens.size() + 1; // how many more elements we need? + + if (n > 0) { + size_t fetched = fetch(n); + return fetched >= n; + } + + return true; +} + +size_t BufferedTokenStream::fetch(size_t n) { + if (_fetchedEOF) { + return 0; + } + + size_t i = 0; + while (i < n) { + std::unique_ptr t(_tokenSource->nextToken()); + + if (is(t.get())) { + (static_cast(t.get()))->setTokenIndex(_tokens.size()); + } + + _tokens.push_back(std::move(t)); + ++i; + + if (_tokens.back()->getType() == Token::EOF) { + _fetchedEOF = true; + break; + } + } + + return i; +} + +Token* BufferedTokenStream::get(size_t i) const { + if (i >= _tokens.size()) { + throw IndexOutOfBoundsException(std::string("token index ") + + std::to_string(i) + + std::string(" out of range 0..") + + std::to_string(_tokens.size() - 1)); + } + return _tokens[i].get(); +} + +std::vector BufferedTokenStream::get(size_t start, size_t stop) { + std::vector subset; + + lazyInit(); + + if (_tokens.empty()) { + return subset; + } + + if (stop >= _tokens.size()) { + stop = _tokens.size() - 1; + } + for (size_t i = start; i <= stop; i++) { + Token *t = _tokens[i].get(); + if (t->getType() == Token::EOF) { + break; + } + subset.push_back(t); + } + return subset; +} + +size_t BufferedTokenStream::LA(ssize_t i) { + return LT(i)->getType(); +} + +Token* BufferedTokenStream::LB(size_t k) { + if (k > _p) { + return nullptr; + } + return _tokens[_p - k].get(); +} + +Token* BufferedTokenStream::LT(ssize_t k) { + lazyInit(); + if (k == 0) { + return nullptr; + } + if (k < 0) { + return LB(-k); + } + + size_t i = _p + k - 1; + sync(i); + if (i >= _tokens.size()) { // return EOF token + // EOF must be last token + return _tokens.back().get(); + } + + return _tokens[i].get(); +} + +ssize_t BufferedTokenStream::adjustSeekIndex(size_t i) { + return i; +} + +void BufferedTokenStream::lazyInit() { + if (_needSetup) { + setup(); + } +} + +void BufferedTokenStream::setup() { + _needSetup = false; + sync(0); + _p = adjustSeekIndex(0); +} + +void BufferedTokenStream::setTokenSource(TokenSource *tokenSource) { + _tokenSource = tokenSource; + _tokens.clear(); + _fetchedEOF = false; + _needSetup = true; +} + +std::vector BufferedTokenStream::getTokens() { + std::vector result; + for (auto &t : _tokens) + result.push_back(t.get()); + return result; +} + +std::vector BufferedTokenStream::getTokens(size_t start, size_t stop) { + return getTokens(start, stop, std::vector()); +} + +std::vector BufferedTokenStream::getTokens(size_t start, size_t stop, const std::vector &types) { + lazyInit(); + if (stop >= _tokens.size() || start >= _tokens.size()) { + throw IndexOutOfBoundsException(std::string("start ") + + std::to_string(start) + + std::string(" or stop ") + + std::to_string(stop) + + std::string(" not in 0..") + + std::to_string(_tokens.size() - 1)); + } + + std::vector filteredTokens; + + if (start > stop) { + return filteredTokens; + } + + for (size_t i = start; i <= stop; i++) { + Token *tok = _tokens[i].get(); + + if (types.empty() || std::find(types.begin(), types.end(), tok->getType()) != types.end()) { + filteredTokens.push_back(tok); + } + } + return filteredTokens; +} + +std::vector BufferedTokenStream::getTokens(size_t start, size_t stop, size_t ttype) { + std::vector s; + s.push_back(ttype); + return getTokens(start, stop, s); +} + +ssize_t BufferedTokenStream::nextTokenOnChannel(size_t i, size_t channel) { + sync(i); + if (i >= size()) { + return size() - 1; + } + + Token *token = _tokens[i].get(); + while (token->getChannel() != channel) { + if (token->getType() == Token::EOF) { + return i; + } + i++; + sync(i); + token = _tokens[i].get(); + } + return i; +} + +ssize_t BufferedTokenStream::previousTokenOnChannel(size_t i, size_t channel) { + sync(i); + if (i >= size()) { + // the EOF token is on every channel + return size() - 1; + } + + while (true) { + Token *token = _tokens[i].get(); + if (token->getType() == Token::EOF || token->getChannel() == channel) { + return i; + } + + if (i == 0) + return -1; + i--; + } + return i; +} + +std::vector BufferedTokenStream::getHiddenTokensToRight(size_t tokenIndex, ssize_t channel) { + lazyInit(); + if (tokenIndex >= _tokens.size()) { + throw IndexOutOfBoundsException(std::to_string(tokenIndex) + " not in 0.." + std::to_string(_tokens.size() - 1)); + } + + ssize_t nextOnChannel = nextTokenOnChannel(tokenIndex + 1, Lexer::DEFAULT_TOKEN_CHANNEL); + size_t to; + size_t from = tokenIndex + 1; + // if none onchannel to right, nextOnChannel=-1 so set to = last token + if (nextOnChannel == -1) { + to = static_cast(size() - 1); + } else { + to = nextOnChannel; + } + + return filterForChannel(from, to, channel); +} + +std::vector BufferedTokenStream::getHiddenTokensToRight(size_t tokenIndex) { + return getHiddenTokensToRight(tokenIndex, -1); +} + +std::vector BufferedTokenStream::getHiddenTokensToLeft(size_t tokenIndex, ssize_t channel) { + lazyInit(); + if (tokenIndex >= _tokens.size()) { + throw IndexOutOfBoundsException(std::to_string(tokenIndex) + " not in 0.." + std::to_string(_tokens.size() - 1)); + } + + if (tokenIndex == 0) { + // Obviously no tokens can appear before the first token. + return { }; + } + + ssize_t prevOnChannel = previousTokenOnChannel(tokenIndex - 1, Lexer::DEFAULT_TOKEN_CHANNEL); + if (prevOnChannel == static_cast(tokenIndex - 1)) { + return { }; + } + // if none onchannel to left, prevOnChannel=-1 then from=0 + size_t from = static_cast(prevOnChannel + 1); + size_t to = tokenIndex - 1; + + return filterForChannel(from, to, channel); +} + +std::vector BufferedTokenStream::getHiddenTokensToLeft(size_t tokenIndex) { + return getHiddenTokensToLeft(tokenIndex, -1); +} + +std::vector BufferedTokenStream::filterForChannel(size_t from, size_t to, ssize_t channel) { + std::vector hidden; + for (size_t i = from; i <= to; i++) { + Token *t = _tokens[i].get(); + if (channel == -1) { + if (t->getChannel() != Lexer::DEFAULT_TOKEN_CHANNEL) { + hidden.push_back(t); + } + } else { + if (t->getChannel() == static_cast(channel)) { + hidden.push_back(t); + } + } + } + + return hidden; +} + +bool BufferedTokenStream::isInitialized() const { + return !_needSetup; +} + +/** + * Get the text of all tokens in this buffer. + */ +std::string BufferedTokenStream::getSourceName() const +{ + return _tokenSource->getSourceName(); +} + +std::string BufferedTokenStream::getText() { + fill(); + return getText(misc::Interval(0U, size() - 1)); +} + +std::string BufferedTokenStream::getText(const misc::Interval &interval) { + lazyInit(); + size_t start = interval.a; + size_t stop = interval.b; + if (start == INVALID_INDEX || stop == INVALID_INDEX) { + return ""; + } + sync(stop); + if (stop >= _tokens.size()) { + stop = _tokens.size() - 1; + } + + std::stringstream ss; + for (size_t i = start; i <= stop; i++) { + Token *t = _tokens[i].get(); + if (t->getType() == Token::EOF) { + break; + } + ss << t->getText(); + } + return ss.str(); +} + +std::string BufferedTokenStream::getText(RuleContext *ctx) { + return getText(ctx->getSourceInterval()); +} + +std::string BufferedTokenStream::getText(Token *start, Token *stop) { + if (start != nullptr && stop != nullptr) { + return getText(misc::Interval(start->getTokenIndex(), stop->getTokenIndex())); + } + + return ""; +} + +void BufferedTokenStream::fill() { + lazyInit(); + const size_t blockSize = 1000; + while (true) { + size_t fetched = fetch(blockSize); + if (fetched < blockSize) { + return; + } + } +} + +void BufferedTokenStream::InitializeInstanceFields() { + _needSetup = true; + _fetchedEOF = false; +} diff --git a/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BufferedTokenStream.h b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BufferedTokenStream.h new file mode 100755 index 0000000000..fab74d24c2 --- /dev/null +++ b/graph-wasm/lbug-0.12.2/lbug-src/third_party/antlr4_runtime/src/BufferedTokenStream.h @@ -0,0 +1,200 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include "TokenStream.h" + +namespace antlr4 { + + /** + * This implementation of {@link TokenStream} loads tokens from a + * {@link TokenSource} on-demand, and places the tokens in a buffer to provide + * access to any previous token by index. + * + *

+ * This token stream ignores the value of {@link Token#getChannel}. If your + * parser requires the token stream filter tokens to only those on a particular + * channel, such as {@link Token#DEFAULT_CHANNEL} or + * {@link Token#HIDDEN_CHANNEL}, use a filtering token stream such a + * {@link CommonTokenStream}.

+ */ + class ANTLR4CPP_PUBLIC BufferedTokenStream : public TokenStream { + public: + BufferedTokenStream(TokenSource *tokenSource); + BufferedTokenStream(const BufferedTokenStream& other) = delete; + + BufferedTokenStream& operator = (const BufferedTokenStream& other) = delete; + + virtual TokenSource* getTokenSource() const override; + virtual size_t index() override; + virtual ssize_t mark() override; + + virtual void release(ssize_t marker) override; + virtual void reset(); + virtual void seek(size_t index) override; + + virtual size_t size() override; + virtual void consume() override; + + virtual Token* get(size_t i) const override; + + /// Get all tokens from start..stop inclusively. + virtual std::vector get(size_t start, size_t stop); + + virtual size_t LA(ssize_t i) override; + virtual Token* LT(ssize_t k) override; + + /// Reset this token stream by setting its token source. + virtual void setTokenSource(TokenSource *tokenSource); + virtual std::vector getTokens(); + virtual std::vector getTokens(size_t start, size_t stop); + + /// + /// Given a start and stop index, return a List of all tokens in + /// the token type BitSet. Return null if no tokens were found. This + /// method looks at both on and off channel tokens. + /// + virtual std::vector getTokens(size_t start, size_t stop, const std::vector &types); + virtual std::vector getTokens(size_t start, size_t stop, size_t ttype); + + /// Collect all tokens on specified channel to the right of + /// the current token up until we see a token on DEFAULT_TOKEN_CHANNEL or + /// EOF. If channel is -1, find any non default channel token. + virtual std::vector getHiddenTokensToRight(size_t tokenIndex, ssize_t channel); + + /// + /// Collect all hidden tokens (any off-default channel) to the right of + /// the current token up until we see a token on DEFAULT_TOKEN_CHANNEL + /// or EOF. + /// + virtual std::vector getHiddenTokensToRight(size_t tokenIndex); + + /// + /// Collect all tokens on specified channel to the left of + /// the current token up until we see a token on DEFAULT_TOKEN_CHANNEL. + /// If channel is -1, find any non default channel token. + /// + virtual std::vector getHiddenTokensToLeft(size_t tokenIndex, ssize_t channel); + + /// + /// Collect all hidden tokens (any off-default channel) to the left of + /// the current token up until we see a token on DEFAULT_TOKEN_CHANNEL. + /// + virtual std::vector getHiddenTokensToLeft(size_t tokenIndex); + + virtual std::string getSourceName() const override; + virtual std::string getText() override; + virtual std::string getText(const misc::Interval &interval) override; + virtual std::string getText(RuleContext *ctx) override; + virtual std::string getText(Token *start, Token *stop) override; + + /// Get all tokens from lexer until EOF. + virtual void fill(); + + protected: + /** + * The {@link TokenSource} from which tokens for this stream are fetched. + */ + TokenSource *_tokenSource; + + /** + * A collection of all tokens fetched from the token source. The list is + * considered a complete view of the input once {@link #fetchedEOF} is set + * to {@code true}. + */ + std::vector> _tokens; + + /** + * The index into {@link #tokens} of the current token (next token to + * {@link #consume}). {@link #tokens}{@code [}{@link #p}{@code ]} should be + * {@link #LT LT(1)}. + * + *

This field is set to -1 when the stream is first constructed or when + * {@link #setTokenSource} is called, indicating that the first token has + * not yet been fetched from the token source. For additional information, + * see the documentation of {@link IntStream} for a description of + * Initializing Methods.

+ */ + // ml: since -1 requires to make this member signed for just this single aspect we use a member _needSetup instead. + // Use bool isInitialized() to find out if this stream has started reading. + size_t _p; + + /** + * Indicates whether the {@link Token#EOF} token has been fetched from + * {@link #tokenSource} and added to {@link #tokens}. This field improves + * performance for the following cases: + * + *