Skip to content

Commit

Permalink
fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Jan 13, 2025
1 parent 0c7794b commit 31aaa3e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ jobs:
-o faulthandler_timeout=3660 \
-v unit-tests integration-tests examples
timeout-minutes: 120
- name: Prepare logs
if: always()
run: |
mkdir logs
cd /tmp/pytest-of-firedrake/pytest-0/
find . -name "*.log" -exec cp --parents {} /__w/gusto/gusto/logs/ \;
- name: Test serial netCDF
run: |
. /home/firedrake/firedrake/bin/activate
Expand All @@ -69,12 +75,6 @@ jobs:
export PYOP2_CFLAGS=-O0
python -m pytest -n 3 -v integration-tests/model/test_nc_outputting.py
timeout-minutes: 10
- name: Prepare logs
if: always()
run: |
mkdir logs
cd /tmp/pytest-of-firedrake/pytest-0/
find . -name "*.log" -exec cp --parents {} /__w/gusto/gusto/logs/ \;
- name: Upload artifact
if: always()
uses: actions/upload-pages-artifact@v3
Expand Down
13 changes: 6 additions & 7 deletions gusto/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,15 +807,15 @@ def create_nc_dump(self, filename, space_names):
if nc_field_file:
nc_field_file.createDimension(f'coords_{space_name}', ndofs)

for i, (coord_name, coord_field) in enumerate(zip(self.domain.coords.coords_name, coord_fields)):
for coord_name, coord_field in zip(self.domain.coords.coords_name, coord_fields):
if nc_field_file:
nc_field_file.createVariable(f'{coord_name}_{space_name}', float, f'coords_{space_name}')

if nc_supports_parallel:
start, stop = self.domain.coords.parallel_array_lims[space_name]
nc_field_file.variables[f'{coord_name}_{space_name}'][start:stop] = coord_field.dat.data_ro
else:
global_coord_field = gather_field_data(coord_field, i, self.domain)
global_coord_field = gather_field_data(coord_field, self.domain)
if comm.rank == 0:
nc_field_file.variables[f'{coord_name}_{space_name}'][...] = global_coord_field

Expand Down Expand Up @@ -849,7 +849,7 @@ def write_nc_dump(self, t):
nc_field_file['time'][self.field_t_idx] = t

# Loop through output field data here
for i, field in enumerate(self.to_dump):
for field in self.to_dump:
field_name = field.name()
space_name = field.function_space().name

Expand All @@ -868,7 +868,7 @@ def write_nc_dump(self, t):
start, stop = self.domain.coords.parallel_array_lims[space_name]
nc_field_file[field_name]['field_values'][start:stop, self.field_t_idx] = field.dat.data_ro
else:
global_field_data = gather_field_data(field, i, self.domain)
global_field_data = gather_field_data(field, self.domain)
if comm.rank == 0:
nc_field_file[field_name]['field_values'][:, self.field_t_idx] = global_field_data

Expand Down Expand Up @@ -964,12 +964,11 @@ def make_nc_dataset(filename, access, comm):
return nc_field_file, nc_supports_parallel


def gather_field_data(field, field_index, domain):
def gather_field_data(field, domain):
"""Gather global field data into a single array on rank 0.
Args:
field (:class:`firedrake.Function`): The field to gather.
field_index (int): Index used to identify the field.
domain (:class:`Domain`): The domain.
Returns:
Expand All @@ -983,7 +982,7 @@ def gather_field_data(field, field_index, domain):
comm = domain.mesh.comm

if comm.size == 1:
return field.dat.data_ro
return field.dat.data_ro.copy()

gathered_data = comm.gather(field.dat.data_ro)
if comm.rank == 0:
Expand Down

0 comments on commit 31aaa3e

Please sign in to comment.