client: Wait before returning from disconnect

This commit is contained in:
Toni Spets
2025-04-09 10:35:08 +03:00
parent 10b8c4b635
commit 538f2a2ec0
3 changed files with 21 additions and 7 deletions
+11 -5
View File
@@ -66,6 +66,7 @@ type TelegramClient struct {
updatesManager *updates.Manager updatesManager *updates.Manager
clientCtx context.Context clientCtx context.Context
clientCancel context.CancelFunc clientCancel context.CancelFunc
clientCloseC <-chan struct{}
appConfigLock sync.Mutex appConfigLock sync.Mutex
appConfig map[string]any appConfig map[string]any
@@ -370,11 +371,13 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
// connectTelegramClient blocks until client is connected, calling Run // connectTelegramClient blocks until client is connected, calling Run
// internally. // internally.
// Technique from: https://github.com/gotd/contrib/blob/master/bg/connect.go // Technique from: https://github.com/gotd/contrib/blob/master/bg/connect.go
func connectTelegramClient(ctx context.Context, cancel context.CancelFunc, client *telegram.Client) error { func connectTelegramClient(ctx context.Context, cancel context.CancelFunc, client *telegram.Client) (<-chan struct{}, error) {
errC := make(chan error, 1) errC := make(chan error, 1)
initDone := make(chan struct{}) initDone := make(chan struct{})
closeC := make(chan struct{})
go func() { go func() {
defer close(errC) defer close(errC)
defer close(closeC)
errC <- client.Run(ctx, func(ctx context.Context) error { errC <- client.Run(ctx, func(ctx context.Context) error {
close(initDone) close(initDone)
<-ctx.Done() <-ctx.Done()
@@ -388,13 +391,13 @@ func connectTelegramClient(ctx context.Context, cancel context.CancelFunc, clien
select { select {
case <-ctx.Done(): // context canceled case <-ctx.Done(): // context canceled
cancel() cancel()
return fmt.Errorf("context cancelled before init done: %w", ctx.Err()) return nil, fmt.Errorf("context cancelled before init done: %w", ctx.Err())
case err := <-errC: // startup timeout case err := <-errC: // startup timeout
cancel() cancel()
return fmt.Errorf("client connection timeout: %w", err) return nil, fmt.Errorf("client connection timeout: %w", err)
case <-initDone: // init done case <-initDone: // init done
} }
return nil return closeC, nil
} }
func (t *TelegramClient) onDead() { func (t *TelegramClient) onDead() {
@@ -497,7 +500,7 @@ func (t *TelegramClient) Connect(ctx context.Context) {
var err error var err error
t.clientCtx, t.clientCancel = context.WithCancel(ctx) t.clientCtx, t.clientCancel = context.WithCancel(ctx)
if err = connectTelegramClient(t.clientCtx, t.clientCancel, t.client); err != nil { if t.clientCloseC, err = connectTelegramClient(t.clientCtx, t.clientCancel, t.client); err != nil {
t.sendBadCredentialsOrUnknownError(err) t.sendBadCredentialsOrUnknownError(err)
return return
} }
@@ -536,6 +539,9 @@ func (t *TelegramClient) Disconnect() {
if t.clientCancel != nil { if t.clientCancel != nil {
t.clientCancel() t.clientCancel()
} }
if t.clientCloseC != nil {
<-t.clientCloseC
}
} }
func (t *TelegramClient) getInputUser(ctx context.Context, id int64) (*tg.InputUser, error) { func (t *TelegramClient) getInputUser(ctx context.Context, id int64) (*tg.InputUser, error) {
+5 -1
View File
@@ -44,6 +44,7 @@ type PhoneLogin struct {
authClient *telegram.Client authClient *telegram.Client
authClientCtx context.Context authClientCtx context.Context
authClientCancel context.CancelFunc authClientCancel context.CancelFunc
authClientCloseC <-chan struct{}
phone string phone string
hash string hash string
@@ -55,6 +56,9 @@ func (p *PhoneLogin) Cancel() {
if p.authClientCancel != nil { if p.authClientCancel != nil {
p.authClientCancel() p.authClientCancel()
} }
if p.authClientCloseC != nil {
<-p.authClientCloseC
}
} }
func (p *PhoneLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { func (p *PhoneLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) {
@@ -85,7 +89,7 @@ func (p *PhoneLogin) SubmitUserInput(ctx context.Context, input map[string]strin
}) })
var err error var err error
p.authClientCtx, p.authClientCancel = context.WithTimeoutCause(log.WithContext(context.Background()), time.Hour, errors.New("phone login took over one hour")) p.authClientCtx, p.authClientCancel = context.WithTimeoutCause(log.WithContext(context.Background()), time.Hour, errors.New("phone login took over one hour"))
if err = connectTelegramClient(p.authClientCtx, p.authClientCancel, p.authClient); err != nil { if p.authClientCloseC, err = connectTelegramClient(p.authClientCtx, p.authClientCancel, p.authClient); err != nil {
return nil, err return nil, err
} }
sentCode, err := p.authClient.Auth().SendCode(p.authClientCtx, p.phone, auth.SendCodeOptions{}) sentCode, err := p.authClient.Auth().SendCode(p.authClientCtx, p.phone, auth.SendCodeOptions{})
+5 -1
View File
@@ -47,6 +47,7 @@ type QRLogin struct {
authClientCtx context.Context authClientCtx context.Context
authClientCancel context.CancelFunc authClientCancel context.CancelFunc
authClientCloseC <-chan struct{}
auth chan qrAuthResult auth chan qrAuthResult
qrToken chan qrlogin.Token qrToken chan qrlogin.Token
@@ -61,6 +62,9 @@ func (q *QRLogin) Cancel() {
if q.authClientCancel != nil { if q.authClientCancel != nil {
q.authClientCancel() q.authClientCancel()
} }
if q.authClientCloseC != nil {
<-q.authClientCloseC
}
} }
func (q *QRLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { func (q *QRLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) {
@@ -85,7 +89,7 @@ func (q *QRLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) {
var err error var err error
q.authClientCtx, q.authClientCancel = context.WithTimeoutCause(log.WithContext(context.Background()), time.Hour, errors.New("phone login took over one hour")) q.authClientCtx, q.authClientCancel = context.WithTimeoutCause(log.WithContext(context.Background()), time.Hour, errors.New("phone login took over one hour"))
if err = connectTelegramClient(q.authClientCtx, q.authClientCancel, q.authClient); err != nil { if q.authClientCloseC, err = connectTelegramClient(q.authClientCtx, q.authClientCancel, q.authClient); err != nil {
return nil, err return nil, err
} }