root/branches/mk/cheesecake/model.py

Revision 144, 22.2 kB (checked in by mk, 7 years ago)

Fixed nasty code parsing bug.

Line 
1 """
2 Code borrowed from Michael Hudson's docextractor package with the author's
3 permission.
4
5 The original code is available at http://codespeak.net/svn/user/mwh/docextractor/.
6
7 Changes:
8   * do not print warnings to stdout (in System.warning)
9   * collect all function calls
10 """
11
12
13 from compiler import ast
14 import sys
15 import os
16 import cPickle as pickle
17 import __builtin__
18 import sets
19
20 from compiler.transformer import parse, parseFile
21 from compiler.visitor import walk
22
23 import ast_pp
24
25
26 def get_call_name(node):
27     assert isinstance(node, ast.CallFunc)
28
29     def get_name(node):
30         if isinstance(node, ast.CallFunc):
31             return None
32         elif isinstance(node, ast.Name):
33             return node.name
34         elif isinstance(node, str):
35             return node
36         elif isinstance(node, tuple):
37             if len(node) == 1:
38                 return node[0]
39             else:
40                 return "%s.%s" % (get_name(node[:-1][0]), node[-1])
41         elif isinstance(node, ast.Getattr):
42             return get_name(node.asList())
43         else:
44             return None
45
46     return get_name(node.node)
47
48 def get_function_calls(node, fc):
49     if not isinstance(node, ast.Node):
50         return
51
52     for child in node.getChildren():
53         if isinstance(child, ast.CallFunc):
54             func_called = get_call_name(child)
55             if func_called:
56                 fc[func_called] = 1
57
58         get_function_calls(child, fc)
59
60
61 class Documentable(object):
62     def __init__(self, system, prefix, name, docstring, parent=None):
63         self.system = system
64         self.prefix = prefix
65         self.name = name
66         self.docstring = docstring
67         self.parent = parent
68         self.setup()
69     def setup(self):
70         self.contents = {}
71         self.orderedcontents = []
72         self._name2fullname = {}
73     def fullName(self):
74         return self.prefix + self.name
75     def shortdocstring(self):
76         docstring = self.docstring
77         if docstring:
78             docstring = docstring.rstrip()
79             if len(docstring) > 20:
80                 docstring = docstring[:8] + '...' + docstring[-8:]
81         return docstring
82     def __repr__(self):
83         return "%s %r"%(self.__class__.__name__, self.fullName())
84     def name2fullname(self, name):
85         if name in self._name2fullname:
86             return self._name2fullname[name]
87         else:
88             return self.parent.name2fullname(name)
89
90     def resolveDottedName(self, dottedname, verbose=False):
91         parts = dottedname.split('.')
92         obj = self
93         system = self.system
94         while parts[0] not in obj._name2fullname:
95             obj = obj.parent
96             if obj is None:
97                 if parts[0] in system.allobjects:
98                     obj = system.allobjects[parts[0]]
99                     break
100                 for othersys in system.moresystems:
101                     if parts[0] in othersys.allobjects:
102                         obj = othersys.allobjects[parts[0]]
103                         break
104                 else:
105                     if verbose:
106                         print "1 didn't find %r from %r"%(dottedname,
107                                                       self.fullName())
108                     return None
109                 break
110         else:
111             fn = obj._name2fullname[parts[0]]
112             if fn in system.allobjects:
113                 obj = system.allobjects[fn]
114             else:
115                 if verbose:
116                     print "1.5 didn't find %r from %r"%(dottedname,
117                                                         self.fullName())
118                 return None
119         for p in parts[1:]:
120             if p not in obj.contents:
121                 if verbose:
122                     print "2 didn't find %r from %r"%(dottedname,
123                                                       self.fullName())
124                 return None
125             obj = obj.contents[p]
126         if verbose:
127             print dottedname, '->', obj.fullName(), 'in', self.fullName()
128         return obj
129
130     def dottedNameToFullName(self, dottedname):
131         if '.' not in dottedname:
132             start, rest = dottedname, ''
133         else:
134             start, rest = dottedname.split('.', 1)
135             rest = '.' + rest
136         obj = self
137         while start not in obj._name2fullname:
138             obj = obj.parent
139             if obj is None:
140                 return dottedname
141         return obj._name2fullname[start] + rest
142
143     def __getstate__(self):
144         # this is so very, very evil.
145         # see doc/extreme-pickling-pain.txt for more.
146         r = {}
147         for k, v in self.__dict__.iteritems():
148             if isinstance(v, Documentable):
149                 r['$'+k] = v.fullName()
150             elif isinstance(v, list) and v:
151                 for vv in v:
152                     if vv is not None and not isinstance(vv, Documentable):
153                         r[k] = v
154                         break
155                 else:
156                     rr = []
157                     for vv in v:
158                         if vv is None:
159                             rr.append(vv)
160                         else:
161                             rr.append(vv.fullName())
162                     r['@'+k] = rr
163             elif isinstance(v, dict) and v:
164                 for vv in v.itervalues():
165                     if not isinstance(vv, Documentable):
166                         r[k] = v
167                         break
168                 else:
169                     rr = {}
170                     for kk, vv in v.iteritems():
171                         rr[kk] = vv.fullName()
172                     r['!'+k] = rr
173             else:
174                 r[k] = v
175         return r
176
177 class Package(Documentable):
178     kind = "Package"
179     def name2fullname(self, name):
180         raise NameError
181
182
183 class Module(Documentable):
184     kind = "Module"
185     def name2fullname(self, name):
186         if name in self._name2fullname:
187             return self._name2fullname[name]
188         elif name in __builtin__.__dict__:
189             return name
190         else:
191             self.system.warning("optimistic name resolution", name)
192             return name
193
194
195 class Class(Documentable):
196     kind = "Class"
197     def setup(self):
198         super(Class, self).setup()
199         self.bases = []
200         self.rawbases = []
201         self.baseobjects = []
202         self.subclasses = []
203
204
205 class Function(Documentable):
206     kind = "Function"
207
208
209 class ModuleVistor(object):
210     def __init__(self, system, modname):
211         self.system = system
212         self.modname = modname
213         self.morenodes = []
214
215     def default(self, node):
216         for child in node.getChildNodes():
217             self.visit(child)
218
219     def postpone(self, docable, node):
220         self.morenodes.append((docable, node))
221
222     def visitModule(self, node):
223         if self.system.current and self.modname in self.system.current.contents:
224             m = self.system.current.contents[self.modname]
225             assert m.docstring is None
226             m.docstring = node.doc
227             self.system.push(m, node)
228             self.default(node)
229             self.system.pop(m)
230         else:
231             if not self.system.current:
232                 roots = [x for x in self.system.rootobjects if x.name == self.modname]
233                 if roots:
234                     mod, = roots
235                     self.system.push(mod, node)
236                     self.default(node)
237                     self.system.pop(mod)
238                     return
239             self.system.pushModule(self.modname, node.doc)
240             self.default(node)
241             self.system.popModule()
242
243     def visitClass(self, node):
244         cls = self.system.pushClass(node.name, node.doc)
245         if node.lineno is not None:
246             cls.linenumber = node.lineno
247         for n in node.bases:
248             str_base = ast_pp.pp(n)
249             cls.rawbases.append(str_base)
250             base = cls.dottedNameToFullName(str_base)
251             cls.bases.append(base)
252         self.default(node)
253         self.system.popClass()
254
255     def visitFrom(self, node):
256         modname = expandModname(self.system, node.modname)
257         name2fullname = self.system.current._name2fullname
258         for fromname, asname in node.names:
259             if fromname == '*':
260                 self.system.warning("import *", modname)
261                 if modname not in self.system.allobjects:
262                     return
263                 mod = self.system.allobjects[modname]
264                 # this might fail if you have an import-* cycle, or if
265                 # you're just not running the import star finder to
266                 # save time (not that this is possibly without
267                 # commenting stuff out yet, but...)
268                 if isinstance(mod, Package):
269                     self.system.warning("import * from a package", modname)
270                     return
271                 if mod.processed:
272                     for n in mod.contents:
273                         name2fullname[n] = modname + '.' + n
274                 else:
275                     self.system.warning("unresolvable import *", modname)
276                 return
277             if asname is None:
278                 asname = fromname
279             name2fullname[asname] = modname + '.' + fromname
280
281     def visitImport(self, node):
282         name2fullname = self.system.current._name2fullname
283         for fromname, asname in node.names:
284             fullname = expandModname(self.system, fromname)
285             if asname is None:
286                 asname = fromname.split('.', 1)[0]
287                 # aaaaargh! python sucks.
288                 parts = fullname.split('.')
289                 for i, part in enumerate(fullname.split('.')[::-1]):
290                     if part == asname:
291                         fullname = '.'.join(parts[:len(parts)-i])
292                         name2fullname[asname] = fullname
293                         break
294                 else:
295                     name2fullname[asname] = '.'.join(parts)
296             else:
297                 name2fullname[asname] = fullname
298
299     def visitFunction(self, node):
300         fc = {}
301         get_function_calls(node, fc)
302         func = self.system.pushFunction(node.name, node.doc, fc)
303         if node.lineno is not None:
304             func.linenumber = node.lineno
305         # ast.Function has a pretty lame representation of
306         # arguments. Let's convert it to a nice concise format
307         # somewhat like what inspect.getargspec returns
308         argnames = node.argnames[:]
309         kwname = starargname = None
310         if node.kwargs:
311             kwname = argnames.pop(-1)
312         if node.varargs:
313             starargname = argnames.pop(-1)
314         defaults = []
315         for default in node.defaults:
316             try:
317                 defaults.append(ast_pp.pp(default))
318             except (KeyboardInterrupt, SystemExit):
319                 raise
320             except Exception, e:
321                 self.system.warning("unparseable default", "%s: %s %r"%(e.__class__.__name__,
322                                                                        e, default))
323                 defaults.append('???')
324         # argh, convert unpacked-arguments from tuples to lists,
325         # because that's what getargspec uses and the unit test
326         # compares it
327         argnames2 = []
328         for argname in argnames:
329             if isinstance(argname, tuple):
330                 argname = list(argname)
331             argnames2.append(argname)
332         func.argspec = (argnames2, starargname, kwname, tuple(defaults))
333         self.postpone(func, node.code)
334         self.system.popFunction()
335
336 states = [
337     'blank',
338     'preparse',
339     'importstarred',
340     'parsed',
341     'finalized',
342     ]
343
344
345 class System(object):
346     Class = Class
347     Module = Module
348     Package = Package
349     Function = Function
350     ModuleVistor = ModuleVistor
351
352     def __init__(self):
353         self.current = None
354         self._stack = []
355         self.allobjects = {}
356         self.orderedallobjects = []
357         self.rootobjects = []
358         self.warnings = {}
359         # importstargraph contains edges {importer:[imported]} but only
360         # for import * statements
361         self.importstargraph = {}
362         self.func_called = {}
363         self.state = 'blank'
364         self.packages = []
365         self.moresystems = []
366         self.urlprefix = ''
367
368     def _push(self, cls, name, docstring):
369         if self.current:
370             prefix = self.current.fullName() + '.'
371             parent = self.current
372         else:
373             prefix = ''
374             parent = None
375         obj = cls(self, prefix, name, docstring, parent)
376         if parent:
377             parent.orderedcontents.append(obj)
378             parent.contents[name] = obj
379             parent._name2fullname[name] = obj.fullName()
380         else:
381             self.rootobjects.append(obj)
382         self.current = obj
383         self.orderedallobjects.append(obj)
384         fullName = obj.fullName()
385         #print 'push', cls.__name__, fullName
386         if fullName in self.allobjects:
387             obj = self.handleDuplicate(obj)
388         else:
389             self.allobjects[obj.fullName()] = obj
390         return obj
391
392     def handleDuplicate(self, obj):
393         '''This is called when we see two objects with the same
394         .fullName(), for example:
395
396         class C:
397             if something:
398                 def meth(self):
399                     implementation 1
400             else:
401                 def meth(self):
402                     implementation 2
403
404         The default is that the second definition "wins".
405         '''
406         i = 0
407         fn = obj.fullName()
408         while (fn + ' ' + str(i)) in self.allobjects:
409             i += 1
410         prev = self.allobjects[obj.fullName()]
411         prev.name = obj.name + ' ' + str(i)
412         self.allobjects[prev.fullName()] = prev
413         self.warning("duplicate", self.allobjects[obj.fullName()])
414         self.allobjects[obj.fullName()] = obj
415         return obj
416
417
418     def _pop(self, cls):
419         assert isinstance(self.current, cls)
420 ##         if self.current.parent:
421 ##             print 'pop', self.current.fullName(), '->', self.current.parent.fullName()
422 ##         else:
423 ##             print 'pop', self.current.fullName(), '->', self.current.parent
424         self.current = self.current.parent
425
426     def push(self, obj, node=None):
427         self._stack.append(self.current)
428         self.current = obj
429
430     def pop(self, obj):
431         assert self.current is obj, "%r is not %r"%(self.current, obj)
432         self.current = self._stack.pop()
433
434     def pushClass(self, name, docstring):
435         return self._push(self.Class, name, docstring)
436     def popClass(self):
437         self._pop(self.Class)
438
439     def pushModule(self, name, docstring):
440         return self._push(self.Module, name, docstring)
441     def popModule(self):
442         self._pop(self.Module)
443
444     def pushFunction(self, name, docstring, func_called):
445         self.func_called.update(func_called)
446         return self._push(self.Function, name, docstring)
447     def popFunction(self):
448         self._pop(self.Function)
449
450     def pushPackage(self, name, docstring):
451         return self._push(self.Package, name, docstring)
452     def popPackage(self):
453         self._pop(self.Package)
454
455     def report(self):
456         for o in self.rootobjects:
457             self._report(o, '')
458
459     def _report(self, o, indent):
460         print indent, o
461         for o2 in o.orderedcontents:
462             self._report(o2, indent+'  ')
463
464     def resolveAlias(self, n):
465         if '.' not in n:
466             return n
467         mod, clsname = n.split('.')
468         if not mod or mod not in self.allobjects:
469             return n
470         m = self.allobjects[mod]
471         if not isinstance(m, Module):
472             return n
473         if clsname in m._name2fullname:
474             newname = m.name2fullname(clsname)
475             if newname not in self.allobjects:
476                 return self.resolveAlias(newname)
477             else:
478                 return newname
479
480     def resolveAliases(self):
481         for ob in self.orderedallobjects:
482             if not isinstance(ob, Class):
483                 continue
484             for i, b in enumerate(ob.bases):
485                 if b not in self.allobjects:
486                     ob.bases[i] = self.resolveAlias(b)
487
488     def warning(self, type, detail):
489         if self.current is not None:
490             fn = self.current.fullName()
491         else:
492             fn = '<None>'
493         self.warnings.setdefault(type, []).append((fn, detail))
494
495     def objectsOfType(self, cls):
496         for o in self.orderedallobjects:
497             if isinstance(o, cls):
498                 yield o
499
500     def finalStateComputations(self):
501         self.recordBasesAndSubclasses()
502
503     def recordBasesAndSubclasses(self):
504         for cls in self.objectsOfType(Class):
505             for n in cls.bases:
506                 o = cls.parent.resolveDottedName(n)
507                 cls.baseobjects.append(o)
508                 if o:
509                     o.subclasses.append(cls)
510
511     def __getstate__(self):
512         state = self.__dict__.copy()
513         del state['moresystems']
514         return state
515
516     def __setstate__(self, state):
517         self.moresystems = []
518         # this is so very, very evil.
519         # see doc/extreme-pickling-pain.txt for more.
520         self.__dict__.update(state)
521         for obj in self.orderedallobjects:
522             for k, v in obj.__dict__.copy().iteritems():
523                 if k.startswith('$'):
524                     del obj.__dict__[k]
525                     obj.__dict__[k[1:]] = self.allobjects[v]
526                 elif k.startswith('@'):
527                     n = []
528                     for vv in v:
529                         if vv is None:
530                             n.append(None)
531                         else:
532                             n.append(self.allobjects[vv])
533                     del obj.__dict__[k]
534                     obj.__dict__[k[1:]] = n
535                 elif k.startswith('!'):
536                     n = {}
537                     for kk, vv in v.iteritems():
538                         n[kk] = self.allobjects[vv]
539                     del obj.__dict__[k]
540                     obj.__dict__[k[1:]] = n
541
542
543 def expandModname(system, modname, givewarning=True):
544     c = system.current
545     if '.' in modname:
546         prefix, suffix = modname.split('.', 1)
547         suffix = '.' + suffix
548     else:
549         prefix, suffix = modname, ''
550     while c is not None and not isinstance(c, Package):
551         c = c.parent
552     while c is not None:
553         if prefix in c.contents:
554             break
555         c = c.parent
556     if c is not None:
557         if givewarning:
558             system.warning("local import", modname)
559         return c.contents[prefix].fullName() + suffix
560     else:
561         return prefix + suffix
562
563 class ImportStarFinder(object):
564     def __init__(self, system, modfullname):
565         self.system = system
566         self.modfullname = modfullname
567
568     def visitFrom(self, node):
569         if node.names[0][0] == '*':
570             modname = expandModname(self.system, node.modname, False)
571             self.system.importstargraph.setdefault(
572                 self.modfullname, []).append(modname)
573
574 def processModuleAst(ast, name, system):
575     mv = system.ModuleVistor(system, name)
576     walk(ast, mv)
577     while mv.morenodes:
578         obj, node = mv.morenodes.pop(0)
579         system.push(obj, node)
580         mv.visit(node)
581         system.pop(obj)
582
583
584 def fromText(src, modname='<test>', system=None):
585     if system is None:
586         _system = System()
587     else:
588         _system = system
589     processModuleAst(parse(src), modname, _system)
590     if system is None:
591         _system.finalStateComputations()
592     return _system.rootobjects[0]
593
594
595 def preprocessDirectory(system, dirpath):
596     assert system.state in ['blank', 'preparse']
597     if os.path.basename(dirpath):
598         package = system.pushPackage(os.path.basename(dirpath), None)
599     else:
600         package = None
601     for fname in os.listdir(dirpath):
602         fullname = os.path.join(dirpath, fname)
603         if os.path.isdir(fullname) and os.path.exists(os.path.join(fullname, '__init__.py')) and fname != 'test':
604             preprocessDirectory(system, fullname)
605         elif fname.endswith('.py'):
606             modname = os.path.splitext(fname)[0]
607             mod = system.pushModule(modname, None)
608             mod.filepath = fullname
609             mod.processed = False
610             system.popModule()
611     if package:
612         system.popPackage()
613     system.state = 'preparse'
614
615 def findImportStars(system):
616     assert system.state in ['preparse']
617     modlist = list(system.objectsOfType(Module))
618     for mod in modlist:
619         system.push(mod.parent)
620         isf = ImportStarFinder(system, mod.fullName())
621         try:
622             ast = parseFile(mod.filepath)
623         except (SyntaxError, ValueError):
624             system.warning("cannot parse", mod.filepath)
625         walk(ast, isf)
626         system.pop(mod.parent)
627     system.state = 'importstarred'
628
629 def extractDocstrings(system):
630     assert system.state in ['preparse', 'importstarred']
631     # and so much more...
632     modlist = list(system.objectsOfType(Module))
633     newlist = toposort([m.fullName() for m in modlist], system.importstargraph)
634
635     for mod in newlist:
636         mod = system.allobjects[mod]
637         system.push(mod.parent)
638         try:
639             ast = parseFile(mod.filepath)
640         except (SyntaxError, ValueError):
641             system.warning("cannot parse", mod.filepath)
642         processModuleAst(ast, mod.name, system)
643         mod.processed = True
644         system.pop(mod.parent)
645     system.state = 'parsed'
646
647 def finalStateComputations(system):
648     assert system.state in ['parsed']
649     system.finalStateComputations()
650     system.state = 'finalized'
651
652 def processDirectory(system, dirpath):
653     preprocessDirectory(system, dirpath)
654     findImportStars(system)
655     extractDocstrings(system)
656     finalStateComputations(system)
657
658 def toposort(input, edges):
659     # this doesn't detect cycles in any clever way.
660     output = []
661     input = dict.fromkeys(input)
662     def p(i):
663         for j in edges.get(i, []):
664             if j in input:
665                 del input[j]
666                 p(j)
667         output.append(i)
668     while input:
669         p(input.popitem()[0])
670     return output
671
672
673 def main(systemcls, argv):
674     if '-r' in argv:
675         argv.remove('-r')
676         assert len(argv) == 1
677         system = systemcls()
678         processDirectory(system, argv[0])
679         pickle.dump(system, open('da.out', 'wb'), pickle.HIGHEST_PROTOCOL)
680         print
681         print 'warning summary:'
682         for k, v in system.warnings.iteritems():
683             print k, len(v)
684     else:
685         system = systemcls()
686         for fname in argv:
687             modname = os.path.splitext(os.path.basename(fname))[0] # XXX!
688             processModuleAst(parseFile(fname), modname, system)
689         system.report()
690
691
692
693 if __name__ == '__main__':
694     main(System, sys.argv[1:])
Note: See TracBrowser for help on using the browser.