Forwarder versioning

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
This commit is contained in:
Gelu Vrabie
2025-08-18 15:08:50 +01:00
committed by GitHub
parent ea3eeea826
commit 345fafd80d
5 changed files with 103 additions and 41 deletions

View File

@@ -45,7 +45,7 @@ build: regenerate-protobufs
# Build the Go forwarder binary
build-forwarder:
cd networking/forwarder && go build -buildvcs=false -o ../../build/forwarder .
HASH=$(uv run scripts/hashdir.py) && cd networking/forwarder && go build -buildvcs=false -o ../../build/forwarder -ldflags "-X 'main.SourceHash=${HASH}'" .
# Run forwarder tests
test-forwarder:

View File

@@ -123,13 +123,13 @@ class Master:
if len(events) == 0:
await asyncio.sleep(0.01)
return
self.logger.info(f"got events: {events}")
self.logger.debug(f"got events: {events}")
# 3. for each event, apply it to the state
for event_from_log in events:
print(f"applying event: {event_from_log}")
self.logger.debug(f"applying event: {event_from_log}")
self.state = apply(self.state, event_from_log)
self.logger.info(f"state: {self.state.model_dump_json()}")
self.logger.debug(f"state: {self.state.model_dump_json()}")
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
write_events: list[Event] = []

View File

@@ -13,9 +13,15 @@ import (
var nodeID = flag.String("node-id", "", "Node ID (defaults to FORWARDER_NODE_ID env var or a new UUID)")
var eventsDBPath = flag.String("events-db", "", "Path to the worker events SQLite database")
var SourceHash = "dev"
func main() {
flag.Parse()
log.Printf("SourceHash: %s\n", SourceHash)
os.Setenv("SOURCE_HASH", SourceHash)
id := *nodeID
if id != "" {
forwarder.SetNodeId(id)

View File

@@ -5,10 +5,10 @@ import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"log"
"net"
"os"
"sort"
"strings"
"sync"
"time"
@@ -65,42 +65,18 @@ const (
mdnsSlowInterval = 30 * time.Second
)
func sortAddrs(addrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
s := make([]multiaddr.Multiaddr, len(addrs))
copy(s, addrs)
sort.Slice(s, func(i, j int) bool {
return s[i].String() < s[j].String()
})
return s
var rendezvousTag string
func computeRendezvousTag() string {
sum := sha256.Sum256([]byte("forwarder_network/" + os.Getenv("SOURCE_HASH")))
return fmt.Sprintf("forwarder_network-%x", sum[:8])
}
func addrsChanged(a, b []multiaddr.Multiaddr) bool {
if len(a) != len(b) {
return true
func getRendezvousTag() string {
if rendezvousTag == "" {
rendezvousTag = computeRendezvousTag()
}
sa := sortAddrs(a)
sb := sortAddrs(b)
for i := range sa {
if !sa[i].Equal(sb[i]) {
return true
}
}
return false
}
func canonicalAddr(a multiaddr.Multiaddr) string {
cs := multiaddr.Split(a)
out := make([]multiaddr.Multiaddrer, 0, len(cs))
for _, c := range cs {
for _, p := range c.Protocols() {
if p.Code == multiaddr.P_P2P {
goto NEXT
}
}
out = append(out, c.Multiaddr())
NEXT:
}
return multiaddr.Join(out...).String()
return rendezvousTag
}
func ipString(a multiaddr.Multiaddr) string {
@@ -385,7 +361,7 @@ func getNode(ctx context.Context) {
opts = append(opts, libp2p.Identity(priv))
opts = append(opts, libp2p.Security(noise.ID, noise.New))
pskHash := sha256.Sum256([]byte("forwarder_network"))
pskHash := sha256.Sum256([]byte("forwarder_network/" + os.Getenv("SOURCE_HASH")))
psk := pnet.PSK(pskHash[:])
opts = append(opts, libp2p.PrivateNetwork(psk))
@@ -416,7 +392,7 @@ func getNode(ctx context.Context) {
log.Fatalf("failed to create pubsub: %v", err)
}
rendezvous := "forwarder_network"
rendezvous := getRendezvousTag()
notifee := &discoveryNotifee{h: node}
mdnsSer = mdns.NewMdnsService(node, rendezvous, notifee)
if err := mdnsSer.Start(); err != nil {
@@ -534,7 +510,7 @@ func forceRestartMDNS(reason string) {
if mdnsSer != nil && node != nil {
log.Printf("Restarting mDNS (%s)", reason)
old := mdnsSer
rendezvous := "forwarder_network"
rendezvous := getRendezvousTag()
notifee := &discoveryNotifee{h: node}
newMdns := mdns.NewMdnsService(node, rendezvous, notifee)
if err := newMdns.Start(); err != nil {

80
scripts/hashdir.py Normal file
View File

@@ -0,0 +1,80 @@
import hashlib
import os
import sys
EXCLUDE_DIRS = {".git", "build", "vendor", ".idea", ".vscode", "__pycache__"}
def norm_rel(path: str, base: str) -> str:
"""Forwarder-rootrelative path with '/' separators."""
abs_path = os.path.abspath(path)
abs_base = os.path.abspath(base)
rel = os.path.relpath(abs_path, abs_base)
return rel.replace(os.sep, "/")
def collect_files(arg_path: str) -> tuple[str, list[str]]:
# Resolve forwarder_root and src_root from the provided path
p = os.path.abspath(arg_path)
if not os.path.isdir(p):
sys.stderr.write(f"error: path must be a directory: {arg_path}\n")
sys.exit(2)
if os.path.basename(p) == "src":
forwarder_root = os.path.dirname(p)
src_root = p
else:
forwarder_root = p
src_root = os.path.join(forwarder_root, "src")
files = []
# 1) Include .go files under src, excluding *_test.go
if os.path.isdir(src_root):
for root, dirs, filenames in os.walk(src_root):
# prune excluded dirs
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
for name in filenames:
# strict .go, exclude *_test.go
if not name.lower().endswith(".go"):
continue
if name.lower().endswith("_test.go"):
continue
files.append(os.path.join(root, name))
# 2) Add go.mod, go.sum, main.go from the forwarder root
for name in ("go.mod", "go.sum", "main.go"):
pth = os.path.join(forwarder_root, name)
if os.path.isfile(pth):
# defensive: exclude *_test.go at root too
if name.lower().endswith("_test.go"):
continue
files.append(pth)
# Deduplicate and sort deterministically by forwarder-rootrelative path
files: list[str] = sorted(set(files), key=lambda f: norm_rel(f, forwarder_root))
return forwarder_root, files
def hash_files(forwarder_root: str, files: list[str]) -> str:
h = hashlib.sha256()
for fp in files:
rel = norm_rel(fp, forwarder_root)
h.update(b"F\x00")
h.update(rel.encode("utf-8"))
h.update(b"\x00")
with open(fp, "rb") as f:
for chunk in iter(lambda: f.read(256 * 1024), b""):
h.update(chunk)
h.update(b"\n")
return h.hexdigest()
def main():
if len(sys.argv) > 1:
arg = sys.argv[1]
else:
arg = os.path.join("networking", "forwarder", "src")
forwarder_root, files = collect_files(arg)
digest = hash_files(forwarder_root, files)
# print without trailing newline (easier to capture in shell)
sys.stdout.write(digest)
if __name__ == "__main__":
main()