Compare commits

...

65 Commits
v0.1 ... v0.2.2

Author SHA1 Message Date
Jakob Borg
c549e413a2 Close tmpfiles earlier (ref #2) 2014-01-01 16:31:52 -05:00
Jakob Borg
63a05ff6fa Command line option to ignore index cache 2014-01-01 16:31:35 -05:00
Jakob Borg
89a5aac6ea Use gzip compression for index cache 2014-01-01 16:31:04 -05:00
Jakob Borg
232d715c37 Fix broken --cfg flag 2014-01-01 08:49:55 -05:00
Jakob Borg
1c4e710adc Build windows binaries 2014-01-01 08:18:11 -05:00
Jakob Borg
7fdea0dd93 Close even if we don't have a connection 2014-01-01 08:09:17 -05:00
Jakob Borg
5b84b72d15 Await completion of pull round before starting next (ref #2) 2014-01-01 08:02:12 -05:00
Jakob Borg
7e0be89052 Simplify index sending, prevent ping timeout 2013-12-31 21:22:49 -05:00
Jakob Borg
632bcae856 Mostly lock free receive loop 2013-12-30 22:10:54 -05:00
Jakob Borg
fd56123acf Send index in chunks of 1000 to avoid lengthy blocking 2013-12-30 22:05:00 -05:00
Jakob Borg
a2a2e1d466 Atomically replace local index cache 2013-12-30 22:04:30 -05:00
Jakob Borg
d4c5786a14 Change default queue parameters to optimize better for small files 2013-12-30 21:35:20 -05:00
Jakob Borg
42ad9f8b02 Increase ping timeout 2013-12-30 21:32:20 -05:00
Jakob Borg
0f6b34160c Propagate and log reason for connection close 2013-12-30 21:25:45 -05:00
Jakob Borg
7e3b29e3e0 Remove source info in log by default 2013-12-30 21:25:34 -05:00
Jakob Borg
2f660aff7a Improve no such node error messages 2013-12-30 20:55:33 -05:00
Jakob Borg
af3e64a5a7 Remove broken Ping latency measurement 2013-12-30 20:52:36 -05:00
Jakob Borg
9560265adc Always continue walk in the face of errors (fixes #1) 2013-12-30 19:50:04 -05:00
Jakob Borg
4097528aa2 Don't crash on zero nodes in pull 2013-12-30 19:49:25 -05:00
Jakob Borg
71d50a50f4 Make sure to always close directory fd 2013-12-30 19:30:59 -05:00
Jakob Borg
ec0489a8ea Improve log message consistency 2013-12-30 15:31:41 -05:00
Jakob Borg
7948d046d1 Fix locking around close events 2013-12-30 15:27:20 -05:00
Jakob Borg
223bdbb9aa Improve/fix buffer handling 2013-12-30 15:06:44 -05:00
Jakob Borg
726afc915a Clarify installing / usage 2013-12-30 10:19:57 -05:00
Jakob Borg
86c0a527fd Include README & LICENSE in build 2013-12-30 10:11:10 -05:00
Jakob Borg
bb0fd87550 Don't print mysterious version message 2013-12-30 10:05:27 -05:00
Jakob Borg
673ab42c3c Remove race / unnecessary check 2013-12-30 10:05:13 -05:00
Jakob Borg
4543bfb837 Don't include .ini in build 2013-12-30 10:04:51 -05:00
Jakob Borg
005b207737 Atomic connection stats updates 2013-12-30 09:53:54 -05:00
Jakob Borg
bceacf04ca Better hash error messages 2013-12-30 09:36:41 -05:00
Jakob Borg
707e992f19 Print model statistics 2013-12-30 09:30:29 -05:00
Jakob Borg
1c757db153 Avoid deadlock in index exchange by more fine grained locking 2013-12-30 09:22:34 -05:00
Jakob Borg
001a6724ec Build artefacts in build dir 2013-12-30 09:02:18 -05:00
Jakob Borg
976baff44f Memory usage optimizations 2013-12-29 20:33:57 -05:00
Jakob Borg
469e96126a Cleanup SeedIndex 2013-12-29 19:49:40 -05:00
Jakob Borg
24efbe7d33 Woops 2013-12-29 19:20:36 -05:00
Jakob Borg
704e0fa6b8 Improve puller somewhat 2013-12-29 12:18:59 -05:00
Jakob Borg
c70fef1208 Typo / cleanup 2013-12-29 10:23:43 -05:00
Jakob Borg
454e672d42 Handle calls on closed connection 2013-12-28 10:33:18 -05:00
Jakob Borg
647fdcf6a5 Discovery 2013-12-28 08:56:18 -05:00
Jakob Borg
e75e68faa0 Don't send initial index twice, use more fetchers 2013-12-28 08:45:18 -05:00
Jakob Borg
74c27ad4e2 Index Updates 2013-12-28 08:10:36 -05:00
Jakob Borg
cf04e101b9 Lock tracing, fixes 2013-12-24 20:31:25 -05:00
Jakob Borg
4151972d3e Index broadcast pacing 2013-12-24 15:21:03 -05:00
Jakob Borg
3dc199d8df README 2013-12-24 11:53:24 -05:00
Jakob Borg
fc4b23fbc6 Locking/Ping cleanup 2013-12-24 11:45:16 -05:00
Jakob Borg
064bfd366f Don't complain about expected timeout 2013-12-24 11:15:21 -05:00
Jakob Borg
f5ea00b297 Don't accumulate goroutines forever 2013-12-24 11:10:49 -05:00
Jakob Borg
746d52930d Report transfer stats 2013-12-23 12:28:19 -05:00
Jakob Borg
cd2040a7d2 Pull in go-flags, modified to build on Solaris 2013-12-23 11:18:46 -05:00
Jakob Borg
f2d8b68278 External discover 2013-12-22 21:35:05 -05:00
Jakob Borg
31ea72dbb3 Perform external queries 2013-12-22 17:13:59 -05:00
Jakob Borg
e48222ada0 Send external announcements 2013-12-22 16:29:23 -05:00
Jakob Borg
8e65d36691 Build script 2013-12-22 00:16:49 +01:00
Jakob Borg
7d235a454d Refactor length check 2013-12-21 23:52:20 +01:00
Jakob Borg
5c1db4f0f4 Close on unknown message type 2013-12-21 08:15:19 +01:00
Jakob Borg
8d3aa97047 Close on version mismatch 2013-12-21 08:06:54 +01:00
Jakob Borg
f5987fba32 Error handling, testing 2013-12-21 07:52:32 +01:00
Jakob Borg
eba1c9e649 Command line flags 2013-12-18 19:36:28 +01:00
Jakob Borg
f774b0a5dc Error handling 2013-12-18 18:29:15 +01:00
Jakob Borg
251b109d14 Follow symlinks in repo 2013-12-15 21:20:50 +01:00
Jakob Borg
bef9ccfa71 Do ping check after 5 minute inactivity 2013-12-15 16:19:45 +01:00
Jakob Borg
768a7d5052 Simplify async results 2013-12-15 15:58:27 +01:00
Jakob Borg
e86296884a Crash for explainable reason when protocol is out of sync (version skew) 2013-12-15 13:18:03 +01:00
Jakob Borg
8589a0fb40 Don't crash on reading empty index 2013-12-15 13:12:32 +01:00
65 changed files with 6719 additions and 563 deletions

2
.gitignore vendored
View File

@@ -1 +1,3 @@
syncthing
*.tar.gz
build

View File

@@ -39,16 +39,18 @@ The following features are _currently implemented and working_:
* Static configuration of cluster nodes.
* Automatic discovery of cluster nodes on the local network. See
[discover.go](https://github.com/calmh/syncthing/blob/master/discover/discover.go)
for the protocol specification.
* Automatic discovery of cluster nodes. See [discover.go][discover.go]
for the protocol specification. Discovery on the LAN is performed by
broadcasts, Internet wide discovery is performed with the assistance
of a global server.
* Handling of deleted files. Deletes can be propagated or ignored per
client.
The following features are _not yet implemented but planned_:
* Synchronizing multiple unrelated directory trees by following
symlinks directly below the repository level.
* Syncing multiple directories from the same syncthing instance.
The following features are _not yet implemented but planned_:
* Change detection by listening to file system notifications instead of
periodic scanning.
@@ -58,19 +60,16 @@ The following features are _not yet implemented but planned_:
The following features are _not implemented but may be implemented_ in
the future:
* Automatic remote node discovery using a DHT. This is not technically
very difficult but requires one or more globally reachable root
nodes. This is open for discussion -- perhaps we can piggyback on an
existing DHT, or root nodes need to be established in some other
manner.
* Syncing multiple directories from the same syncthing instance.
* Automatic NAT handling via UPNP. Required for the above, not very
useful without it.
* Automatic NAT handling via UPNP.
* Conflict resolution. Currently whichever file has the newest
modification time "wins". The correct behavior in the face of
conflicts is open for discussion.
[discover.go]: (https://github.com/calmh/syncthing/blob/master/discover/discover.go
Security
--------
@@ -86,11 +85,21 @@ fingerprint is computed as the SHA-1 hash of the certificate and
displayed in BASE32 encoding to form a compact yet convenient string.
Currently SHA-1 is deemed secure against preimage attacks.
Usage
=====
Installing
==========
Download the appropriate precompiled binary from the
[releases](https://github.com/calmh/syncthing/releases) page. Untar and
put the `syncthing` binary somewhere convenient in your `$PATH`.
If you are a developer and have Go 1.2 installed you can also install
the latest version from source:
`go get github.com/calmh/syncthing`
Usage
=====
Check out the options:
```
@@ -105,17 +114,17 @@ Run syncthing to let it create it's config directory and certificate:
```
$ syncthing
11:34:13 tls.go:61: OK: wrote cert.pem
11:34:13 tls.go:67: OK: wrote key.pem
11:34:13 main.go:85: INFO: Version v0.1-40-gbb0fd87
11:34:13 tls.go:61: OK: Created TLS certificate file
11:34:13 tls.go:67: OK: Created TLS key file
11:34:13 main.go:66: INFO: My ID: NCTBZAAHXR6ZZP3D7SL3DLYFFQERMW4Q
11:34:13 main.go:90: FATAL: No config file
```
Take note of the "My ID: ..." line. Perform the same operation on
another computer (or the same computer but with a different `--home` for
testing) to create another node. Take note of that ID as well, and
create a config file `~/.syncthing/syncthing.ini` looking something like
this:
another computer to create another node. Take note of that ID as well,
and create a config file `~/.syncthing/syncthing.ini` looking something
like this:
```
[repository]
@@ -129,26 +138,33 @@ CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL = dynamic
This assumes that the first node is reachable on either of the two
addresses listed (perhaps one internal and one port-forwarded external)
and that the other node is not normally reachable from the outside. Save
this config file, identically, to both nodes. If both nodes are running
on the same network, you can set all addresses to 'dynamic' and they
will find each other by local node discovery.
this config file, identically, to both nodes.
Start syncthing on both nodes. If you're running both on the same
computer, one needs a different repository directory (in the config
file) and listening port (set as a command line paramter). For the
cautious, one side can be set to be read only.
If the nodes are running on the same network, or reachable on port 22000
from the outside world, you can set all addresses to "dynamic" and they
will find each other using automatic discovery. (This discovery,
including port numbers, can be tweaked or disabled using command line
options.)
Start syncthing on both nodes. For the cautious, one side can be set to
be read only.
```
$ syncthing --ro
13:30:55 main.go:85: INFO: Version v0.1-40-gbb0fd87
13:30:55 main.go:102: INFO: My ID: NCTBZAAHXR6ZZP3D7SL3DLYFFQERMW4Q
13:30:55 main.go:149: INFO: Initial repository scan in progress
13:30:59 main.go:153: INFO: Listening for incoming connections
13:30:59 main.go:157: INFO: Attempting to connect to other nodes
13:30:59 main.go:247: INFO: Starting local discovery
13:30:59 main.go:165: OK: Ready to synchronize
13:31:04 discover.go:113: INFO: Discovered node CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL at 172.16.32.24:23456
13:31:14 main.go:296: OK: Connected to node CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL
13:31:04 discover.go:113: INFO: Discovered node CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL at 172.16.32.24:22000
13:31:14 main.go:296: INFO: Connected to node CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL
13:31:19 main.go:345: INFO: Transferred 139 KiB in (14 KiB/s), 139 KiB out (14 KiB/s)
13:32:20 model.go:94: INFO: CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL: 263.4 KB/s in, 69.1 KB/s out
13:32:20 model.go:104: INFO: 18289 files, 24.24 GB in cluster
13:32:20 model.go:111: INFO: 17132 files, 22.39 GB in local repo
13:32:20 model.go:117: INFO: 1157 files, 1.84 GB to synchronize
...
```
You should see the synchronization start and then finish a short while

View File

@@ -1,13 +1,26 @@
package buffers
var buffers = make(chan []byte, 32)
const (
largeMin = 1024
)
var (
smallBuffers = make(chan []byte, 32)
largeBuffers = make(chan []byte, 32)
)
func Get(size int) []byte {
var ch = largeBuffers
if size < largeMin {
ch = smallBuffers
}
var buf []byte
select {
case buf = <-buffers:
case buf = <-ch:
default:
}
if len(buf) < size {
return make([]byte, size)
}
@@ -15,12 +28,18 @@ func Get(size int) []byte {
}
func Put(buf []byte) {
if cap(buf) == 0 {
buf = buf[:cap(buf)]
if len(buf) == 0 {
return
}
buf = buf[:cap(buf)]
var ch = largeBuffers
if len(buf) < largeMin {
ch = smallBuffers
}
select {
case buffers <- buf:
case ch <- buf:
default:
}
}

39
build.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
version=$(git describe --always)
go test ./... || exit 1
rm -rf build
mkdir -p build || exit 1
for goos in darwin linux freebsd ; do
for goarch in amd64 386 ; do
echo "$goos-$goarch"
export GOOS="$goos"
export GOARCH="$goarch"
export name="syncthing-$goos-$goarch"
go build -ldflags "-X main.Version $version" \
&& mkdir -p "$name" \
&& cp syncthing "build/$name" \
&& cp README.md LICENSE "$name" \
&& mv syncthing "$name" \
&& tar zcf "build/$name.tar.gz" "$name" \
&& rm -r "$name"
done
done
for goos in windows ; do
for goarch in amd64 386 ; do
echo "$goos-$goarch"
export GOOS="$goos"
export GOARCH="$goarch"
export name="syncthing-$goos-$goarch"
go build -ldflags "-X main.Version $version" \
&& mkdir -p "$name" \
&& cp syncthing.exe "build/$name.exe" \
&& cp README.md LICENSE "$name" \
&& zip -qr "build/$name.zip" "$name" \
&& rm -r "$name"
done
done

1
discover/cmd/discosrv/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
discosrv

View File

@@ -0,0 +1,74 @@
package main
import (
"log"
"net"
"sync"
"github.com/calmh/syncthing/discover"
)
type Node struct {
IP []byte
Port uint16
}
var (
nodes = make(map[string]Node)
lock sync.Mutex
)
func main() {
addr, _ := net.ResolveUDPAddr("udp", ":22025")
conn, err := net.ListenUDP("udp", addr)
if err != nil {
panic(err)
}
var buf = make([]byte, 1024)
for {
n, addr, err := conn.ReadFromUDP(buf)
if err != nil {
panic(err)
}
pkt, err := discover.DecodePacket(buf[:n])
if err != nil {
log.Println("Warning:", err)
continue
}
switch pkt.Magic {
case 0x20121025:
// Announcement
//lock.Lock()
ip := addr.IP.To4()
if ip == nil {
ip = addr.IP.To16()
}
node := Node{ip, uint16(pkt.Port)}
log.Println("<-", pkt.ID, node)
nodes[pkt.ID] = node
//lock.Unlock()
case 0x19760309:
// Query
//lock.Lock()
node, ok := nodes[pkt.ID]
//lock.Unlock()
if ok {
pkt := discover.Packet{
Magic: 0x20121025,
ID: pkt.ID,
Port: node.Port,
IP: node.IP,
}
_, _, err = conn.WriteMsgUDP(discover.EncodePacket(pkt), nil, addr)
if err != nil {
log.Println("Warning:", err)
} else {
log.Println("->", pkt.ID, node)
}
}
}
}
}

View File

@@ -4,118 +4,277 @@ served by something more standardized, such as mDNS / DNS-SD. In practice, this
was much easier and quicker to get up and running.
The mode of operation is to periodically (currently once every 30 seconds)
transmit a broadcast UDP packet to the well known port number 21025. The packet
has the following format:
broadcast an Announcement packet to UDP port 21025. The packet has the
following format:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Magic Number |
| Magic Number (0x20121025) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Port Number | Length of NodeID |
| Port Number | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of NodeID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ NodeID (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of IP |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ IP (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
This is the XDR encoding of:
struct Announcement {
unsigned int Magic;
unsigned short Port;
string NodeID<>;
}
(Hence NodeID is padded to a multiple of 32 bits)
The sending node's address is not encoded in local announcement -- the Length
of IP field is set to zero and the address is taken to be the source address of
the announcement. In announcement packets sent by a discovery server in
response to a query, the IP is present and the length is either 4 (IPv4) or 16
(IPv6).
Every time such a packet is received, a local table that maps NodeID to Address
is updated. When the local node wants to connect to another node with the
address specification 'dynamic', this table is consulted.
For external discovery, an identical packet is sent every 30 minutes to the
external discovery server. The server keeps information for up to 60 minutes.
To query the server, and UDP packet with the format below is sent.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Magic Number (0x19760309) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of NodeID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ NodeID (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
The sending node's address is not encoded -- it is taken to be the source
address of the announcement. Every time such a packet is received, a local
table that maps NodeID to Address is updated. When the local node wants to
connect to another node with the address specification 'dynamic', this table is
consulted.
This is the XDR encoding of:
struct Announcement {
unsigned int Magic;
string NodeID<>;
}
(Hence NodeID is padded to a multiple of 32 bits)
It is answered with an announcement packet for the queried node ID if the
information is available. There is no answer for queries about unknown nodes. A
reasonable timeout is recommended instead. (This, combined with server side
rate limits for packets per source IP and queries per node ID, prevents the
server from being used as an amplifier in a DDoS attack.)
*/
package discover
import (
"encoding/binary"
"fmt"
"log"
"net"
"sync"
"time"
"github.com/calmh/syncthing/buffers"
)
const (
AnnouncementPort = 21025
AnnouncementMagic = 0x20121025
QueryMagic = 0x19760309
)
type Discoverer struct {
MyID string
ListenPort int
BroadcastIntv time.Duration
MyID string
ListenPort int
BroadcastIntv time.Duration
ExtListenPort int
ExtBroadcastIntv time.Duration
conn *net.UDPConn
registry map[string]string
registryLock sync.RWMutex
extServer string
}
func NewDiscoverer(id string, port int) (*Discoverer, error) {
local4 := &net.UDPAddr{IP: net.IP{0, 0, 0, 0}, Port: 21025}
// We tolerate a certain amount of errors because we might be running on
// laptops that sleep and wake, have intermittent network connectivity, etc.
// When we hit this many errors in succession, we stop.
const maxErrors = 30
func NewDiscoverer(id string, port int, extPort int, extServer string) (*Discoverer, error) {
local4 := &net.UDPAddr{IP: net.IP{0, 0, 0, 0}, Port: AnnouncementPort}
conn, err := net.ListenUDP("udp4", local4)
if err != nil {
return nil, err
}
disc := &Discoverer{
MyID: id,
ListenPort: port,
BroadcastIntv: 30 * time.Second,
conn: conn,
registry: make(map[string]string),
MyID: id,
ListenPort: port,
BroadcastIntv: 30 * time.Second,
ExtListenPort: extPort,
ExtBroadcastIntv: 1800 * time.Second,
conn: conn,
registry: make(map[string]string),
extServer: extServer,
}
go disc.sendAnnouncements()
go disc.recvAnnouncements()
if disc.ListenPort > 0 {
disc.sendAnnouncements()
}
if len(disc.extServer) > 0 && disc.ExtListenPort > 0 {
disc.sendExtAnnouncements()
}
return disc, nil
}
func (d *Discoverer) sendAnnouncements() {
remote4 := &net.UDPAddr{IP: net.IP{255, 255, 255, 255}, Port: 21025}
remote4 := &net.UDPAddr{IP: net.IP{255, 255, 255, 255}, Port: AnnouncementPort}
idbs := []byte(d.MyID)
buf := make([]byte, 4+4+4+len(idbs))
buf := EncodePacket(Packet{AnnouncementMagic, uint16(d.ListenPort), d.MyID, nil})
go d.writeAnnouncements(buf, remote4, d.BroadcastIntv)
}
binary.BigEndian.PutUint32(buf, uint32(0x121025))
binary.BigEndian.PutUint16(buf[4:], uint16(d.ListenPort))
binary.BigEndian.PutUint16(buf[6:], uint16(len(idbs)))
copy(buf[8:], idbs)
for {
_, _, err := d.conn.WriteMsgUDP(buf, nil, remote4)
if err != nil {
panic(err)
}
time.Sleep(d.BroadcastIntv)
func (d *Discoverer) sendExtAnnouncements() {
extIP, err := net.ResolveUDPAddr("udp", d.extServer+":22025")
if err != nil {
log.Printf("discover/external: %v; no external announcements", err)
return
}
buf := EncodePacket(Packet{AnnouncementMagic, uint16(d.ExtListenPort), d.MyID, nil})
go d.writeAnnouncements(buf, extIP, d.ExtBroadcastIntv)
}
func (d *Discoverer) writeAnnouncements(buf []byte, remote *net.UDPAddr, intv time.Duration) {
var errCounter = 0
var err error
for errCounter < maxErrors {
_, _, err = d.conn.WriteMsgUDP(buf, nil, remote)
if err != nil {
log.Println("discover/write: warning:", err)
errCounter++
} else {
errCounter = 0
}
time.Sleep(intv)
}
log.Println("discover/write: %v: stopping due to too many errors:", remote, err)
}
func (d *Discoverer) recvAnnouncements() {
var buf = make([]byte, 1024)
for {
_, addr, err := d.conn.ReadFromUDP(buf)
var errCounter = 0
var err error
for errCounter < maxErrors {
n, addr, err := d.conn.ReadFromUDP(buf)
if err != nil {
panic(err)
}
magic := binary.BigEndian.Uint32(buf)
if magic != 0x121025 {
errCounter++
time.Sleep(time.Second)
continue
}
port := binary.BigEndian.Uint16(buf[4:])
l := binary.BigEndian.Uint16(buf[6:])
idbs := buf[8 : l+8]
id := string(idbs)
if id != d.MyID {
nodeAddr := fmt.Sprintf("%s:%d", addr.IP.String(), port)
pkt, err := DecodePacket(buf[:n])
if err != nil || pkt.Magic != AnnouncementMagic {
errCounter++
time.Sleep(time.Second)
continue
}
errCounter = 0
if pkt.ID != d.MyID {
nodeAddr := fmt.Sprintf("%s:%d", addr.IP.String(), pkt.Port)
d.registryLock.Lock()
if d.registry[id] != nodeAddr {
d.registry[id] = nodeAddr
if d.registry[pkt.ID] != nodeAddr {
d.registry[pkt.ID] = nodeAddr
}
d.registryLock.Unlock()
}
}
log.Println("discover/read: stopping due to too many errors:", err)
}
func (d *Discoverer) externalLookup(node string) (string, bool) {
extIP, err := net.ResolveUDPAddr("udp", d.extServer+":22025")
if err != nil {
log.Printf("discover/external: %v; no external lookup", err)
return "", false
}
conn, err := net.DialUDP("udp", nil, extIP)
if err != nil {
log.Printf("discover/external: %v; no external lookup", err)
return "", false
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
if err != nil {
log.Printf("discover/external: %v; no external lookup", err)
return "", false
}
_, err = conn.Write(EncodePacket(Packet{QueryMagic, 0, node, nil}))
if err != nil {
log.Printf("discover/external: %v; no external lookup", err)
return "", false
}
var buf = buffers.Get(256)
defer buffers.Put(buf)
n, err := conn.Read(buf)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
// Expected if the server doesn't know about requested node ID
return "", false
}
log.Printf("discover/external/read: %v; no external lookup", err)
return "", false
}
pkt, err := DecodePacket(buf[:n])
if err != nil {
log.Printf("discover/external/read: %v; no external lookup", err)
return "", false
}
if pkt.Magic != AnnouncementMagic {
log.Printf("discover/external/read: bad magic; no external lookup", err)
return "", false
}
return fmt.Sprintf("%s:%d", ipStr(pkt.IP), pkt.Port), true
}
func (d *Discoverer) Lookup(node string) (string, bool) {
d.registryLock.Lock()
defer d.registryLock.Unlock()
addr, ok := d.registry[node]
return addr, ok
d.registryLock.Unlock()
if ok {
return addr, true
} else if len(d.extServer) != 0 {
// We might want to cache this, but not permanently so it needs some intelligence
return d.externalLookup(node)
}
return "", false
}

160
discover/encoding.go Normal file
View File

@@ -0,0 +1,160 @@
package discover
import (
"encoding/binary"
"errors"
"fmt"
)
type Packet struct {
Magic uint32 // AnnouncementMagic or QueryMagic
Port uint16 // unset if magic == QueryMagic
ID string
IP []byte // zero length in local announcements
}
var (
errBadMagic = errors.New("bad magic")
errFormat = errors.New("incorrect packet format")
)
func EncodePacket(pkt Packet) []byte {
if l := len(pkt.IP); l != 0 && l != 4 && l != 16 {
// bad ip format
return nil
}
var idbs = []byte(pkt.ID)
var l = 4 + 4 + len(idbs) + pad(len(idbs))
if pkt.Magic == AnnouncementMagic {
l += 4 + 4 + len(pkt.IP)
}
var buf = make([]byte, l)
var offset = 0
binary.BigEndian.PutUint32(buf[offset:], pkt.Magic)
offset += 4
if pkt.Magic == AnnouncementMagic {
binary.BigEndian.PutUint16(buf[offset:], uint16(pkt.Port))
offset += 4
}
binary.BigEndian.PutUint32(buf[offset:], uint32(len(idbs)))
offset += 4
copy(buf[offset:], idbs)
offset += len(idbs) + pad(len(idbs))
if pkt.Magic == AnnouncementMagic {
binary.BigEndian.PutUint32(buf[offset:], uint32(len(pkt.IP)))
offset += 4
copy(buf[offset:], pkt.IP)
offset += len(pkt.IP)
}
return buf
}
func DecodePacket(buf []byte) (*Packet, error) {
var p Packet
var offset int
if len(buf) < 4 {
// short packet
return nil, errFormat
}
p.Magic = binary.BigEndian.Uint32(buf[offset:])
offset += 4
if p.Magic != AnnouncementMagic && p.Magic != QueryMagic {
return nil, errBadMagic
}
if p.Magic == AnnouncementMagic {
// Port Number
if len(buf) < offset+4 {
// short packet
return nil, errFormat
}
p.Port = binary.BigEndian.Uint16(buf[offset:])
offset += 2
reserved := binary.BigEndian.Uint16(buf[offset:])
if reserved != 0 {
return nil, errFormat
}
offset += 2
}
// Node ID
if len(buf) < offset+4 {
// short packet
return nil, errFormat
}
l := binary.BigEndian.Uint32(buf[offset:])
offset += 4
if len(buf) < offset+int(l)+pad(int(l)) {
// short packet
return nil, errFormat
}
idbs := buf[offset : offset+int(l)]
p.ID = string(idbs)
offset += int(l) + pad(int(l))
if p.Magic == AnnouncementMagic {
// IP
if len(buf) < offset+4 {
// short packet
return nil, errFormat
}
l = binary.BigEndian.Uint32(buf[offset:])
offset += 4
if l != 0 && l != 4 && l != 16 {
// weird ip length
return nil, errFormat
}
if len(buf) < offset+int(l) {
// short packet
return nil, errFormat
}
if l > 0 {
p.IP = buf[offset : offset+int(l)]
offset += int(l)
}
}
if len(buf[offset:]) > 0 {
// extra data
return nil, errFormat
}
return &p, nil
}
func pad(l int) int {
d := l % 4
if d == 0 {
return 0
}
return 4 - d
}
func ipStr(ip []byte) string {
switch len(ip) {
case 4:
return fmt.Sprintf("%d.%d.%d.%d", ip[0], ip[1], ip[2], ip[3])
case 16:
return fmt.Sprintf("%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x",
ip[0], ip[1], ip[2], ip[3],
ip[4], ip[5], ip[6], ip[7],
ip[8], ip[9], ip[10], ip[11],
ip[12], ip[13], ip[14], ip[15])
default:
return ""
}
}

138
discover/encoding_test.go Normal file
View File

@@ -0,0 +1,138 @@
package discover
import (
"bytes"
"reflect"
"testing"
)
var testdata = []struct {
data []byte
packet *Packet
err error
}{
{
[]byte{0x20, 0x12, 0x10, 0x25,
0x12, 0x34, 0x00, 0x00,
0x00, 0x00, 0x00, 0x05,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00},
&Packet{
Magic: 0x20121025,
Port: 0x1234,
ID: "hello",
},
nil,
},
{
[]byte{0x20, 0x12, 0x10, 0x25,
0x34, 0x56, 0x00, 0x00,
0x00, 0x00, 0x00, 0x08,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x21, 0x21,
0x00, 0x00, 0x00, 0x04,
0x01, 0x02, 0x03, 0x04},
&Packet{
Magic: 0x20121025,
Port: 0x3456,
ID: "hello!!!",
IP: []byte{1, 2, 3, 4},
},
nil,
},
{
[]byte{0x19, 0x76, 0x03, 0x09,
0x00, 0x00, 0x00, 0x06,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x00, 0x00},
&Packet{
Magic: 0x19760309,
ID: "hello!",
},
nil,
},
{
[]byte{0x20, 0x12, 0x10, 0x25,
0x12, 0x34, 0x12, 0x34, // reserved bits not set to zero
0x00, 0x00, 0x00, 0x06,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00},
nil,
errFormat,
},
{
[]byte{0x20, 0x12, 0x10, 0x25,
0x12, 0x34, 0x00, 0x00,
0x00, 0x00, 0x00, 0x06,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, // missing padding
0x00, 0x00, 0x00, 0x00},
nil,
errFormat,
},
{
[]byte{0x19, 0x77, 0x03, 0x09, // incorrect Magic
0x00, 0x00, 0x00, 0x06,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x00, 0x00},
nil,
errBadMagic,
},
{
[]byte{0x19, 0x76, 0x03, 0x09,
0x6c, 0x6c, 0x6c, 0x6c, // length exceeds packet size
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x00, 0x00},
nil,
errFormat,
},
{
[]byte{0x19, 0x76, 0x03, 0x09,
0x00, 0x00, 0x00, 0x06,
0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x00, 0x00,
0x23}, // extra data at the end
nil,
errFormat,
},
}
func TestDecodePacket(t *testing.T) {
for i, test := range testdata {
p, err := DecodePacket(test.data)
if err != test.err {
t.Errorf("%d: unexpected error %v", i, err)
} else {
if !reflect.DeepEqual(p, test.packet) {
t.Errorf("%d: incorrect packet\n%v\n%v", i, test.packet, p)
}
}
}
}
func TestEncodePacket(t *testing.T) {
for i, test := range testdata {
if test.err != nil {
continue
}
buf := EncodePacket(*test.packet)
if bytes.Compare(buf, test.data) != 0 {
t.Errorf("%d: incorrect encoded packet\n% x\n% 0x", i, test.data, buf)
}
}
}
var ipstrTests = []struct {
d []byte
s string
}{
{[]byte{192, 168, 34}, ""},
{[]byte{192, 168, 0, 34}, "192.168.0.34"},
{[]byte{0x20, 0x01, 0x12, 0x34,
0x34, 0x56, 0x56, 0x78,
0x78, 0x00, 0x00, 0xdc,
0x00, 0x00, 0x43, 0x54}, "2001:1234:3456:5678:7800:00dc:0000:4354"},
}
func TestIPStr(t *testing.T) {
for _, tc := range ipstrTests {
s1 := ipStr(tc.d)
if s1 != tc.s {
t.Errorf("Incorrect ipstr %q != %q", tc.s, s1)
}
}
}

View File

@@ -0,0 +1,26 @@
Copyright (c) 2012 Jesse van den Kieboom. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,128 @@
go-flags: a go library for parsing command line arguments
=========================================================
This library provides similar functionality to the builtin flag library of
go, but provides much more functionality and nicer formatting. From the
documentation:
Package flags provides an extensive command line option parser.
The flags package is similar in functionality to the go builtin flag package
but provides more options and uses reflection to provide a convenient and
succinct way of specifying command line options.
Supported features:
* Options with short names (-v)
* Options with long names (--verbose)
* Options with and without arguments (bool v.s. other type)
* Options with optional arguments and default values
* Multiple option groups each containing a set of options
* Generate and print well-formatted help message
* Passing remaining command line arguments after -- (optional)
* Ignoring unknown command line options (optional)
* Supports -I/usr/include -I=/usr/include -I /usr/include option argument specification
* Supports multiple short options -aux
* Supports all primitive go types (string, int{8..64}, uint{8..64}, float)
* Supports same option multiple times (can store in slice or last option counts)
* Supports maps
* Supports function callbacks
The flags package uses structs, reflection and struct field tags
to allow users to specify command line options. This results in very simple
and consise specification of your application options. For example:
type Options struct {
Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"`
}
This specifies one option with a short name -v and a long name --verbose.
When either -v or --verbose is found on the command line, a 'true' value
will be appended to the Verbose field. e.g. when specifying -vvv, the
resulting value of Verbose will be {[true, true, true]}.
Example:
--------
var opts struct {
// Slice of bool will append 'true' each time the option
// is encountered (can be set multiple times, like -vvv)
Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"`
// Example of automatic marshalling to desired type (uint)
Offset uint `long:"offset" description:"Offset"`
// Example of a callback, called each time the option is found.
Call func(string) `short:"c" description:"Call phone number"`
// Example of a required flag
Name string `short:"n" long:"name" description:"A name" required:"true"`
// Example of a value name
File string `short:"f" long:"file" description:"A file" value-name:"FILE"`
// Example of a pointer
Ptr *int `short:"p" description:"A pointer to an integer"`
// Example of a slice of strings
StringSlice []string `short:"s" description:"A slice of strings"`
// Example of a slice of pointers
PtrSlice []*string `long:"ptrslice" description:"A slice of pointers to string"`
// Example of a map
IntMap map[string]int `long:"intmap" description:"A map from string to int"`
}
// Callback which will invoke callto:<argument> to call a number.
// Note that this works just on OS X (and probably only with
// Skype) but it shows the idea.
opts.Call = func(num string) {
cmd := exec.Command("open", "callto:"+num)
cmd.Start()
cmd.Process.Release()
}
// Make some fake arguments to parse.
args := []string{
"-vv",
"--offset=5",
"-n", "Me",
"-p", "3",
"-s", "hello",
"-s", "world",
"--ptrslice", "hello",
"--ptrslice", "world",
"--intmap", "a:1",
"--intmap", "b:5",
"arg1",
"arg2",
"arg3",
}
// Parse flags from `args'. Note that here we use flags.ParseArgs for
// the sake of making a working example. Normally, you would simply use
// flags.Parse(&opts) which uses os.Args
args, err := flags.ParseArgs(&opts, args)
if err != nil {
panic(err)
os.Exit(1)
}
fmt.Printf("Verbosity: %v\n", opts.Verbose)
fmt.Printf("Offset: %d\n", opts.Offset)
fmt.Printf("Name: %s\n", opts.Name)
fmt.Printf("Ptr: %d\n", *opts.Ptr)
fmt.Printf("StringSlice: %v\n", opts.StringSlice)
fmt.Printf("PtrSlice: [%v %v]\n", *opts.PtrSlice[0], *opts.PtrSlice[1])
fmt.Printf("IntMap: [a:%v b:%v]\n", opts.IntMap["a"], opts.IntMap["b"])
fmt.Printf("Remaining args: %s\n", strings.Join(args, " "))
// Output: Verbosity: [true true]
// Offset: 5
// Name: Me
// Ptr: 3
// StringSlice: [hello world]
// PtrSlice: [hello world]
// IntMap: [a:1 b:5]
// Remaining args: arg1 arg2 arg3
More information can be found in the godocs: <http://godoc.org/github.com/jessevdk/go-flags>

View File

@@ -0,0 +1,82 @@
package flags
import (
"testing"
)
func assertString(t *testing.T, a string, b string) {
if a != b {
t.Errorf("Expected %#v, but got %#v", b, a)
}
}
func assertStringArray(t *testing.T, a []string, b []string) {
if len(a) != len(b) {
t.Errorf("Expected %#v, but got %#v", b, a)
return
}
for i, v := range a {
if b[i] != v {
t.Errorf("Expected %#v, but got %#v", b, a)
return
}
}
}
func assertBoolArray(t *testing.T, a []bool, b []bool) {
if len(a) != len(b) {
t.Errorf("Expected %#v, but got %#v", b, a)
return
}
for i, v := range a {
if b[i] != v {
t.Errorf("Expected %#v, but got %#v", b, a)
return
}
}
}
func assertParserSuccess(t *testing.T, data interface{}, args ...string) (*Parser, []string) {
parser := NewParser(data, Default&^PrintErrors)
ret, err := parser.ParseArgs(args)
if err != nil {
t.Fatalf("Unexpected parse error: %s", err)
return nil, nil
}
return parser, ret
}
func assertParseSuccess(t *testing.T, data interface{}, args ...string) []string {
_, ret := assertParserSuccess(t, data, args...)
return ret
}
func assertError(t *testing.T, err error, typ ErrorType, msg string) {
if err == nil {
t.Fatalf("Expected error: %s", msg)
return
}
if e, ok := err.(*Error); !ok {
t.Fatalf("Expected Error type, but got %#v", err)
return
} else {
if e.Type != typ {
t.Errorf("Expected error type {%s}, but got {%s}", typ, e.Type)
}
if e.Message != msg {
t.Errorf("Expected error message %#v, but got %#v", msg, e.Message)
}
}
}
func assertParseFail(t *testing.T, typ ErrorType, msg string, data interface{}, args ...string) {
parser := NewParser(data, Default&^PrintErrors)
_, err := parser.ParseArgs(args)
assertError(t, err, typ, msg)
}

View File

@@ -0,0 +1,16 @@
#!/bin/bash
set -e
echo '# linux arm7'
GOARM=7 GOARCH=arm GOOS=linux go build
echo '# linux arm5'
GOARM=5 GOARCH=arm GOOS=linux go build
echo '# windows 386'
GOARCH=386 GOOS=windows go build
echo '# windows amd64'
GOARCH=amd64 GOOS=windows go build
echo '# darwin'
GOARCH=amd64 GOOS=darwin go build
echo '# freebsd'
GOARCH=amd64 GOOS=freebsd go build

View File

@@ -0,0 +1,61 @@
package flags
func levenshtein(s string, t string) int {
if len(s) == 0 {
return len(t)
}
if len(t) == 0 {
return len(s)
}
var l1, l2, l3 int
if len(s) == 1 {
l1 = len(t) + 1
} else {
l1 = levenshtein(s[1:len(s)-1], t) + 1
}
if len(t) == 1 {
l2 = len(s) + 1
} else {
l2 = levenshtein(t[1:len(t)-1], s) + 1
}
l3 = levenshtein(s[1:len(s)], t[1:len(t)])
if s[0] != t[0] {
l3 += 1
}
if l2 < l1 {
l1 = l2
}
if l1 < l3 {
return l1
}
return l3
}
func closestChoice(cmd string, choices []string) (string, int) {
if len(choices) == 0 {
return "", 0
}
mincmd := -1
mindist := -1
for i, c := range choices {
l := levenshtein(cmd, c)
if mincmd < 0 || l < mindist {
mindist = l
mincmd = i
}
}
return choices[mincmd], mindist
}

View File

@@ -0,0 +1,84 @@
package flags
// Command represents an application command. Commands can be added to the
// parser (which itself is a command) and are selected/executed when its name
// is specified on the command line. The Command type embeds a Group and
// therefore also carries a set of command specific options.
type Command struct {
// Embedded, see Group for more information
*Group
// The name by which the command can be invoked
Name string
// The active sub command (set by parsing) or nil
Active *Command
commands []*Command
hasBuiltinHelpGroup bool
}
// Commander is an interface which can be implemented by any command added in
// the options. When implemented, the Execute method will be called for the last
// specified (sub)command providing the remaining command line arguments.
type Commander interface {
// Execute will be called for the last active (sub)command. The
// args argument contains the remaining command line arguments. The
// error that Execute returns will be eventually passed out of the
// Parse method of the Parser.
Execute(args []string) error
}
// Usage is an interface which can be implemented to show a custom usage string
// in the help message shown for a command.
type Usage interface {
// Usage is called for commands to allow customized printing of command
// usage in the generated help message.
Usage() string
}
// AddCommand adds a new command to the parser with the given name and data. The
// data needs to be a pointer to a struct from which the fields indicate which
// options are in the command. The provided data can implement the Command and
// Usage interfaces.
func (c *Command) AddCommand(command string, shortDescription string, longDescription string, data interface{}) (*Command, error) {
cmd := newCommand(command, shortDescription, longDescription, data)
if err := cmd.scan(); err != nil {
return nil, err
}
c.commands = append(c.commands, cmd)
return cmd, nil
}
// AddGroup adds a new group to the command with the given name and data. The
// data needs to be a pointer to a struct from which the fields indicate which
// options are in the group.
func (c *Command) AddGroup(shortDescription string, longDescription string, data interface{}) (*Group, error) {
group := newGroup(shortDescription, longDescription, data)
if err := group.scanType(c.scanSubCommandHandler(group)); err != nil {
return nil, err
}
c.groups = append(c.groups, group)
return group, nil
}
// Commands returns a list of subcommands of this command.
func (c *Command) Commands() []*Command {
return c.commands
}
// Find locates the subcommand with the given name and returns it. If no such
// command can be found Find will return nil.
func (c *Command) Find(name string) *Command {
for _, cc := range c.commands {
if cc.Name == name {
return cc
}
}
return nil
}

View File

@@ -0,0 +1,161 @@
package flags
import (
"reflect"
"sort"
"strings"
"unsafe"
)
type lookup struct {
shortNames map[string]*Option
longNames map[string]*Option
required map[*Option]bool
commands map[string]*Command
}
func newCommand(name string, shortDescription string, longDescription string, data interface{}) *Command {
return &Command{
Group: newGroup(shortDescription, longDescription, data),
Name: name,
}
}
func (c *Command) scanSubCommandHandler(parentg *Group) scanHandler {
f := func(realval reflect.Value, sfield *reflect.StructField) (bool, error) {
mtag := newMultiTag(string(sfield.Tag))
if err := mtag.Parse(); err != nil {
return true, err
}
subcommand := mtag.Get("command")
if len(subcommand) != 0 {
ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr()))
shortDescription := mtag.Get("description")
longDescription := mtag.Get("long-description")
if _, err := c.AddCommand(subcommand, shortDescription, longDescription, ptrval.Interface()); err != nil {
return true, err
}
return true, nil
}
return parentg.scanSubGroupHandler(realval, sfield)
}
return f
}
func (c *Command) scan() error {
return c.scanType(c.scanSubCommandHandler(c.Group))
}
func (c *Command) eachCommand(f func(*Command), recurse bool) {
f(c)
for _, cc := range c.commands {
if recurse {
cc.eachCommand(f, true)
} else {
f(cc)
}
}
}
func (c *Command) eachActiveGroup(f func(g *Group)) {
c.eachGroup(f)
if c.Active != nil {
c.Active.eachActiveGroup(f)
}
}
func (c *Command) addHelpGroups(showHelp func() error) {
if !c.hasBuiltinHelpGroup {
c.addHelpGroup(showHelp)
c.hasBuiltinHelpGroup = true
}
for _, cc := range c.commands {
cc.addHelpGroups(showHelp)
}
}
func (c *Command) makeLookup() lookup {
ret := lookup{
shortNames: make(map[string]*Option),
longNames: make(map[string]*Option),
required: make(map[*Option]bool),
commands: make(map[string]*Command),
}
c.eachGroup(func(g *Group) {
for _, option := range g.options {
if option.Required && option.canCli() {
ret.required[option] = true
}
if option.ShortName != 0 {
ret.shortNames[string(option.ShortName)] = option
}
if len(option.LongName) > 0 {
ret.longNames[option.LongName] = option
}
}
})
for _, subcommand := range c.commands {
ret.commands[subcommand.Name] = subcommand
}
return ret
}
func (c *Command) groupByName(name string) *Group {
if grp := c.Group.groupByName(name); grp != nil {
return grp
}
for _, subc := range c.commands {
prefix := subc.Name + "."
if strings.HasPrefix(name, prefix) {
if grp := subc.groupByName(name[len(prefix):]); grp != nil {
return grp
}
} else if name == subc.Name {
return subc.Group
}
}
return nil
}
type commandList []*Command
func (c commandList) Less(i, j int) bool {
return c[i].Name < c[j].Name
}
func (c commandList) Len() int {
return len(c)
}
func (c commandList) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
func (c *Command) sortedCommands() []*Command {
ret := make(commandList, len(c.commands))
copy(ret, c.commands)
sort.Sort(ret)
return []*Command(ret)
}

View File

@@ -0,0 +1,255 @@
package flags
import (
"testing"
)
func TestCommandInline(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Command struct {
G bool `short:"g"`
} `command:"cmd"`
}{}
p, ret := assertParserSuccess(t, &opts, "-v", "cmd", "-g")
assertStringArray(t, ret, []string{})
if p.Active == nil {
t.Errorf("Expected active command")
}
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !opts.Command.G {
t.Errorf("Expected Command.G to be true")
}
if p.Command.Find("cmd") != p.Active {
t.Errorf("Expected to find command `cmd' to be active")
}
}
func TestCommandInlineMulti(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
C1 struct {
} `command:"c1"`
C2 struct {
G bool `short:"g"`
} `command:"c2"`
}{}
p, ret := assertParserSuccess(t, &opts, "-v", "c2", "-g")
assertStringArray(t, ret, []string{})
if p.Active == nil {
t.Errorf("Expected active command")
}
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !opts.C2.G {
t.Errorf("Expected C2.G to be true")
}
if p.Command.Find("c1") == nil {
t.Errorf("Expected to find command `c1'")
}
if c2 := p.Command.Find("c2"); c2 == nil {
t.Errorf("Expected to find command `c2'")
} else if c2 != p.Active {
t.Errorf("Expected to find command `c2' to be active")
}
}
func TestCommandFlagOrder1(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Command struct {
G bool `short:"g"`
} `command:"cmd"`
}{}
assertParseFail(t, ErrUnknownFlag, "unknown flag `g'", &opts, "-v", "-g", "cmd")
}
func TestCommandFlagOrder2(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Command struct {
G bool `short:"g"`
} `command:"cmd"`
}{}
assertParseFail(t, ErrUnknownFlag, "unknown flag `v'", &opts, "cmd", "-v", "-g")
}
func TestCommandEstimate(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Cmd1 struct {
} `command:"remove"`
Cmd2 struct {
} `command:"add"`
}{}
p := NewParser(&opts, None)
_, err := p.ParseArgs([]string{})
assertError(t, err, ErrRequired, "Please specify one command of: add or remove")
}
type testCommand struct {
G bool `short:"g"`
Executed bool
EArgs []string
}
func (c *testCommand) Execute(args []string) error {
c.Executed = true
c.EArgs = args
return nil
}
func TestCommandExecute(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Command testCommand `command:"cmd"`
}{}
assertParseSuccess(t, &opts, "-v", "cmd", "-g", "a", "b")
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !opts.Command.Executed {
t.Errorf("Did not execute command")
}
if !opts.Command.G {
t.Errorf("Expected Command.C to be true")
}
assertStringArray(t, opts.Command.EArgs, []string{"a", "b"})
}
func TestCommandClosest(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Cmd1 struct {
} `command:"remove"`
Cmd2 struct {
} `command:"add"`
}{}
assertParseFail(t, ErrRequired, "Unknown command `addd', did you mean `add'?", &opts, "-v", "addd")
}
func TestCommandAdd(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
}{}
var cmd = struct {
G bool `short:"g"`
}{}
p := NewParser(&opts, Default)
c, err := p.AddCommand("cmd", "", "", &cmd)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}
ret, err := p.ParseArgs([]string{"-v", "cmd", "-g", "rest"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}
assertStringArray(t, ret, []string{"rest"})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !cmd.G {
t.Errorf("Expected Command.G to be true")
}
if p.Command.Find("cmd") != c {
t.Errorf("Expected to find command `cmd'")
}
if p.Commands()[0] != c {
t.Errorf("Espected command #v, but got #v", c, p.Commands()[0])
}
if c.Options()[0].ShortName != 'g' {
t.Errorf("Expected short name `g' but got %v", c.Options()[0].ShortName)
}
}
func TestCommandNestedInline(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Command struct {
G bool `short:"g"`
Nested struct {
N string `long:"n"`
} `command:"nested"`
} `command:"cmd"`
}{}
p, ret := assertParserSuccess(t, &opts, "-v", "cmd", "-g", "nested", "--n", "n", "rest")
assertStringArray(t, ret, []string{"rest"})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !opts.Command.G {
t.Errorf("Expected Command.G to be true")
}
assertString(t, opts.Command.Nested.N, "n")
if c := p.Command.Find("cmd"); c == nil {
t.Errorf("Expected to find command `cmd'")
} else {
if c != p.Active {
t.Errorf("Expected `cmd' to be the active parser command")
}
if nested := c.Find("nested"); nested == nil {
t.Errorf("Expected to find command `nested'")
} else if nested != c.Active {
t.Errorf("Expected to find command `nested' to be the active `cmd' command")
}
}
}

View File

@@ -0,0 +1,315 @@
// Copyright 2012 Jesse van den Kieboom. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flags
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
// Marshaler is the interface implemented by types that can marshal themselves
// to a string representation of the flag.
type Marshaler interface {
// MarshalFlag marshals a flag value to its string representation.
MarshalFlag() (string, error)
}
// Unmarshaler is the interface implemented by types that can unmarshal a flag
// argument to themselves. The provided value is directly passed from the
// command line.
type Unmarshaler interface {
// UnmarshalFlag unmarshals a string value representation to the flag
// value (which therefore needs to be a pointer receiver).
UnmarshalFlag(value string) error
}
func getBase(options multiTag, base int) (int, error) {
sbase := options.Get("base")
var err error
var ivbase int64
if sbase != "" {
ivbase, err = strconv.ParseInt(sbase, 10, 32)
base = int(ivbase)
}
return base, err
}
func convertMarshal(val reflect.Value) (bool, string, error) {
// Check first for the Marshaler interface
if val.Type().NumMethod() > 0 && val.CanInterface() {
if marshaler, ok := val.Interface().(Marshaler); ok {
ret, err := marshaler.MarshalFlag()
return true, ret, err
}
}
return false, "", nil
}
func convertToString(val reflect.Value, options multiTag) (string, error) {
if ok, ret, err := convertMarshal(val); ok {
return ret, err
}
tp := val.Type()
// Support for time.Duration
if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
stringer := val.Interface().(fmt.Stringer)
return stringer.String(), nil
}
switch tp.Kind() {
case reflect.String:
return val.String(), nil
case reflect.Bool:
if val.Bool() {
return "true", nil
}
return "false", nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
base, _ := getBase(options, 10)
return strconv.FormatInt(val.Int(), base), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
base, _ := getBase(options, 10)
return strconv.FormatUint(val.Uint(), base), nil
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(val.Float(), 'g', -1, tp.Bits()), nil
case reflect.Slice:
if val.Len() == 0 {
return "", nil
}
ret := "["
for i := 0; i < val.Len(); i++ {
if i != 0 {
ret += ", "
}
item, err := convertToString(val.Index(i), options)
if err != nil {
return "", err
}
ret += item
}
return ret + "]", nil
case reflect.Map:
ret := "{"
for i, key := range val.MapKeys() {
if i != 0 {
ret += ", "
}
item, err := convertToString(val.MapIndex(key), options)
if err != nil {
return "", err
}
ret += item
}
return ret + "}", nil
case reflect.Ptr:
return convertToString(reflect.Indirect(val), options)
case reflect.Interface:
if !val.IsNil() {
return convertToString(val.Elem(), options)
}
}
return "", nil
}
func convertUnmarshal(val string, retval reflect.Value) (bool, error) {
if retval.Type().NumMethod() > 0 && retval.CanInterface() {
if unmarshaler, ok := retval.Interface().(Unmarshaler); ok {
return true, unmarshaler.UnmarshalFlag(val)
}
}
if retval.Type().Kind() != reflect.Ptr && retval.CanAddr() {
return convertUnmarshal(val, retval.Addr())
}
if retval.Type().Kind() == reflect.Interface && !retval.IsNil() {
return convertUnmarshal(val, retval.Elem())
}
return false, nil
}
func convert(val string, retval reflect.Value, options multiTag) error {
if ok, err := convertUnmarshal(val, retval); ok {
return err
}
tp := retval.Type()
// Support for time.Duration
if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
parsed, err := time.ParseDuration(val)
if err != nil {
return err
}
retval.SetInt(int64(parsed))
return nil
}
switch tp.Kind() {
case reflect.String:
retval.SetString(val)
case reflect.Bool:
if val == "" {
retval.SetBool(true)
} else {
b, err := strconv.ParseBool(val)
if err != nil {
return err
}
retval.SetBool(b)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
base, err := getBase(options, 10)
if err != nil {
return err
}
parsed, err := strconv.ParseInt(val, base, tp.Bits())
if err != nil {
return err
}
retval.SetInt(parsed)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
base, err := getBase(options, 10)
if err != nil {
return err
}
parsed, err := strconv.ParseUint(val, base, tp.Bits())
if err != nil {
return err
}
retval.SetUint(parsed)
case reflect.Float32, reflect.Float64:
parsed, err := strconv.ParseFloat(val, tp.Bits())
if err != nil {
return err
}
retval.SetFloat(parsed)
case reflect.Slice:
elemtp := tp.Elem()
elemvalptr := reflect.New(elemtp)
elemval := reflect.Indirect(elemvalptr)
if err := convert(val, elemval, options); err != nil {
return err
}
retval.Set(reflect.Append(retval, elemval))
case reflect.Map:
parts := strings.SplitN(val, ":", 2)
key := parts[0]
var value string
if len(parts) == 2 {
value = parts[1]
}
keytp := tp.Key()
keyval := reflect.New(keytp)
if err := convert(key, keyval, options); err != nil {
return err
}
valuetp := tp.Elem()
valueval := reflect.New(valuetp)
if err := convert(value, valueval, options); err != nil {
return err
}
if retval.IsNil() {
retval.Set(reflect.MakeMap(tp))
}
retval.SetMapIndex(reflect.Indirect(keyval), reflect.Indirect(valueval))
case reflect.Ptr:
if retval.IsNil() {
retval.Set(reflect.New(retval.Type().Elem()))
}
return convert(val, reflect.Indirect(retval), options)
case reflect.Interface:
if !retval.IsNil() {
return convert(val, retval.Elem(), options)
}
}
return nil
}
func wrapText(s string, l int, prefix string) string {
// Basic text wrapping of s at spaces to fit in l
var ret string
s = strings.TrimSpace(s)
for len(s) > l {
// Try to split on space
suffix := ""
pos := strings.LastIndex(s[:l], " ")
if pos < 0 {
pos = l - 1
suffix = "-\n"
}
if len(ret) != 0 {
ret += "\n" + prefix
}
ret += strings.TrimSpace(s[:pos]) + suffix
s = strings.TrimSpace(s[pos:])
}
if len(s) > 0 {
if len(ret) != 0 {
ret += "\n" + prefix
}
return ret + s
}
return ret
}

View File

@@ -0,0 +1,113 @@
package flags
import (
"fmt"
)
// ErrorType represents the type of error.
type ErrorType uint
const (
// ErrUnknown indicates a generic error.
ErrUnknown ErrorType = iota
// ErrExpectedArgument indicates that an argument was expected.
ErrExpectedArgument
// ErrUnknownFlag indicates an unknown flag.
ErrUnknownFlag
// ErrUnknownGroup indicates an unknown group.
ErrUnknownGroup
// ErrMarshal indicates a marshalling error while converting values.
ErrMarshal
// ErrHelp indicates that the builtin help was shown (the error
// contains the help message).
ErrHelp
// ErrNoArgumentForBool indicates that an argument was given for a
// boolean flag (which don't not take any arguments).
ErrNoArgumentForBool
// ErrRequired indicates that a required flag was not provided.
ErrRequired
// ErrShortNameTooLong indicates that a short flag name was specified,
// longer than one character.
ErrShortNameTooLong
// ErrDuplicatedFlag indicates that a short or long flag has been
// defined more than once
ErrDuplicatedFlag
// ErrTag indicates an error while parsing flag tags.
ErrTag
)
// String returns a string representation of the error type.
func (e ErrorType) String() string {
switch e {
case ErrUnknown:
return "unknown"
case ErrExpectedArgument:
return "expected argument"
case ErrUnknownFlag:
return "unknown flag"
case ErrUnknownGroup:
return "unknown group"
case ErrMarshal:
return "marshal"
case ErrHelp:
return "help"
case ErrNoArgumentForBool:
return "no argument for bool"
case ErrRequired:
return "required"
case ErrShortNameTooLong:
return "short name too long"
case ErrDuplicatedFlag:
return "duplicated flag"
case ErrTag:
return "tag"
}
return "unknown"
}
// Error represents a parser error. The error returned from Parse is of this
// type. The error contains both a Type and Message.
type Error struct {
// The type of error
Type ErrorType
// The error message
Message string
}
// Error returns the error's message
func (e *Error) Error() string {
return e.Message
}
func newError(tp ErrorType, message string) *Error {
return &Error{
Type: tp,
Message: message,
}
}
func newErrorf(tp ErrorType, format string, args ...interface{}) *Error {
return newError(tp, fmt.Sprintf(format, args...))
}
func wrapError(err error) *Error {
ret, ok := err.(*Error)
if !ok {
return newError(ErrUnknown, err.Error())
}
return ret
}

View File

@@ -0,0 +1,95 @@
// Example of use of the flags package.
package flags
import (
"fmt"
"os"
"os/exec"
"strings"
)
func Example() {
var opts struct {
// Slice of bool will append 'true' each time the option
// is encountered (can be set multiple times, like -vvv)
Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"`
// Example of automatic marshalling to desired type (uint)
Offset uint `long:"offset" description:"Offset"`
// Example of a callback, called each time the option is found.
Call func(string) `short:"c" description:"Call phone number"`
// Example of a required flag
Name string `short:"n" long:"name" description:"A name" required:"true"`
// Example of a value name
File string `short:"f" long:"file" description:"A file" value-name:"FILE"`
// Example of a pointer
Ptr *int `short:"p" description:"A pointer to an integer"`
// Example of a slice of strings
StringSlice []string `short:"s" description:"A slice of strings"`
// Example of a slice of pointers
PtrSlice []*string `long:"ptrslice" description:"A slice of pointers to string"`
// Example of a map
IntMap map[string]int `long:"intmap" description:"A map from string to int"`
}
// Callback which will invoke callto:<argument> to call a number.
// Note that this works just on OS X (and probably only with
// Skype) but it shows the idea.
opts.Call = func(num string) {
cmd := exec.Command("open", "callto:"+num)
cmd.Start()
cmd.Process.Release()
}
// Make some fake arguments to parse.
args := []string{
"-vv",
"--offset=5",
"-n", "Me",
"-p", "3",
"-s", "hello",
"-s", "world",
"--ptrslice", "hello",
"--ptrslice", "world",
"--intmap", "a:1",
"--intmap", "b:5",
"arg1",
"arg2",
"arg3",
}
// Parse flags from `args'. Note that here we use flags.ParseArgs for
// the sake of making a working example. Normally, you would simply use
// flags.Parse(&opts) which uses os.Args
args, err := ParseArgs(&opts, args)
if err != nil {
panic(err)
os.Exit(1)
}
fmt.Printf("Verbosity: %v\n", opts.Verbose)
fmt.Printf("Offset: %d\n", opts.Offset)
fmt.Printf("Name: %s\n", opts.Name)
fmt.Printf("Ptr: %d\n", *opts.Ptr)
fmt.Printf("StringSlice: %v\n", opts.StringSlice)
fmt.Printf("PtrSlice: [%v %v]\n", *opts.PtrSlice[0], *opts.PtrSlice[1])
fmt.Printf("IntMap: [a:%v b:%v]\n", opts.IntMap["a"], opts.IntMap["b"])
fmt.Printf("Remaining args: %s\n", strings.Join(args, " "))
// Output: Verbosity: [true true]
// Offset: 5
// Name: Me
// Ptr: 3
// StringSlice: [hello world]
// PtrSlice: [hello world]
// IntMap: [a:1 b:5]
// Remaining args: arg1 arg2 arg3
}

View File

@@ -0,0 +1,23 @@
package main
import (
"fmt"
)
type AddCommand struct {
All bool `short:"a" long:"all" description:"Add all files"`
}
var addCommand AddCommand
func (x *AddCommand) Execute(args []string) error {
fmt.Printf("Adding (all=%v): %#v\n", x.All, args)
return nil
}
func init() {
parser.AddCommand("add",
"Add a file",
"The add command adds a file to the repository. Use -a to add all files.",
&addCommand)
}

View File

@@ -0,0 +1,75 @@
package main
import (
"errors"
"fmt"
"github.com/calmh/syncthing/github.com/jessevdk/go-flags"
"os"
"strconv"
"strings"
)
type EditorOptions struct {
Input string `short:"i" long:"input" description:"Input file" default:"-"`
Output string `short:"o" long:"output" description:"Output file" default:"-"`
}
type Point struct {
X, Y int
}
func (p *Point) UnmarshalFlag(value string) error {
parts := strings.Split(value, ",")
if len(parts) != 2 {
return errors.New("Expected two numbers separated by a ,")
}
x, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return err
}
y, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return err
}
p.X = int(x)
p.Y = int(y)
return nil
}
func (p Point) MarshalFlag() (string, error) {
return fmt.Sprintf("%d,%d", p.X, p.Y), nil
}
type Options struct {
// Example of verbosity with level
Verbose []bool `short:"v" long:"verbose" description:"Verbose output"`
// Example of optional value
User string `short:"u" long:"user" description:"User name" optional:"yes" optional-value:"pancake"`
// Example of map with multiple default values
Users map[string]string `long:"users" description:"User e-mail map" default:"system:system@example.org" default:"admin:admin@example.org"`
// Example of option group
Editor EditorOptions `group:"Editor Options"`
// Example of custom type Marshal/Unmarshal
Point Point `long:"point" description:"A x,y point" default:"1,2"`
}
var options Options
var parser = flags.NewParser(&options, flags.Default)
func main() {
if _, err := parser.Parse(); err != nil {
os.Exit(1)
}
}

View File

@@ -0,0 +1,23 @@
package main
import (
"fmt"
)
type RmCommand struct {
Force bool `short:"f" long:"force" description:"Force removal of files"`
}
var rmCommand RmCommand
func (x *RmCommand) Execute(args []string) error {
fmt.Printf("Removing (force=%v): %#v\n", x.Force, args)
return nil
}
func init() {
parser.AddCommand("rm",
"Remove a file",
"The rm command removes a file to the repository. Use -f to force removal of files.",
&rmCommand)
}

View File

@@ -0,0 +1,141 @@
// Copyright 2012 Jesse van den Kieboom. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package flags provides an extensive command line option parser.
// The flags package is similar in functionality to the go builtin flag package
// but provides more options and uses reflection to provide a convenient and
// succinct way of specifying command line options.
//
// Supported features:
// Options with short names (-v)
// Options with long names (--verbose)
// Options with and without arguments (bool v.s. other type)
// Options with optional arguments and default values
// Multiple option groups each containing a set of options
// Generate and print well-formatted help message
// Passing remaining command line arguments after -- (optional)
// Ignoring unknown command line options (optional)
// Supports -I/usr/include -I=/usr/include -I /usr/include option argument specification
// Supports multiple short options -aux
// Supports all primitive go types (string, int{8..64}, uint{8..64}, float)
// Supports same option multiple times (can store in slice or last option counts)
// Supports maps
// Supports function callbacks
//
// Additional features specific to Windows:
// Options with short names (/v)
// Options with long names (/verbose)
// Windows-style options with arguments use a colon as the delimiter
// Modify generated help message with Windows-style / options
//
// The flags package uses structs, reflection and struct field tags
// to allow users to specify command line options. This results in very simple
// and consise specification of your application options. For example:
//
// type Options struct {
// Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"`
// }
//
// This specifies one option with a short name -v and a long name --verbose.
// When either -v or --verbose is found on the command line, a 'true' value
// will be appended to the Verbose field. e.g. when specifying -vvv, the
// resulting value of Verbose will be {[true, true, true]}.
//
// Slice options work exactly the same as primitive type options, except that
// whenever the option is encountered, a value is appended to the slice.
//
// Map options from string to primitive type are also supported. On the command
// line, you specify the value for such an option as key:value. For example
//
// type Options struct {
// AuthorInfo string[string] `short:"a"`
// }
//
// Then, the AuthorInfo map can be filled with something like
// -a name:Jesse -a "surname:van den Kieboom".
//
// Finally, for full control over the conversion between command line argument
// values and options, user defined types can choose to implement the Marshaler
// and Unmarshaler interfaces.
//
// Available field tags:
// short: the short name of the option (single character)
// long: the long name of the option
// description: the description of the option (optional)
// optional: whether an argument of the option is optional (optional)
// optional-value: the value of an optional option when the option occurs
// without an argument. This tag can be specified multiple
// times in the case of maps or slices (optional)
// default: the default value of an option. This tag can be specified
// multiple times in the case of slices or maps (optional).
// default-mask: when specified, this value will be displayed in the help
// instead of the actual default value. This is useful
// mostly for hiding otherwise sensitive information from
// showing up in the help. If default-mask takes the special
// value "-", then no default value will be shown at all
// (optional)
// required: whether an option is required to appear on the command
// line. If a required option is not present, the parser
// will return ErrRequired.
// base: a base (radix) used to convert strings to integer values,
// the default base is 10 (i.e. decimal) (optional)
// value-name: the name of the argument value (to be shown in the help,
// (optional)
// group: when specified on a struct field, makes the struct field
// a separate group with the given name (optional).
// command: when specified on a struct field, makes the struct field
// a (sub)command with the given name (optional).
//
// Either short: or long: must be specified to make the field eligible as an
// option.
//
//
// Option groups:
//
// Option groups are a simple way to semantically separate your options. The
// only real difference is in how your options will appear in the builtin
// generated help. All options in a particular group are shown together in the
// help under the name of the group.
//
// There are currently three ways to specify option groups.
//
// 1. Use NewNamedParser specifying the various option groups.
// 2. Use AddGroup to add a group to an existing parser.
// 3. Add a struct field to the toplevel options annotated with the
// group:"group-name" tag.
//
//
//
// Commands:
//
// The flags package also has basic support for commands. Commands are often
// used in monolithic applications that support various commands or actions.
// Take git for example, all of the add, commit, checkout, etc. are called
// commands. Using commands you can easily separate multiple functions of your
// application.
//
// There are currently two ways to specifiy a command.
//
// 1. Use AddCommand on an existing parser.
// 2. Add a struct field to your options struct annotated with the
// command:"command-name" tag.
//
// The most common, idiomatic way to implement commands is to define a global
// parser instance and implement each command in a separate file. These
// command files should define a go init function which calls AddCommand on
// the global parser.
//
// When parsing ends and there is an active command and that command implements
// the Commander interface, then its Execute method will be run with the
// remaining command line arguments.
//
// Command structs can have options which become valid to parse after the
// command has been specified on the command line. It is currently not valid
// to specify options from the parent level of the command after the command
// name has occurred. Thus, given a toplevel option "-v" and a command "add":
//
// Valid: ./app -v add
// Invalid: ./app add -v
//
package flags

View File

@@ -0,0 +1,80 @@
// Copyright 2012 Jesse van den Kieboom. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flags
import (
"errors"
"strings"
)
// ErrNotPointerToStruct indicates that a provided data container is not
// a pointer to a struct. Only pointers to structs are valid data containers
// for options.
var ErrNotPointerToStruct = errors.New("provided data is not a pointer to struct")
// Group represents an option group. Option groups can be used to logically
// group options together under a description. Groups are only used to provide
// more structure to options both for the user (as displayed in the help message)
// and for you, since groups can be nested.
type Group struct {
// A short description of the group. The
// short description is primarily used in the builtin generated help
// message
ShortDescription string
// A long description of the group. The long
// description is primarily used to present information on commands
// (Command embeds Group) in the builtin generated help and man pages.
LongDescription string
// All the options in the group
options []*Option
// All the subgroups
groups []*Group
data interface{}
}
// AddGroup adds a new group to the command with the given name and data. The
// data needs to be a pointer to a struct from which the fields indicate which
// options are in the group.
func (g *Group) AddGroup(shortDescription string, longDescription string, data interface{}) (*Group, error) {
group := newGroup(shortDescription, longDescription, data)
if err := group.scan(); err != nil {
return nil, err
}
g.groups = append(g.groups, group)
return group, nil
}
// Groups returns the list of groups embedded in this group.
func (g *Group) Groups() []*Group {
return g.groups
}
// Options returns the list of options in this group.
func (g *Group) Options() []*Option {
return g.options
}
// Find locates the subgroup with the given short description and returns it.
// If no such group can be found Find will return nil. Note that the description
// is matched case insensitively.
func (g *Group) Find(shortDescription string) *Group {
lshortDescription := strings.ToLower(shortDescription)
var ret *Group
g.eachGroup(func(gg *Group) {
if gg != g && strings.ToLower(gg.ShortDescription) == lshortDescription {
ret = gg
}
})
return ret
}

View File

@@ -0,0 +1,263 @@
package flags
import (
"reflect"
"unicode/utf8"
"unsafe"
)
type scanHandler func(reflect.Value, *reflect.StructField) (bool, error)
func newGroup(shortDescription string, longDescription string, data interface{}) *Group {
return &Group{
ShortDescription: shortDescription,
LongDescription: longDescription,
data: data,
}
}
func (g *Group) optionByName(name string, namematch func(*Option, string) bool) *Option {
prio := 0
var retopt *Option
for _, opt := range g.options {
if namematch != nil && namematch(opt, name) && prio < 4 {
retopt = opt
prio = 4
}
if name == opt.field.Name && prio < 3 {
retopt = opt
prio = 3
}
if name == opt.LongName && prio < 2 {
retopt = opt
prio = 2
}
if opt.ShortName != 0 && name == string(opt.ShortName) && prio < 1 {
retopt = opt
prio = 1
}
}
return retopt
}
func (g *Group) storeDefaults() {
for _, option := range g.options {
// First. empty out the value
if len(option.Default) > 0 {
option.clear()
}
for _, d := range option.Default {
option.set(&d)
}
if !option.value.CanSet() {
continue
}
option.defaultValue = reflect.ValueOf(option.value.Interface())
}
}
func (g *Group) eachGroup(f func(*Group)) {
f(g)
for _, gg := range g.groups {
gg.eachGroup(f)
}
}
func (g *Group) scanStruct(realval reflect.Value, sfield *reflect.StructField, handler scanHandler) error {
stype := realval.Type()
if sfield != nil {
if ok, err := handler(realval, sfield); err != nil {
return err
} else if ok {
return nil
}
}
for i := 0; i < stype.NumField(); i++ {
field := stype.Field(i)
// PkgName is set only for non-exported fields, which we ignore
if field.PkgPath != "" {
continue
}
mtag := newMultiTag(string(field.Tag))
if err := mtag.Parse(); err != nil {
return err
}
// Skip fields with the no-flag tag
if mtag.Get("no-flag") != "" {
continue
}
// Dive deep into structs or pointers to structs
kind := field.Type.Kind()
fld := realval.Field(i)
if kind == reflect.Struct {
if err := g.scanStruct(fld, &field, handler); err != nil {
return err
}
} else if kind == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
if fld.IsNil() {
fld.Set(reflect.New(fld.Type().Elem()))
}
if err := g.scanStruct(reflect.Indirect(fld), &field, handler); err != nil {
return err
}
}
longname := mtag.Get("long")
shortname := mtag.Get("short")
// Need at least either a short or long name
if longname == "" && shortname == "" && mtag.Get("ini-name") == "" {
continue
}
short := rune(0)
rc := utf8.RuneCountInString(shortname)
if rc > 1 {
return newErrorf(ErrShortNameTooLong,
"short names can only be 1 character long, not `%s'",
shortname)
} else if rc == 1 {
short, _ = utf8.DecodeRuneInString(shortname)
}
description := mtag.Get("description")
def := mtag.GetMany("default")
optionalValue := mtag.GetMany("optional-value")
valueName := mtag.Get("value-name")
defaultMask := mtag.Get("default-mask")
optional := (mtag.Get("optional") != "")
required := (mtag.Get("required") != "")
option := &Option{
Description: description,
ShortName: short,
LongName: longname,
Default: def,
OptionalArgument: optional,
OptionalValue: optionalValue,
Required: required,
ValueName: valueName,
DefaultMask: defaultMask,
field: field,
value: realval.Field(i),
tag: mtag,
}
g.options = append(g.options, option)
}
return nil
}
func (g *Group) checkForDuplicateFlags() *Error {
shortNames := make(map[rune]*Option)
longNames := make(map[string]*Option)
var duplicateError *Error
g.eachGroup(func(g *Group) {
for _, option := range g.options {
if option.LongName != "" {
if otherOption, ok := longNames[option.LongName]; ok {
duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same long name as option `%s'", option, otherOption)
return
}
longNames[option.LongName] = option
}
if option.ShortName != 0 {
if otherOption, ok := shortNames[option.ShortName]; ok {
duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same short name as option `%s'", option, otherOption)
return
}
shortNames[option.ShortName] = option
}
}
})
return duplicateError
}
func (g *Group) scanSubGroupHandler(realval reflect.Value, sfield *reflect.StructField) (bool, error) {
mtag := newMultiTag(string(sfield.Tag))
if err := mtag.Parse(); err != nil {
return true, err
}
subgroup := mtag.Get("group")
if len(subgroup) != 0 {
ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr()))
description := mtag.Get("description")
if _, err := g.AddGroup(subgroup, description, ptrval.Interface()); err != nil {
return true, err
}
return true, nil
}
return false, nil
}
func (g *Group) scanType(handler scanHandler) error {
// Get all the public fields in the data struct
ptrval := reflect.ValueOf(g.data)
if ptrval.Type().Kind() != reflect.Ptr {
panic(ErrNotPointerToStruct)
}
stype := ptrval.Type().Elem()
if stype.Kind() != reflect.Struct {
panic(ErrNotPointerToStruct)
}
realval := reflect.Indirect(ptrval)
if err := g.scanStruct(realval, nil, handler); err != nil {
return err
}
if err := g.checkForDuplicateFlags(); err != nil {
return err
}
return nil
}
func (g *Group) scan() error {
return g.scanType(g.scanSubGroupHandler)
}
func (g *Group) groupByName(name string) *Group {
if len(name) == 0 {
return g
}
return g.Find(name)
}

View File

@@ -0,0 +1,160 @@
package flags
import (
"testing"
)
func TestGroupInline(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Group struct {
G bool `short:"g"`
} `group:"Grouped Options"`
}{}
p, ret := assertParserSuccess(t, &opts, "-v", "-g")
assertStringArray(t, ret, []string{})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !opts.Group.G {
t.Errorf("Expected Group.G to be true")
}
if p.Command.Group.Find("Grouped Options") == nil {
t.Errorf("Expected to find group `Grouped Options'")
}
}
func TestGroupAdd(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
}{}
var grp = struct {
G bool `short:"g"`
}{}
p := NewParser(&opts, Default)
g, err := p.AddGroup("Grouped Options", "", &grp)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}
ret, err := p.ParseArgs([]string{"-v", "-g", "rest"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}
assertStringArray(t, ret, []string{"rest"})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !grp.G {
t.Errorf("Expected Group.G to be true")
}
if p.Command.Group.Find("Grouped Options") != g {
t.Errorf("Expected to find group `Grouped Options'")
}
if p.Groups()[1] != g {
t.Errorf("Espected group #v, but got #v", g, p.Groups()[0])
}
if g.Options()[0].ShortName != 'g' {
t.Errorf("Expected short name `g' but got %v", g.Options()[0].ShortName)
}
}
func TestGroupNestedInline(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
Group struct {
G bool `short:"g"`
Nested struct {
N string `long:"n"`
} `group:"Nested Options"`
} `group:"Grouped Options"`
}{}
p, ret := assertParserSuccess(t, &opts, "-v", "-g", "--n", "n", "rest")
assertStringArray(t, ret, []string{"rest"})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
if !opts.Group.G {
t.Errorf("Expected Group.G to be true")
}
assertString(t, opts.Group.Nested.N, "n")
if p.Command.Group.Find("Grouped Options") == nil {
t.Errorf("Expected to find group `Grouped Options'")
}
if p.Command.Group.Find("Nested Options") == nil {
t.Errorf("Expected to find group `Nested Options'")
}
}
func TestDuplicateShortFlags(t *testing.T) {
var opts struct {
Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"`
Variables []string `short:"v" long:"variable" description:"Set a variable value."`
}
args := []string{
"--verbose",
"-v", "123",
"-v", "456",
}
_, err := ParseArgs(&opts, args)
if err == nil {
t.Errorf("Expected an error with type ErrDuplicatedFlag")
} else {
err2 := err.(*Error)
if err2.Type != ErrDuplicatedFlag {
t.Errorf("Expected an error with type ErrDuplicatedFlag")
}
}
}
func TestDuplicateLongFlags(t *testing.T) {
var opts struct {
Test1 []bool `short:"a" long:"testing" description:"Test 1"`
Test2 []string `short:"b" long:"testing" description:"Test 2."`
}
args := []string{
"--testing",
}
_, err := ParseArgs(&opts, args)
if err == nil {
t.Errorf("Expected an error with type ErrDuplicatedFlag")
} else {
err2 := err.(*Error)
if err2.Type != ErrDuplicatedFlag {
t.Errorf("Expected an error with type ErrDuplicatedFlag")
}
}
}

View File

@@ -0,0 +1,275 @@
// Copyright 2012 Jesse van den Kieboom. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flags
import (
"bufio"
"bytes"
"fmt"
"io"
"reflect"
"strings"
"unicode/utf8"
)
type alignmentInfo struct {
maxLongLen int
hasShort bool
hasValueName bool
terminalColumns int
}
func (p *Parser) getAlignmentInfo() alignmentInfo {
ret := alignmentInfo{
maxLongLen: 0,
hasShort: false,
hasValueName: false,
terminalColumns: getTerminalColumns(),
}
if ret.terminalColumns <= 0 {
ret.terminalColumns = 80
}
p.eachActiveGroup(func(grp *Group) {
for _, info := range grp.options {
if info.ShortName != 0 {
ret.hasShort = true
}
lv := utf8.RuneCountInString(info.ValueName)
if lv != 0 {
ret.hasValueName = true
}
l := utf8.RuneCountInString(info.LongName) + lv
if l > ret.maxLongLen {
ret.maxLongLen = l
}
}
})
return ret
}
func (p *Parser) writeHelpOption(writer *bufio.Writer, option *Option, info alignmentInfo) {
line := &bytes.Buffer{}
distanceBetweenOptionAndDescription := 2
paddingBeforeOption := 2
line.WriteString(strings.Repeat(" ", paddingBeforeOption))
if option.ShortName != 0 {
line.WriteRune(defaultShortOptDelimiter)
line.WriteRune(option.ShortName)
} else if info.hasShort {
line.WriteString(" ")
}
descstart := info.maxLongLen + paddingBeforeOption + distanceBetweenOptionAndDescription
if info.hasShort {
descstart += 2
}
if info.maxLongLen > 0 {
descstart += 4
}
if info.hasValueName {
descstart += 3
}
if len(option.LongName) > 0 {
if option.ShortName != 0 {
line.WriteString(", ")
} else if info.hasShort {
line.WriteString(" ")
}
line.WriteString(defaultLongOptDelimiter)
line.WriteString(option.LongName)
}
if option.canArgument() {
line.WriteRune(defaultNameArgDelimiter)
if len(option.ValueName) > 0 {
line.WriteString(option.ValueName)
}
}
written := line.Len()
line.WriteTo(writer)
if option.Description != "" {
dw := descstart - written
writer.WriteString(strings.Repeat(" ", dw))
def := ""
defs := option.Default
if len(option.DefaultMask) != 0 {
if option.DefaultMask != "-" {
def = option.DefaultMask
}
} else if len(defs) == 0 && option.canArgument() {
var showdef bool
switch option.field.Type.Kind() {
case reflect.Func, reflect.Ptr:
showdef = !option.value.IsNil()
case reflect.Slice, reflect.String, reflect.Array:
showdef = option.value.Len() > 0
case reflect.Map:
showdef = !option.value.IsNil() && option.value.Len() > 0
default:
zeroval := reflect.Zero(option.field.Type)
showdef = !reflect.DeepEqual(zeroval.Interface(), option.value.Interface())
}
if showdef {
def, _ = convertToString(option.value, option.tag)
}
} else if len(defs) != 0 {
def = strings.Join(defs, ", ")
}
var desc string
if def != "" {
desc = fmt.Sprintf("%s (%v)", option.Description, def)
} else {
desc = option.Description
}
writer.WriteString(wrapText(desc,
info.terminalColumns-descstart,
strings.Repeat(" ", descstart)))
}
writer.WriteString("\n")
}
func maxCommandLength(s []*Command) int {
if len(s) == 0 {
return 0
}
ret := len(s[0].Name)
for _, v := range s[1:] {
l := len(v.Name)
if l > ret {
ret = l
}
}
return ret
}
// WriteHelp writes a help message containing all the possible options and
// their descriptions to the provided writer. Note that the HelpFlag parser
// option provides a convenient way to add a -h/--help option group to the
// command line parser which will automatically show the help messages using
// this method.
func (p *Parser) WriteHelp(writer io.Writer) {
if writer == nil {
return
}
wr := bufio.NewWriter(writer)
aligninfo := p.getAlignmentInfo()
cmd := p.Command
for cmd.Active != nil {
cmd = cmd.Active
}
if p.Name != "" {
wr.WriteString("Usage:\n")
wr.WriteString(" ")
allcmd := p.Command
for allcmd != nil {
var usage string
if allcmd == p.Command {
if len(p.Usage) != 0 {
usage = p.Usage
} else {
usage = "[OPTIONS]"
}
} else if us, ok := allcmd.data.(Usage); ok {
usage = us.Usage()
} else {
usage = fmt.Sprintf("[%s-OPTIONS]", allcmd.Name)
}
if len(usage) != 0 {
fmt.Fprintf(wr, " %s %s", allcmd.Name, usage)
} else {
fmt.Fprintf(wr, " %s", allcmd.Name)
}
allcmd = allcmd.Active
}
fmt.Fprintln(wr)
if len(cmd.LongDescription) != 0 {
fmt.Fprintln(wr)
t := wrapText(cmd.LongDescription,
aligninfo.terminalColumns,
"")
fmt.Fprintln(wr, t)
}
}
p.eachActiveGroup(func(grp *Group) {
first := true
for _, info := range grp.options {
if info.canCli() {
if first {
fmt.Fprintf(wr, "\n%s:\n", grp.ShortDescription)
first = false
}
p.writeHelpOption(wr, info, aligninfo)
}
}
})
scommands := cmd.sortedCommands()
if len(scommands) > 0 {
maxnamelen := maxCommandLength(scommands)
fmt.Fprintln(wr)
fmt.Fprintln(wr, "Available commands:")
for _, c := range scommands {
fmt.Fprintf(wr, " %s", c.Name)
if len(c.ShortDescription) > 0 {
pad := strings.Repeat(" ", maxnamelen-len(c.Name))
fmt.Fprintf(wr, "%s %s", pad, c.ShortDescription)
}
fmt.Fprintln(wr)
}
}
wr.Flush()
}

View File

@@ -0,0 +1,153 @@
package flags
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"testing"
"time"
)
func helpDiff(a, b string) (string, error) {
atmp, err := ioutil.TempFile("", "help-diff")
if err != nil {
return "", err
}
btmp, err := ioutil.TempFile("", "help-diff")
if err != nil {
return "", err
}
if _, err := io.WriteString(atmp, a); err != nil {
return "", err
}
if _, err := io.WriteString(btmp, b); err != nil {
return "", err
}
ret, err := exec.Command("diff", "-u", "-d", "--label", "got", atmp.Name(), "--label", "expected", btmp.Name()).Output()
os.Remove(atmp.Name())
os.Remove(btmp.Name())
return string(ret), nil
}
type helpOptions struct {
Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information" ini-name:"verbose"`
Call func(string) `short:"c" description:"Call phone number" ini-name:"call"`
PtrSlice []*string `long:"ptrslice" description:"A slice of pointers to string"`
OnlyIni string `ini-name:"only-ini" description:"Option only available in ini"`
Other struct {
StringSlice []string `short:"s" description:"A slice of strings"`
IntMap map[string]int `long:"intmap" description:"A map from string to int" ini-name:"int-map"`
} `group:"Other Options"`
}
func TestHelp(t *testing.T) {
var opts helpOptions
p := NewNamedParser("TestHelp", HelpFlag)
p.AddGroup("Application Options", "The application options", &opts)
_, err := p.ParseArgs([]string{"--help"})
if err == nil {
t.Fatalf("Expected help error")
}
if e, ok := err.(*Error); !ok {
t.Fatalf("Expected flags.Error, but got %#T", err)
} else {
if e.Type != ErrHelp {
t.Errorf("Expected flags.ErrHelp type, but got %s", e.Type)
}
expected := `Usage:
TestHelp [OPTIONS]
Application Options:
-v, --verbose Show verbose debug information
-c= Call phone number
--ptrslice= A slice of pointers to string
Other Options:
-s= A slice of strings
--intmap= A map from string to int
Help Options:
-h, --help Show this help message
`
if e.Message != expected {
ret, err := helpDiff(e.Message, expected)
if err != nil {
t.Errorf("Unexpected diff error: %s", err)
t.Errorf("Unexpected help message, expected:\n\n%s\n\nbut got\n\n%s", expected, e.Message)
} else {
t.Errorf("Unexpected help message:\n\n%s", ret)
}
}
}
}
func TestMan(t *testing.T) {
var opts helpOptions
p := NewNamedParser("TestMan", HelpFlag)
p.ShortDescription = "Test manpage generation"
p.LongDescription = "This is a somewhat longer description of what this does"
p.AddGroup("Application Options", "The application options", &opts)
var buf bytes.Buffer
p.WriteManPage(&buf)
got := buf.String()
tt := time.Now()
expected := fmt.Sprintf(`.TH TestMan 1 "%s"
.SH NAME
TestMan \- Test manpage generation
.SH SYNOPSIS
\fBTestMan\fP [OPTIONS]
.SH DESCRIPTION
This is a somewhat longer description of what this does
.SH OPTIONS
.TP
\fB-v, --verbose\fP
Show verbose debug information
.TP
\fB-c\fP
Call phone number
.TP
\fB--ptrslice\fP
A slice of pointers to string
.TP
\fB-s\fP
A slice of strings
.TP
\fB--intmap\fP
A map from string to int
`, tt.Format("2 January 2006"))
if got != expected {
ret, err := helpDiff(got, expected)
if err != nil {
t.Errorf("Unexpected man page, expected:\n\n%s\n\nbut got\n\n%s", expected, got)
} else {
t.Errorf("Unexpected man page:\n\n%s", ret)
}
}
}

View File

@@ -0,0 +1,146 @@
package flags
import (
"fmt"
"io"
"os"
)
// IniError contains location information on where in the ini file an error
// occured.
type IniError struct {
// The error message.
Message string
// The filename of the file in which the error occurred.
File string
// The line number at which the error occurred.
LineNumber uint
}
// Error provides a "file:line: message" formatted message of the ini error.
func (x *IniError) Error() string {
return fmt.Sprintf("%s:%d: %s",
x.File,
x.LineNumber,
x.Message)
}
// IniOptions for writing ini files
type IniOptions uint
const (
// IniNone indicates no options.
IniNone IniOptions = 0
// IniIncludeDefaults indicates that default values should be written
// when writing options to an ini file.
IniIncludeDefaults = 1 << iota
// IniIncludeComments indicates that comments containing the description
// of an option should be written when writing options to an ini file.
IniIncludeComments
// IniDefault provides a default set of options.
IniDefault = IniIncludeComments
)
// IniParser is a utility to read and write flags options from and to ini
// files.
type IniParser struct {
parser *Parser
}
// NewIniParser creates a new ini parser for a given Parser.
func NewIniParser(p *Parser) *IniParser {
return &IniParser{
parser: p,
}
}
// IniParse is a convenience function to parse command line options with default
// settings from an ini file. The provided data is a pointer to a struct
// representing the default option group (named "Application Options"). For
// more control, use flags.NewParser.
func IniParse(filename string, data interface{}) error {
p := NewParser(data, Default)
return NewIniParser(p).ParseFile(filename)
}
// ParseFile parses flags from an ini formatted file. See Parse for more
// information on the ini file foramt. The returned errors can be of the type
// flags.Error or flags.IniError.
func (i *IniParser) ParseFile(filename string) error {
i.parser.storeDefaults()
ini, err := readIniFromFile(filename)
if err != nil {
return err
}
return i.parse(ini)
}
// Parse parses flags from an ini format. You can use ParseFile as a
// convenience function to parse from a filename instead of a general
// io.Reader.
//
// The format of the ini file is as follows:
//
// [Option group name]
// option = value
//
// Each section in the ini file represents an option group or command in the
// flags parser. The default flags parser option group (i.e. when using
// flags.Parse) is named 'Application Options'. The ini option name is matched
// in the following order:
//
// 1. Compared to the ini-name tag on the option struct field (if present)
// 2. Compared to the struct field name
// 3. Compared to the option long name (if present)
// 4. Compared to the option short name (if present)
//
// Sections for nested groups and commands can be addressed using a dot `.'
// namespacing notation (i.e [subcommand.Options]). Group section names are
// matched case insensitive.
//
// The returned errors can be of the type flags.Error or
// flags.IniError.
func (i *IniParser) Parse(reader io.Reader) error {
i.parser.storeDefaults()
ini, err := readIni(reader, "")
if err != nil {
return err
}
return i.parse(ini)
}
// WriteFile writes the flags as ini format into a file. See WriteIni
// for more information. The returned error occurs when the specified file
// could not be opened for writing.
func (i *IniParser) WriteFile(filename string, options IniOptions) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
i.Write(file, options)
return nil
}
// Write writes the current values of all the flags to an ini format.
// See Parse for more information on the ini file format. You typically
// call this only after settings have been parsed since the default values of each
// option are stored just before parsing the flags (this is only relevant when
// IniIncludeDefaults is _not_ set in options).
func (i *IniParser) Write(writer io.Writer, options IniOptions) {
writeIni(i, writer, options)
}

View File

@@ -0,0 +1,333 @@
package flags
import (
"bufio"
"fmt"
"io"
"os"
"reflect"
"strings"
)
type iniValue struct {
Name string
Value string
}
type iniSection []iniValue
type ini map[string]iniSection
func readFullLine(reader *bufio.Reader) (string, error) {
var line []byte
for {
l, more, err := reader.ReadLine()
if err != nil {
return "", err
}
if line == nil && !more {
return string(l), nil
}
line = append(line, l...)
if !more {
break
}
}
return string(line), nil
}
func optionIniName(option *Option) string {
name := option.tag.Get("_read-ini-name")
if len(name) != 0 {
return name
}
name = option.tag.Get("ini-name")
if len(name) != 0 {
return name
}
return option.field.Name
}
func writeGroupIni(group *Group, namespace string, writer io.Writer, options IniOptions) {
var sname string
if len(namespace) != 0 {
sname = namespace + "." + group.ShortDescription
} else {
sname = group.ShortDescription
}
sectionwritten := false
comments := (options & IniIncludeComments) != IniNone
for _, option := range group.options {
if option.isFunc() {
continue
}
if len(option.tag.Get("no-ini")) != 0 {
continue
}
val := option.value
if (options&IniIncludeDefaults) == IniNone &&
reflect.DeepEqual(val, option.defaultValue) {
continue
}
if !sectionwritten {
fmt.Fprintf(writer, "[%s]\n", sname)
sectionwritten = true
}
if comments {
fmt.Fprintf(writer, "; %s\n", option.Description)
}
oname := optionIniName(option)
switch val.Type().Kind() {
case reflect.Slice:
for idx := 0; idx < val.Len(); idx++ {
v, _ := convertToString(val.Index(idx), option.tag)
fmt.Fprintf(writer, "%s = %s\n", oname, v)
}
if val.Len() == 0 {
fmt.Fprintf(writer, "; %s =\n", oname)
}
case reflect.Map:
for _, key := range val.MapKeys() {
k, _ := convertToString(key, option.tag)
v, _ := convertToString(val.MapIndex(key), option.tag)
fmt.Fprintf(writer, "%s = %s:%s\n", oname, k, v)
}
if val.Len() == 0 {
fmt.Fprintf(writer, "; %s =\n", oname)
}
default:
v, _ := convertToString(val, option.tag)
if len(v) != 0 {
fmt.Fprintf(writer, "%s = %s\n", oname, v)
} else {
fmt.Fprintf(writer, "%s =\n", oname)
}
}
if comments {
fmt.Fprintln(writer)
}
}
if sectionwritten && !comments {
fmt.Fprintln(writer)
}
}
func writeCommandIni(command *Command, namespace string, writer io.Writer, options IniOptions) {
command.eachGroup(func(group *Group) {
writeGroupIni(group, namespace, writer, options)
})
for _, c := range command.commands {
var nns string
if len(namespace) != 0 {
nns = c.Name + "." + nns
} else {
nns = c.Name
}
writeCommandIni(c, nns, writer, options)
}
}
func writeIni(parser *IniParser, writer io.Writer, options IniOptions) {
writeCommandIni(parser.parser.Command, "", writer, options)
}
func readIniFromFile(filename string) (ini, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
return readIni(file, filename)
}
func readIni(contents io.Reader, filename string) (ini, error) {
ret := make(ini)
reader := bufio.NewReader(contents)
// Empty global section
section := make(iniSection, 0, 10)
sectionname := ""
ret[sectionname] = section
var lineno uint
for {
line, err := readFullLine(reader)
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
lineno++
line = strings.TrimSpace(line)
// Skip empty lines and lines starting with ; (comments)
if len(line) == 0 || line[0] == ';' {
continue
}
if line[0] == '[' {
if line[0] != '[' || line[len(line)-1] != ']' {
return nil, &IniError{
Message: "malformed section header",
File: filename,
LineNumber: lineno,
}
}
name := strings.TrimSpace(line[1 : len(line)-1])
if len(name) == 0 {
return nil, &IniError{
Message: "empty section name",
File: filename,
LineNumber: lineno,
}
}
sectionname = name
section = ret[name]
if section == nil {
section = make(iniSection, 0, 10)
ret[name] = section
}
continue
}
// Parse option here
keyval := strings.SplitN(line, "=", 2)
if len(keyval) != 2 {
return nil, &IniError{
Message: fmt.Sprintf("malformed key=value (%s)", line),
File: filename,
LineNumber: lineno,
}
}
name := strings.TrimSpace(keyval[0])
value := strings.TrimSpace(keyval[1])
section = append(section, iniValue{
Name: name,
Value: value,
})
ret[sectionname] = section
}
return ret, nil
}
func (i *IniParser) matchingGroups(name string) []*Group {
if len(name) == 0 {
var ret []*Group
i.parser.eachGroup(func(g *Group) {
ret = append(ret, g)
})
return ret
}
g := i.parser.groupByName(name)
if g != nil {
return []*Group{g}
}
return nil
}
func (i *IniParser) parse(ini ini) error {
p := i.parser
for name, section := range ini {
groups := i.matchingGroups(name)
if len(groups) == 0 {
return newError(ErrUnknownGroup,
fmt.Sprintf("could not find option group `%s'", name))
}
for _, inival := range section {
var opt *Option
for _, group := range groups {
opt = group.optionByName(inival.Name, func(o *Option, n string) bool {
return strings.ToLower(o.tag.Get("ini-name")) == strings.ToLower(n)
})
if opt != nil && len(opt.tag.Get("no-ini")) != 0 {
opt = nil
}
if opt != nil {
break
}
}
if opt == nil {
if (p.Options & IgnoreUnknown) == None {
return newError(ErrUnknownFlag,
fmt.Sprintf("unknown option: %s", inival.Name))
}
continue
}
pval := &inival.Value
if !opt.canArgument() && len(inival.Value) == 0 {
pval = nil
}
if err := opt.set(pval); err != nil {
return wrapError(err)
}
opt.tag.Set("_read-ini-name", inival.Name)
}
}
return nil
}

View File

@@ -0,0 +1,170 @@
package flags
import (
"bytes"
"strings"
"testing"
)
func TestWriteIni(t *testing.T) {
var opts helpOptions
p := NewNamedParser("TestIni", Default)
p.AddGroup("Application Options", "The application options", &opts)
p.ParseArgs([]string{"-vv", "--intmap=a:2", "--intmap", "b:3"})
inip := NewIniParser(p)
var b bytes.Buffer
inip.Write(&b, IniDefault|IniIncludeDefaults)
got := b.String()
expected := `[Application Options]
; Show verbose debug information
verbose = true
verbose = true
; A slice of pointers to string
; PtrSlice =
; Option only available in ini
only-ini =
[Other Options]
; A slice of strings
; StringSlice =
; A map from string to int
int-map = a:2
int-map = b:3
`
if got != expected {
ret, err := helpDiff(got, expected)
if err != nil {
t.Errorf("Unexpected ini, expected:\n\n%s\n\nbut got\n\n%s", expected, got)
} else {
t.Errorf("Unexpected ini:\n\n%s", ret)
}
}
}
func TestReadIni(t *testing.T) {
var opts helpOptions
p := NewNamedParser("TestIni", Default)
p.AddGroup("Application Options", "The application options", &opts)
inip := NewIniParser(p)
inic := `
; Show verbose debug information
verbose = true
verbose = true
[Application Options]
; A slice of pointers to string
; PtrSlice =
[Other Options]
; A slice of strings
; StringSlice =
; A map from string to int
int-map = a:2
int-map = b:3
`
b := strings.NewReader(inic)
err := inip.Parse(b)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
assertBoolArray(t, opts.Verbose, []bool{true, true})
if v, ok := opts.Other.IntMap["a"]; !ok {
t.Errorf("Expected \"a\" in Other.IntMap")
} else if v != 2 {
t.Errorf("Expected Other.IntMap[\"a\"] = 2, but got %v", v)
}
if v, ok := opts.Other.IntMap["b"]; !ok {
t.Errorf("Expected \"b\" in Other.IntMap")
} else if v != 3 {
t.Errorf("Expected Other.IntMap[\"b\"] = 3, but got %v", v)
}
}
func TestIniCommands(t *testing.T) {
var opts struct {
Value string `short:"v" long:"value"`
Add struct {
Name int `short:"n" long:"name" ini-name:"AliasName"`
Other struct {
O string `short:"o" long:"other"`
} `group:"Other Options"`
} `command:"add"`
}
p := NewNamedParser("TestIni", Default)
p.AddGroup("Application Options", "The application options", &opts)
inip := NewIniParser(p)
inic := `[Application Options]
value = some value
[add]
AliasName = 5
[add.Other Options]
other = subgroup
`
b := strings.NewReader(inic)
err := inip.Parse(b)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
assertString(t, opts.Value, "some value")
if opts.Add.Name != 5 {
t.Errorf("Expected opts.Add.Name to be 5, but got %v", opts.Add.Name)
}
assertString(t, opts.Add.Other.O, "subgroup")
}
func TestIniNoIni(t *testing.T) {
var opts struct {
Value string `short:"v" long:"value" no-ini:"yes"`
}
p := NewNamedParser("TestIni", Default)
p.AddGroup("Application Options", "The application options", &opts)
inip := NewIniParser(p)
inic := `[Application Options]
value = some value
`
b := strings.NewReader(inic)
err := inip.Parse(b)
if err == nil {
t.Fatalf("Expected error")
}
assertError(t, err, ErrUnknownFlag, "unknown option: value")
}

View File

@@ -0,0 +1,85 @@
package flags
import (
"testing"
)
func TestLong(t *testing.T) {
var opts = struct {
Value bool `long:"value"`
}{}
ret := assertParseSuccess(t, &opts, "--value")
assertStringArray(t, ret, []string{})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
}
func TestLongArg(t *testing.T) {
var opts = struct {
Value string `long:"value"`
}{}
ret := assertParseSuccess(t, &opts, "--value", "value")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestLongArgEqual(t *testing.T) {
var opts = struct {
Value string `long:"value"`
}{}
ret := assertParseSuccess(t, &opts, "--value=value")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestLongDefault(t *testing.T) {
var opts = struct {
Value string `long:"value" default:"value"`
}{}
ret := assertParseSuccess(t, &opts)
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestLongOptional(t *testing.T) {
var opts = struct {
Value string `long:"value" optional:"yes" optional-value:"value"`
}{}
ret := assertParseSuccess(t, &opts, "--value")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestLongOptionalArg(t *testing.T) {
var opts = struct {
Value string `long:"value" optional:"yes" optional-value:"value"`
}{}
ret := assertParseSuccess(t, &opts, "--value", "no")
assertStringArray(t, ret, []string{"no"})
assertString(t, opts.Value, "value")
}
func TestLongOptionalArgEqual(t *testing.T) {
var opts = struct {
Value string `long:"value" optional:"yes" optional-value:"value"`
}{}
ret := assertParseSuccess(t, &opts, "--value=value", "no")
assertStringArray(t, ret, []string{"no"})
assertString(t, opts.Value, "value")
}

View File

@@ -0,0 +1,134 @@
package flags
import (
"fmt"
"io"
"strings"
"time"
)
func formatForMan(wr io.Writer, s string) {
for {
idx := strings.IndexRune(s, '`')
if idx < 0 {
fmt.Fprintf(wr, "%s", s)
break
}
fmt.Fprintf(wr, "%s", s[:idx])
s = s[idx+1:]
idx = strings.IndexRune(s, '\'')
if idx < 0 {
fmt.Fprintf(wr, "%s", s)
break
}
fmt.Fprintf(wr, "\\fB%s\\fP", s[:idx])
s = s[idx+1:]
}
}
func writeManPageOptions(wr io.Writer, grp *Group) {
grp.eachGroup(func(group *Group) {
for _, opt := range group.options {
if !opt.canCli() {
continue
}
fmt.Fprintln(wr, ".TP")
fmt.Fprintf(wr, "\\fB")
if opt.ShortName != 0 {
fmt.Fprintf(wr, "-%c", opt.ShortName)
}
if len(opt.LongName) != 0 {
if opt.ShortName != 0 {
fmt.Fprintf(wr, ", ")
}
fmt.Fprintf(wr, "--%s", opt.LongName)
}
fmt.Fprintln(wr, "\\fP")
formatForMan(wr, opt.Description)
fmt.Fprintln(wr, "")
}
})
}
func writeManPageSubCommands(wr io.Writer, name string, root *Command) {
commands := root.sortedCommands()
for _, c := range commands {
var nn string
if len(name) != 0 {
nn = name + " " + c.Name
} else {
nn = c.Name
}
writeManPageCommand(wr, nn, c)
}
}
func writeManPageCommand(wr io.Writer, name string, command *Command) {
fmt.Fprintf(wr, ".SS %s\n", name)
fmt.Fprintln(wr, command.ShortDescription)
if len(command.LongDescription) > 0 {
fmt.Fprintln(wr, "")
cmdstart := fmt.Sprintf("The %s command", command.Name)
if strings.HasPrefix(command.LongDescription, cmdstart) {
fmt.Fprintf(wr, "The \\fI%s\\fP command", command.Name)
formatForMan(wr, command.LongDescription[len(cmdstart):])
fmt.Fprintln(wr, "")
} else {
formatForMan(wr, command.LongDescription)
fmt.Fprintln(wr, "")
}
}
writeManPageOptions(wr, command.Group)
writeManPageSubCommands(wr, name, command)
}
// WriteManPage writes a basic man page in groff format to the specified
// writer.
func (p *Parser) WriteManPage(wr io.Writer) {
t := time.Now()
fmt.Fprintf(wr, ".TH %s 1 \"%s\"\n", p.Name, t.Format("2 January 2006"))
fmt.Fprintln(wr, ".SH NAME")
fmt.Fprintf(wr, "%s \\- %s\n", p.Name, p.ShortDescription)
fmt.Fprintln(wr, ".SH SYNOPSIS")
usage := p.Usage
if len(usage) == 0 {
usage = "[OPTIONS]"
}
fmt.Fprintf(wr, "\\fB%s\\fP %s\n", p.Name, usage)
fmt.Fprintln(wr, ".SH DESCRIPTION")
formatForMan(wr, p.LongDescription)
fmt.Fprintln(wr, "")
fmt.Fprintln(wr, ".SH OPTIONS")
writeManPageOptions(wr, p.Command.Group)
if len(p.commands) > 0 {
fmt.Fprintln(wr, ".SH COMMANDS")
writeManPageSubCommands(wr, "", p.Command)
}
}

View File

@@ -0,0 +1,78 @@
package flags
import (
"fmt"
"testing"
)
type marshalled bool
func (m *marshalled) UnmarshalFlag(value string) error {
if value == "yes" {
*m = true
} else if value == "no" {
*m = false
} else {
return fmt.Errorf("`%s' is not a valid value, please specify `yes' or `no'", value)
}
return nil
}
func (m marshalled) MarshalFlag() string {
if m {
return "yes"
}
return "no"
}
func TestMarshal(t *testing.T) {
var opts = struct {
Value marshalled `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v=yes")
assertStringArray(t, ret, []string{})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
}
func TestMarshalDefault(t *testing.T) {
var opts = struct {
Value marshalled `short:"v" default:"yes"`
}{}
ret := assertParseSuccess(t, &opts)
assertStringArray(t, ret, []string{})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
}
func TestMarshalOptional(t *testing.T) {
var opts = struct {
Value marshalled `short:"v" optional:"yes" optional-value:"yes"`
}{}
ret := assertParseSuccess(t, &opts, "-v")
assertStringArray(t, ret, []string{})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
}
func TestMarshalError(t *testing.T) {
var opts = struct {
Value marshalled `short:"v"`
}{}
assertParseFail(t, ErrMarshal, "invalid argument for flag `-v' (expected flags.marshalled): `invalid' is not a valid value, please specify `yes' or `no'", &opts, "-vinvalid")
}

View File

@@ -0,0 +1,140 @@
package flags
import (
"strconv"
)
type multiTag struct {
value string
cache map[string][]string
}
func newMultiTag(v string) multiTag {
return multiTag{
value: v,
}
}
func (x *multiTag) scan() (map[string][]string, error) {
v := x.value
ret := make(map[string][]string)
// This is mostly copied from reflect.StructTag.Get
for v != "" {
i := 0
// Skip whitespace
for i < len(v) && v[i] == ' ' {
i++
}
v = v[i:]
if v == "" {
break
}
// Scan to colon to find key
i = 0
for i < len(v) && v[i] != ' ' && v[i] != ':' && v[i] != '"' {
i++
}
if i >= len(v) {
return nil, newErrorf(ErrTag, "expected `:' after key name, but got end of tag (in `%v`)", x.value)
}
if v[i] != ':' {
return nil, newErrorf(ErrTag, "expected `:' after key name, but got `%v' (in `%v`)", v[i], x.value)
}
if i+1 >= len(v) {
return nil, newErrorf(ErrTag, "expected `\"' to start tag value at end of tag (in `%v`)", x.value)
}
if v[i+1] != '"' {
return nil, newErrorf(ErrTag, "expected `\"' to start tag value, but got `%v' (in `%v`)", v[i+1], x.value)
}
name := v[:i]
v = v[i+1:]
// Scan quoted string to find value
i = 1
for i < len(v) && v[i] != '"' {
if v[i] == '\n' {
return nil, newErrorf(ErrTag, "unexpected newline in tag value `%v' (in `%v`)", name, x.value)
}
if v[i] == '\\' {
i++
}
i++
}
if i >= len(v) {
return nil, newErrorf(ErrTag, "expected end of tag value `\"' at end of tag (in `%v`)", x.value)
}
val, err := strconv.Unquote(v[:i+1])
if err != nil {
return nil, newErrorf(ErrTag, "Malformed value of tag `%v:%v` => %v (in `%v`)", name, v[:i+1], err, x.value)
}
v = v[i+1:]
ret[name] = append(ret[name], val)
}
return ret, nil
}
func (x *multiTag) Parse() error {
vals, err := x.scan()
x.cache = vals
return err
}
func (x *multiTag) cached() map[string][]string {
if x.cache == nil {
cache, _ := x.scan()
if cache == nil {
cache = make(map[string][]string)
}
x.cache = cache
}
return x.cache
}
func (x *multiTag) Get(key string) string {
c := x.cached()
if v, ok := c[key]; ok {
return v[len(v)-1]
}
return ""
}
func (x *multiTag) GetMany(key string) []string {
c := x.cached()
return c[key]
}
func (x *multiTag) Set(key string, value string) {
c := x.cached()
c[key] = []string{value}
}
func (x *multiTag) SetMany(key string, value []string) {
c := x.cached()
c[key] = value
}

View File

@@ -0,0 +1,95 @@
package flags
import (
"fmt"
"reflect"
"unicode/utf8"
)
// Option flag information. Contains a description of the option, short and
// long name as well as a default value and whether an argument for this
// flag is optional.
type Option struct {
// The description of the option flag. This description is shown
// automatically in the builtin help.
Description string
// The short name of the option (a single character). If not 0, the
// option flag can be 'activated' using -<ShortName>. Either ShortName
// or LongName needs to be non-empty.
ShortName rune
// The long name of the option. If not "", the option flag can be
// activated using --<LongName>. Either ShortName or LongName needs
// to be non-empty.
LongName string
// The default value of the option.
Default []string
// If true, specifies that the argument to an option flag is optional.
// When no argument to the flag is specified on the command line, the
// value of Default will be set in the field this option represents.
// This is only valid for non-boolean options.
OptionalArgument bool
// The optional value of the option. The optional value is used when
// the option flag is marked as having an OptionalArgument. This means
// that when the flag is specified, but no option argument is given,
// the value of the field this option represents will be set to
// OptionalValue. This is only valid for non-boolean options.
OptionalValue []string
// If true, the option _must_ be specified on the command line. If the
// option is not specified, the parser will generate an ErrRequired type
// error.
Required bool
// A name for the value of an option shown in the Help as --flag [ValueName]
ValueName string
// A mask value to show in the help instead of the default value. This
// is useful for hiding sensitive information in the help, such as
// passwords.
DefaultMask string
// The struct field which the option represents.
field reflect.StructField
// The struct field value which the option represents.
value reflect.Value
defaultValue reflect.Value
iniUsedName string
tag multiTag
}
// String converts an option to a human friendly readable string describing the
// option.
func (option *Option) String() string {
var s string
var short string
if option.ShortName != 0 {
data := make([]byte, utf8.RuneLen(option.ShortName))
utf8.EncodeRune(data, option.ShortName)
short = string(data)
if len(option.LongName) != 0 {
s = fmt.Sprintf("%s%s, %s%s",
string(defaultShortOptDelimiter), short,
defaultLongOptDelimiter, option.LongName)
} else {
s = fmt.Sprintf("%s%s", string(defaultShortOptDelimiter), short)
}
} else if len(option.LongName) != 0 {
s = fmt.Sprintf("%s%s", defaultLongOptDelimiter, option.LongName)
}
return s
}
// Value returns the option value as an interface{}.
func (option *Option) Value() interface{} {
return option.value.Interface()
}

View File

@@ -0,0 +1,125 @@
package flags
import (
"reflect"
)
// Set the value of an option to the specified value. An error will be returned
// if the specified value could not be converted to the corresponding option
// value type.
func (option *Option) set(value *string) error {
if option.isFunc() {
return option.call(value)
} else if value != nil {
return convert(*value, option.value, option.tag)
} else {
return convert("", option.value, option.tag)
}
return nil
}
func (option *Option) canCli() bool {
return option.ShortName != 0 || len(option.LongName) != 0
}
func (option *Option) canArgument() bool {
if u := option.isUnmarshaler(); u != nil {
return true
}
return !option.isBool()
}
func (option *Option) clear() {
tp := option.value.Type()
switch tp.Kind() {
case reflect.Func:
// Skip
case reflect.Map:
// Empty the map
option.value.Set(reflect.MakeMap(tp))
default:
zeroval := reflect.Zero(tp)
option.value.Set(zeroval)
}
}
func (option *Option) isUnmarshaler() Unmarshaler {
v := option.value
for {
if !v.CanInterface() {
return nil
}
i := v.Interface()
if u, ok := i.(Unmarshaler); ok {
return u
}
if !v.CanAddr() {
return nil
}
v = v.Addr()
}
return nil
}
func (option *Option) isBool() bool {
tp := option.value.Type()
for {
switch tp.Kind() {
case reflect.Bool:
return true
case reflect.Slice:
return (tp.Elem().Kind() == reflect.Bool)
case reflect.Func:
return tp.NumIn() == 0
case reflect.Ptr:
tp = tp.Elem()
default:
return false
}
}
return false
}
func (option *Option) isFunc() bool {
return option.value.Type().Kind() == reflect.Func
}
func (option *Option) call(value *string) error {
var retval []reflect.Value
if value == nil {
retval = option.value.Call(nil)
} else {
tp := option.value.Type().In(0)
val := reflect.New(tp)
val = reflect.Indirect(val)
if err := convert(*value, val, option.tag); err != nil {
return err
}
retval = option.value.Call([]reflect.Value{val})
}
if len(retval) == 1 && retval[0].Type() == reflect.TypeOf((*error)(nil)).Elem() {
if retval[0].Interface() == nil {
return nil
}
return retval[0].Interface().(error)
}
return nil
}

View File

@@ -0,0 +1,45 @@
package flags
import (
"testing"
)
func TestPassDoubleDash(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
}{}
p := NewParser(&opts, PassDoubleDash)
ret, err := p.ParseArgs([]string{"-v", "--", "-v", "-g"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}
if !opts.Value {
t.Errorf("Expected Value to be true")
}
assertStringArray(t, ret, []string{"-v", "-g"})
}
func TestPassAfterNonOption(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
}{}
p := NewParser(&opts, PassAfterNonOption)
ret, err := p.ParseArgs([]string{"-v", "arg", "-v", "-g"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}
if !opts.Value {
t.Errorf("Expected Value to be true")
}
assertStringArray(t, ret, []string{"arg", "-v", "-g"})
}

View File

@@ -0,0 +1,54 @@
// +build !windows
package flags
import (
"strings"
)
const (
defaultShortOptDelimiter = '-'
defaultLongOptDelimiter = "--"
defaultNameArgDelimiter = '='
)
func argumentIsOption(arg string) bool {
return len(arg) > 0 && arg[0] == '-'
}
// stripOptionPrefix returns the option without the prefix and whether or
// not the option is a long option or not.
func stripOptionPrefix(optname string) (prefix string, name string, islong bool) {
if strings.HasPrefix(optname, "--") {
return "--", optname[2:], true
} else if strings.HasPrefix(optname, "-") {
return "-", optname[1:], false
}
return "", optname, false
}
// splitOption attempts to split the passed option into a name and an argument.
// When there is no argument specified, nil will be returned for it.
func splitOption(prefix string, option string, islong bool) (string, *string) {
pos := strings.Index(option, "=")
if (islong && pos >= 0) || (!islong && pos == 1) {
rest := option[pos+1:]
return option[:pos], &rest
}
return option, nil
}
// addHelpGroup adds a new group that contains default help parameters.
func (c *Command) addHelpGroup(showHelp func() error) *Group {
var help struct {
ShowHelp func() error `short:"h" long:"help" description:"Show this help message"`
}
help.ShowHelp = showHelp
ret, _ := c.AddGroup("Help Options", "", &help)
return ret
}

View File

@@ -0,0 +1,85 @@
package flags
import (
"strings"
)
// Windows uses a front slash for both short and long options. Also it uses
// a colon for name/argument delimter.
const (
defaultShortOptDelimiter = '/'
defaultLongOptDelimiter = "/"
defaultNameArgDelimiter = ':'
)
func argumentIsOption(arg string) bool {
// Windows-style options allow front slash for the option
// delimiter.
return len(arg) > 0 && (arg[0] == '-' || arg[0] == '/')
}
// stripOptionPrefix returns the option without the prefix and whether or
// not the option is a long option or not.
func stripOptionPrefix(optname string) (prefix string, name string, islong bool) {
// Determine if the argument is a long option or not. Windows
// typically supports both long and short options with a single
// front slash as the option delimiter, so handle this situation
// nicely.
possplit := 0
if strings.HasPrefix(optname, "--") {
possplit = 2
islong = true
} else if strings.HasPrefix(optname, "-") {
possplit = 1
islong = false
} else if strings.HasPrefix(optname, "/") {
possplit = 1
islong = len(optname) > 2
}
return optname[:possplit], optname[possplit:], islong
}
// splitOption attempts to split the passed option into a name and an argument.
// When there is no argument specified, nil will be returned for it.
func splitOption(prefix string, option string, islong bool) (string, *string) {
if len(option) == 0 {
return option, nil
}
// Windows typically uses a colon for the option name and argument
// delimiter while POSIX typically uses an equals. Support both styles,
// but don't allow the two to be mixed. That is to say /foo:bar and
// --foo=bar are acceptable, but /foo=bar and --foo:bar are not.
var pos int
if prefix == "/" {
pos = strings.Index(option, ":")
} else if len(prefix) > 0 {
pos = strings.Index(option, "=")
}
if (islong && pos >= 0) || (!islong && pos == 1) {
rest := option[pos+1:]
return option[:pos], &rest
}
return option, nil
}
// addHelpGroup adds a new group that contains default help parameters.
func (c *Command) addHelpGroup(showHelp func() error) *Group {
// Windows CLI applications typically use /? for help, so make both
// that available as well as the POSIX style h and help.
var help struct {
ShowHelpWindows func() error `short:"?" description:"Show this help message"`
ShowHelpPosix func() error `short:"h" long:"help" description:"Show this help message"`
}
help.ShowHelpWindows = showHelp
help.ShowHelpPosix = showHelp
ret, _ := c.AddGroup("Help Options", "", &help)
return ret
}

View File

@@ -0,0 +1,212 @@
// Copyright 2012 Jesse van den Kieboom. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flags
import (
"os"
"path"
)
// A Parser provides command line option parsing. It can contain several
// option groups each with their own set of options.
type Parser struct {
// Embedded, see Command for more information
*Command
// A usage string to be displayed in the help message.
Usage string
// Option flags changing the behavior of the parser.
Options Options
internalError error
}
// Options provides parser options that change the behavior of the option
// parser.
type Options uint
const (
// None indicates no options.
None Options = 0
// HelpFlag adds a default Help Options group to the parser containing
// -h and --help options. When either -h or --help is specified on the
// command line, the parser will return the special error of type
// ErrHelp. When PrintErrors is also specified, then the help message
// will also be automatically printed to os.Stderr.
HelpFlag = 1 << iota
// PassDoubleDash passes all arguments after a double dash, --, as
// remaining command line arguments (i.e. they will not be parsed for
// flags).
PassDoubleDash
// IgnoreUnknown ignores any unknown options and passes them as
// remaining command line arguments instead of generating an error.
IgnoreUnknown
// PrintErrors prints any errors which occurred during parsing to
// os.Stderr.
PrintErrors
// PassAfterNonOption passes all arguments after the first non option
// as remaining command line arguments. This is equivalent to strict
// POSIX processing.
PassAfterNonOption
// Default is a convenient default set of options which should cover
// most of the uses of the flags package.
Default = HelpFlag | PrintErrors | PassDoubleDash
)
// Parse is a convenience function to parse command line options with default
// settings. The provided data is a pointer to a struct representing the
// default option group (named "Application Options"). For more control, use
// flags.NewParser.
func Parse(data interface{}) ([]string, error) {
return NewParser(data, Default).Parse()
}
// ParseArgs is a convenience function to parse command line options with default
// settings. The provided data is a pointer to a struct representing the
// default option group (named "Application Options"). The args argument is
// the list of command line arguments to parse. If you just want to parse the
// default program command line arguments (i.e. os.Args), then use flags.Parse
// instead. For more control, use flags.NewParser.
func ParseArgs(data interface{}, args []string) ([]string, error) {
return NewParser(data, Default).ParseArgs(args)
}
// NewParser creates a new parser. It uses os.Args[0] as the application
// name and then calls Parser.NewNamedParser (see Parser.NewNamedParser for
// more details). The provided data is a pointer to a struct representing the
// default option group (named "Application Options"), or nil if the default
// group should not be added. The options parameter specifies a set of options
// for the parser.
func NewParser(data interface{}, options Options) *Parser {
ret := NewNamedParser(path.Base(os.Args[0]), options)
if data != nil {
_, ret.internalError = ret.AddGroup("Application Options", "", data)
}
return ret
}
// NewNamedParser creates a new parser. The appname is used to display the
// executable name in the builtin help message. Option groups and commands can
// be added to this parser by using AddGroup and AddCommand.
func NewNamedParser(appname string, options Options) *Parser {
return &Parser{
Command: newCommand(appname, "", "", nil),
Options: options,
}
}
// Parse parses the command line arguments from os.Args using Parser.ParseArgs.
// For more detailed information see ParseArgs.
func (p *Parser) Parse() ([]string, error) {
return p.ParseArgs(os.Args[1:])
}
// ParseArgs parses the command line arguments according to the option groups that
// were added to the parser. On successful parsing of the arguments, the
// remaining, non-option, arguments (if any) are returned. The returned error
// indicates a parsing error and can be used with PrintError to display
// contextual information on where the error occurred exactly.
//
// When the common help group has been added (AddHelp) and either -h or --help
// was specified in the command line arguments, a help message will be
// automatically printed. Furthermore, the special error type ErrHelp is returned.
// It is up to the caller to exit the program if so desired.
func (p *Parser) ParseArgs(args []string) ([]string, error) {
if p.internalError != nil {
return nil, p.internalError
}
p.eachCommand(func(c *Command) {
p.eachGroup(func(g *Group) {
g.storeDefaults()
})
}, true)
// Add builtin help group to all commands if necessary
if (p.Options & HelpFlag) != None {
p.addHelpGroups(p.showBuiltinHelp)
}
s := &parseState{
args: args,
retargs: make([]string, 0, len(args)),
command: p.Command,
lookup: p.makeLookup(),
}
for !s.eof() {
arg := s.pop()
// When PassDoubleDash is set and we encounter a --, then
// simply append all the rest as arguments and break out
if (p.Options&PassDoubleDash) != None && arg == "--" {
s.retargs = append(s.retargs, s.args...)
break
}
if !argumentIsOption(arg) {
// Note: this also sets s.err, so we can just check for
// nil here and use s.err later
if p.parseNonOption(s) != nil {
break
}
continue
}
var err error
var option *Option
prefix, optname, islong := stripOptionPrefix(arg)
optname, argument := splitOption(prefix, optname, islong)
if islong {
option, err = p.parseLong(s, optname, argument)
} else {
option, err = p.parseShort(s, optname, argument)
}
if err != nil {
ignoreUnknown := (p.Options & IgnoreUnknown) != None
parseErr := wrapError(err)
if !(parseErr.Type == ErrUnknownFlag && ignoreUnknown) {
s.err = parseErr
break
}
if ignoreUnknown {
s.retargs = append(s.retargs, arg)
}
} else {
delete(s.lookup.required, option)
}
}
if s.err == nil {
s.checkRequired()
}
if s.err != nil {
return nil, p.printError(s.err)
}
if len(s.command.commands) != 0 {
return nil, p.printError(s.estimateCommand())
} else if cmd, ok := s.command.data.(Commander); ok {
return nil, p.printError(cmd.Execute(s.retargs))
}
return s.retargs, nil
}

View File

@@ -0,0 +1,243 @@
package flags
import (
"bytes"
"fmt"
"os"
"strings"
"unicode/utf8"
)
type parseState struct {
arg string
args []string
retargs []string
err error
command *Command
lookup lookup
}
func (p *parseState) eof() bool {
return len(p.args) == 0
}
func (p *parseState) pop() string {
if p.eof() {
return ""
}
p.arg = p.args[0]
p.args = p.args[1:]
return p.arg
}
func (p *parseState) peek() string {
if p.eof() {
return ""
}
return p.args[0]
}
func (p *parseState) checkRequired() error {
required := p.lookup.required
if len(required) == 0 {
return nil
}
names := make([]string, 0, len(required))
for k := range required {
names = append(names, "`"+k.String()+"'")
}
var msg string
if len(names) == 1 {
msg = fmt.Sprintf("the required flag %s was not specified", names[0])
} else {
msg = fmt.Sprintf("the required flags %s and %s were not specified",
strings.Join(names[:len(names)-1], ", "), names[len(names)-1])
}
p.err = newError(ErrRequired, msg)
return p.err
}
func (p *parseState) estimateCommand() error {
commands := p.command.sortedCommands()
cmdnames := make([]string, len(commands))
for i, v := range commands {
cmdnames[i] = v.Name
}
var msg string
if len(p.retargs) != 0 {
c, l := closestChoice(p.retargs[0], cmdnames)
msg = fmt.Sprintf("Unknown command `%s'", p.retargs[0])
if float32(l)/float32(len(c)) < 0.5 {
msg = fmt.Sprintf("%s, did you mean `%s'?", msg, c)
} else if len(cmdnames) == 1 {
msg = fmt.Sprintf("%s. You should use the %s command",
msg,
cmdnames[0])
} else {
msg = fmt.Sprintf("%s. Please specify one command of: %s or %s",
msg,
strings.Join(cmdnames[:len(cmdnames)-1], ", "),
cmdnames[len(cmdnames)-1])
}
} else {
if len(cmdnames) == 1 {
msg = fmt.Sprintf("Please specify the %s command", cmdnames[0])
} else {
msg = fmt.Sprintf("Please specify one command of: %s or %s",
strings.Join(cmdnames[:len(cmdnames)-1], ", "),
cmdnames[len(cmdnames)-1])
}
}
return newError(ErrRequired, msg)
}
func (p *Parser) parseOption(s *parseState, name string, option *Option, canarg bool, argument *string) (retoption *Option, err error) {
if !option.canArgument() {
if argument != nil {
msg := fmt.Sprintf("bool flag `%s' cannot have an argument", option)
return option, newError(ErrNoArgumentForBool, msg)
}
err = option.set(nil)
} else if argument != nil {
err = option.set(argument)
} else if canarg && !s.eof() {
arg := s.pop()
err = option.set(&arg)
} else if option.OptionalArgument {
option.clear()
for _, v := range option.OptionalValue {
err = option.set(&v)
if err != nil {
break
}
}
} else {
msg := fmt.Sprintf("expected argument for flag `%s'", option)
err = newError(ErrExpectedArgument, msg)
}
if err != nil {
if _, ok := err.(*Error); !ok {
msg := fmt.Sprintf("invalid argument for flag `%s' (expected %s): %s",
option,
option.value.Type(),
err.Error())
err = newError(ErrMarshal, msg)
}
}
return option, err
}
func (p *Parser) parseLong(s *parseState, name string, argument *string) (option *Option, err error) {
if option := s.lookup.longNames[name]; option != nil {
// Only long options that are required can consume an argument
// from the argument list
canarg := !option.OptionalArgument
return p.parseOption(s, name, option, canarg, argument)
}
return nil, newError(ErrUnknownFlag, fmt.Sprintf("unknown flag `%s'", name))
}
func (p *Parser) splitShortConcatArg(s *parseState, optname string) (string, *string) {
c, n := utf8.DecodeRuneInString(optname)
if n == len(optname) {
return optname, nil
}
first := string(c)
if option := s.lookup.shortNames[first]; option != nil && option.canArgument() {
arg := optname[n:]
return first, &arg
}
return optname, nil
}
func (p *Parser) parseShort(s *parseState, optname string, argument *string) (option *Option, err error) {
if argument == nil {
optname, argument = p.splitShortConcatArg(s, optname)
}
for i, c := range optname {
shortname := string(c)
if option = s.lookup.shortNames[shortname]; option != nil {
// Only the last short argument can consume an argument from
// the arguments list, and only if it's non optional
canarg := (i+utf8.RuneLen(c) == len(optname)) && !option.OptionalArgument
if _, err := p.parseOption(s, shortname, option, canarg, argument); err != nil {
return option, err
}
} else {
return nil, newError(ErrUnknownFlag, fmt.Sprintf("unknown flag `%s'", shortname))
}
// Only the first option can have a concatted argument, so just
// clear argument here
argument = nil
}
return option, nil
}
func (p *Parser) parseNonOption(s *parseState) error {
if cmd := s.lookup.commands[s.arg]; cmd != nil {
if err := s.checkRequired(); err != nil {
return err
}
s.command.Active = cmd
s.command = cmd
s.lookup = cmd.makeLookup()
} else if (p.Options & PassAfterNonOption) != None {
// If PassAfterNonOption is set then all remaining arguments
// are considered positional
s.retargs = append(append(s.retargs, s.arg), s.args...)
s.args = []string{}
} else {
s.retargs = append(s.retargs, s.arg)
}
return nil
}
func (p *Parser) showBuiltinHelp() error {
var b bytes.Buffer
p.WriteHelp(&b)
return newError(ErrHelp, b.String())
}
func (p *Parser) printError(err error) error {
if err != nil && (p.Options&PrintErrors) != None {
fmt.Fprintln(os.Stderr, err)
}
return err
}

View File

@@ -0,0 +1,81 @@
package flags
import (
"testing"
)
func TestPointerBool(t *testing.T) {
var opts = struct {
Value *bool `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v")
assertStringArray(t, ret, []string{})
if !*opts.Value {
t.Errorf("Expected Value to be true")
}
}
func TestPointerString(t *testing.T) {
var opts = struct {
Value *string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v", "value")
assertStringArray(t, ret, []string{})
assertString(t, *opts.Value, "value")
}
func TestPointerSlice(t *testing.T) {
var opts = struct {
Value *[]string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v", "value1", "-v", "value2")
assertStringArray(t, ret, []string{})
assertStringArray(t, *opts.Value, []string{"value1", "value2"})
}
func TestPointerMap(t *testing.T) {
var opts = struct {
Value *map[string]int `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v", "k1:2", "-v", "k2:-5")
assertStringArray(t, ret, []string{})
if v, ok := (*opts.Value)["k1"]; !ok {
t.Errorf("Expected key \"k1\" to exist")
} else if v != 2 {
t.Errorf("Expected \"k1\" to be 2, but got %#v", v)
}
if v, ok := (*opts.Value)["k2"]; !ok {
t.Errorf("Expected key \"k2\" to exist")
} else if v != -5 {
t.Errorf("Expected \"k2\" to be -5, but got %#v", v)
}
}
type PointerGroup struct {
Value bool `short:"v"`
}
func TestPointerGroup(t *testing.T) {
var opts = struct {
Group *PointerGroup `group:"Group Options"`
}{}
ret := assertParseSuccess(t, &opts, "-v")
assertStringArray(t, ret, []string{})
if !opts.Group.Value {
t.Errorf("Expected Group.Value to be true")
}
}

View File

@@ -0,0 +1,169 @@
package flags
import (
"testing"
)
func TestShort(t *testing.T) {
var opts = struct {
Value bool `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v")
assertStringArray(t, ret, []string{})
if !opts.Value {
t.Errorf("Expected Value to be true")
}
}
func TestShortTooLong(t *testing.T) {
var opts = struct {
Value bool `short:"vv"`
}{}
assertParseFail(t, ErrShortNameTooLong, "short names can only be 1 character long, not `vv'", &opts)
}
func TestShortRequired(t *testing.T) {
var opts = struct {
Value bool `short:"v" required:"true"`
}{}
assertParseFail(t, ErrRequired, "the required flag `-v' was not specified", &opts)
}
func TestShortMultiConcat(t *testing.T) {
var opts = struct {
V bool `short:"v"`
O bool `short:"o"`
F bool `short:"f"`
}{}
ret := assertParseSuccess(t, &opts, "-vo", "-f")
assertStringArray(t, ret, []string{})
if !opts.V {
t.Errorf("Expected V to be true")
}
if !opts.O {
t.Errorf("Expected O to be true")
}
if !opts.F {
t.Errorf("Expected F to be true")
}
}
func TestShortMultiSlice(t *testing.T) {
var opts = struct {
Values []bool `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v", "-v")
assertStringArray(t, ret, []string{})
assertBoolArray(t, opts.Values, []bool{true, true})
}
func TestShortMultiSliceConcat(t *testing.T) {
var opts = struct {
Values []bool `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-vvv")
assertStringArray(t, ret, []string{})
assertBoolArray(t, opts.Values, []bool{true, true, true})
}
func TestShortWithEqualArg(t *testing.T) {
var opts = struct {
Value string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v=value")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestShortWithArg(t *testing.T) {
var opts = struct {
Value string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-vvalue")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestShortArg(t *testing.T) {
var opts = struct {
Value string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-v", "value")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "value")
}
func TestShortMultiWithEqualArg(t *testing.T) {
var opts = struct {
F []bool `short:"f"`
Value string `short:"v"`
}{}
assertParseFail(t, ErrExpectedArgument, "expected argument for flag `-v'", &opts, "-ffv=value")
}
func TestShortMultiArg(t *testing.T) {
var opts = struct {
F []bool `short:"f"`
Value string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-ffv", "value")
assertStringArray(t, ret, []string{})
assertBoolArray(t, opts.F, []bool{true, true})
assertString(t, opts.Value, "value")
}
func TestShortMultiArgConcatFail(t *testing.T) {
var opts = struct {
F []bool `short:"f"`
Value string `short:"v"`
}{}
assertParseFail(t, ErrExpectedArgument, "expected argument for flag `-v'", &opts, "-ffvvalue")
}
func TestShortMultiArgConcat(t *testing.T) {
var opts = struct {
F []bool `short:"f"`
Value string `short:"v"`
}{}
ret := assertParseSuccess(t, &opts, "-vff")
assertStringArray(t, ret, []string{})
assertString(t, opts.Value, "ff")
}
func TestShortOptional(t *testing.T) {
var opts = struct {
F []bool `short:"f"`
Value string `short:"v" optional:"yes" optional-value:"value"`
}{}
ret := assertParseSuccess(t, &opts, "-fv", "f")
assertStringArray(t, ret, []string{"f"})
assertString(t, opts.Value, "value")
}

View File

@@ -0,0 +1,39 @@
package flags
import (
"testing"
)
func TestTagMissingColon(t *testing.T) {
var opts = struct {
Value bool `short`
}{}
assertParseFail(t, ErrTag, "expected `:' after key name, but got end of tag (in `short`)", &opts, "")
}
func TestTagMissingValue(t *testing.T) {
var opts = struct {
Value bool `short:`
}{}
assertParseFail(t, ErrTag, "expected `\"' to start tag value at end of tag (in `short:`)", &opts, "")
}
func TestTagMissingQuote(t *testing.T) {
var opts = struct {
Value bool `short:"v`
}{}
assertParseFail(t, ErrTag, "expected end of tag value `\"' at end of tag (in `short:\"v`)", &opts, "")
}
func TestTagNewline(t *testing.T) {
var opts = struct {
Value bool `long:"verbose" description:"verbose
something"`
}{}
assertParseFail(t, ErrTag, "unexpected newline in tag value `description' (in `long:\"verbose\" description:\"verbose\nsomething\"`)", &opts, "")
}

View File

@@ -0,0 +1,5 @@
package flags
func getTerminalColumns() int {
return 80
}

View File

@@ -0,0 +1,66 @@
package flags
import (
"testing"
)
func TestUnknownFlags(t *testing.T) {
var opts = struct {
Verbose []bool `short:"v" long:"verbose" description:"Verbose output"`
}{}
args := []string{
"-f",
}
p := NewParser(&opts, 0)
args, err := p.ParseArgs(args)
if err == nil {
t.Fatal("Expected error for unknown argument")
}
}
func TestIgnoreUnknownFlags(t *testing.T) {
var opts = struct {
Verbose []bool `short:"v" long:"verbose" description:"Verbose output"`
}{}
args := []string{
"hello",
"world",
"-v",
"--foo=bar",
"--verbose",
"-f",
}
p := NewParser(&opts, IgnoreUnknown)
args, err := p.ParseArgs(args)
if err != nil {
t.Fatal(err)
}
exargs := []string{
"hello",
"world",
"--foo=bar",
"-f",
}
issame := (len(args) == len(exargs))
if issame {
for i := 0; i < len(args); i++ {
if args[i] != exargs[i] {
issame = false
break
}
}
}
if !issame {
t.Fatalf("Expected %v but got %v", exargs, args)
}
}

43
locktrace.go Normal file
View File

@@ -0,0 +1,43 @@
//+build locktrace
package main
import (
"log"
"path"
"runtime"
"time"
)
var (
lockTime time.Time
)
func (m *Model) Lock() {
_, file, line, _ := runtime.Caller(1)
log.Printf("%s:%d: Lock()...", path.Base(file), line)
blockTime := time.Now()
m.RWMutex.Lock()
lockTime = time.Now()
log.Printf("%s:%d: ...Lock() [%.04f ms]", path.Base(file), line, time.Since(blockTime).Seconds()*1000)
}
func (m *Model) Unlock() {
_, file, line, _ := runtime.Caller(1)
m.RWMutex.Unlock()
log.Printf("%s:%d: Unlock() [%.04f ms]", path.Base(file), line, time.Since(lockTime).Seconds()*1000)
}
func (m *Model) RLock() {
_, file, line, _ := runtime.Caller(1)
log.Printf("%s:%d: RLock()...", path.Base(file), line)
blockTime := time.Now()
m.RWMutex.RLock()
log.Printf("%s:%d: ...RLock() [%.04f ms]", path.Base(file), line, time.Since(blockTime).Seconds()*1000)
}
func (m *Model) RUnlock() {
_, file, line, _ := runtime.Caller(1)
m.RWMutex.RUnlock()
log.Printf("%s:%d: RUnlock()", path.Base(file), line)
}

View File

@@ -6,21 +6,16 @@ import (
"os"
)
var debugEnabled = true
var logger = log.New(os.Stderr, "", log.Lshortfile|log.Ltime)
var logger = log.New(os.Stderr, "", log.Ltime)
func debugln(vals ...interface{}) {
if debugEnabled {
s := fmt.Sprintln(vals...)
logger.Output(2, "DEBUG: "+s)
}
s := fmt.Sprintln(vals...)
logger.Output(2, "DEBUG: "+s)
}
func debugf(format string, vals ...interface{}) {
if debugEnabled {
s := fmt.Sprintf(format, vals...)
logger.Output(2, "DEBUG: "+s)
}
s := fmt.Sprintf(format, vals...)
logger.Output(2, "DEBUG: "+s)
}
func infoln(vals ...interface{}) {

220
main.go
View File

@@ -1,6 +1,7 @@
package main
import (
"compress/gzip"
"crypto/sha1"
"crypto/tls"
"fmt"
@@ -16,35 +17,50 @@ import (
"github.com/calmh/ini"
"github.com/calmh/syncthing/discover"
flags "github.com/calmh/syncthing/github.com/jessevdk/go-flags"
"github.com/calmh/syncthing/protocol"
docopt "github.com/docopt/docopt.go"
)
type Options struct {
ConfDir string `short:"c" long:"cfg" description:"Configuration directory" default:"~/.syncthing" value-name:"DIR"`
Listen string `short:"l" long:"listen" description:"Listen address" default:":22000" value-name:"ADDR"`
ReadOnly bool `short:"r" long:"ro" description:"Repository is read only"`
Delete bool `short:"d" long:"delete" description:"Delete files deleted from cluster"`
Rehash bool `long:"rehash" description:"Ignore cache and rehash all files in repository"`
NoSymlinks bool `long:"no-symlinks" description:"Don't follow first level symlinks in the repo"`
Discovery DiscoveryOptions `group:"Discovery Options"`
Advanced AdvancedOptions `group:"Advanced Options"`
Debug DebugOptions `group:"Debugging Options"`
}
type DebugOptions struct {
LogSource bool `long:"log-source"`
TraceFile bool `long:"trace-file"`
TraceNet bool `long:"trace-net"`
TraceIdx bool `long:"trace-idx"`
Profiler string `long:"profiler" value-name:"ADDR"`
}
type DiscoveryOptions struct {
ExternalServer string `long:"ext-server" description:"External discovery server" value-name:"NAME" default:"syncthing.nym.se"`
ExternalPort int `short:"e" long:"ext-port" description:"External listen port" value-name:"PORT" default:"22000"`
NoExternalDiscovery bool `short:"n" long:"no-ext-announce" description:"Do not announce presence externally"`
NoLocalDiscovery bool `short:"N" long:"no-local-announce" description:"Do not announce presence locally"`
}
type AdvancedOptions struct {
RequestsInFlight int `long:"reqs-in-flight" description:"Parallell in flight requests per file" default:"4" value-name:"REQS"`
FilesInFlight int `long:"files-in-flight" description:"Parallell in flight file pulls" default:"8" value-name:"FILES"`
ScanInterval time.Duration `long:"scan-intv" description:"Repository scan interval" default:"60s" value-name:"INTV"`
ConnInterval time.Duration `long:"conn-intv" description:"Node reconnect interval" default:"60s" value-name:"INTV"`
}
var opts Options
var Version string = "unknown-dev"
const (
confDirName = ".syncthing"
confFileName = "syncthing.ini"
usage = `Usage:
syncthing [options]
Options:
-l <addr> Listening address [default: :22000]
-p <addr> Enable HTTP profiler on addr
--home <path> Home directory
--delete Delete files that were deleted on a peer node
--ro Local repository is read only
--scan-intv <s> Repository scan interval, in seconds [default: 60]
--conn-intv <s> Node reconnect interval, in seconds [default: 15]
--no-stats Don't print transfer statistics
Help Options:
-h, --help Show this help
--version Show version
Debug Options:
--trace-file Trace file operations
--trace-net Trace network operations
--trace-idx Trace sent indexes
`
)
var (
@@ -52,68 +68,39 @@ var (
nodeAddrs = make(map[string][]string)
)
// Options
var (
confDir = path.Join(getHomeDir(), confDirName)
addr string
prof string
readOnly bool
scanIntv int
connIntv int
traceNet bool
traceFile bool
traceIdx bool
printStats bool
doDelete bool
)
func main() {
// Useful for debugging; to be adjusted.
log.SetFlags(log.Ltime | log.Lshortfile)
arguments, _ := docopt.Parse(usage, nil, true, "syncthing 0.1", false)
addr = arguments["-l"].(string)
prof, _ = arguments["-p"].(string)
readOnly, _ = arguments["--ro"].(bool)
if arguments["--home"] != nil {
confDir, _ = arguments["--home"].(string)
_, err := flags.Parse(&opts)
if err != nil {
os.Exit(0)
}
if opts.Debug.TraceFile || opts.Debug.TraceIdx || opts.Debug.TraceNet || opts.Debug.LogSource {
logger = log.New(os.Stderr, "", log.Lshortfile|log.Ldate|log.Ltime|log.Lmicroseconds)
}
if strings.HasPrefix(opts.ConfDir, "~/") {
opts.ConfDir = strings.Replace(opts.ConfDir, "~", getHomeDir(), 1)
}
scanIntv, _ = strconv.Atoi(arguments["--scan-intv"].(string))
if scanIntv == 0 {
fatalln("Invalid --scan-intv")
}
connIntv, _ = strconv.Atoi(arguments["--conn-intv"].(string))
if connIntv == 0 {
fatalln("Invalid --conn-intv")
}
doDelete = arguments["--delete"].(bool)
traceFile = arguments["--trace-file"].(bool)
traceNet = arguments["--trace-net"].(bool)
traceIdx = arguments["--trace-idx"].(bool)
printStats = !arguments["--no-stats"].(bool)
infoln("Version", Version)
// Ensure that our home directory exists and that we have a certificate and key.
ensureDir(confDir)
cert, err := loadCert(confDir)
ensureDir(opts.ConfDir, 0700)
cert, err := loadCert(opts.ConfDir)
if err != nil {
newCertificate(confDir)
cert, err = loadCert(confDir)
newCertificate(opts.ConfDir)
cert, err = loadCert(opts.ConfDir)
fatalErr(err)
}
myID := string(certId(cert.Certificate[0]))
infoln("My ID:", myID)
if prof != "" {
okln("Profiler listening on", prof)
if opts.Debug.Profiler != "" {
go func() {
http.ListenAndServe(prof, nil)
err := http.ListenAndServe(opts.Debug.Profiler, nil)
if err != nil {
warnln(err)
}
}()
}
@@ -130,7 +117,7 @@ func main() {
// Load the configuration file, if it exists.
cf, err := os.Open(path.Join(confDir, confFileName))
cf, err := os.Open(path.Join(opts.ConfDir, confFileName))
if err != nil {
fatalln("No config file")
config = ini.Config{}
@@ -148,26 +135,30 @@ func main() {
nodeAddrs[nodeID] = addrs
}
ensureDir(dir, -1)
m := NewModel(dir)
// Walk the repository and update the local model before establishing any
// connections to other nodes.
infoln("Iniial repository scan in progress")
loadIndex(m)
if !opts.Rehash {
infoln("Loading index cache")
loadIndex(m)
}
infoln("Populating repository index")
updateLocalModel(m)
// Routine to listen for incoming connections
infoln("Listening for incoming connections")
go listen(myID, addr, m, cfg)
go listen(myID, opts.Listen, m, cfg)
// Routine to connect out to configured nodes
infoln("Attempting to connect to other nodes")
go connect(myID, addr, nodeAddrs, m, cfg)
go connect(myID, opts.Listen, nodeAddrs, m, cfg)
// Routine to pull blocks from other nodes to synchronize the local
// repository. Does not run when we are in read only (publish only) mode.
if !readOnly {
if !opts.ReadOnly {
infoln("Cleaning out incomplete synchronizations")
CleanTempFiles(dir)
okln("Ready to synchronize")
@@ -178,7 +169,7 @@ func main() {
// XXX: Should use some fsnotify mechanism.
go func() {
for {
time.Sleep(time.Duration(scanIntv) * time.Second)
time.Sleep(opts.Advanced.ScanInterval)
updateLocalModel(m)
}
}()
@@ -198,7 +189,7 @@ listen:
continue
}
if traceNet {
if opts.Debug.TraceNet {
debugln("NET: Connect from", conn.RemoteAddr())
}
@@ -224,14 +215,10 @@ listen:
for nodeID := range nodeAddrs {
if nodeID == remoteID {
nc := protocol.NewConnection(remoteID, conn, conn, m)
m.AddNode(nc)
okln("Connected to nodeID", remoteID, "(in)")
m.AddConnection(conn, remoteID)
continue listen
}
}
warnln("Connect from unknown node", remoteID)
conn.Close()
}
}
@@ -241,10 +228,22 @@ func connect(myID string, addr string, nodeAddrs map[string][]string, m *Model,
fatalErr(err)
port, _ := strconv.Atoi(portstr)
infoln("Starting local discovery")
disc, err := discover.NewDiscoverer(myID, port)
if opts.Discovery.NoLocalDiscovery {
port = -1
} else {
infoln("Sending local discovery announcements")
}
if opts.Discovery.NoExternalDiscovery {
opts.Discovery.ExternalPort = -1
} else {
infoln("Sending external discovery announcements")
}
disc, err := discover.NewDiscoverer(myID, port, opts.Discovery.ExternalPort, opts.Discovery.ExternalServer)
if err != nil {
warnln("No local discovery possible")
warnf("No discovery possible (%v)", err)
}
for {
@@ -267,12 +266,12 @@ func connect(myID string, addr string, nodeAddrs map[string][]string, m *Model,
}
}
if traceNet {
if opts.Debug.TraceNet {
debugln("NET: Dial", nodeID, addr)
}
conn, err := tls.Dial("tcp", addr, cfg)
if err != nil {
if traceNet {
if opts.Debug.TraceNet {
debugln("NET:", err)
}
continue
@@ -285,60 +284,65 @@ func connect(myID string, addr string, nodeAddrs map[string][]string, m *Model,
continue
}
nc := protocol.NewConnection(nodeID, conn, conn, m)
okln("Connected to node", remoteID, "(out)")
m.AddNode(nc)
if traceNet {
t0 := time.Now()
nc.Ping()
timing("NET: Ping reply", t0)
}
m.AddConnection(conn, remoteID)
continue nextNode
}
}
time.Sleep(time.Duration(connIntv) * time.Second)
time.Sleep(opts.Advanced.ConnInterval)
}
}
func updateLocalModel(m *Model) {
files := Walk(m.Dir(), m)
files := Walk(m.Dir(), m, !opts.NoSymlinks)
m.ReplaceLocal(files)
saveIndex(m)
}
func saveIndex(m *Model) {
fname := fmt.Sprintf("%x.idx", sha1.Sum([]byte(m.Dir())))
idxf, err := os.Create(path.Join(confDir, fname))
name := fmt.Sprintf("%x.idx.gz", sha1.Sum([]byte(m.Dir())))
fullName := path.Join(opts.ConfDir, name)
idxf, err := os.Create(fullName + ".tmp")
if err != nil {
return
}
protocol.WriteIndex(idxf, m.ProtocolIndex())
gzw := gzip.NewWriter(idxf)
protocol.WriteIndex(gzw, m.ProtocolIndex())
gzw.Close()
idxf.Close()
os.Rename(fullName+".tmp", fullName)
}
func loadIndex(m *Model) {
fname := fmt.Sprintf("%x.idx", sha1.Sum([]byte(m.Dir())))
idxf, err := os.Open(path.Join(confDir, fname))
fname := fmt.Sprintf("%x.idx.gz", sha1.Sum([]byte(m.Dir())))
idxf, err := os.Open(path.Join(opts.ConfDir, fname))
if err != nil {
return
}
defer idxf.Close()
idx, err := protocol.ReadIndex(idxf)
gzr, err := gzip.NewReader(idxf)
if err != nil {
return
}
defer gzr.Close()
idx, err := protocol.ReadIndex(gzr)
if err != nil {
return
}
m.SeedIndex(idx)
}
func ensureDir(dir string) {
func ensureDir(dir string, mode int) {
fi, err := os.Stat(dir)
if os.IsNotExist(err) {
err := os.MkdirAll(dir, 0700)
fatalErr(err)
} else if fi.Mode()&0077 != 0 {
err := os.Chmod(dir, 0700)
} else if mode >= 0 && err == nil && int(fi.Mode()&0777) != mode {
err := os.Chmod(dir, os.FileMode(mode))
fatalErr(err)
}
}

312
model.go
View File

@@ -9,11 +9,11 @@ The model has read and write locks. These must be acquired as appropriate by
public methods. To prevent deadlock situations, private methods should never
acquire locks, but document what locks they require.
TODO(jb): Keep global and per node transfer and performance statistics.
*/
import (
"fmt"
"io"
"os"
"path"
"sync"
@@ -25,30 +25,43 @@ import (
type Model struct {
sync.RWMutex
dir string
updated int64
dir string
global map[string]File // the latest version of each file as it exists in the cluster
local map[string]File // the files we currently have locally on disk
remote map[string]map[string]File
need map[string]bool // the files we need to update
nodes map[string]*protocol.Connection
rawConn map[string]io.ReadWriteCloser
updatedLocal int64 // timestamp of last update to local
updateGlobal int64 // timestamp of last update to remote
lastIdxBcast time.Time
lastIdxBcastRequest time.Time
}
const (
RemoteFetchers = 4
FlagDeleted = 1 << 12
FlagDeleted = 1 << 12
idxBcastHoldtime = 15 * time.Second // Wait at least this long after the last index modification
idxBcastMaxDelay = 120 * time.Second // Unless we've already waited this long
)
func NewModel(dir string) *Model {
m := &Model{
dir: dir,
global: make(map[string]File),
local: make(map[string]File),
remote: make(map[string]map[string]File),
need: make(map[string]bool),
nodes: make(map[string]*protocol.Connection),
dir: dir,
global: make(map[string]File),
local: make(map[string]File),
remote: make(map[string]map[string]File),
need: make(map[string]bool),
nodes: make(map[string]*protocol.Connection),
rawConn: make(map[string]io.ReadWriteCloser),
lastIdxBcast: time.Now(),
}
go m.printStatsLoop()
go m.broadcastIndexLoop()
return m
}
@@ -56,85 +69,155 @@ func (m *Model) Start() {
go m.puller()
}
func (m *Model) printStatsLoop() {
var lastUpdated int64
for {
time.Sleep(60 * time.Second)
m.RLock()
m.printConnectionStats()
if m.updatedLocal+m.updateGlobal > lastUpdated {
m.printModelStats()
lastUpdated = m.updatedLocal + m.updateGlobal
}
m.RUnlock()
}
}
func (m *Model) printConnectionStats() {
for node, conn := range m.nodes {
stats := conn.Statistics()
if stats.InBytesPerSec > 0 || stats.OutBytesPerSec > 0 {
infof("%s: %sB/s in, %sB/s out", node, toSI(stats.InBytesPerSec), toSI(stats.OutBytesPerSec))
}
}
}
func (m *Model) printModelStats() {
var tot int
for _, f := range m.global {
tot += f.Size()
}
infof("%6d files, %8sB in cluster", len(m.global), toSI(tot))
if len(m.need) > 0 {
tot = 0
for _, f := range m.local {
tot += f.Size()
}
infof("%6d files, %8sB in local repo", len(m.local), toSI(tot))
tot = 0
for n := range m.need {
tot += m.global[n].Size()
}
infof("%6d files, %8sB to synchronize", len(m.need), toSI(tot))
}
}
func toSI(n int) string {
if n > 1<<30 {
return fmt.Sprintf("%.02f G", float64(n)/(1<<30))
}
if n > 1<<20 {
return fmt.Sprintf("%.02f M", float64(n)/(1<<20))
}
if n > 1<<10 {
return fmt.Sprintf("%.01f K", float64(n)/(1<<10))
}
return fmt.Sprintf("%d ", n)
}
// Index is called when a new node is connected and we receive their full index.
func (m *Model) Index(nodeID string, fs []protocol.FileInfo) {
m.Lock()
defer m.Unlock()
if traceNet {
if opts.Debug.TraceNet {
debugf("NET IDX(in): %s: %d files", nodeID, len(fs))
}
m.remote[nodeID] = make(map[string]File)
for _, f := range fs {
if f.Flags&FlagDeleted != 0 && !doDelete {
if f.Flags&FlagDeleted != 0 && !opts.Delete {
// Files marked as deleted do not even enter the model
continue
}
mf := File{
Name: f.Name,
Flags: f.Flags,
Modified: int64(f.Modified),
m.remote[nodeID][f.Name] = fileFromFileInfo(f)
}
m.recomputeGlobal()
m.recomputeNeed()
m.printModelStats()
}
// IndexUpdate is called for incremental updates to connected nodes' indexes.
func (m *Model) IndexUpdate(nodeID string, fs []protocol.FileInfo) {
m.Lock()
defer m.Unlock()
if opts.Debug.TraceNet {
debugf("NET IDXUP(in): %s: %d files", nodeID, len(fs))
}
repo, ok := m.remote[nodeID]
if !ok {
return
}
for _, f := range fs {
if f.Flags&FlagDeleted != 0 && !opts.Delete {
// Files marked as deleted do not even enter the model
continue
}
var offset uint64
for _, b := range f.Blocks {
mf.Blocks = append(mf.Blocks, Block{
Offset: offset,
Length: b.Length,
Hash: b.Hash,
})
offset += uint64(b.Length)
}
m.remote[nodeID][f.Name] = mf
repo[f.Name] = fileFromFileInfo(f)
}
m.recomputeGlobal()
m.recomputeNeed()
}
// SeedIndex is called when our previously cached index is loaded from disk at startup.
func (m *Model) SeedIndex(fs []protocol.FileInfo) {
m.Lock()
defer m.Unlock()
m.local = make(map[string]File)
for _, f := range fs {
mf := File{
Name: f.Name,
Flags: f.Flags,
Modified: int64(f.Modified),
}
var offset uint64
for _, b := range f.Blocks {
mf.Blocks = append(mf.Blocks, Block{
Offset: offset,
Length: b.Length,
Hash: b.Hash,
})
offset += uint64(b.Length)
}
m.local[f.Name] = mf
m.local[f.Name] = fileFromFileInfo(f)
}
m.recomputeGlobal()
m.recomputeNeed()
m.printModelStats()
}
func (m *Model) Close(node string) {
func (m *Model) Close(node string, err error) {
m.Lock()
defer m.Unlock()
if traceNet {
debugf("NET CLOSE: %s", node)
conn, ok := m.rawConn[node]
if ok {
conn.Close()
} else {
warnln("Close on unknown connection for node", node)
}
if err != nil {
warnf("Disconnected from node %s: %v", node, err)
} else {
infoln("Disconnected from node", node)
}
delete(m.remote, node)
delete(m.nodes, node)
delete(m.rawConn, node)
m.recomputeGlobal()
m.recomputeNeed()
}
func (m *Model) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
if traceNet && nodeID != "<local>" {
if opts.Debug.TraceNet && nodeID != "<local>" {
debugf("NET REQ(in): %s: %q o=%d s=%d h=%x", nodeID, name, offset, size, hash)
}
fn := path.Join(m.dir, name)
@@ -155,10 +238,13 @@ func (m *Model) Request(nodeID, name string, offset uint64, size uint32, hash []
func (m *Model) RequestGlobal(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
m.RLock()
nc := m.nodes[nodeID]
nc, ok := m.nodes[nodeID]
m.RUnlock()
if !ok {
return nil, fmt.Errorf("RequestGlobal: no such node: %s", nodeID)
}
if traceNet {
if opts.Debug.TraceNet {
debugf("NET REQ(out): %s: %q o=%d s=%d h=%x", nodeID, name, offset, size, hash)
}
@@ -191,19 +277,39 @@ func (m *Model) ReplaceLocal(fs []File) {
m.local = newLocal
m.recomputeGlobal()
m.recomputeNeed()
m.updated = time.Now().Unix()
go m.broadcastIndex()
m.updatedLocal = time.Now().Unix()
m.lastIdxBcastRequest = time.Now()
}
}
// Must be called with the read lock held.
func (m *Model) broadcastIndex() {
idx := m.protocolIndex()
for _, node := range m.nodes {
if traceNet {
debugf("NET IDX(out): %s: %d files", node.ID, len(idx))
func (m *Model) broadcastIndexLoop() {
for {
m.RLock()
bcastRequested := m.lastIdxBcastRequest.After(m.lastIdxBcast)
holdtimeExceeded := time.Since(m.lastIdxBcastRequest) > idxBcastHoldtime
m.RUnlock()
maxDelayExceeded := time.Since(m.lastIdxBcast) > idxBcastMaxDelay
if bcastRequested && (holdtimeExceeded || maxDelayExceeded) {
m.Lock()
var indexWg sync.WaitGroup
indexWg.Add(len(m.nodes))
idx := m.protocolIndex()
m.lastIdxBcast = time.Now()
for _, node := range m.nodes {
node := node
if opts.Debug.TraceNet {
debugf("NET IDX(out/loop): %s: %d files", node.ID, len(idx))
}
go func() {
node.Index(idx)
indexWg.Done()
}()
}
m.Unlock()
indexWg.Wait()
}
node.Index(idx)
time.Sleep(idxBcastHoldtime)
}
}
@@ -239,8 +345,8 @@ func (m *Model) UpdateLocal(f File) {
m.local[f.Name] = f
m.recomputeGlobal()
m.recomputeNeed()
m.updated = time.Now().Unix()
go m.broadcastIndex()
m.updatedLocal = time.Now().Unix()
m.lastIdxBcastRequest = time.Now()
}
}
@@ -290,7 +396,24 @@ func (m *Model) recomputeGlobal() {
}
}
m.global = newGlobal
// Figure out if anything actually changed
var updated bool
if len(newGlobal) != len(m.global) {
updated = true
} else {
for n, f0 := range newGlobal {
if f1, ok := m.global[n]; !ok || f0.Modified != f1.Modified {
updated = true
break
}
}
}
if updated {
m.updateGlobal = time.Now().Unix()
m.global = newGlobal
}
}
// Must be called with the write lock held.
@@ -335,18 +458,8 @@ func (m *Model) ProtocolIndex() []protocol.FileInfo {
func (m *Model) protocolIndex() []protocol.FileInfo {
var index []protocol.FileInfo
for _, f := range m.local {
mf := protocol.FileInfo{
Name: f.Name,
Flags: f.Flags,
Modified: int64(f.Modified),
}
for _, b := range f.Blocks {
mf.Blocks = append(mf.Blocks, protocol.BlockInfo{
Length: b.Length,
Hash: b.Hash,
})
}
if traceIdx {
mf := fileInfoFromFile(f)
if opts.Debug.TraceIdx {
var flagComment string
if mf.Flags&FlagDeleted != 0 {
flagComment = " (deleted)"
@@ -358,16 +471,57 @@ func (m *Model) protocolIndex() []protocol.FileInfo {
return index
}
func (m *Model) AddNode(node *protocol.Connection) {
func (m *Model) AddConnection(conn io.ReadWriteCloser, nodeID string) {
node := protocol.NewConnection(nodeID, conn, conn, m)
m.Lock()
m.nodes[node.ID] = node
m.nodes[nodeID] = node
m.rawConn[nodeID] = conn
m.Unlock()
infoln("Connected to node", nodeID)
m.RLock()
idx := m.protocolIndex()
m.RUnlock()
if traceNet {
debugf("NET IDX(out): %s: %d files", node.ID, len(idx))
}
node.Index(idx)
go func() {
node.Index(idx)
infoln("Sent initial index to node", nodeID)
}()
}
func fileFromFileInfo(f protocol.FileInfo) File {
var blocks []Block
var offset uint64
for _, b := range f.Blocks {
blocks = append(blocks, Block{
Offset: offset,
Length: b.Length,
Hash: b.Hash,
})
offset += uint64(b.Length)
}
return File{
Name: f.Name,
Flags: f.Flags,
Modified: int64(f.Modified),
Blocks: blocks,
}
}
func fileInfoFromFile(f File) protocol.FileInfo {
var blocks []protocol.BlockInfo
for _, b := range f.Blocks {
blocks = append(blocks, protocol.BlockInfo{
Length: b.Length,
Hash: b.Hash,
})
}
return protocol.FileInfo{
Name: f.Name,
Flags: f.Flags,
Modified: int64(f.Modified),
Blocks: blocks,
}
}

View File

@@ -6,19 +6,16 @@ Locking
=======
These methods are never called from the outside so don't follow the locking
policy in model.go. Instead, appropriate locks are acquired when needed and
held for as short a time as possible.
policy in model.go.
TODO(jb): Refactor this into smaller and cleaner pieces.
TODO(jb): Some kind of coalescing / rate limiting of index sending, so we don't
send hundreds of index updates in a short period if time when deleting files
etc.
TODO(jb): Increase performance by taking apparent peer bandwidth into account.
*/
import (
"bytes"
"errors"
"fmt"
"io"
"os"
@@ -33,8 +30,13 @@ func (m *Model) pullFile(name string) error {
m.RLock()
var localFile = m.local[name]
var globalFile = m.global[name]
var nodeIDs = m.whoHas(name)
m.RUnlock()
if len(nodeIDs) == 0 {
return fmt.Errorf("%s: no connected nodes with file available", name)
}
filename := path.Join(m.dir, name)
sdir := path.Dir(filename)
@@ -48,20 +50,20 @@ func (m *Model) pullFile(name string) error {
if err != nil {
return err
}
defer tmpFile.Close()
contentChan := make(chan content, 32)
var applyDone sync.WaitGroup
applyDone.Add(1)
go func() {
applyContent(contentChan, tmpFile)
tmpFile.Close()
applyDone.Done()
}()
local, remote := localFile.Blocks.To(globalFile.Blocks)
var fetchDone sync.WaitGroup
// One local copy routing
// One local copy routine
fetchDone.Add(1)
go func() {
@@ -80,60 +82,37 @@ func (m *Model) pullFile(name string) error {
// N remote copy routines
m.RLock()
var nodeIDs = m.whoHas(name)
m.RUnlock()
var remoteBlocksChan = make(chan Block)
go func() {
for _, block := range remote {
remoteBlocksChan <- block
}
close(remoteBlocksChan)
}()
var remoteBlocks = blockIterator{blocks: remote}
for i := 0; i < opts.Advanced.RequestsInFlight; i++ {
curNode := nodeIDs[i%len(nodeIDs)]
fetchDone.Add(1)
// XXX: This should be rewritten into something nicer that takes differing
// peer performance into account.
for i := 0; i < RemoteFetchers; i++ {
for _, nodeID := range nodeIDs {
fetchDone.Add(1)
go func(nodeID string) {
for block := range remoteBlocksChan {
data, err := m.RequestGlobal(nodeID, name, block.Offset, block.Length, block.Hash)
if err != nil {
break
}
contentChan <- content{
offset: int64(block.Offset),
data: data,
}
go func(nodeID string) {
for {
block, ok := remoteBlocks.Next()
if !ok {
break
}
fetchDone.Done()
}(nodeID)
}
data, err := m.RequestGlobal(nodeID, name, block.Offset, block.Length, block.Hash)
if err != nil {
break
}
contentChan <- content{
offset: int64(block.Offset),
data: data,
}
}
fetchDone.Done()
}(curNode)
}
fetchDone.Wait()
close(contentChan)
applyDone.Wait()
rf, err := os.Open(tmpFilename)
err = hashCheck(tmpFilename, globalFile.Blocks)
if err != nil {
return err
}
defer rf.Close()
writtenBlocks, err := Blocks(rf, BlockSize)
if err != nil {
return err
}
if len(writtenBlocks) != len(globalFile.Blocks) {
return fmt.Errorf("%s: incorrect number of blocks after sync", tmpFilename)
}
for i := range writtenBlocks {
if bytes.Compare(writtenBlocks[i].Hash, globalFile.Blocks[i].Hash) != 0 {
return fmt.Errorf("%s: hash mismatch after sync\n %v\n %v", tmpFilename, writtenBlocks[i], globalFile.Blocks[i])
}
return fmt.Errorf("%s: %s (deleting)", path.Base(name), err.Error())
}
err = os.Chtimes(tmpFilename, time.Unix(globalFile.Modified, 0), time.Unix(globalFile.Modified, 0))
@@ -151,44 +130,59 @@ func (m *Model) pullFile(name string) error {
func (m *Model) puller() {
for {
for {
var n string
var f File
m.RLock()
for n = range m.need {
break // just pick first name
}
if len(n) != 0 {
f = m.global[n]
}
m.RUnlock()
if len(n) == 0 {
// we got nothing
break
}
var err error
if f.Flags&FlagDeleted == 0 {
if traceFile {
debugf("FILE: Pull %q", n)
}
err = m.pullFile(n)
} else {
if traceFile {
debugf("FILE: Remove %q", n)
}
// Cheerfully ignore errors here
_ = os.Remove(path.Join(m.dir, n))
}
if err == nil {
m.UpdateLocal(f)
} else {
warnln(err)
}
}
time.Sleep(time.Second)
var ns []string
m.RLock()
for n := range m.need {
ns = append(ns, n)
}
m.RUnlock()
if len(ns) == 0 {
continue
}
var limiter = make(chan bool, opts.Advanced.FilesInFlight)
var allDone sync.WaitGroup
for _, n := range ns {
limiter <- true
allDone.Add(1)
go func(n string) {
defer func() {
allDone.Done()
<-limiter
}()
f, ok := m.GlobalFile(n)
if !ok {
return
}
var err error
if f.Flags&FlagDeleted == 0 {
if opts.Debug.TraceFile {
debugf("FILE: Pull %q", n)
}
err = m.pullFile(n)
} else {
if opts.Debug.TraceFile {
debugf("FILE: Remove %q", n)
}
// Cheerfully ignore errors here
_ = os.Remove(path.Join(m.dir, n))
}
if err == nil {
m.UpdateLocal(f)
} else {
warnln(err)
}
}(n)
}
allDone.Wait()
}
}
@@ -202,11 +196,53 @@ func applyContent(cc <-chan content, dst io.WriterAt) error {
for c := range cc {
_, err = dst.WriteAt(c.data, c.offset)
buffers.Put(c.data)
if err != nil {
return err
}
buffers.Put(c.data)
}
return nil
}
func hashCheck(name string, correct []Block) error {
rf, err := os.Open(name)
if err != nil {
return err
}
defer rf.Close()
current, err := Blocks(rf, BlockSize)
if err != nil {
return err
}
if len(current) != len(correct) {
return errors.New("incorrect number of blocks")
}
for i := range current {
if bytes.Compare(current[i].Hash, correct[i].Hash) != 0 {
return fmt.Errorf("hash mismatch: %x != %x", current[i], correct[i])
}
}
return nil
}
type blockIterator struct {
sync.Mutex
blocks []Block
}
func (i *blockIterator) Next() (b Block, ok bool) {
i.Lock()
defer i.Unlock()
if len(i.blocks) == 0 {
return
}
b, i.blocks = i.blocks[0], i.blocks[1:]
ok = true
return
}

View File

@@ -47,7 +47,7 @@ var testDataExpected = map[string]File{
func TestUpdateLocal(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m)
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
if len(m.need) > 0 {
@@ -89,7 +89,7 @@ func TestUpdateLocal(t *testing.T) {
func TestRemoteUpdateExisting(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m)
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
newFile := protocol.FileInfo{
@@ -97,7 +97,7 @@ func TestRemoteUpdateExisting(t *testing.T) {
Modified: time.Now().Unix(),
Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}},
}
m.Index(string("42"), []protocol.FileInfo{newFile})
m.Index("42", []protocol.FileInfo{newFile})
if l := len(m.need); l != 1 {
t.Errorf("Model missing Need for one file (%d != 1)", l)
@@ -106,7 +106,7 @@ func TestRemoteUpdateExisting(t *testing.T) {
func TestRemoteAddNew(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m)
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
newFile := protocol.FileInfo{
@@ -114,7 +114,7 @@ func TestRemoteAddNew(t *testing.T) {
Modified: time.Now().Unix(),
Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}},
}
m.Index(string("42"), []protocol.FileInfo{newFile})
m.Index("42", []protocol.FileInfo{newFile})
if l1, l2 := len(m.need), 1; l1 != l2 {
t.Errorf("Model len(m.need) incorrect (%d != %d)", l1, l2)
@@ -123,7 +123,7 @@ func TestRemoteAddNew(t *testing.T) {
func TestRemoteUpdateOld(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m)
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
oldTimeStamp := int64(1234)
@@ -132,16 +132,49 @@ func TestRemoteUpdateOld(t *testing.T) {
Modified: oldTimeStamp,
Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}},
}
m.Index(string("42"), []protocol.FileInfo{newFile})
m.Index("42", []protocol.FileInfo{newFile})
if l1, l2 := len(m.need), 0; l1 != l2 {
t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2)
}
}
func TestRemoteIndexUpdate(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
foo := protocol.FileInfo{
Name: "foo",
Modified: time.Now().Unix(),
Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}},
}
bar := protocol.FileInfo{
Name: "bar",
Modified: time.Now().Unix(),
Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}},
}
m.Index("42", []protocol.FileInfo{foo})
if _, ok := m.need["foo"]; !ok {
t.Error("Model doesn't need 'foo'")
}
m.IndexUpdate("42", []protocol.FileInfo{bar})
if _, ok := m.need["foo"]; !ok {
t.Error("Model doesn't need 'foo'")
}
if _, ok := m.need["bar"]; !ok {
t.Error("Model doesn't need 'bar'")
}
}
func TestDelete(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m)
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
if l1, l2 := len(m.local), len(fs); l1 != l2 {
@@ -231,7 +264,7 @@ func TestDelete(t *testing.T) {
func TestForgetNode(t *testing.T) {
m := NewModel("foo")
fs := Walk("testdata", m)
fs := Walk("testdata", m, false)
m.ReplaceLocal(fs)
if l1, l2 := len(m.local), len(fs); l1 != l2 {
@@ -249,7 +282,7 @@ func TestForgetNode(t *testing.T) {
Modified: time.Now().Unix(),
Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}},
}
m.Index(string("42"), []protocol.FileInfo{newFile})
m.Index("42", []protocol.FileInfo{newFile})
if l1, l2 := len(m.local), len(fs); l1 != l2 {
t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2)
@@ -261,7 +294,7 @@ func TestForgetNode(t *testing.T) {
t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2)
}
m.Close(string("42"))
m.Close("42", nil)
if l1, l2 := len(m.local), len(fs); l1 != l2 {
t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2)

View File

@@ -177,6 +177,14 @@ contents, but copies the Message ID from the Ping.
struct PongMessage {
}
### IndexUpdate (Type = 6)
This message has exactly the same structure as the Index message.
However instead of replacing the contents of the repository in the
model, the Index Update merely amends it with new or updated file
information. Any files not mentioned in an Index Update are left
unchanged.
Example Exchange
----------------

52
protocol/common_test.go Normal file
View File

@@ -0,0 +1,52 @@
package protocol
import "io"
type TestModel struct {
data []byte
name string
offset uint64
size uint32
hash []byte
closed bool
}
func (t *TestModel) Index(nodeID string, files []FileInfo) {
}
func (t *TestModel) IndexUpdate(nodeID string, files []FileInfo) {
}
func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
t.name = name
t.offset = offset
t.size = size
t.hash = hash
return t.data, nil
}
func (t *TestModel) Close(nodeID string, err error) {
t.closed = true
}
type ErrPipe struct {
io.PipeWriter
written int
max int
err error
closed bool
}
func (e *ErrPipe) Write(data []byte) (int, error) {
if e.closed {
return 0, e.err
}
if e.written+len(data) > e.max {
n, _ := e.PipeWriter.Write(data[:e.max-e.written])
e.PipeWriter.CloseWithError(e.err)
e.closed = true
return n, e.err
} else {
return e.PipeWriter.Write(data)
}
}

View File

@@ -1,7 +1,9 @@
package protocol
import (
"errors"
"io"
"sync/atomic"
"github.com/calmh/syncthing/buffers"
)
@@ -18,10 +20,18 @@ var padBytes = []byte{0, 0, 0}
type marshalWriter struct {
w io.Writer
tot int
tot uint64
err error
b [8]byte
}
// We will never encode nor expect to decode blobs larger than 10 MB. Check
// inserted to protect against attempting to allocate arbitrary amounts of
// memory when reading a corrupt message.
const maxBytesFieldLength = 10 * 1 << 20
var ErrFieldLengthExceeded = errors.New("Raw bytes field size exceeds limit")
func (w *marshalWriter) writeString(s string) {
w.writeBytes([]byte(s))
}
@@ -30,51 +40,58 @@ func (w *marshalWriter) writeBytes(bs []byte) {
if w.err != nil {
return
}
if len(bs) > maxBytesFieldLength {
w.err = ErrFieldLengthExceeded
return
}
w.writeUint32(uint32(len(bs)))
if w.err != nil {
return
}
_, w.err = w.w.Write(bs)
if p := pad(len(bs)); p > 0 {
w.w.Write(padBytes[:p])
if p := pad(len(bs)); w.err == nil && p > 0 {
_, w.err = w.w.Write(padBytes[:p])
}
w.tot += len(bs) + pad(len(bs))
atomic.AddUint64(&w.tot, uint64(len(bs)+pad(len(bs))))
}
func (w *marshalWriter) writeUint32(v uint32) {
if w.err != nil {
return
}
var b [4]byte
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
_, w.err = w.w.Write(b[:])
w.tot += 4
w.b[0] = byte(v >> 24)
w.b[1] = byte(v >> 16)
w.b[2] = byte(v >> 8)
w.b[3] = byte(v)
_, w.err = w.w.Write(w.b[:4])
atomic.AddUint64(&w.tot, 4)
}
func (w *marshalWriter) writeUint64(v uint64) {
if w.err != nil {
return
}
var b [8]byte
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
_, w.err = w.w.Write(b[:])
w.tot += 8
w.b[0] = byte(v >> 56)
w.b[1] = byte(v >> 48)
w.b[2] = byte(v >> 40)
w.b[3] = byte(v >> 32)
w.b[4] = byte(v >> 24)
w.b[5] = byte(v >> 16)
w.b[6] = byte(v >> 8)
w.b[7] = byte(v)
_, w.err = w.w.Write(w.b[:8])
atomic.AddUint64(&w.tot, 8)
}
func (w *marshalWriter) getTot() uint64 {
return atomic.LoadUint64(&w.tot)
}
type marshalReader struct {
r io.Reader
tot int
tot uint64
err error
b [8]byte
}
func (r *marshalReader) readString() string {
@@ -91,9 +108,13 @@ func (r *marshalReader) readBytes() []byte {
if r.err != nil {
return nil
}
if l > maxBytesFieldLength {
r.err = ErrFieldLengthExceeded
return nil
}
b := buffers.Get(l + pad(l))
_, r.err = io.ReadFull(r.r, b)
r.tot += int(l + pad(l))
atomic.AddUint64(&r.tot, uint64(l+pad(l)))
return b[:l]
}
@@ -101,19 +122,21 @@ func (r *marshalReader) readUint32() uint32 {
if r.err != nil {
return 0
}
var b [4]byte
_, r.err = io.ReadFull(r.r, b[:])
r.tot += 4
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
_, r.err = io.ReadFull(r.r, r.b[:4])
atomic.AddUint64(&r.tot, 8)
return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
}
func (r *marshalReader) readUint64() uint64 {
if r.err != nil {
return 0
}
var b [8]byte
_, r.err = io.ReadFull(r.r, b[:])
r.tot += 8
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
_, r.err = io.ReadFull(r.r, r.b[:8])
atomic.AddUint64(&r.tot, 8)
return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
}
func (r *marshalReader) getTot() uint64 {
return atomic.LoadUint64(&r.tot)
}

View File

@@ -48,9 +48,9 @@ func (w *marshalWriter) writeIndex(idx []FileInfo) {
}
func WriteIndex(w io.Writer, idx []FileInfo) (int, error) {
mw := marshalWriter{w, 0, nil}
mw := marshalWriter{w: w}
mw.writeIndex(idx)
return mw.tot, mw.err
return int(mw.getTot()), mw.err
}
func (w *marshalWriter) writeRequest(r request) {
@@ -69,25 +69,28 @@ func (r *marshalReader) readHeader() header {
}
func (r *marshalReader) readIndex() []FileInfo {
var files []FileInfo
nfiles := r.readUint32()
files := make([]FileInfo, nfiles)
for i := range files {
files[i].Name = r.readString()
files[i].Flags = r.readUint32()
files[i].Modified = int64(r.readUint64())
nblocks := r.readUint32()
blocks := make([]BlockInfo, nblocks)
for j := range blocks {
blocks[j].Length = r.readUint32()
blocks[j].Hash = r.readBytes()
if nfiles > 0 {
files = make([]FileInfo, nfiles)
for i := range files {
files[i].Name = r.readString()
files[i].Flags = r.readUint32()
files[i].Modified = int64(r.readUint64())
nblocks := r.readUint32()
blocks := make([]BlockInfo, nblocks)
for j := range blocks {
blocks[j].Length = r.readUint32()
blocks[j].Hash = r.readBytes()
}
files[i].Blocks = blocks
}
files[i].Blocks = blocks
}
return files
}
func ReadIndex(r io.Reader) ([]FileInfo, error) {
mr := marshalReader{r, 0, nil}
mr := marshalReader{r: r}
idx := mr.readIndex()
return idx, mr.err
}

View File

@@ -32,10 +32,10 @@ func TestIndex(t *testing.T) {
}
var buf = new(bytes.Buffer)
var wr = marshalWriter{buf, 0, nil}
var wr = marshalWriter{w: buf}
wr.writeIndex(idx)
var rd = marshalReader{buf, 0, nil}
var rd = marshalReader{r: buf}
var idx2 = rd.readIndex()
if !reflect.DeepEqual(idx, idx2) {
@@ -47,9 +47,9 @@ func TestRequest(t *testing.T) {
f := func(name string, offset uint64, size uint32, hash []byte) bool {
var buf = new(bytes.Buffer)
var req = request{name, offset, size, hash}
var wr = marshalWriter{buf, 0, nil}
var wr = marshalWriter{w: buf}
wr.writeRequest(req)
var rd = marshalReader{buf, 0, nil}
var rd = marshalReader{r: buf}
var req2 = rd.readRequest()
return req.name == req2.name &&
req.offset == req2.offset &&
@@ -64,9 +64,9 @@ func TestRequest(t *testing.T) {
func TestResponse(t *testing.T) {
f := func(data []byte) bool {
var buf = new(bytes.Buffer)
var wr = marshalWriter{buf, 0, nil}
var wr = marshalWriter{w: buf}
wr.writeResponse(data)
var rd = marshalReader{buf, 0, nil}
var rd = marshalReader{r: buf}
var read = rd.readResponse()
return bytes.Compare(read, data) == 0
}
@@ -98,7 +98,7 @@ func BenchmarkWriteIndex(b *testing.B) {
},
}
var wr = marshalWriter{ioutil.Discard, 0, nil}
var wr = marshalWriter{w: ioutil.Discard}
for i := 0; i < b.N; i++ {
wr.writeIndex(idx)
@@ -107,7 +107,7 @@ func BenchmarkWriteIndex(b *testing.B) {
func BenchmarkWriteRequest(b *testing.B) {
var req = request{"blah blah", 1231323, 13123123, []byte("hash hash hash")}
var wr = marshalWriter{ioutil.Discard, 0, nil}
var wr = marshalWriter{w: ioutil.Discard}
for i := 0; i < b.N; i++ {
wr.writeRequest(req)

View File

@@ -3,23 +3,23 @@ package protocol
import (
"compress/flate"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/calmh/syncthing/buffers"
)
const (
messageTypeReserved = iota
messageTypeIndex
messageTypeRequest
messageTypeResponse
messageTypePing
messageTypePong
messageTypeIndex = 1
messageTypeRequest = 2
messageTypeResponse = 3
messageTypePing = 4
messageTypePong = 5
messageTypeIndexUpdate = 6
)
var ErrClosed = errors.New("Connection closed")
type FileInfo struct {
Name string
Flags uint32
@@ -35,26 +35,47 @@ type BlockInfo struct {
type Model interface {
// An index was received from the peer node
Index(nodeID string, files []FileInfo)
// An index update was received from the peer node
IndexUpdate(nodeID string, files []FileInfo)
// A request was made by the peer node
Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error)
// The peer node closed the connection
Close(nodeID string)
Close(nodeID string, err error)
}
type Connection struct {
receiver Model
reader io.Reader
mreader *marshalReader
writer io.Writer
mwriter *marshalWriter
wLock sync.RWMutex
closed bool
closedLock sync.RWMutex
awaiting map[int]chan interface{}
nextId int
ID string
sync.RWMutex
ID string
receiver Model
reader io.Reader
mreader *marshalReader
writer io.Writer
mwriter *marshalWriter
closed bool
awaiting map[int]chan asyncResult
nextId int
indexSent map[string]int64
hasSentIndex bool
hasRecvdIndex bool
lastStatistics Statistics
statisticsLock sync.Mutex
}
var ErrClosed = errors.New("Connection closed")
type asyncResult struct {
val []byte
err error
}
const (
pingTimeout = 2 * time.Minute
pingIdleTime = 5 * time.Minute
)
func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver Model) *Connection {
flrd := flate.NewReader(reader)
flwr, err := flate.NewWriter(writer, flate.BestSpeed)
@@ -63,74 +84,119 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
}
c := Connection{
receiver: receiver,
reader: flrd,
mreader: &marshalReader{flrd, 0, nil},
writer: flwr,
mwriter: &marshalWriter{flwr, 0, nil},
awaiting: make(map[int]chan interface{}),
ID: nodeID,
receiver: receiver,
reader: flrd,
mreader: &marshalReader{r: flrd},
writer: flwr,
mwriter: &marshalWriter{w: flwr},
awaiting: make(map[int]chan asyncResult),
ID: nodeID,
lastStatistics: Statistics{At: time.Now()},
}
go c.readerLoop()
go c.pingerLoop()
return &c
}
// Index writes the list of file information to the connected peer node
func (c *Connection) Index(idx []FileInfo) {
c.wLock.Lock()
defer c.wLock.Unlock()
c.Lock()
var msgType int
if c.indexSent == nil {
// This is the first time we send an index.
msgType = messageTypeIndex
c.mwriter.writeHeader(header{0, c.nextId, messageTypeIndex})
c.nextId = (c.nextId + 1) & 0xfff
c.indexSent = make(map[string]int64)
for _, f := range idx {
c.indexSent[f.Name] = f.Modified
}
} else {
// We have sent one full index. Only send updates now.
msgType = messageTypeIndexUpdate
var diff []FileInfo
for _, f := range idx {
if modified, ok := c.indexSent[f.Name]; !ok || f.Modified != modified {
diff = append(diff, f)
c.indexSent[f.Name] = f.Modified
}
}
idx = diff
}
c.mwriter.writeHeader(header{0, c.nextId, msgType})
c.mwriter.writeIndex(idx)
c.flush()
err := c.flush()
c.nextId = (c.nextId + 1) & 0xfff
c.hasSentIndex = true
c.Unlock()
if err != nil {
c.Close(err)
return
} else if c.mwriter.err != nil {
c.Close(c.mwriter.err)
return
}
}
// Request returns the bytes for the specified block after fetching them from the connected peer.
func (c *Connection) Request(name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
c.wLock.Lock()
rc := make(chan interface{})
c.Lock()
if c.closed {
c.Unlock()
return nil, ErrClosed
}
rc := make(chan asyncResult)
c.awaiting[c.nextId] = rc
c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
c.mwriter.writeRequest(request{name, offset, size, hash})
c.flush()
if c.mwriter.err != nil {
c.Unlock()
c.Close(c.mwriter.err)
return nil, c.mwriter.err
}
err := c.flush()
if err != nil {
c.Unlock()
c.Close(err)
return nil, err
}
c.nextId = (c.nextId + 1) & 0xfff
c.wLock.Unlock()
c.Unlock()
// Reading something that might be nil from a possibly closed channel...
// r0<~
var data []byte
i, ok := <-rc
if ok {
if d, ok := i.([]byte); ok {
data = d
}
res, ok := <-rc
if !ok {
return nil, ErrClosed
}
var err error
i, ok = <-rc
if ok {
if e, ok := i.(error); ok {
err = e
}
}
return data, err
return res.val, res.err
}
func (c *Connection) Ping() bool {
c.wLock.Lock()
rc := make(chan interface{})
c.Lock()
if c.closed {
c.Unlock()
return false
}
rc := make(chan asyncResult, 1)
c.awaiting[c.nextId] = rc
c.mwriter.writeHeader(header{0, c.nextId, messageTypePing})
c.flush()
err := c.flush()
if err != nil {
c.Unlock()
c.Close(err)
return false
} else if c.mwriter.err != nil {
c.Unlock()
c.Close(c.mwriter.err)
return false
}
c.nextId = (c.nextId + 1) & 0xfff
c.wLock.Unlock()
c.Unlock()
_, ok := <-rc
return ok
res, ok := <-rc
return ok && res.err == nil
}
func (c *Connection) Stop() {
@@ -140,100 +206,194 @@ type flusher interface {
Flush() error
}
func (c *Connection) flush() {
func (c *Connection) flush() error {
if f, ok := c.writer.(flusher); ok {
f.Flush()
return f.Flush()
}
return nil
}
func (c *Connection) close() {
c.closedLock.Lock()
func (c *Connection) Close(err error) {
c.Lock()
if c.closed {
c.Unlock()
return
}
c.closed = true
c.closedLock.Unlock()
c.wLock.Lock()
for _, ch := range c.awaiting {
close(ch)
}
c.awaiting = nil
c.wLock.Unlock()
c.receiver.Close(c.ID)
c.Unlock()
c.receiver.Close(c.ID, err)
}
func (c *Connection) isClosed() bool {
c.closedLock.RLock()
defer c.closedLock.RUnlock()
c.RLock()
defer c.RUnlock()
return c.closed
}
func (c *Connection) readerLoop() {
for !c.isClosed() {
loop:
for {
hdr := c.mreader.readHeader()
if c.mreader.err != nil {
c.close()
break
c.Close(c.mreader.err)
break loop
}
if hdr.version != 0 {
c.Close(fmt.Errorf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version))
break loop
}
switch hdr.msgType {
case messageTypeIndex:
files := c.mreader.readIndex()
if c.mreader.err != nil {
c.close()
c.Close(c.mreader.err)
break loop
} else {
c.receiver.Index(c.ID, files)
}
c.Lock()
c.hasRecvdIndex = true
c.Unlock()
case messageTypeIndexUpdate:
files := c.mreader.readIndex()
if c.mreader.err != nil {
c.Close(c.mreader.err)
break loop
} else {
c.receiver.IndexUpdate(c.ID, files)
}
case messageTypeRequest:
c.processRequest(hdr.msgID)
req := c.mreader.readRequest()
if c.mreader.err != nil {
c.Close(c.mreader.err)
break loop
}
go c.processRequest(hdr.msgID, req)
case messageTypeResponse:
data := c.mreader.readResponse()
if c.mreader.err != nil {
c.close()
c.Close(c.mreader.err)
break loop
} else {
c.wLock.RLock()
c.Lock()
rc, ok := c.awaiting[hdr.msgID]
c.wLock.RUnlock()
delete(c.awaiting, hdr.msgID)
c.Unlock()
if ok {
rc <- data
rc <- c.mreader.err
delete(c.awaiting, hdr.msgID)
rc <- asyncResult{data, c.mreader.err}
close(rc)
}
}
case messageTypePing:
c.wLock.Lock()
c.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
c.flush()
c.wLock.Unlock()
err := c.flush()
c.Unlock()
if err != nil {
c.Close(err)
break loop
} else if c.mwriter.err != nil {
c.Close(c.mwriter.err)
break loop
}
case messageTypePong:
c.wLock.Lock()
if rc, ok := c.awaiting[hdr.msgID]; ok {
rc <- true
c.RLock()
rc, ok := c.awaiting[hdr.msgID]
c.RUnlock()
if ok {
rc <- asyncResult{}
close(rc)
c.Lock()
delete(c.awaiting, hdr.msgID)
c.Unlock()
}
c.wLock.Unlock()
default:
c.Close(fmt.Errorf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType))
break loop
}
}
}
func (c *Connection) processRequest(msgID int) {
req := c.mreader.readRequest()
if c.mreader.err != nil {
c.close()
} else {
go func() {
data, _ := c.receiver.Request(c.ID, req.name, req.offset, req.size, req.hash)
c.wLock.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
c.mwriter.writeResponse(data)
buffers.Put(data)
c.flush()
c.wLock.Unlock()
}()
func (c *Connection) processRequest(msgID int, req request) {
data, _ := c.receiver.Request(c.ID, req.name, req.offset, req.size, req.hash)
c.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
c.mwriter.writeResponse(data)
err := c.flush()
c.Unlock()
buffers.Put(data)
if err != nil {
c.Close(err)
} else if c.mwriter.err != nil {
c.Close(c.mwriter.err)
}
}
func (c *Connection) pingerLoop() {
var rc = make(chan bool, 1)
for {
time.Sleep(pingIdleTime / 2)
c.RLock()
ready := c.hasRecvdIndex && c.hasSentIndex
c.RUnlock()
if ready {
go func() {
rc <- c.Ping()
}()
select {
case ok := <-rc:
if !ok {
c.Close(fmt.Errorf("Ping failure"))
}
case <-time.After(pingTimeout):
c.Close(fmt.Errorf("Ping timeout"))
}
}
}
}
type Statistics struct {
At time.Time
InBytesTotal int
InBytesPerSec int
OutBytesTotal int
OutBytesPerSec int
}
func (c *Connection) Statistics() Statistics {
c.statisticsLock.Lock()
defer c.statisticsLock.Unlock()
secs := time.Since(c.lastStatistics.At).Seconds()
rt := int(c.mreader.getTot())
wt := int(c.mwriter.getTot())
stats := Statistics{
At: time.Now(),
InBytesTotal: rt,
InBytesPerSec: int(float64(rt-c.lastStatistics.InBytesTotal) / secs),
OutBytesTotal: wt,
OutBytesPerSec: int(float64(wt-c.lastStatistics.OutBytesTotal) / secs),
}
c.lastStatistics = stats
return stats
}

View File

@@ -1,8 +1,11 @@
package protocol
import (
"errors"
"io"
"testing"
"testing/quick"
"time"
)
func TestHeaderFunctions(t *testing.T) {
@@ -35,3 +38,177 @@ func TestPad(t *testing.T) {
}
}
}
func TestPing(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection("c0", ar, bw, nil)
c1 := NewConnection("c1", br, aw, nil)
if ok := c0.Ping(); !ok {
t.Error("c0 ping failed")
}
if ok := c1.Ping(); !ok {
t.Error("c1 ping failed")
}
}
func TestPingErr(t *testing.T) {
e := errors.New("Something broke")
for i := 0; i < 12; i++ {
for j := 0; j < 12; j++ {
m0 := &TestModel{}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
c0 := NewConnection("c0", ar, ebw, m0)
NewConnection("c1", br, eaw, m1)
res := c0.Ping()
if (i < 4 || j < 4) && res {
t.Errorf("Unexpected ping success; i=%d, j=%d", i, j)
} else if (i >= 8 && j >= 8) && !res {
t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j)
}
}
}
}
func TestRequestResponseErr(t *testing.T) {
e := errors.New("Something broke")
var pass bool
for i := 0; i < 36; i++ {
for j := 0; j < 26; j++ {
m0 := &TestModel{data: []byte("response data")}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
NewConnection("c0", ar, ebw, m0)
c1 := NewConnection("c1", br, eaw, m1)
d, err := c1.Request("tn", 1234, 3456, []byte("hashbytes"))
if err == e || err == ErrClosed {
t.Logf("Error at %d+%d bytes", i, j)
if !m1.closed {
t.Error("c1 not closed")
}
time.Sleep(1 * time.Millisecond)
if !m0.closed {
t.Error("c0 not closed")
}
continue
}
if err != nil {
t.Error(err)
}
if string(d) != "response data" {
t.Errorf("Incorrect response data %q", string(d))
}
if m0.name != "tn" {
t.Error("Incorrect name %q", m0.name)
}
if m0.offset != 1234 {
t.Error("Incorrect offset %d", m0.offset)
}
if m0.size != 3456 {
t.Error("Incorrect size %d", m0.size)
}
if string(m0.hash) != "hashbytes" {
t.Error("Incorrect hash %q", m0.hash)
}
t.Logf("Pass at %d+%d bytes", i, j)
pass = true
}
}
if !pass {
t.Error("Never passed")
}
}
func TestVersionErr(t *testing.T) {
m0 := &TestModel{}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection("c0", ar, bw, m0)
NewConnection("c1", br, aw, m1)
c0.mwriter.writeHeader(header{
version: 2,
msgID: 0,
msgType: 0,
})
c0.flush()
if !m1.closed {
t.Error("Connection should close due to unknown version")
}
}
func TestTypeErr(t *testing.T) {
m0 := &TestModel{}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection("c0", ar, bw, m0)
NewConnection("c1", br, aw, m1)
c0.mwriter.writeHeader(header{
version: 0,
msgID: 0,
msgType: 42,
})
c0.flush()
if !m1.closed {
t.Error("Connection should close due to unknown message type")
}
}
func TestClose(t *testing.T) {
m0 := &TestModel{}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection("c0", ar, bw, m0)
NewConnection("c1", br, aw, m1)
c0.Close(nil)
ok := c0.isClosed()
if !ok {
t.Fatal("Connection should be closed")
}
// None of these should panic, some should return an error
ok = c0.Ping()
if ok {
t.Error("Ping should not return true")
}
c0.Index(nil)
c0.Index(nil)
_, err := c0.Request("foo", 0, 0, nil)
if err == nil {
t.Error("Request should return an error")
}
}

View File

@@ -1,11 +0,0 @@
[repository]
dir = /Users/jb/Synced
# The nodes section lists the nodes that make up the cluster. The format is
# <certificate id> = <space separated list of addresses>
# The special address "dynamic" means that outbound connections will not be
# attempted, but inbound connections are accepted.
[nodes]
ITZXTZ7A32DWV3NLNR5W4M3CHVBW56NA = 172.16.32.1:22000 192.23.34.56:22000
CUGAE43Y5N64CRJU26YFH6MTWPSBLSUL = dynamic

4
tls.go
View File

@@ -58,11 +58,11 @@ func newCertificate(dir string) {
fatalErr(err)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
okln("wrote cert.pem")
okln("Created TLS certificate file")
keyOut, err := os.OpenFile(path.Join(dir, "key.pem"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
fatalErr(err)
pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
keyOut.Close()
okln("wrote key.pem")
okln("Created TLS key file")
}

53
walk.go
View File

@@ -25,6 +25,13 @@ func (f File) Dump() {
fmt.Println()
}
func (f File) Size() (bytes int) {
for _, b := range f.Blocks {
bytes += int(b.Length)
}
return
}
func isTempName(name string) bool {
return strings.HasPrefix(path.Base(name), ".syncthing.")
}
@@ -38,7 +45,8 @@ func tempName(name string, modified int64) string {
func genWalker(base string, res *[]File, model *Model) filepath.WalkFunc {
return func(p string, info os.FileInfo, err error) error {
if err != nil {
return err
warnln(err)
return nil
}
if isTempName(p) {
@@ -48,12 +56,14 @@ func genWalker(base string, res *[]File, model *Model) filepath.WalkFunc {
if info.Mode()&os.ModeType == 0 {
rn, err := filepath.Rel(base, p)
if err != nil {
return err
warnln(err)
return nil
}
fi, err := os.Stat(p)
if err != nil {
return err
warnln(err)
return nil
}
modified := fi.ModTime().Unix()
@@ -62,18 +72,20 @@ func genWalker(base string, res *[]File, model *Model) filepath.WalkFunc {
// No change
*res = append(*res, hf)
} else {
if traceFile {
if opts.Debug.TraceFile {
debugf("FILE: Hash %q", p)
}
fd, err := os.Open(p)
if err != nil {
return err
warnln(err)
return nil
}
defer fd.Close()
blocks, err := Blocks(fd, BlockSize)
if err != nil {
return err
warnln(err)
return nil
}
f := File{
Name: rn,
@@ -89,13 +101,38 @@ func genWalker(base string, res *[]File, model *Model) filepath.WalkFunc {
}
}
func Walk(dir string, model *Model) []File {
func Walk(dir string, model *Model, followSymlinks bool) []File {
var files []File
fn := genWalker(dir, &files, model)
err := filepath.Walk(dir, fn)
if err != nil {
warnln(err)
}
if !opts.NoSymlinks {
d, err := os.Open(dir)
if err != nil {
warnln(err)
return files
}
defer d.Close()
fis, err := d.Readdir(-1)
if err != nil {
warnln(err)
return files
}
for _, fi := range fis {
if fi.Mode()&os.ModeSymlink != 0 {
err := filepath.Walk(path.Join(dir, fi.Name())+"/", fn)
if err != nil {
warnln(err)
}
}
}
}
return files
}
@@ -104,7 +141,7 @@ func cleanTempFile(path string, info os.FileInfo, err error) error {
return err
}
if info.Mode()&os.ModeType == 0 && isTempName(path) {
if traceFile {
if opts.Debug.TraceFile {
debugf("FILE: Remove %q", path)
}
os.Remove(path)

View File

@@ -18,7 +18,7 @@ var testdata = []struct {
func TestWalk(t *testing.T) {
m := new(Model)
files := Walk("testdata", m)
files := Walk("testdata", m, false)
if l1, l2 := len(files), len(testdata); l1 != l2 {
t.Fatalf("Incorrect number of walked files %d != %d", l1, l2)