Accepted keys
'all', 'config', 'R_abs_gt', 'template', 'anisotropy', 'cropped_noisy', 'outliers'
- If 'all' or len(keys) > 3:
return the whole data dict
- If ('anisotropy', 'cropped_noisy', 'outliers') in keys:
return (degraded views, degradation value, R_abs_gt) (+ optionally config, template)
Returns:
-
list[Tensor | float] | dict
–
Source code in src/polar/example/utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 | def load_sample_data(*keys: str) -> list[Tensor | float] | dict:
""" Accepted keys:
'all', 'config', 'R_abs_gt', 'template', 'anisotropy', 'cropped_noisy', 'outliers'
- If 'all' or len(keys) > 3:
return the whole data dict
- If ('anisotropy', 'cropped_noisy', 'outliers') in keys:
return (degraded views, degradation value, R_abs_gt) (+ optionally config, template)
Returns:
TODO
"""
here = Path(__file__).parent
data = torch.load(here / 'data.pt')
if 'all' in keys or len(keys) > 3:
return data
degradations = ('anisotropy', 'cropped_noisy', 'outliers')
if len(keys) == 0:
i = randint(0, 2)
keys = (degradations[i],)
msg = "Only one degradations can be specified. Call `load_sample_data` several times or pass 'all'."
assert sum([k in degradations for k in keys]) <= 1, msg
return_values = list()
views = None
for k in degradations:
if k in keys:
views = data['views'][k]
value = data['config'][k]
if views is not None:
return_values.append(views)
return_values.append(value)
return_values.append(data['R_abs_gt'])
for key in ('config', 'template'):
if key in keys:
return_values.append(data[key])
return return_values
|