move gotd fork into repo. (#111)
- update to latest telegram layer - remove some references to fields in tg.Entities that don't exist in the schema - originally added here: https://github.com/beeper/td/commit/820929062a2ba0104397bc01235ab58a9cff780e - referenced here - https://github.com/mautrix/telegramgo/commit/124f0967ed195b5a380c9bd02e170ada9710dde3 - https://github.com/mautrix/telegramgo/commit/4205047aab2e0639217148b5d125bfaab668bd8e
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
func getLocation(loc tg.InputFileLocationClass) (string, error) {
|
||||
v, ok := loc.(interface {
|
||||
GetLocalID() int
|
||||
GetVolumeID() int64
|
||||
})
|
||||
if !ok {
|
||||
return "", tgerr.New(400, tg.ErrFileIDInvalid)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d_%d", v.GetLocalID(), v.GetVolumeID()), nil
|
||||
}
|
||||
|
||||
func (m *Service) openLocation(loc tg.InputFileLocationClass) (File, error) {
|
||||
name, err := getLocation(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := m.storage.Open(name)
|
||||
if err != nil {
|
||||
return nil, tgerr.New(400, tg.ErrFileIDInvalid)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (m *Service) getPart(loc tg.InputFileLocationClass, offset int64, limit int) ([]byte, error) {
|
||||
f, err := m.openLocation(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := make([]byte, limit)
|
||||
n, err := f.ReadAt(r, offset)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read from storage")
|
||||
}
|
||||
|
||||
return r[:n], nil
|
||||
}
|
||||
|
||||
func (m *Service) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) {
|
||||
data, err := m.getPart(request.Location, request.Offset, request.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tg.UploadFile{
|
||||
Type: &tg.StorageFilePartial{},
|
||||
Mtime: 0,
|
||||
Bytes: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func countHashes(data []byte, offset int64, partSize int) []tg.FileHash {
|
||||
actions := data
|
||||
batchSize := partSize
|
||||
batches := make([][]byte, 0, (len(actions)+batchSize-1)/batchSize)
|
||||
|
||||
for batchSize < len(actions) {
|
||||
actions, batches = actions[batchSize:], append(batches, actions[0:batchSize:batchSize])
|
||||
}
|
||||
batches = append(batches, actions)
|
||||
|
||||
currentRange := make([]tg.FileHash, 0, 10)
|
||||
for _, batch := range batches {
|
||||
currentRange = append(currentRange, tg.FileHash{
|
||||
Offset: offset,
|
||||
Limit: partSize,
|
||||
Hash: crypto.SHA256(batch),
|
||||
})
|
||||
offset += int64(len(batch))
|
||||
}
|
||||
return currentRange
|
||||
}
|
||||
|
||||
func divAndCeil(a, b int) int {
|
||||
r := a / b
|
||||
if a%b != 0 {
|
||||
r++
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// computeBatch computes hash range number for given offset.
|
||||
func computeBatch(offset int64, rangeSize, partSize int) int {
|
||||
// Compute number of parts in partSize from offset.
|
||||
parts := divAndCeil(int(offset+1), partSize)
|
||||
// Compute number of hash ranges in rangeSize.
|
||||
batches := divAndCeil(parts, rangeSize)
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
func (m *Service) UploadGetFileHashes(
|
||||
ctx context.Context,
|
||||
request *tg.UploadGetFileHashesRequest,
|
||||
) ([]tg.FileHash, error) {
|
||||
f, err := m.openLocation(request.Location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.Offset >= int64(f.Size()) {
|
||||
return nil, nil
|
||||
}
|
||||
partSize := m.hashPartSize
|
||||
rangeSize := m.hashRangeSize
|
||||
batch := computeBatch(request.Offset, rangeSize, partSize)
|
||||
|
||||
low := (batch - 1) * rangeSize * partSize
|
||||
high := batch * rangeSize * partSize
|
||||
|
||||
r := make([]byte, high-low)
|
||||
n, err := f.ReadAt(r, int64(low))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r = r[:n]
|
||||
|
||||
return countHashes(r, int64(low), partSize), nil
|
||||
}
|
||||
Reference in New Issue
Block a user