Skip to content

Sample dataset¤

load_sample_data(*keys: str) -> list[Tensor | float] | dict ¤

Accepted keys

'all', 'config', 'R_abs_gt', 'template', 'anisotropy', 'cropped_noisy', 'outliers' - If 'all' or len(keys) > 3: return the whole data dict - If ('anisotropy', 'cropped_noisy', 'outliers') in keys: return (degraded views, degradation value, R_abs_gt) (+ optionally config, template)

Returns:

  • list[Tensor | float] | dict

    TODO

Source code in src/polar/example/utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def load_sample_data(*keys: str) -> list[Tensor | float] | dict:
    """ Accepted keys:
        'all', 'config', 'R_abs_gt', 'template', 'anisotropy', 'cropped_noisy', 'outliers'
        - If 'all' or len(keys) > 3:
            return the whole data dict
        - If ('anisotropy', 'cropped_noisy', 'outliers') in keys:
            return (degraded views, degradation value, R_abs_gt) (+ optionally config, template)

    Returns:
        TODO
    """
    here = Path(__file__).parent
    data = torch.load(here / 'data.pt')
    if 'all' in keys or len(keys) > 3:
        return data
    degradations = ('anisotropy', 'cropped_noisy', 'outliers')
    if len(keys) == 0:
        i = randint(0, 2)
        keys = (degradations[i],)
    msg = "Only one degradations can be specified. Call `load_sample_data` several times or pass 'all'."
    assert sum([k in degradations for k in keys]) <= 1, msg
    return_values = list()
    views = None
    for k in degradations:
        if k in keys:
            views = data['views'][k]
            value = data['config'][k]
    if views is not None:
        return_values.append(views)
        return_values.append(value)
        return_values.append(data['R_abs_gt'])
    for key in ('config', 'template'):
        if key in keys:
            return_values.append(data[key])
    return return_values