From 08fb29a0359c53c8b0034a1ffa9a6b677e616ae0 Mon Sep 17 00:00:00 2001 From: jp9000 Date: Wed, 6 Feb 2019 22:24:25 -0800 Subject: [PATCH] UI: Add Auth and OAuth classes Allows the ability to authenticate to a specific service. Typically via OAuth. --- UI/CMakeLists.txt | 11 ++ UI/auth-base.cpp | 71 ++++++++ UI/auth-base.hpp | 58 ++++++ UI/auth-oauth.cpp | 283 ++++++++++++++++++++++++++++++ UI/auth-oauth.hpp | 76 ++++++++ UI/data/locale/en-US.ini | 12 ++ UI/window-basic-main-outputs.cpp | 8 + UI/window-basic-main-profiles.cpp | 10 ++ UI/window-basic-main.cpp | 12 +- UI/window-basic-main.hpp | 6 + 10 files changed, 543 insertions(+), 4 deletions(-) create mode 100644 UI/auth-base.cpp create mode 100644 UI/auth-base.hpp create mode 100644 UI/auth-oauth.cpp create mode 100644 UI/auth-oauth.hpp diff --git a/UI/CMakeLists.txt b/UI/CMakeLists.txt index 04f8a973b..b88916aa0 100644 --- a/UI/CMakeLists.txt +++ b/UI/CMakeLists.txt @@ -105,6 +105,15 @@ elseif(UNIX) Qt5::X11Extras) endif() +if(BROWSER_AVAILABLE_INTERNAL) + list(APPEND obs_PLATFORM_SOURCES + auth-oauth.cpp + ) + list(APPEND obs_PLATFORM_HEADERS + auth-oauth.hpp + ) +endif() + set(obs_libffutil_SOURCES ../deps/libff/libff/ff-util.c ) @@ -150,6 +159,7 @@ set(obs_SOURCES window-log-reply.cpp window-projector.cpp window-remux.cpp + auth-base.cpp source-tree.cpp properties-view.cpp focus-list.cpp @@ -198,6 +208,7 @@ set(obs_HEADERS window-log-reply.hpp window-projector.hpp window-remux.hpp + auth-base.hpp source-tree.hpp properties-view.hpp properties-view.moc.hpp diff --git a/UI/auth-base.cpp b/UI/auth-base.cpp new file mode 100644 index 000000000..6a6bcfc01 --- /dev/null +++ b/UI/auth-base.cpp @@ -0,0 +1,71 @@ +#include "auth-base.hpp" +#include "window-basic-main.hpp" + +#include +#include + +struct AuthInfo { + Auth::Def def; + Auth::create_cb create; +}; + +static std::vector authDefs; + +void Auth::RegisterAuth(const Def &d, create_cb create) +{ + AuthInfo info = {d, create}; + authDefs.push_back(info); +} + +std::shared_ptr Auth::Create(const std::string &service) +{ + for (auto &a : authDefs) { + if (service.find(a.def.service) != std::string::npos) { + return a.create(); + } + } + + return nullptr; +} + +Auth::Type Auth::AuthType(const std::string &service) +{ + for (auto &a : authDefs) { + if (service.find(a.def.service) != std::string::npos) { + return a.def.type; + } + } + + return Type::None; +} + +void Auth::Load() +{ + OBSBasic *main = OBSBasic::Get(); + const char *typeStr = config_get_string(main->Config(), "Auth", "Type"); + if (!typeStr) typeStr = ""; + + main->auth = Create(typeStr); + if (main->auth) { + if (main->auth->LoadInternal()) { + main->auth->LoadUI(); + } + } +} + +void Auth::Save() +{ + OBSBasic *main = OBSBasic::Get(); + Auth *auth = main->auth.get(); + if (!auth) { + if (config_has_user_value(main->Config(), "Auth", "Type")) { + config_remove_value(main->Config(), "Auth", "Type"); + config_save_safe(main->Config(), "tmp", nullptr); + } + return; + } + + config_set_string(main->Config(), "Auth", "Type", auth->service()); + auth->SaveInternal(); + config_save_safe(main->Config(), "tmp", nullptr); +} diff --git a/UI/auth-base.hpp b/UI/auth-base.hpp new file mode 100644 index 000000000..8e04ec131 --- /dev/null +++ b/UI/auth-base.hpp @@ -0,0 +1,58 @@ +#pragma once + +#include +#include +#include + +class Auth : public QObject { + Q_OBJECT + +protected: + virtual void SaveInternal()=0; + virtual bool LoadInternal()=0; + + bool firstLoad = true; + + struct ErrorInfo { + std::string message; + std::string error; + + ErrorInfo(std::string message_, std::string error_) + : message(message_), error(error_) + {} + }; + +public: + enum class Type { + None, + OAuth_StreamKey + }; + + struct Def { + std::string service; + Type type; + }; + + typedef std::function ()> create_cb; + + inline Auth(const Def &d) : def(d) {} + virtual ~Auth() {} + + inline Type type() const {return def.type;} + inline const char *service() const {return def.service.c_str();} + + virtual void LoadUI() {} + + virtual void OnStreamConfig() {} + + static std::shared_ptr Create(const std::string &service); + static Type AuthType(const std::string &service); + static void Load(); + static void Save(); + +protected: + static void RegisterAuth(const Def &d, create_cb create); + +private: + Def def; +}; diff --git a/UI/auth-oauth.cpp b/UI/auth-oauth.cpp new file mode 100644 index 000000000..6b7c45708 --- /dev/null +++ b/UI/auth-oauth.cpp @@ -0,0 +1,283 @@ +#include "auth-oauth.hpp" + +#include +#include +#include + +#include +#include + +#include "window-basic-main.hpp" +#include "remote-text.hpp" + +#include + +#include + +using namespace json11; + +#include +extern QCef *cef; +extern QCefCookieManager *panel_cookies; + +/* ------------------------------------------------------------------------- */ + +OAuthLogin::OAuthLogin(QWidget *parent, const std::string &url, bool token) + : QDialog (parent), + get_token (token) +{ + setWindowTitle("Auth"); + resize(700, 700); + + OBSBasic::InitBrowserPanelSafeBlock(true); + + cefWidget = cef->create_widget(nullptr, url, panel_cookies); + if (!cefWidget) { + fail = true; + return; + } + + connect(cefWidget, SIGNAL(titleChanged(const QString &)), + this, SLOT(setWindowTitle(const QString &))); + connect(cefWidget, SIGNAL(urlChanged(const QString &)), + this, SLOT(urlChanged(const QString &))); + + QPushButton *close = new QPushButton(QTStr("Cancel")); + connect(close, &QAbstractButton::clicked, + this, &QDialog::reject); + + QHBoxLayout *bottomLayout = new QHBoxLayout(); + bottomLayout->addStretch(); + bottomLayout->addWidget(close); + bottomLayout->addStretch(); + + QVBoxLayout *topLayout = new QVBoxLayout(this); + topLayout->addWidget(cefWidget); + topLayout->addLayout(bottomLayout); +} + +OAuthLogin::~OAuthLogin() +{ + delete cefWidget; +} + +void OAuthLogin::urlChanged(const QString &url) +{ + std::string uri = get_token ? "access_token=" : "code="; + int code_idx = url.indexOf(uri.c_str()); + if (code_idx == -1) + return; + + if (url.left(22) != "https://obsproject.com") + return; + + code_idx += (int)uri.size(); + + int next_idx = url.indexOf("&", code_idx); + if (next_idx != -1) + code = url.mid(code_idx, next_idx - code_idx); + else + code = url.right(url.size() - code_idx); + + accept(); +} + +/* ------------------------------------------------------------------------- */ + +struct OAuthInfo { + Auth::Def def; + OAuth::login_cb login; + OAuth::delete_cookies_cb delete_cookies; +}; + +static std::vector loginCBs; + +void OAuth::RegisterOAuth(const Def &d, create_cb create, login_cb login, + delete_cookies_cb delete_cookies) +{ + OAuthInfo info = {d, login, delete_cookies}; + loginCBs.push_back(info); + RegisterAuth(d, create); +} + +std::shared_ptr OAuth::Login(QWidget *parent, const std::string &service) +{ + for (auto &a : loginCBs) { + if (service.find(a.def.service) != std::string::npos) { + return a.login(parent); + } + } + + return nullptr; +} + +void OAuth::DeleteCookies(const std::string &service) +{ + for (auto &a : loginCBs) { + if (service.find(a.def.service) != std::string::npos) { + a.delete_cookies(); + } + } +} + +void OAuth::SaveInternal() +{ + OBSBasic *main = OBSBasic::Get(); + config_set_string(main->Config(), service(), "RefreshToken", + refresh_token.c_str()); + config_set_string(main->Config(), service(), "Token", token.c_str()); + config_set_uint(main->Config(), service(), "ExpireTime", expire_time); + config_set_int(main->Config(), service(), "ScopeVer", currentScopeVer); +} + +static inline std::string get_config_str( + OBSBasic *main, + const char *section, + const char *name) +{ + const char *val = config_get_string(main->Config(), section, name); + return val ? val : ""; +} + +bool OAuth::LoadInternal() +{ + OBSBasic *main = OBSBasic::Get(); + refresh_token = get_config_str(main, service(), "RefreshToken"); + token = get_config_str(main, service(), "Token"); + expire_time = config_get_uint(main->Config(), service(), "ExpireTime"); + currentScopeVer = (int)config_get_int(main->Config(), service(), + "ScopeVer"); + return implicit + ? !token.empty() + : !refresh_token.empty(); +} + +bool OAuth::TokenExpired() +{ + if (token.empty()) + return true; + if ((uint64_t)time(nullptr) > expire_time - 5) + return true; + return false; +} + +bool OAuth::GetToken(const char *url, const std::string &client_id, + int scope_ver, const std::string &auth_code, bool retry) +try { + std::string output; + std::string error; + std::string desc; + + if (currentScopeVer > 0 && currentScopeVer < scope_ver) { + if (RetryLogin()) { + return true; + } else { + QString title = QTStr("Auth.InvalidScope.Title"); + QString text = QTStr("Auth.InvalidScope.Text") + .arg(service()); + + QMessageBox::warning(OBSBasic::Get(), title, text); + } + } + + if (auth_code.empty() && !TokenExpired()) { + return true; + } + + std::string post_data; + post_data += "action=redirect&client_id="; + post_data += client_id; + + if (!auth_code.empty()) { + post_data += "&grant_type=authorization_code&code="; + post_data += auth_code; + } else { + post_data += "&grant_type=refresh_token&refresh_token="; + post_data += refresh_token; + } + + bool success = false; + + auto func = [&] () { + success = GetRemoteFile( + url, + output, + error, + nullptr, + "application/x-www-form-urlencoded", + post_data.c_str(), + std::vector(), + nullptr, + 5); + }; + + ExecuteFuncSafeBlockMsgBox( + func, + QTStr("Auth.Authing.Title"), + QTStr("Auth.Authing.Text").arg(service())); + if (!success || output.empty()) + throw ErrorInfo("Failed to get token from remote", error); + + Json json = Json::parse(output, error); + if (!error.empty()) + throw ErrorInfo("Failed to parse json", error); + + /* -------------------------- */ + /* error handling */ + + error = json["error"].string_value(); + if (!retry && error == "invalid_grant") { + if (RetryLogin()) { + return true; + } + } + if (!error.empty()) + throw ErrorInfo(error, json["error_description"].string_value()); + + /* -------------------------- */ + /* success! */ + + expire_time = (uint64_t)time(nullptr) + json["expires_in"].int_value(); + token = json["access_token"].string_value(); + if (token.empty()) + throw ErrorInfo("Failed to get token from remote", error); + + if (!auth_code.empty()) { + refresh_token = json["refresh_token"].string_value(); + if (refresh_token.empty()) + throw ErrorInfo("Failed to get refresh token from " + "remote", error); + + currentScopeVer = scope_ver; + } + + return true; + +} catch (ErrorInfo info) { + if (!retry) { + QString title = QTStr("Auth.AuthFailure.Title"); + QString text = QTStr("Auth.AuthFailure.Text") + .arg(service(), info.message.c_str(), info.error.c_str()); + + QMessageBox::warning(OBSBasic::Get(), title, text); + } + + blog(LOG_WARNING, "%s: %s: %s", + __FUNCTION__, + info.message.c_str(), + info.error.c_str()); + return false; +} + +void OAuthStreamKey::OnStreamConfig() +{ + OBSBasic *main = OBSBasic::Get(); + obs_service_t *service = main->GetService(); + + obs_data_t *settings = obs_service_get_settings(service); + + obs_data_set_string(settings, "key", key_.c_str()); + obs_service_update(service, settings); + + obs_data_release(settings); +} diff --git a/UI/auth-oauth.hpp b/UI/auth-oauth.hpp new file mode 100644 index 000000000..6379299d3 --- /dev/null +++ b/UI/auth-oauth.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include + +#include "auth-base.hpp" + +class QCefWidget; + +class OAuthLogin : public QDialog { + Q_OBJECT + + QCefWidget *cefWidget = nullptr; + QString code; + bool get_token = false; + bool fail = false; + +public: + OAuthLogin(QWidget *parent, const std::string &url, bool token); + ~OAuthLogin(); + + inline QString GetCode() const {return code;} + inline bool LoadFail() const {return fail;} + +public slots: + void urlChanged(const QString &url); +}; + +class OAuth : public Auth { + Q_OBJECT + +public: + inline OAuth(const Def &d) : Auth(d) {} + + typedef std::function (QWidget *)> login_cb; + typedef std::function delete_cookies_cb; + + static std::shared_ptr Login(QWidget *parent, + const std::string &service); + static void DeleteCookies(const std::string &service); + + static void RegisterOAuth(const Def &d, create_cb create, + login_cb login, delete_cookies_cb delete_cookies); + +protected: + std::string refresh_token; + std::string token; + bool implicit = false; + uint64_t expire_time = 0; + int currentScopeVer = 0; + + virtual void SaveInternal() override; + virtual bool LoadInternal() override; + + virtual bool RetryLogin()=0; + bool TokenExpired(); + bool GetToken(const char *url, const std::string &client_id, + int scope_ver, + const std::string &auth_code = std::string(), + bool retry = false); +}; + +class OAuthStreamKey : public OAuth { + Q_OBJECT + +protected: + std::string key_; + +public: + inline OAuthStreamKey(const Def &d) : OAuth(d) {} + + inline const std::string &key() const {return key_;} + + virtual void OnStreamConfig() override; +}; diff --git a/UI/data/locale/en-US.ini b/UI/data/locale/en-US.ini index ce475bfd3..de07d7ddf 100644 --- a/UI/data/locale/en-US.ini +++ b/UI/data/locale/en-US.ini @@ -91,6 +91,18 @@ AlreadyRunning.Title="OBS is already running" AlreadyRunning.Text="OBS is already running! Unless you meant to do this, please shut down any existing instances of OBS before trying to run a new instance. If you have OBS set to minimize to the system tray, please check to see if it's still running there." AlreadyRunning.LaunchAnyway="Launch Anyway" +# Auth +Auth.Authing.Title="Authenticating.." +Auth.Authing.Text="Authenticating with %1, please wait.." +Auth.AuthFailure.Title="Authentication Failure" +Auth.AuthFailure.Text="Failed to authenticate with %1:\n\n%2: %3" +Auth.InvalidScope.Title="Authentication Required" +Auth.InvalidScope.Text="The authentication requirements for %1 have changed. Some features may not be available." +Auth.LoadingChannel.Title="Loading channel information.." +Auth.LoadingChannel.Text="Loading channel information for %1, please wait.." +Auth.ChannelFailure.Title="Failed to load channel" +Auth.ChannelFailure.Text="Failed to load channel information for %1\n\n%2: %3" + # copy filters Copy.Filters="Copy Filters" Paste.Filters="Paste Filters" diff --git a/UI/window-basic-main-outputs.cpp b/UI/window-basic-main-outputs.cpp index 121bf5a3f..030dc883c 100644 --- a/UI/window-basic-main-outputs.cpp +++ b/UI/window-basic-main-outputs.cpp @@ -651,6 +651,10 @@ bool SimpleOutput::StartStreaming(obs_service_t *service) if (!Active()) SetupOutputs(); + Auth *auth = main->GetAuth(); + if (auth) + auth->OnStreamConfig(); + /* --------------------- */ const char *type = obs_service_get_output_type(service); @@ -1426,6 +1430,10 @@ bool AdvancedOutput::StartStreaming(obs_service_t *service) if (!Active()) SetupOutputs(); + Auth *auth = main->GetAuth(); + if (auth) + auth->OnStreamConfig(); + /* --------------------- */ int trackIndex = config_get_int(main->Config(), "AdvOut", diff --git a/UI/window-basic-main-profiles.cpp b/UI/window-basic-main-profiles.cpp index f70e7b46e..049a4087a 100644 --- a/UI/window-basic-main-profiles.cpp +++ b/UI/window-basic-main-profiles.cpp @@ -232,7 +232,9 @@ bool OBSBasic::AddProfile(bool create_new, const char *title, const char *text, config_set_string(App()->GlobalConfig(), "Basic", "ProfileDir", newDir.c_str()); + Auth::Save(); if (create_new) { + auth.reset(); DestroyPanelCookieManager(); } else if (!rename) { DuplicateCurrentCookieProfile(config); @@ -456,6 +458,8 @@ void OBSBasic::on_actionRemoveProfile_triggered() config_set_string(App()->GlobalConfig(), "Basic", "ProfileDir", newDir); + Auth::Save(); + auth.reset(); DestroyPanelCookieManager(); config.Swap(basicConfig); @@ -471,6 +475,8 @@ void OBSBasic::on_actionRemoveProfile_triggered() UpdateTitleBar(); + Auth::Load(); + if (api) { api->on_event(OBS_FRONTEND_EVENT_PROFILE_LIST_CHANGED); api->on_event(OBS_FRONTEND_EVENT_PROFILE_CHANGED); @@ -615,6 +621,8 @@ void OBSBasic::ChangeProfile() config_set_string(App()->GlobalConfig(), "Basic", "ProfileDir", newDir); + Auth::Save(); + auth.reset(); DestroyPanelCookieManager(); config.Swap(basicConfig); @@ -624,6 +632,8 @@ void OBSBasic::ChangeProfile() config_save_safe(App()->GlobalConfig(), "tmp", nullptr); UpdateTitleBar(); + Auth::Load(); + CheckForSimpleModeX264Fallback(); blog(LOG_INFO, "Switched to profile '%s' (%s)", diff --git a/UI/window-basic-main.cpp b/UI/window-basic-main.cpp index d2f106846..76013997d 100644 --- a/UI/window-basic-main.cpp +++ b/UI/window-basic-main.cpp @@ -1786,6 +1786,8 @@ void OBSBasic::OnFirstLoad() } } #endif + + Auth::Load(); } void OBSBasic::DeferredLoad(const QString &file, int requeueCount) @@ -3656,10 +3658,6 @@ void OBSBasic::closeEvent(QCloseEvent *event) "BasicWindow", "geometry", saveGeometry().toBase64().constData()); - config_set_string(App()->GlobalConfig(), - "BasicWindow", "DockState", - saveState().toBase64().constData()); - if (outputHandler && outputHandler->Active()) { SetShowing(true); @@ -3688,7 +3686,13 @@ void OBSBasic::closeEvent(QCloseEvent *event) signalHandlers.clear(); + Auth::Save(); SaveProjectNow(); + auth.reset(); + + config_set_string(App()->GlobalConfig(), + "BasicWindow", "DockState", + saveState().toBase64().constData()); if (api) api->on_event(OBS_FRONTEND_EVENT_EXIT); diff --git a/UI/window-basic-main.hpp b/UI/window-basic-main.hpp index 7a64fa7d7..109012579 100644 --- a/UI/window-basic-main.hpp +++ b/UI/window-basic-main.hpp @@ -32,6 +32,7 @@ #include "window-basic-filters.hpp" #include "window-projector.hpp" #include "window-basic-about.hpp" +#include "auth-base.hpp" #include @@ -116,6 +117,7 @@ class OBSBasic : public OBSMainWindow { friend class OBSBasicStatusBar; friend class OBSBasicSourceSelect; friend class OBSBasicSettings; + friend class Auth; friend struct OBSStudioAPI; enum class MoveDir { @@ -136,6 +138,8 @@ class OBSBasic : public OBSMainWindow { private: obs_frontend_callbacks *api = nullptr; + std::shared_ptr auth; + std::vector volumes; std::vector signalHandlers; @@ -591,6 +595,8 @@ public: void SaveService(); bool LoadService(); + inline Auth *GetAuth() {return auth.get();} + inline void EnableOutputs(bool enable) { if (enable) {