diff --git a/pkg/connector/startchat.go b/pkg/connector/startchat.go index 6177ff1a..8f26bff8 100644 --- a/pkg/connector/startchat.go +++ b/pkg/connector/startchat.go @@ -18,6 +18,7 @@ package connector import ( "context" + "errors" "fmt" "regexp" "strconv" @@ -28,6 +29,7 @@ import ( "maunium.net/go/mautrix/bridgev2" "go.mau.fi/mautrix-telegram/pkg/connector/ids" + "go.mau.fi/mautrix-telegram/pkg/connector/store" "go.mau.fi/mautrix-telegram/pkg/gotd/telegram/query/hasher" "go.mau.fi/mautrix-telegram/pkg/gotd/tg" ) @@ -58,6 +60,12 @@ func (t *TelegramClient) getResolveIdentifierResponseForUser(ctx context.Context } func (t *TelegramClient) getResolveIdentifierResponseForUserID(ctx context.Context, userID int64) (resp *bridgev2.ResolveIdentifierResponse, err error) { + _, err = t.ScopedStore.GetAccessHash(ctx, ids.PeerTypeUser, userID) + if errors.Is(err, store.ErrNoAccessHash) { + return nil, fmt.Errorf("%w: %w", bridgev2.ErrResolveIdentifierTryNext, err) + } else if err != nil { + return nil, fmt.Errorf("failed to get access hash from store: %w", err) + } networkUserID := ids.MakeUserID(userID) resp = &bridgev2.ResolveIdentifierResponse{ UserID: networkUserID, @@ -116,34 +124,36 @@ func (t *TelegramClient) ResolveIdentifier(ctx context.Context, identifier strin entityType, userID, err := t.ScopedStore.GetEntityIDByUsername(ctx, match[1]) if entityType == ids.PeerTypeUser && (err == nil || userID != 0) { // We know this username. - return t.getResolveIdentifierResponseForUserID(ctx, userID) - } else { - // We don't know this username, try to resolve the username from - // Telegram. - resolved, err := APICallWithUpdates(ctx, t, func() (*tg.ContactsResolvedPeer, error) { - return t.client.API().ContactsResolveUsername(ctx, &tg.ContactsResolveUsernameRequest{ - Username: match[1], - }) - }) - if err != nil { - if tg.IsUsernameNotOccupied(err) { - log.Info().Msg("Username not found in database") - return nil, nil - } else { - return nil, fmt.Errorf("failed to resolve username: %w", err) - } + resp, err := t.getResolveIdentifierResponseForUserID(ctx, userID) + if err == nil || !errors.Is(err, store.ErrNoAccessHash) { + return resp, err } - peer, ok := resolved.GetPeer().(*tg.PeerUser) - if !ok { - return nil, fmt.Errorf("unexpected peer type: %T", resolved.GetPeer()) - } - for _, user := range resolved.GetUsers() { - if user.GetID() == peer.GetUserID() { - return t.getResolveIdentifierResponseForUser(ctx, user) - } - } - return nil, fmt.Errorf("peer user not found in contact resolved response") } + // We don't know this username, try to resolve the username from + // Telegram. + resolved, err := APICallWithUpdates(ctx, t, func() (*tg.ContactsResolvedPeer, error) { + return t.client.API().ContactsResolveUsername(ctx, &tg.ContactsResolveUsernameRequest{ + Username: match[1], + }) + }) + if err != nil { + if tg.IsUsernameNotOccupied(err) { + log.Info().Msg("Username not found in database") + return nil, nil + } else { + return nil, fmt.Errorf("failed to resolve username: %w", err) + } + } + peer, ok := resolved.GetPeer().(*tg.PeerUser) + if !ok { + return nil, fmt.Errorf("unexpected peer type: %T", resolved.GetPeer()) + } + for _, user := range resolved.GetUsers() { + if user.GetID() == peer.GetUserID() { + return t.getResolveIdentifierResponseForUser(ctx, user) + } + } + return nil, fmt.Errorf("peer user not found in contact resolved response") } else { return nil, fmt.Errorf("invalid identifier: %s (must be a phone number, username, or Telegram user ID)", identifier) } diff --git a/pkg/connector/store/scoped_store.go b/pkg/connector/store/scoped_store.go index dcab0c24..9a799359 100644 --- a/pkg/connector/store/scoped_store.go +++ b/pkg/connector/store/scoped_store.go @@ -227,7 +227,7 @@ var ErrNoAccessHash = errors.New("access hash not found") func (s *ScopedStore) GetAccessHash(ctx context.Context, entityType ids.PeerType, entityID int64) (accessHash int64, err error) { err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, entityType, entityID).Scan(&accessHash) if errors.Is(err, sql.ErrNoRows) { - err = ErrNoAccessHash + err = fmt.Errorf("%w for %d", ErrNoAccessHash, entityID) } return }