connector: save channel access hashes in more places
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
+28
-7
@@ -7,26 +7,47 @@ import (
|
|||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
)
|
)
|
||||||
|
|
||||||
type hasUpdates interface {
|
type hasUserUpdates interface {
|
||||||
GetUsers() []tg.UserClass
|
GetUsers() []tg.UserClass
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrapper for API calls that return a response with updates.
|
type hasUpdates interface {
|
||||||
func APICallWithUpdates[U hasUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) {
|
hasUserUpdates
|
||||||
|
GetChats() []tg.ChatClass
|
||||||
|
}
|
||||||
|
|
||||||
|
func APICallWithOnlyUserUpdates[U hasUserUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) {
|
||||||
resp, err := fn()
|
resp, err := fn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return *new(U), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO do we also need to expand this to chats and messages?
|
|
||||||
for _, user := range resp.GetUsers() {
|
for _, user := range resp.GetUsers() {
|
||||||
user, ok := user.(*tg.User)
|
user, ok := user.(*tg.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return resp, fmt.Errorf("user is %T not *tg.User", user)
|
return *new(U), fmt.Errorf("user is %T not *tg.User", user)
|
||||||
}
|
}
|
||||||
_, err := t.updateGhost(ctx, user.ID, user)
|
_, err := t.updateGhost(ctx, user.ID, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return *new(U), err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrapper for API calls that return a response with updates.
|
||||||
|
func APICallWithUpdates[U hasUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) {
|
||||||
|
resp, err := APICallWithOnlyUserUpdates(ctx, t, fn)
|
||||||
|
if err != nil {
|
||||||
|
return *new(U), err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range resp.GetChats() {
|
||||||
|
if channel, ok := c.(*tg.Channel); ok {
|
||||||
|
if err := t.ScopedStore.SetAccessHash(ctx, channel.ID, channel.AccessHash); err != nil {
|
||||||
|
return *new(U), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (t *TelegramClient) SearchUsers(ctx context.Context, query string) (resp []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) GetContactList(ctx context.Context) (resp []*bridgev2.ResolveIdentifierResponse, err error) {
|
func (t *TelegramClient) GetContactList(ctx context.Context) (resp []*bridgev2.ResolveIdentifierResponse, err error) {
|
||||||
contacts, err := APICallWithUpdates(ctx, t, func() (*tg.ContactsContacts, error) {
|
contacts, err := APICallWithOnlyUserUpdates(ctx, t, func() (*tg.ContactsContacts, error) {
|
||||||
c, err := t.client.API().ContactsGetContacts(ctx, t.cachedContactsHash)
|
c, err := t.client.API().ContactsGetContacts(ctx, t.cachedContactsHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -167,16 +167,16 @@ func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID, channelI
|
|||||||
|
|
||||||
var ErrNoAccessHash = errors.New("access hash not found")
|
var ErrNoAccessHash = errors.New("access hash not found")
|
||||||
|
|
||||||
func (s *ScopedStore) GetAccessHash(ctx context.Context, userID int64) (accessHash int64, err error) {
|
func (s *ScopedStore) GetAccessHash(ctx context.Context, entityID int64) (accessHash int64, err error) {
|
||||||
err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, userID).Scan(&accessHash)
|
err = s.db.QueryRow(ctx, getAccessHashQuery, s.telegramUserID, entityID).Scan(&accessHash)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
err = ErrNoAccessHash
|
err = ErrNoAccessHash
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ScopedStore) SetAccessHash(ctx context.Context, userID, accessHash int64) (err error) {
|
func (s *ScopedStore) SetAccessHash(ctx context.Context, entityID, accessHash int64) (err error) {
|
||||||
_, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, userID, accessHash)
|
_, err = s.db.Exec(ctx, setAccessHashQuery, s.telegramUserID, entityID, accessHash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -320,7 +320,14 @@ func (t *TelegramClient) updateGhost(ctx context.Context, userID int64, user *tg
|
|||||||
|
|
||||||
func (t *TelegramClient) onEntityUpdate(ctx context.Context, e tg.Entities) error {
|
func (t *TelegramClient) onEntityUpdate(ctx context.Context, e tg.Entities) error {
|
||||||
for userID, user := range e.Users {
|
for userID, user := range e.Users {
|
||||||
t.updateGhost(ctx, userID, user)
|
if _, err := t.updateGhost(ctx, userID, user); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for channelID, channel := range e.Channels {
|
||||||
|
if err := t.ScopedStore.SetAccessHash(ctx, channelID, channel.AccessHash); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user