"""Nested Sets""" import operator from functools import reduce from django.core import serializers from django.db import connection, models from django.db.models import Q from django.utils.translation import gettext_noop as _ from treebeard.exceptions import InvalidMoveToDescendant, NodeAlreadySaved from treebeard.models import Node def get_result_class(cls): """ For the given model class, determine what class we should use for the nodes returned by its tree methods (such as get_children). Usually this will be trivially the same as the initial model class, but there are special cases when model inheritance is in use: * If the model extends another via multi-table inheritance, we need to use whichever ancestor originally implemented the tree behaviour (i.e. the one which defines the 'lft'/'rgt' fields). We can't use the subclass, because it's not guaranteed that the other nodes reachable from the current one will be instances of the same subclass. * If the model is a proxy model, the returned nodes should also use the proxy class. """ base_class = cls._meta.get_field('lft').model if cls._meta.proxy_for_model == base_class: return cls else: return base_class def merge_deleted_counters(c1, c2): """ Merge return values from Django's Queryset.delete() method. """ object_counts = { key: c1[1].get(key, 0) + c2[1].get(key, 0) for key in set(c1[1]) | set(c2[1]) } return (c1[0] + c2[0], object_counts) class NS_NodeQuerySet(models.query.QuerySet): """ Custom queryset for the tree node manager. Needed only for the customized delete method. """ def delete(self, *args, removed_ranges=None, deleted_counter=None, **kwargs): """ Custom delete method, will remove all descendant nodes to ensure a consistent tree (no orphans) :returns: tuple of the number of objects deleted and a dictionary with the number of deletions per object type """ model = get_result_class(self.model) if deleted_counter is None: deleted_counter = (0, {}) if removed_ranges is not None: # we already know the children, let's call the default django # delete method and let it handle the removal of the user's # foreign keys... result = super().delete(*args, **kwargs) deleted_counter = merge_deleted_counters(deleted_counter, result) cursor = model._get_database_cursor('write') # Now closing the gap (Celko's trees book, page 62) # We do this for every gap that was left in the tree when the nodes # were removed. If many nodes were removed, we're going to update # the same nodes over and over again. This would be probably # cheaper precalculating the gapsize per intervals, or just do a # complete reordering of the tree (uses COUNT)... for tree_id, drop_lft, drop_rgt in sorted(removed_ranges, reverse=True): sql, params = model._get_close_gap_sql(drop_lft, drop_rgt, tree_id) cursor.execute(sql, params) else: # we'll have to manually run through all the nodes that are going # to be deleted and remove nodes from the list if an ancestor is # already getting removed, since that would be redundant removed = {} for node in self.order_by('tree_id', 'lft'): found = False for rid, rnode in removed.items(): if node.is_descendant_of(rnode): found = True break if not found: removed[node.pk] = node # ok, got the minimal list of nodes to remove... # we must also remove their descendants toremove = [] ranges = [] for id, node in removed.items(): toremove.append(Q(lft__range=(node.lft, node.rgt)) & Q(tree_id=node.tree_id)) ranges.append((node.tree_id, node.lft, node.rgt)) if toremove: deleted_counter = model.objects.filter( reduce(operator.or_, toremove) ).delete(removed_ranges=ranges, deleted_counter=deleted_counter) return deleted_counter delete.alters_data = True delete.queryset_only = True class NS_NodeManager(models.Manager): """Custom manager for nodes in a Nested Sets tree.""" def get_queryset(self): """Sets the custom queryset as the default.""" return NS_NodeQuerySet(self.model).order_by('tree_id', 'lft') class NS_Node(Node): """Abstract model to create your own Nested Sets Trees.""" node_order_by = [] lft = models.PositiveIntegerField(db_index=True) rgt = models.PositiveIntegerField(db_index=True) tree_id = models.PositiveIntegerField(db_index=True) depth = models.PositiveIntegerField(db_index=True) objects = NS_NodeManager() @classmethod def add_root(cls, **kwargs): """Adds a root node to the tree.""" # do we have a root node already? last_root = cls.get_last_root_node() if last_root and last_root.node_order_by: # there are root nodes and node_order_by has been set # delegate sorted insertion to add_sibling return last_root.add_sibling('sorted-sibling', **kwargs) if last_root: # adding the new root node as the last one newtree_id = last_root.tree_id + 1 else: # adding the first root node newtree_id = 1 if len(kwargs) == 1 and 'instance' in kwargs: # adding the passed (unsaved) instance to the tree newobj = kwargs['instance'] if not newobj._state.adding: raise NodeAlreadySaved("Attempted to add a tree node that is "\ "already in the database") else: # creating the new object newobj = get_result_class(cls)(**kwargs) newobj.depth = 1 newobj.tree_id = newtree_id newobj.lft = 1 newobj.rgt = 2 # saving the instance before returning it newobj.save() return newobj @classmethod def _move_right(cls, tree_id, rgt, lftmove=False, incdec=2): if lftmove: lftop = '>=' else: lftop = '>' sql = 'UPDATE %(table)s '\ ' SET lft = CASE WHEN lft %(lftop)s %(parent_rgt)d '\ ' THEN lft %(incdec)+d '\ ' ELSE lft END, '\ ' rgt = CASE WHEN rgt >= %(parent_rgt)d '\ ' THEN rgt %(incdec)+d '\ ' ELSE rgt END '\ ' WHERE rgt >= %(parent_rgt)d AND '\ ' tree_id = %(tree_id)s' % { 'table': connection.ops.quote_name( get_result_class(cls)._meta.db_table), 'parent_rgt': rgt, 'tree_id': tree_id, 'lftop': lftop, 'incdec': incdec} return sql, [] @classmethod def _move_tree_right(cls, tree_id): sql = 'UPDATE %(table)s '\ ' SET tree_id = tree_id+1 '\ ' WHERE tree_id >= %(tree_id)d' % { 'table': connection.ops.quote_name( get_result_class(cls)._meta.db_table), 'tree_id': tree_id} return sql, [] def add_child(self, **kwargs): """Adds a child to the node.""" if not self.is_leaf(): # there are child nodes, delegate insertion to add_sibling if self.node_order_by: pos = 'sorted-sibling' else: pos = 'last-sibling' last_child = self.get_last_child() last_child._cached_parent_obj = self return last_child.add_sibling(pos, **kwargs) # we're adding the first child of this node sql, params = self.__class__._move_right(self.tree_id, self.rgt, False, 2) if len(kwargs) == 1 and 'instance' in kwargs: # adding the passed (unsaved) instance to the tree newobj = kwargs['instance'] if not newobj._state.adding: raise NodeAlreadySaved("Attempted to add a tree node that is "\ "already in the database") else: # creating a new object newobj = get_result_class(self.__class__)(**kwargs) newobj.tree_id = self.tree_id newobj.depth = self.depth + 1 newobj.lft = self.lft + 1 newobj.rgt = self.lft + 2 # this is just to update the cache self.rgt += 2 newobj._cached_parent_obj = self cursor = self._get_database_cursor('write') cursor.execute(sql, params) # saving the instance before returning it newobj.save() return newobj def add_sibling(self, pos=None, **kwargs): """Adds a new node as a sibling to the current node object.""" pos = self._prepare_pos_var_for_add_sibling(pos) if len(kwargs) == 1 and 'instance' in kwargs: # adding the passed (unsaved) instance to the tree newobj = kwargs['instance'] if not newobj._state.adding: raise NodeAlreadySaved("Attempted to add a tree node that is "\ "already in the database") else: # creating a new object newobj = get_result_class(self.__class__)(**kwargs) newobj.depth = self.depth sql = None target = self if target.is_root(): newobj.lft = 1 newobj.rgt = 2 if pos == 'sorted-sibling': siblings = list(target.get_sorted_pos_queryset( target.get_siblings(), newobj)) if siblings: pos = 'left' target = siblings[0] else: pos = 'last-sibling' last_root = target.__class__.get_last_root_node() if ( (pos == 'last-sibling') or (pos == 'right' and target == last_root) ): newobj.tree_id = last_root.tree_id + 1 else: newpos = {'first-sibling': 1, 'left': target.tree_id, 'right': target.tree_id + 1}[pos] sql, params = target.__class__._move_tree_right(newpos) newobj.tree_id = newpos else: newobj.tree_id = target.tree_id if pos == 'sorted-sibling': siblings = list(target.get_sorted_pos_queryset( target.get_siblings(), newobj)) if siblings: pos = 'left' target = siblings[0] else: pos = 'last-sibling' if pos in ('left', 'right', 'first-sibling'): siblings = list(target.get_siblings()) if pos == 'right': if target == siblings[-1]: pos = 'last-sibling' else: pos = 'left' found = False for node in siblings: if found: target = node break elif node == target: found = True if pos == 'left': if target == siblings[0]: pos = 'first-sibling' if pos == 'first-sibling': target = siblings[0] move_right = self.__class__._move_right if pos == 'last-sibling': newpos = target.get_parent().rgt sql, params = move_right(target.tree_id, newpos, False, 2) elif pos == 'first-sibling': newpos = target.lft sql, params = move_right(target.tree_id, newpos - 1, False, 2) elif pos == 'left': newpos = target.lft sql, params = move_right(target.tree_id, newpos, True, 2) newobj.lft = newpos newobj.rgt = newpos + 1 # saving the instance before returning it if sql: cursor = self._get_database_cursor('write') cursor.execute(sql, params) newobj.save() return newobj def move(self, target, pos=None): """ Moves the current node and all it's descendants to a new position relative to another node. """ pos = self._prepare_pos_var_for_move(pos) cls = get_result_class(self.__class__) parent = None if pos in ('first-child', 'last-child', 'sorted-child'): # moving to a child if target.is_leaf(): parent = target pos = 'last-child' else: target = target.get_last_child() pos = {'first-child': 'first-sibling', 'last-child': 'last-sibling', 'sorted-child': 'sorted-sibling'}[pos] if target.is_descendant_of(self): raise InvalidMoveToDescendant( _("Can't move node to a descendant.")) if self == target and ( (pos == 'left') or (pos in ('right', 'last-sibling') and target == target.get_last_sibling()) or (pos == 'first-sibling' and target == target.get_first_sibling())): # special cases, not actually moving the node so no need to UPDATE return if pos == 'sorted-sibling': siblings = list(target.get_sorted_pos_queryset( target.get_siblings(), self)) if siblings: pos = 'left' target = siblings[0] else: pos = 'last-sibling' if pos in ('left', 'right', 'first-sibling'): siblings = list(target.get_siblings()) if pos == 'right': if target == siblings[-1]: pos = 'last-sibling' else: pos = 'left' found = False for node in siblings: if found: target = node break elif node == target: found = True if pos == 'left': if target == siblings[0]: pos = 'first-sibling' if pos == 'first-sibling': target = siblings[0] # ok let's move this cursor = self._get_database_cursor('write') move_right = cls._move_right gap = self.rgt - self.lft + 1 sql = None target_tree = target.tree_id # first make a hole if pos == 'last-child': newpos = parent.rgt sql, params = move_right(target.tree_id, newpos, False, gap) elif target.is_root(): newpos = 1 if pos == 'last-sibling': target_tree = target.get_siblings().reverse()[0].tree_id + 1 elif pos == 'first-sibling': target_tree = 1 sql, params = cls._move_tree_right(1) elif pos == 'left': sql, params = cls._move_tree_right(target.tree_id) else: if pos == 'last-sibling': newpos = target.get_parent().rgt sql, params = move_right(target.tree_id, newpos, False, gap) elif pos == 'first-sibling': newpos = target.lft sql, params = move_right(target.tree_id, newpos - 1, False, gap) elif pos == 'left': newpos = target.lft sql, params = move_right(target.tree_id, newpos, True, gap) if sql: cursor.execute(sql, params) # we reload 'self' because lft/rgt may have changed fromobj = cls.objects.get(pk=self.pk) depthdiff = target.depth - fromobj.depth if parent: depthdiff += 1 # move the tree to the hole sql = "UPDATE %(table)s "\ " SET tree_id = %(target_tree)d, "\ " lft = lft + %(jump)d , "\ " rgt = rgt + %(jump)d , "\ " depth = depth + %(depthdiff)d "\ " WHERE tree_id = %(from_tree)d AND "\ " lft BETWEEN %(fromlft)d AND %(fromrgt)d" % { 'table': connection.ops.quote_name(cls._meta.db_table), 'from_tree': fromobj.tree_id, 'target_tree': target_tree, 'jump': newpos - fromobj.lft, 'depthdiff': depthdiff, 'fromlft': fromobj.lft, 'fromrgt': fromobj.rgt} cursor.execute(sql, []) # close the gap sql, params = cls._get_close_gap_sql(fromobj.lft, fromobj.rgt, fromobj.tree_id) cursor.execute(sql, params) @classmethod def _get_close_gap_sql(cls, drop_lft, drop_rgt, tree_id): sql = 'UPDATE %(table)s '\ ' SET lft = CASE '\ ' WHEN lft > %(drop_lft)d '\ ' THEN lft - %(gapsize)d '\ ' ELSE lft END, '\ ' rgt = CASE '\ ' WHEN rgt > %(drop_lft)d '\ ' THEN rgt - %(gapsize)d '\ ' ELSE rgt END '\ ' WHERE (lft > %(drop_lft)d '\ ' OR rgt > %(drop_lft)d) AND '\ ' tree_id=%(tree_id)d' % { 'table': connection.ops.quote_name( get_result_class(cls)._meta.db_table), 'gapsize': drop_rgt - drop_lft + 1, 'drop_lft': drop_lft, 'tree_id': tree_id} return sql, [] @classmethod def load_bulk(cls, bulk_data, parent=None, keep_ids=False): """Loads a list/dictionary structure to the tree.""" cls = get_result_class(cls) # tree, iterative preorder added = [] if parent: parent_id = parent.pk else: parent_id = None # stack of nodes to analyze stack = [(parent_id, node) for node in bulk_data[::-1]] foreign_keys = cls.get_foreign_keys() pk_field = cls._meta.pk.attname while stack: parent_id, node_struct = stack.pop() # shallow copy of the data structure so it doesn't persist... node_data = node_struct['data'].copy() cls._process_foreign_keys(foreign_keys, node_data) if keep_ids: node_data[pk_field] = node_struct[pk_field] if parent_id: parent = cls.objects.get(pk=parent_id) node_obj = parent.add_child(**node_data) else: node_obj = cls.add_root(**node_data) added.append(node_obj.pk) if 'children' in node_struct: # extending the stack with the current node as the parent of # the new nodes stack.extend([ (node_obj.pk, node) for node in node_struct['children'][::-1] ]) return added def get_children(self): """:returns: A queryset of all the node's children""" return self.get_descendants().filter(depth=self.depth + 1) def get_depth(self): """:returns: the depth (level) of the node""" return self.depth def is_leaf(self): """:returns: True if the node is a leaf node (else, returns False)""" return self.rgt - self.lft == 1 def get_root(self): """:returns: the root node for the current node object.""" if self.lft == 1: return self return get_result_class(self.__class__).objects.get( tree_id=self.tree_id, lft=1) def is_root(self): """:returns: True if the node is a root node (else, returns False)""" return self.lft == 1 def get_siblings(self): """ :returns: A queryset of all the node's siblings, including the node itself. """ if self.lft == 1: return self.get_root_nodes() return self.get_parent(True).get_children() @classmethod def dump_bulk(cls, parent=None, keep_ids=True): """Dumps a tree branch to a python data structure.""" qset = cls._get_serializable_model().get_tree(parent) ret, lnk = [], {} pk_field = cls._meta.pk.attname for pyobj in qset: serobj = serializers.serialize('python', [pyobj])[0] # django's serializer stores the attributes in 'fields' fields = serobj['fields'] depth = fields['depth'] # this will be useless in load_bulk del fields['lft'] del fields['rgt'] del fields['depth'] del fields['tree_id'] if pk_field in fields: # this happens immediately after a load_bulk del fields[pk_field] newobj = {'data': fields} if keep_ids: newobj[pk_field] = serobj['pk'] if (not parent and depth == 1) or\ (parent and depth == parent.depth): ret.append(newobj) else: parentobj = pyobj.get_parent() parentser = lnk[parentobj.pk] if 'children' not in parentser: parentser['children'] = [] parentser['children'].append(newobj) lnk[pyobj.pk] = newobj return ret @classmethod def get_tree(cls, parent=None): """ :returns: A *queryset* of nodes ordered as DFS, including the parent. If no parent is given, all trees are returned. """ cls = get_result_class(cls) if parent is None: # return the entire tree return cls.objects.all() if parent.is_leaf(): return cls.objects.filter(pk=parent.pk) return cls.objects.filter( tree_id=parent.tree_id, lft__range=(parent.lft, parent.rgt - 1)) def get_descendants(self): """ :returns: A queryset of all the node's descendants as DFS, doesn't include the node itself """ if self.is_leaf(): return get_result_class(self.__class__).objects.none() return self.__class__.get_tree(self).exclude(pk=self.pk) def get_descendant_count(self): """:returns: the number of descendants of a node.""" return (self.rgt - self.lft - 1) / 2 def get_ancestors(self): """ :returns: A queryset containing the current node object's ancestors, starting by the root node and descending to the parent. """ if self.is_root(): return get_result_class(self.__class__).objects.none() return get_result_class(self.__class__).objects.filter( tree_id=self.tree_id, lft__lt=self.lft, rgt__gt=self.rgt) def is_descendant_of(self, node): """ :returns: ``True`` if the node if a descendant of another node given as an argument, else, returns ``False`` """ return ( self.tree_id == node.tree_id and self.lft > node.lft and self.rgt < node.rgt ) def get_parent(self, update=False): """ :returns: the parent node of the current node object. Caches the result in the object itself to help in loops. """ if self.is_root(): return try: if update: del self._cached_parent_obj else: return self._cached_parent_obj except AttributeError: pass # parent = our most direct ancestor self._cached_parent_obj = self.get_ancestors().reverse()[0] return self._cached_parent_obj @classmethod def get_root_nodes(cls): """:returns: A queryset containing the root nodes in the tree.""" return get_result_class(cls).objects.filter(lft=1) class Meta: """Abstract model.""" abstract = True