diff --git a/backend/go/stablediffusion-ggml/Makefile b/backend/go/stablediffusion-ggml/Makefile index 0d5361e55..dadced715 100644 --- a/backend/go/stablediffusion-ggml/Makefile +++ b/backend/go/stablediffusion-ggml/Makefile @@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1) # stablediffusion.cpp (ggml) STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp -STABLEDIFFUSION_GGML_VERSION?=f16a110f8776398ef23a2a6b7b57522c2471637a +STABLEDIFFUSION_GGML_VERSION?=1d6cb0f8c33ddadf1bff8aff40ec2e5b1ccb4940 CMAKE_ARGS+=-DGGML_MAX_NAME=128 diff --git a/backend/go/stablediffusion-ggml/gosd.cpp b/backend/go/stablediffusion-ggml/gosd.cpp index 47b519a5e..f010f73a6 100644 --- a/backend/go/stablediffusion-ggml/gosd.cpp +++ b/backend/go/stablediffusion-ggml/gosd.cpp @@ -27,107 +27,7 @@ #include #include -// Names of the sampler method, same order as enum sample_method in stable-diffusion.h -const char* sample_method_str[] = { - "euler", - "euler_a", - "heun", - "dpm2", - "dpm++2s_a", - "dpm++2m", - "dpm++2mv2", - "ipndm", - "ipndm_v", - "lcm", - "ddim_trailing", - "tcd", - "res_multistep", - "res_2s", -}; -static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch"); - -// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h -const char* schedulers[] = { - "discrete", - "karras", - "exponential", - "ays", - "gits", - "sgm_uniform", - "simple", - "smoothstep", - "kl_optimal", - "lcm", - "bong_tangent", -}; - -static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch"); - -// New enum string arrays -const char* rng_type_str[] = { - "std_default", - "cuda", - "cpu", -}; -static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch"); - -const char* prediction_str[] = { - "epsilon", - "v", - "edm_v", - "flow", - "flux_flow", - "flux2_flow", -}; -static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch"); - -const char* lora_apply_mode_str[] = { - "auto", - "immediately", - "at_runtime", -}; -static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch"); - -constexpr const char* sd_type_str[] = { - "f32", // 0 - "f16", // 1 - "q4_0", // 2 - "q4_1", // 3 - nullptr, // 4 - nullptr, // 5 - "q5_0", // 6 - "q5_1", // 7 - "q8_0", // 8 - "q8_1", // 9 - "q2_k", // 10 - "q3_k", // 11 - "q4_k", // 12 - "q5_k", // 13 - "q6_k", // 14 - "q8_k", // 15 - "iq2_xxs", // 16 - "iq2_xs", // 17 - "iq3_xxs", // 18 - "iq1_s", // 19 - "iq4_nl", // 20 - "iq3_s", // 21 - "iq2_s", // 22 - "iq4_xs", // 23 - "i8", // 24 - "i16", // 25 - "i32", // 26 - "i64", // 27 - "f64", // 28 - "iq1_m", // 29 - "bf16", // 30 - nullptr, nullptr, nullptr, // 31-33 - "tq1_0", // 34 - "tq2_0", // 35 - nullptr, nullptr, nullptr, // 36-38 - "mxfp4" // 39 -}; -static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch"); sd_ctx_params_t ctx_params; sd_ctx_t* sd_c; @@ -596,75 +496,45 @@ int load_model(const char *model, char *model_path, char* options[], int threads if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval); if (!strcmp(optname, "rng_type")) { - int found = -1; - for (int m = 0; m < RNG_TYPE_COUNT; m++) { - if (!strcmp(optval, rng_type_str[m])) { - found = m; - break; - } - } - if (found != -1) { - rng_type = (rng_type_t)found; + rng_type_t parsed = str_to_rng_type(optval); + if (parsed != RNG_TYPE_COUNT) { + rng_type = parsed; fprintf(stderr, "Found rng_type: %s\n", optval); } else { fprintf(stderr, "Invalid rng_type: %s, using default\n", optval); } } if (!strcmp(optname, "sampler_rng_type")) { - int found = -1; - for (int m = 0; m < RNG_TYPE_COUNT; m++) { - if (!strcmp(optval, rng_type_str[m])) { - found = m; - break; - } - } - if (found != -1) { - sampler_rng_type = (rng_type_t)found; + rng_type_t parsed = str_to_rng_type(optval); + if (parsed != RNG_TYPE_COUNT) { + sampler_rng_type = parsed; fprintf(stderr, "Found sampler_rng_type: %s\n", optval); } else { fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval); } } if (!strcmp(optname, "prediction")) { - int found = -1; - for (int m = 0; m < PREDICTION_COUNT; m++) { - if (!strcmp(optval, prediction_str[m])) { - found = m; - break; - } - } - if (found != -1) { - prediction = (prediction_t)found; + prediction_t parsed = str_to_prediction(optval); + if (parsed != PREDICTION_COUNT) { + prediction = parsed; fprintf(stderr, "Found prediction: %s\n", optval); } else { fprintf(stderr, "Invalid prediction: %s, using default\n", optval); } } if (!strcmp(optname, "lora_apply_mode")) { - int found = -1; - for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) { - if (!strcmp(optval, lora_apply_mode_str[m])) { - found = m; - break; - } - } - if (found != -1) { - lora_apply_mode = (lora_apply_mode_t)found; + lora_apply_mode_t parsed = str_to_lora_apply_mode(optval); + if (parsed != LORA_APPLY_MODE_COUNT) { + lora_apply_mode = parsed; fprintf(stderr, "Found lora_apply_mode: %s\n", optval); } else { fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval); } } if (!strcmp(optname, "wtype")) { - int found = -1; - for (int m = 0; m < SD_TYPE_COUNT; m++) { - if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) { - found = m; - break; - } - } - if (found != -1) { - wtype = (sd_type_t)found; + sd_type_t parsed = str_to_sd_type(optval); + if (parsed != SD_TYPE_COUNT) { + wtype = parsed; fprintf(stderr, "Found wtype: %s\n", optval); } else { fprintf(stderr, "Invalid wtype: %s, using default\n", optval); @@ -735,27 +605,25 @@ int load_model(const char *model, char *model_path, char* options[], int threads fprintf (stderr, "Created context: OK\n"); int sample_method_found = -1; - for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) { - if (!strcmp(sampler, sample_method_str[m])) { - sample_method_found = m; - fprintf(stderr, "Found sampler: %s\n", sampler); - } + sample_method_t sm = str_to_sample_method(sampler); + if (sm != SAMPLE_METHOD_COUNT) { + sample_method_found = (int)sm; + fprintf(stderr, "Found sampler: %s\n", sampler); } if (sample_method_found == -1) { sample_method_found = sd_get_default_sample_method(sd_ctx); - fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]); + fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found)); } sample_method = (sample_method_t)sample_method_found; - for (int d = 0; d < SCHEDULER_COUNT; d++) { - if (!strcmp(scheduler_str, schedulers[d])) { - scheduler = (scheduler_t)d; - fprintf (stderr, "Found scheduler: %s\n", scheduler_str); - } + scheduler_t sched = str_to_scheduler(scheduler_str); + if (sched != SCHEDULER_COUNT) { + scheduler = sched; + fprintf(stderr, "Found scheduler: %s\n", scheduler_str); } if (scheduler == SCHEDULER_COUNT) { - scheduler = sd_get_default_scheduler(sd_ctx, sample_method); - fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]); + scheduler = sd_get_default_scheduler(sd_ctx, sample_method); + fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler)); } sd_c = sd_ctx;