From 6a33e8ee0816896b64ef492874f6da6b236c84ae Mon Sep 17 00:00:00 2001 From: Salvo 'LtWorf' Tomaselli Date: Wed, 3 Jun 2020 07:01:52 +0200 Subject: [PATCH] Improve type annotations --- relational/relation.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/relational/relation.py b/relational/relation.py index 1f564fd..346730d 100644 --- a/relational/relation.py +++ b/relational/relation.py @@ -22,7 +22,8 @@ import csv from itertools import chain, repeat from collections import deque -from typing import List, Union, Set +from typing import * +from pathlib import Path from relational.rtypes import * @@ -53,11 +54,11 @@ class Relation: ''' __hash__ = None # type: None - def __init__(self, filename: str = '') -> None: + def __init__(self, filename: Optional[Union[str, Path]] = None) -> None: self._readonly = False self.content: Set[tuple] = set() - if len(filename) == 0: # Empty relation + if filename is None: # Empty relation self.header = Header([]) return with open(filename) as fp: @@ -73,7 +74,7 @@ class Relation: self._readonly = True copy._readonly = True - def _make_writable(self, copy_content : bool = True) -> None: + def _make_writable(self, copy_content: bool = True) -> None: '''If this relation is marked as readonly, this method will copy the content to make it writable too @@ -92,7 +93,7 @@ class Relation: def __contains__(self, key): return key in self.content - def save(self, filename: str) -> None: + def save(self, filename: Union[Path, str]) -> None: ''' Saves the relation in a file. Will save using the csv format as defined in RFC4180. @@ -169,7 +170,7 @@ class Relation: newt.content.add(i + j) return newt - def projection(self, * attributes) -> 'Relation': + def projection(self, *attributes) -> 'Relation': ''' Can be called in two different ways: a.projection('field1','field2') @@ -200,7 +201,7 @@ class Relation: newt.content.add(tuple(row)) return newt - def rename(self, params: 'Relation') -> 'Relation': + def rename(self, params: Dict[str, str]) -> 'Relation': ''' Takes a dictionary. @@ -505,7 +506,7 @@ class Header(tuple): def __repr__(self): return "Header(%s)" % super(Header, self).__repr__() - def rename(self, params) -> 'Header': + def rename(self, params: Dict[str, str]) -> 'Header': '''Returns a new header, with renamed fields. params is a dictionary of {old:new} names @@ -525,15 +526,15 @@ class Header(tuple): '''Returns how many attributes this header has in common with a given one''' return len(set(self).intersection(set(other))) - def union(self, other) -> set: + def union(self, other: 'Header') -> Set[str]: '''Returns the union of the sets of attributes with another header.''' return set(self).union(set(other)) - def intersection(self, other) -> set: + def intersection(self, other: 'Header') -> Set[str]: '''Returns the set of common attributes with another header.''' return set(self).intersection(set(other)) - def getAttributesId(self, param) -> List[int]: + def getAttributesId(self, param: Iterable[str]) -> List[int]: '''Returns a list with numeric index corresponding to field's name''' try: return [self.index(i) for i in param]