Browse Source

Merge branch 'aardbei-integration'

Maarten van den Berg 5 years ago
parent
commit
8dda1a83b9

+ 3 - 0
.gitignore

7
 # Python egg metadata, regenerated from source files by setuptools.
7
 # Python egg metadata, regenerated from source files by setuptools.
8
 /*.egg-info
8
 /*.egg-info
9
 /*.egg
9
 /*.egg
10
+
11
+# Type checking cache
12
+.mypy_cache

+ 6 - 0
.isort.cfg

1
+[settings]
2
+multi_line_output=3
3
+include_trailing_comma=True
4
+force_grid_wrap=0
5
+use_parentheses=True
6
+line_length=88

+ 1 - 0
.python-version

1
+3.7.5

+ 354 - 0
piket_client/cli.py

1
+from typing import Optional
2
+
3
+import click
4
+from prettytable import PrettyTable
5
+
6
+from piket_client.model import (
7
+    AardbeiActivity,
8
+    ServerStatus,
9
+    NetworkError,
10
+    Consumption,
11
+    AardbeiPeopleDiff,
12
+    Person,
13
+    Settlement,
14
+    ConsumptionType,
15
+)
16
+
17
+
18
+@click.group()
19
+def cli():
20
+    """Poke coco from the command line."""
21
+    pass
22
+
23
+
24
+@cli.command()
25
+def status():
26
+    """Show the current status of the server."""
27
+
28
+    status = ServerStatus.is_server_running()
29
+
30
+    if isinstance(status, NetworkError):
31
+        print_error(f"Failed to get data from server, error {status.value}")
32
+        return
33
+
34
+    print_ok("Server is available.")
35
+
36
+    open_consumptions = ServerStatus.unsettled_consumptions()
37
+
38
+    if isinstance(open_consumptions, NetworkError):
39
+        print_error(
40
+            f"Failed to get unsettled consumptions, error {open_consumptions.value}"
41
+        )
42
+        return
43
+
44
+    click.echo(f"There are {open_consumptions.amount} unsettled consumptions.")
45
+
46
+    if open_consumptions.amount > 0:
47
+        click.echo(f"First at: {open_consumptions.first_timestamp.strftime('%c')}")
48
+        click.echo(f"Most recent at: {open_consumptions.last_timestamp.strftime('%c')}")
49
+
50
+
51
+@cli.group()
52
+def people():
53
+    pass
54
+
55
+
56
+@people.command("list")
57
+@click.option("--active/--inactive", default=None)
58
+def list_people(active: bool) -> None:
59
+    people = Person.get_all(active=active)
60
+
61
+    if isinstance(people, NetworkError):
62
+        print_error(f"Could not get people: {people.value}")
63
+        return
64
+
65
+    table = PrettyTable()
66
+    table.field_names = ["ID", "Full name", "Display name", "Active"]
67
+    table.align["ID"] = "r"
68
+    table.align["Full name"] = "l"
69
+    table.align["Display name"] = "l"
70
+    table.sortby = "Full name"
71
+
72
+    for p in people:
73
+        table.add_row([p.person_id, p.full_name, p.display_name, p.active])
74
+
75
+    print(table)
76
+
77
+
78
+@people.command("create")
79
+@click.option("--display-name", type=click.STRING)
80
+@click.argument("name", type=click.STRING)
81
+def create_person(name: str, display_name: str) -> None:
82
+    """Create a person."""
83
+    person = Person(full_name=name, display_name=display_name).create()
84
+
85
+    if isinstance(person, NetworkError):
86
+        print_error(f"Could not create Person: {person.value}")
87
+        return
88
+
89
+    print_ok(f'Created person "{name}" with ID {person.person_id}.')
90
+
91
+
92
+@people.command("rename")
93
+@click.argument("person-id", type=click.INT)
94
+@click.option("--new-full-name", type=click.STRING)
95
+@click.option("--new-display-name", type=click.STRING)
96
+def rename_person(
97
+    person_id: int, new_full_name: Optional[str], new_display_name: Optional[str],
98
+) -> None:
99
+
100
+    person = Person.get(person_id)
101
+
102
+    if person is None:
103
+        raise click.UsageError(f"Cannot find Person {person_id}!")
104
+
105
+    if new_full_name is None and new_display_name is None:
106
+        raise click.UsageError("No new full name or display name specified!")
107
+
108
+    new_person = person.rename(
109
+        new_full_name=new_full_name, new_display_name=new_display_name
110
+    )
111
+
112
+
113
+@cli.group()
114
+def settlements():
115
+    pass
116
+
117
+
118
+@settlements.command("show")
119
+@click.argument("settlement_id", type=click.INT)
120
+def show_settlement(settlement_id: int) -> None:
121
+    """Get and view the contents of a Settlement."""
122
+    s = Settlement.get(settlement_id)
123
+
124
+    if isinstance(s, NetworkError):
125
+        print_error(f"Could not get Settlement: {s.value}")
126
+        return
127
+
128
+    output_settlement_info(s)
129
+
130
+
131
+@settlements.command("create")
132
+@click.argument("name")
133
+def create_settlement(name: str) -> None:
134
+    """Create a new Settlement."""
135
+    s = Settlement.create(name)
136
+
137
+    if isinstance(s, NetworkError):
138
+        print_error(f"Could not create Settlement: {s.value}")
139
+        return
140
+
141
+    output_settlement_info(s)
142
+
143
+
144
+def output_settlement_info(s: Settlement) -> None:
145
+    click.echo(f'Settlement {s.settlement_id}, "{s.name}"')
146
+
147
+    click.echo(f"Summary:")
148
+    for key, value in s.consumption_summary.items():
149
+        click.echo(f" - {value['count']} {value['name']} ({key})")
150
+
151
+    ct_name_by_id = {key: value["name"] for key, value in s.consumption_summary.items()}
152
+
153
+    table = PrettyTable()
154
+    table.field_names = ["Name", *ct_name_by_id.values()]
155
+    table.sortby = "Name"
156
+    table.align = "r"
157
+    table.align["Name"] = "l"  # type: ignore
158
+
159
+    zero_fields = {k: "" for k in ct_name_by_id.values()}
160
+
161
+    for item in s.per_person_counts.values():
162
+        r = {"Name": item["full_name"], **zero_fields}
163
+        for key, value in item["counts"].items():
164
+            r[ct_name_by_id[key]] = value
165
+
166
+        table.add_row(r.values())
167
+
168
+    print(table)
169
+
170
+
171
+@cli.group()
172
+def consumption_types():
173
+    pass
174
+
175
+
176
+@consumption_types.command("list")
177
+def list_consumption_types() -> None:
178
+    active = ConsumptionType.get_all(active=True)
179
+    inactive = ConsumptionType.get_all(active=False)
180
+
181
+    if isinstance(active, NetworkError) or isinstance(inactive, NetworkError):
182
+        print_error("Could not get consumption types!")
183
+        return
184
+
185
+    table = PrettyTable()
186
+    table.field_names = ["ID", "Name", "Active"]
187
+    table.sortby = "ID"
188
+
189
+    for ct in active + inactive:
190
+        table.add_row([ct.consumption_type_id, ct.name, ct.active])
191
+
192
+    print(table)
193
+
194
+
195
+@consumption_types.command("create")
196
+@click.argument("name")
197
+def create_consumption_type(name: str) -> None:
198
+    ct = ConsumptionType(name=name).create()
199
+
200
+    if not isinstance(ct, NetworkError):
201
+        print_ok(f'Created consumption type "{name}" with ID {ct.consumption_type_id}.')
202
+
203
+
204
+@consumption_types.command("activate")
205
+@click.argument("consumption_type_id", type=click.INT)
206
+def activate_consumption_type(consumption_type_id: int) -> None:
207
+    ct = ConsumptionType.get(consumption_type_id)
208
+
209
+    if isinstance(ct, NetworkError):
210
+        print_error(f"Could not get ConsumptionType: {ct.value}")
211
+        return
212
+
213
+    result = ct.set_active(True)
214
+
215
+    if not isinstance(result, NetworkError):
216
+        print_ok(
217
+            f"Consumption type {ct.consumption_type_id} ({ct.name}) is now active."
218
+        )
219
+
220
+
221
+@consumption_types.command("deactivate")
222
+@click.argument("consumption_type_id", type=click.INT)
223
+def deactivate_consumption_type(consumption_type_id: int) -> None:
224
+    ct = ConsumptionType.get(consumption_type_id)
225
+
226
+    if isinstance(ct, NetworkError):
227
+        print_error(f"Could not get ConsumptionType: {ct.value}")
228
+        return
229
+
230
+    result = ct.set_active(False)
231
+
232
+    if not isinstance(result, NetworkError):
233
+        print_ok(
234
+            f"Consumption type {ct.consumption_type_id} ({ct.name}) is now inactive."
235
+        )
236
+
237
+
238
+def print_ok(msg: str) -> None:
239
+    click.echo(click.style(msg, fg="green"))
240
+
241
+
242
+def print_error(msg: str) -> None:
243
+    click.echo(click.style(msg, fg="red", bold=True), err=True)
244
+
245
+
246
+@cli.group()
247
+@click.option("--token", required=True, envvar="AARDBEI_TOKEN")
248
+@click.option("--endpoint", default="http://localhost:3000", envvar="AARDBEI_ENDPOINT")
249
+@click.pass_context
250
+def aardbei(ctx, token: str, endpoint: str) -> None:
251
+    ctx.ensure_object(dict)
252
+    ctx.obj["AardbeiToken"] = token
253
+    ctx.obj["AardbeiEndpoint"] = endpoint
254
+
255
+
256
+@aardbei.group("activities")
257
+def aardbei_activities() -> None:
258
+    pass
259
+
260
+
261
+@aardbei_activities.command("list")
262
+@click.pass_context
263
+def aardbei_list_activities(ctx) -> None:
264
+    acts = AardbeiActivity.get_available(
265
+        token=ctx.obj["AardbeiToken"], endpoint=ctx.obj["AardbeiEndpoint"]
266
+    )
267
+
268
+    if isinstance(acts, NetworkError):
269
+        print_error(f"Could not get activities: {acts.value}")
270
+        return
271
+
272
+    table = PrettyTable()
273
+    table.field_names = ["ID", "Name"]
274
+    table.align = "l"
275
+
276
+    for a in acts:
277
+        table.add_row([a.aardbei_id, a.name])
278
+
279
+    print(table)
280
+
281
+
282
+@aardbei_activities.command("apply")
283
+@click.argument("activity_id", type=click.INT)
284
+@click.pass_context
285
+def aardbei_apply_activity(ctx, activity_id: int) -> None:
286
+    result = AardbeiActivity.apply_activity(
287
+        token=ctx.obj["AardbeiToken"],
288
+        endpoint=ctx.obj["AardbeiEndpoint"],
289
+        activity_id=activity_id,
290
+    )
291
+
292
+    if isinstance(result, NetworkError):
293
+        print_error("Failed to apply activity: {result.value}")
294
+        return
295
+
296
+    print_ok(f"Activity applied. There are now {result} active people.")
297
+
298
+
299
+@aardbei.group("people")
300
+def aardbei_people() -> None:
301
+    pass
302
+
303
+
304
+@aardbei_people.command("diff")
305
+@click.pass_context
306
+def aardbei_diff_people(ctx) -> None:
307
+    diff = AardbeiPeopleDiff.get_diff(
308
+        token=ctx.obj["AardbeiToken"], endpoint=ctx.obj["AardbeiEndpoint"]
309
+    )
310
+
311
+    if isinstance(diff, NetworkError):
312
+        print_error(f"Could not get differences: {diff.value}")
313
+        return
314
+
315
+    if diff.num_changes == 0:
316
+        print_ok("There are no changes to apply.")
317
+        return
318
+
319
+    click.echo(f"There are {diff.num_changes} pending changes:")
320
+    show_diff(diff)
321
+
322
+
323
+@aardbei_people.command("sync")
324
+@click.pass_context
325
+def aardbei_sync_people(ctx) -> None:
326
+    diff = AardbeiPeopleDiff.sync(
327
+        token=ctx.obj["AardbeiToken"], endpoint=ctx.obj["AardbeiEndpoint"]
328
+    )
329
+
330
+    if isinstance(diff, NetworkError):
331
+        print_error(f"Could not apply differences: {diff.value}")
332
+        return
333
+
334
+    if diff.num_changes == 0:
335
+        print_ok("There were no changes to apply.")
336
+        return
337
+
338
+    print_ok(f"Applied {diff.num_changes} pending changes:")
339
+    show_diff(diff)
340
+
341
+
342
+def show_diff(diff: AardbeiPeopleDiff) -> None:
343
+    for name in diff.new_people:
344
+        click.echo(f" - Create local Person for {name}")
345
+
346
+    for name in diff.link_existing:
347
+        click.echo(f" - Link local and remote people for {name}")
348
+
349
+    for name in diff.altered_name:
350
+        click.echo(f" - Process name change for {name}")
351
+
352
+
353
+if __name__ == "__main__":
354
+    cli()

+ 55 - 27
piket_client/gui.py

2
 Provides the graphical front-end for Piket.
2
 Provides the graphical front-end for Piket.
3
 """
3
 """
4
 import collections
4
 import collections
5
+import itertools
5
 import logging
6
 import logging
7
+import math
6
 import os
8
 import os
7
 import sys
9
 import sys
10
+from typing import Deque, Iterator
8
 
11
 
9
 import qdarkstyle
12
 import qdarkstyle
10
 
13
 
24
     QWidget,
27
     QWidget,
25
 )
28
 )
26
 from PySide2.QtGui import QIcon
29
 from PySide2.QtGui import QIcon
27
-from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot
30
+from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot, QUrl
31
+from PySide2.QtMultimedia import QSoundEffect
28
 
32
 
29
 # pylint: enable=E0611
33
 # pylint: enable=E0611
30
 
34
 
33
 except ImportError:
37
 except ImportError:
34
     dbus = None
38
     dbus = None
35
 
39
 
36
-from piket_client.sound import PLOP_WAVE, UNDO_WAVE
40
+from piket_client.sound import PLOP_PATH, UNDO_PATH
37
 from piket_client.model import (
41
 from piket_client.model import (
38
     Person,
42
     Person,
39
     ConsumptionType,
43
     ConsumptionType,
40
     Consumption,
44
     Consumption,
41
     ServerStatus,
45
     ServerStatus,
46
+    NetworkError,
42
     Settlement,
47
     Settlement,
43
 )
48
 )
44
 import piket_client.logger
49
 import piket_client.logger
46
 LOG = logging.getLogger(__name__)
51
 LOG = logging.getLogger(__name__)
47
 
52
 
48
 
53
 
49
-def plop() -> None:
50
-    """ Asynchronously play the plop sound. """
51
-    PLOP_WAVE.play()
52
-
53
-
54
 class NameButton(QPushButton):
54
 class NameButton(QPushButton):
55
     """ Wraps a QPushButton to provide a counter. """
55
     """ Wraps a QPushButton to provide a counter. """
56
 
56
 
77
     @Slot()
77
     @Slot()
78
     def rebuild(self) -> None:
78
     def rebuild(self) -> None:
79
         """ Refresh the Person object and the label. """
79
         """ Refresh the Person object and the label. """
80
-        self.person = self.person.reload()
80
+        self.person = self.person.reload()  # type: ignore
81
         self.setText(self.current_label)
81
         self.setText(self.current_label)
82
 
82
 
83
     @property
83
     @property
96
         LOG.debug("Button clicked.")
96
         LOG.debug("Button clicked.")
97
         result = self.person.add_consumption(self.active_id)
97
         result = self.person.add_consumption(self.active_id)
98
         if result:
98
         if result:
99
-            plop()
99
+            self.window().play_plop()
100
             self.setText(self.current_label)
100
             self.setText(self.current_label)
101
             self.consumption_created.emit(result)
101
             self.consumption_created.emit(result)
102
         else:
102
         else:
148
         LOG.debug("Initializing NameButtons.")
148
         LOG.debug("Initializing NameButtons.")
149
 
149
 
150
         ps = Person.get_all(True)
150
         ps = Person.get_all(True)
151
-        num_columns = round(len(ps) / 10) + 1
151
+        assert not isinstance(ps, NetworkError)
152
+        num_columns = math.ceil(math.sqrt(len(ps)))
152
 
153
 
153
         if self.layout:
154
         if self.layout:
154
             LOG.debug("Removing %s widgets for rebuild", self.layout.count())
155
             LOG.debug("Removing %s widgets for rebuild", self.layout.count())
173
 
174
 
174
     consumption_type_changed = Signal(str)
175
     consumption_type_changed = Signal(str)
175
 
176
 
177
+    plop_loop: Iterator[QSoundEffect]
178
+    undo_loop: Iterator[QSoundEffect]
179
+
176
     def __init__(self) -> None:
180
     def __init__(self) -> None:
177
         LOG.debug("Initializing PiketMainWindow.")
181
         LOG.debug("Initializing PiketMainWindow.")
178
         super().__init__()
182
         super().__init__()
182
         self.toolbar = None
186
         self.toolbar = None
183
         self.osk = None
187
         self.osk = None
184
         self.undo_action = None
188
         self.undo_action = None
185
-        self.undo_queue = collections.deque([], 15)
189
+        self.undo_queue: Deque[Consumption] = collections.deque([], 15)
186
         self.init_ui()
190
         self.init_ui()
187
 
191
 
188
     def init_ui(self) -> None:
192
     def init_ui(self) -> None:
211
 
215
 
212
         # Initialize toolbar
216
         # Initialize toolbar
213
         self.toolbar = QToolBar()
217
         self.toolbar = QToolBar()
218
+        assert self.toolbar is not None
214
         self.toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
219
         self.toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
215
         self.toolbar.setIconSize(QSize(icon_size, icon_size))
220
         self.toolbar.setIconSize(QSize(icon_size, icon_size))
216
 
221
 
238
         self.toolbar.setContextMenuPolicy(Qt.PreventContextMenu)
243
         self.toolbar.setContextMenuPolicy(Qt.PreventContextMenu)
239
         self.toolbar.setFloatable(False)
244
         self.toolbar.setFloatable(False)
240
         self.toolbar.setMovable(False)
245
         self.toolbar.setMovable(False)
241
-        self.ct_ag = QActionGroup(self.toolbar)
246
+        self.ct_ag: QActionGroup = QActionGroup(self.toolbar)
242
         self.ct_ag.setExclusive(True)
247
         self.ct_ag.setExclusive(True)
243
 
248
 
244
         cts = ConsumptionType.get_all()
249
         cts = ConsumptionType.get_all()
287
 
292
 
288
         self.addToolBar(self.toolbar)
293
         self.addToolBar(self.toolbar)
289
 
294
 
295
+        # Load sounds
296
+        plops = [QSoundEffect(self) for _ in range(7)]
297
+        for qse in plops:
298
+            qse.setSource(QUrl.fromLocalFile(str(PLOP_PATH)))
299
+        self.plop_loop = itertools.cycle(plops)
300
+
301
+        undos = [QSoundEffect(self) for _ in range(5)]
302
+        for qse in undos:
303
+            qse.setSource(QUrl.fromLocalFile(str(UNDO_PATH)))
304
+        self.undo_loop = itertools.cycle(undos)
305
+
290
         # Initialize main widget
306
         # Initialize main widget
291
         self.main_widget = NameButtons(self.ct_ag.actions()[0].data(), self)
307
         self.main_widget = NameButtons(self.ct_ag.actions()[0].data(), self)
292
         self.consumption_type_changed.connect(self.main_widget.consumption_type_changed)
308
         self.consumption_type_changed.connect(self.main_widget.consumption_type_changed)
310
         """ Ask for a new Person and register it, then rebuild the central
326
         """ Ask for a new Person and register it, then rebuild the central
311
         widget. """
327
         widget. """
312
         inactive_persons = Person.get_all(False)
328
         inactive_persons = Person.get_all(False)
329
+        assert not isinstance(inactive_persons, NetworkError)
330
+
313
         inactive_persons.sort(key=lambda p: p.name)
331
         inactive_persons.sort(key=lambda p: p.name)
314
         inactive_names = [p.name for p in inactive_persons]
332
         inactive_names = [p.name for p in inactive_persons]
315
 
333
 
330
                 person.set_active(True)
348
                 person.set_active(True)
331
 
349
 
332
             else:
350
             else:
333
-                person = Person(name=name)
334
-                person = person.create()
351
+                person = Person(full_name=name, display_name=None,)
352
+                person.create()
335
 
353
 
354
+            assert self.main_widget is not None
336
             self.main_widget.init_ui()
355
             self.main_widget.init_ui()
337
 
356
 
338
     def add_consumption_type(self) -> None:
357
     def add_consumption_type(self) -> None:
343
         self.hide_keyboard()
362
         self.hide_keyboard()
344
 
363
 
345
         if ok and name:
364
         if ok and name:
346
-            ct = ConsumptionType(name=name)
347
-            ct = ct.create()
365
+            ct = ConsumptionType(name=name).create()
366
+            assert not isinstance(ct, NetworkError)
348
 
367
 
349
             action = QAction(
368
             action = QAction(
350
                 self.load_icon(ct.icon or "beer_bottle.svg"), ct.name, self.ct_ag
369
                 self.load_icon(ct.icon or "beer_bottle.svg"), ct.name, self.ct_ag
352
             action.setCheckable(True)
371
             action.setCheckable(True)
353
             action.setData(str(ct.consumption_type_id))
372
             action.setData(str(ct.consumption_type_id))
354
 
373
 
374
+            assert self.toolbar is not None
355
             self.toolbar.addAction(action)
375
             self.toolbar.addAction(action)
356
 
376
 
357
     def confirm_quit(self) -> None:
377
     def confirm_quit(self) -> None:
370
 
390
 
371
     def do_undo(self) -> None:
391
     def do_undo(self) -> None:
372
         """ Undo the last marked consumption. """
392
         """ Undo the last marked consumption. """
373
-        UNDO_WAVE.play()
393
+        next(self.undo_loop).play()
374
 
394
 
375
         to_undo = self.undo_queue.pop()
395
         to_undo = self.undo_queue.pop()
376
         LOG.warning("Undoing consumption %s", to_undo)
396
         LOG.warning("Undoing consumption %s", to_undo)
382
             self.undo_queue.append(to_undo)
402
             self.undo_queue.append(to_undo)
383
 
403
 
384
         elif not self.undo_queue:
404
         elif not self.undo_queue:
405
+            assert self.undo_action is not None
385
             self.undo_action.setDisabled(True)
406
             self.undo_action.setDisabled(True)
386
 
407
 
408
+        assert self.main_widget is not None
387
         self.main_widget.init_ui()
409
         self.main_widget.init_ui()
388
 
410
 
389
     @Slot(Consumption)
411
     @Slot(Consumption)
412
         icon = QIcon(os.path.join(self.icons_dir, filename))
434
         icon = QIcon(os.path.join(self.icons_dir, filename))
413
         return icon
435
         return icon
414
 
436
 
437
+    def play_plop(self) -> None:
438
+        next(self.plop_loop).play()
439
+
415
 
440
 
416
 def main() -> None:
441
 def main() -> None:
417
     """ Main entry point of GUI client. """
442
     """ Main entry point of GUI client. """
428
     app.setFont(font)
453
     app.setFont(font)
429
 
454
 
430
     # Test connectivity
455
     # Test connectivity
431
-    server_running, info = ServerStatus.is_server_running()
456
+    server_running = ServerStatus.is_server_running()
432
 
457
 
433
-    if not server_running:
434
-        LOG.critical("Could not connect to server", extra={"info": info})
458
+    if isinstance(server_running, NetworkError):
459
+        LOG.critical("Could not connect to server, error %s", server_running.value)
435
         QMessageBox.critical(
460
         QMessageBox.critical(
436
             None,
461
             None,
437
             "Help er is iets kapot",
462
             "Help er is iets kapot",
438
             "Kan niet starten omdat de server niet reageert, stuur een foto van "
463
             "Kan niet starten omdat de server niet reageert, stuur een foto van "
439
-            "dit naar Maarten: " + repr(info),
464
+            "dit naar Maarten: " + repr(server_running.value),
440
         )
465
         )
441
-        return 1
466
+        return
442
 
467
 
443
     # Load main window
468
     # Load main window
444
     main_window = PiketMainWindow()
469
     main_window = PiketMainWindow()
445
 
470
 
446
     # Test unsettled consumptions
471
     # Test unsettled consumptions
447
     status = ServerStatus.unsettled_consumptions()
472
     status = ServerStatus.unsettled_consumptions()
473
+    assert not isinstance(status, NetworkError)
448
 
474
 
449
-    unsettled = status["unsettled"]["amount"]
475
+    unsettled = status.amount
450
 
476
 
451
     if unsettled > 0:
477
     if unsettled > 0:
452
-        first = status["unsettled"]["first"]
478
+        assert status.first_timestamp is not None
479
+
480
+        first = status.first_timestamp
453
         first_date = first.strftime("%c")
481
         first_date = first.strftime("%c")
454
         ok = QMessageBox.information(
482
         ok = QMessageBox.information(
455
             None,
483
             None,
464
             name, ok = QInputDialog.getText(
492
             name, ok = QInputDialog.getText(
465
                 None,
493
                 None,
466
                 "Lijst afsluiten",
494
                 "Lijst afsluiten",
467
-                "Voer een naam in voor de lijst of druk op OK. Laat de datum " "staan.",
495
+                "Voer een naam in voor de lijst of druk op OK. Laat de datum staan.",
468
                 QLineEdit.Normal,
496
                 QLineEdit.Normal,
469
                 f"{first.strftime('%Y-%m-%d')}",
497
                 f"{first.strftime('%Y-%m-%d')}",
470
             )
498
             )
476
                     f'{item["count"]} {item["name"]}'
504
                     f'{item["count"]} {item["name"]}'
477
                     for item in settlement.consumption_summary.values()
505
                     for item in settlement.consumption_summary.values()
478
                 ]
506
                 ]
479
-                info = ", ".join(info)
507
+                info2 = ", ".join(info)
480
                 QMessageBox.information(
508
                 QMessageBox.information(
481
-                    None, "Lijst afgesloten", f"VO! Op deze lijst stonden: {info}"
509
+                    None, "Lijst afgesloten", f"VO! Op deze lijst stonden: {info2}"
482
                 )
510
                 )
483
 
511
 
484
                 main_window = PiketMainWindow()
512
                 main_window = PiketMainWindow()

+ 363 - 127
piket_client/model.py

1
 """
1
 """
2
 Provides access to the models stored in the database, via the server.
2
 Provides access to the models stored in the database, via the server.
3
 """
3
 """
4
+from __future__ import annotations
5
+
4
 import datetime
6
 import datetime
7
+import enum
5
 import logging
8
 import logging
6
-from typing import NamedTuple, Sequence
9
+from dataclasses import dataclass
10
+from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
7
 from urllib.parse import urljoin
11
 from urllib.parse import urljoin
8
 
12
 
9
 import requests
13
 import requests
10
 
14
 
11
-
12
 LOG = logging.getLogger(__name__)
15
 LOG = logging.getLogger(__name__)
13
 
16
 
14
 SERVER_URL = "http://127.0.0.1:5000"
17
 SERVER_URL = "http://127.0.0.1:5000"
15
 DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
18
 DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
16
 
19
 
17
 
20
 
21
+class NetworkError(enum.Enum):
22
+    """Represents errors that might occur when communicating with the server."""
23
+
24
+    HttpFailure = "http_failure"
25
+    """Returned when the server returns a non-successful status code."""
26
+
27
+    ConnectionFailure = "connection_failure"
28
+    """Returned when we can't connect to the server at all."""
29
+
30
+    InvalidData = "invalid_data"
31
+
32
+
18
 class ServerStatus:
33
 class ServerStatus:
19
     """ Provides helper classes to check whether the server is up. """
34
     """ Provides helper classes to check whether the server is up. """
20
 
35
 
21
     @classmethod
36
     @classmethod
22
-    def is_server_running(cls) -> bool:
37
+    def is_server_running(cls) -> Union[bool, NetworkError]:
23
         try:
38
         try:
24
             req = requests.get(urljoin(SERVER_URL, "ping"))
39
             req = requests.get(urljoin(SERVER_URL, "ping"))
25
-
26
-            if req.status_code == 200:
27
-                return True, req.content
28
-            return False, req.content
40
+            req.raise_for_status()
29
 
41
 
30
         except requests.ConnectionError as ex:
42
         except requests.ConnectionError as ex:
31
-            return False, ex
43
+            LOG.exception(ex)
44
+            return NetworkError.ConnectionFailure
45
+
46
+        except requests.HTTPError as ex:
47
+            LOG.exception(ex)
48
+            return NetworkError.HttpFailure
49
+
50
+        return True
51
+
52
+    @dataclass(frozen=True)
53
+    class OpenConsumptions:
54
+        amount: int
55
+        first_timestamp: Optional[datetime.datetime]
56
+        last_timestamp: Optional[datetime.datetime]
32
 
57
 
33
     @classmethod
58
     @classmethod
34
-    def unsettled_consumptions(cls) -> dict:
35
-        req = requests.get(urljoin(SERVER_URL, "status"))
59
+    def unsettled_consumptions(cls) -> Union[OpenConsumptions, NetworkError]:
60
+        try:
61
+            req = requests.get(urljoin(SERVER_URL, "status"))
62
+            req.raise_for_status()
63
+            data = req.json()
36
 
64
 
37
-        data = req.json()
65
+        except requests.ConnectionError as e:
66
+            LOG.exception(e)
67
+            return NetworkError.ConnectionFailure
38
 
68
 
39
-        if data["unsettled"]["amount"]:
40
-            data["unsettled"]["first"] = datetime.datetime.strptime(
41
-                data["unsettled"]["first"], DATETIME_FORMAT
42
-            )
43
-            data["unsettled"]["last"] = datetime.datetime.strptime(
44
-                data["unsettled"]["last"], DATETIME_FORMAT
69
+        except requests.HTTPError as e:
70
+            LOG.exception(e)
71
+            return NetworkError.HttpFailure
72
+
73
+        except ValueError as e:
74
+            LOG.exception(e)
75
+            return NetworkError.InvalidData
76
+
77
+        amount: int = data["unsettled"]["amount"]
78
+
79
+        if amount == 0:
80
+            return cls.OpenConsumptions(
81
+                amount=0, first_timestamp=None, last_timestamp=None
45
             )
82
             )
46
 
83
 
47
-        return data
84
+        first = datetime.datetime.fromisoformat(data["unsettled"]["first"])
85
+        last = datetime.datetime.fromisoformat(data["unsettled"]["last"])
86
+
87
+        return cls.OpenConsumptions(
88
+            amount=amount, first_timestamp=first, last_timestamp=last
89
+        )
48
 
90
 
49
 
91
 
50
 class Person(NamedTuple):
92
 class Person(NamedTuple):
51
     """ Represents a Person, as retrieved from the database. """
93
     """ Represents a Person, as retrieved from the database. """
52
 
94
 
53
-    name: str
95
+    full_name: str
96
+    display_name: Optional[str]
54
     active: bool = True
97
     active: bool = True
55
-    person_id: int = None
98
+    person_id: Optional[int] = None
56
     consumptions: dict = {}
99
     consumptions: dict = {}
57
 
100
 
58
-    def add_consumption(self, type_id: str) -> bool:
101
+    @property
102
+    def name(self) -> str:
103
+        return self.display_name or self.full_name
104
+
105
+    def add_consumption(self, type_id: str) -> Optional[Consumption]:
59
         """ Register a consumption for this Person. """
106
         """ Register a consumption for this Person. """
60
         req = requests.post(
107
         req = requests.post(
61
             urljoin(SERVER_URL, f"people/{self.person_id}/add_consumption/{type_id}")
108
             urljoin(SERVER_URL, f"people/{self.person_id}/add_consumption/{type_id}")
70
                     req.status_code,
117
                     req.status_code,
71
                     data,
118
                     data,
72
                 )
119
                 )
73
-                return False
120
+                return None
74
 
121
 
75
             self.consumptions.update(data["person"]["consumptions"])
122
             self.consumptions.update(data["person"]["consumptions"])
76
 
123
 
81
                 req.status_code,
128
                 req.status_code,
82
                 req.content,
129
                 req.content,
83
             )
130
             )
84
-            return False
131
+            return None
85
 
132
 
86
-    def create(self) -> "Person":
133
+    def create(self) -> Union[Person, NetworkError]:
87
         """ Create a new Person from the current attributes. As tuples are
134
         """ Create a new Person from the current attributes. As tuples are
88
         immutable, a new Person with the correct id is returned. """
135
         immutable, a new Person with the correct id is returned. """
89
-        req = requests.post(
90
-            urljoin(SERVER_URL, "people"),
91
-            json={"person": {"name": self.name, "active": True}},
136
+
137
+        try:
138
+            req = requests.post(
139
+                urljoin(SERVER_URL, "people"),
140
+                json={
141
+                    "person": {
142
+                        "full_name": self.full_name,
143
+                        "display_name": self.display_name,
144
+                        "active": True,
145
+                    }
146
+                },
147
+            )
148
+            req.raise_for_status()
149
+            data = req.json()
150
+            return Person.from_dict(data["person"])
151
+
152
+        except requests.ConnectionError as e:
153
+            LOG.exception(e)
154
+            return NetworkError.ConnectionFailure
155
+
156
+        except requests.HTTPError as e:
157
+            LOG.exception(e)
158
+            return NetworkError.HttpFailure
159
+
160
+        except ValueError as e:
161
+            LOG.exception(e)
162
+            return NetworkError.InvalidData
163
+
164
+    def rename(
165
+        self, new_full_name: Optional[str], new_display_name: Optional[str]
166
+    ) -> Optional[Person]:
167
+        person_payload: Dict[str, str] = {}
168
+
169
+        if new_full_name is not None:
170
+            person_payload["full_name"] = new_full_name
171
+
172
+        if new_display_name is not None:
173
+            person_payload["display_name"] = new_display_name
174
+
175
+        req = requests.patch(
176
+            urljoin(SERVER_URL, f"people/{self.person_id}"),
177
+            json={"person": person_payload},
92
         )
178
         )
93
 
179
 
94
         try:
180
         try:
95
             data = req.json()
181
             data = req.json()
96
         except ValueError:
182
         except ValueError:
97
             LOG.error(
183
             LOG.error(
98
-                "Did not get JSON on adding Person (%s): %s",
184
+                "Did not get JSON on updating Person (%s): %s",
99
                 req.status_code,
185
                 req.status_code,
100
                 req.content,
186
                 req.content,
101
             )
187
             )
102
             return None
188
             return None
103
 
189
 
104
-        if "error" in data or req.status_code != 201:
105
-            LOG.error("Could not create Person (%s): %s", req.status_code, data)
190
+        if "error" in data or req.status_code != 200:
191
+            LOG.error("Could not update Person (%s): %s", req.status_code, data)
106
             return None
192
             return None
107
 
193
 
108
         return Person.from_dict(data["person"])
194
         return Person.from_dict(data["person"])
109
 
195
 
110
-    def set_active(self, new_state=True) -> "Person":
196
+    def set_active(self, new_state=True) -> Optional[Person]:
111
         req = requests.patch(
197
         req = requests.patch(
112
             urljoin(SERVER_URL, f"people/{self.person_id}"),
198
             urljoin(SERVER_URL, f"people/{self.person_id}"),
113
             json={"person": {"active": new_state}},
199
             json={"person": {"active": new_state}},
130
         return Person.from_dict(data["person"])
216
         return Person.from_dict(data["person"])
131
 
217
 
132
     @classmethod
218
     @classmethod
133
-    def get(cls, person_id: int) -> "Person":
219
+    def get(cls, person_id: int) -> Optional[Person]:
134
         """ Retrieve a Person by id. """
220
         """ Retrieve a Person by id. """
135
         req = requests.get(urljoin(SERVER_URL, f"/people/{person_id}"))
221
         req = requests.get(urljoin(SERVER_URL, f"/people/{person_id}"))
136
 
222
 
154
             return None
240
             return None
155
 
241
 
156
     @classmethod
242
     @classmethod
157
-    def get_all(cls, active=None) -> ["Person"]:
243
+    def get_all(cls, active=None) -> Union[List[Person], NetworkError]:
158
         """ Get all active People. """
244
         """ Get all active People. """
159
         params = {}
245
         params = {}
160
         if active is not None:
246
         if active is not None:
161
             params["active"] = int(active)
247
             params["active"] = int(active)
162
 
248
 
163
-        req = requests.get(urljoin(SERVER_URL, "/people"), params=params)
164
-
165
         try:
249
         try:
250
+            req = requests.get(urljoin(SERVER_URL, "/people"), params=params)
251
+            req.raise_for_status()
166
             data = req.json()
252
             data = req.json()
253
+            return [Person.from_dict(item) for item in data["people"]]
167
 
254
 
168
-            if "error" in data:
169
-                LOG.warning("Could not get people (%s): %s", req.status_code, data)
255
+        except requests.ConnectionError as e:
256
+            LOG.exception(e)
257
+            return NetworkError.ConnectionFailure
170
 
258
 
171
-            return [Person.from_dict(item) for item in data["people"]]
259
+        except requests.HTTPError as e:
260
+            LOG.exception(e)
261
+            return NetworkError.HttpFailure
172
 
262
 
173
-        except ValueError:
174
-            LOG.error(
175
-                "Did not get JSON from server on getting People (%s): %s",
176
-                req.status_code,
177
-                req.content,
178
-            )
179
-            return None
263
+        except ValueError as e:
264
+            LOG.exception(e)
265
+            return NetworkError.InvalidData
180
 
266
 
181
     @classmethod
267
     @classmethod
182
     def from_dict(cls, data: dict) -> "Person":
268
     def from_dict(cls, data: dict) -> "Person":
183
         """ Reconstruct a Person object from a dict. """
269
         """ Reconstruct a Person object from a dict. """
184
         return Person(
270
         return Person(
185
-            name=data["name"],
271
+            full_name=data["full_name"],
272
+            display_name=data["display_name"],
186
             active=data["active"],
273
             active=data["active"],
187
             person_id=data["person_id"],
274
             person_id=data["person_id"],
188
             consumptions=data["consumptions"],
275
             consumptions=data["consumptions"],
206
         )
293
         )
207
 
294
 
208
     @classmethod
295
     @classmethod
209
-    def get_all(cls) -> ["Export"]:
296
+    def get_all(cls) -> Optional[List[Export]]:
210
         """ Get a list of all existing Exports. """
297
         """ Get a list of all existing Exports. """
211
         req = requests.get(urljoin(SERVER_URL, "exports"))
298
         req = requests.get(urljoin(SERVER_URL, "exports"))
212
 
299
 
227
         return [cls.from_dict(e) for e in data["exports"]]
314
         return [cls.from_dict(e) for e in data["exports"]]
228
 
315
 
229
     @classmethod
316
     @classmethod
230
-    def get(cls, export_id: int) -> "Export":
317
+    def get(cls, export_id: int) -> Optional[Export]:
231
         """ Retrieve one Export. """
318
         """ Retrieve one Export. """
232
         req = requests.get(urljoin(SERVER_URL, f"exports/{export_id}"))
319
         req = requests.get(urljoin(SERVER_URL, f"exports/{export_id}"))
233
 
320
 
250
         return cls.from_dict(data["export"])
337
         return cls.from_dict(data["export"])
251
 
338
 
252
     @classmethod
339
     @classmethod
253
-    def create(cls) -> "Export":
340
+    def create(cls) -> Optional[Export]:
254
         """ Create a new Export, containing all un-exported Settlements. """
341
         """ Create a new Export, containing all un-exported Settlements. """
255
         req = requests.post(urljoin(SERVER_URL, "exports"))
342
         req = requests.post(urljoin(SERVER_URL, "exports"))
256
 
343
 
277
     """ Represents a stored ConsumptionType. """
364
     """ Represents a stored ConsumptionType. """
278
 
365
 
279
     name: str
366
     name: str
280
-    consumption_type_id: int = None
281
-    icon: str = None
367
+    consumption_type_id: Optional[int] = None
368
+    icon: Optional[str] = None
369
+    active: bool = True
282
 
370
 
283
-    def create(self) -> "ConsumptionType":
371
+    def create(self) -> Union[ConsumptionType, NetworkError]:
284
         """ Create a new ConsumptionType from the current attributes. As tuples
372
         """ Create a new ConsumptionType from the current attributes. As tuples
285
         are immutable, a new ConsumptionType with the correct id is returned.
373
         are immutable, a new ConsumptionType with the correct id is returned.
286
         """
374
         """
287
-        req = requests.post(
288
-            urljoin(SERVER_URL, "consumption_types"),
289
-            json={"consumption_type": {"name": self.name, "icon": self.icon}},
290
-        )
291
-
292
         try:
375
         try:
293
-            data = req.json()
294
-        except ValueError:
295
-            LOG.error(
296
-                "Did not get JSON on adding ConsumptionType (%s): %s",
297
-                req.status_code,
298
-                req.content,
376
+            req = requests.post(
377
+                urljoin(SERVER_URL, "consumption_types"),
378
+                json={"consumption_type": {"name": self.name, "icon": self.icon}},
299
             )
379
             )
300
-            return None
301
 
380
 
302
-        if "error" in data or req.status_code != 201:
303
-            LOG.error(
304
-                "Could not create ConsumptionType (%s): %s", req.status_code, data
305
-            )
306
-            return None
381
+            req.raise_for_status()
382
+            data = req.json()
383
+            return ConsumptionType.from_dict(data["consumption_type"])
307
 
384
 
308
-        return ConsumptionType.from_dict(data["consumption_type"])
385
+        except requests.ConnectionError as e:
386
+            LOG.exception(e)
387
+            return NetworkError.ConnectionFailure
388
+
389
+        except requests.HTTPError as e:
390
+            LOG.exception(e)
391
+            return NetworkError.HttpFailure
392
+
393
+        except ValueError as e:
394
+            LOG.exception(e)
395
+            return NetworkError.InvalidData
309
 
396
 
310
     @classmethod
397
     @classmethod
311
-    def get(cls, consumption_type_id: int) -> "ConsumptionType":
398
+    def get(cls, consumption_type_id: int) -> Union[ConsumptionType, NetworkError]:
312
         """ Retrieve a ConsumptionType by id. """
399
         """ Retrieve a ConsumptionType by id. """
313
-        req = requests.get(
314
-            urljoin(SERVER_URL, f"/consumption_types/{consumption_type_id}")
315
-        )
316
-
317
         try:
400
         try:
401
+            req = requests.get(
402
+                urljoin(SERVER_URL, f"/consumption_types/{consumption_type_id}")
403
+            )
404
+            req.raise_for_status()
318
             data = req.json()
405
             data = req.json()
319
 
406
 
320
-            if "error" in data:
321
-                LOG.warning(
322
-                    "Could not get consumption type %s (%s): %s",
323
-                    consumption_type_id,
324
-                    req.status_code,
325
-                    data,
326
-                )
327
-                return None
407
+        except requests.ConnectionError as e:
408
+            LOG.exception(e)
409
+            return NetworkError.ConnectionFailure
328
 
410
 
329
-            return cls.from_dict(data["consumption_type"])
411
+        except requests.HTTPError as e:
412
+            LOG.exception(e)
413
+            return NetworkError.HttpFailure
330
 
414
 
331
-        except ValueError:
332
-            LOG.error(
333
-                "Did not get JSON from server on getting consumption type (%s): %s",
334
-                req.status_code,
335
-                req.content,
336
-            )
337
-            return None
415
+        except ValueError as e:
416
+            LOG.exception(e)
417
+            return NetworkError.InvalidData
338
 
418
 
339
-    @classmethod
340
-    def get_all(cls) -> ["ConsumptionType"]:
341
-        """ Get all active ConsumptionTypes. """
342
-        req = requests.get(urljoin(SERVER_URL, "/consumption_types"))
419
+        return cls.from_dict(data["consumption_type"])
343
 
420
 
421
+    @classmethod
422
+    def get_all(cls, active: bool = True) -> Union[List[ConsumptionType], NetworkError]:
423
+        """ Get the list of ConsumptionTypes. """
344
         try:
424
         try:
425
+            req = requests.get(
426
+                urljoin(SERVER_URL, "/consumption_types"),
427
+                params={"active": int(active)},
428
+            )
429
+            req.raise_for_status()
430
+
345
             data = req.json()
431
             data = req.json()
346
 
432
 
347
-            if "error" in data:
348
-                LOG.warning(
349
-                    "Could not get consumption types (%s): %s", req.status_code, data
350
-                )
433
+        except requests.ConnectionError as e:
434
+            LOG.exception(e)
435
+            return NetworkError.ConnectionFailure
351
 
436
 
352
-            return [cls.from_dict(item) for item in data["consumption_types"]]
437
+        except requests.HTTPError as e:
438
+            LOG.exception(e)
439
+            return NetworkError.HttpFailure
353
 
440
 
354
-        except ValueError:
355
-            LOG.error(
356
-                "Did not get JSON from server on getting ConsumptionTypes (%s): %s",
357
-                req.status_code,
358
-                req.content,
359
-            )
360
-            return None
441
+        except ValueError as e:
442
+            LOG.exception(e)
443
+            return NetworkError.InvalidData
444
+
445
+        return [cls.from_dict(x) for x in data["consumption_types"]]
361
 
446
 
362
     @classmethod
447
     @classmethod
363
     def from_dict(cls, data: dict) -> "ConsumptionType":
448
     def from_dict(cls, data: dict) -> "ConsumptionType":
366
             name=data["name"],
451
             name=data["name"],
367
             consumption_type_id=data["consumption_type_id"],
452
             consumption_type_id=data["consumption_type_id"],
368
             icon=data.get("icon"),
453
             icon=data.get("icon"),
454
+            active=data["active"],
369
         )
455
         )
370
 
456
 
457
+    def set_active(self, active: bool) -> Union[ConsumptionType, NetworkError]:
458
+        """Update the 'active' attribute."""
459
+        try:
460
+            req = requests.patch(
461
+                urljoin(SERVER_URL, f"/consumption_types/{self.consumption_type_id}"),
462
+                json={"consumption_type": {"active": active}},
463
+            )
464
+            req.raise_for_status()
465
+            data = req.json()
466
+
467
+        except requests.ConnectionError as e:
468
+            LOG.exception(e)
469
+            return NetworkError.ConnectionFailure
470
+
471
+        except requests.HTTPError as e:
472
+            LOG.exception(e)
473
+            return NetworkError.HttpFailure
474
+
475
+        except ValueError as e:
476
+            LOG.exception(e)
477
+            return NetworkError.InvalidData
478
+
479
+        return self.from_dict(data["consumption_type"])
480
+
371
 
481
 
372
 class Consumption(NamedTuple):
482
 class Consumption(NamedTuple):
373
     """ Represents a stored Consumption. """
483
     """ Represents a stored Consumption. """
377
     consumption_type_id: int
487
     consumption_type_id: int
378
     created_at: datetime.datetime
488
     created_at: datetime.datetime
379
     reversed: bool = False
489
     reversed: bool = False
380
-    settlement_id: int = None
490
+    settlement_id: Optional[int] = None
381
 
491
 
382
     @classmethod
492
     @classmethod
383
     def from_dict(cls, data: dict) -> "Consumption":
493
     def from_dict(cls, data: dict) -> "Consumption":
391
             reversed=data["reversed"],
501
             reversed=data["reversed"],
392
         )
502
         )
393
 
503
 
394
-    def reverse(self) -> "Consumption":
504
+    def reverse(self) -> Optional[Consumption]:
395
         """ Reverse this consumption. """
505
         """ Reverse this consumption. """
396
         req = requests.delete(
506
         req = requests.delete(
397
             urljoin(SERVER_URL, f"/consumptions/{self.consumption_id}")
507
             urljoin(SERVER_URL, f"/consumptions/{self.consumption_id}")
407
                     req.status_code,
517
                     req.status_code,
408
                     data,
518
                     data,
409
                 )
519
                 )
410
-                return False
520
+                return None
411
 
521
 
412
             return Consumption.from_dict(data["consumption"])
522
             return Consumption.from_dict(data["consumption"])
413
 
523
 
417
                 req.status_code,
527
                 req.status_code,
418
                 req.content,
528
                 req.content,
419
             )
529
             )
420
-            return False
530
+            return None
421
 
531
 
422
 
532
 
423
 class Settlement(NamedTuple):
533
 class Settlement(NamedTuple):
425
 
535
 
426
     settlement_id: int
536
     settlement_id: int
427
     name: str
537
     name: str
428
-    consumption_summary: dict
429
-    count_info: dict = {}
538
+    consumption_summary: Dict[str, Any]
539
+    count_info: Dict[str, Any] = {}
540
+    per_person_counts: Dict[str, Any] = {}
430
 
541
 
431
     @classmethod
542
     @classmethod
432
     def from_dict(cls, data: dict) -> "Settlement":
543
     def from_dict(cls, data: dict) -> "Settlement":
434
             settlement_id=data["settlement_id"],
545
             settlement_id=data["settlement_id"],
435
             name=data["name"],
546
             name=data["name"],
436
             consumption_summary=data["consumption_summary"],
547
             consumption_summary=data["consumption_summary"],
437
-            count_info=data.get("count_info", {}),
548
+            count_info=data["count_info"],
549
+            per_person_counts=data["per_person_counts"],
438
         )
550
         )
439
 
551
 
440
     @classmethod
552
     @classmethod
446
         return cls.from_dict(req.json()["settlement"])
558
         return cls.from_dict(req.json()["settlement"])
447
 
559
 
448
     @classmethod
560
     @classmethod
449
-    def get(cls, settlement_id: int) -> "Settlement":
450
-        req = requests.get(urljoin(SERVER_URL, f"/settlements/{settlement_id}"))
451
-
561
+    def get(cls, settlement_id: int) -> Union[Settlement, NetworkError]:
452
         try:
562
         try:
563
+            req = requests.get(urljoin(SERVER_URL, f"/settlements/{settlement_id}"))
564
+            req.raise_for_status()
453
             data = req.json()
565
             data = req.json()
454
-        except ValueError:
455
-            LOG.error(
456
-                "Did not get JSON on retrieving Settlement (%s): %s",
457
-                req.status_code,
458
-                req.content,
459
-            )
460
-            return None
461
 
566
 
462
-        if "error" in data or req.status_code != 200:
463
-            LOG.error("Could not get Export (%s): %s", req.status_code, data)
464
-            return None
567
+        except ValueError as e:
568
+            LOG.exception(e)
569
+            return NetworkError.InvalidData
570
+
571
+        except requests.ConnectionError as e:
572
+            LOG.exception(e)
573
+            return NetworkError.ConnectionFailure
574
+
575
+        except requests.HTTPError as e:
576
+            LOG.exception(e)
577
+            return NetworkError.HttpFailure
465
 
578
 
466
         data["settlement"]["count_info"] = data["count_info"]
579
         data["settlement"]["count_info"] = data["count_info"]
467
 
580
 
468
         return cls.from_dict(data["settlement"])
581
         return cls.from_dict(data["settlement"])
582
+
583
+
584
+@dataclass(frozen=True)
585
+class AardbeiActivity:
586
+    aardbei_id: int
587
+    name: str
588
+
589
+    @classmethod
590
+    def from_dict(cls, data: Dict[str, Any]) -> AardbeiActivity:
591
+        return cls(data["activity"]["id"], data["activity"]["name"])
592
+
593
+    @classmethod
594
+    def get_available(
595
+        cls, token: str, endpoint: str
596
+    ) -> Union[List[AardbeiActivity], NetworkError]:
597
+        try:
598
+            req = requests.post(
599
+                urljoin(SERVER_URL, "/aardbei/get_activities"),
600
+                json={"endpoint": endpoint, "token": token},
601
+            )
602
+
603
+            req.raise_for_status()
604
+            return [cls.from_dict(x) for x in req.json()["activities"]]
605
+
606
+        except requests.ConnectionError as e:
607
+            LOG.exception(e)
608
+            return NetworkError.ConnectionFailure
609
+
610
+        except requests.HTTPError as e:
611
+            LOG.exception(e)
612
+            return NetworkError.HttpFailure
613
+
614
+        except ValueError as e:
615
+            LOG.exception(e)
616
+            return NetworkError.InvalidData
617
+
618
+    @classmethod
619
+    def apply_activity(
620
+        cls, token: str, endpoint: str, activity_id: int
621
+    ) -> Union[int, NetworkError]:
622
+        try:
623
+            req = requests.post(
624
+                urljoin(SERVER_URL, "/aardbei/apply_activity"),
625
+                json={"activity_id": activity_id, "token": token, "endpoint": endpoint},
626
+            )
627
+            req.raise_for_status()
628
+            data = req.json()
629
+
630
+            return data["activity"]["response_counts"]["present"]
631
+
632
+        except requests.ConnectionError as e:
633
+            LOG.exception(e)
634
+            return NetworkError.ConnectionFailure
635
+
636
+        except requests.HTTPError as e:
637
+            LOG.exception(e)
638
+            return NetworkError.HttpFailure
639
+
640
+        except ValueError as e:
641
+            LOG.exception(e)
642
+            return NetworkError.InvalidData
643
+
644
+
645
+@dataclass(frozen=True)
646
+class AardbeiPeopleDiff:
647
+    altered_name: List[str]
648
+    link_existing: List[str]
649
+    new_people: List[str]
650
+    num_changes: int
651
+
652
+    @classmethod
653
+    def from_dict(cls, data: Dict[str, Any]) -> AardbeiPeopleDiff:
654
+        return cls(**data)
655
+
656
+    @classmethod
657
+    def get_diff(
658
+        cls, token: str, endpoint: str
659
+    ) -> Union[AardbeiPeopleDiff, NetworkError]:
660
+        try:
661
+            req = requests.post(
662
+                urljoin(SERVER_URL, "/aardbei/diff_people"),
663
+                json={"endpoint": endpoint, "token": token},
664
+            )
665
+            req.raise_for_status()
666
+            data = req.json()
667
+
668
+            return cls.from_dict(data)
669
+
670
+        except requests.ConnectionError as e:
671
+            LOG.exception(e)
672
+            return NetworkError.ConnectionFailure
673
+
674
+        except requests.HTTPError as e:
675
+            LOG.exception(e)
676
+            return NetworkError.HttpFailure
677
+
678
+        except ValueError as e:
679
+            LOG.exception(e)
680
+            return NetworkError.InvalidData
681
+
682
+    @classmethod
683
+    def sync(cls, token: str, endpoint: str) -> Union[AardbeiPeopleDiff, NetworkError]:
684
+        try:
685
+            req = requests.post(
686
+                urljoin(SERVER_URL, "/aardbei/sync_people"),
687
+                json={"endpoint": endpoint, "token": token},
688
+            )
689
+            req.raise_for_status()
690
+            data = req.json()
691
+
692
+            return cls.from_dict(data)
693
+
694
+        except requests.ConnectionError as e:
695
+            LOG.exception(e)
696
+            return NetworkError.ConnectionFailure
697
+
698
+        except requests.HTTPError as e:
699
+            LOG.exception(e)
700
+            return NetworkError.HttpFailure
701
+
702
+        except ValueError as e:
703
+            LOG.exception(e)
704
+            return NetworkError.InvalidData

+ 13 - 14
piket_client/set_active.py

2
 Provides a helper tool to (de-)activate multiple people at once.
2
 Provides a helper tool to (de-)activate multiple people at once.
3
 """
3
 """
4
 
4
 
5
+import math
5
 import sys
6
 import sys
6
 
7
 
7
 # pylint: disable=E0611
8
 # pylint: disable=E0611
9
+import qdarkstyle
10
+from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot
11
+from PySide2.QtGui import QIcon
8
 from PySide2.QtWidgets import (
12
 from PySide2.QtWidgets import (
9
     QAction,
13
     QAction,
10
     QActionGroup,
14
     QActionGroup,
19
     QToolBar,
23
     QToolBar,
20
     QWidget,
24
     QWidget,
21
 )
25
 )
22
-from PySide2.QtGui import QIcon
23
-from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot
24
 
26
 
25
-# pylint: enable=E0611
27
+from piket_client.model import NetworkError, Person, ServerStatus
26
 
28
 
27
-import qdarkstyle
28
-
29
-from piket_client.model import Person, ServerStatus
29
+# pylint: enable=E0611
30
 
30
 
31
 
31
 
32
 class ActivationButton(QPushButton):
32
 class ActivationButton(QPushButton):
55
 
55
 
56
     def init_ui(self) -> None:
56
     def init_ui(self) -> None:
57
         ps = Person.get_all()
57
         ps = Person.get_all()
58
-        num_columns = round(len(ps) / 10) + 1
58
+        assert not isinstance(ps, NetworkError)
59
+        num_columns = math.ceil(math.sqrt(len(ps)))
59
 
60
 
60
         for index, person in enumerate(ps):
61
         for index, person in enumerate(ps):
61
             button = ActivationButton(person, self)
62
             button = ActivationButton(person, self)
66
     def __init__(self) -> None:
67
     def __init__(self) -> None:
67
         super().__init__()
68
         super().__init__()
68
 
69
 
69
-        self.toolbar = None
70
+        self.toolbar = QToolBar()
70
 
71
 
71
         self.init_ui()
72
         self.init_ui()
72
 
73
 
79
         icon_size = font_metrics.height() * 1.45
80
         icon_size = font_metrics.height() * 1.45
80
 
81
 
81
         # Toolbar
82
         # Toolbar
82
-        self.toolbar = QToolBar()
83
         self.toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
83
         self.toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
84
         self.toolbar.setIconSize(QSize(icon_size, icon_size))
84
         self.toolbar.setIconSize(QSize(icon_size, icon_size))
85
 
85
 
112
     app.setFont(font)
112
     app.setFont(font)
113
 
113
 
114
     # Test connectivity
114
     # Test connectivity
115
-    server_running, info = ServerStatus.is_server_running()
115
+    server_running = ServerStatus.is_server_running()
116
 
116
 
117
-    if not server_running:
118
-        LOG.critical("Could not connect to server", extra={"info": info})
117
+    if not isinstance(server_running, bool):
119
         QMessageBox.critical(
118
         QMessageBox.critical(
120
             None,
119
             None,
121
             "Help er is iets kapot",
120
             "Help er is iets kapot",
122
             "Kan niet starten omdat de server niet reageert, stuur een foto van "
121
             "Kan niet starten omdat de server niet reageert, stuur een foto van "
123
-            "dit naar Maarten: " + repr(info),
122
+            "dit naar Maarten: " + repr(server_running.value),
124
         )
123
         )
125
-        return 1
124
+        return
126
 
125
 
127
     # Load main window
126
     # Load main window
128
     main_window = ActiveStateMainWindow()
127
     main_window = ActiveStateMainWindow()

+ 6 - 8
piket_client/sound.py

2
 Provides functions related to playing sounds.
2
 Provides functions related to playing sounds.
3
 """
3
 """
4
 
4
 
5
-import os
5
+import pathlib
6
 
6
 
7
-import simpleaudio as sa
8
 
7
 
9
-
10
-SOUNDS_DIR = os.path.join(os.path.dirname(__file__), "sounds")
8
+SOUND_PATH = pathlib.Path(__file__).parent / "sounds"
11
 """ Contains the absolute path to the sounds directory. """
9
 """ Contains the absolute path to the sounds directory. """
12
 
10
 
13
-PLOP_WAVE = sa.WaveObject.from_wave_file(os.path.join(SOUNDS_DIR, "plop.wav"))
14
-""" SimpleAudio WaveObject containing the plop sound. """
11
+PLOP_PATH = SOUND_PATH / "plop.wav"
12
+""" Path to the "plop" sound. """
15
 
13
 
16
-UNDO_WAVE = sa.WaveObject.from_wave_file(os.path.join(SOUNDS_DIR, "undo.wav"))
17
-""" SimpleAudio WaveObject containing the undo sound. """
14
+UNDO_PATH = SOUND_PATH / "undo.wav"
15
+""" Path to the "undo" sound". """

+ 9 - 487
piket_server/__init__.py

2
 Piket server, handles events generated by the client.
2
 Piket server, handles events generated by the client.
3
 """
3
 """
4
 
4
 
5
-import datetime
6
-import os
7
-
8
-from sqlalchemy.exc import SQLAlchemyError
9
-from sqlalchemy import func
10
-from flask import Flask, jsonify, abort, request
11
-from flask_sqlalchemy import SQLAlchemy
12
-
13
-
14
-DATA_HOME = os.environ.get("XDG_DATA_HOME", "~/.local/share")
15
-CONFIG_DIR = os.path.join(DATA_HOME, "piket_server")
16
-DB_PATH = os.path.expanduser(os.path.join(CONFIG_DIR, "database.sqlite3"))
17
-DB_URL = f"sqlite:///{DB_PATH}"
18
-
19
-app = Flask("piket_server")
20
-app.config["SQLALCHEMY_DATABASE_URI"] = DB_URL
21
-app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
22
-db = SQLAlchemy(app)
23
-
24
-
25
-# ---------- Models ----------
26
-class Person(db.Model):
27
-    """ Represents a person to be shown on the lists. """
28
-
29
-    __tablename__ = "people"
30
-
31
-    person_id = db.Column(db.Integer, primary_key=True)
32
-    name = db.Column(db.String, nullable=False)
33
-    active = db.Column(db.Boolean, nullable=False, default=False)
34
-
35
-    consumptions = db.relationship("Consumption", backref="person", lazy=True)
36
-
37
-    def __repr__(self) -> str:
38
-        return f"<Person {self.person_id}: {self.name}>"
39
-
40
-    @property
41
-    def as_dict(self) -> dict:
42
-        return {
43
-            "person_id": self.person_id,
44
-            "active": self.active,
45
-            "name": self.name,
46
-            "consumptions": {
47
-                ct.consumption_type_id: Consumption.query.filter_by(person=self)
48
-                .filter_by(settlement=None)
49
-                .filter_by(consumption_type=ct)
50
-                .filter_by(reversed=False)
51
-                .count()
52
-                for ct in ConsumptionType.query.all()
53
-            },
54
-        }
55
-
56
-
57
-class Export(db.Model):
58
-    """ Represents a set of exported Settlements. """
59
-
60
-    __tablename__ = "exports"
61
-
62
-    export_id = db.Column(db.Integer, primary_key=True)
63
-    created_at = db.Column(
64
-        db.DateTime, default=datetime.datetime.utcnow, nullable=False
65
-    )
66
-
67
-    settlements = db.relationship("Settlement", backref="export", lazy=True)
68
-
69
-    @property
70
-    def as_dict(self) -> dict:
71
-        return {
72
-            "export_id": self.export_id,
73
-            "created_at": self.created_at.isoformat(),
74
-            "settlement_ids": [s.settlement_id for s in self.settlements],
75
-        }
76
-
77
-
78
-class Settlement(db.Model):
79
-    """ Represents a settlement of the list. """
80
-
81
-    __tablename__ = "settlements"
82
-
83
-    settlement_id = db.Column(db.Integer, primary_key=True)
84
-    name = db.Column(db.String, nullable=False)
85
-    export_id = db.Column(db.Integer, db.ForeignKey("exports.export_id"), nullable=True)
86
-
87
-    consumptions = db.relationship("Consumption", backref="settlement", lazy=True)
88
-
89
-    def __repr__(self) -> str:
90
-        return f"<Settlement {self.settlement_id}: {self.name}>"
91
-
92
-    @property
93
-    def as_dict(self) -> dict:
94
-        return {
95
-            "settlement_id": self.settlement_id,
96
-            "name": self.name,
97
-            "consumption_summary": self.consumption_summary,
98
-            "unique_people": self.unique_people,
99
-        }
100
-
101
-    @property
102
-    def unique_people(self) -> int:
103
-        q = (
104
-            Consumption.query.filter_by(settlement=self)
105
-            .filter_by(reversed=False)
106
-            .group_by(Consumption.person_id)
107
-            .count()
108
-        )
109
-        return q
110
-
111
-    @property
112
-    def consumption_summary(self) -> dict:
113
-        q = (
114
-            Consumption.query.filter_by(settlement=self)
115
-            .filter_by(reversed=False)
116
-            .group_by(Consumption.consumption_type_id)
117
-            .order_by(ConsumptionType.name)
118
-            .outerjoin(ConsumptionType)
119
-            .with_entities(
120
-                Consumption.consumption_type_id,
121
-                ConsumptionType.name,
122
-                func.count(Consumption.consumption_id),
123
-            )
124
-            .all()
125
-        )
126
-
127
-        return {r[0]: {"name": r[1], "count": r[2]} for r in q}
128
-
129
-    @property
130
-    def per_person(self) -> dict:
131
-        # Get keys of seen consumption_types
132
-        c_types = self.consumption_summary.keys()
133
-
134
-        result = {}
135
-        for type in c_types:
136
-            c_type = ConsumptionType.query.get(type)
137
-            result[type] = {"consumption_type": c_type.as_dict, "counts": {}}
138
-
139
-            q = (
140
-                Consumption.query.filter_by(settlement=self)
141
-                .filter_by(reversed=False)
142
-                .filter_by(consumption_type=c_type)
143
-                .group_by(Consumption.person_id)
144
-                .order_by(Person.name)
145
-                .outerjoin(Person)
146
-                .with_entities(
147
-                    Person.person_id,
148
-                    Person.name,
149
-                    func.count(Consumption.consumption_id),
150
-                )
151
-                .all()
152
-            )
153
-
154
-            for row in q:
155
-                result[type]["counts"][row[0]] = {"name": row[1], "count": row[2]}
156
-
157
-        return result
158
-
159
-
160
-class ConsumptionType(db.Model):
161
-    """ Represents a type of consumption to be counted. """
162
-
163
-    __tablename__ = "consumption_types"
164
-
165
-    consumption_type_id = db.Column(db.Integer, primary_key=True)
166
-    name = db.Column(db.String, nullable=False)
167
-    icon = db.Column(db.String)
168
-    active = db.Column(db.Boolean, default=True)
169
-
170
-    consumptions = db.relationship("Consumption", backref="consumption_type", lazy=True)
171
-
172
-    def __repr__(self) -> str:
173
-        return f"<ConsumptionType: {self.name}>"
174
-
175
-    @property
176
-    def as_dict(self) -> dict:
177
-        return {
178
-            "consumption_type_id": self.consumption_type_id,
179
-            "name": self.name,
180
-            "icon": self.icon,
181
-        }
182
-
183
-
184
-class Consumption(db.Model):
185
-    """ Represent one consumption to be counted. """
186
-
187
-    __tablename__ = "consumptions"
188
-
189
-    consumption_id = db.Column(db.Integer, primary_key=True)
190
-    person_id = db.Column(db.Integer, db.ForeignKey("people.person_id"), nullable=True)
191
-    consumption_type_id = db.Column(
192
-        db.Integer,
193
-        db.ForeignKey("consumption_types.consumption_type_id"),
194
-        nullable=False,
195
-    )
196
-    settlement_id = db.Column(
197
-        db.Integer, db.ForeignKey("settlements.settlement_id"), nullable=True
198
-    )
199
-    created_at = db.Column(
200
-        db.DateTime, default=datetime.datetime.utcnow, nullable=False
201
-    )
202
-    reversed = db.Column(db.Boolean, default=False, nullable=False)
203
-
204
-    def __repr__(self) -> str:
205
-        return f"<Consumption: {self.consumption_type.name} for {self.person.name}>"
206
-
207
-    @property
208
-    def as_dict(self) -> dict:
209
-        return {
210
-            "consumption_id": self.consumption_id,
211
-            "person_id": self.person_id,
212
-            "consumption_type_id": self.consumption_type_id,
213
-            "settlement_id": self.settlement_id,
214
-            "created_at": self.created_at.isoformat(),
215
-            "reversed": self.reversed,
216
-        }
217
-
218
-
219
-# ---------- Models ----------
220
-
221
-
222
-@app.route("/ping")
223
-def ping() -> None:
224
-    """ Return a status ping. """
225
-    return "Pong"
226
-
227
-
228
-@app.route("/status")
229
-def status() -> None:
230
-    """ Return a status dict with info about the database. """
231
-    unsettled_q = Consumption.query.filter_by(settlement=None).filter_by(reversed=False)
232
-
233
-    unsettled = unsettled_q.count()
234
-
235
-    first = None
236
-    last = None
237
-    if unsettled:
238
-        last = (
239
-            unsettled_q.order_by(Consumption.created_at.desc())
240
-            .first()
241
-            .created_at.isoformat()
242
-        )
243
-        first = (
244
-            unsettled_q.order_by(Consumption.created_at.asc())
245
-            .first()
246
-            .created_at.isoformat()
247
-        )
248
-
249
-    return jsonify({"unsettled": {"amount": unsettled, "first": first, "last": last}})
250
-
251
-
252
-# Person
253
-@app.route("/people", methods=["GET"])
254
-def get_people():
255
-    """ Return a list of currently known people. """
256
-    people = Person.query.order_by(Person.name).all()
257
-    q = Person.query.order_by(Person.name)
258
-    if request.args.get("active"):
259
-        active_status = request.args.get("active", type=int)
260
-        q = q.filter_by(active=active_status)
261
-    people = q.all()
262
-    result = [person.as_dict for person in people]
263
-    return jsonify(people=result)
264
-
265
-
266
-@app.route("/people/<int:person_id>", methods=["GET"])
267
-def get_person(person_id: int):
268
-    person = Person.query.get_or_404(person_id)
269
-
270
-    return jsonify(person=person.as_dict)
271
-
272
-
273
-@app.route("/people", methods=["POST"])
274
-def add_person():
275
-    """
276
-    Add a new person.
277
-
278
-    Required parameters:
279
-    - name (str)
280
-    """
281
-    json = request.get_json()
282
-
283
-    if not json:
284
-        return jsonify({"error": "Could not parse JSON."}), 400
285
-
286
-    data = json.get("person") or {}
287
-    person = Person(name=data.get("name"), active=data.get("active", False))
288
-
289
-    try:
290
-        db.session.add(person)
291
-        db.session.commit()
292
-    except SQLAlchemyError:
293
-        return jsonify({"error": "Invalid arguments for Person."}), 400
294
-
295
-    return jsonify(person=person.as_dict), 201
296
-
297
-
298
-@app.route("/people/<int:person_id>/add_consumption", methods=["POST"])
299
-def add_consumption(person_id: int):
300
-    person = Person.query.get_or_404(person_id)
301
-
302
-    consumption = Consumption(person=person, consumption_type_id=1)
303
-    try:
304
-        db.session.add(consumption)
305
-        db.session.commit()
306
-    except SQLAlchemyError:
307
-        return (
308
-            jsonify(
309
-                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
310
-            ),
311
-            400,
312
-        )
313
-
314
-    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201
315
-
316
-
317
-@app.route("/people/<int:person_id>", methods=["PATCH"])
318
-def update_person(person_id: int):
319
-    person = Person.query.get_or_404(person_id)
320
-
321
-    data = request.json["person"]
322
-
323
-    if "active" in data:
324
-        person.active = data["active"]
325
-
326
-        db.session.add(person)
327
-        db.session.commit()
328
-
329
-        return jsonify(person=person.as_dict)
330
-
331
-
332
-@app.route("/people/<int:person_id>/add_consumption/<int:ct_id>", methods=["POST"])
333
-def add_consumption2(person_id: int, ct_id: int):
334
-    person = Person.query.get_or_404(person_id)
335
-
336
-    consumption = Consumption(person=person, consumption_type_id=ct_id)
337
-    try:
338
-        db.session.add(consumption)
339
-        db.session.commit()
340
-    except SQLAlchemyError:
341
-        return (
342
-            jsonify(
343
-                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
344
-            ),
345
-            400,
346
-        )
347
-
348
-    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201
349
-
350
-
351
-@app.route("/consumptions/<int:consumption_id>", methods=["DELETE"])
352
-def reverse_consumption(consumption_id: int):
353
-    """ Reverse a consumption. """
354
-    consumption = Consumption.query.get_or_404(consumption_id)
355
-
356
-    if consumption.reversed:
357
-        return (
358
-            jsonify(
359
-                {
360
-                    "error": "Consumption already reversed",
361
-                    "consumption": consumption.as_dict,
362
-                }
363
-            ),
364
-            409,
365
-        )
366
-
367
-    try:
368
-        consumption.reversed = True
369
-        db.session.add(consumption)
370
-        db.session.commit()
371
-
372
-    except SQLAlchemyError:
373
-        return jsonify({"error": "Database error."}), 500
374
-
375
-    return jsonify(consumption=consumption.as_dict), 200
376
-
377
-
378
-# ConsumptionType
379
-@app.route("/consumption_types", methods=["GET"])
380
-def get_consumption_types():
381
-    """ Return a list of currently active consumption types. """
382
-    ctypes = ConsumptionType.query.filter_by(active=True).all()
383
-    result = [ct.as_dict for ct in ctypes]
384
-    return jsonify(consumption_types=result)
385
-
386
-
387
-@app.route("/consumption_types/<int:consumption_type_id>", methods=["GET"])
388
-def get_consumption_type(consumption_type_id: int):
389
-    ct = ConsumptionType.query.get_or_404(consumption_type_id)
390
-
391
-    return jsonify(consumption_type=ct.as_dict)
392
-
393
-
394
-@app.route("/consumption_types", methods=["POST"])
395
-def add_consumption_type():
396
-    """ Add a new ConsumptionType.  """
397
-    json = request.get_json()
398
-
399
-    if not json:
400
-        return jsonify({"error": "Could not parse JSON."}), 400
401
-
402
-    data = json.get("consumption_type") or {}
403
-    ct = ConsumptionType(name=data.get("name"), icon=data.get("icon"))
404
-
405
-    try:
406
-        db.session.add(ct)
407
-        db.session.commit()
408
-    except SQLAlchemyError:
409
-        return jsonify({"error": "Invalid arguments for ConsumptionType."}), 400
410
-
411
-    return jsonify(consumption_type=ct.as_dict), 201
412
-
413
-
414
-# Settlement
415
-@app.route("/settlements", methods=["GET"])
416
-def get_settlements():
417
-    """ Return a list of the active Settlements. """
418
-    result = Settlement.query.all()
419
-    return jsonify(settlements=[s.as_dict for s in result])
420
-
421
-
422
-@app.route("/settlements/<int:settlement_id>", methods=["GET"])
423
-def get_settlement(settlement_id: int):
424
-    """ Show full details for a single Settlement. """
425
-    s = Settlement.query.get_or_404(settlement_id)
426
-
427
-    per_person = s.per_person
428
-
429
-    return jsonify(settlement=s.as_dict, count_info=per_person)
430
-
431
-
432
-@app.route("/settlements", methods=["POST"])
433
-def add_settlement():
434
-    """ Create a Settlement, and link all un-settled Consumptions to it. """
435
-    json = request.get_json()
436
-
437
-    if not json:
438
-        return jsonify({"error": "Could not parse JSON."}), 400
439
-
440
-    data = json.get("settlement") or {}
441
-    s = Settlement(name=data["name"])
442
-
443
-    db.session.add(s)
444
-    db.session.commit()
445
-
446
-    Consumption.query.filter_by(settlement=None).update(
447
-        {"settlement_id": s.settlement_id}
448
-    )
449
-
450
-    db.session.commit()
451
-
452
-    return jsonify(settlement=s.as_dict)
453
-
454
-
455
-# Export
456
-@app.route("/exports", methods=["GET"])
457
-def get_exports():
458
-    """ Return a list of the created Exports. """
459
-    result = Export.query.all()
460
-    return jsonify(exports=[e.as_dict for e in result])
461
-
462
-
463
-@app.route("/exports/<int:export_id>", methods=["GET"])
464
-def get_export(export_id: int):
465
-    """ Return an overview for the given Export. """
466
-    e = Export.query.get_or_404(export_id)
467
-
468
-    ss = [s.as_dict for s in e.settlements]
469
-
470
-    return jsonify(export=e.as_dict, settlements=ss)
471
-
472
-
473
-@app.route("/exports", methods=["POST"])
474
-def add_export():
475
-    """ Create an Export, and link all un-exported Settlements to it. """
476
-    # Assert that there are Settlements to be exported.
477
-    s_count = Settlement.query.filter_by(export=None).count()
478
-    if s_count == 0:
479
-        return jsonify(error="No un-exported Settlements."), 403
480
-
481
-    e = Export()
482
-
483
-    db.session.add(e)
484
-    db.session.commit()
485
-
486
-    Settlement.query.filter_by(export=None).update({"export_id": e.export_id})
487
-    db.session.commit()
488
-
489
-    ss = [s.as_dict for s in e.settlements]
490
-
491
-    return jsonify(export=e.as_dict, settlements=ss), 201
5
+from piket_server.flask import app
6
+
7
+import piket_server.routes.general
8
+import piket_server.routes.people
9
+import piket_server.routes.consumptions
10
+import piket_server.routes.consumption_types
11
+import piket_server.routes.settlements
12
+import piket_server.routes.exports
13
+import piket_server.routes.aardbei

+ 679 - 0
piket_server/aardbei_sync.py

1
+from __future__ import annotations
2
+
3
+import datetime
4
+import json
5
+import logging
6
+import sys
7
+from dataclasses import asdict, dataclass
8
+from enum import Enum
9
+from typing import Any, Dict, List, NewType, Optional, Tuple, Union
10
+
11
+import requests
12
+
13
+from piket_server.flask import db
14
+from piket_server.models import Person
15
+from piket_server.util import fmt_datetime
16
+
17
+# AARDBEI_ENDPOINT = "https://aardbei.app"
18
+AARDBEI_ENDPOINT = "http://localhost:3000"
19
+log = logging.getLogger(__name__)
20
+
21
+ActivityId = NewType("ActivityId", int)
22
+PersonId = NewType("PersonId", int)
23
+MemberId = NewType("MemberId", int)
24
+ParticipantId = NewType("ParticipantId", int)
25
+
26
+
27
+@dataclass(frozen=True)
28
+class AardbeiPerson:
29
+    """
30
+    Contains the data on a Person as exposed by Aardbei.
31
+
32
+    A Person represents a person in the real world, and maps to a Person in the local database.
33
+    """
34
+
35
+    aardbei_id: PersonId
36
+    full_name: str
37
+
38
+    @classmethod
39
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiPerson:
40
+        """
41
+        Load from a dictionary provided by Aardbei.
42
+
43
+        >>> AardbeiPerson.from_aardbei_dict(
44
+          {"person": {"aardbei_id": 1, "full_name": "Henkie Kraggelwenk"}}
45
+        )
46
+        AardbeiPerson(aardbei_id=AardbeiId(1), full_name="Henkie Kraggelwenk")
47
+        """
48
+
49
+        d = data["person"]
50
+        return cls(full_name=d["full_name"], aardbei_id=PersonId(d["id"]))
51
+
52
+    @property
53
+    def as_json_dict(self) -> Dict[str, Any]:
54
+        """
55
+        Serialize to a dictionary as provided by Aardbei.
56
+
57
+        >>> AardbeiPerson(aardbei_id=AardbeiId(1), full_name="Henkie Kraggelwenk").as_json_dict
58
+        {"person": {"id": 1, "full_name": "Henkie Kraggelwenk"}}
59
+        """
60
+
61
+        return {"person": {"id": self.aardbei_id, "full_name": self.full_name}}
62
+
63
+
64
+@dataclass(frozen=True)
65
+class AardbeiMember:
66
+    """
67
+    Contains the data on a Member exposed by Aardbei.
68
+
69
+    A Member represents the membership of a Person in a Group in Aardbei.
70
+    """
71
+
72
+    person: AardbeiPerson
73
+    aardbei_id: MemberId
74
+    is_leader: bool
75
+    display_name: str
76
+
77
+    @classmethod
78
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiMember:
79
+        """
80
+        Load from a dictionary provided by Aardbei.
81
+
82
+        >>> from_aardbei_dict({
83
+            "member": {
84
+                "person": {
85
+                    "full_name": "Roer Kuggelvork",
86
+                    "id": 2,
87
+                },
88
+                "id": 23,
89
+                "is_leader": False,
90
+                "display_name": "Roer",
91
+            },
92
+        })
93
+        AardbeiMember(
94
+            person=AardbeiPerson(aardbei_id=PersonId(2), full_name="Roer Kuggelvork"),
95
+            aardbei_id=MemberId(23),
96
+            is_leader=False,
97
+            display_name="Roer",
98
+        )
99
+        """
100
+
101
+        d = data["member"]
102
+        person = AardbeiPerson.from_aardbei_dict(d)
103
+        return cls(
104
+            person=person,
105
+            aardbei_id=MemberId(d["id"]),
106
+            is_leader=d["is_leader"],
107
+            display_name=d["display_name"],
108
+        )
109
+
110
+    @property
111
+    def as_json_dict(self) -> Dict[str, Any]:
112
+        """
113
+        Serialize to a dict as provided by Aardbei.
114
+
115
+        >>> AardbeiMember(
116
+            person=AardbeiPerson(aardbei_id=PersonId(2), full_name="Roer Kuggelvork"),
117
+            aardbei_id=MemberId(23),
118
+            is_leader=False,
119
+            display_name="Roer",
120
+        )
121
+        {
122
+            "member": {
123
+                "person": {
124
+                    "full_name": "Roer Kuggelvork",
125
+                    "id": 2,
126
+                },
127
+                "id": 23,
128
+                "is_leader": False,
129
+                "display_name": "Roer",
130
+            }
131
+        }
132
+        """
133
+        res = {
134
+            "id": self.aardbei_id,
135
+            "is_leader": self.is_leader,
136
+            "display_name": self.display_name,
137
+        }
138
+        res.update(self.person.as_json_dict)
139
+        return res
140
+
141
+
142
+@dataclass(frozen=True)
143
+class AardbeiParticipant:
144
+    """
145
+    Represents a Participant as exposed by Aardbei.
146
+
147
+    A Participant represents the participation of a Person (optionally as a Member in a Group) in an Activity.
148
+    """
149
+
150
+    person: AardbeiPerson
151
+    member: Optional[AardbeiMember]
152
+    aardbei_id: ParticipantId
153
+    attending: bool
154
+    is_organizer: bool
155
+    notes: Optional[str]
156
+
157
+    @property
158
+    def name(self) -> str:
159
+        """
160
+        Return the name to show for this Participant.
161
+        This is the display_name if a Member is present, else the Participant's Person's full name.
162
+        """
163
+        if self.member is not None:
164
+            return self.member.display_name
165
+
166
+        return self.person.full_name
167
+
168
+    @classmethod
169
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiParticipant:
170
+        """
171
+        Load from a dictionary as provided by Aardbei.
172
+        """
173
+        d = data["participant"]
174
+        person = AardbeiPerson.from_aardbei_dict(d)
175
+
176
+        member: Optional[AardbeiMember] = None
177
+        if d["member"] is not None:
178
+            member = AardbeiMember.from_aardbei_dict(d)
179
+
180
+        aardbei_id = ParticipantId(d["id"])
181
+
182
+        return cls(
183
+            person=person,
184
+            member=member,
185
+            aardbei_id=aardbei_id,
186
+            attending=d["attending"],
187
+            is_organizer=d["is_organizer"],
188
+            notes=d["notes"],
189
+        )
190
+
191
+    @property
192
+    def as_json_dict(self) -> Dict[str, Any]:
193
+        """
194
+        Serialize to a dict as provided by Aardbei.
195
+        """
196
+        res = {
197
+            "participant": {
198
+                "id": self.aardbei_id,
199
+                "attending": self.attending,
200
+                "is_organizer": self.is_organizer,
201
+                "notes": self.notes,
202
+            }
203
+        }
204
+        res.update(self.person.as_json_dict)
205
+        if self.member is not None:
206
+            res.update(self.member.as_json_dict)
207
+
208
+        return res
209
+
210
+
211
+class NoResponseAction(Enum):
212
+    """Represents the "no response action" attribute of Activities in Aardbei."""
213
+
214
+    Present = "present"
215
+    Absent = "absent"
216
+
217
+
218
+@dataclass(frozen=True)
219
+class ResponseCounts:
220
+    """Represents the "response counts" attribute of Activities in Aardbei."""
221
+
222
+    present: int
223
+    absent: int
224
+    unknown: int
225
+
226
+    @classmethod
227
+    def from_aardbei_dict(cls, data: Dict[str, int]) -> ResponseCounts:
228
+        """Load from a dict as provided by Aardbei."""
229
+        return cls(
230
+            present=data["present"], absent=data["absent"], unknown=data["unknown"]
231
+        )
232
+
233
+    @property
234
+    def as_json_dict(self) -> Dict[str, int]:
235
+        """Serialize to a dict as provided by Aardbei."""
236
+        return {"present": self.present, "absent": self.absent, "unknown": self.unknown}
237
+
238
+
239
+@dataclass(frozen=True)
240
+class SparseAardbeiActivity:
241
+    aardbei_id: ActivityId
242
+    name: str
243
+    description: str
244
+    location: str
245
+    start: datetime.datetime
246
+    end: Optional[datetime.datetime]
247
+    deadline: Optional[datetime.datetime]
248
+    reminder_at: Optional[datetime.datetime]
249
+    no_response_action: NoResponseAction
250
+    response_counts: ResponseCounts
251
+
252
+    def distance(self, reference: datetime.datetime) -> datetime.timedelta:
253
+        """Calculate how long ago this Activity ended / how much time until it starts."""
254
+        if self.end is not None:
255
+            if reference > self.start and reference < self.end:
256
+                return datetime.timedelta(seconds=0)
257
+
258
+            elif reference < self.start:
259
+                return self.start - reference
260
+
261
+            elif reference > self.end:
262
+                return reference - self.end
263
+
264
+        if reference > self.start:
265
+            return reference - self.start
266
+
267
+        return self.start - reference
268
+
269
+    @classmethod
270
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> SparseAardbeiActivity:
271
+        """Load from a dict as provided by Aardbei."""
272
+        start: datetime.datetime = datetime.datetime.fromisoformat(
273
+            data["activity"]["start"]
274
+        )
275
+        end: Optional[datetime.datetime] = None
276
+
277
+        if data["activity"]["end"] is not None:
278
+            end = datetime.datetime.fromisoformat(data["activity"]["end"])
279
+
280
+        deadline: Optional[datetime.datetime] = None
281
+        if data["activity"]["deadline"] is not None:
282
+            deadline = datetime.datetime.fromisoformat(data["activity"]["deadline"])
283
+
284
+        reminder_at: Optional[datetime.datetime] = None
285
+        if data["activity"]["reminder_at"] is not None:
286
+            reminder_at = datetime.datetime.fromisoformat(
287
+                data["activity"]["reminder_at"]
288
+            )
289
+
290
+        no_response_action = NoResponseAction(data["activity"]["no_response_action"])
291
+
292
+        response_counts = ResponseCounts.from_aardbei_dict(
293
+            data["activity"]["response_counts"]
294
+        )
295
+
296
+        return cls(
297
+            aardbei_id=ActivityId(data["activity"]["id"]),
298
+            name=data["activity"]["name"],
299
+            description=data["activity"]["description"],
300
+            location=data["activity"]["location"],
301
+            start=start,
302
+            end=end,
303
+            deadline=deadline,
304
+            reminder_at=reminder_at,
305
+            no_response_action=no_response_action,
306
+            response_counts=response_counts,
307
+        )
308
+
309
+    @property
310
+    def as_json_dict(self) -> Dict[str, Any]:
311
+        """Serialize to a dict as provided by Aardbei."""
312
+        return {
313
+            "activity": {
314
+                "id": self.aardbei_id,
315
+                "name": self.name,
316
+                "description": self.description,
317
+                "location": self.location,
318
+                "start": fmt_datetime(self.start),
319
+                "end": fmt_datetime(self.end),
320
+                "deadline": fmt_datetime(self.deadline),
321
+                "reminder_at": fmt_datetime(self.reminder_at),
322
+                "no_response_action": self.no_response_action.value,
323
+                "response_counts": self.response_counts.as_json_dict,
324
+            }
325
+        }
326
+
327
+
328
+@dataclass(frozen=True)
329
+class AardbeiActivity(SparseAardbeiActivity):
330
+    """Contains the data of an Activity as exposed by Aardbei."""
331
+
332
+    participants: List[AardbeiParticipant]
333
+
334
+    @classmethod
335
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiActivity:
336
+        """Load from a dict as provided by Aardbei."""
337
+        # Ugly: This is a copy of the Sparse variant with added participants.
338
+        # This is not ideal, but I don't care enough to fix this right now.
339
+        participants: List[AardbeiParticipant] = [
340
+            AardbeiParticipant.from_aardbei_dict(x)
341
+            for x in data["activity"]["participants"]
342
+        ]
343
+
344
+        start: datetime.datetime = datetime.datetime.fromisoformat(
345
+            data["activity"]["start"]
346
+        )
347
+        end: Optional[datetime.datetime] = None
348
+
349
+        if data["activity"]["end"] is not None:
350
+            end = datetime.datetime.fromisoformat(data["activity"]["end"])
351
+
352
+        deadline: Optional[datetime.datetime] = None
353
+        if data["activity"]["deadline"] is not None:
354
+            deadline = datetime.datetime.fromisoformat(data["activity"]["deadline"])
355
+
356
+        reminder_at: Optional[datetime.datetime] = None
357
+        if data["activity"]["reminder_at"] is not None:
358
+            reminder_at = datetime.datetime.fromisoformat(
359
+                data["activity"]["reminder_at"]
360
+            )
361
+
362
+        no_response_action = NoResponseAction(data["activity"]["no_response_action"])
363
+
364
+        response_counts = ResponseCounts.from_aardbei_dict(
365
+            data["activity"]["response_counts"]
366
+        )
367
+
368
+        return cls(
369
+            aardbei_id=ActivityId(data["activity"]["id"]),
370
+            name=data["activity"]["name"],
371
+            description=data["activity"]["description"],
372
+            location=data["activity"]["location"],
373
+            start=start,
374
+            end=end,
375
+            deadline=deadline,
376
+            reminder_at=reminder_at,
377
+            no_response_action=no_response_action,
378
+            response_counts=response_counts,
379
+            participants=participants,
380
+        )
381
+
382
+    @property
383
+    def as_json_dict(self) -> Dict[str, Any]:
384
+        """Serialize to a dict as provided by Aardbei."""
385
+        res = super().as_json_dict
386
+        res["participants"] = [p.as_json_dict for p in self.participants]
387
+        return res
388
+
389
+
390
+@dataclass(frozen=True)
391
+class AardbeiMatch:
392
+    """Represents a match between a local Person and a Person present in Aardbei's data."""
393
+
394
+    local: Person
395
+    remote: AardbeiMember
396
+
397
+
398
+@dataclass(frozen=True)
399
+class AardbeiLink:
400
+    """Represents a set of differences between the local state and Aardbei's set of people."""
401
+
402
+    matches: List[AardbeiMatch]
403
+    """People that exist on both sides, but aren't linked in the people table."""
404
+    altered_name: List[AardbeiMatch]
405
+    """People that are already linked but changed one of their names."""
406
+    remote_only: List[AardbeiMember]
407
+    """People that only exist on the remote."""
408
+
409
+    @property
410
+    def num_changes(self) -> int:
411
+        """Return the amount of mismatching people between Aardbei and the local state."""
412
+        return len(self.matches) + len(self.altered_name) + len(self.remote_only)
413
+
414
+
415
+class AardbeiSyncError(Enum):
416
+    """Represents errors that might occur when retrieving data from Aardbei."""
417
+
418
+    CantConnect = "connect_fail"
419
+    HTTPError = "http_fail"
420
+
421
+
422
+def get_aardbei_people(
423
+    token: str, endpoint: str = AARDBEI_ENDPOINT
424
+) -> Union[List[AardbeiMember], AardbeiSyncError]:
425
+    """Retrieve the set of People in a Group from Aardbei, and parse this to
426
+    AardbeiPerson objects. Return a AardbeiSyncError if something fails."""
427
+    try:
428
+        resp: requests.Response = requests.get(
429
+            f"{endpoint}/api/groups/0/",
430
+            headers={"Authorization": f"Group {token}"},
431
+        )
432
+        resp.raise_for_status()
433
+
434
+    except requests.ConnectionError as e:
435
+        log.exception("Can't connect to endpoint %s", endpoint)
436
+        return AardbeiSyncError.CantConnect
437
+
438
+    except requests.HTTPError:
439
+        return AardbeiSyncError.HTTPError
440
+
441
+    members = resp.json()["group"]["members"]
442
+
443
+    return [AardbeiMember.from_aardbei_dict(x) for x in members]
444
+
445
+
446
+def match_local_aardbei(aardbei_members: List[AardbeiMember]) -> AardbeiLink:
447
+    """Inspect the local state and compare it with the set of given
448
+    AardbeiMembers (containing AardbeiPersons). Return a AardbeiLink that
449
+    indicates which local people don't match the remote state."""
450
+
451
+    matches: List[AardbeiMatch] = []
452
+    altered_name: List[AardbeiMatch] = []
453
+    remote_only: List[AardbeiMember] = []
454
+
455
+    for member in aardbei_members:
456
+        p: Optional[Person] = Person.query.filter_by(
457
+            aardbei_id=member.person.aardbei_id
458
+        ).one_or_none()
459
+
460
+        if p is not None:
461
+            if (
462
+                p.full_name != member.person.full_name
463
+                or p.display_name != member.display_name
464
+            ):
465
+                altered_name.append(AardbeiMatch(p, member))
466
+
467
+            else:
468
+                logging.info(
469
+                    "OK: %s / %s (L%s/R%s)",
470
+                    p.full_name,
471
+                    p.display_name,
472
+                    p.person_id,
473
+                    p.aardbei_id,
474
+                )
475
+
476
+            continue
477
+
478
+        p = Person.query.filter_by(full_name=member.person.full_name).one_or_none()
479
+
480
+        if p is not None:
481
+            matches.append(AardbeiMatch(p, member))
482
+        else:
483
+            remote_only.append(member)
484
+
485
+    return AardbeiLink(matches, altered_name, remote_only)
486
+
487
+
488
+def link_matches(matches: List[AardbeiMatch]) -> None:
489
+    """
490
+    Update local people to add the remote ID to the local state.
491
+    This only enqueues the changes in the local SQLAlchemy session, committing
492
+    needs to be done separately.
493
+    """
494
+
495
+    for match in matches:
496
+        match.local.aardbei_id = match.remote.person.aardbei_id
497
+        match.local.display_name = match.remote.display_name
498
+        logging.info(
499
+            "Linking local %s (%s) to remote %s (%s)",
500
+            match.local.full_name,
501
+            match.local.person_id,
502
+            match.remote.display_name,
503
+            match.remote.person.aardbei_id,
504
+        )
505
+
506
+        db.session.add(match.local)
507
+
508
+
509
+def create_missing(missing: List[AardbeiMember]) -> None:
510
+    """
511
+    Create local people for all remote people that don't exist locally.
512
+    This only enqueues the changes in the local SQLAlchemy session, committing
513
+    needs to be done separately.
514
+    """
515
+
516
+    for member in missing:
517
+        pnew = Person(
518
+            full_name=member.person.full_name,
519
+            display_name=member.display_name,
520
+            aardbei_id=member.person.aardbei_id,
521
+            active=False,
522
+        )
523
+        logging.info(
524
+            "Creating new person for %s / %s (%s)",
525
+            member.person.full_name,
526
+            member.display_name,
527
+            member.person.aardbei_id,
528
+        )
529
+        db.session.add(pnew)
530
+
531
+
532
+def update_names(matches: List[AardbeiMatch]) -> None:
533
+    """
534
+    Update the local full and display names of people that were already linked
535
+    to a remote person, and who changed names on the remote.
536
+
537
+    This only enqueues the changes in the local SQLAlchemy session, committing
538
+    needs to be done separately.
539
+    """
540
+
541
+    for match in matches:
542
+        p = match.local
543
+        member = match.remote
544
+        aardbei_person = member.person
545
+
546
+        changed = False
547
+
548
+        if p.full_name != aardbei_person.full_name:
549
+            logging.info(
550
+                "Updating %s (L%s/R%s) full name %s to %s",
551
+                aardbei_person.full_name,
552
+                p.person_id,
553
+                aardbei_person.aardbei_id,
554
+                p.full_name,
555
+                aardbei_person.full_name,
556
+            )
557
+            p.full_name = aardbei_person.full_name
558
+            changed = True
559
+
560
+        if p.display_name != member.display_name:
561
+            logging.info(
562
+                "Updating %s (L%s/R%s) display name %s to %s",
563
+                p.full_name,
564
+                p.person_id,
565
+                aardbei_person.aardbei_id,
566
+                p.display_name,
567
+                member.display_name,
568
+            )
569
+            p.display_name = member.display_name
570
+            changed = True
571
+
572
+        assert changed, "got match but didn't update anything"
573
+
574
+        db.session.add(p)
575
+
576
+
577
+def get_activities(
578
+    token: str, endpoint: str = AARDBEI_ENDPOINT
579
+) -> Union[List[SparseAardbeiActivity], AardbeiSyncError]:
580
+    """
581
+    Get the list of activities present on the remote and return these
582
+    activities, ordered by the temporal distance to the current time.
583
+    """
584
+
585
+    result: List[SparseAardbeiActivity] = []
586
+
587
+    for category in ("upcoming", "current", "previous"):
588
+        try:
589
+            resp = requests.get(
590
+                f"{endpoint}/api/groups/0/{category}_activities",
591
+                headers={"Authorization": f"Group {token}"},
592
+            )
593
+
594
+            resp.raise_for_status()
595
+
596
+        except requests.HTTPError as e:
597
+            log.exception(e)
598
+            return AardbeiSyncError.HTTPError
599
+
600
+        except requests.ConnectionError as e:
601
+            log.exception(e)
602
+            return AardbeiSyncError.CantConnect
603
+
604
+        for item in resp.json():
605
+            result.append(SparseAardbeiActivity.from_aardbei_dict(item))
606
+
607
+    now = datetime.datetime.now(datetime.timezone.utc)
608
+    result.sort(key=lambda x: SparseAardbeiActivity.distance(x, now))
609
+    return result
610
+
611
+
612
+def get_activity(
613
+    activity_id: ActivityId, token: str, endpoint: str
614
+) -> Union[AardbeiActivity, AardbeiSyncError]:
615
+    """
616
+    Get all data (including participants) from the remote about one activity
617
+    with a given ID.
618
+    """
619
+
620
+    try:
621
+        resp = requests.get(
622
+            f"{endpoint}/api/activities/{activity_id}",
623
+            headers={"Authorization": f"Group {token}"},
624
+        )
625
+
626
+        resp.raise_for_status()
627
+
628
+    except requests.HTTPError as e:
629
+        log.exception(e)
630
+        return AardbeiSyncError.HTTPError
631
+
632
+    except requests.ConnectionError as e:
633
+        return AardbeiSyncError.CantConnect
634
+
635
+    return AardbeiActivity.from_aardbei_dict(resp.json())
636
+
637
+
638
+def match_activity(activity: AardbeiActivity) -> None:
639
+    """
640
+    Update the local state to have mark all people present at the given
641
+    activity as active, and all other people as inactive.
642
+    """
643
+    ps = activity.participants
644
+    pids: List[PersonId] = [p.person.aardbei_id for p in ps if p.attending]
645
+
646
+    Person.query.update(values={"active": False})
647
+    Person.query.filter(Person.aardbei_id.in_(pids)).update(
648
+        values={"active": True}, synchronize_session="fetch"
649
+    )
650
+
651
+
652
+if __name__ == "__main__":
653
+    logging.basicConfig(level=logging.DEBUG)
654
+
655
+    token = input("Token: ")
656
+    aardbei_people = get_aardbei_people(token)
657
+
658
+    if isinstance(aardbei_people, AardbeiSyncError):
659
+        logging.error("Could not get people: %s", aardbei_people.value)
660
+        sys.exit(1)
661
+
662
+    activities = get_activities(token)
663
+
664
+    if isinstance(activities, AardbeiSyncError):
665
+        logging.error("Could not get activities: %s", activities.value)
666
+        sys.exit(1)
667
+
668
+    link = match_local_aardbei(aardbei_people)
669
+
670
+    link_matches(link.matches)
671
+    create_missing(link.remote_only)
672
+    update_names(link.altered_name)
673
+
674
+    confirm = input("Commit? Y/N")
675
+    if confirm.lower() == "y":
676
+        print("Committing.")
677
+        db.session.commit()
678
+    else:
679
+        print("Not committing.")

+ 9 - 11
piket_server/alembic/env.py

12
 # This line sets up loggers basically.
12
 # This line sets up loggers basically.
13
 fileConfig(config.config_file_name)
13
 fileConfig(config.config_file_name)
14
 
14
 
15
-# add your model's MetaData object here
16
-# for 'autogenerate' support
17
-# from myapp import mymodel
18
-# target_metadata = mymodel.Base.metadata
19
-import piket_server
20
-
21
-target_metadata = piket_server.db.Model.metadata
22
-
23
 # other values from the config, defined by the needs of env.py,
15
 # other values from the config, defined by the needs of env.py,
24
 # can be acquired:
16
 # can be acquired:
25
 # my_important_option = config.get_main_option("my_important_option")
17
 # my_important_option = config.get_main_option("my_important_option")
26
 # ... etc.
18
 # ... etc.
27
-from piket_server import CONFIG_DIR, DB_URL
19
+from piket_server.flask import CONFIG_DIR, DB_URL, db
20
+
21
+# add your model's MetaData object here
22
+# for 'autogenerate' support
23
+# from myapp import mymodel
24
+# target_metadata = mymodel.Base.metadata
25
+target_metadata = db.Model.metadata
28
 
26
 
29
 os.makedirs(os.path.expanduser(CONFIG_DIR), mode=0o744, exist_ok=True)
27
 os.makedirs(os.path.expanduser(CONFIG_DIR), mode=0o744, exist_ok=True)
30
 
28
 
31
 config.file_config["alembic"]["sqlalchemy.url"] = DB_URL
29
 config.file_config["alembic"]["sqlalchemy.url"] = DB_URL
32
 
30
 
33
 
31
 
34
-def run_migrations_offline():
32
+def run_migrations_offline() -> None:
35
     """Run migrations in 'offline' mode.
33
     """Run migrations in 'offline' mode.
36
 
34
 
37
     This configures the context with just a URL
35
     This configures the context with just a URL
50
         context.run_migrations()
48
         context.run_migrations()
51
 
49
 
52
 
50
 
53
-def run_migrations_online():
51
+def run_migrations_online() -> None:
54
     """Run migrations in 'online' mode.
52
     """Run migrations in 'online' mode.
55
 
53
 
56
     In this scenario we need to create an Engine
54
     In this scenario we need to create an Engine

+ 36 - 0
piket_server/alembic/versions/6a5989118ee3_enable_unique_constraints.py

1
+"""Enable unique constraints
2
+
3
+Revision ID: 6a5989118ee3
4
+Revises: cca57457a0a6
5
+Create Date: 2019-09-22 17:04:01.945713
6
+
7
+"""
8
+from alembic import op
9
+import sqlalchemy as sa
10
+
11
+
12
+# revision identifiers, used by Alembic.
13
+revision = "6a5989118ee3"
14
+down_revision = "cca57457a0a6"
15
+branch_labels = None
16
+depends_on = None
17
+
18
+
19
+def upgrade():
20
+    with op.batch_alter_table("consumption_types") as batch_op:
21
+        batch_op.create_unique_constraint("uc_consumption_types_name", ["name"])
22
+
23
+    with op.batch_alter_table("people") as batch_op2:
24
+        batch_op2.create_unique_constraint("uc_people_aardbei_id", ["aardbei_id"])
25
+        batch_op2.create_unique_constraint("uc_people_full_name", ["full_name"])
26
+        batch_op2.create_unique_constraint("uc_people_display_name", ["display_name"])
27
+
28
+
29
+def downgrade():
30
+    with op.batch_alter_table("people") as batch_op2:
31
+        batch_op2.drop_constraint("uc_people_display_name", type_="unique")
32
+        batch_op2.drop_constraint("uc_people_full_name", type_="unique")
33
+        batch_op2.drop_constraint("uc_people_aardbei_id", type_="unique")
34
+
35
+    with op.batch_alter_table("consumption_types") as batch_op:
36
+        batch_op.drop_constraint("uc_consumption_types_name", type_="unique")

+ 30 - 0
piket_server/alembic/versions/cca57457a0a6_add_aardbei_fields.py

1
+"""Add Aardbei fields
2
+
3
+Revision ID: cca57457a0a6
4
+Revises: 2f3a49058a67
5
+Create Date: 2019-09-05 21:38:28.489281
6
+
7
+"""
8
+from alembic import op
9
+import sqlalchemy as sa
10
+
11
+
12
+# revision identifiers, used by Alembic.
13
+revision = "cca57457a0a6"
14
+down_revision = "2f3a49058a67"
15
+branch_labels = None
16
+depends_on = None
17
+
18
+
19
+def upgrade():
20
+    with op.batch_alter_table("people") as batch_op:
21
+        batch_op.alter_column("name", new_column_name="full_name")
22
+        batch_op.add_column(sa.Column("aardbei_id", sa.Integer(), nullable=True))
23
+        batch_op.add_column(sa.Column("display_name", sa.String(), nullable=True))
24
+
25
+
26
+def downgrade():
27
+    with op.batch_alter_table("people") as batch_op:
28
+        batch_op.alter_column("full_name", new_column_name="name")
29
+        batch_op.drop_column("aardbei_id")
30
+        batch_op.drop_column("display_name")

+ 19 - 0
piket_server/flask.py

1
+"""
2
+Defines the Flask object used to run the server.
3
+"""
4
+
5
+import os
6
+from typing import Any
7
+
8
+from flask import Flask
9
+from flask_sqlalchemy import SQLAlchemy  # type: ignore
10
+
11
+DATA_HOME = os.environ.get("XDG_DATA_HOME", "~/.local/share")
12
+CONFIG_DIR = os.path.join(DATA_HOME, "piket_server")
13
+DB_PATH = os.path.expanduser(os.path.join(CONFIG_DIR, "database.sqlite3"))
14
+DB_URL = f"sqlite:///{DB_PATH}"
15
+
16
+app = Flask("piket_server")
17
+app.config["SQLALCHEMY_DATABASE_URI"] = DB_URL
18
+app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
19
+db: Any = SQLAlchemy(app)

+ 250 - 0
piket_server/models.py

1
+"""
2
+Defines database models used by the server.
3
+"""
4
+
5
+import datetime
6
+from typing import List, Dict, Any
7
+from collections import defaultdict
8
+
9
+from sqlalchemy import func
10
+from sqlalchemy.exc import SQLAlchemyError
11
+
12
+from piket_server.flask import db
13
+
14
+
15
+class Person(db.Model):
16
+    """ Represents a person to be shown on the lists. """
17
+
18
+    __tablename__ = "people"
19
+
20
+    person_id = db.Column(db.Integer, primary_key=True)
21
+    full_name = db.Column(db.String, nullable=False, unique=True)
22
+    display_name = db.Column(db.String, nullable=True, unique=True)
23
+    aardbei_id = db.Column(db.Integer, nullable=True, unique=True)
24
+    active = db.Column(db.Boolean, nullable=False, default=False)
25
+
26
+    consumptions = db.relationship("Consumption", backref="person", lazy=True)
27
+
28
+    def __repr__(self) -> str:
29
+        return f"<Person {self.person_id}: {self.full_name}>"
30
+
31
+    @property
32
+    def as_dict(self) -> dict:
33
+        return {
34
+            "person_id": self.person_id,
35
+            "active": self.active,
36
+            "full_name": self.full_name,
37
+            "display_name": self.display_name,
38
+            "consumptions": {
39
+                ct.consumption_type_id: Consumption.query.filter_by(person=self)
40
+                .filter_by(settlement=None)
41
+                .filter_by(consumption_type=ct)
42
+                .filter_by(reversed=False)
43
+                .count()
44
+                for ct in ConsumptionType.query.all()
45
+            },
46
+        }
47
+
48
+
49
+class Export(db.Model):
50
+    """ Represents a set of exported Settlements. """
51
+
52
+    __tablename__ = "exports"
53
+
54
+    export_id = db.Column(db.Integer, primary_key=True)
55
+    created_at = db.Column(
56
+        db.DateTime, default=datetime.datetime.utcnow, nullable=False
57
+    )
58
+
59
+    settlements = db.relationship("Settlement", backref="export", lazy=True)
60
+
61
+    @property
62
+    def as_dict(self) -> dict:
63
+        return {
64
+            "export_id": self.export_id,
65
+            "created_at": self.created_at.isoformat(),
66
+            "settlement_ids": [s.settlement_id for s in self.settlements],
67
+        }
68
+
69
+
70
+class Settlement(db.Model):
71
+    """ Represents a settlement of the list. """
72
+
73
+    __tablename__ = "settlements"
74
+
75
+    settlement_id = db.Column(db.Integer, primary_key=True)
76
+    name = db.Column(db.String, nullable=False)
77
+    export_id = db.Column(db.Integer, db.ForeignKey("exports.export_id"), nullable=True)
78
+
79
+    consumptions = db.relationship("Consumption", backref="settlement", lazy=True)
80
+
81
+    def __repr__(self) -> str:
82
+        return f"<Settlement {self.settlement_id}: {self.name}>"
83
+
84
+    @property
85
+    def as_dict(self) -> dict:
86
+        return {
87
+            "settlement_id": self.settlement_id,
88
+            "name": self.name,
89
+            "consumption_summary": self.consumption_summary,
90
+            "unique_people": self.unique_people,
91
+            "per_person_counts": self.per_person_counts,
92
+            "count_info": self.per_person,
93
+        }
94
+
95
+    @property
96
+    def unique_people(self) -> int:
97
+        q = (
98
+            Consumption.query.filter_by(settlement=self)
99
+            .filter_by(reversed=False)
100
+            .group_by(Consumption.person_id)
101
+            .count()
102
+        )
103
+        return q
104
+
105
+    @property
106
+    def consumption_summary(self) -> dict:
107
+        q = (
108
+            Consumption.query.filter_by(settlement=self)
109
+            .filter_by(reversed=False)
110
+            .group_by(Consumption.consumption_type_id)
111
+            .order_by(ConsumptionType.name)
112
+            .outerjoin(ConsumptionType)
113
+            .with_entities(
114
+                Consumption.consumption_type_id,
115
+                ConsumptionType.name,
116
+                func.count(Consumption.consumption_id),
117
+            )
118
+            .all()
119
+        )
120
+
121
+        return {r[0]: {"name": r[1], "count": r[2]} for r in q}
122
+
123
+    @property
124
+    def per_person(self) -> dict:
125
+        # Get keys of seen consumption_types
126
+        c_types = self.consumption_summary.keys()
127
+
128
+        result = {}
129
+        for type in c_types:
130
+            c_type = ConsumptionType.query.get(type)
131
+            result[type] = {"consumption_type": c_type.as_dict, "counts": {}}
132
+
133
+            q = (
134
+                Consumption.query.filter_by(settlement=self)
135
+                .filter_by(reversed=False)
136
+                .filter_by(consumption_type=c_type)
137
+                .group_by(Consumption.person_id)
138
+                .order_by(Person.full_name)
139
+                .outerjoin(Person)
140
+                .with_entities(
141
+                    Person.person_id,
142
+                    Person.full_name,
143
+                    func.count(Consumption.consumption_id),
144
+                )
145
+                .all()
146
+            )
147
+
148
+            for row in q:
149
+                result[type]["counts"][row[0]] = {"name": row[1], "count": row[2]}
150
+
151
+        return result
152
+
153
+    @property
154
+    def per_person_counts(self) -> Dict[int, Any]:
155
+        """
156
+        Output a more usable dict containing for each person in the settlement
157
+        how many of each consumption type was counted.
158
+        """
159
+
160
+        q = (
161
+            Consumption.query.filter_by(settlement=self)
162
+            .filter_by(reversed=False)
163
+            .group_by(Consumption.person_id)
164
+            .group_by(Consumption.consumption_type_id)
165
+            .group_by(Person.full_name)
166
+            .outerjoin(Person)
167
+            .with_entities(
168
+                Consumption.person_id,
169
+                Person.full_name,
170
+                Consumption.consumption_type_id,
171
+                func.count(Consumption.consumption_id),
172
+            )
173
+            .all()
174
+        )
175
+
176
+        res: Dict[int, Any] = defaultdict(dict)
177
+
178
+        for row in q:
179
+            item = res[row[0]]
180
+            item["full_name"] = row[1]
181
+            if not item.get("counts"):
182
+                item["counts"] = {}
183
+
184
+            item["counts"][row[2]] = row[3]
185
+
186
+        return res
187
+
188
+
189
+
190
+
191
+class ConsumptionType(db.Model):
192
+    """ Represents a type of consumption to be counted. """
193
+
194
+    __tablename__ = "consumption_types"
195
+
196
+    consumption_type_id = db.Column(db.Integer, primary_key=True)
197
+    name = db.Column(db.String, nullable=False, unique=True)
198
+    icon = db.Column(db.String)
199
+    active = db.Column(db.Boolean, default=True)
200
+
201
+    consumptions = db.relationship("Consumption", backref="consumption_type", lazy=True)
202
+
203
+    def __repr__(self) -> str:
204
+        return f"<ConsumptionType: {self.name}>"
205
+
206
+    @property
207
+    def as_dict(self) -> dict:
208
+        return {
209
+            "consumption_type_id": self.consumption_type_id,
210
+            "name": self.name,
211
+            "icon": self.icon,
212
+            "active": self.active,
213
+        }
214
+
215
+
216
+class Consumption(db.Model):
217
+    """ Represent one consumption to be counted. """
218
+
219
+    __tablename__ = "consumptions"
220
+
221
+    consumption_id = db.Column(db.Integer, primary_key=True)
222
+    person_id = db.Column(db.Integer, db.ForeignKey("people.person_id"), nullable=True)
223
+    consumption_type_id = db.Column(
224
+        db.Integer,
225
+        db.ForeignKey("consumption_types.consumption_type_id"),
226
+        nullable=False,
227
+    )
228
+    settlement_id = db.Column(
229
+        db.Integer, db.ForeignKey("settlements.settlement_id"), nullable=True
230
+    )
231
+    created_at = db.Column(
232
+        db.DateTime, default=datetime.datetime.utcnow, nullable=False
233
+    )
234
+    reversed = db.Column(db.Boolean, default=False, nullable=False)
235
+
236
+    def __repr__(self) -> str:
237
+        return (
238
+            f"<Consumption: {self.consumption_type.name} for {self.person.full_name}>"
239
+        )
240
+
241
+    @property
242
+    def as_dict(self) -> dict:
243
+        return {
244
+            "consumption_id": self.consumption_id,
245
+            "person_id": self.person_id,
246
+            "consumption_type_id": self.consumption_type_id,
247
+            "settlement_id": self.settlement_id,
248
+            "created_at": self.created_at.isoformat(),
249
+            "reversed": self.reversed,
250
+        }

+ 0 - 0
piket_server/routes/__init__.py


+ 121 - 0
piket_server/routes/aardbei.py

1
+from typing import Any, Dict, List, Tuple, Union
2
+
3
+from flask import request
4
+
5
+from piket_server.aardbei_sync import (
6
+    AARDBEI_ENDPOINT,
7
+    ActivityId,
8
+    get_activity,
9
+    AardbeiLink,
10
+    AardbeiSyncError,
11
+    create_missing,
12
+    get_aardbei_people,
13
+    match_activity,
14
+    get_activities,
15
+    link_matches,
16
+    match_local_aardbei,
17
+    update_names,
18
+)
19
+from piket_server.flask import app, db
20
+
21
+
22
+def common_prepare_aardbei_sync(
23
+    token: str, endpoint: str
24
+) -> Union[AardbeiSyncError, AardbeiLink]:
25
+    aardbei_people = get_aardbei_people(token, endpoint)
26
+
27
+    if isinstance(aardbei_people, AardbeiSyncError):
28
+        return aardbei_people
29
+
30
+    aardbei_activities = get_activities(token, endpoint)
31
+
32
+    if isinstance(aardbei_activities, AardbeiSyncError):
33
+        return aardbei_activities
34
+
35
+    return match_local_aardbei(aardbei_people)
36
+
37
+
38
+@app.route("/aardbei/diff_people", methods=["POST"])
39
+def aardbei_diff() -> Tuple[Dict[str, Any], int]:
40
+    data: Dict[str, str] = request.json
41
+    link = common_prepare_aardbei_sync(
42
+        data["token"], data.get("endpoint", AARDBEI_ENDPOINT)
43
+    )
44
+
45
+    if isinstance(link, AardbeiSyncError):
46
+        return {"error": link.value}, 503
47
+
48
+    return (
49
+        {
50
+            "num_changes": link.num_changes,
51
+            "new_people": [member.person.full_name for member in link.remote_only],
52
+            "link_existing": [match.local.full_name for match in link.matches],
53
+            "altered_name": [match.local.full_name for match in link.matches],
54
+        },
55
+        200,
56
+    )
57
+
58
+
59
+@app.route("/aardbei/sync_people", methods=["POST"])
60
+def aardbei_apply() -> Union[Tuple[Dict[str, Any], int]]:
61
+    data: Dict[str, str] = request.json
62
+    link = common_prepare_aardbei_sync(
63
+        data["token"], data.get("endpoint", AARDBEI_ENDPOINT)
64
+    )
65
+
66
+    if isinstance(link, AardbeiSyncError):
67
+        return {"error": link.value}, 503
68
+
69
+    link_matches(link.matches)
70
+    create_missing(link.remote_only)
71
+    update_names(link.altered_name)
72
+
73
+    db.session.commit()
74
+
75
+    return (
76
+        {
77
+            "num_changes": link.num_changes,
78
+            "new_people": [member.person.full_name for member in link.remote_only],
79
+            "link_existing": [match.local.full_name for match in link.matches],
80
+            "altered_name": [match.local.full_name for match in link.altered_name],
81
+        },
82
+        200,
83
+    )
84
+
85
+
86
+@app.route("/aardbei/get_activities", methods=["POST"])
87
+def aardbei_get_activities() -> Tuple[Dict[str, object], int]:
88
+    data: Dict[str, str] = request.json
89
+    activities = get_activities(data["token"], data.get("endpoint", AARDBEI_ENDPOINT))
90
+
91
+    if isinstance(activities, AardbeiSyncError):
92
+        return {"error": activities.value}, 503
93
+
94
+    return {"activities": [x.as_json_dict for x in activities]}, 200
95
+
96
+
97
+@app.route("/aardbei/apply_activity", methods=["POST"])
98
+def aardbei_apply_activity() -> Tuple[Dict[str, Any], int]:
99
+    data: Dict[str, Union[str, int]] = request.json
100
+    aid = data["activity_id"]
101
+    token = data["token"]
102
+    endpoint = data["endpoint"]
103
+
104
+    if not isinstance(aid, int):
105
+        return {"error": "nonnumeric_activity_id"}, 400
106
+
107
+    if not isinstance(token, str):
108
+        return {"error": "illtyped_token"}, 400
109
+
110
+    if not isinstance(endpoint, str):
111
+        return {"error": "illtyped_endpoint"}, 400
112
+
113
+    activity = get_activity(activity_id=ActivityId(aid), token=token, endpoint=endpoint)
114
+
115
+    if isinstance(activity, AardbeiSyncError):
116
+        return {"error": activity.value}, 503
117
+
118
+    match_activity(activity)
119
+    db.session.commit()
120
+
121
+    return (activity.as_json_dict, 200)

+ 64 - 0
piket_server/routes/consumption_types.py

1
+"""
2
+Provides routes related to managing ConsumptionType objects.
3
+"""
4
+
5
+from sqlalchemy.exc import SQLAlchemyError
6
+from flask import jsonify, request
7
+
8
+from piket_server.models import ConsumptionType
9
+from piket_server.flask import app, db
10
+
11
+
12
+@app.route("/consumption_types", methods=["GET"])
13
+def get_consumption_types():
14
+    """ Return a list of currently active consumption types. """
15
+    try:
16
+        active = int(request.args.get("active", 1))
17
+
18
+    except ValueError:
19
+        return {}, 400
20
+
21
+    ctypes = ConsumptionType.query.filter_by(active=active).all()
22
+    result = [ct.as_dict for ct in ctypes]
23
+    return jsonify(consumption_types=result)
24
+
25
+
26
+@app.route("/consumption_types/<int:consumption_type_id>", methods=["GET"])
27
+def get_consumption_type(consumption_type_id: int):
28
+    ct = ConsumptionType.query.get_or_404(consumption_type_id)
29
+
30
+    return jsonify(consumption_type=ct.as_dict)
31
+
32
+
33
+@app.route("/consumption_types", methods=["POST"])
34
+def add_consumption_type():
35
+    """ Add a new ConsumptionType.  """
36
+    json = request.get_json()
37
+
38
+    if not json:
39
+        return jsonify({"error": "Could not parse JSON."}), 400
40
+
41
+    data = json.get("consumption_type") or {}
42
+    ct = ConsumptionType(name=data.get("name"), icon=data.get("icon"))
43
+
44
+    try:
45
+        db.session.add(ct)
46
+        db.session.commit()
47
+    except SQLAlchemyError:
48
+        return jsonify({"error": "Invalid arguments for ConsumptionType."}), 400
49
+
50
+    return jsonify(consumption_type=ct.as_dict), 201
51
+
52
+
53
+@app.route("/consumption_types/<int:consumption_type_id>", methods=["PATCH"])
54
+def activate_consumption_type(consumption_type_id: int):
55
+    ct = ConsumptionType.query.get_or_404(consumption_type_id)
56
+
57
+    data = request.json["consumption_type"]
58
+    new_active = data.get("active", True)
59
+
60
+    ct.active = new_active
61
+    db.session.add(ct)
62
+    db.session.commit()
63
+
64
+    return jsonify(consumption_type=ct.as_dict), 200

+ 36 - 0
piket_server/routes/consumptions.py

1
+"""
2
+Provides routes related to Consumption objects.
3
+"""
4
+
5
+from flask import jsonify
6
+from sqlalchemy.exc import SQLAlchemyError
7
+
8
+from piket_server.flask import app, db
9
+from piket_server.models import Consumption
10
+
11
+
12
+@app.route("/consumptions/<int:consumption_id>", methods=["DELETE"])
13
+def reverse_consumption(consumption_id: int):
14
+    """ Reverse a consumption. """
15
+    consumption = Consumption.query.get_or_404(consumption_id)
16
+
17
+    if consumption.reversed:
18
+        return (
19
+            jsonify(
20
+                {
21
+                    "error": "Consumption already reversed",
22
+                    "consumption": consumption.as_dict,
23
+                }
24
+            ),
25
+            409,
26
+        )
27
+
28
+    try:
29
+        consumption.reversed = True
30
+        db.session.add(consumption)
31
+        db.session.commit()
32
+
33
+    except SQLAlchemyError:
34
+        return jsonify({"error": "Database error."}), 500
35
+
36
+    return jsonify(consumption=consumption.as_dict), 200

+ 46 - 0
piket_server/routes/exports.py

1
+"""
2
+Provides routes for managing Export objects.
3
+"""
4
+
5
+from flask import jsonify
6
+from sqlalchemy.exc import SQLAlchemyError
7
+
8
+from piket_server.flask import app, db
9
+from piket_server.models import Export, Settlement
10
+
11
+@app.route("/exports", methods=["GET"])
12
+def get_exports():
13
+    """ Return a list of the created Exports. """
14
+    result = Export.query.all()
15
+    return jsonify(exports=[e.as_dict for e in result])
16
+
17
+
18
+@app.route("/exports/<int:export_id>", methods=["GET"])
19
+def get_export(export_id: int):
20
+    """ Return an overview for the given Export. """
21
+    e = Export.query.get_or_404(export_id)
22
+
23
+    ss = [s.as_dict for s in e.settlements]
24
+
25
+    return jsonify(export=e.as_dict, settlements=ss)
26
+
27
+
28
+@app.route("/exports", methods=["POST"])
29
+def add_export():
30
+    """ Create an Export, and link all un-exported Settlements to it. """
31
+    # Assert that there are Settlements to be exported.
32
+    s_count = Settlement.query.filter_by(export=None).count()
33
+    if s_count == 0:
34
+        return jsonify(error="No un-exported Settlements."), 403
35
+
36
+    e = Export()
37
+
38
+    db.session.add(e)
39
+    db.session.commit()
40
+
41
+    Settlement.query.filter_by(export=None).update({"export_id": e.export_id})
42
+    db.session.commit()
43
+
44
+    ss = [s.as_dict for s in e.settlements]
45
+
46
+    return jsonify(export=e.as_dict, settlements=ss), 201

+ 38 - 0
piket_server/routes/general.py

1
+"""
2
+Provides general routes.
3
+"""
4
+
5
+from flask import jsonify
6
+
7
+from piket_server.flask import app
8
+from piket_server.models import Consumption
9
+
10
+
11
+@app.route("/ping")
12
+def ping() -> str:
13
+    """ Return a status ping. """
14
+    return "Pong"
15
+
16
+
17
+@app.route("/status")
18
+def status():
19
+    """ Return a status dict with info about the database. """
20
+    unsettled_q = Consumption.query.filter_by(settlement=None).filter_by(reversed=False)
21
+
22
+    unsettled = unsettled_q.count()
23
+
24
+    first = None
25
+    last = None
26
+    if unsettled:
27
+        last = (
28
+            unsettled_q.order_by(Consumption.created_at.desc())
29
+            .first()
30
+            .created_at.isoformat()
31
+        )
32
+        first = (
33
+            unsettled_q.order_by(Consumption.created_at.asc())
34
+            .first()
35
+            .created_at.isoformat()
36
+        )
37
+
38
+    return jsonify({"unsettled": {"amount": unsettled, "first": first, "last": last}})

+ 122 - 0
piket_server/routes/people.py

1
+"""
2
+Provides routes related to managing Person objects.
3
+"""
4
+
5
+from flask import jsonify, request
6
+from sqlalchemy.exc import SQLAlchemyError
7
+
8
+from piket_server.models import Consumption, Person
9
+from piket_server.flask import app, db
10
+
11
+
12
+@app.route("/people", methods=["GET"])
13
+def get_people():
14
+    """ Return a list of currently known people. """
15
+    people = Person.query.order_by(Person.full_name).all()
16
+    q = Person.query.order_by(Person.full_name)
17
+    if request.args.get("active"):
18
+        active_status = request.args.get("active", type=int)
19
+        q = q.filter_by(active=active_status)
20
+    people = q.all()
21
+    result = [person.as_dict for person in people]
22
+    return jsonify(people=result)
23
+
24
+
25
+@app.route("/people/<int:person_id>", methods=["GET"])
26
+def get_person(person_id: int):
27
+    person = Person.query.get_or_404(person_id)
28
+
29
+    return jsonify(person=person.as_dict)
30
+
31
+
32
+@app.route("/people", methods=["POST"])
33
+def add_person():
34
+    """
35
+    Add a new person.
36
+
37
+    Required parameters:
38
+    - name (str)
39
+    """
40
+    json = request.get_json()
41
+
42
+    if not json:
43
+        return jsonify({"error": "Could not parse JSON."}), 400
44
+
45
+    data = json.get("person") or {}
46
+    person = Person(
47
+        full_name=data.get("full_name"),
48
+        active=data.get("active", False),
49
+        display_name=data.get("display_name", None),
50
+    )
51
+
52
+    try:
53
+        db.session.add(person)
54
+        db.session.commit()
55
+    except SQLAlchemyError:
56
+        return jsonify({"error": "Invalid arguments for Person."}), 400
57
+
58
+    return jsonify(person=person.as_dict), 201
59
+
60
+
61
+@app.route("/people/<int:person_id>", methods=["PATCH"])
62
+def update_person(person_id: int):
63
+    person = Person.query.get_or_404(person_id)
64
+
65
+    data = request.json["person"]
66
+    changed = False
67
+
68
+    if "active" in data:
69
+        person.active = data["active"]
70
+        changed = True
71
+
72
+    if "full_name" in data:
73
+        person.full_name = data["full_name"]
74
+        changed = True
75
+
76
+    if "display_name" in data:
77
+        person.display_name = data["display_name"]
78
+        changed = True
79
+
80
+    if changed:
81
+        db.session.add(person)
82
+        db.session.commit()
83
+
84
+    return jsonify(person=person.as_dict)
85
+
86
+
87
+@app.route("/people/<int:person_id>/add_consumption", methods=["POST"])
88
+def add_consumption(person_id: int):
89
+    person = Person.query.get_or_404(person_id)
90
+
91
+    consumption = Consumption(person=person, consumption_type_id=1)
92
+    try:
93
+        db.session.add(consumption)
94
+        db.session.commit()
95
+    except SQLAlchemyError:
96
+        return (
97
+            jsonify(
98
+                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
99
+            ),
100
+            400,
101
+        )
102
+
103
+    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201
104
+
105
+
106
+@app.route("/people/<int:person_id>/add_consumption/<int:ct_id>", methods=["POST"])
107
+def add_consumption2(person_id: int, ct_id: int):
108
+    person = Person.query.get_or_404(person_id)
109
+
110
+    consumption = Consumption(person=person, consumption_type_id=ct_id)
111
+    try:
112
+        db.session.add(consumption)
113
+        db.session.commit()
114
+    except SQLAlchemyError:
115
+        return (
116
+            jsonify(
117
+                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
118
+            ),
119
+            400,
120
+        )
121
+
122
+    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201

+ 49 - 0
piket_server/routes/settlements.py

1
+"""
2
+Provides routes for managing Settlement objects.
3
+"""
4
+
5
+from sqlalchemy.exc import SQLAlchemyError
6
+from flask import jsonify, request
7
+
8
+from piket_server.flask import app, db
9
+from piket_server.models import Consumption, Settlement
10
+
11
+
12
+@app.route("/settlements", methods=["GET"])
13
+def get_settlements():
14
+    """ Return a list of the active Settlements. """
15
+    result = Settlement.query.all()
16
+    return jsonify(settlements=[s.as_dict for s in result])
17
+
18
+
19
+@app.route("/settlements/<int:settlement_id>", methods=["GET"])
20
+def get_settlement(settlement_id: int):
21
+    """ Show full details for a single Settlement. """
22
+    s = Settlement.query.get_or_404(settlement_id)
23
+
24
+    per_person = s.per_person
25
+
26
+    return jsonify(settlement=s.as_dict, count_info=per_person)
27
+
28
+
29
+@app.route("/settlements", methods=["POST"])
30
+def add_settlement():
31
+    """ Create a Settlement, and link all un-settled Consumptions to it. """
32
+    json = request.get_json()
33
+
34
+    if not json:
35
+        return jsonify({"error": "Could not parse JSON."}), 400
36
+
37
+    data = json.get("settlement") or {}
38
+    s = Settlement(name=data["name"])
39
+
40
+    db.session.add(s)
41
+    db.session.commit()
42
+
43
+    Consumption.query.filter_by(settlement=None).update(
44
+        {"settlement_id": s.settlement_id}
45
+    )
46
+
47
+    db.session.commit()
48
+
49
+    return jsonify(settlement=s.as_dict, count_info=s.per_person)

+ 4 - 3
piket_server/seed.py

6
 import csv
6
 import csv
7
 import os
7
 import os
8
 
8
 
9
-from piket_server import db, Person, Settlement, ConsumptionType, Consumption
9
+from piket_server.models import Person, Settlement, ConsumptionType, Consumption
10
+from piket_server.flask import db
10
 
11
 
11
 
12
 
12
 def main():
13
 def main():
52
         print("All data removed. Recreating database...")
53
         print("All data removed. Recreating database...")
53
         db.create_all()
54
         db.create_all()
54
 
55
 
55
-        from alembic.config import Config
56
-        from alembic import command
56
+        from alembic.config import Config  # type: ignore
57
+        from alembic import command  # type: ignore
57
 
58
 
58
         alembic_cfg = Config(os.path.join(os.path.dirname(__file__), "alembic.ini"))
59
         alembic_cfg = Config(os.path.join(os.path.dirname(__file__), "alembic.ini"))
59
         command.stamp(alembic_cfg, "head")
60
         command.stamp(alembic_cfg, "head")

+ 10 - 0
piket_server/util.py

1
+import datetime
2
+from typing import Optional
3
+
4
+
5
+def fmt_datetime(x: Optional[datetime.datetime]) -> Optional[str]:
6
+    """Format a datetime as ISO 8601, if it's not None."""
7
+    if x is not None:
8
+        return x.isoformat()
9
+
10
+    return None

+ 3 - 2
setup.py

18
     entry_points={
18
     entry_points={
19
         "console_scripts": [
19
         "console_scripts": [
20
             "piket-client=piket_client.gui:main",
20
             "piket-client=piket_client.gui:main",
21
+            "piket-cli=piket_client.cli:cli",
21
             "piket-seed=piket_server.seed:main",
22
             "piket-seed=piket_server.seed:main",
22
         ]
23
         ]
23
     },
24
     },
24
     install_requires=[],
25
     install_requires=[],
25
     extras_require={
26
     extras_require={
26
-        "dev": ["black", "pylint"],
27
+        "dev": ["black", "pylint", "mypy", "isort"],
27
         "server": ["Flask", "SQLAlchemy", "Flask-SQLAlchemy", "alembic", "uwsgi"],
28
         "server": ["Flask", "SQLAlchemy", "Flask-SQLAlchemy", "alembic", "uwsgi"],
28
-        "client": ["PySide2", "qdarkstyle>=2.6.0", "requests", "simpleaudio"],
29
+        "client": ["PySide2", "qdarkstyle>=2.6.0", "requests", "simpleaudio", "click", "prettytable"],
29
         "osk": ["dbus-python"],
30
         "osk": ["dbus-python"],
30
         "sentry": ["raven"],
31
         "sentry": ["raven"],
31
     },
32
     },