provisioning: implement legacy QR endpoint

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans
2024-08-27 15:45:25 -06:00
parent 4d9ad4f0af
commit c2d94947ee
4 changed files with 108 additions and 5 deletions
@@ -17,11 +17,13 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"go.mau.fi/util/exhttp"
"maunium.net/go/mautrix/bridgev2"
@@ -62,6 +64,92 @@ type legacyLogin struct {
var inflightLegacyLoginsLock sync.RWMutex
var inflightLegacyLogins = map[id.UserID]*legacyLogin{}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
Subprotocols: []string{"net.maunium.telegram.login"},
}
func legacyProvLoginQR(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context()).With().Str("prov_method", "qr_login").Logger()
ctx := log.WithContext(r.Context())
user := m.Matrix.Provisioning.GetUser(r)
resp := response{Username: user.MXID}
var err error
var loginProcess bridgev2.LoginProcess
var nextStep *bridgev2.LoginStep
if loginProcess, err = c.CreateLogin(ctx, user, connector.LoginFlowIDQR); err != nil {
exhttp.WriteJSONResponse(w, http.StatusInternalServerError, resp.WithError("create_login_failed", fmt.Sprintf("Failed to create a QR login process: %s", err.Error())))
} else if nextStep, err = loginProcess.Start(ctx); err != nil {
exhttp.WriteJSONResponse(w, http.StatusInternalServerError, resp.WithError("start_login_failed", fmt.Sprintf("Failed to start login process: %s", err.Error())))
} else if nextStep.StepID != connector.LoginStepIDShowQR {
exhttp.WriteJSONResponse(w, http.StatusInternalServerError, resp.WithError("unexpected_step", fmt.Sprintf("Unexpected first step %s", nextStep.StepID)))
}
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Err(err).Msg("Failed to upgrade connection to websocket")
return
}
defer func() {
err := ws.Close()
if err != nil {
log.Debug().Err(err).Msg("Error closing websocket")
}
}()
go func() {
// Read everything so SetCloseHandler() works
for {
_, _, err = ws.ReadMessage()
if err != nil {
break
}
}
}()
ctx, cancel := context.WithCancel(context.Background())
ws.SetCloseHandler(func(code int, text string) error {
log.Debug().Int("close_code", code).Msg("Login websocket closed, cancelling login")
cancel()
return nil
})
for {
switch nextStep.StepID {
case connector.LoginStepIDShowQR:
nextStep, err = loginProcess.(bridgev2.LoginProcessDisplayAndWait).Wait(ctx)
if err != nil {
ws.WriteJSON(map[string]any{
"success": false,
"error": "qr_login_failed",
"message": fmt.Sprintf("Failed to login using QR code: %s", err),
})
return
}
ws.WriteJSON(map[string]any{"code": nextStep.DisplayAndWaitParams.Data})
case connector.LoginStepIDComplete:
ws.WriteJSON(map[string]any{"success": true})
return
case connector.LoginStepIDPassword:
inflightLegacyLoginsLock.Lock()
inflightLegacyLogins[user.MXID] = &legacyLogin{Process: loginProcess, NextStep: nextStep}
inflightLegacyLoginsLock.Unlock()
ws.WriteJSON(map[string]any{"success": false, "error": "password-needed"})
return
default:
ws.WriteJSON(map[string]any{
"success": false,
"error": "unexpected_step",
"message": fmt.Sprintf("Unexpected step in QR code login process %s", nextStep.StepID),
})
return
}
}
}
func legacyProvLoginRequestCode(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context()).With().Str("prov_step", "request_code").Logger()
ctx := log.WithContext(r.Context())
+16 -1
View File
@@ -19,6 +19,8 @@ package main
import (
"encoding/base64"
"fmt"
"net/http"
"strings"
"go.mau.fi/util/dbutil/litestream"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
@@ -60,7 +62,20 @@ func main() {
versionWithoutCommit := m.Version
m.PostStart = func() {
if m.Matrix.Provisioning != nil {
// m.Matrix.Provisioning.Router.HandleFunc("/v1/user/{userID}/login/qr", legacyProvLoginQR)
m.Matrix.Provisioning.GetAuthFromRequest = func(r *http.Request) string {
if !strings.HasSuffix(r.URL.Path, "/login/qr") {
return ""
}
authParts := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",")
for _, part := range authParts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "net.maunium.telegram.auth-") {
return strings.TrimPrefix(part, "net.maunium.telegram.auth-")
}
}
return ""
}
m.Matrix.Provisioning.Router.HandleFunc("/v1/user/{userID}/login/qr", legacyProvLoginQR)
m.Matrix.Provisioning.Router.HandleFunc("/v1/user/{userID}/login/request_code", legacyProvLoginRequestCode)
m.Matrix.Provisioning.Router.HandleFunc("/v1/user/{userID}/login/send_code", legacyProvLoginSendCode)
m.Matrix.Provisioning.Router.HandleFunc("/v1/user/{userID}/login/send_password", legacyProvLoginSendPassword)