from __future__ import print_function import numpy as np import yt from astropy.table import Table yt.enable_parallelism() from yt.utilities.parallel_tools.parallel_analysis_interface import \ _get_comm, \ parallel_objects, \ parallel_root_only, \ communication_system def calc_ang_mom_and_fluxes(): ### set up the table of all the stuff we want data = Table(names=('redshift', 'radius', \ 'net_mass_flux', 'net_metal_flux'), \ dtype=('f8', 'f8', 'f8','f8')) # perform calculation over two outputs outs = ['RD0008/RD0008','RD0009/RD0009'] storage = {} ############################################################################################# ## with the 2, I was trying to tell it to break the threads up ## into two groups - half for each dataset ## but maybe that's not the right way to think about it? ############################################################################################# for sto,snap in yt.parallel_objects(outs,2,storage=storage): ds = yt.load(snap) # create all the regions zsnap = ds.get_parameter('CosmologyCurrentRedshift') halo_center = [0.5,0.5,0.5] refine_width_code = 0.25 refine_width = ds.quan(refine_width_code,'code_length').in_units('kpc').value ### Radial bins for the flux calculation radii = refine_width*0.5*np.arange(0.9, 0.1, -0.1) # 0.5 because radius big_sphere = ds.sphere(halo_center,(refine_width_code,'code_length')) # we want to subtract the bulk velocity from the radial velocities bulk_velocity = big_sphere.quantities["BulkVelocity"]() # find number of cells for the FRB cell_size = np.unique(big_sphere['dx'].in_units('kpc'))[2] box_width = ds.quan(0.9*refine_width_code,'code_length') nbins = int(np.ceil(box_width/cell_size).value) halo_center = ds.arr(halo_center,'code_length') xL,xR = halo_center[0]-box_width/2.,halo_center[0]+box_width/2. yL,yR = halo_center[1]-box_width/2.,halo_center[1]+box_width/2. zL,zR = halo_center[2]-box_width/2.,halo_center[2]+box_width/2. jnbins = complex(0,nbins) box = ds.r[xL:xR:jnbins,yL:yR:jnbins,zL:zR:jnbins] box.set_field_parameter("center",halo_center) box.set_field_parameter("bulk_velocity",bulk_velocity) ### OK, now want to call the fields that we'll need for the fluxes cell_mass = box['cell_mass'].to("Msun").flatten() metal_mass = box[('gas', 'metal_mass')].to("Msun").flatten() radius = box['radius'].to("kpc").flatten() radial_velocity = box['radial_velocity'].to('kpc/yr').flatten() ############################################################################################ ## ok -- I think doing it this way should make each radius it's own thread table1 = np.zeros((len(radii),4)) for rad in parallel_objects(radii): ############################################################################################# if rad != np.max(radii): idI = np.where(radii == rad)[0] if rad == radii[-1]: minrad,maxrad = ds.quan(0.,'kpc'),ds.quan(rad,'kpc') else: maxrad,minrad = ds.quan(rad,'kpc'),ds.quan(radii[idI[0]+1],'kpc') dr = maxrad - minrad idR = np.where((radius >= minrad) & (radius < maxrad))[0] gas_flux = (np.sum(cell_mass[idR]*radial_velocity[idR])/dr).to("Msun/yr") metal_flux = (np.sum(metal_mass[idR]*radial_velocity[idR])/dr).to("Msun/yr") table1[idI,:] = [ds.current_redshift, rad, gas_flux, metal_flux] ############################################################################################# # Make sure the table is being updated across threads, then store it in the # global dictionary ############################################################################################# comm = communication_system.communicators[-1] for i in range(table1.shape[0]): table1[i,:] = comm.mpi_allreduce(table1[i,:], op="sum") sto.result = table1 sto.result_id = str(ds) ############################################################################################# ############################################################################################# # once all of the threads are finished, I want it to recombine all of the # dataset calculations into the same file ############################################################################################# if yt.is_root(): for key in storage.keys(): table1 = storage[key] for i in range(table1.shape[0]-1): data.add_row(table1[i+1,:]) tablename = 'testing_parallel.hdf5' data.write(tablename,path='all_data',overwrite=True) return "whooooo angular momentum wheeeeeeee" #----------------------------------------------------------------------------------------------------- calc_ang_mom_and_fluxes()