summaryrefslogtreecommitdiff
path: root/mercurial/wireproto.py
blob: 2b44fd64178177c1e58482ef2c4964c5e276bdcf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
# wireproto.py - generic wire protocol support functions
#
# Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

import urllib, tempfile, os, sys
from i18n import _
from node import bin, hex
import changegroup as changegroupmod
import peer, error, encoding, util, store
import discovery, phases

# abstract batching support

class future(object):
    '''placeholder for a value to be set later'''
    def set(self, value):
        if util.safehasattr(self, 'value'):
            raise error.RepoError("future is already set")
        self.value = value

class batcher(object):
    '''base class for batches of commands submittable in a single request

    All methods invoked on instances of this class are simply queued and
    return a a future for the result. Once you call submit(), all the queued
    calls are performed and the results set in their respective futures.
    '''
    def __init__(self):
        self.calls = []
    def __getattr__(self, name):
        def call(*args, **opts):
            resref = future()
            self.calls.append((name, args, opts, resref,))
            return resref
        return call
    def submit(self):
        pass

class localbatch(batcher):
    '''performs the queued calls directly'''
    def __init__(self, local):
        batcher.__init__(self)
        self.local = local
    def submit(self):
        for name, args, opts, resref in self.calls:
            resref.set(getattr(self.local, name)(*args, **opts))

class remotebatch(batcher):
    '''batches the queued calls; uses as few roundtrips as possible'''
    def __init__(self, remote):
        '''remote must support _submitbatch(encbatch) and
        _submitone(op, encargs)'''
        batcher.__init__(self)
        self.remote = remote
    def submit(self):
        req, rsp = [], []
        for name, args, opts, resref in self.calls:
            mtd = getattr(self.remote, name)
            batchablefn = getattr(mtd, 'batchable', None)
            if batchablefn is not None:
                batchable = batchablefn(mtd.im_self, *args, **opts)
                encargsorres, encresref = batchable.next()
                if encresref:
                    req.append((name, encargsorres,))
                    rsp.append((batchable, encresref, resref,))
                else:
                    resref.set(encargsorres)
            else:
                if req:
                    self._submitreq(req, rsp)
                    req, rsp = [], []
                resref.set(mtd(*args, **opts))
        if req:
            self._submitreq(req, rsp)
    def _submitreq(self, req, rsp):
        encresults = self.remote._submitbatch(req)
        for encres, r in zip(encresults, rsp):
            batchable, encresref, resref = r
            encresref.set(encres)
            resref.set(batchable.next())

def batchable(f):
    '''annotation for batchable methods

    Such methods must implement a coroutine as follows:

    @batchable
    def sample(self, one, two=None):
        # Handle locally computable results first:
        if not one:
            yield "a local result", None
        # Build list of encoded arguments suitable for your wire protocol:
        encargs = [('one', encode(one),), ('two', encode(two),)]
        # Create future for injection of encoded result:
        encresref = future()
        # Return encoded arguments and future:
        yield encargs, encresref
        # Assuming the future to be filled with the result from the batched
        # request now. Decode it:
        yield decode(encresref.value)

    The decorator returns a function which wraps this coroutine as a plain
    method, but adds the original method as an attribute called "batchable",
    which is used by remotebatch to split the call into separate encoding and
    decoding phases.
    '''
    def plain(*args, **opts):
        batchable = f(*args, **opts)
        encargsorres, encresref = batchable.next()
        if not encresref:
            return encargsorres # a local result in this case
        self = args[0]
        encresref.set(self._submitone(f.func_name, encargsorres))
        return batchable.next()
    setattr(plain, 'batchable', f)
    return plain

# list of nodes encoding / decoding

def decodelist(l, sep=' '):
    if l:
        return map(bin, l.split(sep))
    return []

def encodelist(l, sep=' '):
    return sep.join(map(hex, l))

# batched call argument encoding

def escapearg(plain):
    return (plain
            .replace(':', '::')
            .replace(',', ':,')
            .replace(';', ':;')
            .replace('=', ':='))

def unescapearg(escaped):
    return (escaped
            .replace(':=', '=')
            .replace(':;', ';')
            .replace(':,', ',')
            .replace('::', ':'))

# client side

def todict(**args):
    return args

class wirepeer(peer.peerrepository):

    def batch(self):
        return remotebatch(self)
    def _submitbatch(self, req):
        cmds = []
        for op, argsdict in req:
            args = ','.join('%s=%s' % p for p in argsdict.iteritems())
            cmds.append('%s %s' % (op, args))
        rsp = self._call("batch", cmds=';'.join(cmds))
        return rsp.split(';')
    def _submitone(self, op, args):
        return self._call(op, **args)

    @batchable
    def lookup(self, key):
        self.requirecap('lookup', _('look up remote revision'))
        f = future()
        yield todict(key=encoding.fromlocal(key)), f
        d = f.value
        success, data = d[:-1].split(" ", 1)
        if int(success):
            yield bin(data)
        self._abort(error.RepoError(data))

    @batchable
    def heads(self):
        f = future()
        yield {}, f
        d = f.value
        try:
            yield decodelist(d[:-1])
        except ValueError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    @batchable
    def known(self, nodes):
        f = future()
        yield todict(nodes=encodelist(nodes)), f
        d = f.value
        try:
            yield [bool(int(f)) for f in d]
        except ValueError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    @batchable
    def branchmap(self):
        f = future()
        yield {}, f
        d = f.value
        try:
            branchmap = {}
            for branchpart in d.splitlines():
                branchname, branchheads = branchpart.split(' ', 1)
                branchname = encoding.tolocal(urllib.unquote(branchname))
                branchheads = decodelist(branchheads)
                branchmap[branchname] = branchheads
            yield branchmap
        except TypeError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    def branches(self, nodes):
        n = encodelist(nodes)
        d = self._call("branches", nodes=n)
        try:
            br = [tuple(decodelist(b)) for b in d.splitlines()]
            return br
        except ValueError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    def between(self, pairs):
        batch = 8 # avoid giant requests
        r = []
        for i in xrange(0, len(pairs), batch):
            n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
            d = self._call("between", pairs=n)
            try:
                r.extend(l and decodelist(l) or [] for l in d.splitlines())
            except ValueError:
                self._abort(error.ResponseError(_("unexpected response:"), d))
        return r

    @batchable
    def pushkey(self, namespace, key, old, new):
        if not self.capable('pushkey'):
            yield False, None
        f = future()
        self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
        yield todict(namespace=encoding.fromlocal(namespace),
                     key=encoding.fromlocal(key),
                     old=encoding.fromlocal(old),
                     new=encoding.fromlocal(new)), f
        d = f.value
        d, output = d.split('\n', 1)
        try:
            d = bool(int(d))
        except ValueError:
            raise error.ResponseError(
                _('push failed (unexpected response):'), d)
        for l in output.splitlines(True):
            self.ui.status(_('remote: '), l)
        yield d

    @batchable
    def listkeys(self, namespace):
        if not self.capable('pushkey'):
            yield {}, None
        f = future()
        self.ui.debug('preparing listkeys for "%s"\n' % namespace)
        yield todict(namespace=encoding.fromlocal(namespace)), f
        d = f.value
        r = {}
        for l in d.splitlines():
            k, v = l.split('\t')
            r[encoding.tolocal(k)] = encoding.tolocal(v)
        yield r

    def stream_out(self):
        return self._callstream('stream_out')

    def changegroup(self, nodes, kind):
        n = encodelist(nodes)
        f = self._callstream("changegroup", roots=n)
        return changegroupmod.unbundle10(self._decompress(f), 'UN')

    def changegroupsubset(self, bases, heads, kind):
        self.requirecap('changegroupsubset', _('look up remote changes'))
        bases = encodelist(bases)
        heads = encodelist(heads)
        f = self._callstream("changegroupsubset",
                             bases=bases, heads=heads)
        return changegroupmod.unbundle10(self._decompress(f), 'UN')

    def getbundle(self, source, heads=None, common=None):
        self.requirecap('getbundle', _('look up remote changes'))
        opts = {}
        if heads is not None:
            opts['heads'] = encodelist(heads)
        if common is not None:
            opts['common'] = encodelist(common)
        f = self._callstream("getbundle", **opts)
        return changegroupmod.unbundle10(self._decompress(f), 'UN')

    def unbundle(self, cg, heads, source):
        '''Send cg (a readable file-like object representing the
        changegroup to push, typically a chunkbuffer object) to the
        remote server as a bundle. Return an integer indicating the
        result of the push (see localrepository.addchangegroup()).'''

        if heads != ['force'] and self.capable('unbundlehash'):
            heads = encodelist(['hashed',
                                util.sha1(''.join(sorted(heads))).digest()])
        else:
            heads = encodelist(heads)

        ret, output = self._callpush("unbundle", cg, heads=heads)
        if ret == "":
            raise error.ResponseError(
                _('push failed:'), output)
        try:
            ret = int(ret)
        except ValueError:
            raise error.ResponseError(
                _('push failed (unexpected response):'), ret)

        for l in output.splitlines(True):
            self.ui.status(_('remote: '), l)
        return ret

    def debugwireargs(self, one, two, three=None, four=None, five=None):
        # don't pass optional arguments left at their default value
        opts = {}
        if three is not None:
            opts['three'] = three
        if four is not None:
            opts['four'] = four
        return self._call('debugwireargs', one=one, two=two, **opts)

# server side

class streamres(object):
    def __init__(self, gen):
        self.gen = gen

class pushres(object):
    def __init__(self, res):
        self.res = res

class pusherr(object):
    def __init__(self, res):
        self.res = res

class ooberror(object):
    def __init__(self, message):
        self.message = message

def dispatch(repo, proto, command):
    func, spec = commands[command]
    args = proto.getargs(spec)
    return func(repo, proto, *args)

def options(cmd, keys, others):
    opts = {}
    for k in keys:
        if k in others:
            opts[k] = others[k]
            del others[k]
    if others:
        sys.stderr.write("abort: %s got unexpected arguments %s\n"
                         % (cmd, ",".join(others)))
    return opts

def batch(repo, proto, cmds, others):
    res = []
    for pair in cmds.split(';'):
        op, args = pair.split(' ', 1)
        vals = {}
        for a in args.split(','):
            if a:
                n, v = a.split('=')
                vals[n] = unescapearg(v)
        func, spec = commands[op]
        if spec:
            keys = spec.split()
            data = {}
            for k in keys:
                if k == '*':
                    star = {}
                    for key in vals.keys():
                        if key not in keys:
                            star[key] = vals[key]
                    data['*'] = star
                else:
                    data[k] = vals[k]
            result = func(repo, proto, *[data[k] for k in keys])
        else:
            result = func(repo, proto)
        if isinstance(result, ooberror):
            return result
        res.append(escapearg(result))
    return ';'.join(res)

def between(repo, proto, pairs):
    pairs = [decodelist(p, '-') for p in pairs.split(" ")]
    r = []
    for b in repo.between(pairs):
        r.append(encodelist(b) + "\n")
    return "".join(r)

def branchmap(repo, proto):
    branchmap = discovery.visiblebranchmap(repo)
    heads = []
    for branch, nodes in branchmap.iteritems():
        branchname = urllib.quote(encoding.fromlocal(branch))
        branchnodes = encodelist(nodes)
        heads.append('%s %s' % (branchname, branchnodes))
    return '\n'.join(heads)

def branches(repo, proto, nodes):
    nodes = decodelist(nodes)
    r = []
    for b in repo.branches(nodes):
        r.append(encodelist(b) + "\n")
    return "".join(r)

def capabilities(repo, proto):
    caps = ('lookup changegroupsubset branchmap pushkey known getbundle '
            'unbundlehash batch').split()
    if _allowstream(repo.ui):
        if repo.ui.configbool('server', 'preferuncompressed', False):
            caps.append('stream-preferred')
        requiredformats = repo.requirements & repo.supportedformats
        # if our local revlogs are just revlogv1, add 'stream' cap
        if not requiredformats - set(('revlogv1',)):
            caps.append('stream')
        # otherwise, add 'streamreqs' detailing our local revlog format
        else:
            caps.append('streamreqs=%s' % ','.join(requiredformats))
    caps.append('unbundle=%s' % ','.join(changegroupmod.bundlepriority))
    caps.append('httpheader=1024')
    return ' '.join(caps)

def changegroup(repo, proto, roots):
    nodes = decodelist(roots)
    cg = repo.changegroup(nodes, 'serve')
    return streamres(proto.groupchunks(cg))

def changegroupsubset(repo, proto, bases, heads):
    bases = decodelist(bases)
    heads = decodelist(heads)
    cg = repo.changegroupsubset(bases, heads, 'serve')
    return streamres(proto.groupchunks(cg))

def debugwireargs(repo, proto, one, two, others):
    # only accept optional args from the known set
    opts = options('debugwireargs', ['three', 'four'], others)
    return repo.debugwireargs(one, two, **opts)

def getbundle(repo, proto, others):
    opts = options('getbundle', ['heads', 'common'], others)
    for k, v in opts.iteritems():
        opts[k] = decodelist(v)
    cg = repo.getbundle('serve', **opts)
    return streamres(proto.groupchunks(cg))

def heads(repo, proto):
    h = discovery.visibleheads(repo)
    return encodelist(h) + "\n"

def hello(repo, proto):
    '''the hello command returns a set of lines describing various
    interesting things about the server, in an RFC822-like format.
    Currently the only one defined is "capabilities", which
    consists of a line in the form:

    capabilities: space separated list of tokens
    '''
    return "capabilities: %s\n" % (capabilities(repo, proto))

def listkeys(repo, proto, namespace):
    d = repo.listkeys(encoding.tolocal(namespace)).items()
    t = '\n'.join(['%s\t%s' % (encoding.fromlocal(k), encoding.fromlocal(v))
                   for k, v in d])
    return t

def lookup(repo, proto, key):
    try:
        k = encoding.tolocal(key)
        c = repo[k]
        if c.phase() == phases.secret:
            raise error.RepoLookupError(_("unknown revision '%s'") % k)
        r = c.hex()
        success = 1
    except Exception, inst:
        r = str(inst)
        success = 0
    return "%s %s\n" % (success, r)

def known(repo, proto, nodes, others):
    return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))

def pushkey(repo, proto, namespace, key, old, new):
    # compatibility with pre-1.8 clients which were accidentally
    # sending raw binary nodes rather than utf-8-encoded hex
    if len(new) == 20 and new.encode('string-escape') != new:
        # looks like it could be a binary node
        try:
            new.decode('utf-8')
            new = encoding.tolocal(new) # but cleanly decodes as UTF-8
        except UnicodeDecodeError:
            pass # binary, leave unmodified
    else:
        new = encoding.tolocal(new) # normal path

    r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
                     encoding.tolocal(old), new)
    return '%s\n' % int(r)

def _allowstream(ui):
    return ui.configbool('server', 'uncompressed', True, untrusted=True)

def stream(repo, proto):
    '''If the server supports streaming clone, it advertises the "stream"
    capability with a value representing the version and flags of the repo
    it is serving. Client checks to see if it understands the format.

    The format is simple: the server writes out a line with the amount
    of files, then the total amount of bytes to be transfered (separated
    by a space). Then, for each file, the server first writes the filename
    and filesize (separated by the null character), then the file contents.
    '''

    if not _allowstream(repo.ui):
        return '1\n'

    entries = []
    total_bytes = 0
    try:
        # get consistent snapshot of repo, lock during scan
        lock = repo.lock()
        try:
            repo.ui.debug('scanning\n')
            for name, ename, size in repo.store.walk():
                entries.append((name, size))
                total_bytes += size
        finally:
            lock.release()
    except error.LockError:
        return '2\n' # error: 2

    def streamer(repo, entries, total):
        '''stream out all metadata files in repository.'''
        yield '0\n' # success
        repo.ui.debug('%d files, %d bytes to transfer\n' %
                      (len(entries), total_bytes))
        yield '%d %d\n' % (len(entries), total_bytes)
        for name, size in entries:
            repo.ui.debug('sending %s (%d bytes)\n' % (name, size))
            # partially encode name over the wire for backwards compat
            yield '%s\0%d\n' % (store.encodedir(name), size)
            for chunk in util.filechunkiter(repo.sopener(name), limit=size):
                yield chunk

    return streamres(streamer(repo, entries, total_bytes))

def unbundle(repo, proto, heads):
    their_heads = decodelist(heads)

    def check_heads():
        heads = discovery.visibleheads(repo)
        heads_hash = util.sha1(''.join(sorted(heads))).digest()
        return (their_heads == ['force'] or their_heads == heads or
                their_heads == ['hashed', heads_hash])

    proto.redirect()

    # fail early if possible
    if not check_heads():
        return pusherr('unsynced changes')

    # write bundle data to temporary file because it can be big
    fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
    fp = os.fdopen(fd, 'wb+')
    r = 0
    try:
        proto.getfile(fp)
        lock = repo.lock()
        try:
            if not check_heads():
                # someone else committed/pushed/unbundled while we
                # were transferring data
                return pusherr('unsynced changes')

            # push can proceed
            fp.seek(0)
            gen = changegroupmod.readbundle(fp, None)

            try:
                r = repo.addchangegroup(gen, 'serve', proto._client())
            except util.Abort, inst:
                sys.stderr.write("abort: %s\n" % inst)
        finally:
            lock.release()
        return pushres(r)

    finally:
        fp.close()
        os.unlink(tempname)

commands = {
    'batch': (batch, 'cmds *'),
    'between': (between, 'pairs'),
    'branchmap': (branchmap, ''),
    'branches': (branches, 'nodes'),
    'capabilities': (capabilities, ''),
    'changegroup': (changegroup, 'roots'),
    'changegroupsubset': (changegroupsubset, 'bases heads'),
    'debugwireargs': (debugwireargs, 'one two *'),
    'getbundle': (getbundle, '*'),
    'heads': (heads, ''),
    'hello': (hello, ''),
    'known': (known, 'nodes *'),
    'listkeys': (listkeys, 'namespace'),
    'lookup': (lookup, 'key'),
    'pushkey': (pushkey, 'namespace key old new'),
    'stream_out': (stream, ''),
    'unbundle': (unbundle, 'heads'),
}