Compare commits

...

18 Commits

Author SHA1 Message Date
Eva Ho
4bd2afd22f address comment 2026-01-22 14:25:52 -05:00
Eva Ho
af05f2e58e fix test 2026-01-22 13:26:59 -05:00
Eva Ho
d4bf4603ec updater: add more test converage to cover auto updater 2026-01-22 13:26:59 -05:00
Eva Ho
c89f69894f fix tests 2026-01-22 13:26:59 -05:00
Eva Ho
fca5ee8526 fix tests 2026-01-22 13:26:59 -05:00
Eva Ho
e264a88b15 fix test 2026-01-22 13:26:59 -05:00
Eva Ho
bdc17714f9 address comments 2026-01-22 13:26:59 -05:00
Eva Ho
f1095cd043 address comment 2026-01-22 13:26:59 -05:00
Eva Ho
0fc4f9fa4e address comment 2026-01-22 13:26:59 -05:00
Eva Ho
58970ae7c5 fix: gofmt formatting in updater_test.go 2026-01-22 13:26:59 -05:00
Eva Ho
b1f8f3e79d fix test 2026-01-22 13:26:59 -05:00
Eva Ho
9f1baf621a fix format 2026-01-22 13:26:59 -05:00
Eva Ho
016b423639 fix test 2026-01-22 13:26:59 -05:00
Eva Ho
87b4ec129a fix test 2026-01-22 13:26:59 -05:00
Eva Ho
e6b377928b clean up 2026-01-22 13:26:59 -05:00
Eva Ho
a54fd25355 fix behaviour when switching between enabled and disabled 2026-01-22 13:26:59 -05:00
Eva Ho
802aa06680 fix test 2026-01-22 13:26:59 -05:00
Eva Ho
abdb778a1d app: add upgrade configuration to settings page 2026-01-22 13:26:59 -05:00
12 changed files with 613 additions and 59 deletions

View File

@@ -253,6 +253,8 @@ func main() {
done <- osrv.Run(octx)
}()
upd := &updater.Updater{Store: st}
uiServer := ui.Server{
Token: token,
Restart: func() {
@@ -267,6 +269,10 @@ func main() {
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
Updater: upd,
UpdateAvailableFunc: func() {
UpdateAvailable("")
},
}
srv := &http.Server{
@@ -284,8 +290,13 @@ func main() {
slog.Debug("background desktop server done")
}()
updater := &updater.Updater{Store: st}
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
if err != nil {
@@ -348,6 +359,18 @@ func startHiddenTasks() {
// CLI triggered app startup use-case
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
// Still show tray notification so user knows update is ready
UpdateAvailable("")
return
}
if err := updater.DoUpgradeAtStartup(); err != nil {
slog.Info("unable to perform upgrade at startup", "error", err)
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization

View File

@@ -9,12 +9,12 @@ import (
"strings"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
)
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 12
const currentSchemaVersion = 13
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -85,6 +85,7 @@ func (db *database) init() error {
think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '',
remote TEXT NOT NULL DEFAULT '', -- deprecated
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
schema_version INTEGER NOT NULL DEFAULT %d
);
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v11 to v12: %w", err)
}
version = 12
case 12:
// add auto_update_enabled column to settings table
if err := db.migrateV12ToV13(); err != nil {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
return nil
}
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
func (db *database) migrateV12ToV13() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add auto_update_enabled column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`
@@ -482,19 +504,11 @@ func (db *database) cleanupOrphanedData() error {
}
func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
return false
return err != nil && strings.Contains(err.Error(), "duplicate column name")
}
func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
return false
return err != nil && strings.Contains(err.Error(), "no such column")
}
func (db *database) getAllChats() ([]Chat, error) {
@@ -1108,9 +1122,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings
err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err)
}
@@ -1121,8 +1135,8 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(`
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil {
return fmt.Errorf("set settings: %w", err)
}

View File

@@ -169,6 +169,9 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool
// AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool
}
type Store struct {

View File

@@ -413,6 +413,7 @@ export class Settings {
ThinkLevel: string;
SelectedModel: string;
SidebarOpen: boolean;
AutoUpdateEnabled: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
@@ -431,6 +432,7 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
}
}
export class SettingsResponse {
@@ -467,6 +469,46 @@ export class HealthResponse {
this.healthy = source["healthy"];
}
}
export class UpdateInfo {
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.currentVersion = source["currentVersion"];
this.availableVersion = source["availableVersion"];
this.updateAvailable = source["updateAvailable"];
this.updateDownloaded = source["updateDownloaded"];
}
}
export class UpdateCheckResponse {
updateInfo: UpdateInfo;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (Array.isArray(a)) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
export class User {
id: string;
email: string;

View File

@@ -414,3 +414,54 @@ export async function fetchHealth(): Promise<boolean> {
return false;
}
}
export async function getCurrentVersion(): Promise<string> {
try {
const response = await fetch(`${API_BASE}/api/version`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const data = await response.json();
return data.version || "Unknown";
}
return "Unknown";
} catch (error) {
console.error("Error fetching version:", error);
return "Unknown";
}
}
export async function checkForUpdate(): Promise<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
}> {
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error("Failed to check for update");
}
const data = await response.json();
return data.updateInfo;
}
export async function installUpdate(): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(error || "Failed to install update");
}
}

View File

@@ -14,12 +14,13 @@ import {
XMarkIcon,
CogIcon,
ArrowLeftIcon,
ArrowDownTrayIcon,
} from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router";
import { useUser } from "@/hooks/useUser";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { getSettings, updateSettings } from "@/api";
import { getSettings, updateSettings, checkForUpdate } from "@/api";
function AnimatedDots() {
return (
@@ -39,6 +40,12 @@ export default function Settings() {
const queryClient = useQueryClient();
const [showSaved, setShowSaved] = useState(false);
const [restartMessage, setRestartMessage] = useState(false);
const [updateInfo, setUpdateInfo] = useState<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
} | null>(null);
const {
user,
isAuthenticated,
@@ -76,6 +83,10 @@ export default function Settings() {
useEffect(() => {
refetchUser();
// Check for updates on mount
checkForUpdate()
.then(setUpdateInfo)
.catch((err) => console.error("Error checking for update:", err));
}, []); // eslint-disable-line react-hooks/exhaustive-deps
useEffect(() => {
@@ -344,6 +355,58 @@ export default function Settings() {
{/* Local Configuration */}
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
<div className="space-y-4 p-4">
{/* Auto Update */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div className="flex-1">
<Label>Auto-download updates</Label>
<Description>
{settings.AutoUpdateEnabled ? (
<>
Automatically downloads updates when available.
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
Current version: {updateInfo?.currentVersion || "Loading..."}
</div>
</>
) : (
<>
Manually download updates.
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
<div className="space-y-2 text-sm">
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
</div>
{updateInfo?.availableVersion && (
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
</div>
)}
</div>
<a
href="https://ollama.com/download"
target="_blank"
rel="noopener noreferrer"
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
>
Download new version
</a>
</div>
</>
)}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AutoUpdateEnabled}
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
/>
</div>
</div>
</Field>
{/* Expose Ollama */}
<Field>
<div className="flex items-start justify-between gap-4">

View File

@@ -100,6 +100,17 @@ type HealthResponse struct {
Healthy bool `json:"healthy"`
}
type UpdateInfo struct {
CurrentVersion string `json:"currentVersion"`
AvailableVersion string `json:"availableVersion"`
UpdateAvailable bool `json:"updateAvailable"`
UpdateDownloaded bool `json:"updateDownloaded"`
}
type UpdateCheckResponse struct {
UpdateInfo UpdateInfo `json:"updateInfo"`
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/updater"
"github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
@@ -106,6 +107,10 @@ type Server struct {
// Dev is true if the server is running in development mode
Dev bool
// Updater for checking and downloading updates
Updater *updater.Updater
UpdateAvailableFunc func()
}
func (s *Server) log() *slog.Logger {
@@ -284,6 +289,8 @@ func (s *Server) Handler() http.Handler {
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings))
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
// Ollama proxy endpoints
ollamaProxy := s.ollamaProxy()
@@ -1448,6 +1455,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to save settings: %w", err)
}
// Handle auto-update toggle changes
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
if !settings.AutoUpdateEnabled {
// Auto-update disabled: cancel any ongoing download
if s.Updater != nil {
s.Updater.CancelOngoingDownload()
}
} else {
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
s.UpdateAvailableFunc()
} else if s.Updater != nil {
// Trigger the background checker to run immediately
s.Updater.TriggerImmediateCheck()
}
}
}
if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models ||
old.Expose != settings.Expose {
@@ -1524,6 +1549,73 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(response)
}
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
currentVersion := version.Version
if s.Updater == nil {
return fmt.Errorf("updater not available")
}
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
if err != nil {
s.log().Warn("failed to check for update", "error", err)
// Don't return error, just log it and continue with no update available
}
response := responses.UpdateCheckResponse{
UpdateInfo: responses.UpdateInfo{
CurrentVersion: currentVersion,
AvailableVersion: updateVersion,
UpdateAvailable: updateAvailable,
UpdateDownloaded: updater.UpdateDownloaded,
},
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return fmt.Errorf("method not allowed")
}
if s.Updater == nil {
s.log().Error("install failed: updater not available")
return fmt.Errorf("updater not available")
}
// Check if update is downloaded
if !updater.UpdateDownloaded {
s.log().Error("install failed: no update downloaded")
return fmt.Errorf("no update downloaded")
}
// Send response before restarting
response := map[string]any{
"success": true,
"message": "Installing update and restarting...",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
// Give the response time to be sent
time.Sleep(500 * time.Millisecond)
// Trigger the upgrade and restart
go func() {
time.Sleep(500 * time.Millisecond)
if err := s.Updater.InstallAndRestart(); err != nil {
s.log().Error("failed to install update", "error", err)
}
}()
return nil
}
func userAgent() string {
buildinfo, _ := debug.ReadBuildInfo()

View File

@@ -19,6 +19,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/app/store"
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", version.Version)
currentVersion := version.Version
query.Add("version", currentVersion)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
// The original macOS app used to use the device ID
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
}
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Create a cancellable context for this download
downloadCtx, cancel := context.WithCancel(ctx)
u.cancelDownloadLock.Lock()
u.cancelDownload = cancel
u.cancelDownloadLock.Unlock()
defer func() {
u.cancelDownloadLock.Lock()
u.cancelDownload = nil
u.cancelDownloadLock.Unlock()
cancel()
}()
// Do a head first to check etag info
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil {
return err
}
// In case of slow downloads, continue the update check in the background
bgctx, cancel := context.WithCancel(ctx)
defer cancel()
bgctx, bgcancel := context.WithCancel(downloadCtx)
defer bgcancel()
go func() {
for {
select {
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(stageFilename)
if err == nil {
slog.Info("update already downloaded", "bundle", stageFilename)
UpdateDownloaded = true
return nil
}
@@ -244,34 +259,99 @@ func cleanupOldDownloads(stageDir string) {
}
type Updater struct {
Store *store.Store
Store *store.Store
cancelDownload context.CancelFunc
cancelDownloadLock sync.Mutex
checkNow chan struct{}
}
// CancelOngoingDownload cancels any currently running download
func (u *Updater) CancelOngoingDownload() {
u.cancelDownloadLock.Lock()
defer u.cancelDownloadLock.Unlock()
if u.cancelDownload != nil {
slog.Info("cancelling ongoing update download")
u.cancelDownload()
u.cancelDownload = nil
}
}
// TriggerImmediateCheck signals the background checker to check for updates immediately
func (u *Updater) TriggerImmediateCheck() {
if u.checkNow != nil {
select {
case u.checkNow <- struct{}{}:
default:
// Check already pending, no need to queue another
}
}
}
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
ticker := time.NewTicker(UpdateCheckInterval)
defer ticker.Stop()
for {
available, resp := u.checkForUpdate(ctx)
if available {
err := u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} else {
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
}
}
select {
case <-ctx.Done():
slog.Debug("stopping background update checker")
return
default:
time.Sleep(UpdateCheckInterval)
case <-u.checkNow:
// Immediate check triggered
case <-ticker.C:
// Regular interval check
}
// Always check for updates
available, resp := u.checkForUpdate(ctx)
if !available {
continue
}
// Update is available - check if auto-update is enabled for downloading
settings, err := u.Store.Settings()
if err != nil {
slog.Error("failed to load settings", "error", err)
continue
}
if !settings.AutoUpdateEnabled {
// Auto-update disabled - don't download, just log
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
continue
}
// Auto-update is enabled - download
err = u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error("failed to download new release", "error", err)
continue
}
// Download successful - show tray notification (regardless of toggle state)
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)
}
}
}()
}
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
available, resp := u.checkForUpdate(ctx)
return available, resp.UpdateVersion, nil
}
func (u *Updater) InstallAndRestart() error {
if !UpdateDownloaded {
return fmt.Errorf("no update downloaded")
}
slog.Info("installing update and restarting")
return DoUpgrade(true)
}

View File

@@ -11,6 +11,7 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
@@ -85,7 +86,17 @@ func TestBackgoundChecker(t *testing.T) {
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed
defer updater.Store.Close()
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
select {
case <-stallTimer.C:
@@ -99,3 +110,187 @@ func TestBackgoundChecker(t *testing.T) {
}
}
}
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
done := make(chan struct{})
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
// Ensure auto-update is disabled
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
t.Fatal("callback should not be called when auto-update is disabled")
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait enough time for multiple check cycles
time.Sleep(50 * time.Millisecond)
close(done)
if downloadAttempted.Load() {
t.Fatal("download should not be attempted when auto-update is disabled")
}
}
func TestCancelOngoingDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
downloadStarted := make(chan struct{})
downloadCancelled := make(chan struct{})
ctx := t.Context()
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", "1000000")
w.WriteHeader(http.StatusOK)
return
}
// Signal that download has started
close(downloadStarted)
// Wait for cancellation or timeout
select {
case <-r.Context().Done():
close(downloadCancelled)
return
case <-time.After(5 * time.Second):
t.Error("download was not cancelled in time")
}
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
_, resp := updater.checkForUpdate(ctx)
// Start download in goroutine
go func() {
_ = updater.DownloadNewRelease(ctx, resp)
}()
// Wait for download to start
select {
case <-downloadStarted:
case <-time.After(2 * time.Second):
t.Fatal("download did not start in time")
}
// Cancel the download
updater.CancelOngoingDownload()
// Verify cancellation was received
select {
case <-downloadCancelled:
// Success
case <-time.After(2 * time.Second):
t.Fatal("download cancellation was not received by server")
}
}
func TestTriggerImmediateCheck(t *testing.T) {
UpdateStageDir = t.TempDir()
checkCount := atomic.Int32{}
checkDone := make(chan struct{}, 10)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Set a very long interval so only TriggerImmediateCheck causes checks
UpdateCheckInitialDelay = 1 * time.Millisecond
UpdateCheckInterval = 1 * time.Hour
VerifyDownload = func() error {
return nil
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
checkCount.Add(1)
select {
case checkDone <- struct{}{}:
default:
}
// Return no update available
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
cb := func(ver string) error {
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check
updater.TriggerImmediateCheck()
// Wait for the triggered check
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("triggered check did not happen")
}
finalCount := checkCount.Load()
if finalCount <= initialCount {
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
}
}

View File

@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
return nil
}
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
// const ERROR_SUCCESS syscall.Errno = 0
// t.muMenus.RLock()
// menu := uintptr(t.menus[parentId])
// t.muMenus.RUnlock()
// res, _, err := pRemoveMenu.Call(
// menu,
// uintptr(menuItemId),
// MF_BYCOMMAND,
// )
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
// return err
// }
// t.delFromVisibleItems(parentId, menuItemId)
// return nil
// }
func (t *winTray) showMenu() error {
p := point{}
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))

View File

@@ -51,7 +51,6 @@ const (
IMAGE_ICON = 1 // Loads an icon
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
MF_BYCOMMAND = 0x00000000
MFS_DISABLED = 0x00000003
MFT_SEPARATOR = 0x00000800
MFT_STRING = 0x00000000