🎉 Basic graph wasm support

This commit is contained in:
Alejandro Alonso
2025-12-12 18:57:57 +01:00
parent 33c786498d
commit 222481fa0d
2610 changed files with 1259200 additions and 7 deletions

View File

@@ -56,6 +56,7 @@
"text-editor/v2-html-paste"
"text-editor/v2"
"render-wasm/v1"
"graph-wasm/v1"
"variants/v1"})
;; A set of features enabled by default
@@ -79,7 +80,8 @@
"text-editor/v2-html-paste"
"text-editor/v2"
"tokens/numeric-input"
"render-wasm/v1"})
"render-wasm/v1"
"graph-wasm/v1"})
;; Features that are mainly backend only or there are a proper
;; fallback when frontend reports no support for it
@@ -128,6 +130,7 @@
:feature-text-editor-v2 "text-editor/v2"
:feature-text-editor-v2-html-paste "text-editor/v2-html-paste"
:feature-render-wasm "render-wasm/v1"
:feature-graph-wasm "graph-wasm/v1"
:feature-variants "variants/v1"
:feature-token-input "tokens/numeric-input"
nil))

View File

@@ -0,0 +1,23 @@
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8" />
</head>
<body>
<script type="module">
import initWasmModule from '/js/graph-wasm.js';
let Module = null;
function init(moduleInstance) {
Module = moduleInstance;
}
console.log("Loading module");
initWasmModule().then(Module => {
init(Module);
Module._hello();
});
</script>
</body>
</html>

View File

@@ -92,7 +92,7 @@
{:main
{:entries [app.worker]
:web-worker true
:prepend-js "importScripts('./render.js');"
:prepend-js "importScripts('./render.js', './graph-wasm-worker.js');"
:depends-on #{}}}
:js-options

View File

@@ -0,0 +1,12 @@
;; This Source Code Form is subject to the terms of the Mozilla Public
;; License, v. 2.0. If a copy of the MPL was not distributed with this
;; file, You can obtain one at http://mozilla.org/MPL/2.0/.
;;
;; Copyright (c) KALEIDOS INC
(ns app.graph-wasm
"A WASM based render API"
(:require
[app.graph-wasm.api :as wasm.api]))
(def module wasm.api/module)

View File

@@ -0,0 +1,91 @@
;; This Source Code Form is subject to the terms of the Mozilla Public
;; License, v. 2.0. If a copy of the MPL was not distributed with this
;; file, You can obtain one at http://mozilla.org/MPL/2.0/.
;;
;; Copyright (c) KALEIDOS INC
(ns app.graph-wasm.api
(:require
[app.common.data.macros :as dm]
[app.common.uuid :as uuid]
[app.config :as cf]
[app.graph-wasm.wasm :as wasm]
[app.render-wasm.helpers :as h]
[app.render-wasm.serializers :as sr]
[app.util.modules :as mod]
[promesa.core :as p]))
(defn hello []
(h/call wasm/internal-module "_hello"))
(defn init []
(h/call wasm/internal-module "_init"))
(defn use-shape
[id]
(let [buffer (uuid/get-u32 id)]
(println "use-shape" id)
(h/call wasm/internal-module "_use_shape"
(aget buffer 0)
(aget buffer 1)
(aget buffer 2)
(aget buffer 3))))
(defn set-shape-parent-id
[id]
(let [buffer (uuid/get-u32 id)]
(h/call wasm/internal-module "_set_shape_parent"
(aget buffer 0)
(aget buffer 1)
(aget buffer 2)
(aget buffer 3))))
(defn set-shape-type
[type]
(h/call wasm/internal-module "_set_shape_type" (sr/translate-shape-type type)))
(defn set-shape-selrect
[selrect]
(h/call wasm/internal-module "_set_shape_selrect"
(dm/get-prop selrect :x1)
(dm/get-prop selrect :y1)
(dm/get-prop selrect :x2)
(dm/get-prop selrect :y2)))
(defn set-object
[shape]
(let [id (dm/get-prop shape :id)
type (dm/get-prop shape :type)
parent-id (get shape :parent-id)
selrect (get shape :selrect)
children (get shape :shapes)]
(use-shape id)
(set-shape-type type)
(set-shape-parent-id parent-id)
(set-shape-selrect selrect)))
(defn set-objects
[objects]
(doseq [shape (vals objects)]
(set-object shape)))
(defn init-wasm-module
[module]
(let [default-fn (unchecked-get module "default")
href (cf/resolve-href "js/graph-wasm.wasm")]
(default-fn #js {:locateFile (constantly href)})))
(defonce module
(delay
(if (exists? js/dynamicImport)
(let [uri (cf/resolve-href "js/graph-wasm.js")]
(->> (mod/import uri)
(p/mcat init-wasm-module)
(p/fmap (fn [default]
(set! wasm/internal-module default)
true))
(p/merr
(fn [cause]
(js/console.error cause)
(p/resolved false)))))
(p/resolved false))))

View File

@@ -0,0 +1,9 @@
;; This Source Code Form is subject to the terms of the Mozilla Public
;; License, v. 2.0. If a copy of the MPL was not distributed with this
;; file, You can obtain one at http://mozilla.org/MPL/2.0/.
;;
;; Copyright (c) KALEIDOS INC
(ns app.graph-wasm.wasm)
(defonce internal-module #js {})

View File

@@ -0,0 +1,288 @@
;; This Source Code Form is subject to the terms of the Mozilla Public
;; License, v. 2.0. If a copy of the MPL was not distributed with this
;; file, You can obtain one at http://mozilla.org/MPL/2.0/.
;;
;; Copyright (c) KALEIDOS INC
;;
;; High level helpers to turn a shape subtree into a component and
;; replace equivalent subtrees by instances of that component.
(ns app.main.data.workspace.componentize
(:require
[app.common.data :as d]
[app.common.data.macros :as dm]
[app.common.files.changes-builder :as pcb]
[app.common.files.helpers :as cfh]
[app.common.geom.point :as gpt]
[app.common.logic.libraries :as cll]
[app.common.logic.shapes :as cls]
[app.common.types.shape :as cts]
[app.common.uuid :as uuid]
[app.main.data.changes :as dch]
[app.main.data.helpers :as dsh]
[app.main.data.workspace.libraries :as dwl]
[app.main.data.workspace.selection :as dws]
[app.main.data.workspace.shapes :as dwsh]
[app.main.data.workspace.undo :as dwu]
[beicon.v2.core :as rx]
[potok.v2.core :as ptk]))
;; NOTE: We keep this separate from `workspace.libraries` to avoid
;; introducing more complexity in that already big namespace.
(def ^:private instance-structural-keys
"Keys we do NOT want to copy from the original shape when creating a
new component instance. These are identity / structural / component
metadata keys that must be managed by the component system itself."
#{:id
:parent-id
:frame-id
:shapes
;; Component metadata
:component-id
:component-file
:component-root
:main-instance
:remote-synced
:shape-ref
:touched})
(def ^:private instance-geometry-keys
"Geometry-related keys that we *do* want to override per instance when
copying props from an existing subtree to a component instance."
#{:x
:y
:width
:height
:rotation
:flip-x
:flip-y
:selrect
:points
:proportion
:proportion-lock
:transform
:transform-inverse})
(defn- instantiate-similar-subtrees
"Internal helper. Given an atom `id-ref` that will contain the
`component-id`, replace each subtree rooted at the ids in
`similar-ids` by an instance of that component.
The operation is performed in a single undo transaction:
- Instantiate the component once per similar id, roughly at the same
top-left position as the original root.
- Delete the original subtrees.
- Select the main instance plus all the new instances."
[id-ref root-id similar-ids]
(ptk/reify ::instantiate-similar-subtrees
ptk/WatchEvent
(watch [it state _]
(let [component-id @id-ref
similar-ids (vec (or similar-ids []))]
(if (or (uuid/zero? component-id)
(empty? similar-ids))
(rx/empty)
(let [file-id (:current-file-id state)
page (dsh/lookup-page state)
page-id (:id page)
objects (:objects page)
libraries (dsh/lookup-libraries state)
fdata (dsh/lookup-file-data state file-id)
;; Reference subtree: shapes used to build the component.
;; We'll compute per-shape deltas against this subtree so
;; that we only override attributes that actually differ.
ref-subtree-ids (cfh/get-children-ids objects root-id)
ref-all-ids (into [root-id] ref-subtree-ids)
undo-id (js/Symbol)
;; 1) Instantiate component at each similar root position,
;; preserving per-instance overrides (geometry, style, etc.)
[changes new-root-ids]
(reduce
(fn [[changes acc] sid]
(if-let [shape (get objects sid)]
(let [position (gpt/point (:x shape) (:y shape))
;; Remember original parent and index so we can keep
;; the same ordering among the parent's children.
orig-root (get objects sid)
orig-parent-id (:parent-id orig-root)
orig-index (when orig-parent-id
(cfh/get-position-on-parent objects sid))
;; Instantiate a new component instance at the same position
[new-shape changes']
(cll/generate-instantiate-component
(or changes
(-> (pcb/empty-changes it page-id)
(pcb/with-objects objects)))
objects
file-id
component-id
position
page
libraries)
;; Build a structural mapping between the original subtree
;; (rooted at `sid`) and the new instance subtree.
;; NOTE 1: instantiating a component can introduce an extra
;; wrapper frame, so we try to align the original root
;; with the "equivalent" root inside the instance.
;; NOTE 2: by default the instance may be created *inside*
;; the original shape (because of layout / hit-testing).
;; We explicitly move the new instance to the same parent
;; and index as the original root, so that later deletes of
;; the original subtree don't remove the new instances and
;; the ordering among siblings is preserved.
changes' (cond-> changes'
(some? orig-parent-id)
(pcb/change-parent orig-parent-id [new-shape] orig-index
{:allow-altering-copies true
:ignore-touched true}))
objects' (pcb/get-objects changes')
orig-root (get objects sid)
new-root new-shape
orig-type (:type orig-root)
new-type (:type new-root)
;; Full original subtree (root + descendants)
orig-subtree-ids (cfh/get-children-ids objects sid)
orig-all-ids (into [sid] orig-subtree-ids)
;; Try to find an inner instance root matching the original type
;; when the outer instance root type differs (e.g. rect -> frame+rect).
direct-new-children (cfh/get-children-ids objects' (:id new-root))
candidate-instance-root
(when (and orig-type (not= orig-type new-type))
(let [cands (->> direct-new-children
(filter (fn [nid]
(when-let [s (get objects' nid)]
(= (:type s) orig-type)))))]
(when (= 1 (count cands))
(first cands))))
instance-root-id (or candidate-instance-root (:id new-root))
instance-root (get objects' instance-root-id)
new-subtree-ids (cfh/get-children-ids objects' instance-root-id)
new-all-ids (into [instance-root-id] new-subtree-ids)
id-pairs (map vector orig-all-ids new-all-ids)
changes''
;; Compute per-shape deltas against the reference
;; subtree (root-id) and apply only the differences
;; to the new instance subtree, so we don't blindly
;; overwrite attributes that are the same.
(reduce
(fn [ch [idx orig-id new-id]]
(let [ref-id (nth ref-all-ids idx nil)
ref-shape (get objects ref-id)
orig-shape (get objects orig-id)]
(if (and ref-shape orig-shape)
(let [;; Style / layout / text props (see `extract-props`)
ref-style (cts/extract-props ref-shape)
orig-style (cts/extract-props orig-shape)
style-delta (reduce (fn [m k]
(let [rv (get ref-style k ::none)
ov (get orig-style k ::none)]
(if (= rv ov)
m
(assoc m k ov))))
{}
(keys orig-style))
;; Geometry props
ref-geom (select-keys ref-shape instance-geometry-keys)
orig-geom (select-keys orig-shape instance-geometry-keys)
geom-delta (reduce (fn [m k]
(let [rv (get ref-geom k ::none)
ov (get orig-geom k ::none)]
(if (= rv ov)
m
(assoc m k ov))))
{}
(keys orig-geom))
;; Text content: if the subtree reference and the
;; original differ in `:content`, treat the whole
;; content tree as an override for this instance.
content-delta? (not= (:content ref-shape) (:content orig-shape))]
(-> ch
;; First patch style/text/layout props using the
;; canonical helpers so we don't touch structural ids.
(cond-> (seq style-delta)
(pcb/update-shapes
[new-id]
(fn [s objs] (cts/patch-props s style-delta objs))
{:with-objects? true}))
;; Then patch geometry directly on the instance.
(cond-> (seq geom-delta)
(pcb/update-shapes
[new-id]
(d/patch-object geom-delta)))
;; Finally, if text content differs between the
;; reference subtree and the similar subtree,
;; override the instance content with the original.
(cond-> content-delta?
(pcb/update-shapes
[new-id]
#(assoc % :content (:content orig-shape))))))
ch)))
changes'
(map-indexed (fn [idx [orig-id new-id]]
[idx orig-id new-id])
id-pairs))]
[changes'' (conj acc (:id new-shape))])
;; If the shape does not exist we just skip it
[changes acc]))
[nil []]
similar-ids)
changes (or changes
(-> (pcb/empty-changes it page-id)
(pcb/with-objects objects)))
;; 2) Delete original similar subtrees
;; NOTE: `d/ordered-set` with a single arg treats it as a single
;; element, so we must use `into` when we already have a collection.
ids-to-delete (into (d/ordered-set) similar-ids)
[all-parents changes]
(cls/generate-delete-shapes
changes
fdata
page
objects
ids-to-delete
{:allow-altering-copies true})
;; 3) Select main instance + new instances
;; Root id is kept as-is; add all new roots.
sel-ids (into (d/ordered-set) (cons root-id new-root-ids))]
(rx/of
(dwu/start-undo-transaction undo-id)
(dch/commit-changes changes)
(ptk/data-event :layout/update {:ids all-parents})
(dwu/commit-undo-transaction undo-id))))))))
(defn componentize-similar-subtrees
"Turn the subtree rooted at `root-id` into a component, then replace
the subtrees rooted at `similar-ids` with instances of that component.
This is implemented in two phases:
1) Use the existing `dwl/add-component` flow to create a component
from `root-id` (and obtain its `component-id`).
2) Using the new `component-id`, instantiate the component once per
entry in `similar-ids` and delete the old subtrees."
[root-id similar-ids]
(dm/assert!
"expected valid uuid for `root-id`"
(uuid? root-id))
(let [similar-ids (vec (or similar-ids []))]
(ptk/reify ::componentize-similar-subtrees
ptk/WatchEvent
(watch [_ _ _]
(let [id-ref (atom uuid/zero)]
(rx/concat
;; 1) Create component using the existing pipeline
(rx/of (dwl/add-component id-ref (d/ordered-set root-id)))
;; 2) Replace similar subtrees by instances of the new component
(rx/of (instantiate-similar-subtrees id-ref root-id similar-ids))))))))

View File

@@ -13,11 +13,16 @@
[app.common.geom.shapes :as gsh]
[app.common.types.color :as clr]
[app.common.types.component :as ctk]
[app.common.types.container :as ctn]
[app.common.types.path :as path]
[app.common.types.shape :as cts]
[app.common.types.shape-tree :as ctt]
[app.common.types.shape.layout :as ctl]
[app.config :as cf]
[app.graph-wasm.api :as graph-wasm.api]
[app.main.data.workspace.componentize :as dwc]
[app.main.data.workspace.modifiers :as dwm]
[app.main.data.workspace.selection :as dws]
[app.main.data.workspace.variants :as dwv]
[app.main.features :as features]
[app.main.refs :as refs]
@@ -57,8 +62,11 @@
[app.main.ui.workspace.viewport.utils :as utils]
[app.main.ui.workspace.viewport.viewport-ref :refer [create-viewport-ref]]
[app.main.ui.workspace.viewport.widgets :as widgets]
[app.main.worker :as worker]
[app.util.debug :as dbg]
[app.util.modules :as mod]
[beicon.v2.core :as rx]
[promesa.core :as p]
[rumext.v2 :as mf]))
;; --- Viewport
@@ -134,6 +142,7 @@
mod? (mf/use-state false)
space? (mf/use-state false)
z? (mf/use-state false)
g? (mf/use-state false)
cursor (mf/use-state #(utils/get-cursor :pointer-inner))
hover-ids (mf/use-state nil)
hover (mf/use-state nil)
@@ -302,12 +311,79 @@
(mf/use-fn
(mf/deps first-shape)
#(st/emit!
(dwv/add-new-variant (:id first-shape))))]
(dwv/add-new-variant (:id first-shape))))
graph-wasm-enabled? (features/use-feature "graph-wasm/v1")]
(mf/with-effect [page-id]
(when graph-wasm-enabled?
;; Initialize graph-wasm in the worker to avoid blocking main thread
(let [subscription
(->> (worker/ask! {:cmd :graph-wasm/init})
(rx/filter #(= (:status %) :ok))
(rx/take 1)
(rx/merge-map (fn [_]
(worker/ask! {:cmd :graph-wasm/set-objects
:objects base-objects}))))]
(rx/subscribe subscription
(fn [result]
(when (= (:status result) :ok)
(js/console.debug "Graph WASM initialized in worker"
(select-keys result [:processed]))))
(fn [error]
(js/console.error "Error initializing graph-wasm in worker:" error))
(fn []
(js/console.debug "Graph WASM worker operations completed"))))))
(mf/with-effect [selected @g?]
(when graph-wasm-enabled?
;; Search for similar shapes when selection changes or when
;; the user presses the \"c\" key while having a single
;; selection.
(when (and @g?
(some? selected)
(= (count selected) 1))
(let [selected-id (first selected)
selected-shape (get base-objects selected-id)
;; Skip shapes that already belong to a component
non-component? (and (some? selected-shape)
(not (ctn/in-any-component? base-objects selected-shape)))]
(println selected-shape)
(println (ctn/in-any-component? base-objects selected-shape))
(when non-component?
(let [subscription
(worker/ask! {:cmd :graph-wasm/search-similar-shapes
:shape-id selected-id})]
(rx/subscribe subscription
(fn [result]
(when (= (:status result) :ok)
(let [raw-similar-shapes (:similar-shapes result)
;; Filter out shapes that already belong to some component
;; (main instance, instance head or inside a component copy).
similar-shapes (->> raw-similar-shapes
(remove (fn [sid]
(when-let [s (get base-objects sid)]
(ctn/in-any-component? base-objects s))))
(into []))]
(when (d/not-empty? similar-shapes)
;; Transform the selected subtree into a component and
;; replace similar subtrees by instances of that component.
(st/emit! (dwc/componentize-similar-subtrees
selected-id
similar-shapes))))))
(fn [error]
(js/console.error "Error searching similar shapes:" error))
(fn []
(js/console.debug "Similar shapes search completed")))))))))
(hooks/setup-dom-events zoom disable-paste-ref in-viewport-ref read-only? drawing-tool path-drawing?)
(hooks/setup-viewport-size vport viewport-ref)
(hooks/setup-cursor cursor alt? mod? space? panning drawing-tool path-drawing? path-editing? z? read-only?)
(hooks/setup-keyboard alt? mod? space? z? shift?)
(hooks/setup-keyboard alt? mod? space? z? shift? g?)
(hooks/setup-hover-shapes page-id move-stream base-objects transform selected mod? hover measure-hover
hover-ids hover-top-frame-id @hover-disabled? focus zoom show-measures?)
(hooks/setup-viewport-modifiers modifiers base-objects)

View File

@@ -124,7 +124,7 @@
(reset! cursor new-cursor))))))
(defn setup-keyboard
[alt* mod* space* z* shift*]
[alt* mod* space* z* shift* g*]
(let [kbd-zoom-s
(mf/with-memo []
(->> ms/keyboard
@@ -151,12 +151,22 @@
(rx/filter kbd/z?)
(rx/filter (complement kbd/editing-event?))
(rx/map kbd/key-down-event?)
(rx/pipe (rxo/distinct-contiguous))))]
(rx/pipe (rxo/distinct-contiguous))))
kbd-g-s
(mf/with-memo []
(let [c-pred (kbd/is-key-ignore-case? "g")]
(->> ms/keyboard
(rx/filter c-pred)
(rx/filter (complement kbd/editing-event?))
(rx/map kbd/key-down-event?)
(rx/pipe (rxo/distinct-contiguous)))))]
(hooks/use-stream ms/keyboard-alt (partial reset! alt*))
(hooks/use-stream ms/keyboard-space (partial reset! space*))
(hooks/use-stream kbd-z-s (partial reset! z*))
(hooks/use-stream kbd-shift-s (partial reset! shift*))
(hooks/use-stream kbd-g-s (partial reset! g*))
(hooks/use-stream ms/keyboard-mod
(fn [value]
(reset! mod* value)

View File

@@ -122,6 +122,7 @@
mod? (mf/use-state false)
space? (mf/use-state false)
z? (mf/use-state false)
c? (mf/use-state false)
cursor (mf/use-state (utils/get-cursor :pointer-inner))
hover-ids (mf/use-state nil)
hover (mf/use-state nil)
@@ -360,7 +361,7 @@
(hooks/setup-dom-events zoom disable-paste-ref in-viewport-ref read-only? drawing-tool path-drawing?)
(hooks/setup-viewport-size vport viewport-ref)
(hooks/setup-cursor cursor alt? mod? space? panning drawing-tool path-drawing? path-editing? z? read-only?)
(hooks/setup-keyboard alt? mod? space? z? shift?)
(hooks/setup-keyboard alt? mod? space? z? shift? c?)
(hooks/setup-hover-shapes page-id move-stream base-objects transform selected mod? hover measure-hover
hover-ids hover-top-frame-id @hover-disabled? focus zoom show-measures?)
(hooks/setup-shortcuts path-editing? path-drawing? text-editing? grid-editing?)

View File

@@ -11,6 +11,7 @@
[app.common.schema :as sm]
[app.common.types.objects-map]
[app.util.object :as obj]
[app.worker.graph-wasm]
[app.worker.impl :as impl]
[app.worker.import]
[app.worker.index]

View File

@@ -0,0 +1,181 @@
;; This Source Code Form is subject to the terms of the Mozilla Public
;; License, v. 2.0. If a copy of the MPL was not distributed with this
;; file, You can obtain one at http://mozilla.org/MPL/2.0/.
;;
;; Copyright (c) KALEIDOS INC
(ns app.worker.graph-wasm
"Graph WASM operations within the worker."
(:require
[app.common.data.macros :as dm]
[app.common.logging :as log]
[app.common.uuid :as uuid]
[app.config :as cf]
[app.graph-wasm.wasm :as wasm]
[app.render-wasm.helpers :as h]
[app.render-wasm.serializers :as sr]
[app.worker.impl :as impl]
[beicon.v2.core :as rx]
[promesa.core :as p]))
(log/set-level! :info)
(defn- use-shape
[module id]
(let [buffer (uuid/get-u32 id)]
(h/call module "_use_shape"
(aget buffer 0)
(aget buffer 1)
(aget buffer 2)
(aget buffer 3))))
(defn- set-shape-parent-id
[module id]
(let [buffer (uuid/get-u32 id)]
(h/call module "_set_shape_parent"
(aget buffer 0)
(aget buffer 1)
(aget buffer 2)
(aget buffer 3))))
(defn- set-shape-type
[module type]
(h/call module "_set_shape_type" (sr/translate-shape-type type)))
(defn- set-shape-selrect
[module selrect]
(h/call module "_set_shape_selrect"
(dm/get-prop selrect :x1)
(dm/get-prop selrect :y1)
(dm/get-prop selrect :x2)
(dm/get-prop selrect :y2)))
(defn- set-object
[module shape]
(let [id (dm/get-prop shape :id)
type (dm/get-prop shape :type)
parent-id (get shape :parent-id)
selrect (get shape :selrect)]
(use-shape module id)
(set-shape-type module type)
(set-shape-parent-id module parent-id)
(set-shape-selrect module selrect)))
(defonce ^:private graph-wasm-module
(delay
(let [module (unchecked-get js/globalThis "GraphWasmModule")
init-fn (unchecked-get module "default")
href (cf/resolve-href "js/graph-wasm.wasm")]
(->> (init-fn #js {:locateFile (constantly href)})
(p/fnly (fn [module cause]
(if cause
(js/console.error cause)
(set! wasm/internal-module module))))))))
(defmethod impl/handler :graph-wasm/init
[message transfer]
(rx/create
(fn [subs]
(-> @graph-wasm-module
(p/then (fn [module]
(if module
(try
(h/call module "_init")
(rx/push! subs {:status :ok})
(rx/end! subs)
(catch :default cause
(log/error :hint "Error in graph-wasm/init" :cause cause)
(rx/error! subs cause)
(rx/end! subs)))
(do
(log/warn :hint "Graph WASM module not available")
(rx/push! subs {:status :error :message "Module not available"})
(rx/end! subs)))))
(p/catch (fn [cause]
(log/error :hint "Error loading graph-wasm module" :cause cause)
(rx/error! subs cause)
(rx/end! subs))))
nil)))
(defmethod impl/handler :graph-wasm/set-objects
[message transfer]
(let [objects (:objects message)]
(rx/create
(fn [subs]
(-> @graph-wasm-module
(p/then (fn [module]
(if module
(try
(doseq [shape (vals objects)]
(set-object module shape))
(h/call module "_generate_db")
(rx/push! subs {:status :ok :processed (count objects)})
(rx/end! subs)
(catch :default cause
(log/error :hint "Error in graph-wasm/set-objects" :cause cause)
(rx/error! subs cause)
(rx/end! subs)))
(do
(log/warn :hint "Graph WASM module not available")
(rx/push! subs {:status :error :message "Module not available"})
(rx/end! subs)))))
(p/catch (fn [cause]
(log/error :hint "Error loading graph-wasm module" :cause cause)
(rx/error! subs cause)
(rx/end! subs))))
nil))))
(defmethod impl/handler :graph-wasm/search-similar-shapes
[message transfer]
(let [shape-id (:shape-id message)]
(rx/create
(fn [subs]
(-> @graph-wasm-module
(p/then (fn [module]
(if module
(try
(let [buffer (uuid/get-u32 shape-id)
ptr-raw (h/call module "_search_similar_shapes"
(aget buffer 0)
(aget buffer 1)
(aget buffer 2)
(aget buffer 3))
;; Convert pointer to unsigned 32-bit (handle negative numbers from WASM)
;; Use unsigned right shift to convert signed to unsigned 32-bit
ptr (unsigned-bit-shift-right ptr-raw 0)
heapu8 (unchecked-get module "HEAPU8")
;; Read count (first 4 bytes, little-endian u32)
count (bit-or (aget heapu8 ptr)
(bit-shift-left (aget heapu8 (+ ptr 1)) 8)
(bit-shift-left (aget heapu8 (+ ptr 2)) 16)
(bit-shift-left (aget heapu8 (+ ptr 3)) 24))
;; Read UUIDs (16 bytes each, starting at offset 4)
similar-shapes (loop [offset (+ ptr 4)
remaining count
result []]
(if (zero? remaining)
result
(let [uuid-bytes (.slice heapu8 offset (+ offset 16))]
(recur (+ offset 16)
(dec remaining)
(conj result (uuid/from-bytes uuid-bytes))))))]
;; Free the buffer
(h/call module "_free_similar_shapes_buffer")
(rx/push! subs {:status :ok :similar-shapes similar-shapes})
(rx/end! subs))
(catch :default cause
(log/error :hint "Error in graph-wasm/search-similar-shapes" :cause cause)
(rx/error! subs cause)
(rx/end! subs)))
(do
(log/warn :hint "Graph WASM module not available")
(rx/push! subs {:status :error :message "Module not available"})
(rx/end! subs)))))
(p/catch (fn [cause]
(log/error :hint "Error loading graph-wasm module" :cause cause)
(rx/error! subs cause)
(rx/end! subs))))
nil))))

View File

@@ -47,3 +47,6 @@
result (svg-filters/apply-svg-filters shape)]
(is (= shape result))))

View File

@@ -0,0 +1,6 @@
[target.wasm32-unknown-emscripten]
# Note: Not using atomics to avoid recompiling std library
# We're running without pthreads, so lbug needs to work single-threaded
rustflags = ["-C", "link-arg=-fexceptions"]
# Linker is configured via environment variable CARGO_TARGET_WASM32_UNKNOWN_EMSCRIPTEN_LINKER in _build_env

5
graph-wasm/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
target/
debug/
**/*.rs.bk

486
graph-wasm/Cargo.lock generated Normal file
View File

@@ -0,0 +1,486 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "anstyle"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "autocfg"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "bumpalo"
version = "3.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
[[package]]
name = "cc"
version = "1.2.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
dependencies = [
"find-msvc-tools",
"shlex",
]
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "clap"
version = "4.5.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8"
dependencies = [
"clap_builder",
]
[[package]]
name = "clap_builder"
version = "4.5.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00"
dependencies = [
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_lex"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "cmake"
version = "0.1.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586"
dependencies = [
"cc",
]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]]
name = "cxx"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3956d60afa98653c5a57f60d7056edd513bfe0307ef6fb06f6167400c3884459"
dependencies = [
"cc",
"cxxbridge-cmd",
"cxxbridge-flags",
"cxxbridge-macro",
"foldhash",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a4b7522f539fe056f1d6fc8577d8ab731451f6f33a89b1e5912e22b76c553e7"
dependencies = [
"cc",
"codespan-reporting",
"proc-macro2",
"quote",
"scratch",
"syn",
]
[[package]]
name = "cxxbridge-cmd"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f01e92ab4ce9fd4d16e3bb11b158d98cbdcca803c1417aa43130a6526fbf208"
dependencies = [
"clap",
"codespan-reporting",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "cxxbridge-flags"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c41cbfab344869e70998b388923f7d1266588f56c8ca284abf259b1c1ffc695"
[[package]]
name = "cxxbridge-macro"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d82a2f759f0ad3eae43b96604efd42b1d4729a35a6f2dc7bdb797ae25d9284"
dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn",
]
[[package]]
name = "deranged"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587"
dependencies = [
"powerfmt",
]
[[package]]
name = "find-msvc-tools"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844"
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "getrandom"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasip2",
]
[[package]]
name = "graph"
version = "0.1.0"
dependencies = [
"lbug",
"uuid",
]
[[package]]
name = "js-sys"
version = "0.3.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8"
dependencies = [
"once_cell",
"wasm-bindgen",
]
[[package]]
name = "lbug"
version = "0.12.2"
dependencies = [
"cmake",
"cxx",
"cxx-build",
"rust_decimal",
"rustversion",
"time",
"uuid",
]
[[package]]
name = "libc"
version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "link-cplusplus"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82"
dependencies = [
"cc",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustversion"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "scratch"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2"
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]]
name = "time"
version = "0.3.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d"
dependencies = [
"deranged",
"num-conv",
"powerfmt",
"serde",
"time-core",
]
[[package]]
name = "time-core"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b"
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-width"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]]
name = "uuid"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a"
dependencies = [
"getrandom",
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "wasip2"
version = "1.0.1+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wasm-bindgen"
version = "0.2.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd"
dependencies = [
"cfg-if",
"once_cell",
"rustversion",
"wasm-bindgen-macro",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40"
dependencies = [
"bumpalo",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4"
dependencies = [
"unicode-ident",
]
[[package]]
name = "winapi-util"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys",
]
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-sys"
version = "0.61.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
dependencies = [
"windows-link",
]
[[package]]
name = "wit-bindgen"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59"

31
graph-wasm/Cargo.toml Normal file
View File

@@ -0,0 +1,31 @@
[package]
name = "graph"
version = "0.1.0"
edition = "2021"
repository = "https://github.com/penpot/penpot"
license-file = "../LICENSE"
description = "Wasm-based graph module for Penpot"
build = "build.rs"
[[bin]]
name = "graph_wasm"
path = "src/main.rs"
[profile.release]
opt-level = "s"
# Note: We need panic=unwind because core requires it
# We'll compile std from source with atomics support instead
[profile.dev]
# Note: We need panic=unwind because core requires it
# We'll compile std from source with atomics support instead
[dependencies]
lbug = "0.12.2"
uuid = { version = "1.11.0", features = ["v4", "js"] }
# Patch lbug to use local version with wasm32-unknown-emscripten support
[patch.crates-io]
lbug = { path = "./lbug-0.12.2" }

197
graph-wasm/README_WASM.md Normal file
View File

@@ -0,0 +1,197 @@
# lbug WASM Support
This document describes the modifications made to the `lbug` crate to enable compilation for the `wasm32-unknown-emscripten` target, and the build requirements for using it in a WASM context.
## Overview
The `lbug` crate is a Rust wrapper around a C++ graph database library. To compile it for WebAssembly using Emscripten, several modifications were necessary to handle:
1. C++ exception handling in WASM
2. Conditional compilation for WASM-specific code paths
3. Proper linking of static libraries for Emscripten
4. CMake configuration for single-threaded mode
## Changes Made to lbug
### 1. `build.rs` Modifications
The build script (`build.rs`) was modified to detect and handle the `wasm32-unknown-emscripten` target:
#### WASM Detection
```rust
fn is_wasm_emscripten() -> bool {
env::var("TARGET")
.map(|t| t == "wasm32-unknown-emscripten")
.unwrap_or(false)
}
```
#### CMake Configuration (`build_bundled_cmake()`)
- **Single-threaded mode**: Sets `SINGLE_THREADED=TRUE` for WASM builds (required by Emscripten)
- The Emscripten toolchain is automatically detected when `CC`/`CXX` point to `emcc`/`em++`
#### FFI Build Configuration (`build_ffi()`)
- **C++20 standard**: Uses `-std=c++20` flag for WASM
- **Exception support**: Enables `-fexceptions` flag (exceptions must be enabled at compile time)
- Note: `-sDISABLE_EXCEPTION_CATCHING=0` is a linker flag and should be set via `EMCC_CFLAGS`
#### Library Linking (`link_libraries()`)
- **Explicit dependency linking**: For WASM, all static dependencies are explicitly linked:
- `utf8proc`, `antlr4_cypher`, `antlr4_runtime`, `re2`, `fastpfor`
- `parquet`, `thrift`, `snappy`, `zstd`, `miniz`
- `mbedtls`, `brotlidec`, `brotlicommon`, `lz4`
- `roaring_bitmap`, `simsimd`
- **Linking order**: Libraries are linked after FFI compilation for WASM (different from native builds)
### 2. `src/error.rs` Modifications
The error handling code was modified to conditionally compile C++ exception support:
#### Conditional C++ Exception Variant
The `Error::CxxException` variant and related implementations are conditionally compiled:
```rust
#[cfg(not(target_arch = "wasm32"))]
pub enum Error {
// ... other variants ...
CxxException(cxx::Exception),
// ...
}
```
#### Exception Mapping for WASM
In WASM builds, `cxx::Exception` is mapped to `Error::FailedQuery`:
```rust
impl From<cxx::Exception> for Error {
fn from(item: cxx::Exception) -> Self {
#[cfg(not(target_arch = "wasm32"))]
{
Error::CxxException(item)
}
#[cfg(target_arch = "wasm32")]
{
// In wasm, CxxException is not available, map to a generic error
Error::FailedQuery(item.to_string())
}
}
}
```
**Note**: This change does not affect the rest of `lbug` due to `#[cfg]` guards, ensuring native builds remain unchanged.
## Build Requirements
### 1. Using the Modified lbug Crate
To use the modified `lbug` crate in your project, add a `[patch.crates-io]` section to your `Cargo.toml`:
```toml
[dependencies]
lbug = "0.12.2"
# Patch lbug to use local version with wasm32-unknown-emscripten support
[patch.crates-io]
lbug = { path = "./lbug-0.12.2" }
```
### 2. Emscripten Environment Setup
The build requires Emscripten to be properly configured. The following environment variables should be set:
#### Memory Configuration
```bash
export EM_INITIAL_HEAP=$((256 * 1024 * 1024)) # 256 MB initial heap
export EM_MAXIMUM_MEMORY=$((4 * 1024 * 1024 * 1024)) # 4 GB maximum
export EM_MEMORY_GROWTH_GEOMETRIC_STEP=0.8
export EM_MALLOC=dlmalloc
```
#### Compiler/Linker Configuration
```bash
# Prevent cc-rs from adding default flags that conflict with Emscripten
export CRATE_CC_NO_DEFAULTS=1
# Emscripten compiler flags
export EMCC_CFLAGS="--no-entry \
-sASSERTIONS=1 \
-sALLOW_TABLE_GROWTH=1 \
-sALLOW_MEMORY_GROWTH=1 \
-sINITIAL_HEAP=$EM_INITIAL_HEAP \
-sMEMORY_GROWTH_GEOMETRIC_STEP=$EM_MEMORY_GROWTH_GEOMETRIC_STEP \
-sMAXIMUM_MEMORY=$EM_MAXIMUM_MEMORY \
-sERROR_ON_UNDEFINED_SYMBOLS=0 \
-sDISABLE_EXCEPTION_CATCHING=0 \
-sEXPORT_NAME=createGraphModule \
-sEXPORTED_RUNTIME_METHODS=stringToUTF8,HEAPU8 \
-sENVIRONMENT=web \
-sMODULARIZE=1 \
-sEXPORT_ES6=1"
```
#### Function Exports
To control which functions are exported (avoiding issues with `$` symbols in auto-generated exports), use `RUSTFLAGS`:
```bash
export RUSTFLAGS="-C link-arg=-sEXPORTED_FUNCTIONS=@${SCRIPT_DIR}/exports.txt -C link-arg=-sEXPORT_ALL=0"
```
Where `exports.txt` contains the list of functions to export (one per line, with `_` prefix):
```
_hello
_generate_db
_init
_search_similar_shapes
# ... etc
```
### 3. Build Process
1. **Source Emscripten environment**:
```bash
source /opt/emsdk/emsdk_env.sh
```
2. **Set build environment**:
```bash
source ./_build_env
```
3. **Build**:
```bash
cargo build --target=wasm32-unknown-emscripten
```
## Key Differences from Native Builds
1. **Single-threaded**: WASM builds use `SINGLE_THREADED=TRUE` in CMake
2. **Exception handling**: C++ exceptions are enabled at compile time (`-fexceptions`) and runtime (`-sDISABLE_EXCEPTION_CATCHING=0`)
3. **Linking order**: Libraries are linked after FFI compilation for WASM
4. **Error handling**: C++ exceptions are mapped to `FailedQuery` errors in WASM
5. **Function exports**: Manual control of exported functions via `EXPORTED_FUNCTIONS` file
## Troubleshooting
### Missing Symbols
If you encounter "missing function" errors at runtime, ensure:
- All required static libraries are listed in `link_libraries()` for WASM
- Libraries are linked in the correct order (after FFI compilation)
- `EXPORTED_FUNCTIONS` includes all functions you need to call from JavaScript
### Invalid Export Names
If you see errors like `invalid export name: cxxbridge1$exception`:
- Use `EXPORT_ALL=0` and manually specify functions in `exports.txt`
- Avoid using `EXPORT_ALL=1` with auto-generated export lists that may contain `$` symbols
### CMake Compiler Detection Errors
If CMake fails to detect the compiler:
- Ensure `CC` and `CXX` environment variables point to `emcc` and `em++`
- The Emscripten toolchain should be automatically detected by `cmake-rs`
## References
- [Emscripten Documentation](https://emscripten.org/docs/getting_started/index.html)
- [Rust and WebAssembly](https://rustwasm.github.io/docs/book/)
- [cxx crate documentation](https://cxx.rs/)

105
graph-wasm/_build_env Normal file
View File

@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# --------------------
# Build configuration
# --------------------
_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
export CURRENT_VERSION="${CURRENT_VERSION:-develop}"
export BUILD_NAME="${BUILD_NAME:-graph-wasm}"
export CARGO_BUILD_TARGET="${CARGO_BUILD_TARGET:-wasm32-unknown-emscripten}"
# Keep the emscripten cache in-repo so system cache cleaners do not wipe it.
export EM_CACHE="${EM_CACHE:-${_SCRIPT_DIR}/.emsdk_cache}"
export CARGO_TARGET_DIR="${CARGO_TARGET_DIR:-${_SCRIPT_DIR}/target}"
if [[ -z "${CARGO_INCREMENTAL:-}" && "${NODE_ENV:-}" != "production" ]]; then
export CARGO_INCREMENTAL=1
fi
if [[ -z "${RUSTC_WRAPPER:-}" ]] && command -v sccache >/dev/null 2>&1; then
export RUSTC_WRAPPER=sccache
fi
export CRATE_CC_NO_DEFAULTS=1
# BUILD_MODE
if [[ "${NODE_ENV:-}" == "production" ]]; then
BUILD_MODE=release
else
BUILD_MODE="${1:-debug}"
fi
export BUILD_MODE
# --------------------
# Emscripten memory
# --------------------
export EM_INITIAL_HEAP=$((256 * 1024 * 1024))
export EM_MAXIMUM_MEMORY=$((4 * 1024 * 1024 * 1024))
export EM_MEMORY_GROWTH_GEOMETRIC_STEP=0.8
export EM_MALLOC=dlmalloc
# --------------------
# Flags
# --------------------
EMCC_COMMON_FLAGS=(
--no-entry
-sASSERTIONS=1
-sALLOW_TABLE_GROWTH=1
-sALLOW_MEMORY_GROWTH=1
-sINITIAL_HEAP=$EM_INITIAL_HEAP
-sMEMORY_GROWTH_GEOMETRIC_STEP=$EM_MEMORY_GROWTH_GEOMETRIC_STEP
-sMAXIMUM_MEMORY=$EM_MAXIMUM_MEMORY
-sERROR_ON_UNDEFINED_SYMBOLS=0
-sDISABLE_EXCEPTION_CATCHING=0
-sEXPORT_NAME=createGraphModule
-sEXPORTED_RUNTIME_METHODS=stringToUTF8,HEAPU8
-sENVIRONMENT=web
-sMODULARIZE=1
-sEXPORT_ES6=1
)
export RUSTFLAGS="-C link-arg=-sEXPORTED_FUNCTIONS=@${_SCRIPT_DIR}/exports.txt -C link-arg=-sEXPORT_ALL=0"
# Mode-specific flags
if [[ "$BUILD_MODE" == "release" ]]; then
export EMCC_CFLAGS="-Os ${EMCC_COMMON_FLAGS[*]}"
CARGO_PARAMS=(--release "${@:2}")
else
export EMCC_CFLAGS="-g -sVERBOSE=1 -sMALLOC=$EM_MALLOC ${EMCC_COMMON_FLAGS[*]}"
CARGO_PARAMS=("${@:2}")
fi
export CARGO_PARAMS
# --------------------
# Tasks
# --------------------
clean() {
cargo clean
}
setup() {
:
}
build() {
cargo build "${CARGO_PARAMS[@]}"
}
copy_artifacts() {
local dest=$1
local base="target/$CARGO_BUILD_TARGET/$BUILD_MODE"
mkdir -p "$dest"
cp "$base/graph_wasm.js" "$dest/$BUILD_NAME.js"
cp "$base/graph_wasm.wasm" "$dest/$BUILD_NAME.wasm"
sed -i "s/graph_wasm.wasm/$BUILD_NAME.wasm?version=$CURRENT_VERSION/g" \
"$dest/$BUILD_NAME.js"
}
copy_shared_artifact() {
:
}

20
graph-wasm/build Executable file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
EMSDK_QUIET=1 . /opt/emsdk/emsdk_env.sh
_SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd);
pushd $_SCRIPT_DIR;
. ./_build_env
set -ex;
setup;
build;
copy_artifacts "../frontend/resources/public/js";
copy_shared_artifact;
exit $?;
popd

26
graph-wasm/build.log Normal file
View File

@@ -0,0 +1,26 @@
~/penpot/graph-wasm ~/penpot/graph-wasm
./_build_env: line 60: export: `CXX_wasm32-unknown-emscripten=/home/penpot/penpot/graph-wasm/wrapper-em++.sh': not a valid identifier
./_build_env: line 110: export: `CXXFLAGS_wasm32-unknown-emscripten=-ffunction-sections -fdata-sections -fexceptions --target=wasm32-unknown-emscripten': not a valid identifier
+ setup
+ true
+ build
+ cargo build
Compiling graph v0.1.0 (/home/penpot/penpot/graph-wasm)
warning: unused import: `Error`
--> src/main.rs:1:48
|
1 | use lbug::{Database, Connection, SystemConfig, Error};
| ^^^^^
|
= note: `#[warn(unused_imports)]` (part of `#[warn(unused)]`) on by default
warning: variable does not need to be mutable
--> src/main.rs:23:9
|
23 | let mut conn = match Connection::new(&db) {
| ----^^^^
| |
| help: remove this `mut`
|
= note: `#[warn(unused_mut)]` (part of `#[warn(unused)]`) on by default

2
graph-wasm/build.rs Normal file
View File

@@ -0,0 +1,2 @@
// We need this empty script so OUT_DIR is automatically set
fn main() {}

11
graph-wasm/exports.txt Normal file
View File

@@ -0,0 +1,11 @@
_hello
_generate_db
_init
_search_similar_shapes
_free_similar_shapes_buffer
_set_shape_parent
_set_shape_selrect
_set_shape_type
_use_shape

5
graph-wasm/lbug-0.12.2/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
target/
debug/
**/*.rs.bk

1234
graph-wasm/lbug-0.12.2/Cargo.lock generated Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,144 @@
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.
[package]
edition = "2021"
rust-version = "1.81"
name = "lbug"
version = "0.12.2"
build = "build.rs"
include = [
"build.rs",
"/src",
"/include",
"/lbug-src/src",
"/lbug-src/cmake",
"/lbug-src/third_party",
"/lbug-src/CMakeLists.txt",
"/lbug-src/tools/CMakeLists.txt",
]
autolib = false
autobins = false
autoexamples = false
autotests = false
autobenches = false
description = "An in-process property graph database management system built for query speed and scalability"
homepage = "https://ladybugdb.com/"
readme = "lbug-src/README.md"
keywords = [
"database",
"graph",
"ffi",
]
categories = ["database"]
license = "MIT"
repository = "https://github.com/lbugdb/lbug"
[package.metadata.docs.rs]
all-features = true
[features]
arrow = ["dep:arrow"]
default = []
extension_tests = []
[lib]
name = "lbug"
path = "src/lib.rs"
[dependencies.arrow]
version = "55"
features = ["ffi"]
optional = true
default-features = false
[dependencies.cxx]
version = "=1.0.138"
[dependencies.rust_decimal]
version = "1.37"
default-features = false
[dependencies.time]
version = "0.3"
[dependencies.uuid]
version = "1.6"
[dev-dependencies.anyhow]
version = "1"
[dev-dependencies.rust_decimal_macros]
version = "1.37"
[dev-dependencies.tempfile]
version = "3"
[dev-dependencies.time]
version = "0.3"
features = ["macros"]
[build-dependencies.cmake]
version = "0.1"
[build-dependencies.cxx-build]
version = "=1.0.138"
[build-dependencies.rustversion]
version = "1"
[lints.clippy]
inline_always = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "allow"
module_name_repetitions = "allow"
must_use_candidate = "allow"
needless_pass_by_value = "allow"
redundant_closure_for_method_calls = "allow"
return_self_not_must_use = "allow"
similar_names = "allow"
struct_excessive_bools = "allow"
too_many_arguments = "allow"
too_many_lines = "allow"
type_complexity = "allow"
unreadable_literal = "allow"
[lints.clippy.cargo]
level = "warn"
priority = -1
[lints.clippy.complexity]
level = "warn"
priority = -1
[lints.clippy.correctness]
level = "warn"
priority = -1
[lints.clippy.pedantic]
level = "warn"
priority = -1
[lints.clippy.perf]
level = "warn"
priority = -1
[lints.clippy.style]
level = "warn"
priority = -1
[lints.clippy.suspicious]
level = "warn"
priority = -1
[profile.relwithdebinfo]
debug = 2
inherits = "release"

View File

@@ -0,0 +1,283 @@
use std::env;
use std::path::{Path, PathBuf};
fn link_mode() -> &'static str {
if env::var("LBUG_SHARED").is_ok() {
"dylib"
} else {
"static"
}
}
fn get_target() -> String {
env::var("PROFILE").unwrap()
}
fn is_wasm_emscripten() -> bool {
env::var("TARGET")
.map(|t| t == "wasm32-unknown-emscripten")
.unwrap_or(false)
}
fn link_libraries() {
// For wasm32-unknown-emscripten, we need to link lbug and all its dependencies
// These are built by CMake and need to be linked here
if is_wasm_emscripten() {
// Link all dependencies first (built by CMake)
for lib in [
"utf8proc",
"antlr4_cypher",
"antlr4_runtime",
"re2",
"fastpfor",
"parquet",
"thrift",
"snappy",
"zstd",
"miniz",
"mbedtls",
"brotlidec",
"brotlicommon",
"lz4",
"roaring_bitmap",
"simsimd",
] {
println!("cargo:rustc-link-lib=static={lib}");
}
// Link the lbug static library (built by CMake)
println!("cargo:rustc-link-lib=static=lbug");
// Don't link system libraries for wasm (they're handled by Emscripten)
return;
}
// This also needs to be set by any crates using it if they want to use extensions
if !cfg!(windows) && link_mode() == "static" {
println!("cargo:rustc-link-arg=-rdynamic");
}
if cfg!(windows) && link_mode() == "dylib" {
println!("cargo:rustc-link-lib=dylib=lbug_shared");
} else if link_mode() == "dylib" {
println!("cargo:rustc-link-lib={}=lbug", link_mode());
} else if rustversion::cfg!(since(1.82)) {
println!("cargo:rustc-link-lib=static:+whole-archive=lbug");
} else {
println!("cargo:rustc-link-lib=static=lbug");
}
if link_mode() == "static" {
if cfg!(windows) {
println!("cargo:rustc-link-lib=dylib=msvcrt");
println!("cargo:rustc-link-lib=dylib=shell32");
println!("cargo:rustc-link-lib=dylib=ole32");
} else if cfg!(target_os = "macos") {
println!("cargo:rustc-link-lib=dylib=c++");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
for lib in [
"utf8proc",
"antlr4_cypher",
"antlr4_runtime",
"re2",
"fastpfor",
"parquet",
"thrift",
"snappy",
"zstd",
"miniz",
"mbedtls",
"brotlidec",
"brotlicommon",
"lz4",
"roaring_bitmap",
"simsimd",
] {
if rustversion::cfg!(since(1.82)) {
println!("cargo:rustc-link-lib=static:+whole-archive={lib}");
} else {
println!("cargo:rustc-link-lib=static={lib}");
}
}
}
}
fn build_bundled_cmake() -> Vec<PathBuf> {
let lbug_root = {
let root = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("lbug-src");
if root.is_symlink() || root.is_dir() {
root
} else {
// If the path is not directory, this is probably an in-source build on windows where the
// symlink is unreadable.
Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("../..")
}
};
let mut build = cmake::Config::new(&lbug_root);
build
.no_build_target(true)
.define("BUILD_SHELL", "OFF")
.define("BUILD_SINGLE_FILE_HEADER", "OFF")
.define("AUTO_UPDATE_GRAMMAR", "OFF");
// Configure for wasm32-unknown-emscripten
if is_wasm_emscripten() {
// Same configuration as ladybug/tools/wasm build
build.define("SINGLE_THREADED", "TRUE");
// cmake-rs should automatically detect emscripten toolchain when CC/CXX point to emcc/em++
} else if cfg!(windows) {
build.generator("Ninja");
build.cxxflag("/EHsc");
build.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreadedDLL");
build.define("CMAKE_POLICY_DEFAULT_CMP0091", "NEW");
}
if let Ok(jobs) = env::var("NUM_JOBS") {
// SAFETY: Setting environment variables in build scripts is safe
unsafe {
env::set_var("CMAKE_BUILD_PARALLEL_LEVEL", jobs);
}
}
let build_dir = build.build();
let lbug_lib_path = build_dir.join("build").join("src");
println!("cargo:rustc-link-search=native={}", lbug_lib_path.display());
for dir in [
"utf8proc",
"antlr4_cypher",
"antlr4_runtime",
"re2",
"brotli",
"alp",
"fastpfor",
"parquet",
"thrift",
"snappy",
"zstd",
"miniz",
"mbedtls",
"lz4",
"roaring_bitmap",
"simsimd",
] {
let lib_path = build_dir
.join("build")
.join("third_party")
.join(dir)
.canonicalize()
.unwrap_or_else(|_| {
panic!(
"Could not find {}/build/third_party/{}",
build_dir.display(),
dir
)
});
println!("cargo:rustc-link-search=native={}", lib_path.display());
}
vec![
lbug_root.join("src/include"),
build_dir.join("build/src"),
build_dir.join("build/src/include"),
lbug_root.join("third_party/nlohmann_json"),
lbug_root.join("third_party/fastpfor"),
lbug_root.join("third_party/alp/include"),
]
}
fn build_ffi(
bridge_file: &str,
out_name: &str,
source_file: &str,
bundled: bool,
include_paths: &Vec<PathBuf>,
) {
let mut build = cxx_build::bridge(bridge_file);
build.file(source_file);
if bundled {
build.define("LBUG_BUNDLED", None);
}
if get_target() == "debug" || get_target() == "relwithdebinfo" {
build.define("ENABLE_RUNTIME_CHECKS", "1");
}
if link_mode() == "static" {
build.define("LBUG_STATIC_DEFINE", None);
}
build.includes(include_paths);
println!("cargo:rerun-if-env-changed=LBUG_SHARED");
println!("cargo:rerun-if-changed=include/lbug_rs.h");
println!("cargo:rerun-if-changed=src/lbug_rs.cpp");
// Note that this should match the lbug-src/* entries in the package.include list in Cargo.toml
// Unfortunately they appear to need to be specified individually since the symlink is
// considered to be changed each time.
println!("cargo:rerun-if-changed=lbug-src/src");
println!("cargo:rerun-if-changed=lbug-src/cmake");
println!("cargo:rerun-if-changed=lbug-src/third_party");
println!("cargo:rerun-if-changed=lbug-src/CMakeLists.txt");
println!("cargo:rerun-if-changed=lbug-src/tools/CMakeLists.txt");
if is_wasm_emscripten() {
// For emscripten, use C++20 and enable exceptions
build.flag("-std=c++20");
build.flag("-fexceptions");
// Note: -sDISABLE_EXCEPTION_CATCHING=0 is a linker flag, not a compiler flag
// It should be set via EMCC_CFLAGS environment variable or cargo rustc-link-arg
} else if cfg!(windows) {
build.flag("/std:c++20");
build.flag("/MD");
} else {
build.flag("-std=c++2a");
}
build.compile(out_name);
}
fn main() {
if env::var("DOCS_RS").is_ok() {
// Do nothing; we're just building docs and don't need the C++ library
return;
}
let mut bundled = false;
let mut include_paths =
vec![Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("include")];
if let (Ok(lbug_lib_dir), Ok(lbug_include)) =
(env::var("LBUG_LIBRARY_DIR"), env::var("LBUG_INCLUDE_DIR"))
{
println!("cargo:rustc-link-search=native={lbug_lib_dir}");
println!("cargo:rustc-link-arg=-Wl,-rpath,{lbug_lib_dir}");
include_paths.push(Path::new(&lbug_include).to_path_buf());
} else {
include_paths.extend(build_bundled_cmake());
bundled = true;
}
// For wasm, we need to link libraries after building FFI to ensure proper symbol resolution
if !is_wasm_emscripten() && link_mode() == "static" {
link_libraries();
}
build_ffi(
"src/ffi.rs",
"lbug_rs",
"src/lbug_rs.cpp",
bundled,
&include_paths,
);
if cfg!(feature = "arrow") {
build_ffi(
"src/ffi/arrow.rs",
"lbug_arrow_rs",
"src/lbug_arrow.cpp",
bundled,
&include_paths,
);
}
// For wasm, link libraries after FFI; for dylib, link after FFI
if is_wasm_emscripten() || link_mode() == "dylib" {
link_libraries();
}
}

View File

@@ -0,0 +1,15 @@
#pragma once
#include "rust/cxx.h"
#ifdef LBUG_BUNDLED
#include "main/lbug.h"
#else
#include <lbug.hpp>
#endif
namespace lbug_arrow {
ArrowSchema query_result_get_arrow_schema(const lbug::main::QueryResult& result);
ArrowArray query_result_get_next_arrow_chunk(lbug::main::QueryResult& result, uint64_t chunkSize);
} // namespace lbug_arrow

View File

@@ -0,0 +1,243 @@
#pragma once
#include <cstdint>
#include <memory>
#include "rust/cxx.h"
#ifdef LBUG_BUNDLED
#include "common/type_utils.h"
#include "common/types/int128_t.h"
#include "common/types/types.h"
#include "common/types/value/nested.h"
#include "common/types/value/node.h"
#include "common/types/value/recursive_rel.h"
#include "common/types/value/rel.h"
#include "common/types/value/value.h"
#include "main/lbug.h"
#include "storage/storage_version_info.h"
#else
#include <lbug.hpp>
#endif
namespace lbug_rs {
struct TypeListBuilder {
std::vector<lbug::common::LogicalType> types;
void insert(std::unique_ptr<lbug::common::LogicalType> type) {
types.push_back(std::move(*type));
}
};
std::unique_ptr<TypeListBuilder> create_type_list();
struct QueryParams {
std::unordered_map<std::string, std::unique_ptr<lbug::common::Value>> inputParams;
void insert(const rust::Str key, std::unique_ptr<lbug::common::Value> value) {
inputParams.insert(std::make_pair(key, std::move(value)));
}
};
std::unique_ptr<QueryParams> new_params();
std::unique_ptr<lbug::common::LogicalType> create_logical_type(lbug::common::LogicalTypeID id);
std::unique_ptr<lbug::common::LogicalType> create_logical_type_list(
std::unique_ptr<lbug::common::LogicalType> childType);
std::unique_ptr<lbug::common::LogicalType> create_logical_type_array(
std::unique_ptr<lbug::common::LogicalType> childType, uint64_t numElements);
inline std::unique_ptr<lbug::common::LogicalType> create_logical_type_struct(
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes) {
std::vector<lbug::common::StructField> fields;
for (auto i = 0u; i < fieldNames.size(); i++) {
fields.emplace_back(std::string(fieldNames[i]), std::move(fieldTypes->types[i]));
}
return std::make_unique<lbug::common::LogicalType>(
lbug::common::LogicalType::STRUCT(std::move(fields)));
}
inline std::unique_ptr<lbug::common::LogicalType> create_logical_type_union(
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes) {
std::vector<lbug::common::StructField> fields;
for (auto i = 0u; i < fieldNames.size(); i++) {
fields.emplace_back(std::string(fieldNames[i]), std::move(fieldTypes->types[i]));
}
return std::make_unique<lbug::common::LogicalType>(
lbug::common::LogicalType::UNION(std::move(fields)));
}
std::unique_ptr<lbug::common::LogicalType> create_logical_type_map(
std::unique_ptr<lbug::common::LogicalType> keyType,
std::unique_ptr<lbug::common::LogicalType> valueType);
inline std::unique_ptr<lbug::common::LogicalType> create_logical_type_decimal(uint32_t precision,
uint32_t scale) {
return std::make_unique<lbug::common::LogicalType>(
lbug::common::LogicalType::DECIMAL(precision, scale));
}
std::unique_ptr<lbug::common::LogicalType> logical_type_get_list_child_type(
const lbug::common::LogicalType& logicalType);
std::unique_ptr<lbug::common::LogicalType> logical_type_get_array_child_type(
const lbug::common::LogicalType& logicalType);
uint64_t logical_type_get_array_num_elements(const lbug::common::LogicalType& logicalType);
rust::Vec<rust::String> logical_type_get_struct_field_names(const lbug::common::LogicalType& value);
std::unique_ptr<std::vector<lbug::common::LogicalType>> logical_type_get_struct_field_types(
const lbug::common::LogicalType& value);
inline uint32_t logical_type_get_decimal_precision(const lbug::common::LogicalType& logicalType) {
return lbug::common::DecimalType::getPrecision(logicalType);
}
inline uint32_t logical_type_get_decimal_scale(const lbug::common::LogicalType& logicalType) {
return lbug::common::DecimalType::getScale(logicalType);
}
/* Database */
std::unique_ptr<lbug::main::Database> new_database(std::string_view databasePath,
uint64_t bufferPoolSize, uint64_t maxNumThreads, bool enableCompression, bool readOnly,
uint64_t maxDBSize, bool autoCheckpoint, int64_t checkpointThreshold,
bool throwOnWalReplayFailure, bool enableChecksums);
void database_set_logging_level(lbug::main::Database& database, const std::string& level);
/* Connection */
std::unique_ptr<lbug::main::Connection> database_connect(lbug::main::Database& database);
std::unique_ptr<lbug::main::QueryResult> connection_execute(lbug::main::Connection& connection,
lbug::main::PreparedStatement& query, std::unique_ptr<QueryParams> params);
inline std::unique_ptr<lbug::main::QueryResult> connection_query(lbug::main::Connection& connection,
std::string_view query) {
return connection.query(query);
}
/* PreparedStatement */
rust::String prepared_statement_error_message(const lbug::main::PreparedStatement& statement);
/* QueryResult */
rust::String query_result_to_string(const lbug::main::QueryResult& result);
rust::String query_result_get_error_message(const lbug::main::QueryResult& result);
double query_result_get_compiling_time(const lbug::main::QueryResult& result);
double query_result_get_execution_time(const lbug::main::QueryResult& result);
std::unique_ptr<std::vector<lbug::common::LogicalType>> query_result_column_data_types(
const lbug::main::QueryResult& query_result);
rust::Vec<rust::String> query_result_column_names(const lbug::main::QueryResult& query_result);
/* NodeVal/RelVal */
rust::String node_value_get_label_name(const lbug::common::Value& val);
rust::String rel_value_get_label_name(const lbug::common::Value& val);
size_t node_value_get_num_properties(const lbug::common::Value& value);
size_t rel_value_get_num_properties(const lbug::common::Value& value);
rust::String node_value_get_property_name(const lbug::common::Value& value, size_t index);
rust::String rel_value_get_property_name(const lbug::common::Value& value, size_t index);
const lbug::common::Value& node_value_get_property_value(const lbug::common::Value& value,
size_t index);
const lbug::common::Value& rel_value_get_property_value(const lbug::common::Value& value,
size_t index);
/* NodeVal */
const lbug::common::Value& node_value_get_node_id(const lbug::common::Value& val);
/* RelVal */
const lbug::common::Value& rel_value_get_src_id(const lbug::common::Value& val);
std::array<uint64_t, 2> rel_value_get_dst_id(const lbug::common::Value& val);
/* RecursiveRel */
const lbug::common::Value& recursive_rel_get_nodes(const lbug::common::Value& val);
const lbug::common::Value& recursive_rel_get_rels(const lbug::common::Value& val);
/* FlatTuple */
const lbug::common::Value& flat_tuple_get_value(const lbug::processor::FlatTuple& flatTuple,
uint32_t index);
/* Value */
const std::string& value_get_string(const lbug::common::Value& value);
template<typename T>
std::unique_ptr<T> value_get_unique(const lbug::common::Value& value) {
return std::make_unique<T>(value.getValue<T>());
}
int64_t value_get_interval_secs(const lbug::common::Value& value);
int32_t value_get_interval_micros(const lbug::common::Value& value);
int32_t value_get_date_days(const lbug::common::Value& value);
int64_t value_get_timestamp_ns(const lbug::common::Value& value);
int64_t value_get_timestamp_ms(const lbug::common::Value& value);
int64_t value_get_timestamp_sec(const lbug::common::Value& value);
int64_t value_get_timestamp_micros(const lbug::common::Value& value);
int64_t value_get_timestamp_tz(const lbug::common::Value& value);
std::array<uint64_t, 2> value_get_int128_t(const lbug::common::Value& value);
std::array<uint64_t, 2> value_get_internal_id(const lbug::common::Value& value);
uint32_t value_get_children_size(const lbug::common::Value& value);
const lbug::common::Value& value_get_child(const lbug::common::Value& value, uint32_t index);
lbug::common::LogicalTypeID value_get_data_type_id(const lbug::common::Value& value);
const lbug::common::LogicalType& value_get_data_type(const lbug::common::Value& value);
inline lbug::common::PhysicalTypeID value_get_physical_type(const lbug::common::Value& value) {
return value.getDataType().getPhysicalType();
}
rust::String value_to_string(const lbug::common::Value& val);
std::unique_ptr<lbug::common::Value> create_value_string(lbug::common::LogicalTypeID typ,
const rust::Slice<const unsigned char> value);
std::unique_ptr<lbug::common::Value> create_value_timestamp(const int64_t timestamp);
std::unique_ptr<lbug::common::Value> create_value_timestamp_tz(const int64_t timestamp);
std::unique_ptr<lbug::common::Value> create_value_timestamp_ns(const int64_t timestamp);
std::unique_ptr<lbug::common::Value> create_value_timestamp_ms(const int64_t timestamp);
std::unique_ptr<lbug::common::Value> create_value_timestamp_sec(const int64_t timestamp);
inline std::unique_ptr<lbug::common::Value> create_value_date(const int32_t date) {
return std::make_unique<lbug::common::Value>(lbug::common::date_t(date));
}
std::unique_ptr<lbug::common::Value> create_value_interval(const int32_t months, const int32_t days,
const int64_t micros);
std::unique_ptr<lbug::common::Value> create_value_null(
std::unique_ptr<lbug::common::LogicalType> typ);
std::unique_ptr<lbug::common::Value> create_value_int128_t(int64_t high, uint64_t low);
std::unique_ptr<lbug::common::Value> create_value_internal_id(uint64_t offset, uint64_t table);
inline std::unique_ptr<lbug::common::Value> create_value_uuid_t(int64_t high, uint64_t low) {
return std::make_unique<lbug::common::Value>(
lbug::common::ku_uuid_t{lbug::common::int128_t(low, high)});
}
template<typename T>
std::unique_ptr<lbug::common::Value> create_value(const T value) {
return std::make_unique<lbug::common::Value>(value);
}
inline std::unique_ptr<lbug::common::Value> create_value_decimal(int64_t high, uint64_t low,
uint32_t scale, uint32_t precision) {
auto value =
std::make_unique<lbug::common::Value>(lbug::common::LogicalType::DECIMAL(precision, scale),
std::vector<std::unique_ptr<lbug::common::Value>>{});
auto i128 = lbug::common::int128_t(low, high);
lbug::common::TypeUtils::visit(
value->getDataType().getPhysicalType(),
[&](lbug::common::int128_t) { value->val.int128Val = i128; },
[&](int64_t) { value->val.int64Val = static_cast<int64_t>(i128); },
[&](int32_t) { value->val.int32Val = static_cast<int32_t>(i128); },
[&](int16_t) { value->val.int16Val = static_cast<int16_t>(i128); },
[](auto) { KU_UNREACHABLE; });
return value;
}
struct ValueListBuilder {
std::vector<std::unique_ptr<lbug::common::Value>> values;
void insert(std::unique_ptr<lbug::common::Value> value) { values.push_back(std::move(value)); }
};
std::unique_ptr<lbug::common::Value> get_list_value(std::unique_ptr<lbug::common::LogicalType> typ,
std::unique_ptr<ValueListBuilder> value);
std::unique_ptr<ValueListBuilder> create_list();
inline std::string_view string_view_from_str(rust::Str s) {
return {s.data(), s.size()};
}
inline lbug::storage::storage_version_t get_storage_version() {
return lbug::storage::StorageVersionInfo::getStorageVersion();
}
} // namespace lbug_rs

View File

@@ -0,0 +1,454 @@
cmake_minimum_required(VERSION 3.15)
project(Lbug VERSION 0.12.2 LANGUAGES CXX C)
option(SINGLE_THREADED "Single-threaded mode" FALSE)
if(SINGLE_THREADED)
set(__SINGLE_THREADED__ TRUE)
add_compile_definitions(__SINGLE_THREADED__)
message(STATUS "Single-threaded mode is enabled")
else()
message(STATUS "Multi-threaded mode is enabled: CMAKE_BUILD_PARALLEL_LEVEL=$ENV{CMAKE_BUILD_PARALLEL_LEVEL}")
find_package(Threads REQUIRED)
endif()
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
set(CMAKE_C_VISIBILITY_PRESET hidden)
set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE)
set(CMAKE_FIND_PACKAGE_RESOLVE_SYMLINKS TRUE)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_VISIBILITY_INLINES_HIDDEN ON)
# On Linux, symbols in executables are not accessible by loaded shared libraries (e.g. via dlopen(3)). However, we need to export public symbols in executables so that extensions can access public symbols. This enables that behaviour.
set(CMAKE_ENABLE_EXPORTS TRUE)
option(ENABLE_WERROR "Treat all warnings as errors" FALSE)
if(ENABLE_WERROR)
if (CMAKE_VERSION VERSION_GREATER "3.24.0" OR CMAKE_VERSION VERSION_EQUAL "3.24.0")
set(CMAKE_COMPILE_WARNING_AS_ERROR TRUE)
elseif (MSVC)
add_compile_options(\WX)
else ()
add_compile_options(-Werror)
endif()
endif()
# Detect OS and architecture, copied from DuckDB
set(OS_NAME "unknown")
set(OS_ARCH "amd64")
string(REGEX MATCH "(arm64|aarch64)" IS_ARM "${CMAKE_SYSTEM_PROCESSOR}")
if(IS_ARM)
set(OS_ARCH "arm64")
elseif(FORCE_32_BIT)
set(OS_ARCH "i386")
endif()
if(APPLE)
set(OS_NAME "osx")
endif()
if(WIN32)
set(OS_NAME "windows")
endif()
if(UNIX AND NOT APPLE)
set(OS_NAME "linux") # sorry BSD
endif()
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
message(STATUS "64-bit architecture detected")
add_compile_definitions(__64BIT__)
elseif(CMAKE_SIZEOF_VOID_P EQUAL 4)
message(STATUS "32-bit architecture detected")
add_compile_definitions(__32BIT__)
set(__32BIT__ TRUE)
endif()
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
if(DEFINED ENV{PYBIND11_PYTHON_VERSION})
set(PYBIND11_PYTHON_VERSION $ENV{PYBIND11_PYTHON_VERSION})
endif()
if(DEFINED ENV{PYTHON_EXECUTABLE})
set(PYTHON_EXECUTABLE $ENV{PYTHON_EXECUTABLE})
endif()
find_program(CCACHE_PROGRAM ccache)
if (CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
message(STATUS "ccache found and enabled")
else ()
find_program(CCACHE_PROGRAM sccache)
if (CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
message(STATUS "sccache found and enabled")
endif ()
endif ()
set(INSTALL_LIB_DIR
lib
CACHE PATH "Installation directory for libraries")
set(INSTALL_BIN_DIR
bin
CACHE PATH "Installation directory for executables")
set(INSTALL_INCLUDE_DIR
include
CACHE PATH "Installation directory for header files")
set(INSTALL_CMAKE_DIR
${DEF_INSTALL_CMAKE_DIR}
CACHE PATH "Installation directory for CMake files")
option(ENABLE_ADDRESS_SANITIZER "Enable address sanitizer." FALSE)
option(ENABLE_THREAD_SANITIZER "Enable thread sanitizer." FALSE)
option(ENABLE_UBSAN "Enable undefined behavior sanitizer." FALSE)
option(ENABLE_RUNTIME_CHECKS "Enable runtime coherency checks (e.g. asserts)" FALSE)
option(ENABLE_LTO "Enable Link-Time Optimization" FALSE)
option(ENABLE_MALLOC_BUFFER_MANAGER "Enable Buffer manager using malloc. Default option for webassembly" OFF)
option(LBUG_DEFAULT_REL_STORAGE_DIRECTION "Only store fwd direction in rel tables by default." BOTH)
if(NOT LBUG_DEFAULT_REL_STORAGE_DIRECTION)
set(LBUG_DEFAULT_REL_STORAGE_DIRECTION BOTH)
endif()
set(LBUG_DEFAULT_REL_STORAGE_DIRECTION ${LBUG_DEFAULT_REL_STORAGE_DIRECTION}_REL_STORAGE)
option(LBUG_PAGE_SIZE_LOG2 "Log2 of the page size." 12)
if(NOT LBUG_PAGE_SIZE_LOG2)
set(LBUG_PAGE_SIZE_LOG2 12)
endif()
message(STATUS "LBUG_PAGE_SIZE_LOG2: ${LBUG_PAGE_SIZE_LOG2}")
option(LBUG_VECTOR_CAPACITY_LOG2 "Log2 of the vector capacity." 11)
if(NOT LBUG_VECTOR_CAPACITY_LOG2)
set(LBUG_VECTOR_CAPACITY_LOG2 11)
endif()
message(STATUS "LBUG_VECTOR_CAPACITY_LOG2: ${LBUG_VECTOR_CAPACITY_LOG2}")
# 64 * 2048 nodes per group
option(LBUG_NODE_GROUP_SIZE_LOG2 "Log2 of the vector capacity." 17)
if(NOT LBUG_NODE_GROUP_SIZE_LOG2)
set(LBUG_NODE_GROUP_SIZE_LOG2 17)
endif()
message(STATUS "LBUG_NODE_GROUP_SIZE_LOG2: ${LBUG_NODE_GROUP_SIZE_LOG2}")
option(LBUG_MAX_SEGMENT_SIZE_LOG2 "Log2 of the maximum segment size in bytes." 18)
if(NOT LBUG_MAX_SEGMENT_SIZE_LOG2)
set(LBUG_MAX_SEGMENT_SIZE_LOG2 18)
endif()
message(STATUS "LBUG_MAX_SEGMENT_SIZE_LOG2: ${LBUG_MAX_SEGMENT_SIZE_LOG2}")
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/templates/system_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/src/include/common/system_config.h @ONLY)
include(CheckCXXSymbolExists)
check_cxx_symbol_exists(F_FULLFSYNC "fcntl.h" HAS_FULLFSYNC)
check_cxx_symbol_exists(fdatasync "unistd.h" HAS_FDATASYNC)
if(HAS_FULLFSYNC)
message(STATUS "✓ F_FULLFSYNC will be used on this platform")
add_compile_definitions(HAS_FULLFSYNC)
else()
message(STATUS "✗ F_FULLFSYNC not available")
endif()
if(HAS_FDATASYNC)
message(STATUS "✓ fdatasync will be used on this platform")
add_compile_definitions(HAS_FDATASYNC)
else()
message(STATUS "✗ fdatasync not available, using fsync fallback")
endif()
if(MSVC)
# Required for M_PI on Windows
add_compile_definitions(_USE_MATH_DEFINES)
add_compile_definitions(NOMINMAX)
add_compile_definitions(SERD_STATIC)
# This is a workaround for regex oom issue on windows in gtest.
add_compile_definitions(_REGEX_MAX_STACK_COUNT=0)
add_compile_definitions(_REGEX_MAX_COMPLEXITY_COUNT=0)
# Disable constexpr mutex constructor to avoid compatibility issues with
# older versions of the MSVC runtime library
# See: https://github.com/microsoft/STL/wiki/Changelog#vs-2022-1710
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
# TODO (bmwinger): Figure out if this can be set automatically by cmake,
# or at least better integrated with user-specified options
# For now, hardcode _AMD64_
# CMAKE_GENERATOR_PLATFORM can be used for visual studio builds, but not for ninja
add_compile_definitions(_AMD64_)
# Non-english windows system may use other encodings other than utf-8 (e.g. Chinese use GBK).
add_compile_options("/utf-8")
# Enables support for custom hardware exception handling
add_compile_options("/EHa")
# Reduces the size of the static library by roughly 1/2
add_compile_options("/Zc:inline")
# Disable type conversion warnings
add_compile_options(/wd4244 /wd4267)
# Remove the default to avoid warnings
STRING(REPLACE "/EHsc" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
STRING(REPLACE "/EHs" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# Store all libraries and binaries in the same directory so that lbug_shared.dll is found at runtime
set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/src")
set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}/src")
# This is a workaround for regex stackoverflow issue on windows in gtest.
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /STACK:8388608")
string(REGEX REPLACE "/W[3|4]" "/w" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
add_compile_options($<$<CONFIG:Release>:/W0>)
else()
add_compile_options(-Wall -Wextra)
# Disable warnings for unknown pragmas, which is used by several third-party libraries
add_compile_options(-Wno-unknown-pragmas)
endif()
if(${BUILD_WASM})
if(NOT __SINGLE_THREADED__)
add_compile_options(-pthread)
add_link_options(-pthread)
add_link_options(-sPTHREAD_POOL_SIZE=8)
endif()
add_compile_options(-s DISABLE_EXCEPTION_CATCHING=0)
add_link_options(-sSTACK_SIZE=4MB)
add_link_options(-sASSERTIONS=1)
add_link_options(-lembind)
add_link_options(-sWASM_BIGINT)
if(BUILD_TESTS OR BUILD_EXTENSION_TESTS)
add_link_options(-sINITIAL_MEMORY=3892MB)
add_link_options(-sNODERAWFS=1)
elseif(WASM_NODEFS)
add_link_options(-sNODERAWFS=1)
add_link_options(-sALLOW_MEMORY_GROWTH=1)
add_link_options(-sMODULARIZE=1)
add_link_options(-sEXPORTED_RUNTIME_METHODS=FS,wasmMemory)
add_link_options(-sEXPORT_NAME=lbug)
add_link_options(-sMAXIMUM_MEMORY=4GB)
else()
add_link_options(-sSINGLE_FILE=1)
add_link_options(-sALLOW_MEMORY_GROWTH=1)
add_link_options(-sMODULARIZE=1)
add_link_options(-sEXPORTED_RUNTIME_METHODS=FS,wasmMemory)
add_link_options(-lidbfs.js)
add_link_options(-lworkerfs.js)
add_link_options(-sEXPORT_NAME=lbug)
add_link_options(-sMAXIMUM_MEMORY=4GB)
endif()
set(__WASM__ TRUE)
add_compile_options(-fexceptions)
add_link_options(-s DISABLE_EXCEPTION_CATCHING=0)
add_link_options(-fexceptions)
add_compile_definitions(__WASM__)
set(ENABLE_MALLOC_BUFFER_MANAGER ON)
endif()
if(${BUILD_SWIFT})
add_compile_definitions(__SWIFT__)
set(ENABLE_MALLOC_BUFFER_MANAGER ON)
endif()
if (${ENABLE_MALLOC_BUFFER_MANAGER})
add_compile_definitions(BM_MALLOC)
endif()
if(ANDROID_ABI)
message(STATUS "Android ABI detected: ${ANDROID_ABI}")
add_compile_definitions(__ANDROID__)
set(__ANDROID__ TRUE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
add_compile_options(-Wno-restrict) # no restrict until https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105651 is fixed
endif()
if(${ENABLE_THREAD_SANITIZER} AND (NOT __SINGLE_THREADED__))
if(MSVC)
message(FATAL_ERROR "Thread sanitizer is not supported on MSVC")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread -fno-omit-frame-pointer")
endif()
endif()
if(${ENABLE_ADDRESS_SANITIZER})
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /fsanitize=address")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
endif()
endif()
if(${ENABLE_UBSAN})
if(MSVC)
message(FATAL_ERROR "Undefined behavior sanitizer is not supported on MSVC")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-omit-frame-pointer")
endif()
endif()
if(${ENABLE_RUNTIME_CHECKS})
add_compile_definitions(LBUG_RUNTIME_CHECKS)
endif()
if (${ENABLE_DESER_DEBUG})
add_compile_definitions(LBUG_DESER_DEBUG)
endif()
if(${ENABLE_LTO})
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
endif()
option(AUTO_UPDATE_GRAMMAR "Automatically regenerate C++ grammar files on change." TRUE)
option(BUILD_BENCHMARK "Build benchmarks." FALSE)
option(BUILD_EXTENSIONS "Semicolon-separated list of extensions to build." "")
option(BUILD_EXAMPLES "Build examples." FALSE)
option(BUILD_JAVA "Build Java API." FALSE)
option(BUILD_NODEJS "Build NodeJS API." FALSE)
option(BUILD_PYTHON "Build Python API." FALSE)
option(BUILD_SHELL "Build Interactive Shell" TRUE)
option(BUILD_SINGLE_FILE_HEADER "Build single file header. Requires Python >= 3.9." TRUE)
option(BUILD_TESTS "Build C++ tests." FALSE)
option(BUILD_EXTENSION_TESTS "Build C++ extension tests." FALSE)
option(BUILD_LBUG "Build Lbug." TRUE)
option(ENABLE_BACKTRACES "Enable backtrace printing for exceptions and segfaults" FALSE)
option(USE_STD_FORMAT "Use std::format instead of a custom formatter." FALSE)
option(PREFER_SYSTEM_DEPS "Only download certain deps if not found on the system" TRUE)
option(BUILD_LCOV "Build coverage report." FALSE)
if(${BUILD_LCOV})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage")
endif()
if (ENABLE_BACKTRACES)
set(DOWNLOAD_CPPTRACE TRUE)
if(${PREFER_SYSTEM_DEPS})
find_package(cpptrace QUIET)
if(cpptrace_FOUND)
message(STATUS "Using system cpptrace")
set(DOWNLOAD_CPPTRACE FALSE)
endif()
endif()
if(${DOWNLOAD_CPPTRACE})
message(STATUS "Fetching cpptrace from GitHub...")
include(FetchContent)
FetchContent_Declare(
cpptrace
GIT_REPOSITORY https://github.com/jeremy-rifkin/cpptrace.git
GIT_TAG v0.8.3
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(cpptrace)
endif()
add_compile_definitions(LBUG_BACKTRACE)
endif()
if (USE_STD_FORMAT)
add_compile_definitions(USE_STD_FORMAT)
endif()
function(add_lbug_test TEST_NAME)
set(SRCS ${ARGN})
add_executable(${TEST_NAME} ${SRCS})
target_link_libraries(${TEST_NAME} PRIVATE test_helper test_runner graph_test)
if (ENABLE_BACKTRACES)
target_link_libraries(${TEST_NAME} PRIVATE register_backtrace_signal_handler)
endif()
target_include_directories(${TEST_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/test/include)
include(GoogleTest)
if (TEST_NAME STREQUAL "e2e_test")
gtest_discover_tests(${TEST_NAME}
DISCOVERY_TIMEOUT 600
DISCOVERY_MODE PRE_TEST
TEST_PREFIX e2e_test_
)
else()
gtest_discover_tests(${TEST_NAME}
DISCOVERY_TIMEOUT 600
DISCOVERY_MODE PRE_TEST
)
endif()
endfunction()
function(add_lbug_api_test TEST_NAME)
set(SRCS ${ARGN})
add_executable(${TEST_NAME} ${SRCS})
target_link_libraries(${TEST_NAME} PRIVATE api_graph_test api_test_helper)
if (ENABLE_BACKTRACES)
target_link_libraries(${TEST_NAME} PRIVATE register_backtrace_signal_handler)
endif()
target_include_directories(${TEST_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/test/include)
include(GoogleTest)
gtest_discover_tests(${TEST_NAME})
endfunction()
# Windows doesn't support dynamic lookup, so we have to link extensions against lbug.
if (MSVC AND (NOT BUILD_EXTENSIONS EQUAL ""))
set(BUILD_LBUG TRUE)
endif ()
include_directories(third_party/antlr4_cypher/include)
include_directories(third_party/antlr4_runtime/src)
include_directories(third_party/brotli/c/include)
include_directories(third_party/fast_float/include)
include_directories(third_party/mbedtls/include)
include_directories(third_party/parquet)
include_directories(third_party/snappy)
include_directories(third_party/thrift)
include_directories(third_party/miniz)
include_directories(third_party/nlohmann_json)
include_directories(third_party/pybind11/include)
include_directories(third_party/pyparse)
include_directories(third_party/re2/include)
include_directories(third_party/alp/include)
if (${BUILD_TESTS} OR ${BUILD_EXTENSION_TESTS})
include_directories(third_party/spdlog)
elseif (${BUILD_BENCHMARK})
include_directories(third_party/spdlog)
endif ()
include_directories(third_party/utf8proc/include)
include_directories(third_party/zstd/include)
include_directories(third_party/httplib)
include_directories(third_party/pcg)
include_directories(third_party/lz4)
include_directories(third_party/roaring_bitmap)
# Use SYSTEM to suppress warnings from simsimd
include_directories(SYSTEM third_party/simsimd/include)
add_subdirectory(third_party)
add_definitions(-DLBUG_ROOT_DIRECTORY="${PROJECT_SOURCE_DIR}")
add_definitions(-DLBUG_CMAKE_VERSION="${CMAKE_PROJECT_VERSION}")
add_definitions(-DLBUG_EXTENSION_VERSION="0.12.0")
if(BUILD_LBUG)
include_directories(
src/include
${CMAKE_CURRENT_BINARY_DIR}/src/include
)
endif()
if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/extension/CMakeLists.txt")
add_subdirectory(extension)
endif ()
if(BUILD_LBUG)
add_subdirectory(src)
# Link extensions which require static linking.
foreach(ext IN LISTS STATICALLY_LINKED_EXTENSIONS)
if (${BUILD_EXTENSION_TESTS})
add_compile_definitions(__STATIC_LINK_EXTENSION_TEST__)
endif ()
target_link_libraries(lbug PRIVATE "lbug_${ext}_static_extension")
target_link_libraries(lbug_shared PRIVATE "lbug_${ext}_static_extension")
endforeach()
if (${BUILD_TESTS} OR ${BUILD_EXTENSION_TESTS})
add_subdirectory(test)
elseif (${BUILD_BENCHMARK})
add_subdirectory(test/test_helper)
endif ()
add_subdirectory(tools)
endif ()
if (${BUILD_EXAMPLES})
add_subdirectory(examples/c)
add_subdirectory(examples/cpp)
endif()

View File

@@ -0,0 +1,73 @@
<div align="center">
<picture>
<!-- <source srcset="https://ladybugdb.com/img/lbug-logo-dark.png" media="(prefers-color-scheme: dark)"> -->
<img src="https://ladybugdb.com/logo.png" height="100" alt="Ladybug Logo">
</picture>
</div>
<br>
<p align="center">
<a href="https://github.com/LadybugDB/ladybug/actions">
<img src="https://github.com/LadybugDB/ladybug/actions/workflows/ci-workflow.yml/badge.svg?branch=master" alt="Github Actions Badge"></a>
<a href="https://discord.com/invite/hXyHmvW3Vy">
<img src="https://img.shields.io/discord/1162999022819225631?logo=discord" alt="discord" /></a>
<a href="https://twitter.com/lbugdb">
<img src="https://img.shields.io/badge/follow-@lbugdb-1DA1F2?logo=twitter" alt="twitter"></a>
</p>
# Ladybug
Ladybug is an embedded graph database built for query speed and scalability. Ladybug is optimized for handling complex analytical workloads
on very large databases and provides a set of retrieval features, such as a full text search and vector indices. Our core feature set includes:
- Flexible Property Graph Data Model and Cypher query language
- Embeddable, serverless integration into applications
- Native full text search and vector index
- Columnar disk-based storage
- Columnar sparse row-based (CSR) adjacency list/join indices
- Vectorized and factorized query processor
- Novel and very fast join algorithms
- Multi-core query parallelism
- Serializable ACID transactions
- Wasm (WebAssembly) bindings for fast, secure execution in the browser
Ladybug is being developed by [LadybugDB Developers](https://github.com/LadybugDB) and
is available under a permissible license. So try it out and help us make it better! We welcome your feedback and feature requests.
The database was formerly known as [Kuzu](https://github.com/kuzudb/kuzu).
## Installation
> [!WARNING]
> Many of these binary installation methods are not functional yet. We need to work through package names, availability and convention issues.
> For now, use the build from source method.
| Language | Installation |
| -------- |------------------------------------------------------------------------|
| Python | `pip install real_ladybug` |
| NodeJS | `npm install lbug` |
| Rust | `cargo add lbug` |
| Go | `go get github.com/lbugdb/go-lbug` |
| Swift | [lbug-swift](https://github.com/lbugdb/lbug-swift) |
| Java | [Maven Central](https://central.sonatype.com/artifact/com.ladybugdb/lbug) |
| C/C++ | [precompiled binaries](https://github.com/LadybugDB/ladybug/releases/latest) |
| CLI | [precompiled binaries](https://github.com/LadybugDB/ladybug/releases/latest) |
To learn more about installation, see our [Installation](https://docs.ladybugdb.com/installation) page.
## Getting Started
Refer to our [Getting Started](https://docs.ladybugdb.com/get-started/) page for your first example.
## Build from Source
You can build from source using the instructions provided in the [developer guide](https://docs.ladybugdb.com/developer-guide/).
## Contributing
We welcome contributions to Ladybug. If you are interested in contributing to Ladybug, please read our [Contributing Guide](CONTRIBUTING.md).
## License
By contributing to Ladybug, you agree that your contributions will be licensed under the [MIT License](LICENSE).
## Contact
You can contact us at [social@ladybugdb.com](mailto:social@ladybugdb.com) or [join our Discord community](https://discord.com/invite/hXyHmvW3Vy).

View File

@@ -0,0 +1,75 @@
/*
* This is a template header used for generating the header 'system_config.h'
* Any value in the format @VALUE_NAME@ can be substituted with a value passed into CMakeLists.txt
* See https://cmake.org/cmake/help/latest/command/configure_file.html for more details
*/
#pragma once
#include <algorithm>
#include <cstdint>
#include "common/enums/extend_direction.h"
#define BOTH_REL_STORAGE 0
#define FWD_REL_STORAGE 1
#define BWD_REL_STORAGE 2
namespace lbug {
namespace common {
#define VECTOR_CAPACITY_LOG_2 @LBUG_VECTOR_CAPACITY_LOG2@
#if VECTOR_CAPACITY_LOG_2 > 12
#error "Vector capacity log2 should be less than or equal to 12"
#endif
constexpr uint64_t DEFAULT_VECTOR_CAPACITY = static_cast<uint64_t>(1) << VECTOR_CAPACITY_LOG_2;
// Currently the system supports files with 2 different pages size, which we refer to as
// PAGE_SIZE and TEMP_PAGE_SIZE. PAGE_SIZE is the default size of the page which is the
// unit of read/write to the database files.
static constexpr uint64_t PAGE_SIZE_LOG2 = @LBUG_PAGE_SIZE_LOG2@; // Default to 4KB.
static constexpr uint64_t LBUG_PAGE_SIZE = static_cast<uint64_t>(1) << PAGE_SIZE_LOG2;
// Page size for files with large pages, e.g., temporary files that are used by operators that
// may require large amounts of memory.
static constexpr uint64_t TEMP_PAGE_SIZE_LOG2 = 18;
static const uint64_t TEMP_PAGE_SIZE = static_cast<uint64_t>(1) << TEMP_PAGE_SIZE_LOG2;
#define DEFAULT_REL_STORAGE_DIRECTION @LBUG_DEFAULT_REL_STORAGE_DIRECTION@
#if DEFAULT_REL_STORAGE_DIRECTION == FWD_REL_STORAGE
static constexpr ExtendDirection DEFAULT_EXTEND_DIRECTION = ExtendDirection::FWD;
#elif DEFAULT_REL_STORAGE_DIRECTION == BWD_REL_STORAGE
static constexpr ExtendDirection DEFAULT_EXTEND_DIRECTION = ExtendDirection::BWD;
#else
static constexpr ExtendDirection DEFAULT_EXTEND_DIRECTION = ExtendDirection::BOTH;
#endif
struct StorageConfig {
static constexpr uint64_t NODE_GROUP_SIZE_LOG2 = @LBUG_NODE_GROUP_SIZE_LOG2@;
static constexpr uint64_t NODE_GROUP_SIZE = static_cast<uint64_t>(1) << NODE_GROUP_SIZE_LOG2;
// The number of CSR lists in a leaf region.
static constexpr uint64_t CSR_LEAF_REGION_SIZE_LOG2 =
std::min(static_cast<uint64_t>(10), NODE_GROUP_SIZE_LOG2 - 1);
static constexpr uint64_t CSR_LEAF_REGION_SIZE = static_cast<uint64_t>(1)
<< CSR_LEAF_REGION_SIZE_LOG2;
static constexpr uint64_t CHUNKED_NODE_GROUP_CAPACITY =
std::min(static_cast<uint64_t>(2048), NODE_GROUP_SIZE);
// Maximum size for a segment in bytes
static constexpr uint64_t MAX_SEGMENT_SIZE_LOG2 = @LBUG_MAX_SEGMENT_SIZE_LOG2@;
static constexpr uint64_t MAX_SEGMENT_SIZE = 1 << MAX_SEGMENT_SIZE_LOG2;
};
struct OrderByConfig {
static constexpr uint64_t MIN_SIZE_TO_REDUCE = common::DEFAULT_VECTOR_CAPACITY * 5;
};
struct CopyConfig {
static constexpr uint64_t PANDAS_PARTITION_COUNT = 50 * DEFAULT_VECTOR_CAPACITY;
};
} // namespace common
} // namespace lbug
#undef BOTH_REL_STORAGE
#undef FWD_REL_STORAGE
#undef BWD_REL_STORAGE

View File

@@ -0,0 +1,79 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR})
# Have to pass this down to every subdirectory, which actually adds the files.
# This doesn't affect parent directories.
add_compile_definitions(LBUG_EXPORTS)
add_compile_definitions(ANTLR4CPP_STATIC)
add_subdirectory(binder)
add_subdirectory(c_api)
add_subdirectory(catalog)
add_subdirectory(common)
add_subdirectory(expression_evaluator)
add_subdirectory(function)
add_subdirectory(graph)
add_subdirectory(main)
add_subdirectory(optimizer)
add_subdirectory(parser)
add_subdirectory(planner)
add_subdirectory(processor)
add_subdirectory(storage)
add_subdirectory(transaction)
add_subdirectory(extension)
add_library(lbug STATIC ${ALL_OBJECT_FILES})
add_library(lbug_shared SHARED ${ALL_OBJECT_FILES})
set(LBUG_LIBRARIES antlr4_cypher antlr4_runtime brotlidec brotlicommon fast_float utf8proc re2 fastpfor parquet snappy thrift yyjson zstd miniz mbedtls lz4 roaring_bitmap simsimd)
if (NOT __SINGLE_THREADED__)
set(LBUG_LIBRARIES ${LBUG_LIBRARIES} Threads::Threads)
endif()
if(NOT WIN32)
set(LBUG_LIBRARIES dl ${LBUG_LIBRARIES})
endif()
# Seems to be needed for clang on linux only
# for compiling std::atomic<T>::compare_exchange_weak
if ((NOT APPLE AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") AND NOT __WASM__ AND NOT __SINGLE_THREADED__)
set(LBUG_LIBRARIES atomic ${LBUG_LIBRARIES})
endif()
if (ENABLE_BACKTRACES)
set(LBUG_LIBRARIES ${LBUG_LIBRARIES} cpptrace::cpptrace)
endif()
target_link_libraries(lbug PUBLIC ${LBUG_LIBRARIES})
target_link_libraries(lbug_shared PUBLIC ${LBUG_LIBRARIES})
unset(LBUG_LIBRARIES)
set(LBUG_INCLUDES $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}> ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include/c_api ${CMAKE_CURRENT_BINARY_DIR}/../src/include)
target_include_directories(lbug PUBLIC ${LBUG_INCLUDES})
target_include_directories(lbug_shared PUBLIC ${LBUG_INCLUDES})
unset(LBUG_INCLUDES)
if(WIN32)
# Anything linking against the static library must not use dllimport.
target_compile_definitions(lbug INTERFACE LBUG_STATIC_DEFINE)
endif()
if(NOT WIN32)
set_target_properties(lbug_shared PROPERTIES OUTPUT_NAME lbug)
endif()
install(TARGETS lbug lbug_shared)
if(${BUILD_SINGLE_FILE_HEADER})
# Create a command to generate lbug.hpp, and then create a target that is
# always built that depends on it. This allows our generator to detect when
# exactly to build lbug.hpp, while still building the target by default.
find_package(Python3 3.9...4 REQUIRED)
add_custom_command(
OUTPUT lbug.hpp
COMMAND
${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/collect-single-file-header.py ${CMAKE_CURRENT_BINARY_DIR}/..
DEPENDS
${PROJECT_SOURCE_DIR}/scripts/collect-single-file-header.py lbug_shared)
add_custom_target(single_file_header ALL DEPENDS lbug.hpp)
endif()
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/include/c_api/lbug.h TYPE INCLUDE)
if(${BUILD_SINGLE_FILE_HEADER})
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/lbug.hpp TYPE INCLUDE)
endif()

View File

@@ -0,0 +1,917 @@
ku_Statements
: oC_Cypher ( SP? ';' SP? oC_Cypher )* SP? EOF ;
oC_Cypher
: oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )?;
oC_Statement
: oC_Query
| kU_CreateUser
| kU_CreateRole
| kU_CreateNodeTable
| kU_CreateRelTable
| kU_CreateSequence
| kU_CreateType
| kU_Drop
| kU_AlterTable
| kU_CopyFrom
| kU_CopyFromByColumn
| kU_CopyTO
| kU_StandaloneCall
| kU_CreateMacro
| kU_CommentOn
| kU_Transaction
| kU_Extension
| kU_ExportDatabase
| kU_ImportDatabase
| kU_AttachDatabase
| kU_DetachDatabase
| kU_UseDatabase;
kU_CopyFrom
: COPY SP oC_SchemaName kU_ColumnNames? SP FROM SP kU_ScanSource ( SP? '(' SP? kU_Options SP? ')' )? ;
kU_ColumnNames
: SP? '(' SP? (oC_SchemaName ( SP? ',' SP? oC_SchemaName )* SP?)? ')';
kU_ScanSource
: kU_FilePaths
| '(' SP? oC_Query SP? ')'
| oC_Parameter
| oC_Variable
| oC_Variable '.' SP? oC_SchemaName
| oC_FunctionInvocation ;
kU_CopyFromByColumn
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;
kU_CopyTO
: COPY SP '(' SP? oC_Query SP? ')' SP TO SP StringLiteral ( SP? '(' SP? kU_Options SP? ')' )? ;
kU_ExportDatabase
: EXPORT SP DATABASE SP StringLiteral ( SP? '(' SP? kU_Options SP? ')' )? ;
kU_ImportDatabase
: IMPORT SP DATABASE SP StringLiteral;
kU_AttachDatabase
: ATTACH SP StringLiteral (SP AS SP oC_SchemaName)? SP '(' SP? DBTYPE SP oC_SymbolicName (SP? ',' SP? kU_Options)? SP? ')' ;
kU_Option
: oC_SymbolicName (SP? '=' SP? | SP*) oC_Literal | oC_SymbolicName;
kU_Options
: kU_Option ( SP? ',' SP? kU_Option )* ;
kU_DetachDatabase
: DETACH SP oC_SchemaName;
kU_UseDatabase
: USE SP oC_SchemaName;
kU_StandaloneCall
: CALL SP oC_SymbolicName SP? '=' SP? oC_Expression
| CALL SP oC_FunctionInvocation;
kU_CommentOn
: COMMENT SP ON SP TABLE SP oC_SchemaName SP IS SP StringLiteral ;
kU_CreateMacro
: CREATE SP MACRO SP oC_FunctionName SP? '(' SP? kU_PositionalArgs? SP? kU_DefaultArg? ( SP? ',' SP? kU_DefaultArg )* SP? ')' SP AS SP oC_Expression ;
kU_PositionalArgs
: oC_SymbolicName ( SP? ',' SP? oC_SymbolicName )* ;
kU_DefaultArg
: oC_SymbolicName SP? ':' '=' SP? oC_Literal ;
kU_FilePaths
: '[' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ']'
| StringLiteral
| GLOB SP? '(' SP? StringLiteral SP? ')' ;
kU_IfNotExists
: IF SP NOT SP EXISTS ;
kU_CreateNodeTable
: CREATE SP NODE SP TABLE SP (kU_IfNotExists SP)? oC_SchemaName ( SP? '(' SP? kU_PropertyDefinitions SP? ( ',' SP? kU_CreateNodeConstraint )? SP? ')' | SP AS SP oC_Query ) ;
kU_CreateRelTable
: CREATE SP REL SP TABLE ( SP GROUP )? ( SP kU_IfNotExists )? SP oC_SchemaName
SP? '(' SP?
kU_FromToConnections SP? (
( ',' SP? kU_PropertyDefinitions SP? )?
( ',' SP? oC_SymbolicName SP? )? // Constraints
')'
| ')' SP AS SP oC_Query )
( SP WITH SP? '(' SP? kU_Options SP? ')')? ;
kU_FromToConnections
: kU_FromToConnection ( SP? ',' SP? kU_FromToConnection )* ;
kU_FromToConnection
: FROM SP oC_SchemaName SP TO SP oC_SchemaName ;
kU_CreateSequence
: CREATE SP SEQUENCE SP (kU_IfNotExists SP)? oC_SchemaName (SP kU_SequenceOptions)* ;
kU_CreateType
: CREATE SP TYPE SP oC_SchemaName SP AS SP kU_DataType SP? ;
kU_SequenceOptions
: kU_IncrementBy
| kU_MinValue
| kU_MaxValue
| kU_StartWith
| kU_Cycle;
kU_WithPasswd
: SP WITH SP PASSWORD SP StringLiteral ;
kU_CreateUser
: CREATE SP USER SP (kU_IfNotExists SP)? oC_Variable kU_WithPasswd? ;
kU_CreateRole
: CREATE SP ROLE SP (kU_IfNotExists SP)? oC_Variable ;
kU_IncrementBy : INCREMENT SP ( BY SP )? MINUS? oC_IntegerLiteral ;
kU_MinValue : (NO SP MINVALUE) | (MINVALUE SP MINUS? oC_IntegerLiteral) ;
kU_MaxValue : (NO SP MAXVALUE) | (MAXVALUE SP MINUS? oC_IntegerLiteral) ;
kU_StartWith : START SP ( WITH SP )? MINUS? oC_IntegerLiteral ;
kU_Cycle : (NO SP)? CYCLE ;
kU_IfExists
: IF SP EXISTS ;
kU_Drop
: DROP SP (TABLE | SEQUENCE | MACRO) SP (kU_IfExists SP)? oC_SchemaName ;
kU_AlterTable
: ALTER SP TABLE SP oC_SchemaName SP kU_AlterOptions ;
kU_AlterOptions
: kU_AddProperty
| kU_DropProperty
| kU_RenameTable
| kU_RenameProperty
| kU_AddFromToConnection
| kU_DropFromToConnection;
kU_AddProperty
: ADD SP (kU_IfNotExists SP)? oC_PropertyKeyName SP kU_DataType ( SP kU_Default )? ;
kU_Default
: DEFAULT SP oC_Expression ;
kU_DropProperty
: DROP SP (kU_IfExists SP)? oC_PropertyKeyName ;
kU_RenameTable
: RENAME SP TO SP oC_SchemaName ;
kU_RenameProperty
: RENAME SP oC_PropertyKeyName SP TO SP oC_PropertyKeyName ;
kU_AddFromToConnection
: ADD SP (kU_IfNotExists SP)? kU_FromToConnection ;
kU_DropFromToConnection
: DROP SP (kU_IfExists SP)? kU_FromToConnection ;
kU_ColumnDefinitions: kU_ColumnDefinition ( SP? ',' SP? kU_ColumnDefinition )* ;
kU_ColumnDefinition : oC_PropertyKeyName SP kU_DataType ;
kU_PropertyDefinitions : kU_PropertyDefinition ( SP? ',' SP? kU_PropertyDefinition )* ;
kU_PropertyDefinition : kU_ColumnDefinition ( SP kU_Default )? ( SP PRIMARY SP KEY)?;
kU_CreateNodeConstraint : PRIMARY SP KEY SP? '(' SP? oC_PropertyKeyName SP? ')' ;
DECIMAL: ( 'D' | 'd' ) ( 'E' | 'e' ) ( 'C' | 'c' ) ( 'I' | 'i' ) ( 'M' | 'm' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ;
kU_UnionType
: UNION SP? '(' SP? kU_ColumnDefinitions SP? ')' ;
kU_StructType
: STRUCT SP? '(' SP? kU_ColumnDefinitions SP? ')' ;
kU_MapType
: MAP SP? '(' SP? kU_DataType SP? ',' SP? kU_DataType SP? ')' ;
kU_DecimalType
: DECIMAL SP? '(' SP? oC_IntegerLiteral SP? ',' SP? oC_IntegerLiteral SP? ')' ;
kU_DataType
: oC_SymbolicName
| kU_DataType kU_ListIdentifiers
| kU_UnionType
| kU_StructType
| kU_MapType
| kU_DecimalType ;
kU_ListIdentifiers : kU_ListIdentifier ( kU_ListIdentifier )* ;
kU_ListIdentifier : '[' oC_IntegerLiteral? ']' ;
oC_AnyCypherOption
: oC_Explain
| oC_Profile ;
oC_Explain
: EXPLAIN (SP LOGICAL)? ;
oC_Profile
: PROFILE ;
kU_Transaction
: BEGIN SP TRANSACTION
| BEGIN SP TRANSACTION SP READ SP ONLY
| COMMIT
| ROLLBACK
| CHECKPOINT;
kU_Extension
: kU_LoadExtension
| kU_InstallExtension
| kU_UninstallExtension
| kU_UpdateExtension ;
kU_LoadExtension
: LOAD SP (EXTENSION SP)? ( StringLiteral | oC_Variable ) ;
kU_InstallExtension
: (FORCE SP)? INSTALL SP oC_Variable (SP FROM SP StringLiteral)?;
kU_UninstallExtension
: UNINSTALL SP oC_Variable;
kU_UpdateExtension
: UPDATE SP oC_Variable;
oC_Query
: oC_RegularQuery ;
oC_RegularQuery
: oC_SingleQuery ( SP? oC_Union )*
| (oC_Return SP? )+ oC_SingleQuery { notifyReturnNotAtEnd($ctx->start); }
;
oC_Union
: ( UNION SP ALL SP? oC_SingleQuery )
| ( UNION SP? oC_SingleQuery ) ;
oC_SingleQuery
: oC_SinglePartQuery
| oC_MultiPartQuery
;
oC_SinglePartQuery
: ( oC_ReadingClause SP? )* oC_Return
| ( ( oC_ReadingClause SP? )* oC_UpdatingClause ( SP? oC_UpdatingClause )* ( SP? oC_Return )? )
;
oC_MultiPartQuery
: ( kU_QueryPart SP? )+ oC_SinglePartQuery;
kU_QueryPart
: (oC_ReadingClause SP? )* ( oC_UpdatingClause SP? )* oC_With ;
oC_UpdatingClause
: oC_Create
| oC_Merge
| oC_Set
| oC_Delete
;
oC_ReadingClause
: oC_Match
| oC_Unwind
| kU_InQueryCall
| kU_LoadFrom
;
kU_LoadFrom
: LOAD ( SP WITH SP HEADERS SP? '(' SP? kU_ColumnDefinitions SP? ')' )? SP FROM SP kU_ScanSource (SP? '(' SP? kU_Options SP? ')')? (SP? oC_Where)? ;
oC_YieldItem
: ( oC_Variable SP AS SP )? oC_Variable ;
oC_YieldItems
: oC_YieldItem ( SP? ',' SP? oC_YieldItem )* ;
kU_InQueryCall
: CALL SP oC_FunctionInvocation (SP? oC_Where)? ( SP? YIELD SP oC_YieldItems )? ;
oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern ( SP oC_Where )? ( SP kU_Hint )? ;
kU_Hint
: HINT SP kU_JoinNode;
kU_JoinNode
: kU_JoinNode SP JOIN SP kU_JoinNode
| kU_JoinNode ( SP MULTI_JOIN SP oC_SchemaName)+
| '(' SP? kU_JoinNode SP? ')'
| oC_SchemaName ;
oC_Unwind : UNWIND SP? oC_Expression SP AS SP oC_Variable ;
oC_Create
: CREATE SP? oC_Pattern ;
// For unknown reason, openCypher use oC_PatternPart instead of oC_Pattern. There should be no difference in terms of planning.
// So we choose to be consistent with oC_Create and use oC_Pattern instead.
oC_Merge : MERGE SP? oC_Pattern ( SP oC_MergeAction )* ;
oC_MergeAction
: ( ON SP MATCH SP oC_Set )
| ( ON SP CREATE SP oC_Set )
;
oC_Set
: SET SP? oC_SetItem ( SP? ',' SP? oC_SetItem )*
| SET SP? oC_Atom SP? '=' SP? kU_Properties;
oC_SetItem
: ( oC_PropertyExpression SP? '=' SP? oC_Expression ) ;
oC_Delete
: ( DETACH SP )? DELETE SP? oC_Expression ( SP? ',' SP? oC_Expression )*;
oC_With
: WITH oC_ProjectionBody ( SP? oC_Where )? ;
oC_Return
: RETURN oC_ProjectionBody ;
oC_ProjectionBody
: ( SP? DISTINCT )? SP oC_ProjectionItems (SP oC_Order )? ( SP oC_Skip )? ( SP oC_Limit )? ;
oC_ProjectionItems
: ( STAR ( SP? ',' SP? oC_ProjectionItem )* )
| ( oC_ProjectionItem ( SP? ',' SP? oC_ProjectionItem )* )
;
STAR : '*' ;
oC_ProjectionItem
: ( oC_Expression SP AS SP oC_Variable )
| oC_Expression
;
oC_Order
: ORDER SP BY SP oC_SortItem ( ',' SP? oC_SortItem )* ;
oC_Skip
: L_SKIP SP oC_Expression ;
L_SKIP : ( 'S' | 's' ) ( 'K' | 'k' ) ( 'I' | 'i' ) ( 'P' | 'p' ) ;
oC_Limit
: LIMIT SP oC_Expression ;
oC_SortItem
: oC_Expression ( SP? ( ASCENDING | ASC | DESCENDING | DESC ) )? ;
oC_Where
: WHERE SP oC_Expression ;
oC_Pattern
: oC_PatternPart ( SP? ',' SP? oC_PatternPart )* ;
oC_PatternPart
: ( oC_Variable SP? '=' SP? oC_AnonymousPatternPart )
| oC_AnonymousPatternPart ;
oC_AnonymousPatternPart
: oC_PatternElement ;
oC_PatternElement
: ( oC_NodePattern ( SP? oC_PatternElementChain )* )
| ( '(' oC_PatternElement ')' )
;
oC_NodePattern
: '(' SP? ( oC_Variable SP? )? ( oC_NodeLabels SP? )? ( kU_Properties SP? )? ')' ;
oC_PatternElementChain
: oC_RelationshipPattern SP? oC_NodePattern ;
oC_RelationshipPattern
: ( oC_LeftArrowHead SP? oC_Dash SP? oC_RelationshipDetail? SP? oC_Dash )
| ( oC_Dash SP? oC_RelationshipDetail? SP? oC_Dash SP? oC_RightArrowHead )
| ( oC_Dash SP? oC_RelationshipDetail? SP? oC_Dash )
;
oC_RelationshipDetail
: '[' SP? ( oC_Variable SP? )? ( oC_RelationshipTypes SP? )? ( kU_RecursiveDetail SP? )? ( kU_Properties SP? )? ']' ;
// The original oC_Properties definition is oC_MapLiteral | oC_Parameter.
// We choose to not support parameter as properties which will be the decision for a long time.
// We then substitute with oC_MapLiteral definition. We create oC_MapLiteral only when we decide to add MAP type.
kU_Properties
: '{' SP? ( oC_PropertyKeyName SP? ':' SP? oC_Expression SP? ( ',' SP? oC_PropertyKeyName SP? ':' SP? oC_Expression SP? )* )? '}';
oC_RelationshipTypes
: ':' SP? oC_RelTypeName ( SP? '|' ':'? SP? oC_RelTypeName )* ;
oC_NodeLabels
: ':' SP? oC_LabelName ( SP? ('|' ':'? | ':') SP? oC_LabelName )* ;
kU_RecursiveDetail
: '*' ( SP? kU_RecursiveType)? ( SP? oC_RangeLiteral )? ( SP? kU_RecursiveComprehension )? ;
kU_RecursiveType
: (ALL SP)? WSHORTEST SP? '(' SP? oC_PropertyKeyName SP? ')'
| SHORTEST
| ALL SP SHORTEST
| TRAIL
| ACYCLIC ;
oC_RangeLiteral
: oC_LowerBound? SP? DOTDOT SP? oC_UpperBound?
| oC_IntegerLiteral ;
kU_RecursiveComprehension
: '(' SP? oC_Variable SP? ',' SP? oC_Variable ( SP? '|' SP? oC_Where SP? )? ( SP? '|' SP? kU_RecursiveProjectionItems SP? ',' SP? kU_RecursiveProjectionItems SP? )? ')' ;
kU_RecursiveProjectionItems
: '{' SP? oC_ProjectionItems? SP? '}' ;
oC_LowerBound
: DecimalInteger ;
oC_UpperBound
: DecimalInteger ;
oC_LabelName
: oC_SchemaName ;
oC_RelTypeName
: oC_SchemaName ;
oC_Expression
: oC_OrExpression ;
oC_OrExpression
: oC_XorExpression ( SP OR SP oC_XorExpression )* ;
oC_XorExpression
: oC_AndExpression ( SP XOR SP oC_AndExpression )* ;
oC_AndExpression
: oC_NotExpression ( SP AND SP oC_NotExpression )* ;
oC_NotExpression
: ( NOT SP? )* oC_ComparisonExpression;
oC_ComparisonExpression
: kU_BitwiseOrOperatorExpression ( SP? kU_ComparisonOperator SP? kU_BitwiseOrOperatorExpression )?
| kU_BitwiseOrOperatorExpression ( SP? INVALID_NOT_EQUAL SP? kU_BitwiseOrOperatorExpression ) { notifyInvalidNotEqualOperator($INVALID_NOT_EQUAL); }
| kU_BitwiseOrOperatorExpression SP? kU_ComparisonOperator SP? kU_BitwiseOrOperatorExpression ( SP? kU_ComparisonOperator SP? kU_BitwiseOrOperatorExpression )+ { notifyNonBinaryComparison($ctx->start); }
;
kU_ComparisonOperator : '=' | '<>' | '<' | '<=' | '>' | '>=' ;
INVALID_NOT_EQUAL : '!=' ;
kU_BitwiseOrOperatorExpression
: kU_BitwiseAndOperatorExpression ( SP? '|' SP? kU_BitwiseAndOperatorExpression )* ;
kU_BitwiseAndOperatorExpression
: kU_BitShiftOperatorExpression ( SP? '&' SP? kU_BitShiftOperatorExpression )* ;
kU_BitShiftOperatorExpression
: oC_AddOrSubtractExpression ( SP? kU_BitShiftOperator SP? oC_AddOrSubtractExpression )* ;
kU_BitShiftOperator : '>>' | '<<' ;
oC_AddOrSubtractExpression
: oC_MultiplyDivideModuloExpression ( SP? kU_AddOrSubtractOperator SP? oC_MultiplyDivideModuloExpression )* ;
kU_AddOrSubtractOperator : '+' | '-' ;
oC_MultiplyDivideModuloExpression
: oC_PowerOfExpression ( SP? kU_MultiplyDivideModuloOperator SP? oC_PowerOfExpression )* ;
kU_MultiplyDivideModuloOperator : '*' | '/' | '%' ;
oC_PowerOfExpression
: oC_StringListNullOperatorExpression ( SP? '^' SP? oC_StringListNullOperatorExpression )* ;
oC_StringListNullOperatorExpression
: oC_UnaryAddSubtractOrFactorialExpression ( oC_StringOperatorExpression | oC_ListOperatorExpression+ | oC_NullOperatorExpression )? ;
oC_ListOperatorExpression
: ( SP IN SP? oC_PropertyOrLabelsExpression )
| ( '[' oC_Expression ']' )
| ( '[' oC_Expression? ( COLON | DOTDOT ) oC_Expression? ']' ) ;
COLON : ':' ;
DOTDOT : '..' ;
oC_StringOperatorExpression
: ( oC_RegularExpression | ( SP STARTS SP WITH ) | ( SP ENDS SP WITH ) | ( SP CONTAINS ) ) SP? oC_PropertyOrLabelsExpression ;
oC_RegularExpression
: SP? '=~' ;
oC_NullOperatorExpression
: ( SP IS SP NULL )
| ( SP IS SP NOT SP NULL ) ;
MINUS : '-' ;
FACTORIAL : '!' ;
oC_UnaryAddSubtractOrFactorialExpression
: ( MINUS SP? )* oC_PropertyOrLabelsExpression (SP? FACTORIAL)? ;
oC_PropertyOrLabelsExpression
: oC_Atom ( SP? oC_PropertyLookup )* ;
oC_Atom
: oC_Literal
| oC_Parameter
| oC_CaseExpression
| oC_ParenthesizedExpression
| oC_FunctionInvocation
| oC_PathPatterns
| oC_ExistCountSubquery
| oC_Variable
| oC_Quantifier
;
oC_Quantifier
: ( ALL SP? '(' SP? oC_FilterExpression SP? ')' )
| ( ANY SP? '(' SP? oC_FilterExpression SP? ')' )
| ( NONE SP? '(' SP? oC_FilterExpression SP? ')' )
| ( SINGLE SP? '(' SP? oC_FilterExpression SP? ')' )
;
oC_FilterExpression
: oC_IdInColl SP oC_Where ;
oC_IdInColl
: oC_Variable SP IN SP oC_Expression ;
oC_Literal
: oC_NumberLiteral
| StringLiteral
| oC_BooleanLiteral
| NULL
| oC_ListLiteral
| kU_StructLiteral
;
oC_BooleanLiteral
: TRUE
| FALSE
;
oC_ListLiteral
: '[' SP? ( oC_Expression SP? ( kU_ListEntry SP? )* )? ']' ;
kU_ListEntry
: ',' SP? oC_Expression? ;
kU_StructLiteral
: '{' SP? kU_StructField SP? ( ',' SP? kU_StructField SP? )* '}' ;
kU_StructField
: ( oC_SymbolicName | StringLiteral ) SP? ':' SP? oC_Expression ;
oC_ParenthesizedExpression
: '(' SP? oC_Expression SP? ')' ;
oC_FunctionInvocation
: COUNT SP? '(' SP? '*' SP? ')'
| CAST SP? '(' SP? kU_FunctionParameter SP? ( ( AS SP? kU_DataType ) | ( ',' SP? kU_FunctionParameter ) ) SP? ')'
| oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( kU_FunctionParameter SP? ( ',' SP? kU_FunctionParameter SP? )* )? ')' ;
oC_FunctionName
: oC_SymbolicName ;
kU_FunctionParameter
: ( oC_SymbolicName SP? ':' '=' SP? )? oC_Expression
| kU_LambdaParameter ;
kU_LambdaParameter
: kU_LambdaVars SP? '-' '>' SP? oC_Expression SP? ;
kU_LambdaVars
: oC_SymbolicName
| '(' SP? oC_SymbolicName SP? ( ',' SP? oC_SymbolicName SP?)* ')' ;
oC_PathPatterns
: oC_NodePattern ( SP? oC_PatternElementChain )+;
oC_ExistCountSubquery
: (EXISTS | COUNT) SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? ( SP? kU_Hint )? SP? '}' ;
oC_PropertyLookup
: '.' SP? ( oC_PropertyKeyName | STAR ) ;
oC_CaseExpression
: ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_Expression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_Expression )? SP? END ;
oC_CaseAlternative
: WHEN SP? oC_Expression SP? THEN SP? oC_Expression ;
oC_Variable
: oC_SymbolicName ;
StringLiteral
: ( '"' ( StringLiteral_0 | EscapedChar )* '"' )
| ( '\'' ( StringLiteral_1 | EscapedChar )* '\'' )
;
EscapedChar
: '\\' ( '\\' | '\'' | '"' | ( 'B' | 'b' ) | ( 'F' | 'f' ) | ( 'N' | 'n' ) | ( 'R' | 'r' ) | ( 'T' | 't' ) | ( ( 'X' | 'x' ) ( HexDigit HexDigit ) ) | ( ( 'U' | 'u' ) ( HexDigit HexDigit HexDigit HexDigit ) ) | ( ( 'U' | 'u' ) ( HexDigit HexDigit HexDigit HexDigit HexDigit HexDigit HexDigit HexDigit ) ) ) ;
oC_NumberLiteral
: oC_DoubleLiteral
| oC_IntegerLiteral
;
oC_Parameter
: '$' ( oC_SymbolicName | DecimalInteger ) ;
oC_PropertyExpression
: oC_Atom SP? oC_PropertyLookup ;
oC_PropertyKeyName
: oC_SchemaName ;
oC_IntegerLiteral
: DecimalInteger ;
DecimalInteger
: ZeroDigit
| ( NonZeroDigit ( Digit )* )
;
HexLetter
: ( 'A' | 'a' )
| ( 'B' | 'b' )
| ( 'C' | 'c' )
| ( 'D' | 'd' )
| ( 'E' | 'e' )
| ( 'F' | 'f' )
;
HexDigit
: Digit
| HexLetter
;
Digit
: ZeroDigit
| NonZeroDigit
;
NonZeroDigit
: NonZeroOctDigit
| '8'
| '9'
;
NonZeroOctDigit
: '1'
| '2'
| '3'
| '4'
| '5'
| '6'
| '7'
;
ZeroDigit
: '0' ;
oC_DoubleLiteral
: ExponentDecimalReal
| RegularDecimalReal
;
ExponentDecimalReal
: ( ( Digit )+ | ( ( Digit )+ '.' ( Digit )+ ) | ( '.' ( Digit )+ ) ) ( 'E' | 'e' ) '-'? ( Digit )+ ;
RegularDecimalReal
: ( Digit )* '.' ( Digit )+ ;
oC_SchemaName
: oC_SymbolicName ;
oC_SymbolicName
: UnescapedSymbolicName
| EscapedSymbolicName {if ($EscapedSymbolicName.text == "``") { notifyEmptyToken($EscapedSymbolicName); }}
| HexLetter
| kU_NonReservedKeywords
;
// example of BEGIN and END: TCKWith2.Scenario1
kU_NonReservedKeywords
: COMMENT
| ADD
| ALTER
| AS
| ATTACH
| BEGIN
| BY
| CALL
| CHECKPOINT
| COMMENT
| COMMIT
| CONTAINS
| COPY
| COUNT
| CYCLE
| DATABASE
| DECIMAL
| DELETE
| DETACH
| DROP
| EXPLAIN
| EXPORT
| EXTENSION
| FORCE
| GRAPH
| IF
| IS
| IMPORT
| INCREMENT
| KEY
| LOAD
| LOGICAL
| MATCH
| MAXVALUE
| MERGE
| MINVALUE
| NO
| NODE
| PROJECT
| READ
| REL
| RENAME
| RETURN
| ROLLBACK
| ROLE
| SEQUENCE
| SET
| START
| STRUCT
| L_SKIP
| LIMIT
| TRANSACTION
| TYPE
| USE
| UNINSTALL
| UPDATE
| WRITE
| FROM
| TO
| YIELD
| USER
| PASSWORD
| MAP
;
UnescapedSymbolicName
: IdentifierStart ( IdentifierPart )* ;
IdentifierStart
: ID_Start
| Pc
;
IdentifierPart
: ID_Continue
| Sc
;
EscapedSymbolicName
: ( '`' ( EscapedSymbolicName_0 )* '`' )+ ;
SP
: ( WHITESPACE )+ ;
WHITESPACE
: SPACE
| TAB
| LF
| VT
| FF
| CR
| FS
| GS
| RS
| US
| '\u1680'
| '\u180e'
| '\u2000'
| '\u2001'
| '\u2002'
| '\u2003'
| '\u2004'
| '\u2005'
| '\u2006'
| '\u2008'
| '\u2009'
| '\u200a'
| '\u2028'
| '\u2029'
| '\u205f'
| '\u3000'
| '\u00a0'
| '\u2007'
| '\u202f'
| CypherComment
;
CypherComment
: ( '/*' ( Comment_1 | ( '*' Comment_2 ) )* '*/' )
| ( '//' ( Comment_3 )* CR? ( LF | EOF ) )
;
oC_LeftArrowHead
: '<'
| '\u27e8'
| '\u3008'
| '\ufe64'
| '\uff1c'
;
oC_RightArrowHead
: '>'
| '\u27e9'
| '\u3009'
| '\ufe65'
| '\uff1e'
;
oC_Dash
: '-'
| '\u00ad'
| '\u2010'
| '\u2011'
| '\u2012'
| '\u2013'
| '\u2014'
| '\u2015'
| '\u2212'
| '\ufe58'
| '\ufe63'
| '\uff0d'
;
fragment FF : [\f] ;
fragment EscapedSymbolicName_0 : ~[`] ;
fragment RS : [\u001E] ;
fragment ID_Continue : [\p{ID_Continue}] ;
fragment Comment_1 : ~[*] ;
fragment StringLiteral_1 : ~['\\] ;
fragment Comment_3 : ~[\n\r] ;
fragment Comment_2 : ~[/] ;
fragment GS : [\u001D] ;
fragment FS : [\u001C] ;
fragment CR : [\r] ;
fragment Sc : [\p{Sc}] ;
fragment SPACE : [ ] ;
fragment Pc : [\p{Pc}] ;
fragment TAB : [\t] ;
fragment StringLiteral_0 : ~["\\] ;
fragment LF : [\n] ;
fragment VT : [\u000B] ;
fragment US : [\u001F] ;
fragment ID_Start : [\p{ID_Start}] ;
// This is used to capture unknown lexer input (e.g. !) to avoid parser exception.
Unknown : .;

View File

@@ -0,0 +1 @@
Neither `Cypher.g4` nor `keywords.txt` can be individually used to generate Ladybug's grammar. Rather, the files are combined to generate `scripts/antlr4/Cypher.g4`, which is immediately digestible.

View File

@@ -0,0 +1,115 @@
ACYCLIC
ANY
ADD
ALL
ALTER
AND
AS
ASC
ASCENDING
ATTACH
BEGIN
BY
CALL
CASE
CAST
CHECKPOINT
COLUMN
COMMENT
COMMIT
COMMIT_SKIP_CHECKPOINT
CONTAINS
COPY
COUNT
CREATE
CYCLE
DATABASE
DBTYPE
DEFAULT
DELETE
DESC
DESCENDING
DETACH
DISTINCT
DROP
ELSE
END
ENDS
EXISTS
EXPLAIN
EXPORT
EXTENSION
FALSE
FROM
FORCE
GLOB
GRAPH
GROUP
HEADERS
HINT
IMPORT
IF
IN
INCREMENT
INSTALL
IS
JOIN
KEY
LIMIT
LOAD
LOGICAL
MACRO
MATCH
MAXVALUE
MERGE
MINVALUE
MULTI_JOIN
NO
NODE
NOT
NONE
NULL
ON
ONLY
OPTIONAL
OR
ORDER
PRIMARY
PROFILE
PROJECT
READ
REL
RENAME
RETURN
ROLLBACK
ROLLBACK_SKIP_CHECKPOINT
SEQUENCE
SET
SHORTEST
START
STARTS
STRUCT
TABLE
THEN
TO
TRAIL
TRANSACTION
TRUE
TYPE
UNION
UNWIND
UNINSTALL
UPDATE
USE
WHEN
WHERE
WITH
WRITE
WSHORTEST
XOR
SINGLE
YIELD
USER
PASSWORD
ROLE
MAP

View File

@@ -0,0 +1,22 @@
add_subdirectory(bind)
add_subdirectory(bind_expression)
add_subdirectory(ddl)
add_subdirectory(expression)
add_subdirectory(query)
add_subdirectory(rewriter)
add_subdirectory(visitor)
add_library(lbug_binder
OBJECT
binder.cpp
binder_scope.cpp
bound_statement_result.cpp
bound_scan_source.cpp
bound_statement_rewriter.cpp
bound_statement_visitor.cpp
expression_binder.cpp
expression_visitor.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder>
PARENT_SCOPE)

View File

@@ -0,0 +1,32 @@
add_subdirectory(copy)
add_subdirectory(ddl)
add_subdirectory(read)
add_library(
lbug_binder_bind
OBJECT
bind_attach_database.cpp
bind_create_macro.cpp
bind_ddl.cpp
bind_detach_database.cpp
bind_explain.cpp
bind_extension_clause.cpp
bind_file_scan.cpp
bind_graph_pattern.cpp
bind_projection_clause.cpp
bind_query.cpp
bind_reading_clause.cpp
bind_standalone_call.cpp
bind_table_function.cpp
bind_transaction.cpp
bind_updating_clause.cpp
bind_extension.cpp
bind_export_database.cpp
bind_import_database.cpp
bind_use_database.cpp
bind_standalone_call_function.cpp
bind_table_function.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_bind>
PARENT_SCOPE)

View File

@@ -0,0 +1,36 @@
#include "binder/binder.h"
#include "binder/bound_attach_database.h"
#include "common/exception/binder.h"
#include "common/string_utils.h"
#include "parser/attach_database.h"
#include "parser/expression/parsed_literal_expression.h"
namespace lbug {
namespace binder {
static AttachInfo bindAttachInfo(const parser::AttachInfo& attachInfo) {
binder::AttachOption attachOption;
for (auto& [name, value] : attachInfo.options) {
if (value->getExpressionType() != common::ExpressionType::LITERAL) {
throw common::BinderException{"Attach option must be a literal expression."};
}
auto val = value->constPtrCast<parser::ParsedLiteralExpression>()->getValue();
attachOption.options.emplace(name, std::move(val));
}
if (common::StringUtils::getUpper(attachInfo.dbType) == common::ATTACHED_LBUG_DB_TYPE &&
attachInfo.dbAlias.empty()) {
throw common::BinderException{"Attaching a lbug database without an alias is not allowed."};
}
return binder::AttachInfo{attachInfo.dbPath, attachInfo.dbAlias, attachInfo.dbType,
std::move(attachOption)};
}
std::unique_ptr<BoundStatement> Binder::bindAttachDatabase(const parser::Statement& statement) {
auto& attachDatabase = statement.constCast<parser::AttachDatabase>();
auto boundAttachInfo = bindAttachInfo(attachDatabase.getAttachInfo());
return std::make_unique<BoundAttachDatabase>(std::move(boundAttachInfo));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,35 @@
#include "binder/binder.h"
#include "binder/bound_create_macro.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "parser/create_macro.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindCreateMacro(const Statement& statement) const {
auto& createMacro = ku_dynamic_cast<const CreateMacro&>(statement);
auto macroName = createMacro.getMacroName();
StringUtils::toUpper(macroName);
if (catalog::Catalog::Get(*clientContext)
->containsMacro(transaction::Transaction::Get(*clientContext), macroName)) {
throw BinderException{stringFormat("Macro {} already exists.", macroName)};
}
parser::default_macro_args defaultArgs;
for (auto& defaultArg : createMacro.getDefaultArgs()) {
defaultArgs.emplace_back(defaultArg.first, defaultArg.second->copy());
}
auto scalarMacro =
std::make_unique<function::ScalarMacroFunction>(createMacro.getMacroExpression()->copy(),
createMacro.getPositionalArgs(), std::move(defaultArgs));
return std::make_unique<BoundCreateMacro>(std::move(macroName), std::move(scalarMacro));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,492 @@
#include "binder/binder.h"
#include "binder/ddl/bound_alter.h"
#include "binder/ddl/bound_create_sequence.h"
#include "binder/ddl/bound_create_table.h"
#include "binder/ddl/bound_create_type.h"
#include "binder/ddl/bound_drop.h"
#include "binder/expression_visitor.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/sequence_catalog_entry.h"
#include "common/enums/extend_direction_util.h"
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
#include "common/system_config.h"
#include "common/types/types.h"
#include "function/cast/functions/cast_from_string_functions.h"
#include "function/sequence/sequence_functions.h"
#include "main/client_context.h"
#include "parser/ddl/alter.h"
#include "parser/ddl/create_sequence.h"
#include "parser/ddl/create_table.h"
#include "parser/ddl/create_table_info.h"
#include "parser/ddl/create_type.h"
#include "parser/ddl/drop.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
static void validatePropertyName(const std::vector<PropertyDefinition>& definitions) {
case_insensitve_set_t nameSet;
for (auto& definition : definitions) {
if (nameSet.contains(definition.getName())) {
throw BinderException(stringFormat(
"Duplicated column name: {}, column name must be unique.", definition.getName()));
}
if (Binder::reservedInColumnName(definition.getName())) {
throw BinderException(
stringFormat("{} is a reserved property name.", definition.getName()));
}
nameSet.insert(definition.getName());
}
}
std::vector<PropertyDefinition> Binder::bindPropertyDefinitions(
const std::vector<ParsedPropertyDefinition>& parsedDefinitions, const std::string& tableName) {
std::vector<PropertyDefinition> definitions;
for (auto& def : parsedDefinitions) {
auto type = LogicalType::convertFromString(def.getType(), clientContext);
auto defaultExpr =
resolvePropertyDefault(def.defaultExpr.get(), type, tableName, def.getName());
auto boundExpr = expressionBinder.bindExpression(*defaultExpr);
if (boundExpr->dataType != type) {
expressionBinder.implicitCast(boundExpr, type);
}
auto columnDefinition = ColumnDefinition(def.getName(), std::move(type));
definitions.emplace_back(std::move(columnDefinition), std::move(defaultExpr));
}
validatePropertyName(definitions);
return definitions;
}
std::unique_ptr<ParsedExpression> Binder::resolvePropertyDefault(ParsedExpression* parsedDefault,
const LogicalType& type, const std::string& tableName, const std::string& propertyName) {
if (parsedDefault == nullptr) { // No default provided.
if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) {
auto serialName = SequenceCatalogEntry::getSerialName(tableName, propertyName);
auto literalExpr = std::make_unique<ParsedLiteralExpression>(Value(serialName), "");
return std::make_unique<ParsedFunctionExpression>(function::NextValFunction::name,
std::move(literalExpr), "" /* rawName */);
} else {
return std::make_unique<ParsedLiteralExpression>(Value::createNullValue(type), "NULL");
}
} else {
if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) {
throw BinderException("No DEFAULT value should be set for SERIAL columns");
}
return parsedDefault->copy();
}
}
static void validatePrimaryKey(const std::string& pkColName,
const std::vector<PropertyDefinition>& definitions) {
uint32_t primaryKeyIdx = UINT32_MAX;
for (auto i = 0u; i < definitions.size(); i++) {
if (definitions[i].getName() == pkColName) {
primaryKeyIdx = i;
}
}
if (primaryKeyIdx == UINT32_MAX) {
throw BinderException(
"Primary key " + pkColName + " does not match any of the predefined node properties.");
}
const auto& pkType = definitions[primaryKeyIdx].getType();
if (!pkType.isInternalType()) {
throw BinderException(ExceptionMessage::invalidPKType(pkType.toString()));
}
switch (pkType.getPhysicalType()) {
case PhysicalTypeID::UINT8:
case PhysicalTypeID::UINT16:
case PhysicalTypeID::UINT32:
case PhysicalTypeID::UINT64:
case PhysicalTypeID::INT8:
case PhysicalTypeID::INT16:
case PhysicalTypeID::INT32:
case PhysicalTypeID::INT64:
case PhysicalTypeID::INT128:
case PhysicalTypeID::UINT128:
case PhysicalTypeID::STRING:
case PhysicalTypeID::FLOAT:
case PhysicalTypeID::DOUBLE:
break;
default:
throw BinderException(ExceptionMessage::invalidPKType(pkType.toString()));
}
}
BoundCreateTableInfo Binder::bindCreateTableInfo(const CreateTableInfo* info) {
switch (info->type) {
case TableType::NODE: {
return bindCreateNodeTableInfo(info);
}
case TableType::REL: {
return bindCreateRelTableGroupInfo(info);
}
default: {
KU_UNREACHABLE;
}
}
}
BoundCreateTableInfo Binder::bindCreateNodeTableInfo(const CreateTableInfo* info) {
auto propertyDefinitions = bindPropertyDefinitions(info->propertyDefinitions, info->tableName);
auto& extraInfo = info->extraInfo->constCast<ExtraCreateNodeTableInfo>();
validatePrimaryKey(extraInfo.pKName, propertyDefinitions);
auto boundExtraInfo = std::make_unique<BoundExtraCreateNodeTableInfo>(extraInfo.pKName,
std::move(propertyDefinitions));
return BoundCreateTableInfo(CatalogEntryType::NODE_TABLE_ENTRY, info->tableName,
info->onConflict, std::move(boundExtraInfo), clientContext->useInternalCatalogEntry());
}
void Binder::validateNodeTableType(const TableCatalogEntry* entry) {
if (entry->getType() != CatalogEntryType::NODE_TABLE_ENTRY) {
throw BinderException(stringFormat("{} is not of type NODE.", entry->getName()));
}
}
void Binder::validateTableExistence(const main::ClientContext& context,
const std::string& tableName) {
auto transaction = transaction::Transaction::Get(context);
if (!Catalog::Get(context)->containsTable(transaction, tableName)) {
throw BinderException{stringFormat("Table {} does not exist.", tableName)};
}
}
void Binder::validateColumnExistence(const TableCatalogEntry* entry,
const std::string& columnName) {
if (!entry->containsProperty(columnName)) {
throw BinderException{
stringFormat("Column {} does not exist in table {}.", columnName, entry->getName())};
}
}
static ExtendDirection getStorageDirection(const case_insensitive_map_t<Value>& options) {
if (options.contains(TableOptionConstants::REL_STORAGE_DIRECTION_OPTION)) {
return ExtendDirectionUtil::fromString(
options.at(TableOptionConstants::REL_STORAGE_DIRECTION_OPTION).toString());
}
return DEFAULT_EXTEND_DIRECTION;
}
std::vector<PropertyDefinition> Binder::bindRelPropertyDefinitions(const CreateTableInfo& info) {
std::vector<PropertyDefinition> propertyDefinitions;
propertyDefinitions.emplace_back(
ColumnDefinition(InternalKeyword::ID, LogicalType::INTERNAL_ID()));
for (auto& definition : bindPropertyDefinitions(info.propertyDefinitions, info.tableName)) {
propertyDefinitions.push_back(definition.copy());
}
return propertyDefinitions;
}
BoundCreateTableInfo Binder::bindCreateRelTableGroupInfo(const CreateTableInfo* info) {
auto propertyDefinitions = bindRelPropertyDefinitions(*info);
auto& extraInfo = info->extraInfo->constCast<ExtraCreateRelTableGroupInfo>();
auto srcMultiplicity = RelMultiplicityUtils::getFwd(extraInfo.relMultiplicity);
auto dstMultiplicity = RelMultiplicityUtils::getBwd(extraInfo.relMultiplicity);
auto boundOptions = bindParsingOptions(extraInfo.options);
auto storageDirection = getStorageDirection(boundOptions);
// Bind from to pairs
node_table_id_pair_set_t nodePairsSet;
std::vector<NodeTableIDPair> nodePairs;
for (auto& [srcTableName, dstTableName] : extraInfo.srcDstTablePairs) {
auto srcEntry = bindNodeTableEntry(srcTableName);
validateNodeTableType(srcEntry);
auto dstEntry = bindNodeTableEntry(dstTableName);
validateNodeTableType(dstEntry);
NodeTableIDPair pair{srcEntry->getTableID(), dstEntry->getTableID()};
if (nodePairsSet.contains(pair)) {
throw BinderException(
stringFormat("Found duplicate FROM-TO {}-{} pairs.", srcTableName, dstTableName));
}
nodePairsSet.insert(pair);
nodePairs.emplace_back(pair);
}
auto boundExtraInfo =
std::make_unique<BoundExtraCreateRelTableGroupInfo>(std::move(propertyDefinitions),
srcMultiplicity, dstMultiplicity, storageDirection, std::move(nodePairs));
return BoundCreateTableInfo(CatalogEntryType::REL_GROUP_ENTRY, info->tableName,
info->onConflict, std::move(boundExtraInfo), clientContext->useInternalCatalogEntry());
}
std::unique_ptr<BoundStatement> Binder::bindCreateTable(const Statement& statement) {
auto& createTable = statement.constCast<CreateTable>();
if (createTable.getSource()) {
return bindCreateTableAs(createTable);
}
auto boundCreateInfo = bindCreateTableInfo(createTable.getInfo());
return std::make_unique<BoundCreateTable>(std::move(boundCreateInfo),
BoundStatementResult::createSingleStringColumnResult());
}
std::unique_ptr<BoundStatement> Binder::bindCreateTableAs(const Statement& statement) {
auto& createTable = statement.constCast<CreateTable>();
auto boundInnerQuery = bindQuery(*createTable.getSource()->statement.get());
auto innerQueryResult = boundInnerQuery->getStatementResult();
auto columnNames = innerQueryResult->getColumnNames();
auto columnTypes = innerQueryResult->getColumnTypes();
std::vector<PropertyDefinition> propertyDefinitions;
propertyDefinitions.reserve(columnNames.size());
for (size_t i = 0; i < columnNames.size(); ++i) {
propertyDefinitions.emplace_back(
ColumnDefinition(std::string(columnNames[i]), columnTypes[i].copy()));
}
if (columnNames.empty()) {
throw BinderException("Subquery returns no columns");
}
auto createInfo = createTable.getInfo();
switch (createInfo->type) {
case TableType::NODE: {
// first column is primary key column temporarily for now
auto pkName = columnNames[0];
validatePrimaryKey(pkName, propertyDefinitions);
auto boundCopyFromInfo = bindCopyNodeFromInfo(createInfo->tableName, propertyDefinitions,
createTable.getSource(), options_t{}, columnNames, columnTypes, false /* byColumn */);
auto boundExtraInfo =
std::make_unique<BoundExtraCreateNodeTableInfo>(pkName, std::move(propertyDefinitions));
auto boundCreateInfo = BoundCreateTableInfo(CatalogEntryType::NODE_TABLE_ENTRY,
createInfo->tableName, createInfo->onConflict, std::move(boundExtraInfo),
clientContext->useInternalCatalogEntry());
auto boundCreateTable = std::make_unique<BoundCreateTable>(std::move(boundCreateInfo),
BoundStatementResult::createSingleStringColumnResult());
boundCreateTable->setCopyInfo(std::move(boundCopyFromInfo));
return boundCreateTable;
}
case TableType::REL: {
auto& extraInfo = createInfo->extraInfo->constCast<ExtraCreateRelTableGroupInfo>();
// Currently we don't support multiple from/to pairs for create rel table as
if (extraInfo.srcDstTablePairs.size() > 1) {
throw BinderException(
"Multiple FROM/TO pairs are not supported for CREATE REL TABLE AS.");
}
propertyDefinitions.insert(propertyDefinitions.begin(),
PropertyDefinition(ColumnDefinition(InternalKeyword::ID, LogicalType::INTERNAL_ID())));
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
auto fromTable =
catalog->getTableCatalogEntry(transaction, extraInfo.srcDstTablePairs[0].first)
->ptrCast<NodeTableCatalogEntry>();
auto toTable =
catalog->getTableCatalogEntry(transaction, extraInfo.srcDstTablePairs[0].second)
->ptrCast<NodeTableCatalogEntry>();
auto boundCreateInfo = bindCreateRelTableGroupInfo(createInfo);
auto boundCopyFromInfo = bindCopyRelFromInfo(createInfo->tableName, propertyDefinitions,
createTable.getSource(), options_t{}, columnNames, columnTypes, fromTable, toTable);
boundCreateInfo.extraInfo->ptrCast<BoundExtraCreateTableInfo>()->propertyDefinitions =
std::move(propertyDefinitions);
auto boundCreateTable = std::make_unique<BoundCreateTable>(std::move(boundCreateInfo),
BoundStatementResult::createSingleStringColumnResult());
boundCreateTable->setCopyInfo(std::move(boundCopyFromInfo));
return boundCreateTable;
}
default: {
KU_UNREACHABLE;
}
}
}
std::unique_ptr<BoundStatement> Binder::bindCreateType(const Statement& statement) const {
auto createType = statement.constPtrCast<CreateType>();
auto name = createType->getName();
LogicalType type = LogicalType::convertFromString(createType->getDataType(), clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
if (Catalog::Get(*clientContext)->containsType(transaction, name)) {
throw BinderException{stringFormat("Duplicated type name: {}.", name)};
}
return std::make_unique<BoundCreateType>(std::move(name), std::move(type));
}
std::unique_ptr<BoundStatement> Binder::bindCreateSequence(const Statement& statement) const {
auto& createSequence = statement.constCast<CreateSequence>();
auto info = createSequence.getInfo();
auto sequenceName = info.sequenceName;
int64_t startWith = 0;
int64_t increment = 0;
int64_t minValue = 0;
int64_t maxValue = 0;
auto transaction = transaction::Transaction::Get(*clientContext);
switch (info.onConflict) {
case ConflictAction::ON_CONFLICT_THROW: {
if (Catalog::Get(*clientContext)->containsSequence(transaction, sequenceName)) {
throw BinderException(sequenceName + " already exists in catalog.");
}
} break;
default:
break;
}
auto literal = ku_string_t{info.increment.c_str(), info.increment.length()};
if (!function::CastString::tryCast(literal, increment)) {
throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64.");
}
if (increment == 0) {
throw BinderException("INCREMENT must be non-zero.");
}
if (info.minValue == "") {
minValue = increment > 0 ? 1 : std::numeric_limits<int64_t>::min();
} else {
literal = ku_string_t{info.minValue.c_str(), info.minValue.length()};
if (!function::CastString::tryCast(literal, minValue)) {
throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64.");
}
}
if (info.maxValue == "") {
maxValue = increment > 0 ? std::numeric_limits<int64_t>::max() : -1;
} else {
literal = ku_string_t{info.maxValue.c_str(), info.maxValue.length()};
if (!function::CastString::tryCast(literal, maxValue)) {
throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64.");
}
}
if (info.startWith == "") {
startWith = increment > 0 ? minValue : maxValue;
} else {
literal = ku_string_t{info.startWith.c_str(), info.startWith.length()};
if (!function::CastString::tryCast(literal, startWith)) {
throw BinderException("Out of bounds: SEQUENCE accepts integers within INT64.");
}
}
if (maxValue < minValue) {
throw BinderException("SEQUENCE MAXVALUE should be greater than or equal to MINVALUE.");
}
if (startWith < minValue || startWith > maxValue) {
throw BinderException("SEQUENCE START value should be between MINVALUE and MAXVALUE.");
}
auto boundInfo = BoundCreateSequenceInfo(sequenceName, startWith, increment, minValue, maxValue,
info.cycle, info.onConflict, false /* isInternal */);
return std::make_unique<BoundCreateSequence>(std::move(boundInfo));
}
std::unique_ptr<BoundStatement> Binder::bindDrop(const Statement& statement) {
auto& drop = statement.constCast<Drop>();
return std::make_unique<BoundDrop>(drop.getDropInfo());
}
std::unique_ptr<BoundStatement> Binder::bindAlter(const Statement& statement) {
auto& alter = statement.constCast<Alter>();
switch (alter.getInfo()->type) {
case AlterType::RENAME: {
return bindRenameTable(statement);
}
case AlterType::ADD_PROPERTY: {
return bindAddProperty(statement);
}
case AlterType::DROP_PROPERTY: {
return bindDropProperty(statement);
}
case AlterType::RENAME_PROPERTY: {
return bindRenameProperty(statement);
}
case AlterType::COMMENT: {
return bindCommentOn(statement);
}
case AlterType::ADD_FROM_TO_CONNECTION:
case AlterType::DROP_FROM_TO_CONNECTION: {
return bindAlterFromToConnection(statement);
}
default: {
KU_UNREACHABLE;
}
}
}
std::unique_ptr<BoundStatement> Binder::bindRenameTable(const Statement& statement) const {
auto& alter = statement.constCast<Alter>();
auto info = alter.getInfo();
auto extraInfo = ku_dynamic_cast<ExtraRenameTableInfo*>(info->extraInfo.get());
auto tableName = info->tableName;
auto newName = extraInfo->newName;
auto boundExtraInfo = std::make_unique<BoundExtraRenameTableInfo>(newName);
auto boundInfo =
BoundAlterInfo(AlterType::RENAME, tableName, std::move(boundExtraInfo), info->onConflict);
return std::make_unique<BoundAlter>(std::move(boundInfo));
}
std::unique_ptr<BoundStatement> Binder::bindAddProperty(const Statement& statement) {
auto& alter = statement.constCast<Alter>();
auto info = alter.getInfo();
auto extraInfo = info->extraInfo->ptrCast<ExtraAddPropertyInfo>();
auto tableName = info->tableName;
auto propertyName = extraInfo->propertyName;
auto type = LogicalType::convertFromString(extraInfo->dataType, clientContext);
auto columnDefinition = ColumnDefinition(propertyName, type.copy());
auto defaultExpr =
resolvePropertyDefault(extraInfo->defaultValue.get(), type, tableName, propertyName);
auto boundDefault = expressionBinder.bindExpression(*defaultExpr);
boundDefault = expressionBinder.implicitCastIfNecessary(boundDefault, type);
if (ConstantExpressionVisitor::needFold(*boundDefault)) {
boundDefault = expressionBinder.foldExpression(boundDefault);
}
auto propertyDefinition =
PropertyDefinition(std::move(columnDefinition), std::move(defaultExpr));
auto boundExtraInfo = std::make_unique<BoundExtraAddPropertyInfo>(std::move(propertyDefinition),
std::move(boundDefault));
auto boundInfo = BoundAlterInfo(AlterType::ADD_PROPERTY, tableName, std::move(boundExtraInfo),
info->onConflict);
return std::make_unique<BoundAlter>(std::move(boundInfo));
}
std::unique_ptr<BoundStatement> Binder::bindDropProperty(const Statement& statement) const {
auto& alter = statement.constCast<Alter>();
auto info = alter.getInfo();
auto extraInfo = info->extraInfo->constPtrCast<ExtraDropPropertyInfo>();
auto tableName = info->tableName;
auto propertyName = extraInfo->propertyName;
auto boundExtraInfo = std::make_unique<BoundExtraDropPropertyInfo>(propertyName);
auto boundInfo = BoundAlterInfo(AlterType::DROP_PROPERTY, tableName, std::move(boundExtraInfo),
info->onConflict);
return std::make_unique<BoundAlter>(std::move(boundInfo));
}
std::unique_ptr<BoundStatement> Binder::bindRenameProperty(const Statement& statement) const {
auto& alter = statement.constCast<Alter>();
auto info = alter.getInfo();
auto extraInfo = info->extraInfo->constPtrCast<ExtraRenamePropertyInfo>();
auto tableName = info->tableName;
auto propertyName = extraInfo->propertyName;
auto newName = extraInfo->newName;
auto boundExtraInfo = std::make_unique<BoundExtraRenamePropertyInfo>(newName, propertyName);
auto boundInfo = BoundAlterInfo(AlterType::RENAME_PROPERTY, tableName,
std::move(boundExtraInfo), info->onConflict);
return std::make_unique<BoundAlter>(std::move(boundInfo));
}
std::unique_ptr<BoundStatement> Binder::bindCommentOn(const Statement& statement) const {
auto& alter = statement.constCast<Alter>();
auto info = alter.getInfo();
auto extraInfo = info->extraInfo->constPtrCast<ExtraCommentInfo>();
auto tableName = info->tableName;
auto comment = extraInfo->comment;
auto boundExtraInfo = std::make_unique<BoundExtraCommentInfo>(comment);
auto boundInfo =
BoundAlterInfo(AlterType::COMMENT, tableName, std::move(boundExtraInfo), info->onConflict);
return std::make_unique<BoundAlter>(std::move(boundInfo));
}
std::unique_ptr<BoundStatement> Binder::bindAlterFromToConnection(
const Statement& statement) const {
auto& alter = statement.constCast<Alter>();
auto info = alter.getInfo();
auto extraInfo = info->extraInfo->constPtrCast<ExtraAddFromToConnection>();
auto tableName = info->tableName;
auto srcTableEntry = bindNodeTableEntry(extraInfo->srcTableName);
auto dstTableEntry = bindNodeTableEntry(extraInfo->dstTableName);
auto srcTableID = srcTableEntry->getTableID();
auto dstTableID = dstTableEntry->getTableID();
auto boundExtraInfo = std::make_unique<BoundExtraAlterFromToConnection>(srcTableID, dstTableID);
auto boundInfo =
BoundAlterInfo(info->type, tableName, std::move(boundExtraInfo), info->onConflict);
return std::make_unique<BoundAlter>(std::move(boundInfo));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,14 @@
#include "binder/binder.h"
#include "binder/bound_detach_database.h"
#include "parser/detach_database.h"
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindDetachDatabase(const parser::Statement& statement) {
auto& detachDatabase = statement.constCast<parser::DetachDatabase>();
return std::make_unique<BoundDetachDatabase>(detachDatabase.getDBName());
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,16 @@
#include "binder/binder.h"
#include "binder/bound_explain.h"
#include "parser/explain_statement.h"
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindExplain(const parser::Statement& statement) {
auto& explain = statement.constCast<parser::ExplainStatement>();
auto boundStatementToExplain = bind(*explain.getStatementToExplain());
return std::make_unique<BoundExplain>(std::move(boundStatementToExplain),
explain.getExplainType());
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,169 @@
#include "binder/bound_export_database.h"
#include "binder/query/bound_regular_query.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/index_catalog_entry.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/file_system/virtual_file_system.h"
#include "common/string_utils.h"
#include "main/client_context.h"
#include "parser/parser.h"
#include "parser/port_db.h"
#include "parser/query/regular_query.h"
#include "transaction/transaction.h"
using namespace lbug::binder;
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
using namespace lbug::transaction;
using namespace lbug::storage;
namespace lbug {
namespace binder {
FileTypeInfo getFileType(case_insensitive_map_t<Value>& options) {
auto fileTypeInfo =
FileTypeInfo{FileType::PARQUET, PortDBConstants::DEFAULT_EXPORT_FORMAT_OPTION};
if (options.contains(PortDBConstants::EXPORT_FORMAT_OPTION)) {
auto value = options.at(PortDBConstants::EXPORT_FORMAT_OPTION);
if (value.getDataType().getLogicalTypeID() != LogicalTypeID::STRING) {
throw BinderException("The type of format option must be a string.");
}
auto valueStr = value.getValue<std::string>();
StringUtils::toUpper(valueStr);
fileTypeInfo = FileTypeInfo{FileTypeUtils::fromString(valueStr), valueStr};
options.erase(PortDBConstants::EXPORT_FORMAT_OPTION);
}
return fileTypeInfo;
}
void bindExportTableData(ExportedTableData& tableData, const std::string& query,
main::ClientContext* context, Binder* binder) {
auto parsedStatement = Parser::parseQuery(query);
KU_ASSERT(parsedStatement.size() == 1);
auto parsedQuery = parsedStatement[0]->constPtrCast<RegularQuery>();
context->setUseInternalCatalogEntry(true /* useInternalCatalogEntry */);
auto boundQuery = binder->bindQuery(*parsedQuery);
context->setUseInternalCatalogEntry(false /* useInternalCatalogEntry */);
auto columns = boundQuery->getStatementResult()->getColumns();
for (auto& column : columns) {
auto columnName = column->hasAlias() ? column->getAlias() : column->toString();
tableData.columnNames.push_back(columnName);
tableData.columnTypes.push_back(column->getDataType().copy());
}
tableData.regularQuery = std::move(boundQuery);
}
static std::string getExportNodeTableDataQuery(const TableCatalogEntry& entry) {
return stringFormat("match (a:`{}`) return a.*", entry.getName());
}
static std::string getExportRelTableDataQuery(const TableCatalogEntry& relGroupEntry,
const NodeTableCatalogEntry& srcEntry, const NodeTableCatalogEntry& dstEntry) {
return stringFormat("match (a:`{}`)-[r:`{}`]->(b:`{}`) return a.{},b.{},r.*;",
srcEntry.getName(), relGroupEntry.getName(), dstEntry.getName(),
srcEntry.getPrimaryKeyName(), dstEntry.getPrimaryKeyName());
}
static std::vector<ExportedTableData> getExportInfo(const Catalog& catalog,
main::ClientContext* context, Binder* binder, FileTypeInfo& fileTypeInfo) {
auto transaction = Transaction::Get(*context);
std::vector<ExportedTableData> exportData;
for (auto entry : catalog.getNodeTableEntries(transaction, false /*useInternal*/)) {
ExportedTableData tableData;
tableData.tableName = entry->getName();
tableData.fileName =
entry->getName() + "." + StringUtils::getLower(fileTypeInfo.fileTypeStr);
auto query = getExportNodeTableDataQuery(*entry);
bindExportTableData(tableData, query, context, binder);
exportData.push_back(std::move(tableData));
}
for (auto entry : catalog.getRelGroupEntries(transaction, false /* useInternal */)) {
auto& relGroupEntry = entry->constCast<RelGroupCatalogEntry>();
for (auto& info : relGroupEntry.getRelEntryInfos()) {
ExportedTableData tableData;
auto srcTableID = info.nodePair.srcTableID;
auto dstTableID = info.nodePair.dstTableID;
auto& srcEntry = catalog.getTableCatalogEntry(transaction, srcTableID)
->constCast<NodeTableCatalogEntry>();
auto& dstEntry = catalog.getTableCatalogEntry(transaction, dstTableID)
->constCast<NodeTableCatalogEntry>();
tableData.tableName = entry->getName();
tableData.fileName =
stringFormat("{}_{}_{}.{}", relGroupEntry.getName(), srcEntry.getName(),
dstEntry.getName(), StringUtils::getLower(fileTypeInfo.fileTypeStr));
auto query = getExportRelTableDataQuery(relGroupEntry, srcEntry, dstEntry);
bindExportTableData(tableData, query, context, binder);
exportData.push_back(std::move(tableData));
}
}
for (auto indexEntry : catalog.getIndexEntries(transaction)) {
// Export
ExportedTableData tableData;
auto entry = indexEntry->getTableEntryToExport(context);
if (entry == nullptr) {
continue;
}
KU_ASSERT(entry->getTableType() == TableType::NODE);
tableData.tableName = entry->getName();
tableData.fileName =
entry->getName() + "." + StringUtils::getLower(fileTypeInfo.fileTypeStr);
auto query = getExportNodeTableDataQuery(*entry);
bindExportTableData(tableData, query, context, binder);
exportData.push_back(std::move(tableData));
}
return exportData;
}
static bool schemaOnly(case_insensitive_map_t<Value>& parsedOptions,
const parser::ExportDB& exportDB) {
auto isSchemaOnlyOption = [](const std::pair<std::string, Value>& option) -> bool {
if (option.first != PortDBConstants::SCHEMA_ONLY_OPTION) {
return false;
}
if (option.second.getDataType() != LogicalType::BOOL()) {
throw common::BinderException{common::stringFormat(
"The '{}' option must have a BOOL value.", PortDBConstants::SCHEMA_ONLY_OPTION)};
}
return option.second.getValue<bool>();
};
auto exportSchemaOnly =
std::count_if(parsedOptions.begin(), parsedOptions.end(), isSchemaOnlyOption) != 0;
if (exportSchemaOnly && exportDB.getParsingOptionsRef().size() != 1) {
throw common::BinderException{
common::stringFormat("When '{}' option is set to true in export "
"database, no other options are allowed.",
PortDBConstants::SCHEMA_ONLY_OPTION)};
}
parsedOptions.erase(PortDBConstants::SCHEMA_ONLY_OPTION);
return exportSchemaOnly;
}
std::unique_ptr<BoundStatement> Binder::bindExportDatabaseClause(const Statement& statement) {
auto& exportDB = statement.constCast<ExportDB>();
auto parsedOptions = bindParsingOptions(exportDB.getParsingOptionsRef());
auto fileTypeInfo = getFileType(parsedOptions);
switch (fileTypeInfo.fileType) {
case FileType::CSV:
case FileType::PARQUET:
break;
default:
throw BinderException("Export database currently only supports csv and parquet files.");
}
auto exportSchemaOnly = schemaOnly(parsedOptions, exportDB);
if (!exportSchemaOnly && fileTypeInfo.fileType != FileType::CSV && parsedOptions.size() != 0) {
throw BinderException{"Only export to csv can have options."};
}
auto exportData =
getExportInfo(*Catalog::Get(*clientContext), clientContext, this, fileTypeInfo);
auto boundFilePath = VirtualFileSystem::GetUnsafe(*clientContext)
->expandPath(clientContext, exportDB.getFilePath());
return std::make_unique<BoundExportDatabase>(boundFilePath, fileTypeInfo, std::move(exportData),
std::move(parsedOptions), exportSchemaOnly);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,79 @@
#include "binder/binder.h"
#include "binder/bound_extension_statement.h"
#include "common/exception/binder.h"
#include "common/file_system/local_file_system.h"
#include "common/string_utils.h"
#include "extension/extension.h"
#include "parser/extension_statement.h"
using namespace lbug::parser;
namespace lbug {
namespace binder {
static void bindInstallExtension(const ExtensionAuxInfo& auxInfo) {
if (!ExtensionUtils::isOfficialExtension(auxInfo.path)) {
throw common::BinderException(
common::stringFormat("{} is not an official extension.\nNon-official extensions "
"can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.",
auxInfo.path));
}
}
static void bindLoadExtension(main::ClientContext* context, const ExtensionAuxInfo& auxInfo) {
auto localFileSystem = common::LocalFileSystem("");
if (ExtensionUtils::isOfficialExtension(auxInfo.path)) {
auto extensionName = common::StringUtils::getLower(auxInfo.path);
if (!localFileSystem.fileOrPathExists(
ExtensionUtils::getLocalPathForExtensionLib(context, extensionName))) {
throw common::BinderException(
common::stringFormat("Extension: {} is an official extension and has not been "
"installed.\nYou can install it by: install {}.",
extensionName, extensionName));
}
return;
}
if (!localFileSystem.fileOrPathExists(auxInfo.path, nullptr /* clientContext */)) {
throw common::BinderException(
common::stringFormat("The extension {} is neither an official extension, nor does "
"the extension path: '{}' exists.",
auxInfo.path, auxInfo.path));
}
}
static void bindUninstallExtension(const ExtensionAuxInfo& auxInfo) {
if (!ExtensionUtils::isOfficialExtension(auxInfo.path)) {
throw common::BinderException(
common::stringFormat("The extension {} is not an official extension.\nOnly official "
"extensions can be uninstalled.",
auxInfo.path));
}
}
std::unique_ptr<BoundStatement> Binder::bindExtension(const Statement& statement) {
#ifdef __WASM__
throw common::BinderException{"Extensions are not available in the WASM environment"};
#endif
auto extensionStatement = statement.constPtrCast<ExtensionStatement>();
auto auxInfo = extensionStatement->getAuxInfo();
switch (auxInfo->action) {
case ExtensionAction::INSTALL:
bindInstallExtension(*auxInfo);
break;
case ExtensionAction::LOAD:
bindLoadExtension(clientContext, *auxInfo);
break;
case ExtensionAction::UNINSTALL:
bindUninstallExtension(*auxInfo);
break;
default:
KU_UNREACHABLE;
}
if (ExtensionUtils::isOfficialExtension(auxInfo->path)) {
common::StringUtils::toLower(auxInfo->path);
}
return std::make_unique<BoundExtensionStatement>(std::move(auxInfo));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,21 @@
#include "binder/binder.h"
#include "extension/binder_extension.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindExtensionClause(const parser::Statement& statement) {
for (auto& binderExtension : binderExtensions) {
auto boundStatement = binderExtension->bind(statement);
if (boundStatement) {
return boundStatement;
}
}
KU_UNREACHABLE;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,312 @@
#include "binder/binder.h"
#include "binder/bound_scan_source.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/parameter_expression.h"
#include "common/exception/binder.h"
#include "common/exception/copy.h"
#include "common/exception/message.h"
#include "common/file_system/local_file_system.h"
#include "common/file_system/virtual_file_system.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "extension/extension_manager.h"
#include "function/table/bind_input.h"
#include "main/client_context.h"
#include "main/database_manager.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/scan_source.h"
using namespace lbug::parser;
using namespace lbug::binder;
using namespace lbug::common;
using namespace lbug::function;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
FileTypeInfo bindSingleFileType(const main::ClientContext* context, const std::string& filePath) {
std::filesystem::path fileName(filePath);
auto extension = VirtualFileSystem::GetUnsafe(*context)->getFileExtension(fileName);
return FileTypeInfo{FileTypeUtils::getFileTypeFromExtension(extension),
extension.substr(std::min<uint64_t>(1, extension.length()))};
}
FileTypeInfo Binder::bindFileTypeInfo(const std::vector<std::string>& filePaths) const {
auto expectedFileType = FileTypeInfo{FileType::UNKNOWN, "" /* fileTypeStr */};
for (auto& filePath : filePaths) {
auto fileType = bindSingleFileType(clientContext, filePath);
expectedFileType =
(expectedFileType.fileType == FileType::UNKNOWN) ? fileType : expectedFileType;
if (fileType.fileType != expectedFileType.fileType) {
throw CopyException("Loading files with different types is not currently supported.");
}
}
return expectedFileType;
}
std::vector<std::string> Binder::bindFilePaths(const std::vector<std::string>& filePaths) const {
std::vector<std::string> boundFilePaths;
for (auto& filePath : filePaths) {
// This is a temporary workaround because we use duckdb to read from iceberg/delta/azure.
// When we read delta/iceberg/azure tables from s3/httpfs, we don't have the httpfs
// extension loaded meaning that we cannot handle remote paths. So we pass the file path to
// duckdb for validation when we bindFileScanSource.
const auto& loadedExtensions =
extension::ExtensionManager::Get(*clientContext)->getLoadedExtensions();
const bool httpfsExtensionLoaded =
std::any_of(loadedExtensions.begin(), loadedExtensions.end(),
[](const auto& extension) { return extension.getExtensionName() == "HTTPFS"; });
if (!httpfsExtensionLoaded && !LocalFileSystem::isLocalPath(filePath)) {
boundFilePaths.push_back(filePath);
continue;
}
auto globbedFilePaths =
VirtualFileSystem::GetUnsafe(*clientContext)->glob(clientContext, filePath);
if (globbedFilePaths.empty()) {
throw BinderException{
stringFormat("No file found that matches the pattern: {}.", filePath)};
}
for (auto& globbedPath : globbedFilePaths) {
boundFilePaths.push_back(globbedPath);
}
}
return boundFilePaths;
}
case_insensitive_map_t<Value> Binder::bindParsingOptions(const options_t& parsingOptions) {
case_insensitive_map_t<Value> options;
for (auto& option : parsingOptions) {
auto name = option.first;
StringUtils::toUpper(name);
auto expr = expressionBinder.bindExpression(*option.second);
KU_ASSERT(expr->expressionType == ExpressionType::LITERAL);
auto literalExpr = ku_dynamic_cast<LiteralExpression*>(expr.get());
options.insert({name, literalExpr->getValue()});
}
return options;
}
std::unique_ptr<BoundBaseScanSource> Binder::bindScanSource(const BaseScanSource* source,
const options_t& options, const std::vector<std::string>& columnNames,
const std::vector<LogicalType>& columnTypes) {
switch (source->type) {
case ScanSourceType::FILE: {
return bindFileScanSource(*source, options, columnNames, columnTypes);
}
case ScanSourceType::QUERY: {
return bindQueryScanSource(*source, options, columnNames, columnTypes);
}
case ScanSourceType::OBJECT: {
return bindObjectScanSource(*source, options, columnNames, columnTypes);
}
case ScanSourceType::TABLE_FUNC: {
return bindTableFuncScanSource(*source, options, columnNames, columnTypes);
}
case ScanSourceType::PARAM: {
return bindParameterScanSource(*source, options, columnNames, columnTypes);
}
default:
KU_UNREACHABLE;
}
}
bool handleFileViaFunction(main::ClientContext* context, std::vector<std::string> filePaths) {
bool handleFileViaFunction = false;
if (VirtualFileSystem::GetUnsafe(*context)->fileOrPathExists(filePaths[0], context)) {
handleFileViaFunction =
VirtualFileSystem::GetUnsafe(*context)->handleFileViaFunction(filePaths[0]);
}
return handleFileViaFunction;
}
std::unique_ptr<BoundBaseScanSource> Binder::bindFileScanSource(const BaseScanSource& scanSource,
const options_t& options, const std::vector<std::string>& columnNames,
const std::vector<LogicalType>& columnTypes) {
auto fileSource = scanSource.constPtrCast<FileScanSource>();
auto filePaths = bindFilePaths(fileSource->filePaths);
auto boundOptions = bindParsingOptions(options);
FileTypeInfo fileTypeInfo;
if (boundOptions.contains(FileScanInfo::FILE_FORMAT_OPTION_NAME)) {
auto fileFormat = boundOptions.at(FileScanInfo::FILE_FORMAT_OPTION_NAME).toString();
fileTypeInfo = FileTypeInfo{FileTypeUtils::fromString(fileFormat), fileFormat};
} else {
fileTypeInfo = bindFileTypeInfo(filePaths);
}
// If we defined a certain FileType, we have to ensure the path is a file, not something else
// (e.g. an existed directory)
if (fileTypeInfo.fileType != FileType::UNKNOWN) {
for (const auto& filePath : filePaths) {
if (!LocalFileSystem::fileExists(filePath) && LocalFileSystem::isLocalPath(filePath)) {
throw BinderException{stringFormat("Provided path is not a file: {}.", filePath)};
}
}
}
boundOptions.erase(FileScanInfo::FILE_FORMAT_OPTION_NAME);
// Bind file configuration
auto fileScanInfo = std::make_unique<FileScanInfo>(std::move(fileTypeInfo), filePaths);
fileScanInfo->options = std::move(boundOptions);
TableFunction func;
if (handleFileViaFunction(clientContext, filePaths)) {
func = VirtualFileSystem::GetUnsafe(*clientContext)->getHandleFunction(filePaths[0]);
} else {
func = getScanFunction(fileScanInfo->fileTypeInfo, *fileScanInfo);
}
// Bind table function
auto bindInput = TableFuncBindInput();
bindInput.addLiteralParam(Value::createValue(filePaths[0]));
auto extraInput = std::make_unique<ExtraScanTableFuncBindInput>();
extraInput->fileScanInfo = fileScanInfo->copy();
extraInput->expectedColumnNames = columnNames;
extraInput->expectedColumnTypes = LogicalType::copy(columnTypes);
extraInput->tableFunction = &func;
bindInput.extraInput = std::move(extraInput);
bindInput.binder = this;
auto bindData = func.bindFunc(clientContext, &bindInput);
auto info = BoundTableScanInfo(func, std::move(bindData));
return std::make_unique<BoundTableScanSource>(ScanSourceType::FILE, std::move(info));
}
std::unique_ptr<BoundBaseScanSource> Binder::bindQueryScanSource(const BaseScanSource& scanSource,
const options_t& options, const std::vector<std::string>& columnNames,
const std::vector<LogicalType>&) {
auto querySource = scanSource.constPtrCast<QueryScanSource>();
auto boundStatement = bind(*querySource->statement);
auto columns = boundStatement->getStatementResult()->getColumns();
if (columns.size() != columnNames.size()) {
throw BinderException(stringFormat("Query returns {} columns but {} columns were expected.",
columns.size(), columnNames.size()));
}
for (auto i = 0u; i < columns.size(); ++i) {
columns[i]->setAlias(columnNames[i]);
}
auto scanInfo = BoundQueryScanSourceInfo(bindParsingOptions(options));
return std::make_unique<BoundQueryScanSource>(std::move(boundStatement), std::move(scanInfo));
}
static TableFunction getObjectScanFunc(const std::string& dbName, const std::string& tableName,
main::ClientContext* clientContext) {
// Bind external database table
auto attachedDB = main::DatabaseManager::Get(*clientContext)->getAttachedDatabase(dbName);
auto attachedCatalog = attachedDB->getCatalog();
auto entry = attachedCatalog->getTableCatalogEntry(
transaction::Transaction::Get(*clientContext), tableName);
return entry->ptrCast<TableCatalogEntry>()->getScanFunction();
}
BoundTableScanInfo bindTableScanSourceInfo(Binder& binder, TableFunction func,
const std::string& sourceName, std::unique_ptr<TableFuncBindData> bindData,
const std::vector<std::string>& columnNames, const std::vector<LogicalType>& columnTypes) {
expression_vector columns;
if (columnTypes.empty()) {
} else {
if (bindData->getNumColumns() != columnTypes.size()) {
throw BinderException(stringFormat("{} has {} columns but {} columns were expected.",
sourceName, bindData->getNumColumns(), columnTypes.size()));
}
for (auto i = 0u; i < bindData->getNumColumns(); ++i) {
auto column =
binder.createInvisibleVariable(columnNames[i], bindData->columns[i]->getDataType());
binder.replaceExpressionInScope(bindData->columns[i]->toString(), columnNames[i],
column);
columns.push_back(column);
}
bindData->columns = columns;
}
return BoundTableScanInfo(func, std::move(bindData));
}
std::unique_ptr<BoundBaseScanSource> Binder::bindParameterScanSource(
const BaseScanSource& scanSource, const options_t& options,
const std::vector<std::string>& columnNames, const std::vector<LogicalType>& columnTypes) {
auto paramSource = scanSource.constPtrCast<ParameterScanSource>();
auto paramExpr = expressionBinder.bindParameterExpression(*paramSource->paramExpression);
auto scanSourceValue = paramExpr->constCast<ParameterExpression>().getValue();
if (scanSourceValue.getDataType().getLogicalTypeID() != LogicalTypeID::POINTER) {
throw BinderException(stringFormat(
"Trying to scan from unsupported data type {}. The only parameter types that can be "
"scanned from are pandas/polars dataframes and pyarrow tables.",
scanSourceValue.getDataType().toString()));
}
TableFunction func;
std::unique_ptr<TableFuncBindData> bindData;
auto bindInput = TableFuncBindInput();
bindInput.binder = this;
// Bind external object as table
auto replacementData =
clientContext->tryReplaceByHandle(scanSourceValue.getValue<scan_replace_handle_t>());
func = replacementData->func;
auto replaceExtraInput = std::make_unique<ExtraScanTableFuncBindInput>();
replaceExtraInput->fileScanInfo.options = bindParsingOptions(options);
replacementData->bindInput.extraInput = std::move(replaceExtraInput);
replacementData->bindInput.binder = this;
bindData = func.bindFunc(clientContext, &replacementData->bindInput);
auto info = bindTableScanSourceInfo(*this, func, paramExpr->toString(), std::move(bindData),
columnNames, columnTypes);
return std::make_unique<BoundTableScanSource>(ScanSourceType::OBJECT, std::move(info));
}
std::unique_ptr<BoundBaseScanSource> Binder::bindObjectScanSource(const BaseScanSource& scanSource,
const options_t& options, const std::vector<std::string>& columnNames,
const std::vector<LogicalType>& columnTypes) {
auto objectSource = scanSource.constPtrCast<ObjectScanSource>();
TableFunction func;
std::unique_ptr<TableFuncBindData> bindData;
std::string objectName;
auto bindInput = TableFuncBindInput();
bindInput.binder = this;
if (objectSource->objectNames.size() == 1) {
// Bind external object as table
objectName = objectSource->objectNames[0];
auto replacementData = clientContext->tryReplaceByName(objectName);
if (replacementData != nullptr) { // Replace as python object
func = replacementData->func;
auto replaceExtraInput = std::make_unique<ExtraScanTableFuncBindInput>();
replaceExtraInput->fileScanInfo.options = bindParsingOptions(options);
replacementData->bindInput.extraInput = std::move(replaceExtraInput);
replacementData->bindInput.binder = this;
bindData = func.bindFunc(clientContext, &replacementData->bindInput);
} else if (main::DatabaseManager::Get(*clientContext)->hasDefaultDatabase()) {
auto dbName = main::DatabaseManager::Get(*clientContext)->getDefaultDatabase();
func = getObjectScanFunc(dbName, objectSource->objectNames[0], clientContext);
bindData = func.bindFunc(clientContext, &bindInput);
} else {
throw BinderException(ExceptionMessage::variableNotInScope(objectName));
}
} else if (objectSource->objectNames.size() == 2) {
// Bind external database table
objectName = objectSource->objectNames[0] + "." + objectSource->objectNames[1];
func = getObjectScanFunc(objectSource->objectNames[0], objectSource->objectNames[1],
clientContext);
bindData = func.bindFunc(clientContext, &bindInput);
} else {
// LCOV_EXCL_START
throw BinderException(stringFormat("Cannot find object {}.",
StringUtils::join(objectSource->objectNames, ",")));
// LCOV_EXCL_STOP
}
auto info = bindTableScanSourceInfo(*this, func, objectName, std::move(bindData), columnNames,
columnTypes);
return std::make_unique<BoundTableScanSource>(ScanSourceType::OBJECT, std::move(info));
}
std::unique_ptr<BoundBaseScanSource> Binder::bindTableFuncScanSource(
const BaseScanSource& scanSource, const options_t& options,
const std::vector<std::string>& columnNames, const std::vector<LogicalType>& columnTypes) {
if (!options.empty()) {
throw common::BinderException{"No option is supported when copying from table functions."};
}
auto tableFuncScanSource = scanSource.constPtrCast<TableFuncScanSource>();
auto& parsedFuncExpression =
tableFuncScanSource->functionExpression->constCast<parser::ParsedFunctionExpression>();
auto boundTableFunc = bindTableFunc(parsedFuncExpression.getFunctionName(),
*tableFuncScanSource->functionExpression, {} /* yieldVariables */);
auto& tableFunc = boundTableFunc.func;
auto info = bindTableScanSourceInfo(*this, tableFunc, tableFunc.name,
std::move(boundTableFunc.bindData), columnNames, columnTypes);
return std::make_unique<BoundTableScanSource>(ScanSourceType::OBJECT, std::move(info));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,695 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/path_expression.h"
#include "binder/expression/property_expression.h"
#include "binder/expression_visitor.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/enums/rel_direction.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "common/utils.h"
#include "function/cast/functions/cast_from_string_functions.h"
#include "function/gds/rec_joins.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "main/client_context.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
// A graph pattern contains node/rel and a set of key-value pairs associated with the variable. We
// bind node/rel as query graph and key-value pairs as a separate collection. This collection is
// interpreted in two different ways.
// - In MATCH clause, these are additional predicates to WHERE clause
// - In UPDATE clause, there are properties to set.
// We do not store key-value pairs in query graph primarily because we will merge key-value
// std::pairs with other predicates specified in WHERE clause.
BoundGraphPattern Binder::bindGraphPattern(const std::vector<PatternElement>& graphPattern) {
auto queryGraphCollection = QueryGraphCollection();
for (auto& patternElement : graphPattern) {
queryGraphCollection.addAndMergeQueryGraphIfConnected(bindPatternElement(patternElement));
}
queryGraphCollection.finalize();
auto boundPattern = BoundGraphPattern();
boundPattern.queryGraphCollection = std::move(queryGraphCollection);
return boundPattern;
}
// Grammar ensures pattern element is always connected and thus can be bound as a query graph.
QueryGraph Binder::bindPatternElement(const PatternElement& patternElement) {
auto queryGraph = QueryGraph();
expression_vector nodeAndRels;
auto leftNode = bindQueryNode(*patternElement.getFirstNodePattern(), queryGraph);
nodeAndRels.push_back(leftNode);
for (auto i = 0u; i < patternElement.getNumPatternElementChains(); ++i) {
auto patternElementChain = patternElement.getPatternElementChain(i);
auto rightNode = bindQueryNode(*patternElementChain->getNodePattern(), queryGraph);
auto rel =
bindQueryRel(*patternElementChain->getRelPattern(), leftNode, rightNode, queryGraph);
nodeAndRels.push_back(rel);
nodeAndRels.push_back(rightNode);
leftNode = rightNode;
}
if (patternElement.hasPathName()) {
auto pathName = patternElement.getPathName();
auto pathExpression = createPath(pathName, nodeAndRels);
addToScope(pathName, pathExpression);
}
return queryGraph;
}
static LogicalType getRecursiveRelLogicalType(const LogicalType& nodeType,
const LogicalType& relType) {
auto nodesType = LogicalType::LIST(nodeType.copy());
auto relsType = LogicalType::LIST(relType.copy());
std::vector<StructField> recursiveRelFields;
recursiveRelFields.emplace_back(InternalKeyword::NODES, std::move(nodesType));
recursiveRelFields.emplace_back(InternalKeyword::RELS, std::move(relsType));
return LogicalType::RECURSIVE_REL(std::move(recursiveRelFields));
}
static void extraFieldFromStructType(const LogicalType& structType,
std::unordered_set<std::string>& set, std::vector<StructField>& structFields) {
for (auto& field : StructType::getFields(structType)) {
if (!set.contains(field.getName())) {
set.insert(field.getName());
structFields.emplace_back(field.getName(), field.getType().copy());
}
}
}
std::shared_ptr<Expression> Binder::createPath(const std::string& pathName,
const expression_vector& children) {
std::unordered_set<std::string> nodeFieldNameSet;
std::vector<StructField> nodeFields;
std::unordered_set<std::string> relFieldNameSet;
std::vector<StructField> relFields;
for (auto& child : children) {
if (ExpressionUtil::isNodePattern(*child)) {
auto& node = child->constCast<NodeExpression>();
extraFieldFromStructType(node.getDataType(), nodeFieldNameSet, nodeFields);
} else if (ExpressionUtil::isRelPattern(*child)) {
auto rel = ku_dynamic_cast<RelExpression*>(child.get());
extraFieldFromStructType(rel->getDataType(), relFieldNameSet, relFields);
} else if (ExpressionUtil::isRecursiveRelPattern(*child)) {
auto recursiveRel = ku_dynamic_cast<RelExpression*>(child.get());
auto recursiveInfo = recursiveRel->getRecursiveInfo();
extraFieldFromStructType(recursiveInfo->node->getDataType(), nodeFieldNameSet,
nodeFields);
extraFieldFromStructType(recursiveInfo->rel->getDataType(), relFieldNameSet, relFields);
} else {
KU_UNREACHABLE;
}
}
auto nodeType = LogicalType::NODE(std::move(nodeFields));
auto relType = LogicalType::REL(std::move(relFields));
auto uniqueName = getUniqueExpressionName(pathName);
return std::make_shared<PathExpression>(getRecursiveRelLogicalType(nodeType, relType),
uniqueName, pathName, std::move(nodeType), std::move(relType), children);
}
static std::vector<std::string> getPropertyNames(const std::vector<TableCatalogEntry*>& entries) {
std::vector<std::string> result;
std::unordered_set<std::string> propertyNamesSet;
for (auto& entry : entries) {
for (auto& property : entry->getProperties()) {
if (propertyNamesSet.contains(property.getName())) {
continue;
}
propertyNamesSet.insert(property.getName());
result.push_back(property.getName());
}
}
return result;
}
static std::shared_ptr<PropertyExpression> createPropertyExpression(const std::string& propertyName,
const std::string& uniqueVariableName, const std::string& rawVariableName,
const std::vector<TableCatalogEntry*>& entries) {
table_id_map_t<SingleLabelPropertyInfo> infos;
std::vector<LogicalType> dataTypes;
for (auto& entry : entries) {
bool exists = false;
if (entry->containsProperty(propertyName)) {
exists = true;
dataTypes.push_back(entry->getProperty(propertyName).getType().copy());
}
// Bind isPrimaryKey
auto isPrimaryKey = false;
if (entry->getTableType() == TableType::NODE) {
auto nodeEntry = entry->constPtrCast<NodeTableCatalogEntry>();
isPrimaryKey = nodeEntry->getPrimaryKeyName() == propertyName;
}
auto info = SingleLabelPropertyInfo(exists, isPrimaryKey);
infos.insert({entry->getTableID(), std::move(info)});
}
LogicalType maxType = LogicalTypeUtils::combineTypes(dataTypes);
return std::make_shared<PropertyExpression>(std::move(maxType), propertyName,
uniqueVariableName, rawVariableName, std::move(infos));
}
static void checkRelDirectionTypeAgainstStorageDirection(const RelExpression* rel) {
switch (rel->getDirectionType()) {
case RelDirectionType::SINGLE:
// Directed pattern is in the fwd direction
if (!containsValue(rel->getExtendDirections(), ExtendDirection::FWD)) {
throw BinderException(stringFormat("Querying table matched in rel pattern '{}' with "
"bwd-only storage direction isn't supported.",
rel->toString()));
}
break;
case RelDirectionType::BOTH:
if (rel->getExtendDirections().size() < NUM_REL_DIRECTIONS) {
throw BinderException(
stringFormat("Undirected rel pattern '{}' has at least one matched rel table with "
"storage type 'fwd' or 'bwd'. Undirected rel patterns are only "
"supported if every matched rel table has storage type 'both'.",
rel->toString()));
}
break;
default:
KU_UNREACHABLE;
}
}
std::shared_ptr<RelExpression> Binder::bindQueryRel(const RelPattern& relPattern,
const std::shared_ptr<NodeExpression>& leftNode,
const std::shared_ptr<NodeExpression>& rightNode, QueryGraph& queryGraph) {
auto parsedName = relPattern.getVariableName();
if (scope.contains(parsedName)) {
auto prevVariable = scope.getExpression(parsedName);
auto expectedDataType = QueryRelTypeUtils::isRecursive(relPattern.getRelType()) ?
LogicalTypeID::RECURSIVE_REL :
LogicalTypeID::REL;
ExpressionUtil::validateDataType(*prevVariable, expectedDataType);
throw BinderException("Bind relationship " + parsedName +
" to relationship with same name is not supported.");
}
auto entries = bindRelGroupEntries(relPattern.getTableNames());
// bind src & dst node
RelDirectionType directionType = RelDirectionType::UNKNOWN;
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
switch (relPattern.getDirection()) {
case ArrowDirection::LEFT: {
srcNode = rightNode;
dstNode = leftNode;
directionType = RelDirectionType::SINGLE;
} break;
case ArrowDirection::RIGHT: {
srcNode = leftNode;
dstNode = rightNode;
directionType = RelDirectionType::SINGLE;
} break;
case ArrowDirection::BOTH: {
// For both direction, left and right will be written with the same label set. So either one
// being src will be correct.
srcNode = leftNode;
dstNode = rightNode;
directionType = RelDirectionType::BOTH;
} break;
default:
KU_UNREACHABLE;
}
// bind variable length
std::shared_ptr<RelExpression> queryRel;
if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) {
queryRel = createRecursiveQueryRel(relPattern, entries, srcNode, dstNode, directionType);
} else {
queryRel = createNonRecursiveQueryRel(relPattern.getVariableName(), entries, srcNode,
dstNode, directionType);
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs =
expressionBinder.bindNodeOrRelPropertyExpression(*queryRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType);
queryRel->addPropertyDataExpr(propertyName, std::move(boundRhs));
}
}
queryRel->setLeftNode(leftNode);
queryRel->setRightNode(rightNode);
queryRel->setAlias(parsedName);
if (!parsedName.empty()) {
addToScope(parsedName, queryRel);
}
queryGraph.addQueryRel(queryRel);
checkRelDirectionTypeAgainstStorageDirection(queryRel.get());
return queryRel;
}
static std::vector<StructField> getBaseNodeStructFields() {
std::vector<StructField> fields;
fields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID());
fields.emplace_back(InternalKeyword::LABEL, LogicalType::STRING());
return fields;
}
static std::vector<StructField> getBaseRelStructFields() {
std::vector<StructField> fields;
fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID());
fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID());
fields.emplace_back(InternalKeyword::LABEL, LogicalType::STRING());
return fields;
}
static std::shared_ptr<PropertyExpression> construct(LogicalType type,
const std::string& propertyName, const Expression& child) {
KU_ASSERT(child.expressionType == ExpressionType::PATTERN);
auto& patternExpr = child.constCast<NodeOrRelExpression>();
auto variableName = patternExpr.getVariableName();
auto uniqueName = patternExpr.getUniqueName();
// Assign an invalid property id for virtual property.
table_id_map_t<SingleLabelPropertyInfo> infos;
for (auto& entry : patternExpr.getEntries()) {
infos.insert({entry->getTableID(),
SingleLabelPropertyInfo(false /* exists */, false /* isPrimaryKey */)});
}
return std::make_unique<PropertyExpression>(std::move(type), propertyName, uniqueName,
variableName, std::move(infos));
}
std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::string& parsedName,
const std::vector<TableCatalogEntry*>& entries, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType) {
auto uniqueName = getUniqueExpressionName(parsedName);
// Bind properties
auto structFields = getBaseRelStructFields();
std::vector<std::shared_ptr<PropertyExpression>> propertyExpressions;
if (entries.empty()) {
structFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID());
} else {
for (auto& propertyName : getPropertyNames(entries)) {
auto property = createPropertyExpression(propertyName, uniqueName, parsedName, entries);
structFields.emplace_back(property->getPropertyName(), property->getDataType().copy());
propertyExpressions.push_back(std::move(property));
}
}
auto queryRel = std::make_shared<RelExpression>(LogicalType::REL(std::move(structFields)),
uniqueName, parsedName, entries, std::move(srcNode), std::move(dstNode), directionType,
QueryRelType::NON_RECURSIVE);
queryRel->setAlias(parsedName);
if (entries.empty()) {
queryRel->addPropertyExpression(
construct(LogicalType::INTERNAL_ID(), InternalKeyword::ID, *queryRel));
} else {
for (auto& property : propertyExpressions) {
queryRel->addPropertyExpression(property);
}
}
// Bind internal expressions.
if (directionType == RelDirectionType::BOTH) {
queryRel->setDirectionExpr(expressionBinder.createVariableExpression(LogicalType::BOOL(),
queryRel->getUniqueName() + InternalKeyword::DIRECTION));
}
auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryRel});
queryRel->setLabelExpression(function::LabelFunction::rewriteFunc(input));
return queryRel;
}
static void bindProjectionListAsStructField(const expression_vector& projectionList,
std::vector<StructField>& fields) {
for (auto& expression : projectionList) {
if (expression->expressionType != ExpressionType::PROPERTY) {
throw BinderException(stringFormat("Unsupported projection item {} on recursive rel.",
expression->toString()));
}
auto& property = expression->constCast<PropertyExpression>();
fields.emplace_back(property.getPropertyName(), property.getDataType().copy());
}
}
static void checkWeightedShortestPathSupportedType(const LogicalType& type) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::INT8:
case LogicalTypeID::UINT8:
case LogicalTypeID::INT16:
case LogicalTypeID::UINT16:
case LogicalTypeID::INT32:
case LogicalTypeID::UINT32:
case LogicalTypeID::INT64:
case LogicalTypeID::UINT64:
case LogicalTypeID::DOUBLE:
case LogicalTypeID::FLOAT:
return;
default:
break;
}
throw BinderException(stringFormat(
"{} weight type is not supported for weighted shortest path.", type.toString()));
}
std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::RelPattern& relPattern,
const std::vector<TableCatalogEntry*>& entries, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType) {
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
table_catalog_entry_set_t nodeEntrySet;
for (auto entry : entries) {
auto& relGroupEntry = entry->constCast<RelGroupCatalogEntry>();
for (auto id : relGroupEntry.getSrcNodeTableIDSet()) {
nodeEntrySet.insert(catalog->getTableCatalogEntry(transaction, id));
}
for (auto id : relGroupEntry.getDstNodeTableIDSet()) {
nodeEntrySet.insert(catalog->getTableCatalogEntry(transaction, id));
}
}
auto nodeEntries = std::vector<TableCatalogEntry*>{nodeEntrySet.begin(), nodeEntrySet.end()};
auto recursivePatternInfo = relPattern.getRecursiveInfo();
auto prevScope = saveScope();
scope.clear();
// Bind intermediate node.
auto node = createQueryNode(recursivePatternInfo->nodeName, nodeEntries);
addToScope(node->toString(), node);
auto nodeFields = getBaseNodeStructFields();
auto nodeProjectionList = bindRecursivePatternNodeProjectionList(*recursivePatternInfo, *node);
bindProjectionListAsStructField(nodeProjectionList, nodeFields);
node->setDataType(LogicalType::NODE(std::move(nodeFields)));
auto nodeCopy = createQueryNode(recursivePatternInfo->nodeName, nodeEntries);
// Bind intermediate rel
auto rel = createNonRecursiveQueryRel(recursivePatternInfo->relName, entries,
nullptr /* srcNode */, nullptr /* dstNode */, directionType);
addToScope(rel->toString(), rel);
auto relProjectionList = bindRecursivePatternRelProjectionList(*recursivePatternInfo, *rel);
auto relFields = getBaseRelStructFields();
relFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID());
bindProjectionListAsStructField(relProjectionList, relFields);
rel->setDataType(LogicalType::REL(std::move(relFields)));
// Bind predicates in {}, e.g. [e* {date=1999-01-01}]
std::shared_ptr<Expression> relPredicate = nullptr;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodeOrRelPropertyExpression(*rel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType);
auto predicate = expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs);
relPredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND, relPredicate,
predicate);
}
// Bind predicates in (r, n | WHERE )
bool emptyRecursivePattern = false;
std::shared_ptr<Expression> nodePredicate = nullptr;
if (recursivePatternInfo->whereExpression != nullptr) {
expressionBinder.config.disableLabelFunctionLiteralRewrite = true;
auto wherePredicate = bindWhereExpression(*recursivePatternInfo->whereExpression);
expressionBinder.config.disableLabelFunctionLiteralRewrite = false;
for (auto& predicate : wherePredicate->splitOnAND()) {
auto collector = DependentVarNameCollector();
collector.visit(predicate);
auto dependentVariableNames = collector.getVarNames();
auto dependOnNode = dependentVariableNames.contains(node->getUniqueName());
auto dependOnRel = dependentVariableNames.contains(rel->getUniqueName());
if (dependOnNode && dependOnRel) {
throw BinderException(
stringFormat("Cannot evaluate {} because it depends on both {} and {}.",
predicate->toString(), node->toString(), rel->toString()));
} else if (dependOnNode) {
nodePredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND,
nodePredicate, predicate);
} else if (dependOnRel) {
relPredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND,
relPredicate, predicate);
} else {
if (!ExpressionUtil::isBoolLiteral(*predicate)) {
throw BinderException(stringFormat(
"Cannot evaluate {} because it does not depend on {} or {}. Treating it as "
"a node or relationship predicate is ambiguous.",
predicate->toString(), node->toString(), rel->toString()));
}
// If predicate is true literal, we ignore.
// If predicate is false literal, we mark this recursive relationship as empty
// and later in planner we replace it with EmptyResult.
if (!ExpressionUtil::getLiteralValue<bool>(*predicate)) {
emptyRecursivePattern = true;
}
}
}
}
// Bind rel
restoreScope(std::move(prevScope));
auto parsedName = relPattern.getVariableName();
auto prunedRelEntries = entries;
if (emptyRecursivePattern) {
prunedRelEntries.clear();
}
auto queryRel = std::make_shared<RelExpression>(
getRecursiveRelLogicalType(node->getDataType(), rel->getDataType()),
getUniqueExpressionName(parsedName), parsedName, prunedRelEntries, std::move(srcNode),
std::move(dstNode), directionType, relPattern.getRelType());
// Bind graph entry.
auto graphEntry = graph::NativeGraphEntry();
for (auto nodeEntry : node->getEntries()) {
graphEntry.nodeInfos.emplace_back(nodeEntry);
}
for (auto relEntry : rel->getEntries()) {
graphEntry.relInfos.emplace_back(relEntry, rel, relPredicate);
}
auto bindData = std::make_unique<function::RJBindData>(graphEntry.copy());
// Bind lower upper bound.
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
bindData->lowerBound = lowerBound;
bindData->upperBound = upperBound;
// Bind semantic.
bindData->semantic = QueryRelTypeUtils::getPathSemantic(queryRel->getRelType());
// Bind path related expressions.
bindData->lengthExpr = construct(LogicalType::INT64(), InternalKeyword::LENGTH, *queryRel);
bindData->pathNodeIDsExpr =
createInvisibleVariable("pathNodeIDs", LogicalType::LIST(LogicalType::INTERNAL_ID()));
bindData->pathEdgeIDsExpr =
createInvisibleVariable("pathEdgeIDs", LogicalType::LIST(LogicalType::INTERNAL_ID()));
if (queryRel->getDirectionType() == RelDirectionType::BOTH) {
bindData->directionExpr =
createInvisibleVariable("pathEdgeDirections", LogicalType::LIST(LogicalType::BOOL()));
}
// Bind weighted path related expressions.
if (QueryRelTypeUtils::isWeighted(queryRel->getRelType())) {
auto propertyExpr = expressionBinder.bindNodeOrRelPropertyExpression(*rel,
recursivePatternInfo->weightPropertyName);
checkWeightedShortestPathSupportedType(propertyExpr->getDataType());
bindData->weightPropertyExpr = propertyExpr;
bindData->weightOutputExpr =
createInvisibleVariable(parsedName + "_cost", LogicalType::DOUBLE());
}
auto recursiveInfo = std::make_unique<RecursiveInfo>();
recursiveInfo->node = node;
recursiveInfo->nodeCopy = nodeCopy;
recursiveInfo->rel = rel;
recursiveInfo->nodePredicate = std::move(nodePredicate);
recursiveInfo->relPredicate = std::move(relPredicate);
recursiveInfo->nodeProjectionList = std::move(nodeProjectionList);
recursiveInfo->relProjectionList = std::move(relProjectionList);
recursiveInfo->function = QueryRelTypeUtils::getFunction(queryRel->getRelType());
recursiveInfo->bindData = std::move(bindData);
queryRel->setRecursiveInfo(std::move(recursiveInfo));
return queryRel;
}
expression_vector Binder::bindRecursivePatternNodeProjectionList(
const RecursiveRelPatternInfo& info, const NodeOrRelExpression& expr) {
expression_vector result;
if (!info.hasProjection) {
for (auto& expression : expr.getPropertyExpressions()) {
result.push_back(expression);
}
} else {
for (auto& expression : info.nodeProjectionList) {
result.push_back(expressionBinder.bindExpression(*expression));
}
}
return result;
}
expression_vector Binder::bindRecursivePatternRelProjectionList(const RecursiveRelPatternInfo& info,
const NodeOrRelExpression& expr) {
expression_vector result;
if (!info.hasProjection) {
for (auto& property : expr.getPropertyExpressions()) {
if (property->isInternalID()) {
continue;
}
result.push_back(property);
}
} else {
for (auto& expression : info.relProjectionList) {
result.push_back(expressionBinder.bindExpression(*expression));
}
}
return result;
}
std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(const RelPattern& relPattern) {
auto recursiveInfo = relPattern.getRecursiveInfo();
uint32_t lowerBound = 0;
function::CastString::operation(
ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()},
lowerBound);
auto maxDepth = clientContext->getClientConfig()->varLengthMaxDepth;
auto upperBound = maxDepth;
if (!recursiveInfo->upperBound.empty()) {
function::CastString::operation(
ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()},
upperBound);
}
if (lowerBound > upperBound) {
throw BinderException(stringFormat("Lower bound of rel {} is greater than upperBound.",
relPattern.getVariableName()));
}
if (upperBound > maxDepth) {
throw BinderException(stringFormat("Upper bound of rel {} exceeds maximum: {}.",
relPattern.getVariableName(), std::to_string(maxDepth)));
}
if ((relPattern.getRelType() == QueryRelType::ALL_SHORTEST ||
relPattern.getRelType() == QueryRelType::SHORTEST) &&
lowerBound != 1) {
throw BinderException("Lower bound of shortest/all_shortest path must be 1.");
}
return std::make_pair(lowerBound, upperBound);
}
std::shared_ptr<NodeExpression> Binder::bindQueryNode(const NodePattern& nodePattern,
QueryGraph& queryGraph) {
auto parsedName = nodePattern.getVariableName();
std::shared_ptr<NodeExpression> queryNode;
if (scope.contains(parsedName)) { // bind to node in scope
auto prevVariable = scope.getExpression(parsedName);
if (!ExpressionUtil::isNodePattern(*prevVariable)) {
if (!scope.hasNodeReplacement(parsedName)) {
throw BinderException(stringFormat("Cannot bind {} as node pattern.", parsedName));
}
queryNode = scope.getNodeReplacement(parsedName);
queryNode->addPropertyDataExpr(InternalKeyword::ID, queryNode->getInternalID());
} else {
queryNode = std::static_pointer_cast<NodeExpression>(prevVariable);
// E.g. MATCH (a:person) MATCH (a:organisation)
// We bind to a single node with both labels
if (!nodePattern.getTableNames().empty()) {
auto otherNodeEntries = bindNodeTableEntries(nodePattern.getTableNames());
queryNode->addEntries(otherNodeEntries);
}
}
} else {
queryNode = createQueryNode(nodePattern);
if (!parsedName.empty()) {
addToScope(parsedName, queryNode);
}
}
for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodeOrRelPropertyExpression(*queryNode, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = expressionBinder.forceCast(boundRhs, boundLhs->dataType);
queryNode->addPropertyDataExpr(propertyName, std::move(boundRhs));
}
queryGraph.addQueryNode(queryNode);
return queryNode;
}
std::shared_ptr<NodeExpression> Binder::createQueryNode(const NodePattern& nodePattern) {
auto parsedName = nodePattern.getVariableName();
return createQueryNode(parsedName, bindNodeTableEntries(nodePattern.getTableNames()));
}
std::shared_ptr<NodeExpression> Binder::createQueryNode(const std::string& parsedName,
const std::vector<TableCatalogEntry*>& entries) {
auto uniqueName = getUniqueExpressionName(parsedName);
// Bind properties.
auto structFields = getBaseNodeStructFields();
std::vector<std::shared_ptr<PropertyExpression>> propertyExpressions;
for (auto& propertyName : getPropertyNames(entries)) {
auto property = createPropertyExpression(propertyName, uniqueName, parsedName, entries);
structFields.emplace_back(property->getPropertyName(), property->getDataType().copy());
propertyExpressions.push_back(std::move(property));
}
auto queryNode = std::make_shared<NodeExpression>(LogicalType::NODE(std::move(structFields)),
uniqueName, parsedName, entries);
queryNode->setAlias(parsedName);
for (auto& property : propertyExpressions) {
queryNode->addPropertyExpression(property);
}
// Bind internal expressions
queryNode->setInternalID(
construct(LogicalType::INTERNAL_ID(), InternalKeyword::ID, *queryNode));
auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryNode});
queryNode->setLabelExpression(function::LabelFunction::rewriteFunc(input));
return queryNode;
}
static std::vector<TableCatalogEntry*> sortEntries(const table_catalog_entry_set_t& set) {
std::vector<TableCatalogEntry*> entries;
for (auto entry : set) {
entries.push_back(entry);
}
std::sort(entries.begin(), entries.end(),
[](const TableCatalogEntry* a, const TableCatalogEntry* b) {
return a->getTableID() < b->getTableID();
});
return entries;
}
std::vector<TableCatalogEntry*> Binder::bindNodeTableEntries(
const std::vector<std::string>& tableNames) const {
auto transaction = transaction::Transaction::Get(*clientContext);
auto catalog = Catalog::Get(*clientContext);
auto useInternal = clientContext->useInternalCatalogEntry();
table_catalog_entry_set_t entrySet;
if (tableNames.empty()) { // Rewrite as all node tables in database.
for (auto entry : catalog->getNodeTableEntries(transaction, useInternal)) {
entrySet.insert(entry);
}
} else {
for (auto& name : tableNames) {
auto entry = bindNodeTableEntry(name);
if (entry->getType() != CatalogEntryType::NODE_TABLE_ENTRY) {
throw BinderException(
stringFormat("Cannot bind {} as a node pattern label.", entry->getName()));
}
entrySet.insert(entry);
}
}
return sortEntries(entrySet);
}
TableCatalogEntry* Binder::bindNodeTableEntry(const std::string& name) const {
auto transaction = transaction::Transaction::Get(*clientContext);
auto catalog = Catalog::Get(*clientContext);
auto useInternal = clientContext->useInternalCatalogEntry();
if (!catalog->containsTable(transaction, name, useInternal)) {
throw BinderException(stringFormat("Table {} does not exist.", name));
}
return catalog->getTableCatalogEntry(transaction, name, useInternal);
}
std::vector<TableCatalogEntry*> Binder::bindRelGroupEntries(
const std::vector<std::string>& tableNames) const {
auto transaction = transaction::Transaction::Get(*clientContext);
auto catalog = Catalog::Get(*clientContext);
auto useInternal = clientContext->useInternalCatalogEntry();
table_catalog_entry_set_t entrySet;
if (tableNames.empty()) { // Rewrite as all rel groups in database.
for (auto entry : catalog->getRelGroupEntries(transaction, useInternal)) {
entrySet.insert(entry);
}
} else {
for (auto& name : tableNames) {
if (catalog->containsTable(transaction, name)) {
auto entry = catalog->getTableCatalogEntry(transaction, name, useInternal);
if (entry->getType() != CatalogEntryType::REL_GROUP_ENTRY) {
throw BinderException(stringFormat(
"Cannot bind {} as a relationship pattern label.", entry->getName()));
}
entrySet.insert(entry);
} else {
throw BinderException(stringFormat("Table {} does not exist.", name));
}
}
}
return sortEntries(entrySet);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,144 @@
#include "binder/binder.h"
#include "binder/bound_import_database.h"
#include "common/copier_config/csv_reader_config.h"
#include "common/exception/binder.h"
#include "common/file_system/virtual_file_system.h"
#include "main/client_context.h"
#include "parser/copy.h"
#include "parser/parser.h"
#include "parser/port_db.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
static std::string getQueryFromFile(VirtualFileSystem* vfs, const std::string& boundFilePath,
const std::string& fileName, main::ClientContext* context) {
auto filePath = vfs->joinPath(boundFilePath, fileName);
if (!vfs->fileOrPathExists(filePath, context)) {
if (fileName == PortDBConstants::COPY_FILE_NAME) {
return "";
}
if (fileName == PortDBConstants::INDEX_FILE_NAME) {
return "";
}
throw BinderException(stringFormat("File {} does not exist.", filePath));
}
auto fileInfo = vfs->openFile(filePath, FileOpenFlags(FileFlags::READ_ONLY
#ifdef _WIN32
| FileFlags::BINARY
#endif
));
auto fsize = fileInfo->getFileSize();
auto buffer = std::make_unique<char[]>(fsize);
fileInfo->readFile(buffer.get(), fsize);
return std::string(buffer.get(), fsize);
}
static std::string getColumnNamesToCopy(const CopyFrom& copyFrom) {
std::string columns = "";
std::string delimiter = "";
for (auto& column : copyFrom.getCopyColumnInfo().columnNames) {
columns += delimiter;
columns += "`" + column + "`";
if (delimiter == "") {
delimiter = ",";
}
}
if (columns.empty()) {
return columns;
}
return stringFormat("({})", columns);
}
static std::string getCopyFilePath(const std::string& boundFilePath, const std::string& filePath) {
if (filePath[0] == '/' || (std::isalpha(filePath[0]) && filePath[1] == ':')) {
// Note:
// Unix absolute path starts with '/'
// Windows absolute path starts with "[DiskID]://"
// This code path is for backward compatibility, we used to export the absolute path for
// csv files to copy.cypher files.
return filePath;
}
auto path = boundFilePath + "/" + filePath;
#if defined(_WIN32)
// TODO(Ziyi): This is a temporary workaround because our parser requires input cypher queries
// to escape all special characters in string literal. E.g. The user input query is: 'IMPORT
// DATABASE 'C:\\db\\uw''. The parser removes any escaped characters and this function accepts
// the path parameter as 'C:\db\uw'. Then the ImportDatabase operator gives the file path to
// antlr4 parser directly without escaping any special characters in the path, which causes a
// parser exception. However, the parser exception is not thrown properly which leads to the
// undefined behaviour.
size_t pos = 0;
while ((pos = path.find('\\', pos)) != std::string::npos) {
path.replace(pos, 1, "\\\\");
pos += 2;
}
#endif
return path;
}
std::unique_ptr<BoundStatement> Binder::bindImportDatabaseClause(const Statement& statement) {
auto& importDB = statement.constCast<ImportDB>();
auto fs = VirtualFileSystem::GetUnsafe(*clientContext);
auto boundFilePath = fs->expandPath(clientContext, importDB.getFilePath());
if (!fs->fileOrPathExists(boundFilePath, clientContext)) {
throw BinderException(stringFormat("Directory {} does not exist.", boundFilePath));
}
std::string finalQueryStatements;
finalQueryStatements +=
getQueryFromFile(fs, boundFilePath, PortDBConstants::SCHEMA_FILE_NAME, clientContext);
// replace the path in copy from statements with the bound path
auto copyQuery =
getQueryFromFile(fs, boundFilePath, PortDBConstants::COPY_FILE_NAME, clientContext);
if (!copyQuery.empty()) {
auto parsedStatements = Parser::parseQuery(copyQuery);
for (auto& parsedStatement : parsedStatements) {
KU_ASSERT(parsedStatement->getStatementType() == StatementType::COPY_FROM);
auto& copyFromStatement = parsedStatement->constCast<CopyFrom>();
KU_ASSERT(copyFromStatement.getSource()->type == ScanSourceType::FILE);
auto filePaths =
copyFromStatement.getSource()->constPtrCast<FileScanSource>()->filePaths;
KU_ASSERT(filePaths.size() == 1);
auto fileTypeInfo = bindFileTypeInfo(filePaths);
std::string query;
auto copyFilePath = getCopyFilePath(boundFilePath, filePaths[0]);
auto columnNames = getColumnNamesToCopy(copyFromStatement);
auto parsingOptions = bindParsingOptions(copyFromStatement.getParsingOptions());
std::unordered_map<std::string, std::string> copyFromOptions;
if (parsingOptions.contains(CopyConstants::FROM_OPTION_NAME)) {
KU_ASSERT(parsingOptions.contains(CopyConstants::TO_OPTION_NAME));
copyFromOptions[CopyConstants::FROM_OPTION_NAME] = stringFormat("'{}'",
parsingOptions.at(CopyConstants::FROM_OPTION_NAME).getValue<std::string>());
copyFromOptions[CopyConstants::TO_OPTION_NAME] = stringFormat("'{}'",
parsingOptions.at(CopyConstants::TO_OPTION_NAME).getValue<std::string>());
parsingOptions.erase(CopyConstants::FROM_OPTION_NAME);
parsingOptions.erase(CopyConstants::TO_OPTION_NAME);
}
if (fileTypeInfo.fileType == FileType::CSV) {
auto csvConfig = CSVReaderConfig::construct(parsingOptions);
csvConfig.option.autoDetection = false;
auto optionsMap = csvConfig.option.toOptionsMap(csvConfig.parallel);
if (!copyFromOptions.empty()) {
optionsMap.insert(copyFromOptions.begin(), copyFromOptions.end());
}
query =
stringFormat("COPY `{}` {} FROM \"{}\" {};", copyFromStatement.getTableName(),
columnNames, copyFilePath, CSVOption::toCypher(optionsMap));
} else {
query =
stringFormat("COPY `{}` {} FROM \"{}\" {};", copyFromStatement.getTableName(),
columnNames, copyFilePath, CSVOption::toCypher(copyFromOptions));
}
finalQueryStatements += query;
}
}
return std::make_unique<BoundImportDatabase>(boundFilePath, finalQueryStatements,
getQueryFromFile(fs, boundFilePath, PortDBConstants::INDEX_FILE_NAME, clientContext));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,308 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/lambda_expression.h"
#include "binder/expression_visitor.h"
#include "binder/query/return_with_clause/bound_return_clause.h"
#include "binder/query/return_with_clause/bound_with_clause.h"
#include "common/exception/binder.h"
#include "parser/expression/parsed_property_expression.h"
#include "parser/query/return_with_clause/with_clause.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
void validateColumnNamesAreUnique(const std::vector<std::string>& columnNames) {
auto existColumnNames = std::unordered_set<std::string>();
for (auto& name : columnNames) {
if (existColumnNames.contains(name)) {
throw BinderException(
"Multiple result columns with the same name " + name + " are not supported.");
}
existColumnNames.insert(name);
}
}
std::vector<std::string> getColumnNames(const expression_vector& exprs,
const std::vector<std::string>& aliases) {
std::vector<std::string> columnNames;
for (auto i = 0u; i < exprs.size(); ++i) {
if (aliases[i].empty()) {
columnNames.push_back(exprs[i]->toString());
} else {
columnNames.push_back(aliases[i]);
}
}
return columnNames;
}
static void validateOrderByFollowedBySkipOrLimitInWithClause(
const BoundProjectionBody& boundProjectionBody) {
auto hasSkipOrLimit = boundProjectionBody.hasSkip() || boundProjectionBody.hasLimit();
if (boundProjectionBody.hasOrderByExpressions() && !hasSkipOrLimit) {
throw BinderException("In WITH clause, ORDER BY must be followed by SKIP or LIMIT.");
}
}
BoundWithClause Binder::bindWithClause(const WithClause& withClause) {
auto projectionBody = withClause.getProjectionBody();
auto [projectionExprs, aliases] = bindProjectionList(*projectionBody);
// Check all expressions are aliased
for (auto& alias : aliases) {
if (alias.empty()) {
throw BinderException("Expression in WITH must be aliased (use AS).");
}
}
auto columnNames = getColumnNames(projectionExprs, aliases);
validateColumnNamesAreUnique(columnNames);
auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExprs, aliases);
validateOrderByFollowedBySkipOrLimitInWithClause(boundProjectionBody);
// Update scope
scope.clear();
for (auto i = 0u; i < projectionExprs.size(); ++i) {
addToScope(aliases[i], projectionExprs[i]);
}
auto boundWithClause = BoundWithClause(std::move(boundProjectionBody));
if (withClause.hasWhereExpression()) {
boundWithClause.setWhereExpression(bindWhereExpression(*withClause.getWhereExpression()));
}
return boundWithClause;
}
BoundReturnClause Binder::bindReturnClause(const ReturnClause& returnClause) {
auto projectionBody = returnClause.getProjectionBody();
auto [projectionExprs, aliases] = bindProjectionList(*projectionBody);
auto columnNames = getColumnNames(projectionExprs, aliases);
auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExprs, aliases);
auto statementResult = BoundStatementResult();
KU_ASSERT(columnNames.size() == projectionExprs.size());
for (auto i = 0u; i < columnNames.size(); ++i) {
statementResult.addColumn(columnNames[i], projectionExprs[i]);
}
return BoundReturnClause(std::move(boundProjectionBody), std::move(statementResult));
}
static expression_vector getAggregateExpressions(const std::shared_ptr<Expression>& expression,
const BinderScope& scope) {
expression_vector result;
if (expression->hasAlias() && scope.contains(expression->getAlias())) {
return result;
}
if (expression->expressionType == ExpressionType::AGGREGATE_FUNCTION) {
result.push_back(expression);
return result;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) {
for (auto& expr : getAggregateExpressions(child, scope)) {
result.push_back(expr);
}
}
return result;
}
std::pair<expression_vector, std::vector<std::string>> Binder::bindProjectionList(
const ProjectionBody& projectionBody) {
expression_vector projectionExprs;
std::vector<std::string> aliases;
for (auto& parsedExpr : projectionBody.getProjectionExpressions()) {
if (parsedExpr->getExpressionType() == ExpressionType::STAR) {
// Rewrite star expression as all expression in scope.
if (scope.empty()) {
throw BinderException(
"RETURN or WITH * is not allowed when there are no variables in scope.");
}
for (auto& expr : scope.getExpressions()) {
projectionExprs.push_back(expr);
aliases.push_back(expr->getAlias());
}
} else if (parsedExpr->getExpressionType() == ExpressionType::PROPERTY) {
auto& propExpr = parsedExpr->constCast<ParsedPropertyExpression>();
if (propExpr.isStar()) {
// Rewrite property star expression
for (auto& expr : expressionBinder.bindPropertyStarExpression(*parsedExpr)) {
projectionExprs.push_back(expr);
aliases.push_back("");
}
} else {
auto expr = expressionBinder.bindExpression(*parsedExpr);
projectionExprs.push_back(expr);
aliases.push_back(parsedExpr->getAlias());
}
} else {
auto expr = expressionBinder.bindExpression(*parsedExpr);
projectionExprs.push_back(expr);
aliases.push_back(parsedExpr->hasAlias() ? parsedExpr->getAlias() : expr->getAlias());
}
}
return {projectionExprs, aliases};
}
class NestedAggCollector final : public ExpressionVisitor {
public:
expression_vector exprs;
protected:
void visitAggFunctionExpr(std::shared_ptr<Expression> expr) override { exprs.push_back(expr); }
void visitChildren(const Expression& expr) override {
switch (expr.expressionType) {
case ExpressionType::CASE_ELSE: {
visitCaseExprChildren(expr);
} break;
case ExpressionType::LAMBDA: {
auto& lambda = expr.constCast<LambdaExpression>();
visit(lambda.getFunctionExpr());
} break;
case ExpressionType::AGGREGATE_FUNCTION: {
// We do not traverse the child or aggregate because nested agg is validated recursively
// e.g. WITH SUM(1) AS x WITH SUM(x) AS y RETURN SUM(y)
// when validating SUM(y) we only need to check y and not x.
} break;
default: {
for (auto& child : expr.getChildren()) {
visit(child);
}
}
}
}
};
static void validateNestedAggregate(const Expression& expr, const BinderScope& scope) {
KU_ASSERT(expr.expressionType == ExpressionType::AGGREGATE_FUNCTION);
if (expr.getNumChildren() == 0) { // Skip COUNT(*)
return;
}
auto collector = NestedAggCollector();
collector.visit(expr.getChild(0));
for (auto& childAgg : collector.exprs) {
if (!scope.contains(childAgg->getAlias())) {
throw BinderException(
stringFormat("Expression {} contains nested aggregation.", expr.toString()));
}
}
}
BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& projectionBody,
const expression_vector& projectionExprs, const std::vector<std::string>& aliases) {
expression_vector groupByExprs;
expression_vector aggregateExprs;
KU_ASSERT(projectionExprs.size() == aliases.size());
for (auto i = 0u; i < projectionExprs.size(); ++i) {
auto expr = projectionExprs[i];
auto aggExprs = getAggregateExpressions(expr, scope);
if (!aggExprs.empty()) {
for (auto& agg : aggExprs) {
aggregateExprs.push_back(agg);
}
} else {
groupByExprs.push_back(expr);
}
expr->setAlias(aliases[i]);
}
auto boundProjectionBody = BoundProjectionBody(projectionBody.getIsDistinct());
boundProjectionBody.setProjectionExpressions(projectionExprs);
if (!aggregateExprs.empty()) {
for (auto& expr : aggregateExprs) {
validateNestedAggregate(*expr, scope);
}
if (!groupByExprs.empty()) {
// TODO(Xiyang): we can remove augment group by. But make sure we test sufficient
// including edge case and bug before release.
expression_vector augmentedGroupByExpressions = groupByExprs;
for (auto& expression : groupByExprs) {
if (ExpressionUtil::isNodePattern(*expression)) {
auto& node = expression->constCast<NodeExpression>();
augmentedGroupByExpressions.push_back(node.getInternalID());
} else if (ExpressionUtil::isRelPattern(*expression)) {
auto& rel = expression->constCast<RelExpression>();
augmentedGroupByExpressions.push_back(rel.getInternalID());
}
}
boundProjectionBody.setGroupByExpressions(std::move(augmentedGroupByExpressions));
}
boundProjectionBody.setAggregateExpressions(std::move(aggregateExprs));
}
// Bind order by
if (projectionBody.hasOrderByExpressions()) {
// Cypher rule of ORDER BY expression scope: if projection contains aggregation, only
// expressions in projection are available. Otherwise, expressions before projection are
// also available
expression_vector orderByExprs;
if (boundProjectionBody.hasAggregateExpressions() || boundProjectionBody.isDistinct()) {
scope.clear();
KU_ASSERT(projectionBody.getProjectionExpressions().size() == projectionExprs.size());
std::vector<std::string> tmpAliases;
for (auto& expr : projectionBody.getProjectionExpressions()) {
tmpAliases.push_back(expr->hasAlias() ? expr->getAlias() : expr->toString());
}
addToScope(tmpAliases, projectionExprs);
expressionBinder.config.bindOrderByAfterAggregate = true;
orderByExprs = bindOrderByExpressions(projectionBody.getOrderByExpressions());
expressionBinder.config.bindOrderByAfterAggregate = false;
} else {
addToScope(aliases, projectionExprs);
orderByExprs = bindOrderByExpressions(projectionBody.getOrderByExpressions());
}
boundProjectionBody.setOrderByExpressions(std::move(orderByExprs),
projectionBody.getSortOrders());
}
// Bind skip
if (projectionBody.hasSkipExpression()) {
boundProjectionBody.setSkipNumber(
bindSkipLimitExpression(*projectionBody.getSkipExpression()));
}
// Bind limit
if (projectionBody.hasLimitExpression()) {
boundProjectionBody.setLimitNumber(
bindSkipLimitExpression(*projectionBody.getLimitExpression()));
}
return boundProjectionBody;
}
static bool isOrderByKeyTypeSupported(const LogicalType& dataType) {
switch (dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::INTERNAL_ID:
case LogicalTypeID::LIST:
case LogicalTypeID::ARRAY:
case LogicalTypeID::STRUCT:
case LogicalTypeID::MAP:
case LogicalTypeID::UNION:
case LogicalTypeID::POINTER:
return false;
default:
return true;
}
}
expression_vector Binder::bindOrderByExpressions(
const std::vector<std::unique_ptr<ParsedExpression>>& parsedExprs) {
expression_vector exprs;
for (auto& parsedExpr : parsedExprs) {
auto expr = expressionBinder.bindExpression(*parsedExpr);
if (!isOrderByKeyTypeSupported(expr->dataType)) {
throw BinderException(stringFormat("Cannot order by {}. Order by {} is not supported.",
expr->toString(), expr->dataType.toString()));
}
exprs.push_back(std::move(expr));
}
return exprs;
}
std::shared_ptr<Expression> Binder::bindSkipLimitExpression(const ParsedExpression& expression) {
auto boundExpression = expressionBinder.bindExpression(expression);
if (boundExpression->expressionType != ExpressionType::LITERAL &&
boundExpression->expressionType != ExpressionType::PARAMETER) {
throw BinderException(
"The number of rows to skip/limit must be a parameter/literal expression.");
}
return boundExpression;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,106 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/query/return_with_clause/bound_return_clause.h"
#include "binder/query/return_with_clause/bound_with_clause.h"
#include "common/exception/binder.h"
#include "parser/query/regular_query.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
void validateUnionColumnsOfTheSameType(
const std::vector<NormalizedSingleQuery>& normalizedSingleQueries) {
if (normalizedSingleQueries.size() <= 1) {
return;
}
auto columns = normalizedSingleQueries[0].getStatementResult()->getColumns();
for (auto i = 1u; i < normalizedSingleQueries.size(); i++) {
auto otherColumns = normalizedSingleQueries[i].getStatementResult()->getColumns();
if (columns.size() != otherColumns.size()) {
throw BinderException("The number of columns to union/union all must be the same.");
}
// Check whether the dataTypes in union expressions are exactly the same in each single
// query.
for (auto j = 0u; j < columns.size(); j++) {
ExpressionUtil::validateDataType(*otherColumns[j], columns[j]->getDataType());
}
}
}
void validateIsAllUnionOrUnionAll(const BoundRegularQuery& regularQuery) {
auto unionAllExpressionCounter = 0u;
for (auto i = 0u; i < regularQuery.getNumSingleQueries() - 1; i++) {
unionAllExpressionCounter += regularQuery.getIsUnionAll(i);
}
if ((0 < unionAllExpressionCounter) &&
(unionAllExpressionCounter < regularQuery.getNumSingleQueries() - 1)) {
throw BinderException("Union and union all can not be used together.");
}
}
std::unique_ptr<BoundRegularQuery> Binder::bindQuery(const Statement& statement) {
auto& regularQuery = statement.constCast<RegularQuery>();
std::vector<NormalizedSingleQuery> normalizedSingleQueries;
for (auto i = 0u; i < regularQuery.getNumSingleQueries(); i++) {
// Don't clear scope within bindSingleQuery() yet because it is also used for subquery
// binding.
scope.clear();
normalizedSingleQueries.push_back(bindSingleQuery(*regularQuery.getSingleQuery(i)));
}
validateUnionColumnsOfTheSameType(normalizedSingleQueries);
KU_ASSERT(!normalizedSingleQueries.empty());
auto boundRegularQuery = std::make_unique<BoundRegularQuery>(regularQuery.getIsUnionAll(),
normalizedSingleQueries[0].getStatementResult()->copy());
for (auto& normalizedSingleQuery : normalizedSingleQueries) {
boundRegularQuery->addSingleQuery(std::move(normalizedSingleQuery));
}
validateIsAllUnionOrUnionAll(*boundRegularQuery);
return boundRegularQuery;
}
NormalizedSingleQuery Binder::bindSingleQuery(const SingleQuery& singleQuery) {
auto normalizedSingleQuery = NormalizedSingleQuery();
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
normalizedSingleQuery.appendQueryPart(bindQueryPart(*singleQuery.getQueryPart(i)));
}
auto lastQueryPart = NormalizedQueryPart();
for (auto i = 0u; i < singleQuery.getNumReadingClauses(); i++) {
lastQueryPart.addReadingClause(bindReadingClause(*singleQuery.getReadingClause(i)));
}
for (auto i = 0u; i < singleQuery.getNumUpdatingClauses(); ++i) {
lastQueryPart.addUpdatingClause(bindUpdatingClause(*singleQuery.getUpdatingClause(i)));
}
auto statementResult = BoundStatementResult();
if (singleQuery.hasReturnClause()) {
auto boundReturnClause = bindReturnClause(*singleQuery.getReturnClause());
lastQueryPart.setProjectionBody(boundReturnClause.getProjectionBody()->copy());
statementResult = boundReturnClause.getStatementResult()->copy();
} else {
statementResult = BoundStatementResult::createEmptyResult();
}
normalizedSingleQuery.appendQueryPart(std::move(lastQueryPart));
normalizedSingleQuery.setStatementResult(std::move(statementResult));
return normalizedSingleQuery;
}
NormalizedQueryPart Binder::bindQueryPart(const QueryPart& queryPart) {
auto normalizedQueryPart = NormalizedQueryPart();
for (auto i = 0u; i < queryPart.getNumReadingClauses(); i++) {
normalizedQueryPart.addReadingClause(bindReadingClause(*queryPart.getReadingClause(i)));
}
for (auto i = 0u; i < queryPart.getNumUpdatingClauses(); ++i) {
normalizedQueryPart.addUpdatingClause(bindUpdatingClause(*queryPart.getUpdatingClause(i)));
}
auto boundWithClause = bindWithClause(*queryPart.getWithClause());
normalizedQueryPart.setProjectionBody(boundWithClause.getProjectionBody()->copy());
if (boundWithClause.hasWhereExpression()) {
normalizedQueryPart.setProjectionBodyPredicate(boundWithClause.getWhereExpression());
}
return normalizedQueryPart;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,30 @@
#include "binder/binder.h"
#include "parser/query/reading_clause/reading_clause.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::unique_ptr<BoundReadingClause> Binder::bindReadingClause(const ReadingClause& readingClause) {
switch (readingClause.getClauseType()) {
case ClauseType::MATCH: {
return bindMatchClause(readingClause);
}
case ClauseType::UNWIND: {
return bindUnwindClause(readingClause);
}
case ClauseType::IN_QUERY_CALL: {
return bindInQueryCall(readingClause);
}
case ClauseType::LOAD_FROM: {
return bindLoadFrom(readingClause);
}
default:
KU_UNREACHABLE;
}
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,44 @@
#include "binder/binder.h"
#include "binder/bound_standalone_call.h"
#include "binder/expression/expression_util.h"
#include "binder/expression_visitor.h"
#include "common/cast.h"
#include "common/exception/binder.h"
#include "main/client_context.h"
#include "main/db_config.h"
#include "parser/standalone_call.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindStandaloneCall(const parser::Statement& statement) {
auto& callStatement = ku_dynamic_cast<const parser::StandaloneCall&>(statement);
const main::Option* option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
option = clientContext->getExtensionOption(callStatement.getOptionName());
}
if (option == nullptr) {
throw BinderException{"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindExpression(*callStatement.getOptionValue());
ExpressionUtil::validateExpressionType(*optionValue, ExpressionType::LITERAL);
if (LogicalTypeUtils::isFloatingPoint(optionValue->dataType.getLogicalTypeID()) &&
LogicalTypeUtils::isIntegral(LogicalType(option->parameterType))) {
throw BinderException{stringFormat(
"Expression {} has data type {} but expected {}. Implicit cast is not supported.",
optionValue->toString(),
LogicalTypeUtils::toString(optionValue->dataType.getLogicalTypeID()),
LogicalTypeUtils::toString(option->parameterType))};
}
optionValue =
expressionBinder.implicitCastIfNecessary(optionValue, LogicalType(option->parameterType));
if (ConstantExpressionVisitor::needFold(*optionValue)) {
optionValue = expressionBinder.foldExpression(optionValue);
}
return std::make_unique<BoundStandaloneCall>(option, std::move(optionValue));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,35 @@
#include "binder/binder.h"
#include "binder/bound_standalone_call_function.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "main/client_context.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/standalone_call_function.h"
#include "transaction/transaction.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindStandaloneCallFunction(
const parser::Statement& statement) {
auto& callStatement = statement.constCast<parser::StandaloneCallFunction>();
auto& funcExpr =
callStatement.getFunctionExpression()->constCast<parser::ParsedFunctionExpression>();
auto funcName = funcExpr.getFunctionName();
auto catalog = catalog::Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
auto entry =
catalog->getFunctionEntry(transaction, funcName, clientContext->useInternalCatalogEntry());
KU_ASSERT(entry);
if (entry->getType() != catalog::CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY) {
throw BinderException(
"Only standalone table functions can be called without return statement.");
}
auto boundTableFunction = bindTableFunc(funcName, funcExpr, {} /* yieldVariables */);
return std::make_unique<BoundStandaloneCallFunction>(std::move(boundTableFunction));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,77 @@
#include "binder/binder.h"
#include "binder/bound_table_scan_info.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "catalog/catalog.h"
#include "function/built_in_function_utils.h"
#include "main/client_context.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::function;
namespace lbug {
namespace binder {
BoundTableScanInfo Binder::bindTableFunc(const std::string& tableFuncName,
const parser::ParsedExpression& expr, std::vector<parser::YieldVariable> yieldVariables) {
auto catalog = catalog::Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
auto entry = catalog->getFunctionEntry(transaction, tableFuncName,
clientContext->useInternalCatalogEntry());
expression_vector positionalParams;
std::vector<LogicalType> positionalParamTypes;
optional_params_t optionalParams;
expression_vector optionalParamsLegacy;
for (auto i = 0u; i < expr.getNumChildren(); i++) {
auto& childExpr = *expr.getChild(i);
auto param = expressionBinder.bindExpression(childExpr);
ExpressionUtil::validateExpressionType(*param,
{ExpressionType::LITERAL, ExpressionType::PARAMETER, ExpressionType::PATTERN});
if (!childExpr.hasAlias()) {
positionalParams.push_back(param);
positionalParamTypes.push_back(param->getDataType().copy());
} else {
if (param->expressionType == ExpressionType::LITERAL) {
auto literalExpr = param->constPtrCast<LiteralExpression>();
optionalParams.emplace(childExpr.getAlias(), literalExpr->getValue());
}
param->setAlias(expr.getChild(i)->getAlias());
optionalParamsLegacy.push_back(param);
}
}
auto func = BuiltInFunctionsUtils::matchFunction(tableFuncName, positionalParamTypes,
entry->ptrCast<catalog::FunctionCatalogEntry>());
auto tableFunc = func->constPtrCast<TableFunction>();
std::vector<LogicalType> inputTypes;
if (tableFunc->inferInputTypes) {
// For functions which take in nested data types, we have to use the input parameters to
// detect the input types. (E.g. query_hnsw_index takes in an ARRAY which needs the user
// input parameters to decide the array dimension).
inputTypes = tableFunc->inferInputTypes(positionalParams);
} else {
// For functions which don't have nested type parameters, we can simply use the types
// declared in the function signature.
for (auto i = 0u; i < tableFunc->parameterTypeIDs.size(); i++) {
inputTypes.push_back(LogicalType(tableFunc->parameterTypeIDs[i]));
}
}
for (auto i = 0u; i < positionalParams.size(); ++i) {
auto parameterTypeID = tableFunc->parameterTypeIDs[i];
if (positionalParams[i]->expressionType == ExpressionType::LITERAL &&
parameterTypeID != LogicalTypeID::ANY) {
positionalParams[i] = expressionBinder.foldExpression(
expressionBinder.implicitCastIfNecessary(positionalParams[i], inputTypes[i]));
}
}
auto bindInput = TableFuncBindInput();
bindInput.params = std::move(positionalParams);
bindInput.optionalParams = std::move(optionalParams);
bindInput.optionalParamsLegacy = std::move(optionalParamsLegacy);
bindInput.binder = this;
bindInput.yieldVariables = std::move(yieldVariables);
return BoundTableScanInfo{*tableFunc, tableFunc->bindFunc(clientContext, &bindInput)};
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,16 @@
#include "binder/binder.h"
#include "binder/bound_transaction_statement.h"
#include "parser/transaction_statement.h"
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindTransaction(const Statement& statement) {
auto& transactionStatement = statement.constCast<TransactionStatement>();
return std::make_unique<BoundTransactionStatement>(transactionStatement.getTransactionAction());
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,387 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/property_expression.h"
#include "binder/query/query_graph_label_analyzer.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_insert_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
#include "binder/query/updating_clause/bound_set_clause.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/index_catalog_entry.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/assert.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "parser/query/updating_clause/delete_clause.h"
#include "parser/query/updating_clause/insert_clause.h"
#include "parser/query/updating_clause/merge_clause.h"
#include "parser/query/updating_clause/set_clause.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
std::unique_ptr<BoundUpdatingClause> Binder::bindUpdatingClause(
const UpdatingClause& updatingClause) {
switch (updatingClause.getClauseType()) {
case ClauseType::INSERT: {
return bindInsertClause(updatingClause);
}
case ClauseType::MERGE: {
return bindMergeClause(updatingClause);
}
case ClauseType::SET: {
return bindSetClause(updatingClause);
}
case ClauseType::DELETE_: {
return bindDeleteClause(updatingClause);
}
default:
KU_UNREACHABLE;
}
}
static std::unordered_set<std::string> populatePatternsScope(const BinderScope& scope) {
std::unordered_set<std::string> result;
for (auto& expression : scope.getExpressions()) {
if (ExpressionUtil::isNodePattern(*expression) ||
ExpressionUtil::isRelPattern(*expression)) {
result.insert(expression->toString());
} else if (expression->expressionType == ExpressionType::VARIABLE) {
if (scope.hasNodeReplacement(expression->toString())) {
result.insert(expression->toString());
}
}
}
return result;
}
std::unique_ptr<BoundUpdatingClause> Binder::bindInsertClause(
const UpdatingClause& updatingClause) {
auto& insertClause = updatingClause.constCast<InsertClause>();
auto patternsScope = populatePatternsScope(scope);
// bindGraphPattern will update scope.
auto boundGraphPattern = bindGraphPattern(insertClause.getPatternElementsRef());
auto insertInfos = bindInsertInfos(boundGraphPattern.queryGraphCollection, patternsScope);
return std::make_unique<BoundInsertClause>(std::move(insertInfos));
}
static expression_vector getColumnDataExprs(QueryGraphCollection& collection) {
expression_vector exprs;
for (auto i = 0u; i < collection.getNumQueryGraphs(); ++i) {
auto queryGraph = collection.getQueryGraph(i);
for (auto& pattern : queryGraph->getAllPatterns()) {
for (auto& [_, rhs] : pattern->getPropertyDataExprRef()) {
exprs.push_back(rhs);
}
}
}
return exprs;
}
std::unique_ptr<BoundUpdatingClause> Binder::bindMergeClause(const UpdatingClause& updatingClause) {
auto& mergeClause = updatingClause.constCast<MergeClause>();
auto patternsScope = populatePatternsScope(scope);
// bindGraphPattern will update scope.
auto boundGraphPattern = bindGraphPattern(mergeClause.getPatternElementsRef());
auto columnDataExprs = getColumnDataExprs(boundGraphPattern.queryGraphCollection);
// Rewrite key value pairs in MATCH clause as predicate
rewriteMatchPattern(boundGraphPattern);
auto existenceMark =
expressionBinder.createVariableExpression(LogicalType::BOOL(), std::string("__existence"));
auto distinctMark =
expressionBinder.createVariableExpression(LogicalType::BOOL(), std::string("__distinct"));
auto createInfos = bindInsertInfos(boundGraphPattern.queryGraphCollection, patternsScope);
auto boundMergeClause =
std::make_unique<BoundMergeClause>(columnDataExprs, std::move(existenceMark),
std::move(distinctMark), std::move(boundGraphPattern.queryGraphCollection),
std::move(boundGraphPattern.where), std::move(createInfos));
if (mergeClause.hasOnMatchSetItems()) {
for (auto& [lhs, rhs] : mergeClause.getOnMatchSetItemsRef()) {
auto setPropertyInfo = bindSetPropertyInfo(lhs.get(), rhs.get());
boundMergeClause->addOnMatchSetPropertyInfo(std::move(setPropertyInfo));
}
}
if (mergeClause.hasOnCreateSetItems()) {
for (auto& [lhs, rhs] : mergeClause.getOnCreateSetItemsRef()) {
auto setPropertyInfo = bindSetPropertyInfo(lhs.get(), rhs.get());
boundMergeClause->addOnCreateSetPropertyInfo(std::move(setPropertyInfo));
}
}
return boundMergeClause;
}
std::vector<BoundInsertInfo> Binder::bindInsertInfos(QueryGraphCollection& queryGraphCollection,
const std::unordered_set<std::string>& patternsInScope_) {
auto patternsInScope = patternsInScope_;
std::vector<BoundInsertInfo> result;
auto analyzer = QueryGraphLabelAnalyzer(*clientContext, true /* throwOnViolate */);
for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = queryGraphCollection.getQueryGraphUnsafe(i);
// Ensure query graph does not violate declared schema.
analyzer.pruneLabel(*queryGraph);
for (auto j = 0u; j < queryGraph->getNumQueryNodes(); ++j) {
auto node = queryGraph->getQueryNode(j);
if (node->getVariableName().empty()) { // Always create anonymous node.
bindInsertNode(node, result);
continue;
}
if (patternsInScope.contains(node->getVariableName())) {
continue;
}
patternsInScope.insert(node->getVariableName());
bindInsertNode(node, result);
}
for (auto j = 0u; j < queryGraph->getNumQueryRels(); ++j) {
auto rel = queryGraph->getQueryRel(j);
if (rel->getVariableName().empty()) { // Always create anonymous rel.
bindInsertRel(rel, result);
continue;
}
if (patternsInScope.contains(rel->getVariableName())) {
continue;
}
patternsInScope.insert(rel->getVariableName());
bindInsertRel(rel, result);
}
}
if (result.empty()) {
throw BinderException("Cannot resolve any node or relationship to create.");
}
return result;
}
static void validatePrimaryKeyExistence(const NodeTableCatalogEntry* nodeTableEntry,
const NodeExpression& node, const expression_vector& defaultExprs) {
auto primaryKeyName = nodeTableEntry->getPrimaryKeyName();
auto pkeyDefaultExpr = defaultExprs.at(nodeTableEntry->getPrimaryKeyID());
if (!node.hasPropertyDataExpr(primaryKeyName) &&
ExpressionUtil::isNullLiteral(*pkeyDefaultExpr)) {
throw BinderException(stringFormat("Create node {} expects primary key {} as input.",
node.toString(), primaryKeyName));
}
}
void Binder::bindInsertNode(std::shared_ptr<NodeExpression> node,
std::vector<BoundInsertInfo>& infos) {
if (node->isMultiLabeled()) {
throw BinderException(
"Create node " + node->toString() + " with multiple node labels is not supported.");
}
if (node->isEmpty()) {
throw BinderException(
"Create node " + node->toString() + " with empty node labels is not supported.");
}
KU_ASSERT(node->getNumEntries() == 1);
auto entry = node->getEntry(0);
KU_ASSERT(entry->getTableType() == TableType::NODE);
auto insertInfo = BoundInsertInfo(TableType::NODE, node);
for (auto& property : node->getPropertyExpressions()) {
if (property->hasProperty(entry->getTableID())) {
insertInfo.columnExprs.push_back(property);
}
}
insertInfo.columnDataExprs =
bindInsertColumnDataExprs(node->getPropertyDataExprRef(), entry->getProperties());
auto nodeEntry = entry->ptrCast<NodeTableCatalogEntry>();
validatePrimaryKeyExistence(nodeEntry, *node, insertInfo.columnDataExprs);
// Check extension secondary index loaded
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
for (auto indexEntry : catalog->getIndexEntries(transaction, nodeEntry->getTableID())) {
if (!indexEntry->isLoaded()) {
throw BinderException(stringFormat(
"Trying to insert into an index on table {} but its extension is not loaded.",
nodeEntry->getName()));
}
}
infos.push_back(std::move(insertInfo));
}
static TableCatalogEntry* tryPruneMultiLabeled(const RelExpression& rel,
const TableCatalogEntry& srcEntry, const TableCatalogEntry& dstEntry) {
std::vector<TableCatalogEntry*> candidates;
for (auto& entry : rel.getEntries()) {
KU_ASSERT(entry->getType() == CatalogEntryType::REL_GROUP_ENTRY);
auto& relEntry = entry->constCast<RelGroupCatalogEntry>();
if (relEntry.hasRelEntryInfo(srcEntry.getTableID(), dstEntry.getTableID())) {
candidates.push_back(entry);
}
}
if (candidates.size() > 1) {
throw BinderException(stringFormat(
"Create rel {} with multiple rel labels is not supported.", rel.toString()));
}
if (candidates.size() == 0) {
throw BinderException(
stringFormat("Cannot find a valid label in {} that connects {} and {}.", rel.toString(),
srcEntry.getName(), dstEntry.getName()));
}
return candidates[0];
}
void Binder::bindInsertRel(std::shared_ptr<RelExpression> rel,
std::vector<BoundInsertInfo>& infos) {
if (rel->isBoundByMultiLabeledNode()) {
throw BinderException(stringFormat(
"Create rel {} bound by multiple node labels is not supported.", rel->toString()));
}
if (rel->getDirectionType() == RelDirectionType::BOTH) {
throw BinderException(stringFormat("Create undirected relationship is not supported. "
"Try create 2 directed relationships instead."));
}
if (ExpressionUtil::isRecursiveRelPattern(*rel)) {
throw BinderException(stringFormat("Cannot create recursive rel {}.", rel->toString()));
}
TableCatalogEntry* entry = nullptr;
if (!rel->isMultiLabeled()) {
KU_ASSERT(rel->getNumEntries() == 1);
entry = rel->getEntry(0);
} else {
auto srcEntry = rel->getSrcNode()->getEntry(0);
auto dstEntry = rel->getDstNode()->getEntry(0);
entry = tryPruneMultiLabeled(*rel, *srcEntry, *dstEntry);
}
rel->setEntries(std::vector{entry});
auto insertInfo = BoundInsertInfo(TableType::REL, rel);
// Because we might prune entries, some property exprs may belong to pruned entry
for (auto& p : entry->getProperties()) {
insertInfo.columnExprs.push_back(rel->getPropertyExpression(p.getName()));
}
insertInfo.columnDataExprs =
bindInsertColumnDataExprs(rel->getPropertyDataExprRef(), entry->getProperties());
infos.push_back(std::move(insertInfo));
}
expression_vector Binder::bindInsertColumnDataExprs(
const case_insensitive_map_t<std::shared_ptr<Expression>>& propertyDataExprs,
const std::vector<PropertyDefinition>& propertyDefinitions) {
expression_vector result;
for (auto& definition : propertyDefinitions) {
std::shared_ptr<Expression> rhs;
if (propertyDataExprs.contains(definition.getName())) {
rhs = propertyDataExprs.at(definition.getName());
} else {
rhs = expressionBinder.bindExpression(*definition.defaultExpr);
}
rhs = expressionBinder.implicitCastIfNecessary(rhs, definition.getType());
result.push_back(std::move(rhs));
}
return result;
}
std::unique_ptr<BoundUpdatingClause> Binder::bindSetClause(const UpdatingClause& updatingClause) {
auto& setClause = updatingClause.constCast<SetClause>();
auto boundSetClause = std::make_unique<BoundSetClause>();
for (auto& setItem : setClause.getSetItemsRef()) {
boundSetClause->addInfo(bindSetPropertyInfo(setItem.first.get(), setItem.second.get()));
}
return boundSetClause;
}
BoundSetPropertyInfo Binder::bindSetPropertyInfo(const ParsedExpression* column,
const ParsedExpression* columnData) {
auto expr = expressionBinder.bindExpression(*column->getChild(0));
auto isNode = ExpressionUtil::isNodePattern(*expr);
auto isRel = ExpressionUtil::isRelPattern(*expr);
if (!isNode && !isRel) {
throw BinderException(
stringFormat("Cannot set expression {} with type {}. Expect node or rel pattern.",
expr->toString(), ExpressionTypeUtil::toString(expr->expressionType)));
}
auto boundSetItem = bindSetItem(column, columnData);
auto boundColumn = boundSetItem.first;
auto boundColumnData = boundSetItem.second;
auto& nodeOrRel = expr->constCast<NodeOrRelExpression>();
auto& property = boundSetItem.first->constCast<PropertyExpression>();
// Check secondary index constraint
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
for (auto entry : nodeOrRel.getEntries()) {
// When setting multi labeled node, skip checking if property is not in current table.
if (!property.hasProperty(entry->getTableID())) {
continue;
}
auto propertyID = entry->getPropertyID(property.getPropertyName());
if (catalog->containsUnloadedIndex(transaction, entry->getTableID(), propertyID)) {
throw BinderException(
stringFormat("Cannot set property {} in table {} because it is used in one or more "
"indexes which is unloaded.",
property.getPropertyName(), entry->getName()));
}
}
// Check primary key constraint
if (isNode) {
for (auto entry : nodeOrRel.getEntries()) {
if (property.isPrimaryKey(entry->getTableID())) {
throw BinderException(
stringFormat("Cannot set property {} in table {} because it is used as primary "
"key. Try delete and then insert.",
property.getPropertyName(), entry->getName()));
}
}
return BoundSetPropertyInfo(TableType::NODE, expr, boundColumn, boundColumnData);
}
return BoundSetPropertyInfo(TableType::REL, expr, boundColumn, boundColumnData);
}
expression_pair Binder::bindSetItem(const ParsedExpression* column,
const ParsedExpression* columnData) {
auto boundColumn = expressionBinder.bindExpression(*column);
auto boundColumnData = expressionBinder.bindExpression(*columnData);
boundColumnData =
expressionBinder.implicitCastIfNecessary(boundColumnData, boundColumn->dataType);
return make_pair(std::move(boundColumn), std::move(boundColumnData));
}
std::unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(
const UpdatingClause& updatingClause) {
auto& deleteClause = updatingClause.constCast<DeleteClause>();
auto deleteType = deleteClause.getDeleteClauseType();
auto boundDeleteClause = std::make_unique<BoundDeleteClause>();
for (auto i = 0u; i < deleteClause.getNumExpressions(); ++i) {
auto pattern = expressionBinder.bindExpression(*deleteClause.getExpression(i));
if (ExpressionUtil::isNodePattern(*pattern)) {
auto deleteNodeInfo = BoundDeleteInfo(deleteType, TableType::NODE, pattern);
auto& node = pattern->constCast<NodeExpression>();
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
for (auto entry : node.getEntries()) {
for (auto index : catalog->getIndexEntries(transaction, entry->getTableID())) {
if (!index->isLoaded()) {
throw BinderException(
stringFormat("Trying to delete from an index on table {} but its "
"extension is not loaded.",
entry->getName()));
}
}
}
boundDeleteClause->addInfo(std::move(deleteNodeInfo));
} else if (ExpressionUtil::isRelPattern(*pattern)) {
// LCOV_EXCL_START
if (deleteClause.getDeleteClauseType() == DeleteNodeType::DETACH_DELETE) {
throw BinderException("Detach delete on rel tables is not supported.");
}
// LCOV_EXCL_STOP
auto rel = pattern->constPtrCast<RelExpression>();
if (rel->getDirectionType() == RelDirectionType::BOTH) {
throw BinderException("Delete undirected rel is not supported.");
}
auto deleteRel = BoundDeleteInfo(deleteType, TableType::REL, pattern);
boundDeleteClause->addInfo(std::move(deleteRel));
} else {
throw BinderException(stringFormat(
"Cannot delete expression {} with type {}. Expect node or rel pattern.",
pattern->toString(), ExpressionTypeUtil::toString(pattern->expressionType)));
}
}
return boundDeleteClause;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,14 @@
#include "binder/binder.h"
#include "binder/bound_use_database.h"
#include "parser/use_database.h"
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindUseDatabase(const parser::Statement& statement) {
auto useDatabase = statement.constCast<parser::UseDatabase>();
return std::make_unique<BoundUseDatabase>(useDatabase.getDBName());
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,8 @@
add_library(lbug_binder_bind_copy
OBJECT
bind_copy_to.cpp
bind_copy_from.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_bind_copy>
PARENT_SCOPE)

View File

@@ -0,0 +1,344 @@
#include "binder/binder.h"
#include "binder/copy/bound_copy_from.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/index_catalog_entry.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "parser/copy.h"
#include "transaction/transaction.h"
using namespace lbug::binder;
using namespace lbug::catalog;
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::function;
namespace lbug {
namespace binder {
static void throwTableNotExist(const std::string& tableName) {
throw BinderException(stringFormat("Table {} does not exist.", tableName));
}
std::unique_ptr<BoundStatement> Binder::bindLegacyCopyRelGroupFrom(const Statement& statement) {
auto& copyFrom = statement.constCast<CopyFrom>();
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
auto tableName = copyFrom.getTableName();
auto tableNameParts = common::StringUtils::split(tableName, "_");
if (tableNameParts.size() != 3 || !catalog->containsTable(transaction, tableNameParts[0])) {
throwTableNotExist(tableName);
}
auto entry = catalog->getTableCatalogEntry(transaction, tableNameParts[0]);
if (entry->getType() != CatalogEntryType::REL_GROUP_ENTRY) {
throwTableNotExist(tableName);
}
auto relGroupEntry = entry->ptrCast<RelGroupCatalogEntry>();
try {
return bindCopyRelFrom(copyFrom, *relGroupEntry, tableNameParts[1], tableNameParts[2]);
} catch (Exception& e) {
throwTableNotExist(tableName);
return nullptr;
}
}
std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& statement) {
auto& copyStatement = statement.constCast<CopyFrom>();
auto tableName = copyStatement.getTableName();
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
if (!catalog->containsTable(transaction, tableName)) {
return bindLegacyCopyRelGroupFrom(statement);
}
auto tableEntry = catalog->getTableCatalogEntry(transaction, tableName);
switch (tableEntry->getType()) {
case CatalogEntryType::NODE_TABLE_ENTRY: {
return bindCopyNodeFrom(statement, *tableEntry->ptrCast<NodeTableCatalogEntry>());
}
case CatalogEntryType::REL_GROUP_ENTRY: {
auto entry = tableEntry->ptrCast<RelGroupCatalogEntry>();
auto properties = entry->getProperties();
KU_ASSERT(entry->getNumRelTables() > 0);
if (entry->getNumRelTables() == 1) {
auto fromToNodePair = entry->getSingleRelEntryInfo().nodePair;
auto fromTable = catalog->getTableCatalogEntry(transaction, fromToNodePair.srcTableID);
auto toTable = catalog->getTableCatalogEntry(transaction, fromToNodePair.dstTableID);
return bindCopyRelFrom(statement, *entry, fromTable->getName(), toTable->getName());
} else {
auto options = bindParsingOptions(copyStatement.getParsingOptions());
if (!options.contains(CopyConstants::FROM_OPTION_NAME) ||
!options.contains(CopyConstants::TO_OPTION_NAME)) {
throw BinderException(stringFormat(
"The table {} has multiple FROM and TO pairs defined in the schema. A "
"specific pair of FROM and TO options is expected when copying data "
"into "
"the {} table.",
tableName, tableName));
}
auto from = options.at(CopyConstants::FROM_OPTION_NAME).getValue<std::string>();
auto to = options.at(CopyConstants::TO_OPTION_NAME).getValue<std::string>();
return bindCopyRelFrom(statement, *entry, from, to);
}
}
default: {
KU_UNREACHABLE;
}
}
}
static void bindExpectedNodeColumns(const NodeTableCatalogEntry& entry,
const CopyFromColumnInfo& info, std::vector<std::string>& columnNames,
std::vector<LogicalType>& columnTypes);
static void bindExpectedRelColumns(const RelGroupCatalogEntry& entry,
const NodeTableCatalogEntry& fromEntry, const NodeTableCatalogEntry& toEntry,
const CopyFromColumnInfo& info, std::vector<std::string>& columnNames,
std::vector<LogicalType>& columnTypes);
static std::pair<ColumnEvaluateType, std::shared_ptr<Expression>> matchColumnExpression(
const expression_vector& columns, const PropertyDefinition& property,
ExpressionBinder& expressionBinder) {
for (auto& column : columns) {
if (property.getName() == column->toString()) {
if (column->dataType == property.getType()) {
return {ColumnEvaluateType::REFERENCE, column};
} else {
return {ColumnEvaluateType::CAST,
expressionBinder.forceCast(column, property.getType())};
}
}
}
return {ColumnEvaluateType::DEFAULT, expressionBinder.bindExpression(*property.defaultExpr)};
}
BoundCopyFromInfo Binder::bindCopyNodeFromInfo(std::string tableName,
const std::vector<PropertyDefinition>& properties, const BaseScanSource* source,
const options_t& parsingOptions, const std::vector<std::string>& expectedColumnNames,
const std::vector<LogicalType>& expectedColumnTypes, bool byColumn) {
auto boundSource =
bindScanSource(source, parsingOptions, expectedColumnNames, expectedColumnTypes);
expression_vector warningDataExprs = boundSource->getWarningColumns();
if (boundSource->type == ScanSourceType::FILE) {
auto bindData = boundSource->constCast<BoundTableScanSource>()
.info.bindData->constPtrCast<ScanFileBindData>();
if (byColumn && bindData->fileScanInfo.fileTypeInfo.fileType != FileType::NPY) {
throw BinderException(stringFormat("Copy by column with {} file type is not supported.",
bindData->fileScanInfo.fileTypeInfo.fileTypeStr));
}
}
expression_vector columns;
std::vector<ColumnEvaluateType> evaluateTypes;
for (auto& property : properties) {
auto [evaluateType, column] =
matchColumnExpression(boundSource->getColumns(), property, expressionBinder);
columns.push_back(column);
evaluateTypes.push_back(evaluateType);
}
columns.insert(columns.end(), warningDataExprs.begin(), warningDataExprs.end());
auto offset =
createInvisibleVariable(std::string(InternalKeyword::ROW_OFFSET), LogicalType::INT64());
return BoundCopyFromInfo(tableName, TableType::NODE, std::move(boundSource), std::move(offset),
std::move(columns), std::move(evaluateTypes), nullptr /* extraInfo */);
}
std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(const Statement& statement,
NodeTableCatalogEntry& nodeTableEntry) {
auto& copyStatement = statement.constCast<CopyFrom>();
// Check extension secondary index loaded
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
for (auto indexEntry : catalog->getIndexEntries(transaction, nodeTableEntry.getTableID())) {
if (!indexEntry->isLoaded()) {
throw BinderException(stringFormat(
"Trying to insert into an index on table {} but its extension is not loaded.",
nodeTableEntry.getName()));
}
}
// Bind expected columns based on catalog information.
std::vector<std::string> expectedColumnNames;
std::vector<LogicalType> expectedColumnTypes;
bindExpectedNodeColumns(nodeTableEntry, copyStatement.getCopyColumnInfo(), expectedColumnNames,
expectedColumnTypes);
auto boundCopyFromInfo =
bindCopyNodeFromInfo(nodeTableEntry.getName(), nodeTableEntry.getProperties(),
copyStatement.getSource(), copyStatement.getParsingOptions(), expectedColumnNames,
expectedColumnTypes, copyStatement.byColumn());
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
}
static options_t getScanSourceOptions(const CopyFrom& copyFrom) {
options_t options;
static case_insensitve_set_t copyFromPairsOptions = {CopyConstants::FROM_OPTION_NAME,
CopyConstants::TO_OPTION_NAME};
for (auto& option : copyFrom.getParsingOptions()) {
if (copyFromPairsOptions.contains(option.first)) {
continue;
}
options.emplace(option.first, option.second->copy());
}
return options;
}
BoundCopyFromInfo Binder::bindCopyRelFromInfo(std::string tableName,
const std::vector<PropertyDefinition>& properties, const BaseScanSource* source,
const options_t& parsingOptions, const std::vector<std::string>& expectedColumnNames,
const std::vector<LogicalType>& expectedColumnTypes, const NodeTableCatalogEntry* fromTable,
const NodeTableCatalogEntry* toTable) {
auto boundSource =
bindScanSource(source, parsingOptions, expectedColumnNames, expectedColumnTypes);
expression_vector warningDataExprs = boundSource->getWarningColumns();
auto columns = boundSource->getColumns();
auto offset =
createInvisibleVariable(std::string(InternalKeyword::ROW_OFFSET), LogicalType::INT64());
auto srcOffset = createVariable(std::string(InternalKeyword::SRC_OFFSET), LogicalType::INT64());
auto dstOffset = createVariable(std::string(InternalKeyword::DST_OFFSET), LogicalType::INT64());
expression_vector columnExprs{srcOffset, dstOffset, offset};
std::vector<ColumnEvaluateType> evaluateTypes{ColumnEvaluateType::REFERENCE,
ColumnEvaluateType::REFERENCE, ColumnEvaluateType::REFERENCE};
for (auto i = 1u; i < properties.size(); ++i) { // skip internal ID
auto& property = properties[i];
auto [evaluateType, column] =
matchColumnExpression(boundSource->getColumns(), property, expressionBinder);
columnExprs.push_back(column);
evaluateTypes.push_back(evaluateType);
}
columnExprs.insert(columnExprs.end(), warningDataExprs.begin(), warningDataExprs.end());
std::shared_ptr<Expression> srcKey = nullptr, dstKey = nullptr;
if (expectedColumnTypes[0] != columns[0]->getDataType()) {
srcKey = expressionBinder.forceCast(columns[0], expectedColumnTypes[0]);
} else {
srcKey = columns[0];
}
if (expectedColumnTypes[1] != columns[1]->getDataType()) {
dstKey = expressionBinder.forceCast(columns[1], expectedColumnTypes[1]);
} else {
dstKey = columns[1];
}
auto srcLookUpInfo =
IndexLookupInfo(fromTable->getTableID(), srcOffset, srcKey, warningDataExprs);
auto dstLookUpInfo =
IndexLookupInfo(toTable->getTableID(), dstOffset, dstKey, warningDataExprs);
auto lookupInfos = std::vector<IndexLookupInfo>{srcLookUpInfo, dstLookUpInfo};
auto internalIDColumnIndices = std::vector<idx_t>{0, 1, 2};
auto extraCopyRelInfo = std::make_unique<ExtraBoundCopyRelInfo>(fromTable->getName(),
toTable->getName(), internalIDColumnIndices, lookupInfos);
return BoundCopyFromInfo(tableName, TableType::REL, boundSource->copy(), offset,
std::move(columnExprs), std::move(evaluateTypes), std::move(extraCopyRelInfo));
}
std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(const Statement& statement,
RelGroupCatalogEntry& relGroupEntry, const std::string& fromTableName,
const std::string& toTableName) {
auto& copyStatement = statement.constCast<CopyFrom>();
if (copyStatement.byColumn()) {
throw BinderException(
stringFormat("Copy by column is not supported for relationship table."));
}
// Bind from to tables
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
auto fromTable =
catalog->getTableCatalogEntry(transaction, fromTableName)->ptrCast<NodeTableCatalogEntry>();
auto toTable =
catalog->getTableCatalogEntry(transaction, toTableName)->ptrCast<NodeTableCatalogEntry>();
auto relInfo = relGroupEntry.getRelEntryInfo(fromTable->getTableID(), toTable->getTableID());
if (relInfo == nullptr) {
throw BinderException(stringFormat("Rel table {} does not contain {}-{} from-to pair.",
relGroupEntry.getName(), fromTable->getName(), toTable->getName()));
}
// Bind expected columns based on catalog information.
std::vector<std::string> expectedColumnNames;
std::vector<LogicalType> expectedColumnTypes;
bindExpectedRelColumns(relGroupEntry, *fromTable, *toTable, copyStatement.getCopyColumnInfo(),
expectedColumnNames, expectedColumnTypes);
// Bind info
auto boundCopyFromInfo =
bindCopyRelFromInfo(relGroupEntry.getName(), relGroupEntry.getProperties(),
copyStatement.getSource(), getScanSourceOptions(copyStatement), expectedColumnNames,
expectedColumnTypes, fromTable, toTable);
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
}
static bool skipPropertyInFile(const PropertyDefinition& property) {
if (property.getName() == InternalKeyword::ID) {
return true;
}
return false;
}
static bool skipPropertyInSchema(const PropertyDefinition& property) {
if (property.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) {
return true;
}
if (property.getName() == InternalKeyword::ID) {
return true;
}
return false;
}
static void bindExpectedColumns(const TableCatalogEntry& entry, const CopyFromColumnInfo& info,
std::vector<std::string>& columnNames, std::vector<LogicalType>& columnTypes) {
if (info.inputColumnOrder) {
std::unordered_set<std::string> inputColumnNamesSet;
for (auto& columName : info.columnNames) {
if (inputColumnNamesSet.contains(columName)) {
throw BinderException(
stringFormat("Detect duplicate column name {} during COPY.", columName));
}
inputColumnNamesSet.insert(columName);
}
// Search column data type for each input column.
for (auto& columnName : info.columnNames) {
if (!entry.containsProperty(columnName)) {
throw BinderException(stringFormat("Table {} does not contain column {}.",
entry.getName(), columnName));
}
auto& property = entry.getProperty(columnName);
if (skipPropertyInFile(property)) {
continue;
}
columnNames.push_back(columnName);
columnTypes.push_back(property.getType().copy());
}
} else {
// No column specified. Fall back to schema columns.
for (auto& property : entry.getProperties()) {
if (skipPropertyInSchema(property)) {
continue;
}
columnNames.push_back(property.getName());
columnTypes.push_back(property.getType().copy());
}
}
}
void bindExpectedNodeColumns(const NodeTableCatalogEntry& entry, const CopyFromColumnInfo& info,
std::vector<std::string>& columnNames, std::vector<LogicalType>& columnTypes) {
KU_ASSERT(columnNames.empty() && columnTypes.empty());
bindExpectedColumns(entry, info, columnNames, columnTypes);
}
void bindExpectedRelColumns(const RelGroupCatalogEntry& entry,
const NodeTableCatalogEntry& fromEntry, const NodeTableCatalogEntry& toEntry,
const CopyFromColumnInfo& info, std::vector<std::string>& columnNames,
std::vector<LogicalType>& columnTypes) {
KU_ASSERT(columnNames.empty() && columnTypes.empty());
columnNames.push_back("from");
columnNames.push_back("to");
auto srcPKColumnType = fromEntry.getPrimaryKeyDefinition().getType().copy();
if (srcPKColumnType.getLogicalTypeID() == LogicalTypeID::SERIAL) {
srcPKColumnType = LogicalType::INT64();
}
auto dstPKColumnType = toEntry.getPrimaryKeyDefinition().getType().copy();
if (dstPKColumnType.getLogicalTypeID() == LogicalTypeID::SERIAL) {
dstPKColumnType = LogicalType::INT64();
}
columnTypes.push_back(std::move(srcPKColumnType));
columnTypes.push_back(std::move(dstPKColumnType));
bindExpectedColumns(entry, info, columnNames, columnTypes);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,48 @@
#include "binder/binder.h"
#include "binder/copy/bound_copy_to.h"
#include "catalog/catalog.h"
#include "common/exception/catalog.h"
#include "common/exception/runtime.h"
#include "function/built_in_function_utils.h"
#include "parser/copy.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statement) {
auto& copyToStatement = statement.constCast<CopyTo>();
auto boundFilePath = copyToStatement.getFilePath();
auto fileTypeInfo = bindFileTypeInfo({boundFilePath});
std::vector<std::string> columnNames;
auto parsedQuery = copyToStatement.getStatement();
auto query = bindQuery(*parsedQuery);
auto columns = query->getStatementResult()->getColumns();
auto fileTypeStr = fileTypeInfo.fileTypeStr;
auto name = stringFormat("COPY_{}", fileTypeStr);
catalog::CatalogEntry* entry = nullptr;
try {
entry = catalog::Catalog::Get(*clientContext)
->getFunctionEntry(transaction::Transaction::Get(*clientContext), name);
} catch (CatalogException& exception) {
throw RuntimeException{common::stringFormat(
"Exporting query result to the '{}' file is currently not supported.", fileTypeStr)};
}
auto exportFunc = function::BuiltInFunctionsUtils::matchFunction(name,
entry->ptrCast<catalog::FunctionCatalogEntry>())
->constPtrCast<function::ExportFunction>();
for (auto& column : columns) {
auto columnName = column->hasAlias() ? column->getAlias() : column->toString();
columnNames.push_back(columnName);
}
function::ExportFuncBindInput bindInput{std::move(columnNames), std::move(boundFilePath),
bindParsingOptions(copyToStatement.getParsingOptions())};
auto bindData = exportFunc->bind(bindInput);
return std::make_unique<BoundCopyTo>(std::move(bindData), *exportFunc, std::move(query));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,7 @@
add_library(lbug_binder_bind_ddl
OBJECT
bound_create_table_info.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_bind_ddl>
PARENT_SCOPE)

View File

@@ -0,0 +1,43 @@
#include "binder/ddl/bound_create_table_info.h"
#include "catalog/catalog_entry/catalog_entry_type.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
using namespace lbug::parser;
using namespace lbug::common;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
std::string BoundCreateTableInfo::toString() const {
std::string result = "";
switch (type) {
case CatalogEntryType::NODE_TABLE_ENTRY: {
result += "Create Node Table: ";
result += tableName;
result += ",Properties: ";
auto nodeInfo = extraInfo->ptrCast<BoundExtraCreateNodeTableInfo>();
for (auto& definition : nodeInfo->propertyDefinitions) {
result += definition.getName();
result += ", ";
}
} break;
case CatalogEntryType::REL_GROUP_ENTRY: {
result += "Create Relationship Table: ";
result += tableName;
auto relGroupInfo = extraInfo->ptrCast<BoundExtraCreateRelTableGroupInfo>();
result += "Properties: ";
for (auto& definition : relGroupInfo->propertyDefinitions) {
result += definition.getName();
result += ", ";
}
} break;
default:
break;
}
return result;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,10 @@
add_library(lbug_binder_bind_read
OBJECT
bind_in_query_call.cpp
bind_load_from.cpp
bind_match.cpp
bind_unwind.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_bind_read>
PARENT_SCOPE)

View File

@@ -0,0 +1,45 @@
#include "binder/binder.h"
#include "binder/query/reading_clause/bound_table_function_call.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/query/reading_clause/in_query_call_clause.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::catalog;
using namespace lbug::parser;
using namespace lbug::function;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = readingClause.constCast<InQueryCallClause>();
auto expr = call.getFunctionExpression();
auto functionExpr = expr->constPtrCast<ParsedFunctionExpression>();
auto functionName = functionExpr->getFunctionName();
std::unique_ptr<BoundReadingClause> boundReadingClause;
auto transaction = transaction::Transaction::Get(*clientContext);
auto entry = Catalog::Get(*clientContext)->getFunctionEntry(transaction, functionName);
switch (entry->getType()) {
case CatalogEntryType::TABLE_FUNCTION_ENTRY: {
auto boundTableFunction =
bindTableFunc(functionName, *functionExpr, call.getYieldVariables());
boundReadingClause =
std::make_unique<BoundTableFunctionCall>(std::move(boundTableFunction));
} break;
default:
throw BinderException(
stringFormat("{} is not a table or algorithm function.", functionName));
}
if (call.hasWherePredicate()) {
auto wherePredicate = bindWhereExpression(*call.getWherePredicate());
boundReadingClause->setPredicate(std::move(wherePredicate));
}
return boundReadingClause;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,64 @@
#include "binder/binder.h"
#include "binder/bound_scan_source.h"
#include "binder/expression/expression_util.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "common/exception/binder.h"
#include "parser/query/reading_clause/load_from.h"
#include "parser/scan_source.h"
using namespace lbug::function;
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& readingClause) {
auto& loadFrom = readingClause.constCast<LoadFrom>();
auto source = loadFrom.getSource();
std::unique_ptr<BoundLoadFrom> boundLoadFrom;
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
for (auto& [name, type] : loadFrom.getColumnDefinitions()) {
columnNames.push_back(name);
columnTypes.push_back(LogicalType::convertFromString(type, clientContext));
}
switch (source->type) {
case ScanSourceType::OBJECT: {
auto objectSource = source->ptrCast<ObjectScanSource>();
auto boundScanSource = bindObjectScanSource(*objectSource, loadFrom.getParsingOptions(),
columnNames, columnTypes);
auto& scanInfo = boundScanSource->constCast<BoundTableScanSource>().info;
boundLoadFrom = std::make_unique<BoundLoadFrom>(scanInfo.copy());
} break;
case ScanSourceType::FILE: {
auto boundScanSource =
bindFileScanSource(*source, loadFrom.getParsingOptions(), columnNames, columnTypes);
auto& scanInfo = boundScanSource->constCast<BoundTableScanSource>().info;
boundLoadFrom = std::make_unique<BoundLoadFrom>(scanInfo.copy());
} break;
case ScanSourceType::PARAM: {
auto boundScanSource = bindParameterScanSource(*source, loadFrom.getParsingOptions(),
columnNames, columnTypes);
auto& scanInfo = boundScanSource->constCast<BoundTableScanSource>().info;
boundLoadFrom = std::make_unique<BoundLoadFrom>(scanInfo.copy());
} break;
default:
throw BinderException(stringFormat("LOAD FROM subquery is not supported."));
}
if (!columnTypes.empty()) {
auto info = boundLoadFrom->getInfo();
for (auto i = 0u; i < columnTypes.size(); ++i) {
ExpressionUtil::validateDataType(*info->bindData->columns[i], columnTypes[i]);
}
}
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = bindWhereExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setPredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,126 @@
#include "binder/binder.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "common/exception/binder.h"
#include "parser/query/reading_clause/match_clause.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
static void collectHintPattern(const BoundJoinHintNode& node, binder::expression_set& set) {
if (node.isLeaf()) {
set.insert(node.nodeOrRel);
return;
}
for (auto& child : node.children) {
collectHintPattern(*child, set);
}
}
static void validateHintCompleteness(const BoundJoinHintNode& root, const QueryGraph& queryGraph) {
binder::expression_set set;
collectHintPattern(root, set);
for (auto& nodeOrRel : queryGraph.getAllPatterns()) {
if (nodeOrRel->getVariableName().empty()) {
throw BinderException(
"Cannot hint join order in a match patter with anonymous node or relationship.");
}
if (!set.contains(nodeOrRel)) {
throw BinderException(
stringFormat("Cannot find {} in join hint.", nodeOrRel->toString()));
}
}
}
std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause& readingClause) {
auto& matchClause = readingClause.constCast<MatchClause>();
auto boundGraphPattern = bindGraphPattern(matchClause.getPatternElementsRef());
if (matchClause.hasWherePredicate()) {
boundGraphPattern.where = bindWhereExpression(*matchClause.getWherePredicate());
}
rewriteMatchPattern(boundGraphPattern);
auto boundMatch = std::make_unique<BoundMatchClause>(
std::move(boundGraphPattern.queryGraphCollection), matchClause.getMatchClauseType());
if (matchClause.hasHint()) {
boundMatch->setHint(
bindJoinHint(*boundMatch->getQueryGraphCollection(), *matchClause.getHint()));
}
boundMatch->setPredicate(boundGraphPattern.where);
return boundMatch;
}
std::shared_ptr<BoundJoinHintNode> Binder::bindJoinHint(
const QueryGraphCollection& queryGraphCollection, const JoinHintNode& joinHintNode) {
if (queryGraphCollection.getNumQueryGraphs() > 1) {
throw BinderException("Join hint on disconnected match pattern is not supported.");
}
auto hint = bindJoinNode(joinHintNode);
validateHintCompleteness(*hint, *queryGraphCollection.getQueryGraph(0));
return hint;
}
std::shared_ptr<BoundJoinHintNode> Binder::bindJoinNode(const JoinHintNode& joinHintNode) {
if (joinHintNode.isLeaf()) {
std::shared_ptr<Expression> pattern = nullptr;
if (scope.contains(joinHintNode.variableName)) {
pattern = scope.getExpression(joinHintNode.variableName);
}
if (pattern == nullptr || pattern->expressionType != ExpressionType::PATTERN) {
throw BinderException(stringFormat("Cannot bind {} to a node or relationship pattern",
joinHintNode.variableName));
}
return std::make_shared<BoundJoinHintNode>(std::move(pattern));
}
auto node = std::make_shared<BoundJoinHintNode>();
for (auto& child : joinHintNode.children) {
node->addChild(bindJoinNode(*child));
}
return node;
}
void Binder::rewriteMatchPattern(BoundGraphPattern& boundGraphPattern) {
// Rewrite self loop edge
// e.g. rewrite (a)-[e]->(a) as [a]-[e]->(b) WHERE id(a) = id(b)
expression_vector selfLoopEdgePredicates;
auto& graphCollection = boundGraphPattern.queryGraphCollection;
for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = graphCollection.getQueryGraphUnsafe(i);
for (auto& queryRel : queryGraph->getQueryRels()) {
if (!queryRel->isSelfLoop()) {
continue;
}
auto src = queryRel->getSrcNode();
auto dst = queryRel->getDstNode();
auto newDst = createQueryNode(dst->getVariableName(), dst->getEntries());
queryGraph->addQueryNode(newDst);
queryRel->setDstNode(newDst);
auto predicate = expressionBinder.createEqualityComparisonExpression(
src->getInternalID(), newDst->getInternalID());
selfLoopEdgePredicates.push_back(std::move(predicate));
}
}
auto where = boundGraphPattern.where;
for (auto& predicate : selfLoopEdgePredicates) {
where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate, where);
}
// Rewrite key value pairs in MATCH clause as predicate
for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = graphCollection.getQueryGraphUnsafe(i);
for (auto& pattern : queryGraph->getAllPatterns()) {
for (auto& [propertyName, rhs] : pattern->getPropertyDataExprRef()) {
auto propertyExpr =
expressionBinder.bindNodeOrRelPropertyExpression(*pattern, propertyName);
auto predicate =
expressionBinder.createEqualityComparisonExpression(propertyExpr, rhs);
where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate,
where);
}
}
}
boundGraphPattern.where = std::move(where);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,55 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "parser/query/reading_clause/unwind_clause.h"
using namespace lbug::parser;
using namespace lbug::common;
namespace lbug {
namespace binder {
// E.g. UNWIND $1. We cannot validate $1 has data type LIST until we see the actual parameter.
static bool skipDataTypeValidation(const Expression& expr) {
return expr.expressionType == ExpressionType::PARAMETER &&
expr.getDataType().getLogicalTypeID() == LogicalTypeID::ANY;
}
std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = readingClause.constCast<UnwindClause>();
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
auto aliasName = unwindClause.getAlias();
std::shared_ptr<Expression> alias;
if (boundExpression->getDataType().getLogicalTypeID() == LogicalTypeID::ARRAY) {
auto targetType =
LogicalType::LIST(ArrayType::getChildType(boundExpression->dataType).copy());
boundExpression = expressionBinder.implicitCast(boundExpression, targetType);
}
if (!skipDataTypeValidation(*boundExpression)) {
if (ExpressionUtil::isNullLiteral(*boundExpression)) {
// For the `unwind NULL` clause, we assign the STRING[] type to the NULL literal.
alias = createVariable(aliasName, LogicalType::STRING());
boundExpression = expressionBinder.implicitCast(boundExpression,
LogicalType::LIST(LogicalType::STRING()));
} else {
ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST);
boundExpression = expressionBinder.implicitCastIfNecessary(boundExpression,
LogicalTypeUtils::purgeAny(boundExpression->dataType, LogicalType::STRING()));
alias = createVariable(aliasName, ListType::getChildType(boundExpression->dataType));
}
} else {
alias = createVariable(aliasName, LogicalType::ANY());
}
std::shared_ptr<Expression> idExpr = nullptr;
if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) {
auto entries = scope.getMemorizedTableEntries(boundExpression->getAlias());
auto node = createQueryNode(aliasName, entries);
idExpr = node->getInternalID();
scope.addNodeReplacement(node);
}
return make_unique<BoundUnwindClause>(std::move(boundExpression), std::move(alias),
std::move(idExpr));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,18 @@
add_library(
lbug_binder_bind_expression
OBJECT
bind_boolean_expression.cpp
bind_case_expression.cpp
bind_comparison_expression.cpp
bind_function_expression.cpp
bind_literal_expression.cpp
bind_null_operator_expression.cpp
bind_parameter_expression.cpp
bind_property_expression.cpp
bind_subquery_expression.cpp
bind_variable_expression.cpp
bind_lambda_expression.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_bind_expression>
PARENT_SCOPE)

View File

@@ -0,0 +1,57 @@
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_binder.h"
#include "function/boolean/vector_boolean_functions.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::function;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
children.push_back(bindExpression(*parsedExpression.getChild(i)));
}
return bindBooleanExpression(parsedExpression.getExpressionType(), children);
}
std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(ExpressionType expressionType,
const expression_vector& children) {
expression_vector childrenAfterCast;
std::vector<LogicalTypeID> inputTypeIDs;
for (auto& child : children) {
childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalType::BOOL()));
inputTypeIDs.push_back(LogicalTypeID::BOOL);
}
auto functionName = ExpressionTypeUtil::toString(expressionType);
scalar_func_exec_t execFunc;
VectorBooleanFunction::bindExecFunction(expressionType, childrenAfterCast, execFunc);
scalar_func_select_t selectFunc;
VectorBooleanFunction::bindSelectFunction(expressionType, childrenAfterCast, selectFunc);
auto bindData = std::make_unique<FunctionBindData>(LogicalType::BOOL());
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
auto func = std::make_unique<ScalarFunction>(functionName, inputTypeIDs, LogicalTypeID::BOOL,
execFunc, selectFunc);
return std::make_shared<ScalarFunctionExpression>(expressionType, std::move(func),
std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName);
}
std::shared_ptr<Expression> ExpressionBinder::combineBooleanExpressions(
ExpressionType expressionType, std::shared_ptr<Expression> left,
std::shared_ptr<Expression> right) {
if (left == nullptr) {
return right;
} else if (right == nullptr) {
return left;
} else {
return bindBooleanExpression(expressionType,
expression_vector{std::move(left), std::move(right)});
}
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,75 @@
#include "binder/binder.h"
#include "binder/expression/case_expression.h"
#include "binder/expression/expression_util.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_case_expression.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindCaseExpression(
const ParsedExpression& parsedExpression) {
auto& parsedCaseExpression = parsedExpression.constCast<ParsedCaseExpression>();
auto resultType = LogicalType::ANY();
// Resolve result type by checking each then expression type.
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto alternative = parsedCaseExpression.getCaseAlternative(i);
auto boundThen = bindExpression(*alternative->thenExpression);
if (boundThen->getDataType().getLogicalTypeID() != LogicalTypeID::ANY) {
resultType = boundThen->getDataType().copy();
}
}
// Resolve result type by else expression if above resolving fails.
if (resultType.getLogicalTypeID() == LogicalTypeID::ANY &&
parsedCaseExpression.hasElseExpression()) {
auto elseExpression = bindExpression(*parsedCaseExpression.getElseExpression());
resultType = elseExpression->getDataType().copy();
}
auto name = binder->getUniqueExpressionName(parsedExpression.getRawName());
// bind ELSE ...
std::shared_ptr<Expression> elseExpression;
if (parsedCaseExpression.hasElseExpression()) {
elseExpression = bindExpression(*parsedCaseExpression.getElseExpression());
} else {
elseExpression = createNullLiteralExpression();
}
elseExpression = implicitCastIfNecessary(elseExpression, resultType);
auto boundCaseExpression =
make_shared<CaseExpression>(resultType.copy(), std::move(elseExpression), name);
// bind WHEN ... THEN ...
if (parsedCaseExpression.hasCaseExpression()) {
auto boundCase = bindExpression(*parsedCaseExpression.getCaseExpression());
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, boundCase->dataType);
// rewrite "CASE a.age WHEN 1" as "CASE WHEN a.age = 1"
if (ExpressionUtil::isNullLiteral(*boundWhen)) {
boundWhen = bindNullOperatorExpression(ExpressionType::IS_NULL,
expression_vector{boundWhen});
} else {
boundWhen = bindComparisonExpression(ExpressionType::EQUALS,
expression_vector{boundCase, boundWhen});
}
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, resultType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
} else {
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, LogicalType::BOOL());
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, resultType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
}
return boundCaseExpression;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,96 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_binder.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::catalog;
using namespace lbug::parser;
using namespace lbug::function;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
children.push_back(std::move(child));
}
return bindComparisonExpression(parsedExpression.getExpressionType(), children);
}
static bool isNodeOrRel(const Expression& expression) {
switch (expression.getDataType().getLogicalTypeID()) {
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
return true;
default:
return false;
}
}
std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
// Rewrite node or rel comparison
KU_ASSERT(children.size() == 2);
if (isNodeOrRel(*children[0]) && isNodeOrRel(*children[1])) {
expression_vector newChildren;
newChildren.push_back(children[0]->constCast<NodeOrRelExpression>().getInternalID());
newChildren.push_back(children[1]->constCast<NodeOrRelExpression>().getInternalID());
return bindComparisonExpression(expressionType, newChildren);
}
auto catalog = Catalog::Get(*context);
auto transaction = transaction::Transaction::Get(*context);
auto functionName = ExpressionTypeUtil::toString(expressionType);
LogicalType combinedType(LogicalTypeID::ANY);
if (!ExpressionUtil::tryCombineDataType(children, combinedType)) {
throw BinderException(stringFormat("Type Mismatch: Cannot compare types {} and {}",
children[0]->dataType.toString(), children[1]->dataType.toString()));
}
if (combinedType.getLogicalTypeID() == LogicalTypeID::ANY) {
combinedType = LogicalType(LogicalTypeID::INT8);
}
std::vector<LogicalType> childrenTypes;
for (auto i = 0u; i < children.size(); i++) {
childrenTypes.push_back(combinedType.copy());
}
auto entry =
catalog->getFunctionEntry(transaction, functionName)->ptrCast<FunctionCatalogEntry>();
auto function = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry)
->ptrCast<ScalarFunction>();
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
if (children[i]->dataType != combinedType) {
childrenAfterCast.push_back(forceCast(children[i], combinedType));
} else {
childrenAfterCast.push_back(children[i]);
}
}
if (function->bindFunc) {
// Resolve exec and select function if necessary
// Only used for decimal at the moment. See `bindDecimalCompare`.
function->bindFunc({childrenAfterCast, function, nullptr,
std::vector<std::string>{} /* optionalParams */});
}
auto bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return std::make_shared<ScalarFunctionExpression>(expressionType, function->copy(),
std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName);
}
std::shared_ptr<Expression> ExpressionBinder::createEqualityComparisonExpression(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right) {
return bindComparisonExpression(ExpressionType::EQUALS,
expression_vector{std::move(left), std::move(right)});
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,215 @@
#include "binder/binder.h"
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_binder.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "function/cast/vector_cast_functions.h"
#include "function/rewrite_function.h"
#include "function/scalar_macro_function.h"
#include "parser/expression/parsed_expression_visitor.h"
#include "parser/expression/parsed_function_expression.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::function;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) {
auto funcExpr = expr.constPtrCast<ParsedFunctionExpression>();
auto functionName = funcExpr->getNormalizedFunctionName();
auto transaction = transaction::Transaction::Get(*context);
auto catalog = Catalog::Get(*context);
auto entry = catalog->getFunctionEntry(transaction, functionName);
switch (entry->getType()) {
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return bindScalarFunctionExpression(expr, functionName);
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
return bindRewriteFunctionExpression(expr);
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return bindAggregateFunctionExpression(expr, functionName, funcExpr->getIsDistinct());
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return bindMacroExpression(expr, functionName);
default:
throw BinderException(
stringFormat("{} is a {}. Scalar function, aggregate function or macro was expected. ",
functionName, CatalogEntryTypeUtils::toString(entry->getType())));
}
}
std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto expr = bindExpression(*parsedExpression.getChild(i));
if (parsedExpression.getChild(i)->hasAlias()) {
expr->setAlias(parsedExpression.getChild(i)->getAlias());
}
children.push_back(expr);
}
return bindScalarFunctionExpression(children, functionName,
parsedExpression.constCast<ParsedFunctionExpression>().getOptionalArguments());
}
static std::vector<LogicalType> getTypes(const expression_vector& exprs) {
std::vector<LogicalType> result;
for (auto& expr : exprs) {
result.push_back(expr->getDataType().copy());
}
return result;
}
std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName,
std::vector<std::string> optionalArguments) {
auto catalog = Catalog::Get(*context);
auto transaction = transaction::Transaction::Get(*context);
auto childrenTypes = getTypes(children);
auto entry = catalog->getFunctionEntry(transaction, functionName);
auto function = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes,
entry->ptrCast<FunctionCatalogEntry>())
->ptrCast<ScalarFunction>()
->copy();
if (children.size() == 2 && children[1]->expressionType == ExpressionType::LAMBDA) {
if (!function->isListLambda) {
throw BinderException(stringFormat("{} does not support lambda input.", functionName));
}
bindLambdaExpression(*children[0], *children[1]);
}
expression_vector childrenAfterCast;
std::unique_ptr<FunctionBindData> bindData;
auto bindInput =
ScalarBindFuncInput{children, function.get(), context, std::move(optionalArguments)};
if (functionName == CastAnyFunction::name) {
bindData = function->bindFunc(bindInput);
if (bindData == nullptr) { // No need to cast.
// TODO(Xiyang): We should return a deep copy otherwise the same expression might
// appear in the final projection list repeatedly.
// E.g. RETURN cast([NULL], "INT64[1][]"), cast([NULL], "INT64[1][][]")
return children[0];
}
auto childAfterCast = children[0];
if (children[0]->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
childAfterCast = implicitCastIfNecessary(children[0], LogicalType::STRING());
}
childrenAfterCast.push_back(std::move(childAfterCast));
} else {
if (function->bindFunc) {
bindData = function->bindFunc(bindInput);
} else {
bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
}
if (!bindData->paramTypes.empty()) {
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
implicitCastIfNecessary(children[i], bindData->paramTypes[i]));
}
} else {
for (auto i = 0u; i < children.size(); ++i) {
auto id = function->isVarLength ? function->parameterTypeIDs[0] :
function->parameterTypeIDs[i];
auto type = LogicalType(id);
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], type));
}
}
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return std::make_shared<ScalarFunctionExpression>(ExpressionType::FUNCTION, std::move(function),
std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName);
}
std::shared_ptr<Expression> ExpressionBinder::bindRewriteFunctionExpression(
const ParsedExpression& expr) {
auto& funcExpr = expr.constCast<ParsedFunctionExpression>();
expression_vector children;
for (auto i = 0u; i < expr.getNumChildren(); ++i) {
children.push_back(bindExpression(*expr.getChild(i)));
}
auto childrenTypes = getTypes(children);
auto functionName = funcExpr.getNormalizedFunctionName();
auto transaction = transaction::Transaction::Get(*context);
auto entry = Catalog::Get(*context)->getFunctionEntry(transaction, functionName);
auto match = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes,
entry->ptrCast<FunctionCatalogEntry>());
auto function = match->constPtrCast<RewriteFunction>();
KU_ASSERT(function->rewriteFunc != nullptr);
auto input = RewriteFunctionBindInput(context, this, children);
return function->rewriteFunc(input);
}
std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) {
std::vector<LogicalType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
childrenTypes.push_back(child->dataType.copy());
children.push_back(std::move(child));
}
auto transaction = transaction::Transaction::Get(*context);
auto entry = Catalog::Get(*context)->getFunctionEntry(transaction, functionName);
auto function = BuiltInFunctionsUtils::matchAggregateFunction(functionName, childrenTypes,
isDistinct, entry->ptrCast<FunctionCatalogEntry>())
->copy();
if (function.paramRewriteFunc) {
function.paramRewriteFunc(children);
}
if (functionName == CollectFunction::name && parsedExpression.hasAlias() &&
children[0]->getDataType().getLogicalTypeID() == LogicalTypeID::NODE) {
auto& node = children[0]->constCast<NodeExpression>();
binder->scope.memorizeTableEntries(parsedExpression.getAlias(), node.getEntries());
}
auto uniqueExpressionName =
AggregateFunctionExpression::getUniqueName(function.name, children, function.isDistinct);
if (children.empty()) {
uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName);
}
std::unique_ptr<FunctionBindData> bindData;
if (function.bindFunc) {
auto bindInput = ScalarBindFuncInput{children, &function, context,
std::vector<std::string>{} /* optionalParams */};
bindData = function.bindFunc(bindInput);
} else {
bindData = std::make_unique<FunctionBindData>(LogicalType(function.returnTypeID));
}
return std::make_shared<AggregateFunctionExpression>(std::move(function), std::move(bindData),
std::move(children), uniqueExpressionName);
}
std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
const ParsedExpression& parsedExpression, const std::string& macroName) {
auto transaction = transaction::Transaction::Get(*context);
auto scalarMacroFunction =
Catalog::Get(*context)->getScalarMacroFunction(transaction, macroName);
auto macroExpr = scalarMacroFunction->expression->copy();
auto parameterVals = scalarMacroFunction->getDefaultParameterVals();
auto& parsedFuncExpr = parsedExpression.constCast<ParsedFunctionExpression>();
auto positionalArgs = scalarMacroFunction->getPositionalArgs();
if (parsedFuncExpr.getNumChildren() > scalarMacroFunction->getNumArgs() ||
parsedFuncExpr.getNumChildren() < positionalArgs.size()) {
throw BinderException{"Invalid number of arguments for macro " + macroName + "."};
}
// Bind positional arguments.
for (auto i = 0u; i < positionalArgs.size(); i++) {
parameterVals[positionalArgs[i]] = parsedFuncExpr.getChild(i);
}
// Bind arguments with default values.
for (auto i = positionalArgs.size(); i < parsedFuncExpr.getNumChildren(); i++) {
auto parameterName =
scalarMacroFunction->getDefaultParameterName(i - positionalArgs.size());
parameterVals[parameterName] = parsedFuncExpr.getChild(i);
}
auto replacer = MacroParameterReplacer(parameterVals);
return bindExpression(*replacer.replace(std::move(macroExpr)));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,37 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/lambda_expression.h"
#include "parser/expression/parsed_lambda_expression.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
void ExpressionBinder::bindLambdaExpression(const Expression& lambdaInput,
Expression& lambdaExpr) const {
ExpressionUtil::validateDataType(lambdaInput, LogicalTypeID::LIST);
auto& listChildType = ListType::getChildType(lambdaInput.getDataType());
auto& boundLambdaExpr = lambdaExpr.cast<LambdaExpression>();
auto& parsedLambdaExpr =
boundLambdaExpr.getParsedLambdaExpr()->constCast<ParsedLambdaExpression>();
auto prevScope = binder->saveScope();
for (auto& varName : parsedLambdaExpr.getVarNames()) {
binder->createVariable(varName, listChildType);
}
auto funcExpr =
binder->getExpressionBinder()->bindExpression(*parsedLambdaExpr.getFunctionExpr());
binder->restoreScope(std::move(prevScope));
boundLambdaExpr.cast(funcExpr->getDataType().copy());
boundLambdaExpr.setFunctionExpr(std::move(funcExpr));
}
std::shared_ptr<Expression> ExpressionBinder::bindLambdaExpression(
const parser::ParsedExpression& parsedExpr) const {
auto uniqueName = getUniqueName(parsedExpr.getRawName());
return std::make_shared<LambdaExpression>(parsedExpr.copy(), uniqueName);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,43 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_literal_expression.h"
using namespace lbug::parser;
using namespace lbug::common;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindLiteralExpression(
const ParsedExpression& parsedExpression) const {
auto& literalExpression = parsedExpression.constCast<ParsedLiteralExpression>();
auto value = literalExpression.getValue();
if (value.isNull()) {
return createNullLiteralExpression(value);
}
return createLiteralExpression(value);
}
std::shared_ptr<Expression> ExpressionBinder::createLiteralExpression(const Value& value) const {
auto uniqueName = binder->getUniqueExpressionName(value.toString());
return std::make_unique<LiteralExpression>(value, uniqueName);
}
std::shared_ptr<Expression> ExpressionBinder::createLiteralExpression(
const std::string& strVal) const {
return createLiteralExpression(Value(strVal));
}
std::shared_ptr<Expression> ExpressionBinder::createNullLiteralExpression() const {
return make_shared<LiteralExpression>(Value::createNullValue(),
binder->getUniqueExpressionName("NULL"));
}
std::shared_ptr<Expression> ExpressionBinder::createNullLiteralExpression(
const Value& value) const {
return make_shared<LiteralExpression>(value, binder->getUniqueExpressionName("NULL"));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,48 @@
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_binder.h"
#include "function/null/vector_null_functions.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::function;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
children.push_back(bindExpression(*parsedExpression.getChild(i)));
}
return bindNullOperatorExpression(parsedExpression.getExpressionType(), children);
}
std::shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
ExpressionType expressionType, const expression_vector& children) {
expression_vector childrenAfterCast;
std::vector<LogicalTypeID> inputTypeIDs;
for (auto& child : children) {
inputTypeIDs.push_back(child->getDataType().getLogicalTypeID());
if (child->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalType::BOOL()));
} else {
childrenAfterCast.push_back(child);
}
}
auto functionName = ExpressionTypeUtil::toString(expressionType);
function::scalar_func_exec_t execFunc;
function::VectorNullFunction::bindExecFunction(expressionType, childrenAfterCast, execFunc);
function::scalar_func_select_t selectFunc;
function::VectorNullFunction::bindSelectFunction(expressionType, childrenAfterCast, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::BOOL());
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
auto func = std::make_unique<ScalarFunction>(functionName, inputTypeIDs, LogicalTypeID::BOOL,
execFunc, selectFunc);
return make_shared<ScalarFunctionExpression>(expressionType, std::move(func),
std::move(bindData), std::move(childrenAfterCast), uniqueExpressionName);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,26 @@
#include "binder/expression/parameter_expression.h"
#include "binder/expression_binder.h"
#include "common/exception/binder.h"
#include "parser/expression/parsed_parameter_expression.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindParameterExpression(
const ParsedExpression& parsedExpression) {
auto& parsedParameterExpression = parsedExpression.constCast<ParsedParameterExpression>();
auto parameterName = parsedParameterExpression.getParameterName();
if (knownParameters.contains(parameterName)) {
return make_shared<ParameterExpression>(parameterName, *knownParameters.at(parameterName));
}
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot find parameter {}. This should not happen.", parameterName));
// LCOV_EXCL_STOP
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,124 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/node_rel_expression.h"
#include "binder/expression_binder.h"
#include "common/cast.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "function/struct/vector_struct_functions.h"
#include "parser/expression/parsed_property_expression.h"
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
static bool isNodeOrRelPattern(const Expression& expression) {
return ExpressionUtil::isNodePattern(expression) || ExpressionUtil::isRelPattern(expression);
}
static bool isStructPattern(const Expression& expression) {
auto logicalTypeID = expression.getDataType().getLogicalTypeID();
return logicalTypeID == LogicalTypeID::NODE || logicalTypeID == LogicalTypeID::REL ||
logicalTypeID == LogicalTypeID::STRUCT;
}
expression_vector ExpressionBinder::bindPropertyStarExpression(
const parser::ParsedExpression& parsedExpression) {
auto child = bindExpression(*parsedExpression.getChild(0));
if (isNodeOrRelPattern(*child)) {
return bindNodeOrRelPropertyStarExpression(*child);
} else if (isStructPattern(*child)) {
return bindStructPropertyStarExpression(child);
} else {
throw BinderException(stringFormat("Cannot bind property for expression {} with type {}.",
child->toString(), ExpressionTypeUtil::toString(child->expressionType)));
}
}
expression_vector ExpressionBinder::bindNodeOrRelPropertyStarExpression(const Expression& child) {
expression_vector result;
auto& nodeOrRel = child.constCast<NodeOrRelExpression>();
for (auto& property : nodeOrRel.getPropertyExpressions()) {
if (Binder::reservedInPropertyLookup(property->getPropertyName())) {
continue;
}
result.push_back(property);
}
return result;
}
expression_vector ExpressionBinder::bindStructPropertyStarExpression(
const std::shared_ptr<Expression>& child) {
expression_vector result;
const auto& childType = child->getDataType();
for (auto& field : StructType::getFields(childType)) {
result.push_back(bindStructPropertyExpression(child, field.getName()));
}
return result;
}
std::shared_ptr<Expression> ExpressionBinder::bindPropertyExpression(
const ParsedExpression& parsedExpression) {
auto& propertyExpression = parsedExpression.constCast<ParsedPropertyExpression>();
if (propertyExpression.isStar()) {
throw BinderException(stringFormat("Cannot bind {} as a single property expression.",
parsedExpression.toString()));
}
auto propertyName = propertyExpression.getPropertyName();
auto child = bindExpression(*parsedExpression.getChild(0));
ExpressionUtil::validateDataType(*child,
std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT,
LogicalTypeID::ANY});
if (config.bindOrderByAfterAggregate) {
// See the declaration of this field for more information.
return bindStructPropertyExpression(child, propertyName);
}
if (isNodeOrRelPattern(*child)) {
if (Binder::reservedInPropertyLookup(propertyName)) {
// Note we don't expose direct access to internal properties in case user tries to
// modify them. However, we can expose indirect read-only access through function e.g.
// ID().
throw BinderException(
propertyName + " is reserved for system usage. External access is not allowed.");
}
return bindNodeOrRelPropertyExpression(*child, propertyName);
} else if (isStructPattern(*child)) {
return bindStructPropertyExpression(child, propertyName);
} else if (child->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
return createVariableExpression(LogicalType::ANY(), binder->getUniqueExpressionName(""));
} else {
throw BinderException(stringFormat("Cannot bind property for expression {} with type {}.",
child->toString(), ExpressionTypeUtil::toString(child->expressionType)));
}
}
std::shared_ptr<Expression> ExpressionBinder::bindNodeOrRelPropertyExpression(
const Expression& child, const std::string& propertyName) {
auto& nodeOrRel = child.constCast<NodeOrRelExpression>();
// TODO(Xiyang): we should be able to remove l97-l100 after removing propertyDataExprs from node
// & rel expression.
if (propertyName == InternalKeyword::ID &&
child.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) {
auto& node = ku_dynamic_cast<const NodeExpression&>(child);
return node.getInternalID();
}
if (!nodeOrRel.hasPropertyExpression(propertyName)) {
throw BinderException(
"Cannot find property " + propertyName + " for " + child.toString() + ".");
}
// We always create new object when binding expression except when referring to an existing
// alias when binding variables.
return nodeOrRel.getPropertyExpression(propertyName)->copy();
}
std::shared_ptr<Expression> ExpressionBinder::bindStructPropertyExpression(
std::shared_ptr<Expression> child, const std::string& propertyName) {
auto children = expression_vector{std::move(child), createLiteralExpression(propertyName)};
return bindScalarFunctionExpression(children, function::StructExtractFunctions::name);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,74 @@
#include "binder/binder.h"
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/subquery_expression.h"
#include "binder/expression_binder.h"
#include "catalog/catalog.h"
#include "common/types/value/value.h"
#include "function/aggregate/count_star.h"
#include "function/built_in_function_utils.h"
#include "parser/expression/parsed_subquery_expression.h"
#include "transaction/transaction.h"
using namespace lbug::parser;
using namespace lbug::common;
using namespace lbug::function;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindSubqueryExpression(
const ParsedExpression& parsedExpr) {
auto& subqueryExpr = ku_dynamic_cast<const ParsedSubqueryExpression&>(parsedExpr);
auto prevScope = binder->saveScope();
auto boundGraphPattern = binder->bindGraphPattern(subqueryExpr.getPatternElements());
if (subqueryExpr.hasWhereClause()) {
boundGraphPattern.where = binder->bindWhereExpression(*subqueryExpr.getWhereClause());
}
binder->rewriteMatchPattern(boundGraphPattern);
auto subqueryType = subqueryExpr.getSubqueryType();
auto dataType =
subqueryType == SubqueryType::COUNT ? LogicalType::INT64() : LogicalType::BOOL();
auto rawName = subqueryExpr.getRawName();
auto uniqueName = binder->getUniqueExpressionName(rawName);
auto boundSubqueryExpr = make_shared<SubqueryExpression>(subqueryType, std::move(dataType),
std::move(boundGraphPattern.queryGraphCollection), uniqueName, std::move(rawName));
boundSubqueryExpr->setWhereExpression(boundGraphPattern.where);
// Bind projection
auto entry = catalog::Catalog::Get(*context)->getFunctionEntry(
transaction::Transaction::Get(*context), CountStarFunction::name);
auto function = BuiltInFunctionsUtils::matchAggregateFunction(CountStarFunction::name,
std::vector<LogicalType>{}, false, entry->ptrCast<catalog::FunctionCatalogEntry>());
auto bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
auto countStarExpr =
std::make_shared<AggregateFunctionExpression>(function->copy(), std::move(bindData),
expression_vector{}, binder->getUniqueExpressionName(CountStarFunction::name));
boundSubqueryExpr->setCountStarExpr(countStarExpr);
std::shared_ptr<Expression> projectionExpr;
switch (subqueryType) {
case SubqueryType::COUNT: {
// Rewrite COUNT subquery as COUNT(*)
projectionExpr = countStarExpr;
} break;
case SubqueryType::EXISTS: {
// Rewrite EXISTS subquery as COUNT(*) > 0
auto literalExpr = createLiteralExpression(Value(static_cast<int64_t>(0)));
projectionExpr = bindComparisonExpression(ExpressionType::GREATER_THAN,
expression_vector{countStarExpr, literalExpr});
} break;
default:
KU_UNREACHABLE;
}
// Use the same unique identifier for projection & subquery expression. We will replace subquery
// expression with projection expression during processing.
projectionExpr->setUniqueName(uniqueName);
boundSubqueryExpr->setProjectionExpr(projectionExpr);
if (subqueryExpr.hasHint()) {
boundSubqueryExpr->setHint(binder->bindJoinHint(
*boundSubqueryExpr->getQueryGraphCollection(), *subqueryExpr.getHint()));
}
binder->restoreScope(std::move(prevScope));
return boundSubqueryExpr;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,41 @@
#include "binder/binder.h"
#include "binder/expression/variable_expression.h"
#include "binder/expression_binder.h"
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "parser/expression/parsed_variable_expression.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindVariableExpression(
const ParsedExpression& parsedExpression) const {
auto& variableExpression = ku_dynamic_cast<const ParsedVariableExpression&>(parsedExpression);
auto variableName = variableExpression.getVariableName();
return bindVariableExpression(variableName);
}
std::shared_ptr<Expression> ExpressionBinder::bindVariableExpression(
const std::string& varName) const {
if (binder->scope.contains(varName)) {
return binder->scope.getExpression(varName);
}
throw BinderException(ExceptionMessage::variableNotInScope(varName));
}
std::shared_ptr<Expression> ExpressionBinder::createVariableExpression(LogicalType logicalType,
std::string_view name) const {
return createVariableExpression(std::move(logicalType), std::string(name));
}
std::shared_ptr<Expression> ExpressionBinder::createVariableExpression(LogicalType logicalType,
std::string name) const {
return std::make_shared<VariableExpression>(std::move(logicalType),
binder->getUniqueExpressionName(name), std::move(name));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,278 @@
#include "binder/binder.h"
#include "binder/bound_statement_rewriter.h"
#include "catalog/catalog.h"
#include "common/copier_config/csv_reader_config.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "function/built_in_function_utils.h"
#include "function/table/table_function.h"
#include "parser/statement.h"
#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"
#include "transaction/transaction.h"
using namespace lbug::catalog;
using namespace lbug::common;
using namespace lbug::function;
using namespace lbug::parser;
using namespace lbug::processor;
namespace lbug {
namespace binder {
std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
std::unique_ptr<BoundStatement> boundStatement;
switch (statement.getStatementType()) {
case StatementType::CREATE_TABLE: {
boundStatement = bindCreateTable(statement);
} break;
case StatementType::CREATE_TYPE: {
boundStatement = bindCreateType(statement);
} break;
case StatementType::CREATE_SEQUENCE: {
boundStatement = bindCreateSequence(statement);
} break;
case StatementType::COPY_FROM: {
boundStatement = bindCopyFromClause(statement);
} break;
case StatementType::COPY_TO: {
boundStatement = bindCopyToClause(statement);
} break;
case StatementType::DROP: {
boundStatement = bindDrop(statement);
} break;
case StatementType::ALTER: {
boundStatement = bindAlter(statement);
} break;
case StatementType::QUERY: {
boundStatement = bindQuery(statement);
} break;
case StatementType::STANDALONE_CALL: {
boundStatement = bindStandaloneCall(statement);
} break;
case StatementType::STANDALONE_CALL_FUNCTION: {
boundStatement = bindStandaloneCallFunction(statement);
} break;
case StatementType::EXPLAIN: {
boundStatement = bindExplain(statement);
} break;
case StatementType::CREATE_MACRO: {
boundStatement = bindCreateMacro(statement);
} break;
case StatementType::TRANSACTION: {
boundStatement = bindTransaction(statement);
} break;
case StatementType::EXTENSION: {
boundStatement = bindExtension(statement);
} break;
case StatementType::EXPORT_DATABASE: {
boundStatement = bindExportDatabaseClause(statement);
} break;
case StatementType::IMPORT_DATABASE: {
boundStatement = bindImportDatabaseClause(statement);
} break;
case StatementType::ATTACH_DATABASE: {
boundStatement = bindAttachDatabase(statement);
} break;
case StatementType::DETACH_DATABASE: {
boundStatement = bindDetachDatabase(statement);
} break;
case StatementType::USE_DATABASE: {
boundStatement = bindUseDatabase(statement);
} break;
case StatementType::EXTENSION_CLAUSE: {
boundStatement = bindExtensionClause(statement);
} break;
default: {
KU_UNREACHABLE;
}
}
BoundStatementRewriter::rewrite(*boundStatement, *clientContext);
return boundStatement;
}
std::shared_ptr<Expression> Binder::bindWhereExpression(const ParsedExpression& parsedExpression) {
auto whereExpression = expressionBinder.bindExpression(parsedExpression);
expressionBinder.implicitCastIfNecessary(whereExpression, LogicalType::BOOL());
return whereExpression;
}
std::shared_ptr<Expression> Binder::createVariable(std::string_view name, LogicalTypeID typeID) {
return createVariable(std::string(name), LogicalType{typeID});
}
std::shared_ptr<Expression> Binder::createVariable(const std::string& name,
LogicalTypeID logicalTypeID) {
return createVariable(name, LogicalType{logicalTypeID});
}
std::shared_ptr<Expression> Binder::createVariable(const std::string& name,
const LogicalType& dataType) {
if (scope.contains(name)) {
throw BinderException("Variable " + name + " already exists.");
}
auto expression = expressionBinder.createVariableExpression(dataType.copy(), name);
expression->setAlias(name);
addToScope(name, expression);
return expression;
}
std::shared_ptr<Expression> Binder::createInvisibleVariable(const std::string& name,
const LogicalType& dataType) const {
auto expression = expressionBinder.createVariableExpression(dataType.copy(), name);
expression->setAlias(name);
return expression;
}
expression_vector Binder::createVariables(const std::vector<std::string>& names,
const std::vector<common::LogicalType>& types) {
KU_ASSERT(names.size() == types.size());
expression_vector variables;
for (auto i = 0u; i < names.size(); ++i) {
variables.push_back(createVariable(names[i], types[i]));
}
return variables;
}
expression_vector Binder::createInvisibleVariables(const std::vector<std::string>& names,
const std::vector<LogicalType>& types) const {
KU_ASSERT(names.size() == types.size());
expression_vector variables;
for (auto i = 0u; i < names.size(); ++i) {
variables.push_back(createInvisibleVariable(names[i], types[i]));
}
return variables;
}
std::string Binder::getUniqueExpressionName(const std::string& name) {
return "_" + std::to_string(lastExpressionId++) + "_" + name;
}
struct ReservedNames {
// Column name that might conflict with internal names.
static std::unordered_set<std::string> getColumnNames() {
return {
InternalKeyword::ID,
InternalKeyword::LABEL,
InternalKeyword::SRC,
InternalKeyword::DST,
InternalKeyword::DIRECTION,
InternalKeyword::LENGTH,
InternalKeyword::NODES,
InternalKeyword::RELS,
InternalKeyword::PLACE_HOLDER,
StringUtils::getUpper(InternalKeyword::ROW_OFFSET),
StringUtils::getUpper(InternalKeyword::SRC_OFFSET),
StringUtils::getUpper(InternalKeyword::DST_OFFSET),
};
}
// Properties that should be hidden from user access.
static std::unordered_set<std::string> getPropertyLookupName() {
return {
InternalKeyword::ID,
};
}
};
bool Binder::reservedInColumnName(const std::string& name) {
auto normalizedName = StringUtils::getUpper(name);
return ReservedNames::getColumnNames().contains(normalizedName);
}
bool Binder::reservedInPropertyLookup(const std::string& name) {
auto normalizedName = StringUtils::getUpper(name);
return ReservedNames::getPropertyLookupName().contains(normalizedName);
}
void Binder::addToScope(const std::vector<std::string>& names, const expression_vector& exprs) {
KU_ASSERT(names.size() == exprs.size());
for (auto i = 0u; i < names.size(); ++i) {
addToScope(names[i], exprs[i]);
}
}
void Binder::addToScope(const std::string& name, std::shared_ptr<Expression> expr) {
// TODO(Xiyang): assert name not in scope.
// Note to Xiyang: I don't think the TODO still stands here. I tried adding the assertion, but
// it failed a few tests. You may want to revisit this TODO.
scope.addExpression(name, std::move(expr));
}
BinderScope Binder::saveScope() const {
return scope.copy();
}
void Binder::restoreScope(BinderScope prevScope) {
scope = std::move(prevScope);
}
void Binder::replaceExpressionInScope(const std::string& oldName, const std::string& newName,
std::shared_ptr<Expression> expression) {
scope.replaceExpression(oldName, newName, expression);
}
TableFunction Binder::getScanFunction(const FileTypeInfo& typeInfo,
const FileScanInfo& fileScanInfo) const {
Function* func = nullptr;
std::vector<LogicalType> inputTypes;
inputTypes.push_back(LogicalType::STRING());
auto catalog = Catalog::Get(*clientContext);
auto transaction = transaction::Transaction::Get(*clientContext);
switch (typeInfo.fileType) {
case FileType::PARQUET: {
auto entry = catalog->getFunctionEntry(transaction, ParquetScanFunction::name);
func = BuiltInFunctionsUtils::matchFunction(ParquetScanFunction::name, inputTypes,
entry->ptrCast<FunctionCatalogEntry>());
} break;
case FileType::NPY: {
auto entry = catalog->getFunctionEntry(transaction, NpyScanFunction::name);
func = BuiltInFunctionsUtils::matchFunction(NpyScanFunction::name, inputTypes,
entry->ptrCast<FunctionCatalogEntry>());
} break;
case FileType::CSV: {
bool containCompressedCSV = std::any_of(fileScanInfo.filePaths.begin(),
fileScanInfo.filePaths.end(), [&](const auto& file) {
return VirtualFileSystem::GetUnsafe(*clientContext)->isCompressedFile(file);
});
auto csvConfig = CSVReaderConfig::construct(fileScanInfo.options);
// Parallel CSV scanning is only allowed:
// 1. No newline character inside the csv body.
// 2. The CSV file to scan is not compressed (because we couldn't perform seek in such
// case).
// 3. Not explicitly set by the user to use the serial csv reader.
auto name = (csvConfig.parallel && !containCompressedCSV) ? ParallelCSVScan::name :
SerialCSVScan::name;
auto entry = catalog->getFunctionEntry(transaction, name);
func = BuiltInFunctionsUtils::matchFunction(name, inputTypes,
entry->ptrCast<FunctionCatalogEntry>());
} break;
case FileType::UNKNOWN: {
try {
auto name = stringFormat("{}_SCAN", typeInfo.fileTypeStr);
auto entry = catalog->getFunctionEntry(transaction, name);
func = BuiltInFunctionsUtils::matchFunction(name, inputTypes,
entry->ptrCast<FunctionCatalogEntry>());
} catch (...) {
if (typeInfo.fileTypeStr == "") {
throw BinderException{"Cannot infer the format of the given file. Please "
"set the file format explicitly by (file_format=<type>)."};
}
throw BinderException{
stringFormat("Cannot load from file type {}. If this file type is part of a lbug "
"extension please load the extension then try again.",
typeInfo.fileTypeStr)};
}
} break;
default:
KU_UNREACHABLE;
}
return *func->ptrCast<TableFunction>();
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,27 @@
#include "binder/binder_scope.h"
namespace lbug {
namespace binder {
void BinderScope::addExpression(const std::string& varName,
std::shared_ptr<Expression> expression) {
nameToExprIdx.insert({varName, expressions.size()});
expressions.push_back(std::move(expression));
}
void BinderScope::replaceExpression(const std::string& oldName, const std::string& newName,
std::shared_ptr<Expression> expression) {
KU_ASSERT(nameToExprIdx.contains(oldName));
auto idx = nameToExprIdx.at(oldName);
expressions[idx] = std::move(expression);
nameToExprIdx.erase(oldName);
nameToExprIdx.insert({newName, idx});
}
void BinderScope::clear() {
expressions.clear();
nameToExprIdx.clear();
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,36 @@
#include "binder/bound_scan_source.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
expression_vector BoundTableScanSource::getWarningColumns() const {
expression_vector warningDataExprs;
auto& columns = info.bindData->columns;
switch (type) {
case ScanSourceType::FILE: {
auto bindData = info.bindData->constPtrCast<function::ScanFileBindData>();
for (auto i = bindData->numWarningDataColumns; i >= 1; --i) {
KU_ASSERT(i < columns.size());
warningDataExprs.push_back(columns[columns.size() - i]);
}
} break;
default:
break;
}
return warningDataExprs;
}
bool BoundTableScanSource::getIgnoreErrorsOption() const {
return info.bindData->getIgnoreErrorsOption();
}
bool BoundQueryScanSource::getIgnoreErrorsOption() const {
return info.options.contains(CopyConstants::IGNORE_ERRORS_OPTION_NAME) ?
info.options.at(CopyConstants::IGNORE_ERRORS_OPTION_NAME).getValue<bool>() :
CopyConstants::DEFAULT_IGNORE_ERRORS;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,20 @@
#include "binder/bound_statement_result.h"
#include "binder/expression/literal_expression.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
BoundStatementResult BoundStatementResult::createSingleStringColumnResult(
const std::string& columnName) {
auto result = BoundStatementResult();
auto value = Value(LogicalType::STRING(), columnName);
auto stringColumn = std::make_shared<LiteralExpression>(std::move(value), columnName);
result.addColumn(columnName, stringColumn);
return result;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,27 @@
#include "binder/bound_statement_rewriter.h"
#include "binder/rewriter/match_clause_pattern_label_rewriter.h"
#include "binder/rewriter/normalized_query_part_match_rewriter.h"
#include "binder/rewriter/with_clause_projection_rewriter.h"
#include "binder/visitor/default_type_solver.h"
namespace lbug {
namespace binder {
void BoundStatementRewriter::rewrite(BoundStatement& boundStatement,
main::ClientContext& clientContext) {
auto withClauseProjectionRewriter = WithClauseProjectionRewriter();
withClauseProjectionRewriter.visitUnsafe(boundStatement);
auto normalizedQueryPartMatchRewriter = NormalizedQueryPartMatchRewriter(&clientContext);
normalizedQueryPartMatchRewriter.visitUnsafe(boundStatement);
auto matchClausePatternLabelRewriter = MatchClausePatternLabelRewriter(clientContext);
matchClausePatternLabelRewriter.visitUnsafe(boundStatement);
auto defaultTypeSolver = DefaultTypeSolver();
defaultTypeSolver.visit(boundStatement);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,222 @@
#include "binder/bound_statement_visitor.h"
#include "binder/bound_explain.h"
#include "binder/copy/bound_copy_from.h"
#include "binder/copy/bound_copy_to.h"
#include "binder/query/bound_regular_query.h"
#include "common/cast.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
void BoundStatementVisitor::visit(const BoundStatement& statement) {
switch (statement.getStatementType()) {
case StatementType::QUERY: {
visitRegularQuery(statement);
} break;
case StatementType::CREATE_SEQUENCE: {
visitCreateSequence(statement);
} break;
case StatementType::DROP: {
visitDrop(statement);
} break;
case StatementType::CREATE_TABLE: {
visitCreateTable(statement);
} break;
case StatementType::CREATE_TYPE: {
visitCreateType(statement);
} break;
case StatementType::ALTER: {
visitAlter(statement);
} break;
case StatementType::COPY_FROM: {
visitCopyFrom(statement);
} break;
case StatementType::COPY_TO: {
visitCopyTo(statement);
} break;
case StatementType::STANDALONE_CALL: {
visitStandaloneCall(statement);
} break;
case StatementType::EXPLAIN: {
visitExplain(statement);
} break;
case StatementType::CREATE_MACRO: {
visitCreateMacro(statement);
} break;
case StatementType::TRANSACTION: {
visitTransaction(statement);
} break;
case StatementType::EXTENSION: {
visitExtension(statement);
} break;
case StatementType::EXPORT_DATABASE: {
visitExportDatabase(statement);
} break;
case StatementType::IMPORT_DATABASE: {
visitImportDatabase(statement);
} break;
case StatementType::ATTACH_DATABASE: {
visitAttachDatabase(statement);
} break;
case StatementType::DETACH_DATABASE: {
visitDetachDatabase(statement);
} break;
case StatementType::USE_DATABASE: {
visitUseDatabase(statement);
} break;
case StatementType::STANDALONE_CALL_FUNCTION: {
visitStandaloneCallFunction(statement);
} break;
case StatementType::EXTENSION_CLAUSE: {
visitExtensionClause(statement);
} break;
default:
KU_UNREACHABLE;
}
}
void BoundStatementVisitor::visitUnsafe(BoundStatement& statement) {
switch (statement.getStatementType()) {
case StatementType::QUERY: {
visitRegularQueryUnsafe(statement);
} break;
default:
break;
}
}
void BoundStatementVisitor::visitCopyFrom(const BoundStatement& statement) {
auto& copyFrom = ku_dynamic_cast<const BoundCopyFrom&>(statement);
if (copyFrom.getInfo()->source->type == ScanSourceType::QUERY) {
auto querySource = ku_dynamic_cast<BoundQueryScanSource*>(copyFrom.getInfo()->source.get());
visit(*querySource->statement);
}
}
void BoundStatementVisitor::visitCopyTo(const BoundStatement& statement) {
auto& copyTo = ku_dynamic_cast<const BoundCopyTo&>(statement);
visitRegularQuery(*copyTo.getRegularQuery());
}
void BoundStatementVisitor::visitRegularQuery(const BoundStatement& statement) {
auto& regularQuery = ku_dynamic_cast<const BoundRegularQuery&>(statement);
for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) {
visitSingleQuery(*regularQuery.getSingleQuery(i));
}
}
void BoundStatementVisitor::visitRegularQueryUnsafe(BoundStatement& statement) {
auto& regularQuery = statement.cast<BoundRegularQuery>();
for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) {
visitSingleQueryUnsafe(*regularQuery.getSingleQueryUnsafe(i));
}
}
void BoundStatementVisitor::visitSingleQuery(const NormalizedSingleQuery& singleQuery) {
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
visitQueryPart(*singleQuery.getQueryPart(i));
}
}
void BoundStatementVisitor::visitSingleQueryUnsafe(NormalizedSingleQuery& singleQuery) {
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
visitQueryPartUnsafe(*singleQuery.getQueryPartUnsafe(i));
}
}
void BoundStatementVisitor::visitQueryPart(const NormalizedQueryPart& queryPart) {
for (auto i = 0u; i < queryPart.getNumReadingClause(); ++i) {
visitReadingClause(*queryPart.getReadingClause(i));
}
for (auto i = 0u; i < queryPart.getNumUpdatingClause(); ++i) {
visitUpdatingClause(*queryPart.getUpdatingClause(i));
}
if (queryPart.hasProjectionBody()) {
visitProjectionBody(*queryPart.getProjectionBody());
if (queryPart.hasProjectionBodyPredicate()) {
visitProjectionBodyPredicate(queryPart.getProjectionBodyPredicate());
}
}
}
void BoundStatementVisitor::visitQueryPartUnsafe(NormalizedQueryPart& queryPart) {
for (auto i = 0u; i < queryPart.getNumReadingClause(); ++i) {
visitReadingClauseUnsafe(*queryPart.getReadingClause(i));
}
for (auto i = 0u; i < queryPart.getNumUpdatingClause(); ++i) {
visitUpdatingClause(*queryPart.getUpdatingClause(i));
}
if (queryPart.hasProjectionBody()) {
visitProjectionBody(*queryPart.getProjectionBody());
if (queryPart.hasProjectionBodyPredicate()) {
visitProjectionBodyPredicate(queryPart.getProjectionBodyPredicate());
}
}
}
void BoundStatementVisitor::visitExplain(const BoundStatement& statement) {
visit(*(statement.constCast<BoundExplain>()).getStatementToExplain());
}
void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& readingClause) {
switch (readingClause.getClauseType()) {
case ClauseType::MATCH: {
visitMatch(readingClause);
} break;
case ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case ClauseType::TABLE_FUNCTION_CALL: {
visitTableFunctionCall(readingClause);
} break;
case ClauseType::LOAD_FROM: {
visitLoadFrom(readingClause);
} break;
default:
KU_UNREACHABLE;
}
}
void BoundStatementVisitor::visitReadingClauseUnsafe(BoundReadingClause& readingClause) {
switch (readingClause.getClauseType()) {
case ClauseType::MATCH: {
visitMatchUnsafe(readingClause);
} break;
case ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case ClauseType::TABLE_FUNCTION_CALL: {
visitTableFunctionCall(readingClause);
} break;
case ClauseType::LOAD_FROM: {
visitLoadFrom(readingClause);
} break;
default:
KU_UNREACHABLE;
}
}
void BoundStatementVisitor::visitUpdatingClause(const BoundUpdatingClause& updatingClause) {
switch (updatingClause.getClauseType()) {
case ClauseType::SET: {
visitSet(updatingClause);
} break;
case ClauseType::DELETE_: {
visitDelete(updatingClause);
} break;
case ClauseType::INSERT: {
visitInsert(updatingClause);
} break;
case ClauseType::MERGE: {
visitMerge(updatingClause);
} break;
default:
KU_UNREACHABLE;
}
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,9 @@
add_library(
lbug_binder_ddl
OBJECT
bound_alter_info.cpp
property_definition.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_ddl>
PARENT_SCOPE)

View File

@@ -0,0 +1,43 @@
#include "binder/ddl/bound_alter_info.h"
namespace lbug {
namespace binder {
std::string BoundAlterInfo::toString() const {
std::string result = "Operation: ";
switch (alterType) {
case common::AlterType::RENAME: {
auto renameInfo = common::ku_dynamic_cast<BoundExtraRenameTableInfo*>(extraInfo.get());
result += "Rename Table " + tableName + " to " + renameInfo->newName;
break;
}
case common::AlterType::ADD_PROPERTY: {
auto addPropInfo = common::ku_dynamic_cast<BoundExtraAddPropertyInfo*>(extraInfo.get());
result +=
"Add Property " + addPropInfo->propertyDefinition.getName() + " to Table " + tableName;
break;
}
case common::AlterType::DROP_PROPERTY: {
auto dropPropInfo = common::ku_dynamic_cast<BoundExtraDropPropertyInfo*>(extraInfo.get());
result += "Drop Property " + dropPropInfo->propertyName + " from Table " + tableName;
break;
}
case common::AlterType::RENAME_PROPERTY: {
auto renamePropInfo =
common::ku_dynamic_cast<BoundExtraRenamePropertyInfo*>(extraInfo.get());
result += "Rename Property " + renamePropInfo->oldName + " to " + renamePropInfo->newName +
" in Table " + tableName;
break;
}
case common::AlterType::COMMENT: {
result += "Comment on Table " + tableName;
break;
}
default:
break;
}
return result;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,34 @@
#include "binder/ddl/property_definition.h"
#include "common/serializer/deserializer.h"
#include "common/serializer/serializer.h"
#include "parser/expression/parsed_literal_expression.h"
using namespace lbug::common;
using namespace lbug::parser;
namespace lbug {
namespace binder {
PropertyDefinition::PropertyDefinition(ColumnDefinition columnDefinition)
: columnDefinition{std::move(columnDefinition)} {
defaultExpr = std::make_unique<ParsedLiteralExpression>(Value::createNullValue(), "NULL");
}
void PropertyDefinition::serialize(Serializer& serializer) const {
serializer.serializeValue(columnDefinition.name);
columnDefinition.type.serialize(serializer);
defaultExpr->serialize(serializer);
}
PropertyDefinition PropertyDefinition::deserialize(Deserializer& deserializer) {
std::string name;
deserializer.deserializeValue(name);
auto type = LogicalType::deserialize(deserializer);
auto columnDefinition = ColumnDefinition(name, std::move(type));
auto defaultExpr = ParsedExpression::deserialize(deserializer);
return PropertyDefinition(std::move(columnDefinition), std::move(defaultExpr));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,19 @@
add_library(
lbug_binder_expression
OBJECT
aggregate_function_expression.cpp
case_expression.cpp
expression.cpp
expression_util.cpp
literal_expression.cpp
node_expression.cpp
node_rel_expression.cpp
parameter_expression.cpp
property_expression.cpp
rel_expression.cpp
scalar_function_expression.cpp
variable_expression.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:lbug_binder_expression>
PARENT_SCOPE)

View File

@@ -0,0 +1,22 @@
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/expression_util.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
std::string AggregateFunctionExpression::toStringInternal() const {
return stringFormat("{}({}{})", function.name, function.isDistinct ? "DISTINCT " : "",
ExpressionUtil::toString(children));
}
std::string AggregateFunctionExpression::getUniqueName(const std::string& functionName,
const expression_vector& children, bool isDistinct) {
return stringFormat("{}({}{})", functionName, isDistinct ? "DISTINCT " : "",
ExpressionUtil::getUniqueName(children));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,17 @@
#include "binder/expression/case_expression.h"
namespace lbug {
namespace binder {
std::string CaseExpression::toStringInternal() const {
std::string result = "CASE ";
for (auto& caseAlternative : caseAlternatives) {
result += "WHEN " + caseAlternative->whenExpression->toString() + " THEN " +
caseAlternative->thenExpression->toString();
}
result += " ELSE " + elseExpression->toString();
return result;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,34 @@
#include "binder/expression/expression.h"
#include "common/exception/binder.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
Expression::~Expression() = default;
void Expression::cast(const LogicalType&) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Data type of expression {} should not be modified.", toString()));
// LCOV_EXCL_STOP
}
expression_vector Expression::splitOnAND() {
expression_vector result;
if (ExpressionType::AND == expressionType) {
for (auto& child : children) {
for (auto& exp : child->splitOnAND()) {
result.push_back(exp);
}
}
} else {
result.push_back(shared_from_this());
}
return result;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,568 @@
#include "binder/expression/expression_util.h"
#include <algorithm>
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/node_rel_expression.h"
#include "binder/expression/parameter_expression.h"
#include "common/exception/binder.h"
#include "common/exception/runtime.h"
#include "common/type_utils.h"
#include "common/types/value/nested.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
expression_vector ExpressionUtil::getExpressionsWithDataType(const expression_vector& expressions,
LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}
uint32_t ExpressionUtil::find(const Expression* target, const expression_vector& expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
return i;
}
}
return INVALID_IDX;
}
std::string ExpressionUtil::toString(const expression_vector& expressions) {
if (expressions.empty()) {
return std::string{};
}
auto result = expressions[0]->toString();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->toString();
}
return result;
}
std::string ExpressionUtil::toStringOrdered(const expression_vector& expressions) {
auto expressions_ = expressions;
std::sort(expressions_.begin(), expressions_.end(),
[](const std::shared_ptr<Expression>& a, const std::shared_ptr<Expression>& b) {
return a->toString() < b->toString();
});
return toString(expressions_);
}
std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}
std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}
std::string ExpressionUtil::getUniqueName(const expression_vector& expressions) {
if (expressions.empty()) {
return std::string();
}
auto result = expressions[0]->getUniqueName();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->getUniqueName();
}
return result;
}
expression_vector ExpressionUtil::excludeExpression(const expression_vector& exprs,
const Expression& exprToExclude) {
expression_vector result;
for (auto& expr : exprs) {
if (*expr != exprToExclude) {
result.push_back(expr);
}
}
return result;
}
expression_vector ExpressionUtil::excludeExpressions(const expression_vector& expressions,
const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}
logical_type_vec_t ExpressionUtil::getDataTypes(const expression_vector& expressions) {
std::vector<LogicalType> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}
expression_vector ExpressionUtil::removeDuplication(const expression_vector& expressions) {
expression_vector result;
expression_set expressionSet;
for (auto& expression : expressions) {
if (expressionSet.contains(expression)) {
continue;
}
result.push_back(expression);
expressionSet.insert(expression);
}
return result;
}
bool ExpressionUtil::isEmptyPattern(const Expression& expression) {
if (expression.expressionType != ExpressionType::PATTERN) {
return false;
}
return expression.constCast<NodeOrRelExpression>().isEmpty();
}
bool ExpressionUtil::isNodePattern(const Expression& expression) {
if (expression.expressionType != ExpressionType::PATTERN) {
return false;
}
return expression.dataType.getLogicalTypeID() == LogicalTypeID::NODE;
};
bool ExpressionUtil::isRelPattern(const Expression& expression) {
if (expression.expressionType != ExpressionType::PATTERN) {
return false;
}
return expression.dataType.getLogicalTypeID() == LogicalTypeID::REL;
}
bool ExpressionUtil::isRecursiveRelPattern(const Expression& expression) {
if (expression.expressionType != ExpressionType::PATTERN) {
return false;
}
return expression.dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL;
}
bool ExpressionUtil::isNullLiteral(const Expression& expression) {
if (expression.expressionType != ExpressionType::LITERAL) {
return false;
}
return expression.constCast<LiteralExpression>().getValue().isNull();
}
bool ExpressionUtil::isBoolLiteral(const Expression& expression) {
if (expression.expressionType != ExpressionType::LITERAL) {
return false;
}
return expression.dataType == LogicalType::BOOL();
}
bool ExpressionUtil::isFalseLiteral(const Expression& expression) {
if (expression.expressionType != ExpressionType::LITERAL) {
return false;
}
return expression.constCast<LiteralExpression>().getValue().getValue<bool>() == false;
}
bool ExpressionUtil::isEmptyList(const Expression& expression) {
auto val = Value::createNullValue();
switch (expression.expressionType) {
case ExpressionType::LITERAL: {
val = expression.constCast<LiteralExpression>().getValue();
} break;
case ExpressionType::PARAMETER: {
val = expression.constCast<ParameterExpression>().getValue();
} break;
default:
return false;
}
if (val.getDataType().getLogicalTypeID() != LogicalTypeID::LIST) {
return false;
}
return val.getChildrenSize() == 0;
}
void ExpressionUtil::validateExpressionType(const Expression& expr, ExpressionType expectedType) {
if (expr.expressionType == expectedType) {
return;
}
throw BinderException(stringFormat("{} has type {} but {} was expected.", expr.toString(),
ExpressionTypeUtil::toString(expr.expressionType),
ExpressionTypeUtil::toString(expectedType)));
}
void ExpressionUtil::validateExpressionType(const Expression& expr,
std::vector<ExpressionType> expectedType) {
if (std::find(expectedType.begin(), expectedType.end(), expr.expressionType) !=
expectedType.end()) {
return;
}
std::string expectedTypesStr = "";
std::for_each(expectedType.begin(), expectedType.end(),
[&expectedTypesStr](ExpressionType type) {
expectedTypesStr += expectedTypesStr.empty() ? ExpressionTypeUtil::toString(type) :
"," + ExpressionTypeUtil::toString(type);
});
throw BinderException(stringFormat("{} has type {} but {} was expected.", expr.toString(),
ExpressionTypeUtil::toString(expr.expressionType), expectedTypesStr));
}
void ExpressionUtil::validateDataType(const Expression& expr, const LogicalType& expectedType) {
if (expr.getDataType() == expectedType) {
return;
}
throw BinderException(stringFormat("{} has data type {} but {} was expected.", expr.toString(),
expr.getDataType().toString(), expectedType.toString()));
}
void ExpressionUtil::validateDataType(const Expression& expr, LogicalTypeID expectedTypeID) {
if (expr.getDataType().getLogicalTypeID() == expectedTypeID) {
return;
}
throw BinderException(stringFormat("{} has data type {} but {} was expected.", expr.toString(),
expr.getDataType().toString(), LogicalTypeUtils::toString(expectedTypeID)));
}
void ExpressionUtil::validateDataType(const Expression& expr,
const std::vector<LogicalTypeID>& expectedTypeIDs) {
auto targetsSet =
std::unordered_set<LogicalTypeID>{expectedTypeIDs.begin(), expectedTypeIDs.end()};
if (targetsSet.contains(expr.getDataType().getLogicalTypeID())) {
return;
}
throw BinderException(stringFormat("{} has data type {} but {} was expected.", expr.toString(),
expr.getDataType().toString(), LogicalTypeUtils::toString(expectedTypeIDs)));
}
template<>
uint64_t ExpressionUtil::getLiteralValue(const Expression& expr) {
validateExpressionType(expr, ExpressionType::LITERAL);
validateDataType(expr, LogicalType::UINT64());
return expr.constCast<LiteralExpression>().getValue().getValue<uint64_t>();
}
template<>
int64_t ExpressionUtil::getLiteralValue(const Expression& expr) {
validateExpressionType(expr, ExpressionType::LITERAL);
validateDataType(expr, LogicalType::INT64());
return expr.constCast<LiteralExpression>().getValue().getValue<int64_t>();
}
template<>
bool ExpressionUtil::getLiteralValue(const Expression& expr) {
validateExpressionType(expr, ExpressionType::LITERAL);
validateDataType(expr, LogicalType::BOOL());
return expr.constCast<LiteralExpression>().getValue().getValue<bool>();
}
template<>
std::string ExpressionUtil::getLiteralValue(const Expression& expr) {
validateExpressionType(expr, ExpressionType::LITERAL);
validateDataType(expr, LogicalType::STRING());
return expr.constCast<LiteralExpression>().getValue().getValue<std::string>();
}
template<>
double ExpressionUtil::getLiteralValue(const Expression& expr) {
validateExpressionType(expr, ExpressionType::LITERAL);
validateDataType(expr, LogicalType::DOUBLE());
return expr.constCast<LiteralExpression>().getValue().getValue<double>();
}
// For primitive types, two types are compatible if they have the same id.
// For nested types, two types are compatible if they have the same id and their children are also
// compatible.
// E.g. [NULL] is compatible with [1,2]
// E.g. {a: NULL, b: NULL} is compatible with {a: [1,2], b: ['c']}
static bool compatible(const LogicalType& type, const LogicalType& target) {
if (type.isInternalType() != target.isInternalType()) {
return false;
}
if (type.getLogicalTypeID() == LogicalTypeID::ANY) {
return true;
}
if (type.getLogicalTypeID() != target.getLogicalTypeID()) {
return false;
}
switch (type.getLogicalTypeID()) {
case LogicalTypeID::LIST: {
return compatible(ListType::getChildType(type), ListType::getChildType(target));
}
case LogicalTypeID::ARRAY: {
return compatible(ArrayType::getChildType(type), ArrayType::getChildType(target));
}
case LogicalTypeID::STRUCT: {
if (StructType::getNumFields(type) != StructType::getNumFields(target)) {
return false;
}
for (auto i = 0u; i < StructType::getNumFields(type); ++i) {
if (!compatible(StructType::getField(type, i).getType(),
StructType::getField(target, i).getType())) {
return false;
}
}
return true;
}
case LogicalTypeID::DECIMAL:
case LogicalTypeID::UNION:
case LogicalTypeID::MAP:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
return false;
default:
return true;
}
}
// Handle special cases where value can be compatible to a type. This happens when a value is a
// nested value but does not have any child.
// E.g. [] is compatible with [1,2]
static bool compatible(const Value& value, const LogicalType& targetType) {
if (value.isNull()) { // Value is null. We can safely change its type.
return true;
}
if (value.getDataType().getLogicalTypeID() != targetType.getLogicalTypeID()) {
return false;
}
switch (value.getDataType().getLogicalTypeID()) {
case LogicalTypeID::LIST: {
if (!value.hasNoneNullChildren()) { // Empty list free to change.
return true;
}
for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) {
if (!compatible(*NestedVal::getChildVal(&value, i),
ListType::getChildType(targetType))) {
return false;
}
}
return true;
}
case LogicalTypeID::ARRAY: {
if (!value.hasNoneNullChildren()) { // Empty array free to change.
return true;
}
for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) {
if (!compatible(*NestedVal::getChildVal(&value, i),
ArrayType::getChildType(targetType))) {
return false;
}
}
return true;
}
case LogicalTypeID::MAP: {
if (!value.hasNoneNullChildren()) { // Empty map free to change.
return true;
}
const auto& keyType = MapType::getKeyType(targetType);
const auto& valType = MapType::getValueType(targetType);
for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) {
auto childVal = NestedVal::getChildVal(&value, i);
KU_ASSERT(NestedVal::getChildrenSize(childVal) == 2);
auto key = NestedVal::getChildVal(childVal, 0);
auto val = NestedVal::getChildVal(childVal, 1);
if (!compatible(*key, keyType) || !compatible(*val, valType)) {
return false;
}
}
return true;
}
default:
break;
}
return compatible(value.getDataType(), targetType);
}
bool ExpressionUtil::tryCombineDataType(const expression_vector& expressions, LogicalType& result) {
std::vector<Value> secondaryValues;
std::vector<LogicalType> primaryTypes;
for (auto& expr : expressions) {
if (expr->expressionType != ExpressionType::LITERAL) {
primaryTypes.push_back(expr->getDataType().copy());
continue;
}
auto literalExpr = expr->constPtrCast<LiteralExpression>();
if (literalExpr->getValue().allowTypeChange()) {
secondaryValues.push_back(literalExpr->getValue());
continue;
}
primaryTypes.push_back(expr->getDataType().copy());
}
if (!LogicalTypeUtils::tryGetMaxLogicalType(primaryTypes, result)) {
return false;
}
for (auto& value : secondaryValues) {
if (compatible(value, result)) {
continue;
}
if (!LogicalTypeUtils::tryGetMaxLogicalType(result, value.getDataType(), result)) {
return false;
}
}
return true;
}
bool ExpressionUtil::canCastStatically(const Expression& expr, const LogicalType& targetType) {
switch (expr.expressionType) {
case ExpressionType::LITERAL: {
auto value = expr.constPtrCast<LiteralExpression>()->getValue();
return compatible(value, targetType);
}
case ExpressionType::PARAMETER: {
auto value = expr.constPtrCast<ParameterExpression>()->getValue();
return compatible(value, targetType);
}
default:
return compatible(expr.getDataType(), targetType);
}
}
bool ExpressionUtil::canEvaluateAsLiteral(const Expression& expr) {
switch (expr.expressionType) {
case ExpressionType::LITERAL:
return true;
case ExpressionType::PARAMETER:
return expr.getDataType().getLogicalTypeID() != LogicalTypeID::ANY;
default:
return false;
}
}
Value ExpressionUtil::evaluateAsLiteralValue(const Expression& expr) {
KU_ASSERT(canEvaluateAsLiteral(expr));
auto value = Value::createDefaultValue(expr.dataType);
switch (expr.expressionType) {
case ExpressionType::LITERAL: {
value = expr.constCast<LiteralExpression>().getValue();
} break;
case ExpressionType::PARAMETER: {
value = expr.constCast<ParameterExpression>().getValue();
} break;
default:
KU_UNREACHABLE;
}
return value;
}
uint64_t ExpressionUtil::evaluateAsSkipLimit(const Expression& expr) {
auto value = evaluateAsLiteralValue(expr);
auto errorMsg = "The number of rows to skip/limit must be a non-negative integer.";
uint64_t number = INVALID_LIMIT;
TypeUtils::visit(
value.getDataType(),
[&]<IntegerTypes T>(T) {
if (value.getValue<T>() < 0) {
throw RuntimeException{errorMsg};
}
number = (uint64_t)value.getValue<T>();
},
[&](auto) { throw RuntimeException{errorMsg}; });
return number;
}
template<typename T>
T ExpressionUtil::getExpressionVal(const Expression& expr, const Value& value,
const LogicalType& targetType, validate_param_func<T> validateParamFunc) {
if (value.getDataType() != targetType) {
throw RuntimeException{common::stringFormat("Parameter: {} must be a {} literal.",
expr.getAlias(), targetType.toString())};
}
T val = value.getValue<T>();
if (validateParamFunc != nullptr) {
validateParamFunc(val);
}
return val;
}
template<typename T>
T ExpressionUtil::evaluateLiteral(main::ClientContext* context,
std::shared_ptr<Expression> expression, const common::LogicalType& type,
validate_param_func<T> validateParamFunc) {
if (!canEvaluateAsLiteral(*expression)) {
std::string errMsg;
switch (expression->expressionType) {
case ExpressionType::PARAMETER: {
errMsg = common::stringFormat(
"The expression: '{}' is a parameter expression. Please assign it a value.",
expression->toString());
} break;
default: {
errMsg =
common::stringFormat("The expression: '{}' must be a parameter/literal expression.",
expression->toString());
;
} break;
}
throw RuntimeException{errMsg};
}
if (expression->getDataType() != type) {
binder::Binder binder{context};
auto literalExpr = std::make_shared<LiteralExpression>(
ExpressionUtil::evaluateAsLiteralValue(*expression), expression->getUniqueName());
expression = binder.getExpressionBinder()->implicitCast(literalExpr, type.copy());
expression = binder.getExpressionBinder()->foldExpression(expression);
}
auto value = evaluateAsLiteralValue(*expression);
return getExpressionVal(*expression, value, type, validateParamFunc);
}
std::shared_ptr<Expression> ExpressionUtil::applyImplicitCastingIfNecessary(
main::ClientContext* context, std::shared_ptr<Expression> expr,
common::LogicalType targetType) {
if (expr->getDataType() != targetType) {
binder::Binder binder{context};
expr = binder.getExpressionBinder()->implicitCastIfNecessary(expr, targetType);
expr = binder.getExpressionBinder()->foldExpression(expr);
}
return expr;
}
template LBUG_API std::string ExpressionUtil::getExpressionVal(const Expression& expr,
const common::Value& value, const common::LogicalType& targetType,
validate_param_func<std::string> validateParamFunc);
template LBUG_API double ExpressionUtil::getExpressionVal(const Expression& expr,
const common::Value& value, const common::LogicalType& targetType,
validate_param_func<double> validateParamFunc);
template LBUG_API int64_t ExpressionUtil::getExpressionVal(const Expression& expr,
const common::Value& value, const common::LogicalType& targetType,
validate_param_func<int64_t> validateParamFunc);
template LBUG_API bool ExpressionUtil::getExpressionVal(const Expression& expr,
const common::Value& value, const common::LogicalType& targetType,
validate_param_func<bool> validateParamFunc);
template LBUG_API std::string ExpressionUtil::evaluateLiteral<std::string>(
main::ClientContext* context, std::shared_ptr<Expression> expression,
const common::LogicalType& type, validate_param_func<std::string> validateParamFunc);
template LBUG_API double ExpressionUtil::evaluateLiteral<double>(main::ClientContext* context,
std::shared_ptr<Expression> expression, const LogicalType& type,
validate_param_func<double> validateParamFunc);
template LBUG_API int64_t ExpressionUtil::evaluateLiteral<int64_t>(main::ClientContext* context,
std::shared_ptr<Expression> expression, const LogicalType& type,
validate_param_func<int64_t> validateParamFunc);
template LBUG_API bool ExpressionUtil::evaluateLiteral<bool>(main::ClientContext* context,
std::shared_ptr<Expression> expression, const LogicalType& type,
validate_param_func<bool> validateParamFunc);
template LBUG_API uint64_t ExpressionUtil::evaluateLiteral<uint64_t>(main::ClientContext* context,
std::shared_ptr<Expression> expression, const LogicalType& type,
validate_param_func<uint64_t> validateParamFunc);
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,24 @@
#include "binder/expression/literal_expression.h"
#include "common/exception/binder.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
void LiteralExpression::cast(const LogicalType& type) {
// The following is a safeguard to make sure we are not changing literal type unexpectedly.
if (!value.allowTypeChange()) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change literal expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type.copy();
value.setDataType(type);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,18 @@
#include "binder/expression/node_expression.h"
namespace lbug {
namespace binder {
NodeExpression::~NodeExpression() = default;
std::shared_ptr<Expression> NodeExpression::getPrimaryKey(common::table_id_t tableID) const {
for (auto& property : propertyExprs) {
if (property->isPrimaryKey(tableID)) {
return property;
}
}
KU_UNREACHABLE;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,44 @@
#include "binder/expression/node_rel_expression.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
using namespace lbug::catalog;
using namespace lbug::common;
namespace lbug {
namespace binder {
table_id_vector_t NodeOrRelExpression::getTableIDs() const {
table_id_vector_t result;
for (auto& entry : entries) {
result.push_back(entry->getTableID());
}
return result;
}
table_id_set_t NodeOrRelExpression::getTableIDsSet() const {
table_id_set_t result;
for (auto& entry : entries) {
result.insert(entry->getTableID());
}
return result;
}
void NodeOrRelExpression::addEntries(const std::vector<TableCatalogEntry*>& entries_) {
auto tableIDsSet = getTableIDsSet();
for (auto& entry : entries_) {
if (!tableIDsSet.contains(entry->getTableID())) {
entries.push_back(entry);
}
}
}
void NodeOrRelExpression::addPropertyExpression(std::shared_ptr<PropertyExpression> property) {
auto propertyName = property->getPropertyName();
KU_ASSERT(!propertyNameToIdx.contains(propertyName));
propertyNameToIdx.insert({propertyName, propertyExprs.size()});
propertyExprs.push_back(std::move(property));
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,23 @@
#include "binder/expression/parameter_expression.h"
#include "common/exception/binder.h"
namespace lbug {
using namespace common;
namespace binder {
void ParameterExpression::cast(const LogicalType& type) {
if (!dataType.containsAny()) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change parameter expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type.copy();
value.setDataType(type);
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,31 @@
#include "binder/expression/property_expression.h"
using namespace lbug::common;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
bool PropertyExpression::isPrimaryKey() const {
for (auto& [id, info] : infos) {
if (!info.isPrimaryKey) {
return false;
}
}
return true;
}
bool PropertyExpression::isPrimaryKey(table_id_t tableID) const {
if (!infos.contains(tableID)) {
return false;
}
return infos.at(tableID).isPrimaryKey;
}
bool PropertyExpression::hasProperty(table_id_t tableID) const {
KU_ASSERT(infos.contains(tableID));
return infos.at(tableID).exists;
}
} // namespace binder
} // namespace lbug

View File

@@ -0,0 +1,88 @@
#include "binder/expression/rel_expression.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/enums/extend_direction_util.h"
#include "common/exception/binder.h"
#include "common/utils.h"
using namespace lbug::common;
namespace lbug {
namespace binder {
bool RelExpression::isMultiLabeled() const {
if (entries.size() > 1) {
return true;
}
for (auto& entry : entries) {
auto relGroupEntry = entry->ptrCast<catalog::RelGroupCatalogEntry>();
if (relGroupEntry->getNumRelTables() > 1) {
return true;
}
}
return false;
}
std::string RelExpression::detailsToString() const {
std::string result = toString();
switch (relType) {
case QueryRelType::SHORTEST: {
result += "SHORTEST";
} break;
case QueryRelType::ALL_SHORTEST: {
result += "ALL SHORTEST";
} break;
case QueryRelType::WEIGHTED_SHORTEST: {
result += "WEIGHTED SHORTEST";
} break;
case QueryRelType::ALL_WEIGHTED_SHORTEST: {
result += "ALL WEIGHTED SHORTEST";
} break;
default:
break;
}
if (QueryRelTypeUtils::isRecursive(relType)) {
result += std::to_string(recursiveInfo->bindData->lowerBound);
result += "..";
result += std::to_string(recursiveInfo->bindData->upperBound);
}
return result;
}
std::vector<ExtendDirection> RelExpression::getExtendDirections() const {
std::vector<ExtendDirection> ret;
for (const auto direction : {ExtendDirection::FWD, ExtendDirection::BWD}) {
const bool addDirection = std::all_of(entries.begin(), entries.end(),
[direction](const catalog::TableCatalogEntry* tableEntry) {
const auto* entry = tableEntry->constPtrCast<catalog::RelGroupCatalogEntry>();
return common::containsValue(entry->getRelDataDirections(),
ExtendDirectionUtil::getRelDataDirection(direction));
});
if (addDirection) {
ret.push_back(direction);
}
}
if (ret.empty()) {
throw BinderException(stringFormat(
"There are no common storage directions among the rel "
"tables matched by pattern '{}' (some tables have storage direction 'fwd' "
"while others have storage direction 'bwd'). Scanning different tables matching the "
"same pattern in different directions is currently unsupported.",
toString()));
}
return ret;
}
std::vector<table_id_t> RelExpression::getInnerRelTableIDs() const {
std::vector<table_id_t> innerTableIDs;
for (auto& entry : entries) {
for (auto& info : entry->cast<catalog::RelGroupCatalogEntry>().getRelEntryInfos()) {
innerTableIDs.push_back(info.oid);
}
}
return innerTableIDs;
}
} // namespace binder
} // namespace lbug

Some files were not shown because too many files have changed in this diff Show More