mirror of
https://github.com/kopia/kopia.git
synced 2026-01-22 13:27:57 -05:00
204 lines
6.2 KiB
Go
204 lines
6.2 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const oneDay = 24 * time.Hour
|
|
|
|
var (
|
|
serverStartTLSGenerateCert = serverStartCommand.Flag("tls-generate-cert", "Generate TLS certificate").Hidden().Bool()
|
|
serverStartTLSCertFile = serverStartCommand.Flag("tls-cert-file", "TLS certificate PEM").String()
|
|
serverStartTLSKeyFile = serverStartCommand.Flag("tls-key-file", "TLS key PEM file").String()
|
|
serverStartTLSGenerateRSAKeySize = serverStartCommand.Flag("tls-generate-rsa-key-size", "TLS RSA Key size (bits)").Hidden().Default("4096").Int()
|
|
serverStartTLSGenerateCertValidDays = serverStartCommand.Flag("tls-generate-cert-valid-days", "How long should the TLS certificate be valid").Default("3650").Hidden().Int()
|
|
serverStartTLSGenerateCertNames = serverStartCommand.Flag("tls-generate-cert-name", "Host names/IP addresses to generate TLS certificate for").Default("127.0.0.1").Hidden().Strings()
|
|
)
|
|
|
|
func generateServerCertificate(ctx context.Context) (*x509.Certificate, *rsa.PrivateKey, error) {
|
|
log(ctx).Infof("generating new TLS certificate")
|
|
|
|
priv, err := rsa.GenerateKey(rand.Reader, *serverStartTLSGenerateRSAKeySize)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "unable to generate RSA key")
|
|
}
|
|
|
|
notBefore := time.Now()
|
|
notAfter := notBefore.Add(time.Duration(*serverStartTLSGenerateCertValidDays) * oneDay)
|
|
|
|
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "unable to generate serial number")
|
|
}
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: serialNumber,
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Kopia"},
|
|
},
|
|
NotBefore: notBefore,
|
|
NotAfter: notAfter,
|
|
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
for _, n := range *serverStartTLSGenerateCertNames {
|
|
if ip := net.ParseIP(n); ip != nil {
|
|
log(ctx).Infof("adding alternative IP to certificate: %v", ip)
|
|
template.IPAddresses = append(template.IPAddresses, ip)
|
|
} else {
|
|
log(ctx).Infof("adding alternative DNS name to certificate: %v", n)
|
|
template.DNSNames = append(template.DNSNames, n)
|
|
}
|
|
}
|
|
|
|
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "failed to create certificate")
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(derBytes)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "failed to parse certificate")
|
|
}
|
|
|
|
return cert, priv, nil
|
|
}
|
|
|
|
func writePrivateKeyToFile(fname string, priv *rsa.PrivateKey) error {
|
|
f, err := os.Create(fname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close() //nolint:errcheck
|
|
|
|
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Unable to marshal private key")
|
|
}
|
|
|
|
if err := pem.Encode(f, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
|
return errors.Wrap(err, "Failed to write data to")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func writeCertificateToFile(fname string, cert *x509.Certificate) error {
|
|
f, err := os.Create(fname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close() //nolint:errcheck
|
|
|
|
if err := pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}); err != nil {
|
|
return errors.Wrap(err, "Failed to write data")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func startServerWithOptionalTLS(ctx context.Context, httpServer *http.Server) error {
|
|
l, err := net.Listen("tcp", httpServer.Addr)
|
|
if err != nil {
|
|
return errors.Wrap(err, "listen error")
|
|
}
|
|
defer l.Close() //nolint:errcheck
|
|
|
|
httpServer.Addr = l.Addr().String()
|
|
|
|
return startServerWithOptionalTLSAndListener(ctx, httpServer, l)
|
|
}
|
|
|
|
func startServerWithOptionalTLSAndListener(ctx context.Context, httpServer *http.Server, listener net.Listener) error {
|
|
// generate and save to PEM files
|
|
if *serverStartTLSGenerateCert && *serverStartTLSCertFile != "" && *serverStartTLSKeyFile != "" {
|
|
if _, err := os.Stat(*serverStartTLSCertFile); err == nil {
|
|
return errors.Errorf("TLS cert file already exists: %q", *serverStartTLSCertFile)
|
|
}
|
|
|
|
if _, err := os.Stat(*serverStartTLSKeyFile); err == nil {
|
|
return errors.Errorf("TLS key file already exists: %q", *serverStartTLSKeyFile)
|
|
}
|
|
|
|
cert, key, err := generateServerCertificate(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "unable to generate server cert")
|
|
}
|
|
|
|
log(ctx).Infof("writing TLS certificate to %v", *serverStartTLSCertFile)
|
|
|
|
if err := writeCertificateToFile(*serverStartTLSCertFile, cert); err != nil {
|
|
return errors.Wrap(err, "unable to write private key")
|
|
}
|
|
|
|
log(ctx).Infof("writing TLS private key to %v", *serverStartTLSKeyFile)
|
|
|
|
if err := writePrivateKeyToFile(*serverStartTLSKeyFile, key); err != nil {
|
|
return errors.Wrap(err, "unable to write private key")
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case *serverStartTLSCertFile != "" && *serverStartTLSKeyFile != "":
|
|
// PEM files provided
|
|
fmt.Fprintf(os.Stderr, "SERVER ADDRESS: https://%v\n", httpServer.Addr)
|
|
showServerUIPrompt()
|
|
|
|
return httpServer.ServeTLS(listener, *serverStartTLSCertFile, *serverStartTLSKeyFile)
|
|
|
|
case *serverStartTLSGenerateCert:
|
|
// PEM files not provided, generate in-memory TLS cert/key but don't persit.
|
|
cert, key, err := generateServerCertificate(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "unable to generate server cert")
|
|
}
|
|
|
|
httpServer.TLSConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{
|
|
{
|
|
Certificate: [][]byte{cert.Raw},
|
|
PrivateKey: key,
|
|
},
|
|
},
|
|
}
|
|
|
|
fingerprint := sha256.Sum256(cert.Raw)
|
|
fmt.Fprintf(os.Stderr, "SERVER CERT SHA256: %v\n", hex.EncodeToString(fingerprint[:]))
|
|
fmt.Fprintf(os.Stderr, "SERVER ADDRESS: https://%v\n", httpServer.Addr)
|
|
showServerUIPrompt()
|
|
|
|
return httpServer.ServeTLS(listener, "", "")
|
|
|
|
default:
|
|
fmt.Fprintf(os.Stderr, "SERVER ADDRESS: http://%v\n", httpServer.Addr)
|
|
showServerUIPrompt()
|
|
|
|
return httpServer.Serve(listener)
|
|
}
|
|
}
|
|
|
|
func showServerUIPrompt() {
|
|
if *serverStartUI {
|
|
printStderr("\nOpen the address above in a web browser to use the UI.\n")
|
|
}
|
|
}
|